| |
| """Single-step BF16 β Marlin INT4 quantization for Voxtral Realtime 4B. |
| |
| Produces a single consolidated.safetensors with: |
| - Encoder + adapter + tok_embeddings + norms: BF16 (copied as-is) |
| - Decoder linear weights: Marlin-packed INT4 (group_size=128) |
| |
| The decoder linears are RTN-quantized (round-to-nearest, symmetric, per-group) |
| and packed directly into Marlin's tiled INT4 format in one step β no intermediate |
| GPTQ format, no multiple requantization cycles. |
| |
| Why RTN over GPTQ: GPTQ's Hessian optimization destroys the critical SPAD-to-text |
| transition boundary in Voxtral's streaming architecture because calibration runs |
| through MistralForCausalLM (without ada_rms_norm_t_cond). RTN preserves it. |
| |
| Marlin pack logic from IST-DASLab/marlin (Apache 2.0): |
| https://github.com/IST-DASLab/marlin |
| |
| Usage: |
| # From original HuggingFace BF16 model: |
| python3 quantize_marlin.py --model-dir path/to/Voxtral-Mini-4B-Realtime-2602 |
| |
| # Output (default: ./output/consolidated.safetensors): |
| python3 quantize_marlin.py --model-dir path/to/model --output-dir ./my-output |
| |
| Requires: torch, numpy, safetensors |
| """ |
|
|
| import argparse |
| import gc |
| import json |
| import os |
| import shutil |
| import sys |
| import time |
|
|
| import numpy as np |
| import torch |
| from safetensors import safe_open |
| from safetensors.torch import save_file |
|
|
|
|
| |
|
|
| N_LAYERS = 26 |
| N_HEADS = 32 |
| N_KV_HEADS = 8 |
| DIM = 3072 |
| HEAD_DIM = 128 |
|
|
| |
|
|
| BITS = 4 |
| GROUP_SIZE = 128 |
| PACK_FACTOR = 32 // BITS |
| BIAS = 1 << (BITS - 1) |
| MAXQ = (1 << BITS) - 1 |
|
|
| |
|
|
| DECODER_LINEARS = { |
| "attention.wq": ("self_attn.q_proj", True, N_HEADS), |
| "attention.wk": ("self_attn.k_proj", True, N_KV_HEADS), |
| "attention.wv": ("self_attn.v_proj", False, None), |
| "attention.wo": ("self_attn.o_proj", False, None), |
| "feed_forward.w1": ("mlp.gate_proj", False, None), |
| "feed_forward.w2": ("mlp.down_proj", False, None), |
| "feed_forward.w3": ("mlp.up_proj", False, None), |
| } |
|
|
|
|
| |
|
|
| def _get_perms(): |
| perm = [] |
| for i in range(32): |
| perm1 = [] |
| col = i // 4 |
| for block in [0, 1]: |
| for row in [ |
| 2 * (i % 4), |
| 2 * (i % 4) + 1, |
| 2 * (i % 4 + 4), |
| 2 * (i % 4 + 4) + 1, |
| ]: |
| perm1.append(16 * row + col + 8 * block) |
| for j in range(4): |
| perm.extend([p + 256 * j for p in perm1]) |
|
|
| perm = np.array(perm) |
| interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) |
| perm = perm.reshape((-1, 8))[:, interleave].ravel() |
| perm = torch.from_numpy(perm) |
|
|
| scale_perm = [] |
| for i in range(8): |
| scale_perm.extend([i + 8 * j for j in range(8)]) |
|
|
| return perm, scale_perm |
|
|
|
|
| _perm, _scale_perm = _get_perms() |
|
|
|
|
| |
|
|
| def permute_qk(w, n_heads, hidden_size): |
| """Apply MistralβHF head dimension interleaving for Q/K weights.""" |
| head_dim = w.shape[0] // n_heads |
| return ( |
| w.view(n_heads, head_dim // 2, 2, hidden_size) |
| .transpose(1, 2) |
| .reshape(n_heads * head_dim, hidden_size) |
| ) |
|
|
|
|
| |
|
|
| def quantize_and_pack_marlin(w_bf16, group_size=GROUP_SIZE): |
| """RTN-quantize a BF16 weight and pack into Marlin format in one step. |
| |
| Args: |
| w_bf16: [N_out, K] BF16/FP16 weight tensor |
| |
| Returns: |
| B: [K//16, 2*N_out] int32 (Marlin-packed weights) |
| s: [K//group_size, N_out] fp16 (Marlin-permuted scales) |
| """ |
| N_out, K = w_bf16.shape |
| n_groups = K // group_size |
| tile = 16 |
|
|
| |
| |
| w = w_bf16.t().float().contiguous() |
| w_grouped = w.reshape(n_groups, group_size, N_out) |
| max_val = w_grouped.abs().amax(dim=1).clamp(min=1e-10) |
| scales = (max_val / BIAS).half() |
|
|
| |
| s_expanded = scales.float().unsqueeze(1).expand_as(w_grouped) |
| w_int = torch.round(w_grouped / s_expanded).clamp(-BIAS, BIAS - 1).int() |
| w_uint = (w_int + BIAS).clamp(0, MAXQ) |
| w_uint = w_uint.reshape(K, N_out) |
|
|
| |
| s = scales.clone() |
| s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm] |
| s = s.reshape((-1, N_out)).contiguous() |
|
|
| |
| w_tiled = w_uint.reshape(K // tile, tile, N_out // tile, tile) |
| w_tiled = w_tiled.permute(0, 2, 1, 3) |
| w_tiled = w_tiled.reshape(K // tile, N_out * tile) |
|
|
| |
| res = w_tiled.reshape((-1, _perm.numel()))[:, _perm].reshape(w_tiled.shape) |
|
|
| |
| q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32) |
| res_np = res.cpu().numpy().astype(np.uint32) |
| for i in range(8): |
| q |= res_np[:, i::8] << (4 * i) |
| B = torch.from_numpy(q.astype(np.int32)) |
|
|
| return B, s.half() |
|
|
|
|
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Quantize Voxtral BF16 β single-file Marlin INT4") |
| parser.add_argument("--model-dir", required=True, |
| help="Directory with consolidated.safetensors (BF16, Mistral format)") |
| parser.add_argument("--output-dir", default="./output", |
| help="Output directory (default: ./output)") |
| args = parser.parse_args() |
|
|
| sf_path = os.path.join(args.model_dir, "consolidated.safetensors") |
| if not os.path.exists(sf_path): |
| print(f"Error: {sf_path} not found", file=sys.stderr) |
| sys.exit(1) |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| output_path = os.path.join(args.output_dir, "consolidated.safetensors") |
|
|
| print(f"Input: {sf_path}") |
| print(f"Output: {output_path}") |
| print(f"Quantization: RTN {BITS}-bit, group_size={GROUP_SIZE}, uint4b8 Marlin") |
| print() |
|
|
| sf = safe_open(sf_path, framework="pt", device="cpu") |
| all_keys = list(sf.keys()) |
| tensors = {} |
| t0 = time.time() |
|
|
| |
| |
| decoder_linear_keys = set() |
| for layer_idx in range(N_LAYERS): |
| for mistral_name in DECODER_LINEARS: |
| decoder_linear_keys.add(f"layers.{layer_idx}.{mistral_name}.weight") |
|
|
| n_copied = 0 |
| for key in all_keys: |
| if key in decoder_linear_keys: |
| continue |
| tensors[key] = sf.get_tensor(key) |
| n_copied += 1 |
|
|
| print(f"Copied {n_copied} non-linear tensors (encoder, norms, embeddings, etc.)") |
|
|
| |
| n_quantized = 0 |
| for layer_idx in range(N_LAYERS): |
| for mistral_name, (hf_name, needs_permute, n_heads) in DECODER_LINEARS.items(): |
| src_key = f"layers.{layer_idx}.{mistral_name}.weight" |
| w = sf.get_tensor(src_key).half() |
|
|
| |
| if needs_permute: |
| w = permute_qk(w, n_heads, DIM) |
|
|
| |
| B, s = quantize_and_pack_marlin(w) |
| del w |
|
|
| out_prefix = f"layers.{layer_idx}.{hf_name}" |
| tensors[f"{out_prefix}.B"] = B |
| tensors[f"{out_prefix}.s"] = s |
| n_quantized += 1 |
|
|
| gc.collect() |
| elapsed = time.time() - t0 |
| print(f" Layer {layer_idx + 1}/{N_LAYERS} quantized ({elapsed:.1f}s)") |
|
|
| print(f"\nQuantized {n_quantized} decoder linear weights to Marlin INT4") |
| print(f"Total tensors in output: {len(tensors)}") |
|
|
| |
| print(f"\nSaving to {output_path}...") |
| save_file(tensors, output_path) |
| file_size = os.path.getsize(output_path) |
| print(f"Output: {file_size / (1024**3):.2f} GB ({len(tensors)} tensors)") |
|
|
| |
| for aux in ["params.json", "tekken.json"]: |
| src = os.path.join(args.model_dir, aux) |
| if os.path.exists(src): |
| shutil.copy2(src, os.path.join(args.output_dir, aux)) |
| print(f"Copied {aux}") |
|
|
| print(f"\nDone in {time.time() - t0:.1f}s") |
|
|
| |
| print(f"\nSample Marlin tensor names:") |
| marlin_keys = sorted(k for k in tensors if k.endswith(".B"))[:5] |
| for k in marlin_keys: |
| print(f" {k}: {list(tensors[k].shape)} {tensors[k].dtype}") |
| sk = k[:-2] + ".s" |
| print(f" {sk}: {list(tensors[sk].shape)} {tensors[sk].dtype}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|