| """ |
| Quantize the exported Octen-Embedding-0.6B ONNX to dynamic INT8. |
| |
| Why this script won't OOM |
| -------------------------- |
| Default quantize_dynamic loads the full protobuf into Python memory and writes |
| the quantized weights inline β peak is ~3-4Γ model size. We avoid that via: |
| |
| 1. op_types_to_quantize=['MatMul'] |
| Skip Gather (embedding lookup table β 600 MB for 151 936 Γ 1024 vocab). |
| Quantizing Gather saves nothing at inference (it's indexed element-wise) |
| but would waste RAM on copying the largest single tensor. |
| |
| 2. use_external_data_format=True (output) |
| Writes quantized weights to a separate binary file instead of inlining |
| them in the protobuf. Inline bytes are kept as Python list[int] β 2-3Γ |
| the actual byte count. External format writes direct byte buffers. |
| |
| 3. per_channel=False |
| Per-tensor calibration β much cheaper than per-channel. |
| |
| 4. Pass file-path strings, not pre-loaded ModelProto. |
| Lets ORT handle mmap / lazy loading instead of pulling everything into |
| the Python heap up front. |
| |
| 5. Explicit gc.collect() between load / quantize / save phases. |
| |
| Expected peak RAM: ~2 GB (FP32 weights in ORT's internal buffers) + ~500 MB |
| (INT8 copy) + Python overhead β 3 GB total β well within 4 GB headroom. |
| """ |
|
|
| import gc |
| import os |
| import sys |
| from pathlib import Path |
|
|
| |
|
|
| STAGING = Path(__file__).parent |
| INPUT = STAGING / "model.onnx" |
| OUT_DIR = STAGING / "model_int8" |
| OUTPUT = OUT_DIR / "model.int8.onnx" |
|
|
| |
|
|
| if not INPUT.exists(): |
| sys.exit(f"ERROR: model not found at {INPUT}") |
| if not (STAGING / "model.onnx.data").exists(): |
| sys.exit("ERROR: model.onnx.data not found alongside model.onnx") |
|
|
| OUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| |
|
|
| print("Inspecting model graph (no weights loaded) β¦") |
| import onnx |
| proto = onnx.load(str(INPUT), load_external_data=False) |
| matmul_count = sum(1 for n in proto.graph.node if n.op_type == "MatMul") |
| gather_count = sum(1 for n in proto.graph.node if n.op_type == "Gather") |
| total_nodes = len(proto.graph.node) |
| print(f" nodes={total_nodes} MatMul={matmul_count} Gather={gather_count}") |
| del proto |
| gc.collect() |
|
|
| |
|
|
| print("\nQuantizing MatMul weights to dynamic INT8 β¦") |
| print(f" input : {INPUT} ({INPUT.stat().st_size/1e6:.0f} MB graph)") |
| print(f" data : {STAGING/'model.onnx.data'} " |
| f"({(STAGING/'model.onnx.data').stat().st_size/1e9:.2f} GB)") |
| print(f" output : {OUTPUT}") |
| print() |
|
|
| from onnxruntime.quantization import quantize_dynamic, QuantType |
|
|
| quantize_dynamic( |
| model_input=str(INPUT), |
| model_output=str(OUTPUT), |
|
|
| |
| op_types_to_quantize=["MatMul"], |
|
|
| |
| per_channel=False, |
|
|
| |
| weight_type=QuantType.QInt8, |
|
|
| |
| use_external_data_format=True, |
| ) |
|
|
| gc.collect() |
|
|
| |
|
|
| print("\nVerifying output β¦") |
| out_graph = OUTPUT |
| out_data = OUT_DIR / "model.int8.onnx.data" |
|
|
| if not out_graph.exists(): |
| sys.exit("ERROR: output ONNX file not created!") |
|
|
| print(f" graph file : {out_graph.stat().st_size/1e6:.1f} MB") |
| if out_data.exists(): |
| print(f" data file : {out_data.stat().st_size/1e9:.2f} GB") |
|
|
| proto_q = onnx.load(str(out_graph), load_external_data=False) |
| print(f" inputs : {[i.name for i in proto_q.graph.input]}") |
| print(f" outputs : {[o.name for o in proto_q.graph.output]}") |
| qdq_nodes = sum(1 for n in proto_q.graph.node if "Quant" in n.op_type or "Dequant" in n.op_type) |
| print(f" quant/dequant nodes added: {qdq_nodes}") |
| del proto_q |
| gc.collect() |
|
|
| |
|
|
| import shutil |
| tok_files = [ |
| "tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", |
| "vocab.json", "merges.txt", "added_tokens.json", |
| "config.json", |
| ] |
| print("\nCopying tokenizer files β¦") |
| for f in tok_files: |
| src = STAGING / f |
| if src.exists(): |
| shutil.copy2(src, OUT_DIR / f) |
| print(f" copied {f}") |
|
|
| print(f"\nDone. Quantized model written to: {OUT_DIR}/") |
| print("Files:") |
| for p in sorted(OUT_DIR.iterdir()): |
| print(f" {p.name:40s} {p.stat().st_size/1e6:.1f} MB") |
|
|