tsp-stefano's picture
Fix file sizes in README (4.4 GB), clean up wording
1ed0d94 verified
---
license: apache-2.0
base_model: mistralai/Voxtral-Mini-4B-Realtime-2602
base_model_relation: quantized
tags:
- speech-to-text
- voxtral
- mistral
- int4
- quantized
- marlin
- jetson
- edge
- realtime
- streaming
language:
- en
- fr
- es
- de
- ru
- zh
- ja
- it
- pt
- nl
- ar
- hi
- ko
---
# Voxtral Mini 4B INT4 — Jetson Orin Nano
INT4 quantized [Voxtral Mini 4B Realtime](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602) for edge deployment on NVIDIA Jetson Orin Nano (8 GB).
**4.4 GB** — fits in 8 GB unified memory with room for KV cache and runtime.
## What's in this repo
| File | Size | Description |
|------|------|-------------|
| `consolidated.safetensors` | 4.4 GB | Marlin-packed INT4 decoder + BF16 encoder/norms/embeddings |
| `params.json` | 1.6 KB | Model architecture config (Mistral native format) |
| `tekken.json` | 15 MB | Mistral tekken tokenizer |
| `requirements.txt` | — | Pinned Python dependencies for Jetson |
| `scripts/jetson_serve_sdpa.py` | ~50 KB | Self-contained inference server (no HF/vLLM deps) |
| `scripts/quantize_marlin.py` | ~10 KB | Quantization script to reproduce this model |
| `kernels/fused_ops.cu` | 8.5 KB | Fused CUDA kernels (JIT compiled, SM87) |
## Quantization details
- **Method**: RTN (Round-To-Nearest) quantized directly into Marlin-packed format
- **Bits**: 4-bit (decoder linear layers), BF16 (audio encoder, norms, embeddings)
- **Group size**: 128
- **Encoding**: uint4b8 (value + 8 bias), Marlin tiled INT4 layout
- **Why RTN over GPTQ**: GPTQ's Hessian optimization destroys the critical SPAD-to-text transition boundary in Voxtral's streaming architecture. RTN preserves it perfectly. See [below](#why-rtn-not-gptq).
### Reproducing the quantization
```bash
pip install torch safetensors numpy
# From the original HuggingFace model:
python scripts/quantize_marlin.py \
--model-dir path/to/Voxtral-Mini-4B-Realtime-2602 \
--output-dir ./output
```
## Architecture
| Component | Params | Precision | Size |
|-----------|--------|-----------|------|
| Audio encoder (Whisper-style, 32 layers) | ~600M | BF16 | 1.86 GB |
| Projector (5120 → 3072 → 3072) | ~25M | BF16 | 0.05 GB |
| LM decoder (26 layers, 3072 hidden, GQA 32/8 heads) | ~3B | Marlin INT4 | ~1.70 GB |
| Token embeddings (131072 × 3072) | ~400M | BF16 | 0.77 GB |
| ada_rms_norm_t_cond + norms | ~1M | BF16 | 0.01 GB |
| **Total** | **~4B** | | **4.4 GB** |
## Transcription quality
Tested on Fleurs en_us samples — near-perfect output matching the fp16 baseline:
| Sample | Quality | Notes |
|--------|---------|-------|
| 0 — communication channels | Excellent | Punctuation added, matches reference |
| 1 — capital letters | Good | "sie" → "say" (phonetic) |
| 2 — town of Sintra | Excellent | Full match |
| 3 — cabbage juice | Excellent | Full match |
| 4 — dinosaurs with feathers | Perfect | Exact match |
## Usage
### Self-contained server (recommended for Jetson)
No HuggingFace or vLLM dependencies needed. Tested on JetPack 6.x (R36.5.0), Python 3.10, CUDA 12.6.
```bash
pip install -r requirements.txt
# Test with an audio file
python scripts/jetson_serve_sdpa.py --test audio.wav
# Start WebSocket server on port 8000
python scripts/jetson_serve_sdpa.py
```
The server exposes `ws://localhost:8000/v1/realtime` for streaming transcription.
**Key optimizations in the server:**
- Marlin fused INT4 dequant+matmul (~50x faster than on-the-fly dequant)
- F.scaled_dot_product_attention (fused attention kernel)
- Pre-allocated KV cache (eliminates per-token torch.cat)
- Fused CUDA kernels for RMSNorm, RoPE, SiLU·Mul (~500 kernel launches/token → ~80)
### WebSocket client example
```python
import asyncio, base64, json, numpy as np, soundfile as sf, websockets
async def transcribe(audio_path):
audio, sr = sf.read(audio_path, dtype="float32")
pcm16 = (audio * 32768.0).clip(-32768, 32767).astype(np.int16)
async with websockets.connect("ws://localhost:8000/v1/realtime") as ws:
await ws.recv() # session.created
await ws.send(json.dumps({"type": "session.update"}))
# Send audio in 500ms chunks
for i in range(0, len(pcm16), 8000):
chunk = base64.b64encode(pcm16[i:i+8000].tobytes()).decode()
await ws.send(json.dumps({"type": "input_audio_buffer.append", "audio": chunk}))
await ws.send(json.dumps({"type": "input_audio_buffer.commit"}))
text = ""
while True:
msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=60))
if msg["type"] == "transcription.delta":
text += msg["delta"]
elif msg["type"] == "transcription.done":
break
return text
```
## Memory budget (Jetson Orin Nano 8 GB)
| Component | Size |
|-----------|------|
| Model weights | 4.4 GB |
| Runtime + KV cache | ~1.5 GB |
| OS + system | ~2 GB |
| **Total** | **~7.9 GB** |
## Why RTN, not GPTQ?
GPTQ quantization fails on this model at every bit precision (4-bit and 8-bit) with every calibration strategy tested. The root cause:
1. **Architecture mismatch during calibration**: GPTQ processes layers through standard `MistralDecoderLayer` which lacks `ada_rms_norm_t_cond`. The MLP sees wrong activations during Hessian estimation.
2. **Critical decision boundary**: Voxtral's streaming protocol requires the model to transition from STREAMING_PAD tokens to text tokens at precise positions. This transition margin is only ~5-10 logit points. GPTQ's optimization noise is enough to prevent the transition entirely.
3. **RTN preserves the boundary**: Simple round-to-nearest quantization at 4-bit with group_size=128 preserves the SPAD→text transition perfectly, producing output identical to the fp16 baseline.
## Credits
- Base model: [Voxtral Mini 4B Realtime](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602) by Mistral AI
- Marlin INT4 kernel: [IST-DASLab/marlin](https://github.com/IST-DASLab/marlin) (Apache 2.0)
- Quantization and Jetson optimization by [Teaspoon AI](https://huggingface.co/Teaspoon-AI)