--- license: apache-2.0 base_model: mistralai/Voxtral-Mini-4B-Realtime-2602 base_model_relation: quantized tags: - speech-to-text - voxtral - mistral - int4 - quantized - marlin - jetson - edge - realtime - streaming language: - en - fr - es - de - ru - zh - ja - it - pt - nl - ar - hi - ko --- # Voxtral Mini 4B INT4 — Jetson Orin Nano INT4 quantized [Voxtral Mini 4B Realtime](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602) for edge deployment on NVIDIA Jetson Orin Nano (8 GB). **4.4 GB** — fits in 8 GB unified memory with room for KV cache and runtime. ## What's in this repo | File | Size | Description | |------|------|-------------| | `consolidated.safetensors` | 4.4 GB | Marlin-packed INT4 decoder + BF16 encoder/norms/embeddings | | `params.json` | 1.6 KB | Model architecture config (Mistral native format) | | `tekken.json` | 15 MB | Mistral tekken tokenizer | | `requirements.txt` | — | Pinned Python dependencies for Jetson | | `scripts/jetson_serve_sdpa.py` | ~50 KB | Self-contained inference server (no HF/vLLM deps) | | `scripts/quantize_marlin.py` | ~10 KB | Quantization script to reproduce this model | | `kernels/fused_ops.cu` | 8.5 KB | Fused CUDA kernels (JIT compiled, SM87) | ## Quantization details - **Method**: RTN (Round-To-Nearest) quantized directly into Marlin-packed format - **Bits**: 4-bit (decoder linear layers), BF16 (audio encoder, norms, embeddings) - **Group size**: 128 - **Encoding**: uint4b8 (value + 8 bias), Marlin tiled INT4 layout - **Why RTN over GPTQ**: GPTQ's Hessian optimization destroys the critical SPAD-to-text transition boundary in Voxtral's streaming architecture. RTN preserves it perfectly. See [below](#why-rtn-not-gptq). ### Reproducing the quantization ```bash pip install torch safetensors numpy # From the original HuggingFace model: python scripts/quantize_marlin.py \ --model-dir path/to/Voxtral-Mini-4B-Realtime-2602 \ --output-dir ./output ``` ## Architecture | Component | Params | Precision | Size | |-----------|--------|-----------|------| | Audio encoder (Whisper-style, 32 layers) | ~600M | BF16 | 1.86 GB | | Projector (5120 → 3072 → 3072) | ~25M | BF16 | 0.05 GB | | LM decoder (26 layers, 3072 hidden, GQA 32/8 heads) | ~3B | Marlin INT4 | ~1.70 GB | | Token embeddings (131072 × 3072) | ~400M | BF16 | 0.77 GB | | ada_rms_norm_t_cond + norms | ~1M | BF16 | 0.01 GB | | **Total** | **~4B** | | **4.4 GB** | ## Transcription quality Tested on Fleurs en_us samples — near-perfect output matching the fp16 baseline: | Sample | Quality | Notes | |--------|---------|-------| | 0 — communication channels | Excellent | Punctuation added, matches reference | | 1 — capital letters | Good | "sie" → "say" (phonetic) | | 2 — town of Sintra | Excellent | Full match | | 3 — cabbage juice | Excellent | Full match | | 4 — dinosaurs with feathers | Perfect | Exact match | ## Usage ### Self-contained server (recommended for Jetson) No HuggingFace or vLLM dependencies needed. Tested on JetPack 6.x (R36.5.0), Python 3.10, CUDA 12.6. ```bash pip install -r requirements.txt # Test with an audio file python scripts/jetson_serve_sdpa.py --test audio.wav # Start WebSocket server on port 8000 python scripts/jetson_serve_sdpa.py ``` The server exposes `ws://localhost:8000/v1/realtime` for streaming transcription. **Key optimizations in the server:** - Marlin fused INT4 dequant+matmul (~50x faster than on-the-fly dequant) - F.scaled_dot_product_attention (fused attention kernel) - Pre-allocated KV cache (eliminates per-token torch.cat) - Fused CUDA kernels for RMSNorm, RoPE, SiLU·Mul (~500 kernel launches/token → ~80) ### WebSocket client example ```python import asyncio, base64, json, numpy as np, soundfile as sf, websockets async def transcribe(audio_path): audio, sr = sf.read(audio_path, dtype="float32") pcm16 = (audio * 32768.0).clip(-32768, 32767).astype(np.int16) async with websockets.connect("ws://localhost:8000/v1/realtime") as ws: await ws.recv() # session.created await ws.send(json.dumps({"type": "session.update"})) # Send audio in 500ms chunks for i in range(0, len(pcm16), 8000): chunk = base64.b64encode(pcm16[i:i+8000].tobytes()).decode() await ws.send(json.dumps({"type": "input_audio_buffer.append", "audio": chunk})) await ws.send(json.dumps({"type": "input_audio_buffer.commit"})) text = "" while True: msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=60)) if msg["type"] == "transcription.delta": text += msg["delta"] elif msg["type"] == "transcription.done": break return text ``` ## Memory budget (Jetson Orin Nano 8 GB) | Component | Size | |-----------|------| | Model weights | 4.4 GB | | Runtime + KV cache | ~1.5 GB | | OS + system | ~2 GB | | **Total** | **~7.9 GB** | ## Why RTN, not GPTQ? GPTQ quantization fails on this model at every bit precision (4-bit and 8-bit) with every calibration strategy tested. The root cause: 1. **Architecture mismatch during calibration**: GPTQ processes layers through standard `MistralDecoderLayer` which lacks `ada_rms_norm_t_cond`. The MLP sees wrong activations during Hessian estimation. 2. **Critical decision boundary**: Voxtral's streaming protocol requires the model to transition from STREAMING_PAD tokens to text tokens at precise positions. This transition margin is only ~5-10 logit points. GPTQ's optimization noise is enough to prevent the transition entirely. 3. **RTN preserves the boundary**: Simple round-to-nearest quantization at 4-bit with group_size=128 preserves the SPAD→text transition perfectly, producing output identical to the fp16 baseline. ## Credits - Base model: [Voxtral Mini 4B Realtime](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602) by Mistral AI - Marlin INT4 kernel: [IST-DASLab/marlin](https://github.com/IST-DASLab/marlin) (Apache 2.0) - Quantization and Jetson optimization by [Teaspoon AI](https://huggingface.co/Teaspoon-AI)