File size: 9,879 Bytes
22a85ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
#!/usr/bin/env python3
"""Single-step BF16 β†’ Marlin INT4 quantization for Voxtral Realtime 4B.

Produces a single consolidated.safetensors with:
  - Encoder + adapter + tok_embeddings + norms: BF16 (copied as-is)
  - Decoder linear weights: Marlin-packed INT4 (group_size=128)

The decoder linears are RTN-quantized (round-to-nearest, symmetric, per-group)
and packed directly into Marlin's tiled INT4 format in one step β€” no intermediate
GPTQ format, no multiple requantization cycles.

Why RTN over GPTQ: GPTQ's Hessian optimization destroys the critical SPAD-to-text
transition boundary in Voxtral's streaming architecture because calibration runs
through MistralForCausalLM (without ada_rms_norm_t_cond). RTN preserves it.

Marlin pack logic from IST-DASLab/marlin (Apache 2.0):
  https://github.com/IST-DASLab/marlin

Usage:
    # From original HuggingFace BF16 model:
    python3 quantize_marlin.py --model-dir path/to/Voxtral-Mini-4B-Realtime-2602

    # Output (default: ./output/consolidated.safetensors):
    python3 quantize_marlin.py --model-dir path/to/model --output-dir ./my-output

Requires: torch, numpy, safetensors
"""

import argparse
import gc
import json
import os
import shutil
import sys
import time

import numpy as np
import torch
from safetensors import safe_open
from safetensors.torch import save_file


# ─── Model constants ─────────────────────────────────────────────────────────

N_LAYERS = 26
N_HEADS = 32
N_KV_HEADS = 8
DIM = 3072
HEAD_DIM = 128

# ─── Quantization constants ──────────────────────────────────────────────────

BITS = 4
GROUP_SIZE = 128
PACK_FACTOR = 32 // BITS   # 8 int4 values per int32
BIAS = 1 << (BITS - 1)     # 8 (uint4b8 encoding: stored = value + 8)
MAXQ = (1 << BITS) - 1     # 15

# ─── Mistral β†’ HF naming for decoder linears ─────────────────────────────────

DECODER_LINEARS = {
    "attention.wq": ("self_attn.q_proj", True,  N_HEADS),     # needs Q/K permute
    "attention.wk": ("self_attn.k_proj", True,  N_KV_HEADS),  # needs Q/K permute
    "attention.wv": ("self_attn.v_proj", False, None),
    "attention.wo": ("self_attn.o_proj", False, None),
    "feed_forward.w1": ("mlp.gate_proj", False, None),
    "feed_forward.w2": ("mlp.down_proj", False, None),
    "feed_forward.w3": ("mlp.up_proj",   False, None),
}


# ─── Marlin permutation tables (from IST-DASLab/marlin, Apache 2.0) ─────────

def _get_perms():
    perm = []
    for i in range(32):
        perm1 = []
        col = i // 4
        for block in [0, 1]:
            for row in [
                2 * (i % 4),
                2 * (i % 4) + 1,
                2 * (i % 4 + 4),
                2 * (i % 4 + 4) + 1,
            ]:
                perm1.append(16 * row + col + 8 * block)
        for j in range(4):
            perm.extend([p + 256 * j for p in perm1])

    perm = np.array(perm)
    interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
    perm = perm.reshape((-1, 8))[:, interleave].ravel()
    perm = torch.from_numpy(perm)

    scale_perm = []
    for i in range(8):
        scale_perm.extend([i + 8 * j for j in range(8)])

    return perm, scale_perm


_perm, _scale_perm = _get_perms()


# ─── Q/K head permutation (Mistral β†’ HF interleaving) ────────────────────────

def permute_qk(w, n_heads, hidden_size):
    """Apply Mistral→HF head dimension interleaving for Q/K weights."""
    head_dim = w.shape[0] // n_heads
    return (
        w.view(n_heads, head_dim // 2, 2, hidden_size)
        .transpose(1, 2)
        .reshape(n_heads * head_dim, hidden_size)
    )


# ─── Single-step RTN quantize + Marlin pack ──────────────────────────────────

def quantize_and_pack_marlin(w_bf16, group_size=GROUP_SIZE):
    """RTN-quantize a BF16 weight and pack into Marlin format in one step.

    Args:
        w_bf16: [N_out, K] BF16/FP16 weight tensor

    Returns:
        B: [K//16, 2*N_out] int32 (Marlin-packed weights)
        s: [K//group_size, N_out] fp16 (Marlin-permuted scales)
    """
    N_out, K = w_bf16.shape
    n_groups = K // group_size
    tile = 16

    # ── Step 1: Compute per-group RTN scales ──
    # Work in [K, N] layout for Marlin packing
    w = w_bf16.t().float().contiguous()  # [K, N]
    w_grouped = w.reshape(n_groups, group_size, N_out)
    max_val = w_grouped.abs().amax(dim=1).clamp(min=1e-10)  # [n_groups, N]
    scales = (max_val / BIAS).half()  # [n_groups, N] β€” scale = max_abs / 8

    # ── Step 2: Quantize to uint4 ──
    s_expanded = scales.float().unsqueeze(1).expand_as(w_grouped)  # [n_groups, gs, N]
    w_int = torch.round(w_grouped / s_expanded).clamp(-BIAS, BIAS - 1).int()
    w_uint = (w_int + BIAS).clamp(0, MAXQ)  # uint4b8: [-8,7] β†’ [0,15]
    w_uint = w_uint.reshape(K, N_out)  # [K, N]

    # ── Step 3: Permute scales for Marlin ──
    s = scales.clone()  # [n_groups, N]
    s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
    s = s.reshape((-1, N_out)).contiguous()

    # ── Step 4: Tile into 16Γ—16 blocks ──
    w_tiled = w_uint.reshape(K // tile, tile, N_out // tile, tile)
    w_tiled = w_tiled.permute(0, 2, 1, 3)
    w_tiled = w_tiled.reshape(K // tile, N_out * tile)

    # ── Step 5: Apply Marlin permutation ──
    res = w_tiled.reshape((-1, _perm.numel()))[:, _perm].reshape(w_tiled.shape)

    # ── Step 6: Pack 8 int4 values into each int32 ──
    q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32)
    res_np = res.cpu().numpy().astype(np.uint32)
    for i in range(8):
        q |= res_np[:, i::8] << (4 * i)
    B = torch.from_numpy(q.astype(np.int32))

    return B, s.half()


# ─── Main ────────────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(
        description="Quantize Voxtral BF16 β†’ single-file Marlin INT4")
    parser.add_argument("--model-dir", required=True,
                        help="Directory with consolidated.safetensors (BF16, Mistral format)")
    parser.add_argument("--output-dir", default="./output",
                        help="Output directory (default: ./output)")
    args = parser.parse_args()

    sf_path = os.path.join(args.model_dir, "consolidated.safetensors")
    if not os.path.exists(sf_path):
        print(f"Error: {sf_path} not found", file=sys.stderr)
        sys.exit(1)

    os.makedirs(args.output_dir, exist_ok=True)
    output_path = os.path.join(args.output_dir, "consolidated.safetensors")

    print(f"Input:  {sf_path}")
    print(f"Output: {output_path}")
    print(f"Quantization: RTN {BITS}-bit, group_size={GROUP_SIZE}, uint4b8 Marlin")
    print()

    sf = safe_open(sf_path, framework="pt", device="cpu")
    all_keys = list(sf.keys())
    tensors = {}
    t0 = time.time()

    # ── Pass 1: Copy non-decoder-linear tensors as-is ──
    # These are encoder, adapter, tok_embeddings, norms, ada_rms_norm, final norm
    decoder_linear_keys = set()
    for layer_idx in range(N_LAYERS):
        for mistral_name in DECODER_LINEARS:
            decoder_linear_keys.add(f"layers.{layer_idx}.{mistral_name}.weight")

    n_copied = 0
    for key in all_keys:
        if key in decoder_linear_keys:
            continue
        tensors[key] = sf.get_tensor(key)
        n_copied += 1

    print(f"Copied {n_copied} non-linear tensors (encoder, norms, embeddings, etc.)")

    # ── Pass 2: Quantize decoder linears β†’ Marlin ──
    n_quantized = 0
    for layer_idx in range(N_LAYERS):
        for mistral_name, (hf_name, needs_permute, n_heads) in DECODER_LINEARS.items():
            src_key = f"layers.{layer_idx}.{mistral_name}.weight"
            w = sf.get_tensor(src_key).half()  # bf16 β†’ fp16 for torch ops

            # Apply Q/K head permutation if needed
            if needs_permute:
                w = permute_qk(w, n_heads, DIM)

            # Single-step quantize + Marlin pack
            B, s = quantize_and_pack_marlin(w)
            del w

            out_prefix = f"layers.{layer_idx}.{hf_name}"
            tensors[f"{out_prefix}.B"] = B
            tensors[f"{out_prefix}.s"] = s
            n_quantized += 1

        gc.collect()
        elapsed = time.time() - t0
        print(f"  Layer {layer_idx + 1}/{N_LAYERS} quantized ({elapsed:.1f}s)")

    print(f"\nQuantized {n_quantized} decoder linear weights to Marlin INT4")
    print(f"Total tensors in output: {len(tensors)}")

    # ── Save ──
    print(f"\nSaving to {output_path}...")
    save_file(tensors, output_path)
    file_size = os.path.getsize(output_path)
    print(f"Output: {file_size / (1024**3):.2f} GB ({len(tensors)} tensors)")

    # ── Copy auxiliary files ──
    for aux in ["params.json", "tekken.json"]:
        src = os.path.join(args.model_dir, aux)
        if os.path.exists(src):
            shutil.copy2(src, os.path.join(args.output_dir, aux))
            print(f"Copied {aux}")

    print(f"\nDone in {time.time() - t0:.1f}s")

    # ── Verify tensor names ──
    print(f"\nSample Marlin tensor names:")
    marlin_keys = sorted(k for k in tensors if k.endswith(".B"))[:5]
    for k in marlin_keys:
        print(f"  {k}: {list(tensors[k].shape)} {tensors[k].dtype}")
        sk = k[:-2] + ".s"
        print(f"  {sk}: {list(tensors[sk].shape)} {tensors[sk].dtype}")


if __name__ == "__main__":
    main()