|
|
"""Quick inspection of specific ONNX models.""" |
|
|
import onnx |
|
|
from pathlib import Path |
|
|
|
|
|
for idx in [11, 12, 22, 33]: |
|
|
matches = list(Path('oneocr_extracted/onnx_models').glob(f'model_{idx:02d}_*')) |
|
|
if not matches: |
|
|
continue |
|
|
m = onnx.load(str(matches[0])) |
|
|
print(f'\n=== model_{idx:02d} ({matches[0].name}) ===') |
|
|
print('Opsets:', [(o.domain, o.version) for o in m.opset_import]) |
|
|
custom_ops = set() |
|
|
for n in m.graph.node: |
|
|
if n.domain: |
|
|
custom_ops.add((n.domain, n.op_type)) |
|
|
print('Custom ops:', list(custom_ops)) |
|
|
for i in m.graph.input: |
|
|
dims = [] |
|
|
if i.type.tensor_type.HasField('shape'): |
|
|
for d in i.type.tensor_type.shape.dim: |
|
|
v = str(d.dim_value) if d.dim_value else (d.dim_param or '?') |
|
|
dims.append(v) |
|
|
etype = i.type.tensor_type.elem_type |
|
|
print(f' Input: {i.name} shape=[{", ".join(dims)}] dtype={etype}') |
|
|
for o in m.graph.output: |
|
|
dims = [] |
|
|
if o.type.tensor_type.HasField('shape'): |
|
|
for d in o.type.tensor_type.shape.dim: |
|
|
v = str(d.dim_value) if d.dim_value else (d.dim_param or '?') |
|
|
dims.append(v) |
|
|
etype = o.type.tensor_type.elem_type |
|
|
print(f' Output: {o.name} shape=[{", ".join(dims)}] dtype={etype}') |
|
|
cnt = {} |
|
|
for n in m.graph.node: |
|
|
key = f'{n.domain}::{n.op_type}' if n.domain else n.op_type |
|
|
cnt[key] = cnt.get(key, 0) + 1 |
|
|
print(' Nodes:', dict(cnt)) |
|
|
|
|
|
|
|
|
for n in m.graph.node: |
|
|
if n.domain: |
|
|
attrs = {a.name: a for a in n.attribute} |
|
|
attr_summary = {} |
|
|
for k, a in attrs.items(): |
|
|
if a.type == 2: |
|
|
attr_summary[k] = a.f |
|
|
elif a.type == 1: |
|
|
attr_summary[k] = a.i |
|
|
elif a.type == 3: |
|
|
attr_summary[k] = a.s.decode() |
|
|
print(f' CustomOp {n.op_type} attrs: {attr_summary}') |
|
|
|
|
|
print(f' inputs: {list(n.input)}') |
|
|
print(f' outputs: {list(n.output)}') |
|
|
break |
|
|
|