tsp-stefano commited on
Commit
22a85ee
Β·
verified Β·
1 Parent(s): 8198ed0

Single-file Marlin INT4: direct RTN quantization, no GPTQ intermediate

Browse files

Replace 2-file GPTQ+Marlin format with single consolidated.safetensors.

- Single-step BF16 β†’ RTN β†’ Marlin pack (no intermediate GPTQ, scales computed once)
- 4.07 GB single file (was 5.8 GB across two files)
- Python server: remove DequantLinear/MarlinLinear, add PrepackedMarlinLinear
- Add quantize_marlin.py for reproducibility
- Tested on Jetson Orin Nano: identical transcription, 15.2 tok/s

README.md CHANGED
@@ -8,6 +8,7 @@ tags:
8
  - mistral
9
  - int4
10
  - quantized
 
11
  - jetson
12
  - edge
13
  - realtime
@@ -32,35 +33,49 @@ language:
32
 
33
  INT4 quantized [Voxtral Mini 4B Realtime](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602) for edge deployment on NVIDIA Jetson Orin Nano (8 GB).
34
 
35
- **4.1 GB total** β€” fits in 8 GB unified memory with room for KV cache and runtime.
36
 
37
  ## What's in this repo
38
 
39
  | File | Size | Description |
40
  |------|------|-------------|
41
- | `consolidated.safetensors` | 4.1 GB | INT4 GPTQ-packed weights (encoder fp16 + decoder int4) |
42
  | `params.json` | 1.6 KB | Model architecture config (Mistral native format) |
43
  | `tekken.json` | 15 MB | Mistral tekken tokenizer |
44
- | `scripts/jetson_serve_sdpa.py` | 53 KB | Self-contained inference server (no HF/vLLM deps) |
 
45
  | `kernels/fused_ops.cu` | 8.5 KB | Fused CUDA kernels (JIT compiled, SM87) |
46
 
47
  ## Quantization details
48
 
49
- - **Method**: RTN (Round-To-Nearest) with INT4 packing in GPTQ format
50
- - **Bits**: 4-bit (decoder linear layers), fp16 (audio encoder, norms, embeddings)
51
  - **Group size**: 128
52
- - **Encoding**: uint4b8 (value + 8 bias), compatible with Marlin fused INT4 kernel
53
- - **Why RTN over GPTQ**: GPTQ's Hessian optimization destroys the critical SPAD-to-text transition boundary in Voxtral's streaming architecture. RTN preserves it perfectly. See the [full quantization report](https://huggingface.co/Teaspoon-AI/Voxtral-Mini-4B-INT4-Jetson/blob/main/README.md#why-rtn-not-gptq) below.
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  ## Architecture
56
 
57
  | Component | Params | Precision | Size |
58
  |-----------|--------|-----------|------|
59
- | Audio encoder (Whisper-style, 32 layers) | ~600M | fp16 | 1.86 GB |
60
- | Projector (5120 β†’ 3072 β†’ 3072) | ~25M | fp16 | 0.05 GB |
61
- | LM decoder (26 layers, 3072 hidden, GQA 32/8 heads) | ~3B | INT4 | ~2.2 GB |
62
- | ada_rms_norm_t_cond (52 tensors) | ~1M | fp16 | 0.01 GB |
63
- | **Total** | **~3.6B** | | **4.1 GB** |
 
64
 
65
  ## Transcription quality
66
 
@@ -81,7 +96,7 @@ Tested on Fleurs en_us samples β€” near-perfect output matching the fp16 baselin
81
  No HuggingFace or vLLM dependencies needed. Runs inside the [PyTorch Jetson container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-pytorch).
82
 
83
  ```bash
84
- pip install safetensors websockets soundfile numpy librosa
85
 
86
  # Test with an audio file
87
  python scripts/jetson_serve_sdpa.py --test audio.wav
@@ -98,21 +113,6 @@ The server exposes `ws://localhost:8000/v1/realtime` for streaming transcription
98
  - Pre-allocated KV cache (eliminates per-token torch.cat)
99
  - Fused CUDA kernels for RMSNorm, RoPE, SiLUΒ·Mul (~500 kernel launches/token β†’ ~80)
100
 
101
- ### Option 2: vLLM serving
102
-
103
- ```bash
104
- pip install -U vllm --extra-index-url https://wheels.vllm.ai/nightly --pre
105
- pip install librosa soxr
106
-
107
- python -m vllm.entrypoints.openai.api_server \
108
- --model /path/to/this/repo \
109
- --tokenizer-mode mistral --config-format mistral --load-format mistral \
110
- --max-model-len 8192 --dtype float16 --enforce-eager \
111
- --gpu-memory-utilization 0.5
112
- ```
113
-
114
- **Note**: Requires vLLM nightly (>=0.15.2dev) for `/v1/realtime` WebSocket support.
115
-
116
  ### WebSocket client example
117
 
118
  ```python
@@ -165,4 +165,5 @@ GPTQ quantization fails on this model at every bit precision (4-bit and 8-bit) w
165
  ## Credits
166
 
167
  - Base model: [Voxtral Mini 4B Realtime](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602) by Mistral AI
 
168
  - Quantization and Jetson optimization by [Teaspoon AI](https://huggingface.co/Teaspoon-AI)
 
8
  - mistral
9
  - int4
10
  - quantized
11
+ - marlin
12
  - jetson
13
  - edge
14
  - realtime
 
33
 
34
  INT4 quantized [Voxtral Mini 4B Realtime](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602) for edge deployment on NVIDIA Jetson Orin Nano (8 GB).
35
 
36
+ **4.1 GB single file** β€” fits in 8 GB unified memory with room for KV cache and runtime.
37
 
38
  ## What's in this repo
39
 
40
  | File | Size | Description |
41
  |------|------|-------------|
42
+ | `consolidated.safetensors` | 4.1 GB | Marlin-packed INT4 decoder + BF16 encoder/norms/embeddings |
43
  | `params.json` | 1.6 KB | Model architecture config (Mistral native format) |
44
  | `tekken.json` | 15 MB | Mistral tekken tokenizer |
45
+ | `scripts/jetson_serve_sdpa.py` | ~50 KB | Self-contained inference server (no HF/vLLM deps) |
46
+ | `scripts/quantize_marlin.py` | ~6 KB | Quantization script to reproduce this model |
47
  | `kernels/fused_ops.cu` | 8.5 KB | Fused CUDA kernels (JIT compiled, SM87) |
48
 
49
  ## Quantization details
50
 
51
+ - **Method**: RTN (Round-To-Nearest) quantized directly into Marlin-packed format
52
+ - **Bits**: 4-bit (decoder linear layers), BF16 (audio encoder, norms, embeddings)
53
  - **Group size**: 128
54
+ - **Encoding**: uint4b8 (value + 8 bias), Marlin tiled INT4 layout
55
+ - **Single step**: BF16 β†’ RTN quantize β†’ Marlin pack (no intermediate GPTQ format, scales computed once)
56
+ - **Why RTN over GPTQ**: GPTQ's Hessian optimization destroys the critical SPAD-to-text transition boundary in Voxtral's streaming architecture. RTN preserves it perfectly. See [below](#why-rtn-not-gptq).
57
+
58
+ ### Reproducing the quantization
59
+
60
+ ```bash
61
+ pip install torch safetensors numpy
62
+
63
+ # From the original HuggingFace model:
64
+ python scripts/quantize_marlin.py \
65
+ --model-dir path/to/Voxtral-Mini-4B-Realtime-2602 \
66
+ --output-dir ./output
67
+ ```
68
 
69
  ## Architecture
70
 
71
  | Component | Params | Precision | Size |
72
  |-----------|--------|-----------|------|
73
+ | Audio encoder (Whisper-style, 32 layers) | ~600M | BF16 | 1.86 GB |
74
+ | Projector (5120 β†’ 3072 β†’ 3072) | ~25M | BF16 | 0.05 GB |
75
+ | LM decoder (26 layers, 3072 hidden, GQA 32/8 heads) | ~3B | Marlin INT4 | ~1.58 GB |
76
+ | Token embeddings (131072 Γ— 3072) | ~400M | BF16 | 0.77 GB |
77
+ | ada_rms_norm_t_cond + norms | ~1M | BF16 | 0.01 GB |
78
+ | **Total** | **~4B** | | **4.1 GB** |
79
 
80
  ## Transcription quality
81
 
 
96
  No HuggingFace or vLLM dependencies needed. Runs inside the [PyTorch Jetson container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-pytorch).
97
 
98
  ```bash
99
+ pip install safetensors websockets soundfile numpy librosa marlin
100
 
101
  # Test with an audio file
102
  python scripts/jetson_serve_sdpa.py --test audio.wav
 
113
  - Pre-allocated KV cache (eliminates per-token torch.cat)
114
  - Fused CUDA kernels for RMSNorm, RoPE, SiLUΒ·Mul (~500 kernel launches/token β†’ ~80)
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  ### WebSocket client example
117
 
118
  ```python
 
165
  ## Credits
166
 
167
  - Base model: [Voxtral Mini 4B Realtime](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602) by Mistral AI
168
+ - Marlin INT4 kernel: [IST-DASLab/marlin](https://github.com/IST-DASLab/marlin) (Apache 2.0)
169
  - Quantization and Jetson optimization by [Teaspoon AI](https://huggingface.co/Teaspoon-AI)
consolidated.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fd25f9d675042c37b0a9b051a5333ef001c129d302461995bdc3e7b321c3b2b6
3
- size 4382321392
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a88ff5bf15ec7ca11fc7b0ff51148721dcca585f7c356baa2576eee785250d44
3
+ size 4367478888
params.json CHANGED
@@ -54,12 +54,12 @@
54
  "ada_rms_norm_t_cond": true,
55
  "ada_rms_norm_t_cond_dim": 32,
56
  "quantization_config": {
57
- "quant_method": "gptq",
58
  "bits": 4,
59
  "group_size": 128,
60
- "desc_act": false,
61
  "sym": true,
62
- "checkpoint_format": "gptq",
63
- "pack_dtype": "int32"
 
64
  }
65
  }
 
54
  "ada_rms_norm_t_cond": true,
55
  "ada_rms_norm_t_cond_dim": 32,
56
  "quantization_config": {
57
+ "quant_method": "rtn",
58
  "bits": 4,
59
  "group_size": 128,
 
60
  "sym": true,
61
+ "checkpoint_format": "marlin",
62
+ "pack_dtype": "int32",
63
+ "encoding": "uint4b8"
64
  }
65
  }
scripts/jetson_serve_sdpa.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/env python3
2
- """Voxtral Mini 4B Realtime β€” Jetson Orin Nano 8GB inference server.
3
 
4
- Loads INT4-packed GPTQ weights from Mistral native format and serves
5
  transcription via WebSocket at ws://localhost:8000/v1/realtime.
6
 
7
  Key architecture detail: at each decoder position, the input embedding is
@@ -55,6 +55,7 @@ try:
55
  HAS_MARLIN = True
56
  except ImportError:
57
  HAS_MARLIN = False
 
58
 
59
  # Try to JIT-compile fused CUDA kernels (collapses ~500 kernel launches/token to ~80)
60
  HAS_FUSED = False
@@ -103,54 +104,24 @@ DOWNSAMPLE_FACTOR = 4
103
 
104
  # ─── Marlin Fused INT4 Linear ────────────────────────────────────────────────
105
 
106
- class MarlinLinear(nn.Module):
107
- """Linear layer using Marlin fused INT4 dequant+matmul CUDA kernel.
108
 
109
- Repacks GPTQ INT4 weights into Marlin's optimized format at construction time.
110
- Forward pass is a single fused kernel call β€” ~50x faster than on-the-fly dequant.
111
- Memory footprint is identical to GPTQ INT4 (no extra memory needed).
112
  """
113
 
114
- def __init__(self, qweight, scales, qzeros, unpermute=None):
115
  super().__init__()
116
- in_features = qweight.shape[0] * PACK_FACTOR
117
- out_features = qweight.shape[1]
118
- n_groups = scales.shape[0]
119
- self.in_features = in_features
120
- self.out_features = out_features
121
-
122
- # Dequantize GPTQ β†’ fp16, then repack into Marlin format
123
- shifts = torch.arange(0, 32, BITS, device=qweight.device, dtype=torch.int32)
124
- unpacked = (qweight.unsqueeze(0) >> shifts.view(-1, 1, 1)) & 0xF
125
- unpacked = unpacked.permute(1, 0, 2).reshape(in_features, out_features)
126
- unpacked = unpacked.T.reshape(out_features, n_groups, GROUP_SIZE)
127
- s = scales.T.float().unsqueeze(-1)
128
- w_fp16 = ((unpacked.float() - BIAS) * s).reshape(out_features, in_features).half()
129
- del unpacked, s
130
-
131
- if unpermute is not None:
132
- n_heads, hidden_size = unpermute
133
- head_dim = w_fp16.shape[0] // n_heads
134
- w_fp16 = (w_fp16.view(n_heads, 2, head_dim // 2, hidden_size)
135
- .transpose(1, 2)
136
- .reshape(out_features, in_features))
137
-
138
- # Create temporary nn.Linear for Marlin's pack()
139
- linear = nn.Linear(in_features, out_features, bias=False,
140
- dtype=torch.half, device=qweight.device)
141
- linear.weight.data = w_fp16
142
-
143
- # Create Marlin layer and pack (handles permutation + bit packing)
144
- ml = _marlin.Layer(in_features, out_features, groupsize=GROUP_SIZE)
145
- ml.pack(linear, scales.T)
146
- del linear, w_fp16
147
-
148
- # Store Marlin buffers
149
- self.register_buffer('B', ml.B.to(qweight.device))
150
- self.register_buffer('s', ml.s.to(qweight.device))
151
  self.register_buffer('workspace',
152
- torch.zeros(out_features // 128 * 16,
153
- dtype=torch.int, device=qweight.device),
154
  persistent=False)
155
 
156
  def forward(self, x):
@@ -161,85 +132,6 @@ class MarlinLinear(nn.Module):
161
  return C
162
 
163
 
164
- # ─── GPTQ INT4 Dequantization (fallback when Marlin unavailable) ────────────
165
-
166
- class DequantLinear(nn.Module):
167
- """Linear layer with INT4 GPTQ packed weights.
168
-
169
- Supports two modes:
170
- - On-the-fly dequantization (default): dequantizes each forward call
171
- - Cached mode: stores pre-dequantized fp16 weight for fast matmul
172
- """
173
-
174
- _shifts = None # class-level cached shifts tensor
175
-
176
- def __init__(self, qweight, scales, qzeros, unpermute=None):
177
- super().__init__()
178
- self.register_buffer('qweight', qweight)
179
- self.register_buffer('scales', scales)
180
- self.register_buffer('qzeros', qzeros)
181
- self.in_features = qweight.shape[0] * PACK_FACTOR
182
- self.out_features = qweight.shape[1]
183
- self.unpermute = unpermute
184
- self._cached_w = None # pre-dequantized fp16 weight [out, in]
185
-
186
- def cache_weight(self, free_int4=True):
187
- """Pre-dequantize and cache the fp16 weight.
188
- If free_int4=True, frees INT4 buffers (saves memory, not reversible).
189
- """
190
- self._cached_w = self._dequantize()
191
- if free_int4:
192
- self.qweight = None
193
- self.scales = None
194
- self.qzeros = None
195
-
196
- def uncache_weight(self):
197
- """Free the cached weight (e.g., before re-loading INT4 weights)."""
198
- self._cached_w = None
199
-
200
- @property
201
- def cached_bytes(self):
202
- """Memory used by cached weight in bytes."""
203
- if self._cached_w is not None:
204
- return self._cached_w.nelement() * self._cached_w.element_size()
205
- return 0
206
-
207
- def _dequantize(self):
208
- """Dequantize INT4 packed weights to fp16 [out, in]."""
209
- qw = self.qweight
210
- in_packed, out = qw.shape
211
- n_groups = self.scales.shape[0]
212
-
213
- # Cached shifts tensor (shared across all instances)
214
- if DequantLinear._shifts is None or DequantLinear._shifts.device != qw.device:
215
- DequantLinear._shifts = torch.arange(0, 32, BITS, device=qw.device, dtype=torch.int32)
216
- shifts = DequantLinear._shifts
217
-
218
- # Vectorized unpack: [8, in/8, out]
219
- unpacked = (qw.unsqueeze(0) >> shifts.view(-1, 1, 1)) & 0xF
220
- # Interleave to [in, out] then transpose+group to [out, groups, GROUP_SIZE]
221
- unpacked = unpacked.permute(1, 0, 2).reshape(self.in_features, out)
222
- unpacked = unpacked.T.reshape(out, n_groups, GROUP_SIZE)
223
- # Scale: (val - 8) * scale
224
- s = self.scales.T.float().unsqueeze(-1)
225
- w = ((unpacked.float() - BIAS) * s).reshape(out, self.in_features).half()
226
- del unpacked, s
227
-
228
- if self.unpermute is not None:
229
- n_heads, hidden_size = self.unpermute
230
- head_dim = w.shape[0] // n_heads
231
- w = (w.view(n_heads, 2, head_dim // 2, hidden_size)
232
- .transpose(1, 2)
233
- .reshape(out, self.in_features))
234
- return w
235
-
236
- def forward(self, x):
237
- if self._cached_w is not None:
238
- return F.linear(x, self._cached_w)
239
- w = self._dequantize()
240
- result = F.linear(x, w)
241
- del w
242
- return result
243
 
244
 
245
  # ─── Building Blocks ─────────────────────────────────────────────────────────
@@ -629,14 +521,13 @@ class VoxtralModel:
629
  return F.linear(h, self.embed.weight)
630
 
631
  def _dql(self, f, prefix, dev, unpermute=None):
632
- qw = f.get_tensor(f'{prefix}.qweight').to(dev)
633
- sc = f.get_tensor(f'{prefix}.scales').to(dev)
634
- qz = f.get_tensor(f'{prefix}.qzeros').to(dev)
635
- in_f = qw.shape[0] * PACK_FACTOR
636
- out_f = qw.shape[1]
637
- if HAS_MARLIN and in_f % 128 == 0 and out_f % 256 == 0:
638
- return MarlinLinear(qw, sc, qz, unpermute=unpermute)
639
- return DequantLinear(qw, sc, qz, unpermute=unpermute)
640
 
641
  def _set(self, module, name, tensor):
642
  """Replace a meta parameter with a real CUDA tensor."""
@@ -761,8 +652,7 @@ class VoxtralModel:
761
  print(f" LM layers {start}-{end-1} loaded")
762
  self._load_section(path, load_dec_batch)
763
 
764
- backend = "Marlin fused INT4" if HAS_MARLIN else "DequantLinear"
765
- print(f" LM decoder loaded ({self.n_layers} layers, {backend})")
766
  gc.collect()
767
  torch.cuda.empty_cache()
768
  mem = torch.cuda.memory_allocated() / 1024**3
@@ -785,128 +675,6 @@ class VoxtralModel:
785
  self.tokenizer = None
786
  print(" WARNING: mistral_common not available, using fallback decoder")
787
 
788
- def _pre_dequant(self):
789
- """Offload encoder to CPU and pre-dequantize decoder weights into GPU cache.
790
-
791
- After encoding is done, the encoder (~1.86 GB) is no longer needed on GPU.
792
- Offloading it frees memory for caching pre-dequantized decoder weights,
793
- which eliminates the per-token dequantization overhead.
794
- """
795
- import gc
796
-
797
- if hasattr(self, '_decoder_cached') and self._decoder_cached:
798
- return # already cached
799
-
800
- t0 = time.time()
801
-
802
- # Move encoder + projector to CPU to free GPU memory
803
- self.encoder.cpu()
804
- self.projector.cpu()
805
- gc.collect()
806
- torch.cuda.empty_cache()
807
- self._evict_cache()
808
-
809
- free, _ = torch.cuda.mem_get_info(0)
810
- print(f" After encoder offload: {free/1024**3:.2f} GB free")
811
-
812
- # Budget: leave 500 MB for KV cache + intermediates
813
- budget = free - 500 * 1024 * 1024
814
- used_bytes = 0
815
- cached_count = 0
816
-
817
- for i, dl in enumerate(self.layers):
818
- projs = [dl.attn.q_proj, dl.attn.k_proj, dl.attn.v_proj, dl.attn.o_proj,
819
- dl.gate_proj, dl.up_proj, dl.down_proj]
820
-
821
- # Estimate net memory cost (fp16 weight minus freed INT4 buffers)
822
- layer_fp16 = sum(
823
- p.in_features * p.out_features * 2
824
- for p in projs if isinstance(p, DequantLinear) and p._cached_w is None
825
- )
826
- layer_int4 = sum(
827
- p.qweight.nelement() * 4 + p.scales.nelement() * 2 + p.qzeros.nelement() * 4
828
- for p in projs
829
- if isinstance(p, DequantLinear) and p.qweight is not None
830
- )
831
- net = layer_fp16 - layer_int4 # net increase in memory
832
-
833
- if used_bytes + net > budget:
834
- break
835
-
836
- for p in projs:
837
- if isinstance(p, DequantLinear) and p._cached_w is None and p.qweight is not None:
838
- p.cache_weight(free_int4=True)
839
-
840
- used_bytes += net
841
- cached_count += 1
842
- # Periodic cleanup to keep peak memory low
843
- if cached_count % 5 == 0:
844
- gc.collect()
845
- torch.cuda.empty_cache()
846
-
847
- gc.collect()
848
- torch.cuda.empty_cache()
849
- free2, _ = torch.cuda.mem_get_info(0)
850
- print(f" Pre-dequantized {cached_count}/{self.n_layers} layers in {time.time()-t0:.1f}s, "
851
- f"{free2/1024**3:.2f} GB free")
852
- self._decoder_cached = True
853
-
854
- def _restore_encoder(self):
855
- """Move encoder back to GPU for the next transcription.
856
-
857
- Frees cached decoder weights first to make room, then reloads
858
- INT4 weights for layers that had their buffers freed.
859
- """
860
- import gc
861
-
862
- if not hasattr(self, '_decoder_cached') or not self._decoder_cached:
863
- return
864
-
865
- t0 = time.time()
866
-
867
- # Free cached decoder weights
868
- needs_reload = []
869
- for i, dl in enumerate(self.layers):
870
- for p in [dl.attn.q_proj, dl.attn.k_proj, dl.attn.v_proj, dl.attn.o_proj,
871
- dl.gate_proj, dl.up_proj, dl.down_proj]:
872
- if isinstance(p, DequantLinear):
873
- if p._cached_w is not None and p.qweight is None:
874
- needs_reload.append(i)
875
- p.uncache_weight()
876
-
877
- gc.collect()
878
- torch.cuda.empty_cache()
879
- self._evict_cache()
880
-
881
- # Move encoder + projector back to GPU
882
- self.encoder.to(self.device)
883
- self.projector.to(self.device)
884
-
885
- # Reload INT4 weights for layers that were freed
886
- if needs_reload:
887
- needs_reload = sorted(set(needs_reload))
888
- path = os.path.join(self.model_path, 'consolidated.safetensors')
889
- D = str(self.device)
890
- with safe_open(path, framework='pt', device=D) as f:
891
- for i in needs_reload:
892
- lp = f'layers.{i}'
893
- dl = self.layers[i]
894
- dl.attn.q_proj = self._dql(f, f'{lp}.self_attn.q_proj', D)
895
- dl.attn.k_proj = self._dql(f, f'{lp}.self_attn.k_proj', D)
896
- dl.attn.v_proj = self._dql(f, f'{lp}.self_attn.v_proj', D)
897
- dl.attn.o_proj = self._dql(f, f'{lp}.self_attn.o_proj', D)
898
- dl.gate_proj = self._dql(f, f'{lp}.mlp.gate_proj', D)
899
- dl.up_proj = self._dql(f, f'{lp}.mlp.up_proj', D)
900
- dl.down_proj = self._dql(f, f'{lp}.mlp.down_proj', D)
901
- gc.collect()
902
- torch.cuda.empty_cache()
903
- print(f" Reloaded {len(needs_reload)} decoder layers from disk")
904
-
905
- gc.collect()
906
- torch.cuda.empty_cache()
907
- self._decoder_cached = False
908
- print(f" Encoder restored in {time.time()-t0:.1f}s")
909
-
910
  def decode_tokens(self, ids):
911
  if self.tokenizer is not None:
912
  try:
@@ -950,10 +718,6 @@ class VoxtralModel:
950
  free, _ = torch.cuda.mem_get_info(0)
951
  print(f" CUDA free before inference: {free/1024**3:.2f} GB")
952
 
953
- # Restore encoder to GPU if it was offloaded (only needed without Marlin)
954
- if not HAS_MARLIN:
955
- self._restore_encoder()
956
-
957
  # 0. Pad audio for streaming alignment
958
  audio = self._pad_audio(audio)
959
  print(f" padded audio: {len(audio)} samples ({len(audio)/SAMPLE_RATE:.1f}s)")
@@ -985,11 +749,7 @@ class VoxtralModel:
985
  del enc_ds
986
  print(f" adapter: {adapter.shape}")
987
 
988
- # 5. Offload encoder, pre-dequantize decoder weights (only without Marlin)
989
- if not HAS_MARLIN:
990
- self._pre_dequant()
991
-
992
- # 6. Build prompt: [BOS] + [SPAD] * (n_left_pad + delay_tokens)
993
  prompt_len = 1 + self.n_left_pad + self.delay_tokens
994
  prompt_ids = [TOKEN_BOS] + [TOKEN_STREAMING_PAD] * (self.n_left_pad + self.delay_tokens)
995
 
 
1
  #!/usr/bin/env python3
2
+ """Voxtral Mini 4B Realtime β€” Jetson Orin Nano inference server.
3
 
4
+ Loads Marlin-packed INT4 weights from consolidated.safetensors and serves
5
  transcription via WebSocket at ws://localhost:8000/v1/realtime.
6
 
7
  Key architecture detail: at each decoder position, the input embedding is
 
55
  HAS_MARLIN = True
56
  except ImportError:
57
  HAS_MARLIN = False
58
+ print("WARNING: Marlin not installed. Install with: pip install marlin")
59
 
60
  # Try to JIT-compile fused CUDA kernels (collapses ~500 kernel launches/token to ~80)
61
  HAS_FUSED = False
 
104
 
105
  # ─── Marlin Fused INT4 Linear ────────────────────────────────────────────────
106
 
107
+ class PrepackedMarlinLinear(nn.Module):
108
+ """Linear layer using pre-packed Marlin INT4 weights from safetensors.
109
 
110
+ Loads .B and .s tensors directly — no GPTQ→Marlin conversion needed.
111
+ Used with single-file consolidated.safetensors that already contains
112
+ Marlin-format weights.
113
  """
114
 
115
+ def __init__(self, B, s):
116
  super().__init__()
117
+ # B: [K//16, 2*N] int32, s: [K//groupsize, N] fp16
118
+ self.in_features = B.shape[0] * 16
119
+ self.out_features = B.shape[1] // 2
120
+ self.register_buffer('B', B)
121
+ self.register_buffer('s', s)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  self.register_buffer('workspace',
123
+ torch.zeros(self.out_features // 128 * 16,
124
+ dtype=torch.int, device=B.device),
125
  persistent=False)
126
 
127
  def forward(self, x):
 
132
  return C
133
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
 
137
  # ─── Building Blocks ─────────────────────────────────────────────────────────
 
521
  return F.linear(h, self.embed.weight)
522
 
523
  def _dql(self, f, prefix, dev, unpermute=None):
524
+ B = f.get_tensor(f'{prefix}.B').to(dev)
525
+ s = f.get_tensor(f'{prefix}.s').to(dev)
526
+ if not HAS_MARLIN:
527
+ raise RuntimeError(
528
+ "Marlin INT4 kernel required but not installed. "
529
+ "Install with: pip install marlin")
530
+ return PrepackedMarlinLinear(B, s)
 
531
 
532
  def _set(self, module, name, tensor):
533
  """Replace a meta parameter with a real CUDA tensor."""
 
652
  print(f" LM layers {start}-{end-1} loaded")
653
  self._load_section(path, load_dec_batch)
654
 
655
+ print(f" LM decoder loaded ({self.n_layers} layers, Marlin fused INT4)")
 
656
  gc.collect()
657
  torch.cuda.empty_cache()
658
  mem = torch.cuda.memory_allocated() / 1024**3
 
675
  self.tokenizer = None
676
  print(" WARNING: mistral_common not available, using fallback decoder")
677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
678
  def decode_tokens(self, ids):
679
  if self.tokenizer is not None:
680
  try:
 
718
  free, _ = torch.cuda.mem_get_info(0)
719
  print(f" CUDA free before inference: {free/1024**3:.2f} GB")
720
 
 
 
 
 
721
  # 0. Pad audio for streaming alignment
722
  audio = self._pad_audio(audio)
723
  print(f" padded audio: {len(audio)} samples ({len(audio)/SAMPLE_RATE:.1f}s)")
 
749
  del enc_ds
750
  print(f" adapter: {adapter.shape}")
751
 
752
+ # 5. Build prompt: [BOS] + [SPAD] * (n_left_pad + delay_tokens)
 
 
 
 
753
  prompt_len = 1 + self.n_left_pad + self.delay_tokens
754
  prompt_ids = [TOKEN_BOS] + [TOKEN_STREAMING_PAD] * (self.n_left_pad + self.delay_tokens)
755
 
scripts/quantize_marlin.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Single-step BF16 β†’ Marlin INT4 quantization for Voxtral Realtime 4B.
3
+
4
+ Produces a single consolidated.safetensors with:
5
+ - Encoder + adapter + tok_embeddings + norms: BF16 (copied as-is)
6
+ - Decoder linear weights: Marlin-packed INT4 (group_size=128)
7
+
8
+ The decoder linears are RTN-quantized (round-to-nearest, symmetric, per-group)
9
+ and packed directly into Marlin's tiled INT4 format in one step β€” no intermediate
10
+ GPTQ format, no multiple requantization cycles.
11
+
12
+ Why RTN over GPTQ: GPTQ's Hessian optimization destroys the critical SPAD-to-text
13
+ transition boundary in Voxtral's streaming architecture because calibration runs
14
+ through MistralForCausalLM (without ada_rms_norm_t_cond). RTN preserves it.
15
+
16
+ Marlin pack logic from IST-DASLab/marlin (Apache 2.0):
17
+ https://github.com/IST-DASLab/marlin
18
+
19
+ Usage:
20
+ # From original HuggingFace BF16 model:
21
+ python3 quantize_marlin.py --model-dir path/to/Voxtral-Mini-4B-Realtime-2602
22
+
23
+ # Output (default: ./output/consolidated.safetensors):
24
+ python3 quantize_marlin.py --model-dir path/to/model --output-dir ./my-output
25
+
26
+ Requires: torch, numpy, safetensors
27
+ """
28
+
29
+ import argparse
30
+ import gc
31
+ import json
32
+ import os
33
+ import shutil
34
+ import sys
35
+ import time
36
+
37
+ import numpy as np
38
+ import torch
39
+ from safetensors import safe_open
40
+ from safetensors.torch import save_file
41
+
42
+
43
+ # ─── Model constants ─────────────────────────────────────────────────────────
44
+
45
+ N_LAYERS = 26
46
+ N_HEADS = 32
47
+ N_KV_HEADS = 8
48
+ DIM = 3072
49
+ HEAD_DIM = 128
50
+
51
+ # ─── Quantization constants ──────────────────────────────────────────────────
52
+
53
+ BITS = 4
54
+ GROUP_SIZE = 128
55
+ PACK_FACTOR = 32 // BITS # 8 int4 values per int32
56
+ BIAS = 1 << (BITS - 1) # 8 (uint4b8 encoding: stored = value + 8)
57
+ MAXQ = (1 << BITS) - 1 # 15
58
+
59
+ # ─── Mistral β†’ HF naming for decoder linears ─────────────────────────────────
60
+
61
+ DECODER_LINEARS = {
62
+ "attention.wq": ("self_attn.q_proj", True, N_HEADS), # needs Q/K permute
63
+ "attention.wk": ("self_attn.k_proj", True, N_KV_HEADS), # needs Q/K permute
64
+ "attention.wv": ("self_attn.v_proj", False, None),
65
+ "attention.wo": ("self_attn.o_proj", False, None),
66
+ "feed_forward.w1": ("mlp.gate_proj", False, None),
67
+ "feed_forward.w2": ("mlp.down_proj", False, None),
68
+ "feed_forward.w3": ("mlp.up_proj", False, None),
69
+ }
70
+
71
+
72
+ # ─── Marlin permutation tables (from IST-DASLab/marlin, Apache 2.0) ─────────
73
+
74
+ def _get_perms():
75
+ perm = []
76
+ for i in range(32):
77
+ perm1 = []
78
+ col = i // 4
79
+ for block in [0, 1]:
80
+ for row in [
81
+ 2 * (i % 4),
82
+ 2 * (i % 4) + 1,
83
+ 2 * (i % 4 + 4),
84
+ 2 * (i % 4 + 4) + 1,
85
+ ]:
86
+ perm1.append(16 * row + col + 8 * block)
87
+ for j in range(4):
88
+ perm.extend([p + 256 * j for p in perm1])
89
+
90
+ perm = np.array(perm)
91
+ interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
92
+ perm = perm.reshape((-1, 8))[:, interleave].ravel()
93
+ perm = torch.from_numpy(perm)
94
+
95
+ scale_perm = []
96
+ for i in range(8):
97
+ scale_perm.extend([i + 8 * j for j in range(8)])
98
+
99
+ return perm, scale_perm
100
+
101
+
102
+ _perm, _scale_perm = _get_perms()
103
+
104
+
105
+ # ─── Q/K head permutation (Mistral β†’ HF interleaving) ────────────────────────
106
+
107
+ def permute_qk(w, n_heads, hidden_size):
108
+ """Apply Mistral→HF head dimension interleaving for Q/K weights."""
109
+ head_dim = w.shape[0] // n_heads
110
+ return (
111
+ w.view(n_heads, head_dim // 2, 2, hidden_size)
112
+ .transpose(1, 2)
113
+ .reshape(n_heads * head_dim, hidden_size)
114
+ )
115
+
116
+
117
+ # ─── Single-step RTN quantize + Marlin pack ──────────────────────────────────
118
+
119
+ def quantize_and_pack_marlin(w_bf16, group_size=GROUP_SIZE):
120
+ """RTN-quantize a BF16 weight and pack into Marlin format in one step.
121
+
122
+ Args:
123
+ w_bf16: [N_out, K] BF16/FP16 weight tensor
124
+
125
+ Returns:
126
+ B: [K//16, 2*N_out] int32 (Marlin-packed weights)
127
+ s: [K//group_size, N_out] fp16 (Marlin-permuted scales)
128
+ """
129
+ N_out, K = w_bf16.shape
130
+ n_groups = K // group_size
131
+ tile = 16
132
+
133
+ # ── Step 1: Compute per-group RTN scales ──
134
+ # Work in [K, N] layout for Marlin packing
135
+ w = w_bf16.t().float().contiguous() # [K, N]
136
+ w_grouped = w.reshape(n_groups, group_size, N_out)
137
+ max_val = w_grouped.abs().amax(dim=1).clamp(min=1e-10) # [n_groups, N]
138
+ scales = (max_val / BIAS).half() # [n_groups, N] β€” scale = max_abs / 8
139
+
140
+ # ── Step 2: Quantize to uint4 ──
141
+ s_expanded = scales.float().unsqueeze(1).expand_as(w_grouped) # [n_groups, gs, N]
142
+ w_int = torch.round(w_grouped / s_expanded).clamp(-BIAS, BIAS - 1).int()
143
+ w_uint = (w_int + BIAS).clamp(0, MAXQ) # uint4b8: [-8,7] β†’ [0,15]
144
+ w_uint = w_uint.reshape(K, N_out) # [K, N]
145
+
146
+ # ── Step 3: Permute scales for Marlin ──
147
+ s = scales.clone() # [n_groups, N]
148
+ s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
149
+ s = s.reshape((-1, N_out)).contiguous()
150
+
151
+ # ── Step 4: Tile into 16Γ—16 blocks ──
152
+ w_tiled = w_uint.reshape(K // tile, tile, N_out // tile, tile)
153
+ w_tiled = w_tiled.permute(0, 2, 1, 3)
154
+ w_tiled = w_tiled.reshape(K // tile, N_out * tile)
155
+
156
+ # ── Step 5: Apply Marlin permutation ──
157
+ res = w_tiled.reshape((-1, _perm.numel()))[:, _perm].reshape(w_tiled.shape)
158
+
159
+ # ── Step 6: Pack 8 int4 values into each int32 ──
160
+ q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32)
161
+ res_np = res.cpu().numpy().astype(np.uint32)
162
+ for i in range(8):
163
+ q |= res_np[:, i::8] << (4 * i)
164
+ B = torch.from_numpy(q.astype(np.int32))
165
+
166
+ return B, s.half()
167
+
168
+
169
+ # ─── Main ────────────────────────────────────────────────────────────────────
170
+
171
+ def main():
172
+ parser = argparse.ArgumentParser(
173
+ description="Quantize Voxtral BF16 β†’ single-file Marlin INT4")
174
+ parser.add_argument("--model-dir", required=True,
175
+ help="Directory with consolidated.safetensors (BF16, Mistral format)")
176
+ parser.add_argument("--output-dir", default="./output",
177
+ help="Output directory (default: ./output)")
178
+ args = parser.parse_args()
179
+
180
+ sf_path = os.path.join(args.model_dir, "consolidated.safetensors")
181
+ if not os.path.exists(sf_path):
182
+ print(f"Error: {sf_path} not found", file=sys.stderr)
183
+ sys.exit(1)
184
+
185
+ os.makedirs(args.output_dir, exist_ok=True)
186
+ output_path = os.path.join(args.output_dir, "consolidated.safetensors")
187
+
188
+ print(f"Input: {sf_path}")
189
+ print(f"Output: {output_path}")
190
+ print(f"Quantization: RTN {BITS}-bit, group_size={GROUP_SIZE}, uint4b8 Marlin")
191
+ print()
192
+
193
+ sf = safe_open(sf_path, framework="pt", device="cpu")
194
+ all_keys = list(sf.keys())
195
+ tensors = {}
196
+ t0 = time.time()
197
+
198
+ # ── Pass 1: Copy non-decoder-linear tensors as-is ──
199
+ # These are encoder, adapter, tok_embeddings, norms, ada_rms_norm, final norm
200
+ decoder_linear_keys = set()
201
+ for layer_idx in range(N_LAYERS):
202
+ for mistral_name in DECODER_LINEARS:
203
+ decoder_linear_keys.add(f"layers.{layer_idx}.{mistral_name}.weight")
204
+
205
+ n_copied = 0
206
+ for key in all_keys:
207
+ if key in decoder_linear_keys:
208
+ continue
209
+ tensors[key] = sf.get_tensor(key)
210
+ n_copied += 1
211
+
212
+ print(f"Copied {n_copied} non-linear tensors (encoder, norms, embeddings, etc.)")
213
+
214
+ # ── Pass 2: Quantize decoder linears β†’ Marlin ──
215
+ n_quantized = 0
216
+ for layer_idx in range(N_LAYERS):
217
+ for mistral_name, (hf_name, needs_permute, n_heads) in DECODER_LINEARS.items():
218
+ src_key = f"layers.{layer_idx}.{mistral_name}.weight"
219
+ w = sf.get_tensor(src_key).half() # bf16 β†’ fp16 for torch ops
220
+
221
+ # Apply Q/K head permutation if needed
222
+ if needs_permute:
223
+ w = permute_qk(w, n_heads, DIM)
224
+
225
+ # Single-step quantize + Marlin pack
226
+ B, s = quantize_and_pack_marlin(w)
227
+ del w
228
+
229
+ out_prefix = f"layers.{layer_idx}.{hf_name}"
230
+ tensors[f"{out_prefix}.B"] = B
231
+ tensors[f"{out_prefix}.s"] = s
232
+ n_quantized += 1
233
+
234
+ gc.collect()
235
+ elapsed = time.time() - t0
236
+ print(f" Layer {layer_idx + 1}/{N_LAYERS} quantized ({elapsed:.1f}s)")
237
+
238
+ print(f"\nQuantized {n_quantized} decoder linear weights to Marlin INT4")
239
+ print(f"Total tensors in output: {len(tensors)}")
240
+
241
+ # ── Save ──
242
+ print(f"\nSaving to {output_path}...")
243
+ save_file(tensors, output_path)
244
+ file_size = os.path.getsize(output_path)
245
+ print(f"Output: {file_size / (1024**3):.2f} GB ({len(tensors)} tensors)")
246
+
247
+ # ── Copy auxiliary files ──
248
+ for aux in ["params.json", "tekken.json"]:
249
+ src = os.path.join(args.model_dir, aux)
250
+ if os.path.exists(src):
251
+ shutil.copy2(src, os.path.join(args.output_dir, aux))
252
+ print(f"Copied {aux}")
253
+
254
+ print(f"\nDone in {time.time() - t0:.1f}s")
255
+
256
+ # ── Verify tensor names ──
257
+ print(f"\nSample Marlin tensor names:")
258
+ marlin_keys = sorted(k for k in tensors if k.endswith(".B"))[:5]
259
+ for k in marlin_keys:
260
+ print(f" {k}: {list(tensors[k].shape)} {tensors[k].dtype}")
261
+ sk = k[:-2] + ".s"
262
+ print(f" {sk}: {list(tensors[sk].shape)} {tensors[sk].dtype}")
263
+
264
+
265
+ if __name__ == "__main__":
266
+ main()