File size: 10,479 Bytes
73e1c32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
#!/usr/bin/env python3
"""
Streaming weight-only INT8 quantizer for large ONNX models.

Implements the same transformation as:
  quantize_dynamic(..., MatMulConstBOnly=True, per_channel=False, weight_type=QInt8)

Fully streaming: reads and writes one tensor at a time.
Peak RAM: ~1.5 GB (for the largest single tensor, the embedding table ~1.2 GB).

Usage: python stream_int8.py
"""

import gc
from pathlib import Path
import numpy as np
import onnx
from onnx import TensorProto, numpy_helper, helper

FP32_ONNX = Path("/Volumes/backups/ai/zerank_fp32_tmp/model_fp32.onnx")
FP32_DATA = Path("/Volumes/backups/ai/zerank_fp32_tmp/model_fp32.onnx_data")
INT8_OUT  = Path("/Volumes/backups/ai/zerank_onnx_int8/model_int8.onnx")
INT8_DATA = Path("/Volumes/backups/ai/zerank_onnx_int8/model_int8.onnx_data")
MODEL_ID  = "zeroentropy/zerank-1-small"

INT8_OUT.parent.mkdir(parents=True, exist_ok=True)


def quantize_tensor_per_tensor(arr: np.ndarray):
    """Symmetric per-tensor INT8 quantization (zero_point = 0)."""
    arr = arr.astype(np.float32)
    abs_max = np.max(np.abs(arr))
    if abs_max == 0:
        scale = np.float32(1.0)
        quantized = np.zeros_like(arr, dtype=np.int8)
    else:
        scale = np.float32(abs_max / 127.0)
        quantized = np.clip(np.round(arr / scale), -127, 127).astype(np.int8)
    return quantized, scale


def add_external_data(init: onnx.TensorProto, offset: int, length: int, data_file_name: str):
    """Update an initializer proto to point to external data."""
    init.data_location = TensorProto.EXTERNAL
    init.ClearField("external_data")
    for k, v in [("location", data_file_name), ("offset", str(offset)), ("length", str(length))]:
        e = init.external_data.add()
        e.key, e.value = k, v


def quantize_model():
    print(f"Loading proto skeleton (no external data)...")
    m = onnx.load(str(FP32_ONNX), load_external_data=False)
    print(f"  Nodes: {len(m.graph.node)}, Initializers: {len(m.graph.initializer)}")

    # Build index of external initializers
    ext_index = {}  # name β†’ (offset, length, dtype, dims)
    inline_index = {}  # name β†’ data bytes (for inline tensors)
    for init in m.graph.initializer:
        if init.data_location == TensorProto.EXTERNAL:
            info = {e.key: e.value for e in init.external_data}
            ext_index[init.name] = {
                "offset": int(info.get("offset", 0)),
                "length": int(info.get("length", 0)),
                "dtype": init.data_type,
                "dims": list(init.dims),
            }
        else:
            inline_index[init.name] = init

    # Find all MatMul nodes with constant B (initializer)
    matmul_b_names = set()
    for node in m.graph.node:
        if node.op_type == "MatMul" and len(node.input) >= 2:
            b_name = node.input[1]
            if b_name in ext_index or b_name in inline_index:
                matmul_b_names.add(b_name)

    print(f"  MatMul B weights to quantize: {len(matmul_b_names)}")
    non_matmul = [name for name, meta in ext_index.items() if name not in matmul_b_names]
    print(f"  Non-MatMul external tensors (kept as FP32): {len(non_matmul)}")

    # ── Phase 1: Stream all tensors to INT8 data file ─────────────────────────
    print(f"\nPhase 1: Writing tensor data to {INT8_DATA.name}")
    data_file_name = INT8_DATA.name  # just the filename, not full path

    # Track where each tensor ends up in the output data file
    # key β†’ (offset, length) for the output
    out_positions = {}  # name β†’ (offset, length)
    # For quantized weights: also store scale values (tiny, inline later)
    scale_values = {}   # weight_name β†’ float32 scale

    try:
        from tqdm import tqdm
    except ImportError:
        tqdm = None

    offset = 0
    with open(str(FP32_DATA), "rb") as fp32_f, open(str(INT8_DATA), "wb") as int8_f:
        # 1a. Write quantized MatMul weights (INT8)
        matmul_list = sorted(matmul_b_names)
        if tqdm:
            it = tqdm(matmul_list, desc="  Quantizing MatMul weights")
        else:
            it = matmul_list

        for w_name in it:
            if w_name in ext_index:
                meta = ext_index[w_name]
                fp32_f.seek(meta["offset"])
                raw = fp32_f.read(meta["length"])
                arr = np.frombuffer(raw, dtype=np.float32).reshape(meta["dims"])
            else:
                arr = numpy_helper.to_array(inline_index[w_name]).astype(np.float32)

            q_arr, scale_val = quantize_tensor_per_tensor(arr)
            del arr
            scale_values[w_name] = scale_val

            raw_int8 = q_arr.tobytes()
            int8_f.write(raw_int8)
            out_positions[w_name + "_quantized"] = (offset, len(raw_int8))
            offset += len(raw_int8)
            del q_arr

        # 1b. Copy non-MatMul external tensors verbatim (already FP32/int64/etc.)
        print(f"  Copying {len(non_matmul)} non-MatMul tensors...")
        for name in non_matmul:
            meta = ext_index[name]
            fp32_f.seek(meta["offset"])
            raw = fp32_f.read(meta["length"])
            int8_f.write(raw)
            out_positions[name] = (offset, len(raw))
            offset += len(raw)

    print(f"  Data file written: {INT8_DATA.stat().st_size / 1e9:.2f} GB")

    # ── Phase 2: Rebuild the ONNX proto ───────────────────────────────────────
    print("\nPhase 2: Rebuilding ONNX proto...")

    # Rebuild graph: replace MatMul nodes with DQL β†’ MatMul
    new_nodes = []
    dql_inserted = set()
    for node in m.graph.node:
        if node.op_type == "MatMul" and node.input[1] in matmul_b_names:
            b_name = node.input[1]
            dql_out_name = b_name + "_dequant"
            if b_name not in dql_inserted:
                dql_node = helper.make_node(
                    "DequantizeLinear",
                    inputs=[b_name + "_quantized", b_name + "_scale", b_name + "_zero_point"],
                    outputs=[dql_out_name],
                )
                new_nodes.append(dql_node)
                dql_inserted.add(b_name)
            new_node = helper.make_node(
                "MatMul",
                inputs=[node.input[0], dql_out_name],
                outputs=list(node.output),
                name=node.name,
            )
            new_nodes.append(new_node)
        else:
            new_nodes.append(node)

    del m.graph.node[:]
    m.graph.node.extend(new_nodes)

    # Rebuild initializers
    new_initializers = []

    # a. Quantized MatMul weights (external data)
    for w_name in matmul_b_names:
        meta = ext_index.get(w_name) or {
            "dims": list(numpy_helper.to_array(inline_index[w_name]).shape)
        }
        dims = meta["dims"]

        q_init = TensorProto()
        q_init.name = w_name + "_quantized"
        q_init.data_type = TensorProto.INT8
        q_init.dims.extend(dims)
        off, length = out_positions[w_name + "_quantized"]
        add_external_data(q_init, off, length, data_file_name)

        scale_init = numpy_helper.from_array(
            np.array([scale_values[w_name]], dtype=np.float32), name=w_name + "_scale"
        )
        zp_init = numpy_helper.from_array(
            np.array([0], dtype=np.int8), name=w_name + "_zero_point"
        )

        new_initializers.extend([q_init, scale_init, zp_init])

    # b. Non-MatMul external tensors (external data, already written)
    for name in non_matmul:
        meta = ext_index[name]
        init = TensorProto()
        init.name = name
        init.data_type = meta["dtype"]
        init.dims.extend(meta["dims"])
        off, length = out_positions[name]
        add_external_data(init, off, length, data_file_name)
        new_initializers.append(init)

    # c. Inline initializers from FP32 model (already inline in proto β€” not external data)
    for init in m.graph.initializer:
        if init.name not in ext_index:  # it's inline
            new_initializers.append(init)

    del m.graph.initializer[:]
    m.graph.initializer.extend(new_initializers)
    del m.graph.value_info[:]  # clear stale type annotations

    print(f"  Saving proto β†’ {INT8_OUT}")
    onnx.save(m, str(INT8_OUT))
    print(f"  Proto size: {INT8_OUT.stat().st_size / 1e6:.1f} MB")
    total_gb = (INT8_OUT.stat().st_size + INT8_DATA.stat().st_size) / 1e9
    print(f"  Total INT8 size: {total_gb:.2f} GB")


def verify():
    import onnxruntime as ort
    from transformers import AutoTokenizer

    print(f"\nVerifying {INT8_OUT.name}...")
    sess_opts = ort.SessionOptions()
    sess = ort.InferenceSession(
        str(INT8_OUT), sess_opts, providers=["CPUExecutionProvider"]
    )
    for inp in sess.get_inputs():
        print(f"  in:  {inp.name} {inp.shape}")
    for out in sess.get_outputs():
        print(f"  out: {out.name} {out.shape}")

    tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    pairs = [
        ("what is a panda?", "A panda is a large black-and-white bear native to China."),
        ("what is a panda?", "The sky is blue and the grass is green."),
    ]
    scores = []
    for q, d in pairs:
        enc = tok(q, d, return_tensors="np", truncation=True, max_length=256)
        logit = sess.run(["logits"], {
            "input_ids":      enc["input_ids"].astype(np.int64),
            "attention_mask": enc["attention_mask"].astype(np.int64),
        })[0]
        scores.append(float(logit[0][0]))

    print(f"  logits: {[f'{s:.3f}' for s in scores]}")
    assert scores[0] > scores[1], \
        f"Relevant doc should score higher: {scores[0]:.3f} vs {scores[1]:.3f}"
    print("  OK β€” relevant doc ranked higher")


if __name__ == "__main__":
    for p in [INT8_OUT, INT8_DATA]:
        if p.exists():
            p.unlink()
            print(f"Deleted {p.name}")

    quantize_model()
    gc.collect()
    verify()

    print("\nAll done. Upload commands:")
    print("  huggingface-cli upload cstr/zerank-1-small-ONNX /private/tmp/zerank_export/zerank_onnx . --repo-type model")
    print(f"  huggingface-cli upload cstr/zerank-1-small-ONNX {INT8_OUT.parent}/ . --commit-message 'add INT8' --repo-type model --include '*.onnx*'")
    print(f"  huggingface-cli upload cstr/zerank-1-small-ONNX /Volumes/backups/ai/zerank_onnx_int4/model_int4_full.onnx model_int4_full.onnx --repo-type model")