Octen-Embedding-0.6B-ONNX-INT8 / quantize_octen_int8.py
cstr's picture
Upload folder using huggingface_hub
b81fa7f verified
"""
Quantize the exported Octen-Embedding-0.6B ONNX to dynamic INT8.
Why this script won't OOM
--------------------------
Default quantize_dynamic loads the full protobuf into Python memory and writes
the quantized weights inline β†’ peak is ~3-4Γ— model size. We avoid that via:
1. op_types_to_quantize=['MatMul']
Skip Gather (embedding lookup table β‰ˆ 600 MB for 151 936 Γ— 1024 vocab).
Quantizing Gather saves nothing at inference (it's indexed element-wise)
but would waste RAM on copying the largest single tensor.
2. use_external_data_format=True (output)
Writes quantized weights to a separate binary file instead of inlining
them in the protobuf. Inline bytes are kept as Python list[int] β†’ 2-3Γ—
the actual byte count. External format writes direct byte buffers.
3. per_channel=False
Per-tensor calibration β€” much cheaper than per-channel.
4. Pass file-path strings, not pre-loaded ModelProto.
Lets ORT handle mmap / lazy loading instead of pulling everything into
the Python heap up front.
5. Explicit gc.collect() between load / quantize / save phases.
Expected peak RAM: ~2 GB (FP32 weights in ORT's internal buffers) + ~500 MB
(INT8 copy) + Python overhead β‰ˆ 3 GB total β€” well within 4 GB headroom.
"""
import gc
import os
import sys
from pathlib import Path
# ── Paths ─────────────────────────────────────────────────────────────────────
STAGING = Path(__file__).parent # directory this script lives in
INPUT = STAGING / "model.onnx" # FP32 ONNX (weights in model.onnx.data)
OUT_DIR = STAGING / "model_int8" # output directory
OUTPUT = OUT_DIR / "model.int8.onnx" # quantized ONNX (weights in model.int8.onnx.data)
# ── Sanity checks ─────────────────────────────────────────────────────────────
if not INPUT.exists():
sys.exit(f"ERROR: model not found at {INPUT}")
if not (STAGING / "model.onnx.data").exists():
sys.exit("ERROR: model.onnx.data not found alongside model.onnx")
OUT_DIR.mkdir(parents=True, exist_ok=True)
# ── Quick inspection ──────────────────────────────────────────────────────────
print("Inspecting model graph (no weights loaded) …")
import onnx
proto = onnx.load(str(INPUT), load_external_data=False)
matmul_count = sum(1 for n in proto.graph.node if n.op_type == "MatMul")
gather_count = sum(1 for n in proto.graph.node if n.op_type == "Gather")
total_nodes = len(proto.graph.node)
print(f" nodes={total_nodes} MatMul={matmul_count} Gather={gather_count}")
del proto
gc.collect()
# ── Quantize ──────────────────────────────────────────────────────────────────
print("\nQuantizing MatMul weights to dynamic INT8 …")
print(f" input : {INPUT} ({INPUT.stat().st_size/1e6:.0f} MB graph)")
print(f" data : {STAGING/'model.onnx.data'} "
f"({(STAGING/'model.onnx.data').stat().st_size/1e9:.2f} GB)")
print(f" output : {OUTPUT}")
print()
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
model_input=str(INPUT),
model_output=str(OUTPUT),
# Only quantize MatMul β€” skip Gather (embedding table β‰ˆ 600 MB, no benefit)
op_types_to_quantize=["MatMul"],
# Per-tensor (not per-channel): simpler, less memory, fine for embeddings
per_channel=False,
# Signed INT8 weights
weight_type=QuantType.QInt8,
# Write quantized weights as a companion .onnx.data file, not inline bytes
use_external_data_format=True,
)
gc.collect()
# ── Verify ────────────────────────────────────────────────────────────────────
print("\nVerifying output …")
out_graph = OUTPUT
out_data = OUT_DIR / "model.int8.onnx.data"
if not out_graph.exists():
sys.exit("ERROR: output ONNX file not created!")
print(f" graph file : {out_graph.stat().st_size/1e6:.1f} MB")
if out_data.exists():
print(f" data file : {out_data.stat().st_size/1e9:.2f} GB")
proto_q = onnx.load(str(out_graph), load_external_data=False)
print(f" inputs : {[i.name for i in proto_q.graph.input]}")
print(f" outputs : {[o.name for o in proto_q.graph.output]}")
qdq_nodes = sum(1 for n in proto_q.graph.node if "Quant" in n.op_type or "Dequant" in n.op_type)
print(f" quant/dequant nodes added: {qdq_nodes}")
del proto_q
gc.collect()
# ── Copy tokenizer files ───────────────────────────────────────────────────────
import shutil
tok_files = [
"tokenizer.json", "tokenizer_config.json", "special_tokens_map.json",
"vocab.json", "merges.txt", "added_tokens.json",
"config.json",
]
print("\nCopying tokenizer files …")
for f in tok_files:
src = STAGING / f
if src.exists():
shutil.copy2(src, OUT_DIR / f)
print(f" copied {f}")
print(f"\nDone. Quantized model written to: {OUT_DIR}/")
print("Files:")
for p in sorted(OUT_DIR.iterdir()):
print(f" {p.name:40s} {p.stat().st_size/1e6:.1f} MB")