File size: 7,898 Bytes
ce847d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
"""Verify extracted .bin files as valid ONNX models."""
import os
import struct
from pathlib import Path
EXTRACT_DIR = Path(r"c:\Users\MattyMroz\Desktop\PROJECTS\ONEOCR\extracted_models")
VERIFIED_DIR = Path(r"c:\Users\MattyMroz\Desktop\PROJECTS\ONEOCR\verified_models")
VERIFIED_DIR.mkdir(exist_ok=True)
def try_parse_onnx_protobuf(data: bytes) -> dict | None:
"""Try to parse the first few fields of an ONNX ModelProto protobuf."""
# ONNX ModelProto:
# field 1 (varint) = ir_version
# field 2 (len-delimited) = opset_import (repeated)
# field 3 (len-delimited) = producer_name
# field 4 (len-delimited) = producer_version
# field 5 (len-delimited) = domain
# field 6 (varint) = model_version
# field 7 (len-delimited) = doc_string
# field 8 (len-delimited) = graph (GraphProto)
if len(data) < 4:
return None
pos = 0
result = {}
try:
# Field 1: ir_version (varint, field tag = 0x08)
if data[pos] != 0x08:
return None
pos += 1
# Read varint
ir_version = 0
shift = 0
while pos < len(data):
b = data[pos]
pos += 1
ir_version |= (b & 0x7F) << shift
if not (b & 0x80):
break
shift += 7
if ir_version < 1 or ir_version > 12:
return None
result['ir_version'] = ir_version
# Next field - check tag
if pos >= len(data):
return None
tag = data[pos]
field_num = tag >> 3
wire_type = tag & 0x07
# We expect field 2 (opset_import, len-delimited, tag=0x12) or
# field 3 (producer_name, len-delimited, tag=0x1a)
if tag == 0x12: # field 2, length-delimited
pos += 1
# Read length varint
length = 0
shift = 0
while pos < len(data):
b = data[pos]
pos += 1
length |= (b & 0x7F) << shift
if not (b & 0x80):
break
shift += 7
if length > 0 and length < len(data):
result['has_opset_or_producer'] = True
result['next_field_len'] = length
else:
return None
elif tag == 0x1a: # field 3, length-delimited
pos += 1
length = 0
shift = 0
while pos < len(data):
b = data[pos]
pos += 1
length |= (b & 0x7F) << shift
if not (b & 0x80):
break
shift += 7
if length > 0 and length < 1000:
producer = data[pos:pos+length]
try:
result['producer_name'] = producer.decode('utf-8', errors='strict')
except:
result['producer_name'] = f"<binary {length}b>"
result['has_opset_or_producer'] = True
else:
return None
return result
except (IndexError, ValueError):
return None
def check_onnx_with_lib(filepath: str) -> dict | None:
"""Try loading with onnx library."""
try:
import onnx
model = onnx.load(filepath)
return {
'ir_version': model.ir_version,
'producer': model.producer_name,
'model_version': model.model_version,
'opset': [f"{o.domain or 'ai.onnx'}:{o.version}" for o in model.opset_import],
'graph_name': model.graph.name if model.graph else None,
'num_nodes': len(model.graph.node) if model.graph else 0,
'num_inputs': len(model.graph.input) if model.graph else 0,
'num_outputs': len(model.graph.output) if model.graph else 0,
}
except Exception as e:
return None
# Phase 1: Quick protobuf header scan
print("=" * 70)
print("PHASE 1: Quick protobuf header scan")
print("=" * 70)
candidates = []
files = sorted(EXTRACT_DIR.glob("*.bin"), key=lambda f: f.stat().st_size, reverse=True)
print(f"Total files: {len(files)}")
for f in files:
size = f.stat().st_size
if size < 1000: # Skip tiny files
continue
with open(f, 'rb') as fh:
header = fh.read(256)
info = try_parse_onnx_protobuf(header)
if info and info.get('ir_version', 0) >= 3:
candidates.append((f, size, info))
print(f"Candidates with valid ONNX protobuf header: {len(candidates)}")
print()
# Group by ir_version
from collections import Counter
ir_counts = Counter(c[2]['ir_version'] for c in candidates)
print("IR version distribution:")
for v, cnt in sorted(ir_counts.items()):
total_size = sum(c[1] for c in candidates if c[2]['ir_version'] == v)
print(f" ir_version={v}: {cnt} files, total {total_size/1024/1024:.1f} MB")
# Phase 2: Try onnx.load on top candidates (by size, unique sizes to avoid duplicates)
print()
print("=" * 70)
print("PHASE 2: Verify with onnx library (top candidates by size)")
print("=" * 70)
# Take unique sizes - many files may be near-duplicates from overlapping memory
seen_sizes = set()
unique_candidates = []
for f, size, info in candidates:
# Round to nearest 1KB to detect near-duplicates
size_key = size // 1024
if size_key not in seen_sizes:
seen_sizes.add(size_key)
unique_candidates.append((f, size, info))
print(f"Unique-size candidates: {len(unique_candidates)}")
print()
verified = []
for i, (f, size, info) in enumerate(unique_candidates[:50]): # Check top 50 by size
result = check_onnx_with_lib(str(f))
if result:
verified.append((f, size, result))
print(f" VALID ONNX: {f.name}")
print(f" Size: {size/1024:.0f} KB")
print(f" ir={result['ir_version']} producer='{result['producer']}' "
f"opset={result['opset']}")
print(f" graph='{result['graph_name']}' nodes={result['num_nodes']} "
f"inputs={result['num_inputs']} outputs={result['num_outputs']}")
# Copy to verified dir
import shutil
dest_name = f"model_{len(verified):02d}_ir{result['ir_version']}_{result['graph_name'] or 'unknown'}_{size//1024}KB.onnx"
# Clean filename
dest_name = dest_name.replace('/', '_').replace('\\', '_').replace(':', '_')
dest = VERIFIED_DIR / dest_name
shutil.copy2(f, dest)
print(f" -> Saved as {dest_name}")
print()
if not verified:
print(" No files passed onnx.load validation in top 50.")
print()
# Try even more
print(" Trying ALL candidates...")
for i, (f, size, info) in enumerate(unique_candidates):
if i < 50:
continue
result = check_onnx_with_lib(str(f))
if result:
verified.append((f, size, result))
print(f" VALID ONNX: {f.name}")
print(f" Size: {size/1024:.0f} KB, ir={result['ir_version']}, "
f"producer='{result['producer']}', nodes={result['num_nodes']}")
import shutil
dest_name = f"model_{len(verified):02d}_ir{result['ir_version']}_{result['graph_name'] or 'unknown'}_{size//1024}KB.onnx"
dest_name = dest_name.replace('/', '_').replace('\\', '_').replace(':', '_')
dest = VERIFIED_DIR / dest_name
shutil.copy2(f, dest)
print()
print("=" * 70)
print(f"SUMMARY: {len(verified)} verified ONNX models out of {len(candidates)} candidates")
print("=" * 70)
if verified:
total_size = sum(v[1] for v in verified)
print(f"Total size: {total_size/1024/1024:.1f} MB")
for f, size, result in verified:
print(f" {f.name}: {size/1024:.0f}KB, {result['num_nodes']} nodes, "
f"graph='{result['graph_name']}'")
|