|
|
""" |
|
|
Extract SeglinkProposals custom op attributes from detection model (model_00). |
|
|
Also dump all XY/Seglink related config from the ONNX graph. |
|
|
""" |
|
|
import onnx |
|
|
from onnx import numpy_helper |
|
|
import os |
|
|
|
|
|
model_path = "oneocr_extracted/onnx_models_unlocked/model_00_ir7_onnx_quantize_13118KB.onnx" |
|
|
if not os.path.exists(model_path): |
|
|
|
|
|
model_path = "_archive/onnx_models/model_00_ir7_onnx_quantize_13118KB.onnx" |
|
|
|
|
|
print(f"Loading: {model_path}") |
|
|
model = onnx.load(model_path) |
|
|
graph = model.graph |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("ALL CUSTOM OPS IN DETECTION MODEL") |
|
|
print("="*60) |
|
|
|
|
|
for node in graph.node: |
|
|
if node.domain and 'OneOCR' in node.domain: |
|
|
print(f"\n Op: {node.op_type}") |
|
|
print(f" Domain: {node.domain}") |
|
|
print(f" Name: {node.name}") |
|
|
print(f" Inputs: {list(node.input)}") |
|
|
print(f" Outputs: {list(node.output)}") |
|
|
if node.attribute: |
|
|
print(f" Attributes:") |
|
|
for attr in node.attribute: |
|
|
if attr.type == 1: |
|
|
print(f" {attr.name} = {attr.f}") |
|
|
elif attr.type == 2: |
|
|
print(f" {attr.name} = {attr.i}") |
|
|
elif attr.type == 3: |
|
|
print(f" {attr.name} = {attr.s.decode()}") |
|
|
elif attr.type == 6: |
|
|
print(f" {attr.name} = {list(attr.floats)}") |
|
|
elif attr.type == 7: |
|
|
print(f" {attr.name} = {list(attr.ints)}") |
|
|
else: |
|
|
print(f" {attr.name} = <type {attr.type}>") |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("ALL NODES WITH seglink/proposal/SeglinkProposals IN NAME OR OP") |
|
|
print("="*60) |
|
|
|
|
|
for node in graph.node: |
|
|
if any(kw in (node.op_type + node.name).lower() for kw in ['seglink', 'proposal', 'feature_extract']): |
|
|
print(f"\n Op: {node.op_type} (domain: {node.domain})") |
|
|
print(f" Name: {node.name}") |
|
|
print(f" Inputs: {list(node.input)}") |
|
|
print(f" Outputs: {list(node.output)}") |
|
|
for attr in node.attribute: |
|
|
if attr.type == 1: |
|
|
print(f" {attr.name} = {attr.f}") |
|
|
elif attr.type == 2: |
|
|
print(f" {attr.name} = {attr.i}") |
|
|
elif attr.type == 3: |
|
|
print(f" {attr.name} = {attr.s.decode()}") |
|
|
elif attr.type == 6: |
|
|
print(f" {attr.name} = {list(attr.floats)}") |
|
|
elif attr.type == 7: |
|
|
print(f" {attr.name} = {list(attr.ints)}") |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("MODEL INPUTS") |
|
|
print("="*60) |
|
|
for inp in graph.input: |
|
|
dims = [d.dim_value if d.dim_value else d.dim_param for d in inp.type.tensor_type.shape.dim] |
|
|
print(f" {inp.name}: {dims}") |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("MODEL OUTPUTS") |
|
|
print("="*60) |
|
|
for out in graph.output: |
|
|
dims = [d.dim_value if d.dim_value else d.dim_param for d in out.type.tensor_type.shape.dim] |
|
|
print(f" {out.name}: {dims}") |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("ALL NODE TYPES (unique)") |
|
|
print("="*60) |
|
|
ops = {} |
|
|
for node in graph.node: |
|
|
key = f"{node.domain}::{node.op_type}" if node.domain else node.op_type |
|
|
ops[key] = ops.get(key, 0) + 1 |
|
|
for op, count in sorted(ops.items()): |
|
|
print(f" {op}: {count}x") |
|
|
|