""" Unlock ALL OneOCRFeatureExtract models (11-33). Replaces the custom `OneOCRFeatureExtract` op (domain: com.microsoft.oneocr) with a standard ONNX `Gemm` node. The weights are extracted from the big-endian float32 config blob stored as a STRING tensor. Config blob structure (for small/medium LM models 11-32): - W[input_dim × output_dim] as big-endian float32 - b[output_dim] as big-endian float32 - metadata[remaining] containing dimensions, flags, etc. Usage: python unlock_models.py # unlock all 11-33 python unlock_models.py 11 22 33 # unlock specific models """ import onnx from onnx import numpy_helper, helper import numpy as np from pathlib import Path import copy import sys def extract_fe_weights(model: onnx.ModelProto) -> tuple[np.ndarray, np.ndarray, dict]: """Extract weights from OneOCRFeatureExtract config blob. The config blob is stored as big-endian float32: W[in_dim × out_dim] + b[out_dim] + metadata The metadata tail contains the dimensions as float values. Returns: (weight_matrix, bias, metadata_dict) """ # Find the feature/config initializer config_blob = None for init in model.graph.initializer: if init.name == "feature/config": config_blob = bytes(init.string_data[0]) if init.string_data else bytes(init.raw_data) break if config_blob is None: raise ValueError("No feature/config initializer found") # Parse as big-endian float32 be_arr = np.frombuffer(config_blob, dtype='>f4').copy() # Find the OneOCRFeatureExtract node to determine input/output dimensions fe_node = None for node in model.graph.node: if node.op_type == "OneOCRFeatureExtract": fe_node = node break if fe_node is None: raise ValueError("No OneOCRFeatureExtract node found") # Get input/output dimensions from the graph # Input comes from a normalization pipeline, output goes to Gemm in_dim = None out_dim = None # Try to get dims from metadata at the end of blob # Pattern: [..., in_dim, out_dim, num_classes, ...] near the end for i in range(len(be_arr) - 10, len(be_arr)): val = be_arr[i] if val == 21.0 and i + 1 < len(be_arr) and be_arr[i + 1] in [50.0, 51.0]: in_dim = int(val) out_dim = int(be_arr[i + 1]) break # Fallback: infer from graph inputs if in_dim is None: for graph_input in model.graph.input: if graph_input.name == "data": shape = [d.dim_value for d in graph_input.type.tensor_type.shape.dim] if len(shape) >= 2: in_dim = shape[1] if shape[1] > 0 else 21 break if out_dim is None: # Find the Gemm after OneOCRFeatureExtract output fe_output = fe_node.output[0] for node in model.graph.node: if node.op_type == "Gemm" and fe_output in node.input: # The Gemm's weight tells us the output dim weight_name = node.input[1] for init in model.graph.initializer: if init.name == weight_name: W = numpy_helper.to_array(init) out_dim = W.shape[0] if len(W.shape) == 2 else W.shape[1] break break if in_dim is None or out_dim is None: raise ValueError(f"Could not determine dimensions: in={in_dim}, out={out_dim}") # Extract weights: first in_dim*out_dim floats = W, next out_dim = b n_weights = in_dim * out_dim n_bias = out_dim if len(be_arr) < n_weights + n_bias: raise ValueError(f"Config blob too small: {len(be_arr)} < {n_weights + n_bias}") W = be_arr[:n_weights].reshape(in_dim, out_dim).astype(np.float32) b = be_arr[n_weights:n_weights + n_bias].astype(np.float32) metadata = be_arr[n_weights + n_bias:] meta_dict = { "in_dim": in_dim, "out_dim": out_dim, "total_floats": len(be_arr), "metadata_floats": len(metadata), "metadata_values": metadata.tolist(), } return W, b, meta_dict def unlock_model(model_path: Path, output_dir: Path) -> Path: """Replace OneOCRFeatureExtract with standard Gemm in an ONNX model. Args: model_path: Path to the original ONNX model. output_dir: Directory to save the modified model. Returns: Path to the modified model. """ model = onnx.load(str(model_path)) # Check if model uses OneOCRFeatureExtract has_custom_op = any( node.op_type == "OneOCRFeatureExtract" for node in model.graph.node ) if not has_custom_op: print(f" {model_path.name}: No OneOCRFeatureExtract — skipping") return model_path # Extract weights try: W, b, meta = extract_fe_weights(model) except Exception as e: print(f" {model_path.name}: Failed to extract weights: {e}") return model_path print(f" {model_path.name}: W[{meta['in_dim']}×{meta['out_dim']}] + b[{meta['out_dim']}] " f"(metadata: {meta['metadata_floats']} floats)") # Modify the model new_model = copy.deepcopy(model) # Find the OneOCRFeatureExtract node fe_node = None for node in new_model.graph.node: if node.op_type == "OneOCRFeatureExtract": fe_node = node break fe_input = fe_node.input[0] fe_output = fe_node.output[0] # Replace initializers: remove feature/config, add W and b new_inits = [init for init in new_model.graph.initializer if init.name != "feature/config"] new_inits.append(numpy_helper.from_array(W.T, name="fe_weight")) # [out, in] for transB=1 new_inits.append(numpy_helper.from_array(b, name="fe_bias")) del new_model.graph.initializer[:] new_model.graph.initializer.extend(new_inits) # Replace the custom op node with Gemm new_nodes = [] for node in new_model.graph.node: if node.op_type == "OneOCRFeatureExtract": gemm_node = helper.make_node( "Gemm", inputs=[fe_input, "fe_weight", "fe_bias"], outputs=[fe_output], alpha=1.0, beta=1.0, transB=1, ) new_nodes.append(gemm_node) else: new_nodes.append(node) del new_model.graph.node[:] new_model.graph.node.extend(new_nodes) # Clean up inputs (remove feature/config) new_inputs = [inp for inp in new_model.graph.input if inp.name != "feature/config"] del new_model.graph.input[:] new_model.graph.input.extend(new_inputs) # Remove custom opset domain new_opsets = [op for op in new_model.opset_import if op.domain != "com.microsoft.oneocr"] del new_model.opset_import[:] new_model.opset_import.extend(new_opsets) # Save output_dir.mkdir(parents=True, exist_ok=True) out_name = model_path.stem + "_unlocked.onnx" out_path = output_dir / out_name onnx.save(new_model, str(out_path)) # Verify it loads in onnxruntime try: import onnxruntime as ort sess = ort.InferenceSession(str(out_path)) # Quick test with zero input input_info = sess.get_inputs() feeds = {} for inp in input_info: shape = [d if isinstance(d, int) and d > 0 else 1 for d in inp.shape] feeds[inp.name] = np.zeros(shape, dtype=np.float32) result = sess.run(None, feeds) print(f" ✓ Inference OK — output shapes: {[r.shape for r in result]}") except Exception as e: print(f" ✗ Inference failed: {e}") return out_path def main(): models_dir = Path("oneocr_extracted/onnx_models") output_dir = Path("oneocr_extracted/onnx_models_unlocked") # Determine which models to process if len(sys.argv) > 1: indices = [int(x) for x in sys.argv[1:]] else: indices = list(range(11, 34)) # models 11-33 print(f"Unlocking {len(indices)} models...") print(f"Source: {models_dir}") print(f"Output: {output_dir}") print() results = {"success": [], "skip": [], "fail": []} for idx in indices: matches = list(models_dir.glob(f"model_{idx:02d}_*")) if not matches: print(f" model_{idx:02d}: NOT FOUND") results["fail"].append(idx) continue model_path = matches[0] try: out = unlock_model(model_path, output_dir) if out == model_path: results["skip"].append(idx) else: results["success"].append(idx) except Exception as e: print(f" model_{idx:02d}: ERROR — {e}") results["fail"].append(idx) # Summary print(f"\n{'='*60}") print(f"Results:") print(f" Unlocked: {len(results['success'])} — {results['success']}") print(f" Skipped: {len(results['skip'])} — {results['skip']}") print(f" Failed: {len(results['fail'])} — {results['fail']}") if __name__ == "__main__": main()