File size: 10,479 Bytes
73e1c32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 | #!/usr/bin/env python3
"""
Streaming weight-only INT8 quantizer for large ONNX models.
Implements the same transformation as:
quantize_dynamic(..., MatMulConstBOnly=True, per_channel=False, weight_type=QInt8)
Fully streaming: reads and writes one tensor at a time.
Peak RAM: ~1.5 GB (for the largest single tensor, the embedding table ~1.2 GB).
Usage: python stream_int8.py
"""
import gc
from pathlib import Path
import numpy as np
import onnx
from onnx import TensorProto, numpy_helper, helper
FP32_ONNX = Path("/Volumes/backups/ai/zerank_fp32_tmp/model_fp32.onnx")
FP32_DATA = Path("/Volumes/backups/ai/zerank_fp32_tmp/model_fp32.onnx_data")
INT8_OUT = Path("/Volumes/backups/ai/zerank_onnx_int8/model_int8.onnx")
INT8_DATA = Path("/Volumes/backups/ai/zerank_onnx_int8/model_int8.onnx_data")
MODEL_ID = "zeroentropy/zerank-1-small"
INT8_OUT.parent.mkdir(parents=True, exist_ok=True)
def quantize_tensor_per_tensor(arr: np.ndarray):
"""Symmetric per-tensor INT8 quantization (zero_point = 0)."""
arr = arr.astype(np.float32)
abs_max = np.max(np.abs(arr))
if abs_max == 0:
scale = np.float32(1.0)
quantized = np.zeros_like(arr, dtype=np.int8)
else:
scale = np.float32(abs_max / 127.0)
quantized = np.clip(np.round(arr / scale), -127, 127).astype(np.int8)
return quantized, scale
def add_external_data(init: onnx.TensorProto, offset: int, length: int, data_file_name: str):
"""Update an initializer proto to point to external data."""
init.data_location = TensorProto.EXTERNAL
init.ClearField("external_data")
for k, v in [("location", data_file_name), ("offset", str(offset)), ("length", str(length))]:
e = init.external_data.add()
e.key, e.value = k, v
def quantize_model():
print(f"Loading proto skeleton (no external data)...")
m = onnx.load(str(FP32_ONNX), load_external_data=False)
print(f" Nodes: {len(m.graph.node)}, Initializers: {len(m.graph.initializer)}")
# Build index of external initializers
ext_index = {} # name β (offset, length, dtype, dims)
inline_index = {} # name β data bytes (for inline tensors)
for init in m.graph.initializer:
if init.data_location == TensorProto.EXTERNAL:
info = {e.key: e.value for e in init.external_data}
ext_index[init.name] = {
"offset": int(info.get("offset", 0)),
"length": int(info.get("length", 0)),
"dtype": init.data_type,
"dims": list(init.dims),
}
else:
inline_index[init.name] = init
# Find all MatMul nodes with constant B (initializer)
matmul_b_names = set()
for node in m.graph.node:
if node.op_type == "MatMul" and len(node.input) >= 2:
b_name = node.input[1]
if b_name in ext_index or b_name in inline_index:
matmul_b_names.add(b_name)
print(f" MatMul B weights to quantize: {len(matmul_b_names)}")
non_matmul = [name for name, meta in ext_index.items() if name not in matmul_b_names]
print(f" Non-MatMul external tensors (kept as FP32): {len(non_matmul)}")
# ββ Phase 1: Stream all tensors to INT8 data file βββββββββββββββββββββββββ
print(f"\nPhase 1: Writing tensor data to {INT8_DATA.name}")
data_file_name = INT8_DATA.name # just the filename, not full path
# Track where each tensor ends up in the output data file
# key β (offset, length) for the output
out_positions = {} # name β (offset, length)
# For quantized weights: also store scale values (tiny, inline later)
scale_values = {} # weight_name β float32 scale
try:
from tqdm import tqdm
except ImportError:
tqdm = None
offset = 0
with open(str(FP32_DATA), "rb") as fp32_f, open(str(INT8_DATA), "wb") as int8_f:
# 1a. Write quantized MatMul weights (INT8)
matmul_list = sorted(matmul_b_names)
if tqdm:
it = tqdm(matmul_list, desc=" Quantizing MatMul weights")
else:
it = matmul_list
for w_name in it:
if w_name in ext_index:
meta = ext_index[w_name]
fp32_f.seek(meta["offset"])
raw = fp32_f.read(meta["length"])
arr = np.frombuffer(raw, dtype=np.float32).reshape(meta["dims"])
else:
arr = numpy_helper.to_array(inline_index[w_name]).astype(np.float32)
q_arr, scale_val = quantize_tensor_per_tensor(arr)
del arr
scale_values[w_name] = scale_val
raw_int8 = q_arr.tobytes()
int8_f.write(raw_int8)
out_positions[w_name + "_quantized"] = (offset, len(raw_int8))
offset += len(raw_int8)
del q_arr
# 1b. Copy non-MatMul external tensors verbatim (already FP32/int64/etc.)
print(f" Copying {len(non_matmul)} non-MatMul tensors...")
for name in non_matmul:
meta = ext_index[name]
fp32_f.seek(meta["offset"])
raw = fp32_f.read(meta["length"])
int8_f.write(raw)
out_positions[name] = (offset, len(raw))
offset += len(raw)
print(f" Data file written: {INT8_DATA.stat().st_size / 1e9:.2f} GB")
# ββ Phase 2: Rebuild the ONNX proto βββββββββββββββββββββββββββββββββββββββ
print("\nPhase 2: Rebuilding ONNX proto...")
# Rebuild graph: replace MatMul nodes with DQL β MatMul
new_nodes = []
dql_inserted = set()
for node in m.graph.node:
if node.op_type == "MatMul" and node.input[1] in matmul_b_names:
b_name = node.input[1]
dql_out_name = b_name + "_dequant"
if b_name not in dql_inserted:
dql_node = helper.make_node(
"DequantizeLinear",
inputs=[b_name + "_quantized", b_name + "_scale", b_name + "_zero_point"],
outputs=[dql_out_name],
)
new_nodes.append(dql_node)
dql_inserted.add(b_name)
new_node = helper.make_node(
"MatMul",
inputs=[node.input[0], dql_out_name],
outputs=list(node.output),
name=node.name,
)
new_nodes.append(new_node)
else:
new_nodes.append(node)
del m.graph.node[:]
m.graph.node.extend(new_nodes)
# Rebuild initializers
new_initializers = []
# a. Quantized MatMul weights (external data)
for w_name in matmul_b_names:
meta = ext_index.get(w_name) or {
"dims": list(numpy_helper.to_array(inline_index[w_name]).shape)
}
dims = meta["dims"]
q_init = TensorProto()
q_init.name = w_name + "_quantized"
q_init.data_type = TensorProto.INT8
q_init.dims.extend(dims)
off, length = out_positions[w_name + "_quantized"]
add_external_data(q_init, off, length, data_file_name)
scale_init = numpy_helper.from_array(
np.array([scale_values[w_name]], dtype=np.float32), name=w_name + "_scale"
)
zp_init = numpy_helper.from_array(
np.array([0], dtype=np.int8), name=w_name + "_zero_point"
)
new_initializers.extend([q_init, scale_init, zp_init])
# b. Non-MatMul external tensors (external data, already written)
for name in non_matmul:
meta = ext_index[name]
init = TensorProto()
init.name = name
init.data_type = meta["dtype"]
init.dims.extend(meta["dims"])
off, length = out_positions[name]
add_external_data(init, off, length, data_file_name)
new_initializers.append(init)
# c. Inline initializers from FP32 model (already inline in proto β not external data)
for init in m.graph.initializer:
if init.name not in ext_index: # it's inline
new_initializers.append(init)
del m.graph.initializer[:]
m.graph.initializer.extend(new_initializers)
del m.graph.value_info[:] # clear stale type annotations
print(f" Saving proto β {INT8_OUT}")
onnx.save(m, str(INT8_OUT))
print(f" Proto size: {INT8_OUT.stat().st_size / 1e6:.1f} MB")
total_gb = (INT8_OUT.stat().st_size + INT8_DATA.stat().st_size) / 1e9
print(f" Total INT8 size: {total_gb:.2f} GB")
def verify():
import onnxruntime as ort
from transformers import AutoTokenizer
print(f"\nVerifying {INT8_OUT.name}...")
sess_opts = ort.SessionOptions()
sess = ort.InferenceSession(
str(INT8_OUT), sess_opts, providers=["CPUExecutionProvider"]
)
for inp in sess.get_inputs():
print(f" in: {inp.name} {inp.shape}")
for out in sess.get_outputs():
print(f" out: {out.name} {out.shape}")
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
pairs = [
("what is a panda?", "A panda is a large black-and-white bear native to China."),
("what is a panda?", "The sky is blue and the grass is green."),
]
scores = []
for q, d in pairs:
enc = tok(q, d, return_tensors="np", truncation=True, max_length=256)
logit = sess.run(["logits"], {
"input_ids": enc["input_ids"].astype(np.int64),
"attention_mask": enc["attention_mask"].astype(np.int64),
})[0]
scores.append(float(logit[0][0]))
print(f" logits: {[f'{s:.3f}' for s in scores]}")
assert scores[0] > scores[1], \
f"Relevant doc should score higher: {scores[0]:.3f} vs {scores[1]:.3f}"
print(" OK β relevant doc ranked higher")
if __name__ == "__main__":
for p in [INT8_OUT, INT8_DATA]:
if p.exists():
p.unlink()
print(f"Deleted {p.name}")
quantize_model()
gc.collect()
verify()
print("\nAll done. Upload commands:")
print(" huggingface-cli upload cstr/zerank-1-small-ONNX /private/tmp/zerank_export/zerank_onnx . --repo-type model")
print(f" huggingface-cli upload cstr/zerank-1-small-ONNX {INT8_OUT.parent}/ . --commit-message 'add INT8' --repo-type model --include '*.onnx*'")
print(f" huggingface-cli upload cstr/zerank-1-small-ONNX /Volumes/backups/ai/zerank_onnx_int4/model_int4_full.onnx model_int4_full.onnx --repo-type model")
|