"""Verify extracted .bin files as valid ONNX models.""" import os import struct from pathlib import Path EXTRACT_DIR = Path(r"c:\Users\MattyMroz\Desktop\PROJECTS\ONEOCR\extracted_models") VERIFIED_DIR = Path(r"c:\Users\MattyMroz\Desktop\PROJECTS\ONEOCR\verified_models") VERIFIED_DIR.mkdir(exist_ok=True) def try_parse_onnx_protobuf(data: bytes) -> dict | None: """Try to parse the first few fields of an ONNX ModelProto protobuf.""" # ONNX ModelProto: # field 1 (varint) = ir_version # field 2 (len-delimited) = opset_import (repeated) # field 3 (len-delimited) = producer_name # field 4 (len-delimited) = producer_version # field 5 (len-delimited) = domain # field 6 (varint) = model_version # field 7 (len-delimited) = doc_string # field 8 (len-delimited) = graph (GraphProto) if len(data) < 4: return None pos = 0 result = {} try: # Field 1: ir_version (varint, field tag = 0x08) if data[pos] != 0x08: return None pos += 1 # Read varint ir_version = 0 shift = 0 while pos < len(data): b = data[pos] pos += 1 ir_version |= (b & 0x7F) << shift if not (b & 0x80): break shift += 7 if ir_version < 1 or ir_version > 12: return None result['ir_version'] = ir_version # Next field - check tag if pos >= len(data): return None tag = data[pos] field_num = tag >> 3 wire_type = tag & 0x07 # We expect field 2 (opset_import, len-delimited, tag=0x12) or # field 3 (producer_name, len-delimited, tag=0x1a) if tag == 0x12: # field 2, length-delimited pos += 1 # Read length varint length = 0 shift = 0 while pos < len(data): b = data[pos] pos += 1 length |= (b & 0x7F) << shift if not (b & 0x80): break shift += 7 if length > 0 and length < len(data): result['has_opset_or_producer'] = True result['next_field_len'] = length else: return None elif tag == 0x1a: # field 3, length-delimited pos += 1 length = 0 shift = 0 while pos < len(data): b = data[pos] pos += 1 length |= (b & 0x7F) << shift if not (b & 0x80): break shift += 7 if length > 0 and length < 1000: producer = data[pos:pos+length] try: result['producer_name'] = producer.decode('utf-8', errors='strict') except: result['producer_name'] = f"" result['has_opset_or_producer'] = True else: return None return result except (IndexError, ValueError): return None def check_onnx_with_lib(filepath: str) -> dict | None: """Try loading with onnx library.""" try: import onnx model = onnx.load(filepath) return { 'ir_version': model.ir_version, 'producer': model.producer_name, 'model_version': model.model_version, 'opset': [f"{o.domain or 'ai.onnx'}:{o.version}" for o in model.opset_import], 'graph_name': model.graph.name if model.graph else None, 'num_nodes': len(model.graph.node) if model.graph else 0, 'num_inputs': len(model.graph.input) if model.graph else 0, 'num_outputs': len(model.graph.output) if model.graph else 0, } except Exception as e: return None # Phase 1: Quick protobuf header scan print("=" * 70) print("PHASE 1: Quick protobuf header scan") print("=" * 70) candidates = [] files = sorted(EXTRACT_DIR.glob("*.bin"), key=lambda f: f.stat().st_size, reverse=True) print(f"Total files: {len(files)}") for f in files: size = f.stat().st_size if size < 1000: # Skip tiny files continue with open(f, 'rb') as fh: header = fh.read(256) info = try_parse_onnx_protobuf(header) if info and info.get('ir_version', 0) >= 3: candidates.append((f, size, info)) print(f"Candidates with valid ONNX protobuf header: {len(candidates)}") print() # Group by ir_version from collections import Counter ir_counts = Counter(c[2]['ir_version'] for c in candidates) print("IR version distribution:") for v, cnt in sorted(ir_counts.items()): total_size = sum(c[1] for c in candidates if c[2]['ir_version'] == v) print(f" ir_version={v}: {cnt} files, total {total_size/1024/1024:.1f} MB") # Phase 2: Try onnx.load on top candidates (by size, unique sizes to avoid duplicates) print() print("=" * 70) print("PHASE 2: Verify with onnx library (top candidates by size)") print("=" * 70) # Take unique sizes - many files may be near-duplicates from overlapping memory seen_sizes = set() unique_candidates = [] for f, size, info in candidates: # Round to nearest 1KB to detect near-duplicates size_key = size // 1024 if size_key not in seen_sizes: seen_sizes.add(size_key) unique_candidates.append((f, size, info)) print(f"Unique-size candidates: {len(unique_candidates)}") print() verified = [] for i, (f, size, info) in enumerate(unique_candidates[:50]): # Check top 50 by size result = check_onnx_with_lib(str(f)) if result: verified.append((f, size, result)) print(f" VALID ONNX: {f.name}") print(f" Size: {size/1024:.0f} KB") print(f" ir={result['ir_version']} producer='{result['producer']}' " f"opset={result['opset']}") print(f" graph='{result['graph_name']}' nodes={result['num_nodes']} " f"inputs={result['num_inputs']} outputs={result['num_outputs']}") # Copy to verified dir import shutil dest_name = f"model_{len(verified):02d}_ir{result['ir_version']}_{result['graph_name'] or 'unknown'}_{size//1024}KB.onnx" # Clean filename dest_name = dest_name.replace('/', '_').replace('\\', '_').replace(':', '_') dest = VERIFIED_DIR / dest_name shutil.copy2(f, dest) print(f" -> Saved as {dest_name}") print() if not verified: print(" No files passed onnx.load validation in top 50.") print() # Try even more print(" Trying ALL candidates...") for i, (f, size, info) in enumerate(unique_candidates): if i < 50: continue result = check_onnx_with_lib(str(f)) if result: verified.append((f, size, result)) print(f" VALID ONNX: {f.name}") print(f" Size: {size/1024:.0f} KB, ir={result['ir_version']}, " f"producer='{result['producer']}', nodes={result['num_nodes']}") import shutil dest_name = f"model_{len(verified):02d}_ir{result['ir_version']}_{result['graph_name'] or 'unknown'}_{size//1024}KB.onnx" dest_name = dest_name.replace('/', '_').replace('\\', '_').replace(':', '_') dest = VERIFIED_DIR / dest_name shutil.copy2(f, dest) print() print("=" * 70) print(f"SUMMARY: {len(verified)} verified ONNX models out of {len(candidates)} candidates") print("=" * 70) if verified: total_size = sum(v[1] for v in verified) print(f"Total size: {total_size/1024/1024:.1f} MB") for f, size, result in verified: print(f" {f.name}: {size/1024:.0f}KB, {result['num_nodes']} nodes, " f"graph='{result['graph_name']}'")