oneocr / _archive /attempts /verify_models.py
OneOCR Dev
OneOCR - reverse engineering complete, ONNX pipeline 53% match rate
ce847d4
"""Verify extracted .bin files as valid ONNX models."""
import os
import struct
from pathlib import Path
EXTRACT_DIR = Path(r"c:\Users\MattyMroz\Desktop\PROJECTS\ONEOCR\extracted_models")
VERIFIED_DIR = Path(r"c:\Users\MattyMroz\Desktop\PROJECTS\ONEOCR\verified_models")
VERIFIED_DIR.mkdir(exist_ok=True)
def try_parse_onnx_protobuf(data: bytes) -> dict | None:
"""Try to parse the first few fields of an ONNX ModelProto protobuf."""
# ONNX ModelProto:
# field 1 (varint) = ir_version
# field 2 (len-delimited) = opset_import (repeated)
# field 3 (len-delimited) = producer_name
# field 4 (len-delimited) = producer_version
# field 5 (len-delimited) = domain
# field 6 (varint) = model_version
# field 7 (len-delimited) = doc_string
# field 8 (len-delimited) = graph (GraphProto)
if len(data) < 4:
return None
pos = 0
result = {}
try:
# Field 1: ir_version (varint, field tag = 0x08)
if data[pos] != 0x08:
return None
pos += 1
# Read varint
ir_version = 0
shift = 0
while pos < len(data):
b = data[pos]
pos += 1
ir_version |= (b & 0x7F) << shift
if not (b & 0x80):
break
shift += 7
if ir_version < 1 or ir_version > 12:
return None
result['ir_version'] = ir_version
# Next field - check tag
if pos >= len(data):
return None
tag = data[pos]
field_num = tag >> 3
wire_type = tag & 0x07
# We expect field 2 (opset_import, len-delimited, tag=0x12) or
# field 3 (producer_name, len-delimited, tag=0x1a)
if tag == 0x12: # field 2, length-delimited
pos += 1
# Read length varint
length = 0
shift = 0
while pos < len(data):
b = data[pos]
pos += 1
length |= (b & 0x7F) << shift
if not (b & 0x80):
break
shift += 7
if length > 0 and length < len(data):
result['has_opset_or_producer'] = True
result['next_field_len'] = length
else:
return None
elif tag == 0x1a: # field 3, length-delimited
pos += 1
length = 0
shift = 0
while pos < len(data):
b = data[pos]
pos += 1
length |= (b & 0x7F) << shift
if not (b & 0x80):
break
shift += 7
if length > 0 and length < 1000:
producer = data[pos:pos+length]
try:
result['producer_name'] = producer.decode('utf-8', errors='strict')
except:
result['producer_name'] = f"<binary {length}b>"
result['has_opset_or_producer'] = True
else:
return None
return result
except (IndexError, ValueError):
return None
def check_onnx_with_lib(filepath: str) -> dict | None:
"""Try loading with onnx library."""
try:
import onnx
model = onnx.load(filepath)
return {
'ir_version': model.ir_version,
'producer': model.producer_name,
'model_version': model.model_version,
'opset': [f"{o.domain or 'ai.onnx'}:{o.version}" for o in model.opset_import],
'graph_name': model.graph.name if model.graph else None,
'num_nodes': len(model.graph.node) if model.graph else 0,
'num_inputs': len(model.graph.input) if model.graph else 0,
'num_outputs': len(model.graph.output) if model.graph else 0,
}
except Exception as e:
return None
# Phase 1: Quick protobuf header scan
print("=" * 70)
print("PHASE 1: Quick protobuf header scan")
print("=" * 70)
candidates = []
files = sorted(EXTRACT_DIR.glob("*.bin"), key=lambda f: f.stat().st_size, reverse=True)
print(f"Total files: {len(files)}")
for f in files:
size = f.stat().st_size
if size < 1000: # Skip tiny files
continue
with open(f, 'rb') as fh:
header = fh.read(256)
info = try_parse_onnx_protobuf(header)
if info and info.get('ir_version', 0) >= 3:
candidates.append((f, size, info))
print(f"Candidates with valid ONNX protobuf header: {len(candidates)}")
print()
# Group by ir_version
from collections import Counter
ir_counts = Counter(c[2]['ir_version'] for c in candidates)
print("IR version distribution:")
for v, cnt in sorted(ir_counts.items()):
total_size = sum(c[1] for c in candidates if c[2]['ir_version'] == v)
print(f" ir_version={v}: {cnt} files, total {total_size/1024/1024:.1f} MB")
# Phase 2: Try onnx.load on top candidates (by size, unique sizes to avoid duplicates)
print()
print("=" * 70)
print("PHASE 2: Verify with onnx library (top candidates by size)")
print("=" * 70)
# Take unique sizes - many files may be near-duplicates from overlapping memory
seen_sizes = set()
unique_candidates = []
for f, size, info in candidates:
# Round to nearest 1KB to detect near-duplicates
size_key = size // 1024
if size_key not in seen_sizes:
seen_sizes.add(size_key)
unique_candidates.append((f, size, info))
print(f"Unique-size candidates: {len(unique_candidates)}")
print()
verified = []
for i, (f, size, info) in enumerate(unique_candidates[:50]): # Check top 50 by size
result = check_onnx_with_lib(str(f))
if result:
verified.append((f, size, result))
print(f" VALID ONNX: {f.name}")
print(f" Size: {size/1024:.0f} KB")
print(f" ir={result['ir_version']} producer='{result['producer']}' "
f"opset={result['opset']}")
print(f" graph='{result['graph_name']}' nodes={result['num_nodes']} "
f"inputs={result['num_inputs']} outputs={result['num_outputs']}")
# Copy to verified dir
import shutil
dest_name = f"model_{len(verified):02d}_ir{result['ir_version']}_{result['graph_name'] or 'unknown'}_{size//1024}KB.onnx"
# Clean filename
dest_name = dest_name.replace('/', '_').replace('\\', '_').replace(':', '_')
dest = VERIFIED_DIR / dest_name
shutil.copy2(f, dest)
print(f" -> Saved as {dest_name}")
print()
if not verified:
print(" No files passed onnx.load validation in top 50.")
print()
# Try even more
print(" Trying ALL candidates...")
for i, (f, size, info) in enumerate(unique_candidates):
if i < 50:
continue
result = check_onnx_with_lib(str(f))
if result:
verified.append((f, size, result))
print(f" VALID ONNX: {f.name}")
print(f" Size: {size/1024:.0f} KB, ir={result['ir_version']}, "
f"producer='{result['producer']}', nodes={result['num_nodes']}")
import shutil
dest_name = f"model_{len(verified):02d}_ir{result['ir_version']}_{result['graph_name'] or 'unknown'}_{size//1024}KB.onnx"
dest_name = dest_name.replace('/', '_').replace('\\', '_').replace(':', '_')
dest = VERIFIED_DIR / dest_name
shutil.copy2(f, dest)
print()
print("=" * 70)
print(f"SUMMARY: {len(verified)} verified ONNX models out of {len(candidates)} candidates")
print("=" * 70)
if verified:
total_size = sum(v[1] for v in verified)
print(f"Total size: {total_size/1024/1024:.1f} MB")
for f, size, result in verified:
print(f" {f.name}: {size/1024:.0f}KB, {result['num_nodes']} nodes, "
f"graph='{result['graph_name']}'")