|
|
"""Decode OneOCRFeatureExtract config blob.""" |
|
|
import onnx |
|
|
import numpy as np |
|
|
import struct |
|
|
from pathlib import Path |
|
|
|
|
|
m = onnx.load('oneocr_extracted/onnx_models/model_11_ir6_1.9_26KB.onnx') |
|
|
|
|
|
for init in m.graph.initializer: |
|
|
if init.name == 'feature/config': |
|
|
raw = init.string_data[0] |
|
|
print(f'Total bytes: {len(raw)}') |
|
|
print(f'First 100 bytes hex: {raw[:100].hex()}') |
|
|
|
|
|
|
|
|
for offset in [0, 4, 8, 12]: |
|
|
vals = struct.unpack_from('<4f', raw, offset) |
|
|
print(f'Offset {offset:3d} as 4xfloat32: {vals}') |
|
|
|
|
|
|
|
|
rnn = Path('oneocr_extracted/config_data/chunk_36_rnn_info.rnn_info').read_text() |
|
|
rnn_lines = rnn.strip().split('\n') |
|
|
lp_count = int(rnn_lines[0].split()[-1]) |
|
|
print(f'\nLogPrior count from rnn_info: {lp_count}') |
|
|
lp_val = float(rnn_lines[1]) |
|
|
print(f'LogPrior[0] = {lp_val}') |
|
|
|
|
|
lp_f32 = struct.pack('<f', np.float32(lp_val)) |
|
|
lp_f64 = struct.pack('<d', lp_val) |
|
|
pos_f32 = raw.find(lp_f32) |
|
|
pos_f64 = raw.find(lp_f64) |
|
|
print(f'LogPrior as float32 at pos: {pos_f32}') |
|
|
print(f'LogPrior as float64 at pos: {pos_f64}') |
|
|
|
|
|
|
|
|
|
|
|
arr_f32 = np.frombuffer(raw, dtype=np.float32) |
|
|
|
|
|
|
|
|
reasonable = (np.abs(arr_f32) < 20) & (arr_f32 != 0) |
|
|
transitions = np.diff(reasonable.astype(int)) |
|
|
starts = np.where(transitions == 1)[0] + 1 |
|
|
ends = np.where(transitions == -1)[0] + 1 |
|
|
|
|
|
print(f'\nSections of reasonable float32 values:') |
|
|
for s, e in zip(starts[:10], ends[:10]): |
|
|
print(f' [{s}:{e}] ({e-s} values) first: {arr_f32[s:s+3]}') |
|
|
|
|
|
|
|
|
header_ints = struct.unpack_from('<8I', raw, 0) |
|
|
print(f'\nFirst 8 uint32: {header_ints}') |
|
|
|
|
|
header_shorts = struct.unpack_from('<16H', raw, 0) |
|
|
print(f'First 16 uint16: {header_shorts}') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print('\n=== rnn_info structure ===') |
|
|
section = None |
|
|
counts = {} |
|
|
for line in rnn_lines: |
|
|
if line.startswith('<') and line.endswith('>'): |
|
|
section = line |
|
|
elif line.startswith('<') and '>' in line: |
|
|
parts = line.strip().split() |
|
|
section = parts[0].rstrip('>')+'>' |
|
|
count = int(parts[-1]) if len(parts) > 1 else 0 |
|
|
counts[section] = count |
|
|
print(f'Section: {section} count={count}') |
|
|
else: |
|
|
if section and section not in counts: |
|
|
counts[section] = 0 |
|
|
print(f'Sections found: {counts}') |
|
|
|