cstr commited on
Commit
73e1c32
Β·
verified Β·
1 Parent(s): b19c63d

Upload stream_int8.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. stream_int8.py +269 -0
stream_int8.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Streaming weight-only INT8 quantizer for large ONNX models.
4
+
5
+ Implements the same transformation as:
6
+ quantize_dynamic(..., MatMulConstBOnly=True, per_channel=False, weight_type=QInt8)
7
+
8
+ Fully streaming: reads and writes one tensor at a time.
9
+ Peak RAM: ~1.5 GB (for the largest single tensor, the embedding table ~1.2 GB).
10
+
11
+ Usage: python stream_int8.py
12
+ """
13
+
14
+ import gc
15
+ from pathlib import Path
16
+ import numpy as np
17
+ import onnx
18
+ from onnx import TensorProto, numpy_helper, helper
19
+
20
+ FP32_ONNX = Path("/Volumes/backups/ai/zerank_fp32_tmp/model_fp32.onnx")
21
+ FP32_DATA = Path("/Volumes/backups/ai/zerank_fp32_tmp/model_fp32.onnx_data")
22
+ INT8_OUT = Path("/Volumes/backups/ai/zerank_onnx_int8/model_int8.onnx")
23
+ INT8_DATA = Path("/Volumes/backups/ai/zerank_onnx_int8/model_int8.onnx_data")
24
+ MODEL_ID = "zeroentropy/zerank-1-small"
25
+
26
+ INT8_OUT.parent.mkdir(parents=True, exist_ok=True)
27
+
28
+
29
+ def quantize_tensor_per_tensor(arr: np.ndarray):
30
+ """Symmetric per-tensor INT8 quantization (zero_point = 0)."""
31
+ arr = arr.astype(np.float32)
32
+ abs_max = np.max(np.abs(arr))
33
+ if abs_max == 0:
34
+ scale = np.float32(1.0)
35
+ quantized = np.zeros_like(arr, dtype=np.int8)
36
+ else:
37
+ scale = np.float32(abs_max / 127.0)
38
+ quantized = np.clip(np.round(arr / scale), -127, 127).astype(np.int8)
39
+ return quantized, scale
40
+
41
+
42
+ def add_external_data(init: onnx.TensorProto, offset: int, length: int, data_file_name: str):
43
+ """Update an initializer proto to point to external data."""
44
+ init.data_location = TensorProto.EXTERNAL
45
+ init.ClearField("external_data")
46
+ for k, v in [("location", data_file_name), ("offset", str(offset)), ("length", str(length))]:
47
+ e = init.external_data.add()
48
+ e.key, e.value = k, v
49
+
50
+
51
+ def quantize_model():
52
+ print(f"Loading proto skeleton (no external data)...")
53
+ m = onnx.load(str(FP32_ONNX), load_external_data=False)
54
+ print(f" Nodes: {len(m.graph.node)}, Initializers: {len(m.graph.initializer)}")
55
+
56
+ # Build index of external initializers
57
+ ext_index = {} # name β†’ (offset, length, dtype, dims)
58
+ inline_index = {} # name β†’ data bytes (for inline tensors)
59
+ for init in m.graph.initializer:
60
+ if init.data_location == TensorProto.EXTERNAL:
61
+ info = {e.key: e.value for e in init.external_data}
62
+ ext_index[init.name] = {
63
+ "offset": int(info.get("offset", 0)),
64
+ "length": int(info.get("length", 0)),
65
+ "dtype": init.data_type,
66
+ "dims": list(init.dims),
67
+ }
68
+ else:
69
+ inline_index[init.name] = init
70
+
71
+ # Find all MatMul nodes with constant B (initializer)
72
+ matmul_b_names = set()
73
+ for node in m.graph.node:
74
+ if node.op_type == "MatMul" and len(node.input) >= 2:
75
+ b_name = node.input[1]
76
+ if b_name in ext_index or b_name in inline_index:
77
+ matmul_b_names.add(b_name)
78
+
79
+ print(f" MatMul B weights to quantize: {len(matmul_b_names)}")
80
+ non_matmul = [name for name, meta in ext_index.items() if name not in matmul_b_names]
81
+ print(f" Non-MatMul external tensors (kept as FP32): {len(non_matmul)}")
82
+
83
+ # ── Phase 1: Stream all tensors to INT8 data file ─────────────────────────
84
+ print(f"\nPhase 1: Writing tensor data to {INT8_DATA.name}")
85
+ data_file_name = INT8_DATA.name # just the filename, not full path
86
+
87
+ # Track where each tensor ends up in the output data file
88
+ # key β†’ (offset, length) for the output
89
+ out_positions = {} # name β†’ (offset, length)
90
+ # For quantized weights: also store scale values (tiny, inline later)
91
+ scale_values = {} # weight_name β†’ float32 scale
92
+
93
+ try:
94
+ from tqdm import tqdm
95
+ except ImportError:
96
+ tqdm = None
97
+
98
+ offset = 0
99
+ with open(str(FP32_DATA), "rb") as fp32_f, open(str(INT8_DATA), "wb") as int8_f:
100
+ # 1a. Write quantized MatMul weights (INT8)
101
+ matmul_list = sorted(matmul_b_names)
102
+ if tqdm:
103
+ it = tqdm(matmul_list, desc=" Quantizing MatMul weights")
104
+ else:
105
+ it = matmul_list
106
+
107
+ for w_name in it:
108
+ if w_name in ext_index:
109
+ meta = ext_index[w_name]
110
+ fp32_f.seek(meta["offset"])
111
+ raw = fp32_f.read(meta["length"])
112
+ arr = np.frombuffer(raw, dtype=np.float32).reshape(meta["dims"])
113
+ else:
114
+ arr = numpy_helper.to_array(inline_index[w_name]).astype(np.float32)
115
+
116
+ q_arr, scale_val = quantize_tensor_per_tensor(arr)
117
+ del arr
118
+ scale_values[w_name] = scale_val
119
+
120
+ raw_int8 = q_arr.tobytes()
121
+ int8_f.write(raw_int8)
122
+ out_positions[w_name + "_quantized"] = (offset, len(raw_int8))
123
+ offset += len(raw_int8)
124
+ del q_arr
125
+
126
+ # 1b. Copy non-MatMul external tensors verbatim (already FP32/int64/etc.)
127
+ print(f" Copying {len(non_matmul)} non-MatMul tensors...")
128
+ for name in non_matmul:
129
+ meta = ext_index[name]
130
+ fp32_f.seek(meta["offset"])
131
+ raw = fp32_f.read(meta["length"])
132
+ int8_f.write(raw)
133
+ out_positions[name] = (offset, len(raw))
134
+ offset += len(raw)
135
+
136
+ print(f" Data file written: {INT8_DATA.stat().st_size / 1e9:.2f} GB")
137
+
138
+ # ── Phase 2: Rebuild the ONNX proto ───────────────────────────────────────
139
+ print("\nPhase 2: Rebuilding ONNX proto...")
140
+
141
+ # Rebuild graph: replace MatMul nodes with DQL β†’ MatMul
142
+ new_nodes = []
143
+ dql_inserted = set()
144
+ for node in m.graph.node:
145
+ if node.op_type == "MatMul" and node.input[1] in matmul_b_names:
146
+ b_name = node.input[1]
147
+ dql_out_name = b_name + "_dequant"
148
+ if b_name not in dql_inserted:
149
+ dql_node = helper.make_node(
150
+ "DequantizeLinear",
151
+ inputs=[b_name + "_quantized", b_name + "_scale", b_name + "_zero_point"],
152
+ outputs=[dql_out_name],
153
+ )
154
+ new_nodes.append(dql_node)
155
+ dql_inserted.add(b_name)
156
+ new_node = helper.make_node(
157
+ "MatMul",
158
+ inputs=[node.input[0], dql_out_name],
159
+ outputs=list(node.output),
160
+ name=node.name,
161
+ )
162
+ new_nodes.append(new_node)
163
+ else:
164
+ new_nodes.append(node)
165
+
166
+ del m.graph.node[:]
167
+ m.graph.node.extend(new_nodes)
168
+
169
+ # Rebuild initializers
170
+ new_initializers = []
171
+
172
+ # a. Quantized MatMul weights (external data)
173
+ for w_name in matmul_b_names:
174
+ meta = ext_index.get(w_name) or {
175
+ "dims": list(numpy_helper.to_array(inline_index[w_name]).shape)
176
+ }
177
+ dims = meta["dims"]
178
+
179
+ q_init = TensorProto()
180
+ q_init.name = w_name + "_quantized"
181
+ q_init.data_type = TensorProto.INT8
182
+ q_init.dims.extend(dims)
183
+ off, length = out_positions[w_name + "_quantized"]
184
+ add_external_data(q_init, off, length, data_file_name)
185
+
186
+ scale_init = numpy_helper.from_array(
187
+ np.array([scale_values[w_name]], dtype=np.float32), name=w_name + "_scale"
188
+ )
189
+ zp_init = numpy_helper.from_array(
190
+ np.array([0], dtype=np.int8), name=w_name + "_zero_point"
191
+ )
192
+
193
+ new_initializers.extend([q_init, scale_init, zp_init])
194
+
195
+ # b. Non-MatMul external tensors (external data, already written)
196
+ for name in non_matmul:
197
+ meta = ext_index[name]
198
+ init = TensorProto()
199
+ init.name = name
200
+ init.data_type = meta["dtype"]
201
+ init.dims.extend(meta["dims"])
202
+ off, length = out_positions[name]
203
+ add_external_data(init, off, length, data_file_name)
204
+ new_initializers.append(init)
205
+
206
+ # c. Inline initializers from FP32 model (already inline in proto β€” not external data)
207
+ for init in m.graph.initializer:
208
+ if init.name not in ext_index: # it's inline
209
+ new_initializers.append(init)
210
+
211
+ del m.graph.initializer[:]
212
+ m.graph.initializer.extend(new_initializers)
213
+ del m.graph.value_info[:] # clear stale type annotations
214
+
215
+ print(f" Saving proto β†’ {INT8_OUT}")
216
+ onnx.save(m, str(INT8_OUT))
217
+ print(f" Proto size: {INT8_OUT.stat().st_size / 1e6:.1f} MB")
218
+ total_gb = (INT8_OUT.stat().st_size + INT8_DATA.stat().st_size) / 1e9
219
+ print(f" Total INT8 size: {total_gb:.2f} GB")
220
+
221
+
222
+ def verify():
223
+ import onnxruntime as ort
224
+ from transformers import AutoTokenizer
225
+
226
+ print(f"\nVerifying {INT8_OUT.name}...")
227
+ sess_opts = ort.SessionOptions()
228
+ sess = ort.InferenceSession(
229
+ str(INT8_OUT), sess_opts, providers=["CPUExecutionProvider"]
230
+ )
231
+ for inp in sess.get_inputs():
232
+ print(f" in: {inp.name} {inp.shape}")
233
+ for out in sess.get_outputs():
234
+ print(f" out: {out.name} {out.shape}")
235
+
236
+ tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
237
+ pairs = [
238
+ ("what is a panda?", "A panda is a large black-and-white bear native to China."),
239
+ ("what is a panda?", "The sky is blue and the grass is green."),
240
+ ]
241
+ scores = []
242
+ for q, d in pairs:
243
+ enc = tok(q, d, return_tensors="np", truncation=True, max_length=256)
244
+ logit = sess.run(["logits"], {
245
+ "input_ids": enc["input_ids"].astype(np.int64),
246
+ "attention_mask": enc["attention_mask"].astype(np.int64),
247
+ })[0]
248
+ scores.append(float(logit[0][0]))
249
+
250
+ print(f" logits: {[f'{s:.3f}' for s in scores]}")
251
+ assert scores[0] > scores[1], \
252
+ f"Relevant doc should score higher: {scores[0]:.3f} vs {scores[1]:.3f}"
253
+ print(" OK β€” relevant doc ranked higher")
254
+
255
+
256
+ if __name__ == "__main__":
257
+ for p in [INT8_OUT, INT8_DATA]:
258
+ if p.exists():
259
+ p.unlink()
260
+ print(f"Deleted {p.name}")
261
+
262
+ quantize_model()
263
+ gc.collect()
264
+ verify()
265
+
266
+ print("\nAll done. Upload commands:")
267
+ print(" huggingface-cli upload cstr/zerank-1-small-ONNX /private/tmp/zerank_export/zerank_onnx . --repo-type model")
268
+ print(f" huggingface-cli upload cstr/zerank-1-small-ONNX {INT8_OUT.parent}/ . --commit-message 'add INT8' --repo-type model --include '*.onnx*'")
269
+ print(f" huggingface-cli upload cstr/zerank-1-small-ONNX /Volumes/backups/ai/zerank_onnx_int4/model_int4_full.onnx model_int4_full.onnx --repo-type model")