Single-file Marlin INT4: direct RTN quantization, no GPTQ intermediate
Browse filesReplace 2-file GPTQ+Marlin format with single consolidated.safetensors.
- Single-step BF16 β RTN β Marlin pack (no intermediate GPTQ, scales computed once)
- 4.07 GB single file (was 5.8 GB across two files)
- Python server: remove DequantLinear/MarlinLinear, add PrepackedMarlinLinear
- Add quantize_marlin.py for reproducibility
- Tested on Jetson Orin Nano: identical transcription, 15.2 tok/s
- README.md +29 -28
- consolidated.safetensors +2 -2
- params.json +4 -4
- scripts/jetson_serve_sdpa.py +25 -265
- scripts/quantize_marlin.py +266 -0
README.md
CHANGED
|
@@ -8,6 +8,7 @@ tags:
|
|
| 8 |
- mistral
|
| 9 |
- int4
|
| 10 |
- quantized
|
|
|
|
| 11 |
- jetson
|
| 12 |
- edge
|
| 13 |
- realtime
|
|
@@ -32,35 +33,49 @@ language:
|
|
| 32 |
|
| 33 |
INT4 quantized [Voxtral Mini 4B Realtime](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602) for edge deployment on NVIDIA Jetson Orin Nano (8 GB).
|
| 34 |
|
| 35 |
-
**4.1 GB
|
| 36 |
|
| 37 |
## What's in this repo
|
| 38 |
|
| 39 |
| File | Size | Description |
|
| 40 |
|------|------|-------------|
|
| 41 |
-
| `consolidated.safetensors` | 4.1 GB |
|
| 42 |
| `params.json` | 1.6 KB | Model architecture config (Mistral native format) |
|
| 43 |
| `tekken.json` | 15 MB | Mistral tekken tokenizer |
|
| 44 |
-
| `scripts/jetson_serve_sdpa.py` |
|
|
|
|
| 45 |
| `kernels/fused_ops.cu` | 8.5 KB | Fused CUDA kernels (JIT compiled, SM87) |
|
| 46 |
|
| 47 |
## Quantization details
|
| 48 |
|
| 49 |
-
- **Method**: RTN (Round-To-Nearest)
|
| 50 |
-
- **Bits**: 4-bit (decoder linear layers),
|
| 51 |
- **Group size**: 128
|
| 52 |
-
- **Encoding**: uint4b8 (value + 8 bias),
|
| 53 |
-
- **
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
## Architecture
|
| 56 |
|
| 57 |
| Component | Params | Precision | Size |
|
| 58 |
|-----------|--------|-----------|------|
|
| 59 |
-
| Audio encoder (Whisper-style, 32 layers) | ~600M |
|
| 60 |
-
| Projector (5120 β 3072 β 3072) | ~25M |
|
| 61 |
-
| LM decoder (26 layers, 3072 hidden, GQA 32/8 heads) | ~3B | INT4 | ~
|
| 62 |
-
|
|
| 63 |
-
|
|
|
|
|
| 64 |
|
| 65 |
## Transcription quality
|
| 66 |
|
|
@@ -81,7 +96,7 @@ Tested on Fleurs en_us samples β near-perfect output matching the fp16 baselin
|
|
| 81 |
No HuggingFace or vLLM dependencies needed. Runs inside the [PyTorch Jetson container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-pytorch).
|
| 82 |
|
| 83 |
```bash
|
| 84 |
-
pip install safetensors websockets soundfile numpy librosa
|
| 85 |
|
| 86 |
# Test with an audio file
|
| 87 |
python scripts/jetson_serve_sdpa.py --test audio.wav
|
|
@@ -98,21 +113,6 @@ The server exposes `ws://localhost:8000/v1/realtime` for streaming transcription
|
|
| 98 |
- Pre-allocated KV cache (eliminates per-token torch.cat)
|
| 99 |
- Fused CUDA kernels for RMSNorm, RoPE, SiLUΒ·Mul (~500 kernel launches/token β ~80)
|
| 100 |
|
| 101 |
-
### Option 2: vLLM serving
|
| 102 |
-
|
| 103 |
-
```bash
|
| 104 |
-
pip install -U vllm --extra-index-url https://wheels.vllm.ai/nightly --pre
|
| 105 |
-
pip install librosa soxr
|
| 106 |
-
|
| 107 |
-
python -m vllm.entrypoints.openai.api_server \
|
| 108 |
-
--model /path/to/this/repo \
|
| 109 |
-
--tokenizer-mode mistral --config-format mistral --load-format mistral \
|
| 110 |
-
--max-model-len 8192 --dtype float16 --enforce-eager \
|
| 111 |
-
--gpu-memory-utilization 0.5
|
| 112 |
-
```
|
| 113 |
-
|
| 114 |
-
**Note**: Requires vLLM nightly (>=0.15.2dev) for `/v1/realtime` WebSocket support.
|
| 115 |
-
|
| 116 |
### WebSocket client example
|
| 117 |
|
| 118 |
```python
|
|
@@ -165,4 +165,5 @@ GPTQ quantization fails on this model at every bit precision (4-bit and 8-bit) w
|
|
| 165 |
## Credits
|
| 166 |
|
| 167 |
- Base model: [Voxtral Mini 4B Realtime](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602) by Mistral AI
|
|
|
|
| 168 |
- Quantization and Jetson optimization by [Teaspoon AI](https://huggingface.co/Teaspoon-AI)
|
|
|
|
| 8 |
- mistral
|
| 9 |
- int4
|
| 10 |
- quantized
|
| 11 |
+
- marlin
|
| 12 |
- jetson
|
| 13 |
- edge
|
| 14 |
- realtime
|
|
|
|
| 33 |
|
| 34 |
INT4 quantized [Voxtral Mini 4B Realtime](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602) for edge deployment on NVIDIA Jetson Orin Nano (8 GB).
|
| 35 |
|
| 36 |
+
**4.1 GB single file** β fits in 8 GB unified memory with room for KV cache and runtime.
|
| 37 |
|
| 38 |
## What's in this repo
|
| 39 |
|
| 40 |
| File | Size | Description |
|
| 41 |
|------|------|-------------|
|
| 42 |
+
| `consolidated.safetensors` | 4.1 GB | Marlin-packed INT4 decoder + BF16 encoder/norms/embeddings |
|
| 43 |
| `params.json` | 1.6 KB | Model architecture config (Mistral native format) |
|
| 44 |
| `tekken.json` | 15 MB | Mistral tekken tokenizer |
|
| 45 |
+
| `scripts/jetson_serve_sdpa.py` | ~50 KB | Self-contained inference server (no HF/vLLM deps) |
|
| 46 |
+
| `scripts/quantize_marlin.py` | ~6 KB | Quantization script to reproduce this model |
|
| 47 |
| `kernels/fused_ops.cu` | 8.5 KB | Fused CUDA kernels (JIT compiled, SM87) |
|
| 48 |
|
| 49 |
## Quantization details
|
| 50 |
|
| 51 |
+
- **Method**: RTN (Round-To-Nearest) quantized directly into Marlin-packed format
|
| 52 |
+
- **Bits**: 4-bit (decoder linear layers), BF16 (audio encoder, norms, embeddings)
|
| 53 |
- **Group size**: 128
|
| 54 |
+
- **Encoding**: uint4b8 (value + 8 bias), Marlin tiled INT4 layout
|
| 55 |
+
- **Single step**: BF16 β RTN quantize β Marlin pack (no intermediate GPTQ format, scales computed once)
|
| 56 |
+
- **Why RTN over GPTQ**: GPTQ's Hessian optimization destroys the critical SPAD-to-text transition boundary in Voxtral's streaming architecture. RTN preserves it perfectly. See [below](#why-rtn-not-gptq).
|
| 57 |
+
|
| 58 |
+
### Reproducing the quantization
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
pip install torch safetensors numpy
|
| 62 |
+
|
| 63 |
+
# From the original HuggingFace model:
|
| 64 |
+
python scripts/quantize_marlin.py \
|
| 65 |
+
--model-dir path/to/Voxtral-Mini-4B-Realtime-2602 \
|
| 66 |
+
--output-dir ./output
|
| 67 |
+
```
|
| 68 |
|
| 69 |
## Architecture
|
| 70 |
|
| 71 |
| Component | Params | Precision | Size |
|
| 72 |
|-----------|--------|-----------|------|
|
| 73 |
+
| Audio encoder (Whisper-style, 32 layers) | ~600M | BF16 | 1.86 GB |
|
| 74 |
+
| Projector (5120 β 3072 β 3072) | ~25M | BF16 | 0.05 GB |
|
| 75 |
+
| LM decoder (26 layers, 3072 hidden, GQA 32/8 heads) | ~3B | Marlin INT4 | ~1.58 GB |
|
| 76 |
+
| Token embeddings (131072 Γ 3072) | ~400M | BF16 | 0.77 GB |
|
| 77 |
+
| ada_rms_norm_t_cond + norms | ~1M | BF16 | 0.01 GB |
|
| 78 |
+
| **Total** | **~4B** | | **4.1 GB** |
|
| 79 |
|
| 80 |
## Transcription quality
|
| 81 |
|
|
|
|
| 96 |
No HuggingFace or vLLM dependencies needed. Runs inside the [PyTorch Jetson container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-pytorch).
|
| 97 |
|
| 98 |
```bash
|
| 99 |
+
pip install safetensors websockets soundfile numpy librosa marlin
|
| 100 |
|
| 101 |
# Test with an audio file
|
| 102 |
python scripts/jetson_serve_sdpa.py --test audio.wav
|
|
|
|
| 113 |
- Pre-allocated KV cache (eliminates per-token torch.cat)
|
| 114 |
- Fused CUDA kernels for RMSNorm, RoPE, SiLUΒ·Mul (~500 kernel launches/token β ~80)
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
### WebSocket client example
|
| 117 |
|
| 118 |
```python
|
|
|
|
| 165 |
## Credits
|
| 166 |
|
| 167 |
- Base model: [Voxtral Mini 4B Realtime](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602) by Mistral AI
|
| 168 |
+
- Marlin INT4 kernel: [IST-DASLab/marlin](https://github.com/IST-DASLab/marlin) (Apache 2.0)
|
| 169 |
- Quantization and Jetson optimization by [Teaspoon AI](https://huggingface.co/Teaspoon-AI)
|
consolidated.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a88ff5bf15ec7ca11fc7b0ff51148721dcca585f7c356baa2576eee785250d44
|
| 3 |
+
size 4367478888
|
params.json
CHANGED
|
@@ -54,12 +54,12 @@
|
|
| 54 |
"ada_rms_norm_t_cond": true,
|
| 55 |
"ada_rms_norm_t_cond_dim": 32,
|
| 56 |
"quantization_config": {
|
| 57 |
-
"quant_method": "
|
| 58 |
"bits": 4,
|
| 59 |
"group_size": 128,
|
| 60 |
-
"desc_act": false,
|
| 61 |
"sym": true,
|
| 62 |
-
"checkpoint_format": "
|
| 63 |
-
"pack_dtype": "int32"
|
|
|
|
| 64 |
}
|
| 65 |
}
|
|
|
|
| 54 |
"ada_rms_norm_t_cond": true,
|
| 55 |
"ada_rms_norm_t_cond_dim": 32,
|
| 56 |
"quantization_config": {
|
| 57 |
+
"quant_method": "rtn",
|
| 58 |
"bits": 4,
|
| 59 |
"group_size": 128,
|
|
|
|
| 60 |
"sym": true,
|
| 61 |
+
"checkpoint_format": "marlin",
|
| 62 |
+
"pack_dtype": "int32",
|
| 63 |
+
"encoding": "uint4b8"
|
| 64 |
}
|
| 65 |
}
|
scripts/jetson_serve_sdpa.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
-
"""Voxtral Mini 4B Realtime β Jetson Orin Nano
|
| 3 |
|
| 4 |
-
Loads
|
| 5 |
transcription via WebSocket at ws://localhost:8000/v1/realtime.
|
| 6 |
|
| 7 |
Key architecture detail: at each decoder position, the input embedding is
|
|
@@ -55,6 +55,7 @@ try:
|
|
| 55 |
HAS_MARLIN = True
|
| 56 |
except ImportError:
|
| 57 |
HAS_MARLIN = False
|
|
|
|
| 58 |
|
| 59 |
# Try to JIT-compile fused CUDA kernels (collapses ~500 kernel launches/token to ~80)
|
| 60 |
HAS_FUSED = False
|
|
@@ -103,54 +104,24 @@ DOWNSAMPLE_FACTOR = 4
|
|
| 103 |
|
| 104 |
# βββ Marlin Fused INT4 Linear ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 105 |
|
| 106 |
-
class
|
| 107 |
-
"""Linear layer using Marlin
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
"""
|
| 113 |
|
| 114 |
-
def __init__(self,
|
| 115 |
super().__init__()
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
self.
|
| 120 |
-
self.
|
| 121 |
-
|
| 122 |
-
# Dequantize GPTQ β fp16, then repack into Marlin format
|
| 123 |
-
shifts = torch.arange(0, 32, BITS, device=qweight.device, dtype=torch.int32)
|
| 124 |
-
unpacked = (qweight.unsqueeze(0) >> shifts.view(-1, 1, 1)) & 0xF
|
| 125 |
-
unpacked = unpacked.permute(1, 0, 2).reshape(in_features, out_features)
|
| 126 |
-
unpacked = unpacked.T.reshape(out_features, n_groups, GROUP_SIZE)
|
| 127 |
-
s = scales.T.float().unsqueeze(-1)
|
| 128 |
-
w_fp16 = ((unpacked.float() - BIAS) * s).reshape(out_features, in_features).half()
|
| 129 |
-
del unpacked, s
|
| 130 |
-
|
| 131 |
-
if unpermute is not None:
|
| 132 |
-
n_heads, hidden_size = unpermute
|
| 133 |
-
head_dim = w_fp16.shape[0] // n_heads
|
| 134 |
-
w_fp16 = (w_fp16.view(n_heads, 2, head_dim // 2, hidden_size)
|
| 135 |
-
.transpose(1, 2)
|
| 136 |
-
.reshape(out_features, in_features))
|
| 137 |
-
|
| 138 |
-
# Create temporary nn.Linear for Marlin's pack()
|
| 139 |
-
linear = nn.Linear(in_features, out_features, bias=False,
|
| 140 |
-
dtype=torch.half, device=qweight.device)
|
| 141 |
-
linear.weight.data = w_fp16
|
| 142 |
-
|
| 143 |
-
# Create Marlin layer and pack (handles permutation + bit packing)
|
| 144 |
-
ml = _marlin.Layer(in_features, out_features, groupsize=GROUP_SIZE)
|
| 145 |
-
ml.pack(linear, scales.T)
|
| 146 |
-
del linear, w_fp16
|
| 147 |
-
|
| 148 |
-
# Store Marlin buffers
|
| 149 |
-
self.register_buffer('B', ml.B.to(qweight.device))
|
| 150 |
-
self.register_buffer('s', ml.s.to(qweight.device))
|
| 151 |
self.register_buffer('workspace',
|
| 152 |
-
torch.zeros(out_features // 128 * 16,
|
| 153 |
-
dtype=torch.int, device=
|
| 154 |
persistent=False)
|
| 155 |
|
| 156 |
def forward(self, x):
|
|
@@ -161,85 +132,6 @@ class MarlinLinear(nn.Module):
|
|
| 161 |
return C
|
| 162 |
|
| 163 |
|
| 164 |
-
# βββ GPTQ INT4 Dequantization (fallback when Marlin unavailable) ββββββββββββ
|
| 165 |
-
|
| 166 |
-
class DequantLinear(nn.Module):
|
| 167 |
-
"""Linear layer with INT4 GPTQ packed weights.
|
| 168 |
-
|
| 169 |
-
Supports two modes:
|
| 170 |
-
- On-the-fly dequantization (default): dequantizes each forward call
|
| 171 |
-
- Cached mode: stores pre-dequantized fp16 weight for fast matmul
|
| 172 |
-
"""
|
| 173 |
-
|
| 174 |
-
_shifts = None # class-level cached shifts tensor
|
| 175 |
-
|
| 176 |
-
def __init__(self, qweight, scales, qzeros, unpermute=None):
|
| 177 |
-
super().__init__()
|
| 178 |
-
self.register_buffer('qweight', qweight)
|
| 179 |
-
self.register_buffer('scales', scales)
|
| 180 |
-
self.register_buffer('qzeros', qzeros)
|
| 181 |
-
self.in_features = qweight.shape[0] * PACK_FACTOR
|
| 182 |
-
self.out_features = qweight.shape[1]
|
| 183 |
-
self.unpermute = unpermute
|
| 184 |
-
self._cached_w = None # pre-dequantized fp16 weight [out, in]
|
| 185 |
-
|
| 186 |
-
def cache_weight(self, free_int4=True):
|
| 187 |
-
"""Pre-dequantize and cache the fp16 weight.
|
| 188 |
-
If free_int4=True, frees INT4 buffers (saves memory, not reversible).
|
| 189 |
-
"""
|
| 190 |
-
self._cached_w = self._dequantize()
|
| 191 |
-
if free_int4:
|
| 192 |
-
self.qweight = None
|
| 193 |
-
self.scales = None
|
| 194 |
-
self.qzeros = None
|
| 195 |
-
|
| 196 |
-
def uncache_weight(self):
|
| 197 |
-
"""Free the cached weight (e.g., before re-loading INT4 weights)."""
|
| 198 |
-
self._cached_w = None
|
| 199 |
-
|
| 200 |
-
@property
|
| 201 |
-
def cached_bytes(self):
|
| 202 |
-
"""Memory used by cached weight in bytes."""
|
| 203 |
-
if self._cached_w is not None:
|
| 204 |
-
return self._cached_w.nelement() * self._cached_w.element_size()
|
| 205 |
-
return 0
|
| 206 |
-
|
| 207 |
-
def _dequantize(self):
|
| 208 |
-
"""Dequantize INT4 packed weights to fp16 [out, in]."""
|
| 209 |
-
qw = self.qweight
|
| 210 |
-
in_packed, out = qw.shape
|
| 211 |
-
n_groups = self.scales.shape[0]
|
| 212 |
-
|
| 213 |
-
# Cached shifts tensor (shared across all instances)
|
| 214 |
-
if DequantLinear._shifts is None or DequantLinear._shifts.device != qw.device:
|
| 215 |
-
DequantLinear._shifts = torch.arange(0, 32, BITS, device=qw.device, dtype=torch.int32)
|
| 216 |
-
shifts = DequantLinear._shifts
|
| 217 |
-
|
| 218 |
-
# Vectorized unpack: [8, in/8, out]
|
| 219 |
-
unpacked = (qw.unsqueeze(0) >> shifts.view(-1, 1, 1)) & 0xF
|
| 220 |
-
# Interleave to [in, out] then transpose+group to [out, groups, GROUP_SIZE]
|
| 221 |
-
unpacked = unpacked.permute(1, 0, 2).reshape(self.in_features, out)
|
| 222 |
-
unpacked = unpacked.T.reshape(out, n_groups, GROUP_SIZE)
|
| 223 |
-
# Scale: (val - 8) * scale
|
| 224 |
-
s = self.scales.T.float().unsqueeze(-1)
|
| 225 |
-
w = ((unpacked.float() - BIAS) * s).reshape(out, self.in_features).half()
|
| 226 |
-
del unpacked, s
|
| 227 |
-
|
| 228 |
-
if self.unpermute is not None:
|
| 229 |
-
n_heads, hidden_size = self.unpermute
|
| 230 |
-
head_dim = w.shape[0] // n_heads
|
| 231 |
-
w = (w.view(n_heads, 2, head_dim // 2, hidden_size)
|
| 232 |
-
.transpose(1, 2)
|
| 233 |
-
.reshape(out, self.in_features))
|
| 234 |
-
return w
|
| 235 |
-
|
| 236 |
-
def forward(self, x):
|
| 237 |
-
if self._cached_w is not None:
|
| 238 |
-
return F.linear(x, self._cached_w)
|
| 239 |
-
w = self._dequantize()
|
| 240 |
-
result = F.linear(x, w)
|
| 241 |
-
del w
|
| 242 |
-
return result
|
| 243 |
|
| 244 |
|
| 245 |
# βββ Building Blocks βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -629,14 +521,13 @@ class VoxtralModel:
|
|
| 629 |
return F.linear(h, self.embed.weight)
|
| 630 |
|
| 631 |
def _dql(self, f, prefix, dev, unpermute=None):
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
return DequantLinear(qw, sc, qz, unpermute=unpermute)
|
| 640 |
|
| 641 |
def _set(self, module, name, tensor):
|
| 642 |
"""Replace a meta parameter with a real CUDA tensor."""
|
|
@@ -761,8 +652,7 @@ class VoxtralModel:
|
|
| 761 |
print(f" LM layers {start}-{end-1} loaded")
|
| 762 |
self._load_section(path, load_dec_batch)
|
| 763 |
|
| 764 |
-
|
| 765 |
-
print(f" LM decoder loaded ({self.n_layers} layers, {backend})")
|
| 766 |
gc.collect()
|
| 767 |
torch.cuda.empty_cache()
|
| 768 |
mem = torch.cuda.memory_allocated() / 1024**3
|
|
@@ -785,128 +675,6 @@ class VoxtralModel:
|
|
| 785 |
self.tokenizer = None
|
| 786 |
print(" WARNING: mistral_common not available, using fallback decoder")
|
| 787 |
|
| 788 |
-
def _pre_dequant(self):
|
| 789 |
-
"""Offload encoder to CPU and pre-dequantize decoder weights into GPU cache.
|
| 790 |
-
|
| 791 |
-
After encoding is done, the encoder (~1.86 GB) is no longer needed on GPU.
|
| 792 |
-
Offloading it frees memory for caching pre-dequantized decoder weights,
|
| 793 |
-
which eliminates the per-token dequantization overhead.
|
| 794 |
-
"""
|
| 795 |
-
import gc
|
| 796 |
-
|
| 797 |
-
if hasattr(self, '_decoder_cached') and self._decoder_cached:
|
| 798 |
-
return # already cached
|
| 799 |
-
|
| 800 |
-
t0 = time.time()
|
| 801 |
-
|
| 802 |
-
# Move encoder + projector to CPU to free GPU memory
|
| 803 |
-
self.encoder.cpu()
|
| 804 |
-
self.projector.cpu()
|
| 805 |
-
gc.collect()
|
| 806 |
-
torch.cuda.empty_cache()
|
| 807 |
-
self._evict_cache()
|
| 808 |
-
|
| 809 |
-
free, _ = torch.cuda.mem_get_info(0)
|
| 810 |
-
print(f" After encoder offload: {free/1024**3:.2f} GB free")
|
| 811 |
-
|
| 812 |
-
# Budget: leave 500 MB for KV cache + intermediates
|
| 813 |
-
budget = free - 500 * 1024 * 1024
|
| 814 |
-
used_bytes = 0
|
| 815 |
-
cached_count = 0
|
| 816 |
-
|
| 817 |
-
for i, dl in enumerate(self.layers):
|
| 818 |
-
projs = [dl.attn.q_proj, dl.attn.k_proj, dl.attn.v_proj, dl.attn.o_proj,
|
| 819 |
-
dl.gate_proj, dl.up_proj, dl.down_proj]
|
| 820 |
-
|
| 821 |
-
# Estimate net memory cost (fp16 weight minus freed INT4 buffers)
|
| 822 |
-
layer_fp16 = sum(
|
| 823 |
-
p.in_features * p.out_features * 2
|
| 824 |
-
for p in projs if isinstance(p, DequantLinear) and p._cached_w is None
|
| 825 |
-
)
|
| 826 |
-
layer_int4 = sum(
|
| 827 |
-
p.qweight.nelement() * 4 + p.scales.nelement() * 2 + p.qzeros.nelement() * 4
|
| 828 |
-
for p in projs
|
| 829 |
-
if isinstance(p, DequantLinear) and p.qweight is not None
|
| 830 |
-
)
|
| 831 |
-
net = layer_fp16 - layer_int4 # net increase in memory
|
| 832 |
-
|
| 833 |
-
if used_bytes + net > budget:
|
| 834 |
-
break
|
| 835 |
-
|
| 836 |
-
for p in projs:
|
| 837 |
-
if isinstance(p, DequantLinear) and p._cached_w is None and p.qweight is not None:
|
| 838 |
-
p.cache_weight(free_int4=True)
|
| 839 |
-
|
| 840 |
-
used_bytes += net
|
| 841 |
-
cached_count += 1
|
| 842 |
-
# Periodic cleanup to keep peak memory low
|
| 843 |
-
if cached_count % 5 == 0:
|
| 844 |
-
gc.collect()
|
| 845 |
-
torch.cuda.empty_cache()
|
| 846 |
-
|
| 847 |
-
gc.collect()
|
| 848 |
-
torch.cuda.empty_cache()
|
| 849 |
-
free2, _ = torch.cuda.mem_get_info(0)
|
| 850 |
-
print(f" Pre-dequantized {cached_count}/{self.n_layers} layers in {time.time()-t0:.1f}s, "
|
| 851 |
-
f"{free2/1024**3:.2f} GB free")
|
| 852 |
-
self._decoder_cached = True
|
| 853 |
-
|
| 854 |
-
def _restore_encoder(self):
|
| 855 |
-
"""Move encoder back to GPU for the next transcription.
|
| 856 |
-
|
| 857 |
-
Frees cached decoder weights first to make room, then reloads
|
| 858 |
-
INT4 weights for layers that had their buffers freed.
|
| 859 |
-
"""
|
| 860 |
-
import gc
|
| 861 |
-
|
| 862 |
-
if not hasattr(self, '_decoder_cached') or not self._decoder_cached:
|
| 863 |
-
return
|
| 864 |
-
|
| 865 |
-
t0 = time.time()
|
| 866 |
-
|
| 867 |
-
# Free cached decoder weights
|
| 868 |
-
needs_reload = []
|
| 869 |
-
for i, dl in enumerate(self.layers):
|
| 870 |
-
for p in [dl.attn.q_proj, dl.attn.k_proj, dl.attn.v_proj, dl.attn.o_proj,
|
| 871 |
-
dl.gate_proj, dl.up_proj, dl.down_proj]:
|
| 872 |
-
if isinstance(p, DequantLinear):
|
| 873 |
-
if p._cached_w is not None and p.qweight is None:
|
| 874 |
-
needs_reload.append(i)
|
| 875 |
-
p.uncache_weight()
|
| 876 |
-
|
| 877 |
-
gc.collect()
|
| 878 |
-
torch.cuda.empty_cache()
|
| 879 |
-
self._evict_cache()
|
| 880 |
-
|
| 881 |
-
# Move encoder + projector back to GPU
|
| 882 |
-
self.encoder.to(self.device)
|
| 883 |
-
self.projector.to(self.device)
|
| 884 |
-
|
| 885 |
-
# Reload INT4 weights for layers that were freed
|
| 886 |
-
if needs_reload:
|
| 887 |
-
needs_reload = sorted(set(needs_reload))
|
| 888 |
-
path = os.path.join(self.model_path, 'consolidated.safetensors')
|
| 889 |
-
D = str(self.device)
|
| 890 |
-
with safe_open(path, framework='pt', device=D) as f:
|
| 891 |
-
for i in needs_reload:
|
| 892 |
-
lp = f'layers.{i}'
|
| 893 |
-
dl = self.layers[i]
|
| 894 |
-
dl.attn.q_proj = self._dql(f, f'{lp}.self_attn.q_proj', D)
|
| 895 |
-
dl.attn.k_proj = self._dql(f, f'{lp}.self_attn.k_proj', D)
|
| 896 |
-
dl.attn.v_proj = self._dql(f, f'{lp}.self_attn.v_proj', D)
|
| 897 |
-
dl.attn.o_proj = self._dql(f, f'{lp}.self_attn.o_proj', D)
|
| 898 |
-
dl.gate_proj = self._dql(f, f'{lp}.mlp.gate_proj', D)
|
| 899 |
-
dl.up_proj = self._dql(f, f'{lp}.mlp.up_proj', D)
|
| 900 |
-
dl.down_proj = self._dql(f, f'{lp}.mlp.down_proj', D)
|
| 901 |
-
gc.collect()
|
| 902 |
-
torch.cuda.empty_cache()
|
| 903 |
-
print(f" Reloaded {len(needs_reload)} decoder layers from disk")
|
| 904 |
-
|
| 905 |
-
gc.collect()
|
| 906 |
-
torch.cuda.empty_cache()
|
| 907 |
-
self._decoder_cached = False
|
| 908 |
-
print(f" Encoder restored in {time.time()-t0:.1f}s")
|
| 909 |
-
|
| 910 |
def decode_tokens(self, ids):
|
| 911 |
if self.tokenizer is not None:
|
| 912 |
try:
|
|
@@ -950,10 +718,6 @@ class VoxtralModel:
|
|
| 950 |
free, _ = torch.cuda.mem_get_info(0)
|
| 951 |
print(f" CUDA free before inference: {free/1024**3:.2f} GB")
|
| 952 |
|
| 953 |
-
# Restore encoder to GPU if it was offloaded (only needed without Marlin)
|
| 954 |
-
if not HAS_MARLIN:
|
| 955 |
-
self._restore_encoder()
|
| 956 |
-
|
| 957 |
# 0. Pad audio for streaming alignment
|
| 958 |
audio = self._pad_audio(audio)
|
| 959 |
print(f" padded audio: {len(audio)} samples ({len(audio)/SAMPLE_RATE:.1f}s)")
|
|
@@ -985,11 +749,7 @@ class VoxtralModel:
|
|
| 985 |
del enc_ds
|
| 986 |
print(f" adapter: {adapter.shape}")
|
| 987 |
|
| 988 |
-
# 5.
|
| 989 |
-
if not HAS_MARLIN:
|
| 990 |
-
self._pre_dequant()
|
| 991 |
-
|
| 992 |
-
# 6. Build prompt: [BOS] + [SPAD] * (n_left_pad + delay_tokens)
|
| 993 |
prompt_len = 1 + self.n_left_pad + self.delay_tokens
|
| 994 |
prompt_ids = [TOKEN_BOS] + [TOKEN_STREAMING_PAD] * (self.n_left_pad + self.delay_tokens)
|
| 995 |
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
+
"""Voxtral Mini 4B Realtime β Jetson Orin Nano inference server.
|
| 3 |
|
| 4 |
+
Loads Marlin-packed INT4 weights from consolidated.safetensors and serves
|
| 5 |
transcription via WebSocket at ws://localhost:8000/v1/realtime.
|
| 6 |
|
| 7 |
Key architecture detail: at each decoder position, the input embedding is
|
|
|
|
| 55 |
HAS_MARLIN = True
|
| 56 |
except ImportError:
|
| 57 |
HAS_MARLIN = False
|
| 58 |
+
print("WARNING: Marlin not installed. Install with: pip install marlin")
|
| 59 |
|
| 60 |
# Try to JIT-compile fused CUDA kernels (collapses ~500 kernel launches/token to ~80)
|
| 61 |
HAS_FUSED = False
|
|
|
|
| 104 |
|
| 105 |
# βββ Marlin Fused INT4 Linear ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 106 |
|
| 107 |
+
class PrepackedMarlinLinear(nn.Module):
|
| 108 |
+
"""Linear layer using pre-packed Marlin INT4 weights from safetensors.
|
| 109 |
|
| 110 |
+
Loads .B and .s tensors directly β no GPTQβMarlin conversion needed.
|
| 111 |
+
Used with single-file consolidated.safetensors that already contains
|
| 112 |
+
Marlin-format weights.
|
| 113 |
"""
|
| 114 |
|
| 115 |
+
def __init__(self, B, s):
|
| 116 |
super().__init__()
|
| 117 |
+
# B: [K//16, 2*N] int32, s: [K//groupsize, N] fp16
|
| 118 |
+
self.in_features = B.shape[0] * 16
|
| 119 |
+
self.out_features = B.shape[1] // 2
|
| 120 |
+
self.register_buffer('B', B)
|
| 121 |
+
self.register_buffer('s', s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
self.register_buffer('workspace',
|
| 123 |
+
torch.zeros(self.out_features // 128 * 16,
|
| 124 |
+
dtype=torch.int, device=B.device),
|
| 125 |
persistent=False)
|
| 126 |
|
| 127 |
def forward(self, x):
|
|
|
|
| 132 |
return C
|
| 133 |
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
|
| 137 |
# βββ Building Blocks βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 521 |
return F.linear(h, self.embed.weight)
|
| 522 |
|
| 523 |
def _dql(self, f, prefix, dev, unpermute=None):
|
| 524 |
+
B = f.get_tensor(f'{prefix}.B').to(dev)
|
| 525 |
+
s = f.get_tensor(f'{prefix}.s').to(dev)
|
| 526 |
+
if not HAS_MARLIN:
|
| 527 |
+
raise RuntimeError(
|
| 528 |
+
"Marlin INT4 kernel required but not installed. "
|
| 529 |
+
"Install with: pip install marlin")
|
| 530 |
+
return PrepackedMarlinLinear(B, s)
|
|
|
|
| 531 |
|
| 532 |
def _set(self, module, name, tensor):
|
| 533 |
"""Replace a meta parameter with a real CUDA tensor."""
|
|
|
|
| 652 |
print(f" LM layers {start}-{end-1} loaded")
|
| 653 |
self._load_section(path, load_dec_batch)
|
| 654 |
|
| 655 |
+
print(f" LM decoder loaded ({self.n_layers} layers, Marlin fused INT4)")
|
|
|
|
| 656 |
gc.collect()
|
| 657 |
torch.cuda.empty_cache()
|
| 658 |
mem = torch.cuda.memory_allocated() / 1024**3
|
|
|
|
| 675 |
self.tokenizer = None
|
| 676 |
print(" WARNING: mistral_common not available, using fallback decoder")
|
| 677 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 678 |
def decode_tokens(self, ids):
|
| 679 |
if self.tokenizer is not None:
|
| 680 |
try:
|
|
|
|
| 718 |
free, _ = torch.cuda.mem_get_info(0)
|
| 719 |
print(f" CUDA free before inference: {free/1024**3:.2f} GB")
|
| 720 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 721 |
# 0. Pad audio for streaming alignment
|
| 722 |
audio = self._pad_audio(audio)
|
| 723 |
print(f" padded audio: {len(audio)} samples ({len(audio)/SAMPLE_RATE:.1f}s)")
|
|
|
|
| 749 |
del enc_ds
|
| 750 |
print(f" adapter: {adapter.shape}")
|
| 751 |
|
| 752 |
+
# 5. Build prompt: [BOS] + [SPAD] * (n_left_pad + delay_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 753 |
prompt_len = 1 + self.n_left_pad + self.delay_tokens
|
| 754 |
prompt_ids = [TOKEN_BOS] + [TOKEN_STREAMING_PAD] * (self.n_left_pad + self.delay_tokens)
|
| 755 |
|
scripts/quantize_marlin.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Single-step BF16 β Marlin INT4 quantization for Voxtral Realtime 4B.
|
| 3 |
+
|
| 4 |
+
Produces a single consolidated.safetensors with:
|
| 5 |
+
- Encoder + adapter + tok_embeddings + norms: BF16 (copied as-is)
|
| 6 |
+
- Decoder linear weights: Marlin-packed INT4 (group_size=128)
|
| 7 |
+
|
| 8 |
+
The decoder linears are RTN-quantized (round-to-nearest, symmetric, per-group)
|
| 9 |
+
and packed directly into Marlin's tiled INT4 format in one step β no intermediate
|
| 10 |
+
GPTQ format, no multiple requantization cycles.
|
| 11 |
+
|
| 12 |
+
Why RTN over GPTQ: GPTQ's Hessian optimization destroys the critical SPAD-to-text
|
| 13 |
+
transition boundary in Voxtral's streaming architecture because calibration runs
|
| 14 |
+
through MistralForCausalLM (without ada_rms_norm_t_cond). RTN preserves it.
|
| 15 |
+
|
| 16 |
+
Marlin pack logic from IST-DASLab/marlin (Apache 2.0):
|
| 17 |
+
https://github.com/IST-DASLab/marlin
|
| 18 |
+
|
| 19 |
+
Usage:
|
| 20 |
+
# From original HuggingFace BF16 model:
|
| 21 |
+
python3 quantize_marlin.py --model-dir path/to/Voxtral-Mini-4B-Realtime-2602
|
| 22 |
+
|
| 23 |
+
# Output (default: ./output/consolidated.safetensors):
|
| 24 |
+
python3 quantize_marlin.py --model-dir path/to/model --output-dir ./my-output
|
| 25 |
+
|
| 26 |
+
Requires: torch, numpy, safetensors
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import argparse
|
| 30 |
+
import gc
|
| 31 |
+
import json
|
| 32 |
+
import os
|
| 33 |
+
import shutil
|
| 34 |
+
import sys
|
| 35 |
+
import time
|
| 36 |
+
|
| 37 |
+
import numpy as np
|
| 38 |
+
import torch
|
| 39 |
+
from safetensors import safe_open
|
| 40 |
+
from safetensors.torch import save_file
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# βββ Model constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 44 |
+
|
| 45 |
+
N_LAYERS = 26
|
| 46 |
+
N_HEADS = 32
|
| 47 |
+
N_KV_HEADS = 8
|
| 48 |
+
DIM = 3072
|
| 49 |
+
HEAD_DIM = 128
|
| 50 |
+
|
| 51 |
+
# βββ Quantization constants ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
+
|
| 53 |
+
BITS = 4
|
| 54 |
+
GROUP_SIZE = 128
|
| 55 |
+
PACK_FACTOR = 32 // BITS # 8 int4 values per int32
|
| 56 |
+
BIAS = 1 << (BITS - 1) # 8 (uint4b8 encoding: stored = value + 8)
|
| 57 |
+
MAXQ = (1 << BITS) - 1 # 15
|
| 58 |
+
|
| 59 |
+
# βββ Mistral β HF naming for decoder linears βββββββββββββββββββββββββββββββββ
|
| 60 |
+
|
| 61 |
+
DECODER_LINEARS = {
|
| 62 |
+
"attention.wq": ("self_attn.q_proj", True, N_HEADS), # needs Q/K permute
|
| 63 |
+
"attention.wk": ("self_attn.k_proj", True, N_KV_HEADS), # needs Q/K permute
|
| 64 |
+
"attention.wv": ("self_attn.v_proj", False, None),
|
| 65 |
+
"attention.wo": ("self_attn.o_proj", False, None),
|
| 66 |
+
"feed_forward.w1": ("mlp.gate_proj", False, None),
|
| 67 |
+
"feed_forward.w2": ("mlp.down_proj", False, None),
|
| 68 |
+
"feed_forward.w3": ("mlp.up_proj", False, None),
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# βββ Marlin permutation tables (from IST-DASLab/marlin, Apache 2.0) βββββββββ
|
| 73 |
+
|
| 74 |
+
def _get_perms():
|
| 75 |
+
perm = []
|
| 76 |
+
for i in range(32):
|
| 77 |
+
perm1 = []
|
| 78 |
+
col = i // 4
|
| 79 |
+
for block in [0, 1]:
|
| 80 |
+
for row in [
|
| 81 |
+
2 * (i % 4),
|
| 82 |
+
2 * (i % 4) + 1,
|
| 83 |
+
2 * (i % 4 + 4),
|
| 84 |
+
2 * (i % 4 + 4) + 1,
|
| 85 |
+
]:
|
| 86 |
+
perm1.append(16 * row + col + 8 * block)
|
| 87 |
+
for j in range(4):
|
| 88 |
+
perm.extend([p + 256 * j for p in perm1])
|
| 89 |
+
|
| 90 |
+
perm = np.array(perm)
|
| 91 |
+
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
| 92 |
+
perm = perm.reshape((-1, 8))[:, interleave].ravel()
|
| 93 |
+
perm = torch.from_numpy(perm)
|
| 94 |
+
|
| 95 |
+
scale_perm = []
|
| 96 |
+
for i in range(8):
|
| 97 |
+
scale_perm.extend([i + 8 * j for j in range(8)])
|
| 98 |
+
|
| 99 |
+
return perm, scale_perm
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
_perm, _scale_perm = _get_perms()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# βββ Q/K head permutation (Mistral β HF interleaving) ββββββββββββββββββββββββ
|
| 106 |
+
|
| 107 |
+
def permute_qk(w, n_heads, hidden_size):
|
| 108 |
+
"""Apply MistralβHF head dimension interleaving for Q/K weights."""
|
| 109 |
+
head_dim = w.shape[0] // n_heads
|
| 110 |
+
return (
|
| 111 |
+
w.view(n_heads, head_dim // 2, 2, hidden_size)
|
| 112 |
+
.transpose(1, 2)
|
| 113 |
+
.reshape(n_heads * head_dim, hidden_size)
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# βββ Single-step RTN quantize + Marlin pack ββββββββββββββββββββββββββββββββββ
|
| 118 |
+
|
| 119 |
+
def quantize_and_pack_marlin(w_bf16, group_size=GROUP_SIZE):
|
| 120 |
+
"""RTN-quantize a BF16 weight and pack into Marlin format in one step.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
w_bf16: [N_out, K] BF16/FP16 weight tensor
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
B: [K//16, 2*N_out] int32 (Marlin-packed weights)
|
| 127 |
+
s: [K//group_size, N_out] fp16 (Marlin-permuted scales)
|
| 128 |
+
"""
|
| 129 |
+
N_out, K = w_bf16.shape
|
| 130 |
+
n_groups = K // group_size
|
| 131 |
+
tile = 16
|
| 132 |
+
|
| 133 |
+
# ββ Step 1: Compute per-group RTN scales ββ
|
| 134 |
+
# Work in [K, N] layout for Marlin packing
|
| 135 |
+
w = w_bf16.t().float().contiguous() # [K, N]
|
| 136 |
+
w_grouped = w.reshape(n_groups, group_size, N_out)
|
| 137 |
+
max_val = w_grouped.abs().amax(dim=1).clamp(min=1e-10) # [n_groups, N]
|
| 138 |
+
scales = (max_val / BIAS).half() # [n_groups, N] β scale = max_abs / 8
|
| 139 |
+
|
| 140 |
+
# ββ Step 2: Quantize to uint4 ββ
|
| 141 |
+
s_expanded = scales.float().unsqueeze(1).expand_as(w_grouped) # [n_groups, gs, N]
|
| 142 |
+
w_int = torch.round(w_grouped / s_expanded).clamp(-BIAS, BIAS - 1).int()
|
| 143 |
+
w_uint = (w_int + BIAS).clamp(0, MAXQ) # uint4b8: [-8,7] β [0,15]
|
| 144 |
+
w_uint = w_uint.reshape(K, N_out) # [K, N]
|
| 145 |
+
|
| 146 |
+
# ββ Step 3: Permute scales for Marlin ββ
|
| 147 |
+
s = scales.clone() # [n_groups, N]
|
| 148 |
+
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
|
| 149 |
+
s = s.reshape((-1, N_out)).contiguous()
|
| 150 |
+
|
| 151 |
+
# ββ Step 4: Tile into 16Γ16 blocks ββ
|
| 152 |
+
w_tiled = w_uint.reshape(K // tile, tile, N_out // tile, tile)
|
| 153 |
+
w_tiled = w_tiled.permute(0, 2, 1, 3)
|
| 154 |
+
w_tiled = w_tiled.reshape(K // tile, N_out * tile)
|
| 155 |
+
|
| 156 |
+
# ββ Step 5: Apply Marlin permutation ββ
|
| 157 |
+
res = w_tiled.reshape((-1, _perm.numel()))[:, _perm].reshape(w_tiled.shape)
|
| 158 |
+
|
| 159 |
+
# ββ Step 6: Pack 8 int4 values into each int32 ββ
|
| 160 |
+
q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32)
|
| 161 |
+
res_np = res.cpu().numpy().astype(np.uint32)
|
| 162 |
+
for i in range(8):
|
| 163 |
+
q |= res_np[:, i::8] << (4 * i)
|
| 164 |
+
B = torch.from_numpy(q.astype(np.int32))
|
| 165 |
+
|
| 166 |
+
return B, s.half()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# βββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 170 |
+
|
| 171 |
+
def main():
|
| 172 |
+
parser = argparse.ArgumentParser(
|
| 173 |
+
description="Quantize Voxtral BF16 β single-file Marlin INT4")
|
| 174 |
+
parser.add_argument("--model-dir", required=True,
|
| 175 |
+
help="Directory with consolidated.safetensors (BF16, Mistral format)")
|
| 176 |
+
parser.add_argument("--output-dir", default="./output",
|
| 177 |
+
help="Output directory (default: ./output)")
|
| 178 |
+
args = parser.parse_args()
|
| 179 |
+
|
| 180 |
+
sf_path = os.path.join(args.model_dir, "consolidated.safetensors")
|
| 181 |
+
if not os.path.exists(sf_path):
|
| 182 |
+
print(f"Error: {sf_path} not found", file=sys.stderr)
|
| 183 |
+
sys.exit(1)
|
| 184 |
+
|
| 185 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 186 |
+
output_path = os.path.join(args.output_dir, "consolidated.safetensors")
|
| 187 |
+
|
| 188 |
+
print(f"Input: {sf_path}")
|
| 189 |
+
print(f"Output: {output_path}")
|
| 190 |
+
print(f"Quantization: RTN {BITS}-bit, group_size={GROUP_SIZE}, uint4b8 Marlin")
|
| 191 |
+
print()
|
| 192 |
+
|
| 193 |
+
sf = safe_open(sf_path, framework="pt", device="cpu")
|
| 194 |
+
all_keys = list(sf.keys())
|
| 195 |
+
tensors = {}
|
| 196 |
+
t0 = time.time()
|
| 197 |
+
|
| 198 |
+
# ββ Pass 1: Copy non-decoder-linear tensors as-is ββ
|
| 199 |
+
# These are encoder, adapter, tok_embeddings, norms, ada_rms_norm, final norm
|
| 200 |
+
decoder_linear_keys = set()
|
| 201 |
+
for layer_idx in range(N_LAYERS):
|
| 202 |
+
for mistral_name in DECODER_LINEARS:
|
| 203 |
+
decoder_linear_keys.add(f"layers.{layer_idx}.{mistral_name}.weight")
|
| 204 |
+
|
| 205 |
+
n_copied = 0
|
| 206 |
+
for key in all_keys:
|
| 207 |
+
if key in decoder_linear_keys:
|
| 208 |
+
continue
|
| 209 |
+
tensors[key] = sf.get_tensor(key)
|
| 210 |
+
n_copied += 1
|
| 211 |
+
|
| 212 |
+
print(f"Copied {n_copied} non-linear tensors (encoder, norms, embeddings, etc.)")
|
| 213 |
+
|
| 214 |
+
# ββ Pass 2: Quantize decoder linears β Marlin ββ
|
| 215 |
+
n_quantized = 0
|
| 216 |
+
for layer_idx in range(N_LAYERS):
|
| 217 |
+
for mistral_name, (hf_name, needs_permute, n_heads) in DECODER_LINEARS.items():
|
| 218 |
+
src_key = f"layers.{layer_idx}.{mistral_name}.weight"
|
| 219 |
+
w = sf.get_tensor(src_key).half() # bf16 β fp16 for torch ops
|
| 220 |
+
|
| 221 |
+
# Apply Q/K head permutation if needed
|
| 222 |
+
if needs_permute:
|
| 223 |
+
w = permute_qk(w, n_heads, DIM)
|
| 224 |
+
|
| 225 |
+
# Single-step quantize + Marlin pack
|
| 226 |
+
B, s = quantize_and_pack_marlin(w)
|
| 227 |
+
del w
|
| 228 |
+
|
| 229 |
+
out_prefix = f"layers.{layer_idx}.{hf_name}"
|
| 230 |
+
tensors[f"{out_prefix}.B"] = B
|
| 231 |
+
tensors[f"{out_prefix}.s"] = s
|
| 232 |
+
n_quantized += 1
|
| 233 |
+
|
| 234 |
+
gc.collect()
|
| 235 |
+
elapsed = time.time() - t0
|
| 236 |
+
print(f" Layer {layer_idx + 1}/{N_LAYERS} quantized ({elapsed:.1f}s)")
|
| 237 |
+
|
| 238 |
+
print(f"\nQuantized {n_quantized} decoder linear weights to Marlin INT4")
|
| 239 |
+
print(f"Total tensors in output: {len(tensors)}")
|
| 240 |
+
|
| 241 |
+
# ββ Save ββ
|
| 242 |
+
print(f"\nSaving to {output_path}...")
|
| 243 |
+
save_file(tensors, output_path)
|
| 244 |
+
file_size = os.path.getsize(output_path)
|
| 245 |
+
print(f"Output: {file_size / (1024**3):.2f} GB ({len(tensors)} tensors)")
|
| 246 |
+
|
| 247 |
+
# ββ Copy auxiliary files ββ
|
| 248 |
+
for aux in ["params.json", "tekken.json"]:
|
| 249 |
+
src = os.path.join(args.model_dir, aux)
|
| 250 |
+
if os.path.exists(src):
|
| 251 |
+
shutil.copy2(src, os.path.join(args.output_dir, aux))
|
| 252 |
+
print(f"Copied {aux}")
|
| 253 |
+
|
| 254 |
+
print(f"\nDone in {time.time() - t0:.1f}s")
|
| 255 |
+
|
| 256 |
+
# ββ Verify tensor names ββ
|
| 257 |
+
print(f"\nSample Marlin tensor names:")
|
| 258 |
+
marlin_keys = sorted(k for k in tensors if k.endswith(".B"))[:5]
|
| 259 |
+
for k in marlin_keys:
|
| 260 |
+
print(f" {k}: {list(tensors[k].shape)} {tensors[k].dtype}")
|
| 261 |
+
sk = k[:-2] + ".s"
|
| 262 |
+
print(f" {sk}: {list(tensors[sk].shape)} {tensors[sk].dtype}")
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
if __name__ == "__main__":
|
| 266 |
+
main()
|