|
|
"""Deep-dive into model_11 and model_22 graph structure to understand OneOCRFeatureExtract.""" |
|
|
import onnx |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
|
|
|
models_dir = Path("oneocr_extracted/onnx_models") |
|
|
|
|
|
for idx in [11, 22]: |
|
|
matches = list(models_dir.glob(f"model_{idx:02d}_*")) |
|
|
model = onnx.load(str(matches[0])) |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"FULL GRAPH: model_{idx:02d}") |
|
|
print(f"{'='*70}") |
|
|
|
|
|
|
|
|
print(f"\n Initializers ({len(model.graph.initializer)}):") |
|
|
for init in model.graph.initializer: |
|
|
data = onnx.numpy_helper.to_array(init) |
|
|
print(f" {init.name}: shape={data.shape}, dtype={data.dtype}, " |
|
|
f"range=[{data.min():.4f}, {data.max():.4f}]") |
|
|
|
|
|
|
|
|
print(f"\n Nodes ({len(model.graph.node)}):") |
|
|
for i, node in enumerate(model.graph.node): |
|
|
domain_str = f" (domain={node.domain!r})" if node.domain else "" |
|
|
print(f" [{i}] {node.op_type}{domain_str}") |
|
|
print(f" in: {list(node.input)}") |
|
|
print(f" out: {list(node.output)}") |
|
|
for attr in node.attribute: |
|
|
if attr.type == 2: |
|
|
print(f" {attr.name} = {attr.i}") |
|
|
elif attr.type == 1: |
|
|
print(f" {attr.name} = {attr.f}") |
|
|
elif attr.type == 3: |
|
|
val = attr.s |
|
|
if len(val) > 100: |
|
|
print(f" {attr.name} = bytes({len(val)})") |
|
|
else: |
|
|
print(f" {attr.name} = {val!r}") |
|
|
elif attr.type == 4: |
|
|
t = attr.t |
|
|
print(f" {attr.name} = tensor(dtype={t.data_type}, dims={list(t.dims)}, " |
|
|
f"raw_bytes={len(t.raw_data)})") |
|
|
elif attr.type == 7: |
|
|
print(f" {attr.name} = {list(attr.ints)}") |
|
|
elif attr.type == 6: |
|
|
print(f" {attr.name} = {list(attr.floats)[:10]}...") |
|
|
|
|
|
|
|
|
for init in model.graph.initializer: |
|
|
if "config" in init.name.lower() or "feature" in init.name.lower(): |
|
|
raw = init.raw_data |
|
|
print(f"\n feature/config blob: {len(raw)} bytes") |
|
|
print(f" First 64 bytes (hex): {raw[:64].hex()}") |
|
|
print(f" Last 32 bytes (hex): {raw[-32:].hex()}") |
|
|
|
|
|
|
|
|
print(f" As uint32 first 8 values: {[int.from_bytes(raw[i:i+4], 'little') for i in range(0, 32, 4)]}") |
|
|
print(f" As float32 first 8 values: {list(np.frombuffer(raw[:32], dtype=np.float32))}") |
|
|
|