|
|
"""Verify extracted .bin files as valid ONNX models.""" |
|
|
import os |
|
|
import struct |
|
|
from pathlib import Path |
|
|
|
|
|
EXTRACT_DIR = Path(r"c:\Users\MattyMroz\Desktop\PROJECTS\ONEOCR\extracted_models") |
|
|
VERIFIED_DIR = Path(r"c:\Users\MattyMroz\Desktop\PROJECTS\ONEOCR\verified_models") |
|
|
VERIFIED_DIR.mkdir(exist_ok=True) |
|
|
|
|
|
def try_parse_onnx_protobuf(data: bytes) -> dict | None: |
|
|
"""Try to parse the first few fields of an ONNX ModelProto protobuf.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(data) < 4: |
|
|
return None |
|
|
|
|
|
pos = 0 |
|
|
result = {} |
|
|
|
|
|
try: |
|
|
|
|
|
if data[pos] != 0x08: |
|
|
return None |
|
|
pos += 1 |
|
|
|
|
|
|
|
|
ir_version = 0 |
|
|
shift = 0 |
|
|
while pos < len(data): |
|
|
b = data[pos] |
|
|
pos += 1 |
|
|
ir_version |= (b & 0x7F) << shift |
|
|
if not (b & 0x80): |
|
|
break |
|
|
shift += 7 |
|
|
|
|
|
if ir_version < 1 or ir_version > 12: |
|
|
return None |
|
|
result['ir_version'] = ir_version |
|
|
|
|
|
|
|
|
if pos >= len(data): |
|
|
return None |
|
|
|
|
|
tag = data[pos] |
|
|
field_num = tag >> 3 |
|
|
wire_type = tag & 0x07 |
|
|
|
|
|
|
|
|
|
|
|
if tag == 0x12: |
|
|
pos += 1 |
|
|
|
|
|
length = 0 |
|
|
shift = 0 |
|
|
while pos < len(data): |
|
|
b = data[pos] |
|
|
pos += 1 |
|
|
length |= (b & 0x7F) << shift |
|
|
if not (b & 0x80): |
|
|
break |
|
|
shift += 7 |
|
|
|
|
|
if length > 0 and length < len(data): |
|
|
result['has_opset_or_producer'] = True |
|
|
result['next_field_len'] = length |
|
|
else: |
|
|
return None |
|
|
elif tag == 0x1a: |
|
|
pos += 1 |
|
|
length = 0 |
|
|
shift = 0 |
|
|
while pos < len(data): |
|
|
b = data[pos] |
|
|
pos += 1 |
|
|
length |= (b & 0x7F) << shift |
|
|
if not (b & 0x80): |
|
|
break |
|
|
shift += 7 |
|
|
|
|
|
if length > 0 and length < 1000: |
|
|
producer = data[pos:pos+length] |
|
|
try: |
|
|
result['producer_name'] = producer.decode('utf-8', errors='strict') |
|
|
except: |
|
|
result['producer_name'] = f"<binary {length}b>" |
|
|
result['has_opset_or_producer'] = True |
|
|
else: |
|
|
return None |
|
|
|
|
|
return result |
|
|
|
|
|
except (IndexError, ValueError): |
|
|
return None |
|
|
|
|
|
|
|
|
def check_onnx_with_lib(filepath: str) -> dict | None: |
|
|
"""Try loading with onnx library.""" |
|
|
try: |
|
|
import onnx |
|
|
model = onnx.load(filepath) |
|
|
return { |
|
|
'ir_version': model.ir_version, |
|
|
'producer': model.producer_name, |
|
|
'model_version': model.model_version, |
|
|
'opset': [f"{o.domain or 'ai.onnx'}:{o.version}" for o in model.opset_import], |
|
|
'graph_name': model.graph.name if model.graph else None, |
|
|
'num_nodes': len(model.graph.node) if model.graph else 0, |
|
|
'num_inputs': len(model.graph.input) if model.graph else 0, |
|
|
'num_outputs': len(model.graph.output) if model.graph else 0, |
|
|
} |
|
|
except Exception as e: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
print("=" * 70) |
|
|
print("PHASE 1: Quick protobuf header scan") |
|
|
print("=" * 70) |
|
|
|
|
|
candidates = [] |
|
|
files = sorted(EXTRACT_DIR.glob("*.bin"), key=lambda f: f.stat().st_size, reverse=True) |
|
|
print(f"Total files: {len(files)}") |
|
|
|
|
|
for f in files: |
|
|
size = f.stat().st_size |
|
|
if size < 1000: |
|
|
continue |
|
|
|
|
|
with open(f, 'rb') as fh: |
|
|
header = fh.read(256) |
|
|
|
|
|
info = try_parse_onnx_protobuf(header) |
|
|
if info and info.get('ir_version', 0) >= 3: |
|
|
candidates.append((f, size, info)) |
|
|
|
|
|
print(f"Candidates with valid ONNX protobuf header: {len(candidates)}") |
|
|
print() |
|
|
|
|
|
|
|
|
from collections import Counter |
|
|
ir_counts = Counter(c[2]['ir_version'] for c in candidates) |
|
|
print("IR version distribution:") |
|
|
for v, cnt in sorted(ir_counts.items()): |
|
|
total_size = sum(c[1] for c in candidates if c[2]['ir_version'] == v) |
|
|
print(f" ir_version={v}: {cnt} files, total {total_size/1024/1024:.1f} MB") |
|
|
|
|
|
|
|
|
print() |
|
|
print("=" * 70) |
|
|
print("PHASE 2: Verify with onnx library (top candidates by size)") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
seen_sizes = set() |
|
|
unique_candidates = [] |
|
|
for f, size, info in candidates: |
|
|
|
|
|
size_key = size // 1024 |
|
|
if size_key not in seen_sizes: |
|
|
seen_sizes.add(size_key) |
|
|
unique_candidates.append((f, size, info)) |
|
|
|
|
|
print(f"Unique-size candidates: {len(unique_candidates)}") |
|
|
print() |
|
|
|
|
|
verified = [] |
|
|
for i, (f, size, info) in enumerate(unique_candidates[:50]): |
|
|
result = check_onnx_with_lib(str(f)) |
|
|
if result: |
|
|
verified.append((f, size, result)) |
|
|
print(f" VALID ONNX: {f.name}") |
|
|
print(f" Size: {size/1024:.0f} KB") |
|
|
print(f" ir={result['ir_version']} producer='{result['producer']}' " |
|
|
f"opset={result['opset']}") |
|
|
print(f" graph='{result['graph_name']}' nodes={result['num_nodes']} " |
|
|
f"inputs={result['num_inputs']} outputs={result['num_outputs']}") |
|
|
|
|
|
|
|
|
import shutil |
|
|
dest_name = f"model_{len(verified):02d}_ir{result['ir_version']}_{result['graph_name'] or 'unknown'}_{size//1024}KB.onnx" |
|
|
|
|
|
dest_name = dest_name.replace('/', '_').replace('\\', '_').replace(':', '_') |
|
|
dest = VERIFIED_DIR / dest_name |
|
|
shutil.copy2(f, dest) |
|
|
print(f" -> Saved as {dest_name}") |
|
|
print() |
|
|
|
|
|
if not verified: |
|
|
print(" No files passed onnx.load validation in top 50.") |
|
|
print() |
|
|
|
|
|
print(" Trying ALL candidates...") |
|
|
for i, (f, size, info) in enumerate(unique_candidates): |
|
|
if i < 50: |
|
|
continue |
|
|
result = check_onnx_with_lib(str(f)) |
|
|
if result: |
|
|
verified.append((f, size, result)) |
|
|
print(f" VALID ONNX: {f.name}") |
|
|
print(f" Size: {size/1024:.0f} KB, ir={result['ir_version']}, " |
|
|
f"producer='{result['producer']}', nodes={result['num_nodes']}") |
|
|
|
|
|
import shutil |
|
|
dest_name = f"model_{len(verified):02d}_ir{result['ir_version']}_{result['graph_name'] or 'unknown'}_{size//1024}KB.onnx" |
|
|
dest_name = dest_name.replace('/', '_').replace('\\', '_').replace(':', '_') |
|
|
dest = VERIFIED_DIR / dest_name |
|
|
shutil.copy2(f, dest) |
|
|
|
|
|
print() |
|
|
print("=" * 70) |
|
|
print(f"SUMMARY: {len(verified)} verified ONNX models out of {len(candidates)} candidates") |
|
|
print("=" * 70) |
|
|
|
|
|
if verified: |
|
|
total_size = sum(v[1] for v in verified) |
|
|
print(f"Total size: {total_size/1024/1024:.1f} MB") |
|
|
for f, size, result in verified: |
|
|
print(f" {f.name}: {size/1024:.0f}KB, {result['num_nodes']} nodes, " |
|
|
f"graph='{result['graph_name']}'") |
|
|
|