Spaces:
Running
Running
chore: merge master → dev/video-fer (SSE transcribe-stream)
Browse filesChanges from master:
- Switch to evoxtral-rl model
- RL training pipeline, technical report (docs only, PDF excluded)
- Updated README and model card
From dev/video-fer:
- FER LFS pointer detection and auto-download
- Streaming emotion/valence/arousal from bracket tags
- Inline bracket badges with seek-on-click
- Async job polling in proxy (fallback for local inference)
- README.md +82 -164
- api/main.py +2 -2
- docs/model_card/README.md +94 -11
- docs/research/references.md +84 -0
- docs/technical_report.md +300 -0
- docs/technical_report.tex +383 -0
- space/app.py +1 -1
- training/scripts/rl_modal.py +506 -0
- training/scripts/serve_modal.py +120 -34
- training/scripts/train_modal.py +165 -17
- web/src/app/api/speech-to-text/route.ts +9 -4
- web/src/app/api/transcribe-stream/route.ts +72 -0
- web/src/app/studio/page.tsx +96 -54
README.md
CHANGED
|
@@ -10,209 +10,127 @@ pinned: false
|
|
| 10 |
|
| 11 |
# Ethos Studio — Emotional Speech Recognition
|
| 12 |
|
| 13 |
-
|
| 14 |
|
| 15 |
-
|
| 16 |
|
| 17 |
-
##
|
| 18 |
-
- **Character-level text highlighting** — transcript text sweeps dark→gray in sync with playback position, character by character
|
| 19 |
-
- **Click-to-seek** — click any character in the transcript to jump the timeline to that exact moment; uses `caretRangeFromPoint` for precision
|
| 20 |
-
- **Inline `[bracket]` badges** — paralinguistic tags produced by Voxtral (e.g. `[laughs]`, `[sighs]`) render as pill badges at their exact inline position, not appended at the end; clicking a badge seeks to the moment just before it
|
| 21 |
-
- **Bidirectional timeline ↔ transcript sync** — scrolling/clicking the timeline highlights the active segment in the transcript and auto-scrolls it into view; clicking a segment row seeks the timeline
|
| 22 |
-
- **Per-segment state** (`past` / `active` / `future`) with opacity transitions
|
| 23 |
|
| 24 |
-
###
|
| 25 |
-
- **Streaming speech emotion** — the Speech emotion badge updates sub-segment as playback passes each `[bracket]` tag; timing is estimated from the tag's character position proportional to segment duration
|
| 26 |
-
- **Streaming valence & arousal bars** — both bars transition to the bracket tag's valence/arousal values at the same moment, creating a continuous emotional arc within each segment
|
| 27 |
-
- **Per-second face emotion** (video only) — the Face badge updates every second from the `face_emotion_timeline` returned by the FER pipeline, more granular than the per-segment majority vote
|
| 28 |
-
- **Live indicator** — animated green dot appears during playback
|
| 29 |
|
| 30 |
-
|
| 31 |
-
- **Click-to-seek** on the track area
|
| 32 |
-
- **Active segment highlight** with ring indicator
|
| 33 |
-
- **Played-region overlay** — subtle tint left of the playhead
|
| 34 |
-
- **Dot + line playhead** design
|
| 35 |
|
| 36 |
-
|
| 37 |
-
- Video files (`.mp4`, `.mkv`, `.avi`, `.mov`, `.m4v`, `.webm`) display inline in the right panel preview area
|
| 38 |
-
- FER runs on video frames and produces both per-segment majority-vote emotion and a per-second `face_emotion_timeline`
|
| 39 |
|
| 40 |
-
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
├── proxy/ # Node.js/Express — API gateway for the frontend
|
| 46 |
-
├── web/ # Next.js — Studio editor UI
|
| 47 |
-
├── training/ # Fine-tuning code (Voxtral LoRA), data prep, eval
|
| 48 |
-
├── docs/ # Specs, model card, hackathon guidelines
|
| 49 |
-
├── models/ # ONNX weights (emotion_model_web.onnx — tracked via Git LFS)
|
| 50 |
-
├── Dockerfile # Single-container HF Spaces build
|
| 51 |
-
├── nginx.conf # Reverse proxy config (port 7860 → :3000/:3030)
|
| 52 |
-
└── supervisord.conf # Process manager for all four services
|
| 53 |
-
```
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
Python API (:8000)
|
| 63 |
-
├─ POST /transcribe-diarize — VAD + Voxtral STT + emotion tags + FER timeline
|
| 64 |
-
└─ POST /fer — per-frame FER via MobileViT-XXS ONNX
|
| 65 |
-
```
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
-
|
| 70 |
-
|
| 71 |
-
| Directory | Port | Role |
|
| 72 |
-
|-----------|------|------|
|
| 73 |
-
| `api/` | 8000 | Voxtral local inference; VAD segmentation; per-segment emotion; FER timeline |
|
| 74 |
-
| `proxy/` | 3000 | API entrypoint; proxies to `api/` |
|
| 75 |
-
| `web/` | 3030 | Next.js Studio UI |
|
| 76 |
-
|
| 77 |
-
## API response format
|
| 78 |
-
|
| 79 |
-
`POST /api/transcribe-diarize` returns:
|
| 80 |
-
|
| 81 |
-
```json
|
| 82 |
-
{
|
| 83 |
-
"filename": "interview.mp4",
|
| 84 |
-
"duration": 42.5,
|
| 85 |
-
"text": "Full transcript...",
|
| 86 |
-
"segments": [
|
| 87 |
-
{
|
| 88 |
-
"id": 1,
|
| 89 |
-
"speaker": "SPEAKER_00",
|
| 90 |
-
"start": 0.0,
|
| 91 |
-
"end": 5.2,
|
| 92 |
-
"text": "Welcome to the show. [laughs]",
|
| 93 |
-
"emotion": "Happy",
|
| 94 |
-
"valence": 0.7,
|
| 95 |
-
"arousal": 0.6,
|
| 96 |
-
"face_emotion": "Happy"
|
| 97 |
-
}
|
| 98 |
-
],
|
| 99 |
-
"has_video": true,
|
| 100 |
-
"face_emotion_timeline": {
|
| 101 |
-
"0": "Neutral",
|
| 102 |
-
"1": "Happy",
|
| 103 |
-
"2": "Happy"
|
| 104 |
-
}
|
| 105 |
-
}
|
| 106 |
-
```
|
| 107 |
|
| 108 |
-
|
| 109 |
|
| 110 |
-
##
|
| 111 |
|
| 112 |
-
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|-----|---------|---------|---------|
|
| 116 |
-
| `[laughs]` / `[laughing]` | Happy | +0.70 | +0.60 |
|
| 117 |
-
| `[sighs]` / `[sighing]` | Sad | −0.30 | −0.30 |
|
| 118 |
-
| `[whispers]` / `[whispering]` | Calm | +0.10 | −0.50 |
|
| 119 |
-
| `[shouts]` / `[shouting]` | Angry | −0.50 | +0.80 |
|
| 120 |
-
| `[exclaims]` | Excited | +0.50 | +0.70 |
|
| 121 |
-
| `[gasps]` | Surprised | +0.20 | +0.70 |
|
| 122 |
-
| `[hesitates]` / `[stutters]` / `[stammers]` | Anxious | −0.20 | +0.35 |
|
| 123 |
-
| `[cries]` / `[crying]` | Sad | −0.70 | +0.40 |
|
| 124 |
-
| `[claps]` / `[applause]` | Happy | +0.60 | +0.50 |
|
| 125 |
-
| `[clears throat]` / `[pause]` | Neutral | 0.00 | ±0.10 |
|
| 126 |
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
|
|
|
|
|
|
|
| 132 |
|
| 133 |
-
```
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
```
|
| 140 |
|
| 141 |
-
|
| 142 |
|
| 143 |
-
|
| 144 |
|
| 145 |
-
###
|
| 146 |
|
| 147 |
```bash
|
| 148 |
-
cd
|
| 149 |
-
|
| 150 |
-
|
|
|
|
| 151 |
```
|
| 152 |
|
| 153 |
-
###
|
| 154 |
|
| 155 |
```bash
|
| 156 |
-
cd
|
| 157 |
-
npm install
|
| 158 |
-
npm run dev
|
| 159 |
```
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
### Quick health check
|
| 164 |
|
| 165 |
```bash
|
| 166 |
-
|
| 167 |
-
curl -X POST http://localhost:3000/api/transcribe-diarize -F "audio=@/path/to/audio.m4a"
|
| 168 |
-
curl -X POST http://localhost:3000/api/transcribe-diarize -F "audio=@/path/to/video.mov"
|
| 169 |
```
|
| 170 |
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
## Models
|
| 174 |
-
|
| 175 |
-
| Model | Purpose | Source |
|
| 176 |
-
|-------|---------|--------|
|
| 177 |
-
| `mistralai/Voxtral-Mini-3B-2507` | Speech-to-text base | HF Hub (downloaded at runtime) |
|
| 178 |
-
| `YongkangZOU/evoxtral-lora` | LoRA adapter — emotion-aware transcription | HF Hub (downloaded at runtime) |
|
| 179 |
-
| `models/emotion_model_web.onnx` | MobileViT-XXS 8-class FER | Stored in repo via Git LFS |
|
| 180 |
-
|
| 181 |
-
FER emotion classes: `Anger | Contempt | Disgust | Fear | Happy | Neutral | Sad | Surprise`
|
| 182 |
-
|
| 183 |
-
## Training
|
| 184 |
|
| 185 |
-
|
| 186 |
|
| 187 |
```bash
|
| 188 |
-
|
| 189 |
-
pip install -r requirements.txt
|
| 190 |
-
# fine-tune on Modal
|
| 191 |
-
python scripts/train_modal.py
|
| 192 |
-
# push adapter to HF Hub
|
| 193 |
-
python scripts/push_hub.py
|
| 194 |
```
|
| 195 |
|
| 196 |
-
|
| 197 |
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
```bash
|
| 203 |
-
# Build and test locally
|
| 204 |
-
docker build -t ethos-studio .
|
| 205 |
-
docker run -p 7860:7860 ethos-studio
|
| 206 |
-
```
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
> MODELS_SHA=$(git ls-tree space/main | grep $'\tmodels$' | awk '{print $3}')
|
| 214 |
-
> TREE_SHA=$((git ls-tree HEAD | grep -v $'\tmodels$'; echo "040000 tree $MODELS_SHA\tmodels") | git mktree)
|
| 215 |
-
> PARENT=$(git rev-parse space/main)
|
| 216 |
-
> COMMIT_SHA=$(git commit-tree "$TREE_SHA" -p "$PARENT" -m "your message")
|
| 217 |
-
> git push space "${COMMIT_SHA}:refs/heads/main"
|
| 218 |
-
> ```
|
|
|
|
| 10 |
|
| 11 |
# Ethos Studio — Emotional Speech Recognition
|
| 12 |
|
| 13 |
+
Built for the **Mistral AI Online Hackathon 2026** (W&B Fine-Tuning Track).
|
| 14 |
|
| 15 |
+
Ethos Studio is a full-stack emotional speech recognition platform combining real-time transcription, facial emotion recognition, and expressive audio tagging. It turns raw speech into richly annotated transcripts with emotions, non-verbal sounds, and delivery cues.
|
| 16 |
|
| 17 |
+
## Key Components
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
### Evoxtral — Expressive Tagged Transcription
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
LoRA finetune of [Voxtral-Mini-3B-2507](https://huggingface.co/mistralai/Voxtral-Mini-3B-2507) that produces transcriptions with inline [ElevenLabs v3](https://elevenlabs.io/docs/api-reference/text-to-speech) audio tags. Two-stage pipeline: **SFT** (3 epochs) → **RL via RAFT** (rejection sampling, 1 epoch).
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
**Standard ASR:** `So I was thinking maybe we could try that new restaurant downtown.`
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
**Evoxtral:** `[nervous] So... [stammers] I was thinking maybe we could... [clears throat] try that new restaurant downtown? [laughs nervously]`
|
| 26 |
|
| 27 |
+
**Two model variants:**
|
| 28 |
+
- **[Evoxtral SFT](https://huggingface.co/YongkangZOU/evoxtral-lora)** — Best transcription accuracy (lowest WER)
|
| 29 |
+
- **[Evoxtral RL](https://huggingface.co/YongkangZOU/evoxtral-rl)** — Best expressive tag accuracy (highest Tag F1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
| Metric | Base Voxtral | Evoxtral SFT | Evoxtral RL | Best |
|
| 32 |
+
|--------|-------------|-------------|------------|------|
|
| 33 |
+
| **WER** ↓ | 6.64% | **4.47%** | 5.12% | SFT |
|
| 34 |
+
| **CER** ↓ | 2.72% | **1.23%** | 1.48% | SFT |
|
| 35 |
+
| **Tag F1** ↑ | 22.0% | 67.2% | **69.4%** | RL |
|
| 36 |
+
| **Tag Recall** ↑ | 22.0% | 69.4% | **72.7%** | RL |
|
| 37 |
+
| **Emphasis F1** ↑ | 42.0% | 84.0% | **86.0%** | RL |
|
| 38 |
|
| 39 |
+
- [SFT Model](https://huggingface.co/YongkangZOU/evoxtral-lora) | [RL Model](https://huggingface.co/YongkangZOU/evoxtral-rl)
|
| 40 |
+
- [Live Demo (HF Space)](https://huggingface.co/spaces/YongkangZOU/evoxtral)
|
| 41 |
+
- [API (Swagger UI)](https://yongkang-zou1999--evoxtral-api-evoxtralmodel-web.modal.run/docs)
|
| 42 |
+
- [W&B Dashboard](https://wandb.ai/yongkang-zou-ai/evoxtral)
|
| 43 |
+
- [Technical Report (PDF)](Evoxtral%20Technical%20Report.pdf) | [LaTeX source](docs/technical_report.tex)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
### FER — Facial Emotion Recognition
|
| 46 |
+
|
| 47 |
+
MobileViT-XXS model trained on 8 emotion classes, exported to ONNX for real-time browser inference.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
**Classes:** Anger, Contempt, Disgust, Fear, Happy, Neutral, Sad, Surprise
|
| 50 |
|
| 51 |
+
### Voxtral Server — Speech-to-Text + Emotion
|
| 52 |
|
| 53 |
+
Speech-to-text service with VAD sentence segmentation and per-segment emotion analysis, powered by [Voxtral Mini 4B](https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602).
|
| 54 |
|
| 55 |
+
## Architecture
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
+
```
|
| 58 |
+
Browser (port 3030) → Server layer (Node, :3000) → Model layer (Python, :8000)
|
| 59 |
+
↑ Studio UI POST /api/speech-to-text POST /transcribe
|
| 60 |
+
↑ Upload dialog POST /api/transcribe-diarize POST /transcribe-diarize
|
| 61 |
+
GET /health GET /health
|
| 62 |
+
```
|
| 63 |
|
| 64 |
+
| Layer | Path | Role |
|
| 65 |
+
|-------|------|------|
|
| 66 |
+
| **Model** | `model/voxtral-server` | Voxtral inference, VAD segmentation, emotion analysis |
|
| 67 |
+
| **Server** | `demo/server` | API entrypoint; proxies to Model |
|
| 68 |
+
| **Frontend** | `demo` | Next.js UI (upload, Studio editor, waveform, timeline) |
|
| 69 |
+
| **Evoxtral** | `training/scripts/` | Training, eval, RL, serving for expressive transcription |
|
| 70 |
+
| **FER** | `models/` | Facial emotion recognition ONNX model |
|
| 71 |
|
| 72 |
+
See [demo/README.md](demo/README.md) for full API and usage; [model/voxtral-server/README.md](model/voxtral-server/README.md) for the Model API.
|
| 73 |
+
|
| 74 |
+
## Project Structure
|
| 75 |
|
| 76 |
+
```
|
| 77 |
+
├── api/ # Python FastAPI — local Voxtral inference + FER
|
| 78 |
+
├── proxy/ # Node.js/Express — API gateway for frontend
|
| 79 |
+
├── web/ # Next.js — Studio editor UI
|
| 80 |
+
├── training/ # Fine-tuning code (SFT + RL), data prep, eval
|
| 81 |
+
│ └── scripts/ # Modal scripts: train, RL (RAFT), eval, serve
|
| 82 |
+
├── space/ # HuggingFace Space (Gradio demo)
|
| 83 |
+
├── models/ # FER ONNX model (MobileViT-XXS)
|
| 84 |
+
├── docs/ # Technical report, design docs, research refs
|
| 85 |
+
├── data/ # Training data scripts (audio files gitignored)
|
| 86 |
+
└── Dockerfile # Single-container HF Spaces build
|
| 87 |
```
|
| 88 |
|
| 89 |
+
## How to Run
|
| 90 |
|
| 91 |
+
**Requirements**: Python 3.10+, Node.js 20+, ffmpeg; GPU recommended.
|
| 92 |
|
| 93 |
+
### Model layer (port 8000)
|
| 94 |
|
| 95 |
```bash
|
| 96 |
+
cd model/voxtral-server
|
| 97 |
+
python -m venv .venv && source .venv/bin/activate
|
| 98 |
+
pip install -r requirements.txt
|
| 99 |
+
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
| 100 |
```
|
| 101 |
|
| 102 |
+
### Server layer (port 3000)
|
| 103 |
|
| 104 |
```bash
|
| 105 |
+
cd demo/server && npm install && npm run dev
|
|
|
|
|
|
|
| 106 |
```
|
| 107 |
|
| 108 |
+
### Frontend (port 3030)
|
|
|
|
|
|
|
| 109 |
|
| 110 |
```bash
|
| 111 |
+
cd demo && npm install && npm run dev
|
|
|
|
|
|
|
| 112 |
```
|
| 113 |
|
| 114 |
+
Open [http://localhost:3030](http://localhost:3030).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
+
### Evoxtral API (Modal)
|
| 117 |
|
| 118 |
```bash
|
| 119 |
+
modal deploy training/scripts/serve_modal.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
```
|
| 121 |
|
| 122 |
+
## Tech Stack
|
| 123 |
|
| 124 |
+
- **Models**: Voxtral-Mini-3B + LoRA, Voxtral-Mini-4B, MobileViT-XXS
|
| 125 |
+
- **Training**: PyTorch, PEFT, Weights & Biases
|
| 126 |
+
- **Inference**: Modal (serverless GPU), HuggingFace ZeroGPU, ONNX Runtime
|
| 127 |
+
- **Backend**: FastAPI, Node.js
|
| 128 |
+
- **Frontend**: Next.js, Gradio
|
| 129 |
|
| 130 |
+
## Links
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
+
- [W&B Project](https://wandb.ai/yongkang-zou-ai/evoxtral) | [W&B Eval Report](https://wandb.ai/yongkang-zou-ai/evoxtral/reports/Evoxtral-—-Evaluation-Results:-Base-vs-SFT-vs-RL--VmlldzoxNjA3MzI3Nw==)
|
| 133 |
+
- [Evoxtral SFT Model](https://huggingface.co/YongkangZOU/evoxtral-lora) | [Evoxtral RL Model](https://huggingface.co/YongkangZOU/evoxtral-rl)
|
| 134 |
+
- [Evoxtral Demo](https://huggingface.co/spaces/YongkangZOU/evoxtral)
|
| 135 |
+
- [Evoxtral API (Swagger)](https://yongkang-zou1999--evoxtral-api-evoxtralmodel-web.modal.run/docs)
|
| 136 |
+
- [Technical Report (PDF)](Evoxtral%20Technical%20Report.pdf) | [LaTeX](docs/technical_report.tex)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/main.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
Evoxtral speech-to-text server (Model layer).
|
| 3 |
-
Runs Voxtral-Mini-3B + evoxtral-
|
| 4 |
tags. For video files, also runs FER (MobileViT-XXS ONNX) per segment.
|
| 5 |
"""
|
| 6 |
import asyncio
|
|
@@ -19,7 +19,7 @@ from fastapi import FastAPI, File, UploadFile, HTTPException
|
|
| 19 |
from fastapi.middleware.cors import CORSMiddleware
|
| 20 |
|
| 21 |
MODEL_ID = os.environ.get("MODEL_ID", "mistralai/Voxtral-Mini-3B-2507")
|
| 22 |
-
ADAPTER_ID = os.environ.get("ADAPTER_ID", "YongkangZOU/evoxtral-
|
| 23 |
MAX_UPLOAD_BYTES = int(os.environ.get("MAX_UPLOAD_MB", "100")) * 1024 * 1024
|
| 24 |
TARGET_SR = 16000
|
| 25 |
|
|
|
|
| 1 |
"""
|
| 2 |
Evoxtral speech-to-text server (Model layer).
|
| 3 |
+
Runs Voxtral-Mini-3B + evoxtral-rl locally for transcription with expressive
|
| 4 |
tags. For video files, also runs FER (MobileViT-XXS ONNX) per segment.
|
| 5 |
"""
|
| 6 |
import asyncio
|
|
|
|
| 19 |
from fastapi.middleware.cors import CORSMiddleware
|
| 20 |
|
| 21 |
MODEL_ID = os.environ.get("MODEL_ID", "mistralai/Voxtral-Mini-3B-2507")
|
| 22 |
+
ADAPTER_ID = os.environ.get("ADAPTER_ID", "YongkangZOU/evoxtral-rl")
|
| 23 |
MAX_UPLOAD_BYTES = int(os.environ.get("MAX_UPLOAD_MB", "100")) * 1024 * 1024
|
| 24 |
TARGET_SR = 16000
|
| 25 |
|
docs/model_card/README.md
CHANGED
|
@@ -9,6 +9,8 @@ tags:
|
|
| 9 |
- audio
|
| 10 |
- mistral
|
| 11 |
- hackathon
|
|
|
|
|
|
|
| 12 |
datasets:
|
| 13 |
- custom
|
| 14 |
language:
|
|
@@ -23,6 +25,10 @@ A LoRA adapter for [Voxtral-Mini-3B-2507](https://huggingface.co/mistralai/Voxtr
|
|
| 23 |
|
| 24 |
Built for the **Mistral AI Online Hackathon 2026** (W&B Fine-Tuning Track).
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
## What It Does
|
| 27 |
|
| 28 |
Standard ASR:
|
|
@@ -31,17 +37,63 @@ Standard ASR:
|
|
| 31 |
Evoxtral:
|
| 32 |
> [nervous] So... [stammers] I was thinking maybe we could... [clears throat] try that new restaurant downtown? [laughs nervously] I mean, if you're free this weekend?
|
| 33 |
|
| 34 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
| **WER** (Word Error Rate) | 6.64% | **4.47%** | 32.7% better |
|
| 39 |
-
| **Tag F1** (Expressive Tag Accuracy) | 22.0% | **67.2%** | 3x better |
|
| 40 |
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
## Training Details
|
| 44 |
|
|
|
|
|
|
|
| 45 |
| Parameter | Value |
|
| 46 |
|-----------|-------|
|
| 47 |
| Base model | `mistralai/Voxtral-Mini-3B-2507` |
|
|
@@ -60,6 +112,20 @@ Evaluated on 50 held-out test samples. The finetuned model dramatically improves
|
|
| 60 |
| Training time | ~25 minutes |
|
| 61 |
| Trainable params | 124.8M / 4.8B (2.6%) |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
## Dataset
|
| 64 |
|
| 65 |
Custom synthetic dataset of 1,010 audio samples generated with ElevenLabs TTS v3:
|
|
@@ -76,7 +142,8 @@ from transformers import VoxtralForConditionalGeneration, AutoProcessor
|
|
| 76 |
from peft import PeftModel
|
| 77 |
|
| 78 |
repo_id = "mistralai/Voxtral-Mini-3B-2507"
|
| 79 |
-
|
|
|
|
| 80 |
|
| 81 |
processor = AutoProcessor.from_pretrained(repo_id)
|
| 82 |
base_model = VoxtralForConditionalGeneration.from_pretrained(
|
|
@@ -102,12 +169,26 @@ print(transcription)
|
|
| 102 |
# [nervous] So... I was thinking maybe we could [clears throat] try that new restaurant downtown?
|
| 103 |
```
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
## W&B Tracking
|
| 106 |
|
| 107 |
All training and evaluation runs are tracked on Weights & Biases:
|
| 108 |
-
- [
|
| 109 |
-
- [
|
| 110 |
-
- [
|
|
|
|
|
|
|
| 111 |
- [Project dashboard](https://wandb.ai/yongkang-zou-ai/evoxtral)
|
| 112 |
|
| 113 |
## Supported Tags
|
|
@@ -119,7 +200,9 @@ The model can produce any tag from the ElevenLabs v3 expressive tag set, includi
|
|
| 119 |
## Limitations
|
| 120 |
|
| 121 |
- Trained on synthetic (TTS-generated) audio, not natural speech recordings
|
| 122 |
-
-
|
|
|
|
|
|
|
| 123 |
- English only
|
| 124 |
- Best results on conversational and emotionally expressive speech
|
| 125 |
|
|
|
|
| 9 |
- audio
|
| 10 |
- mistral
|
| 11 |
- hackathon
|
| 12 |
+
- rl
|
| 13 |
+
- raft
|
| 14 |
datasets:
|
| 15 |
- custom
|
| 16 |
language:
|
|
|
|
| 25 |
|
| 26 |
Built for the **Mistral AI Online Hackathon 2026** (W&B Fine-Tuning Track).
|
| 27 |
|
| 28 |
+
**Two model variants available:**
|
| 29 |
+
- **[Evoxtral SFT](https://huggingface.co/YongkangZOU/evoxtral-lora)** — Best overall transcription accuracy (lowest WER)
|
| 30 |
+
- **[Evoxtral RL](https://huggingface.co/YongkangZOU/evoxtral-rl)** — Best expressive tag accuracy (highest Tag F1)
|
| 31 |
+
|
| 32 |
## What It Does
|
| 33 |
|
| 34 |
Standard ASR:
|
|
|
|
| 37 |
Evoxtral:
|
| 38 |
> [nervous] So... [stammers] I was thinking maybe we could... [clears throat] try that new restaurant downtown? [laughs nervously] I mean, if you're free this weekend?
|
| 39 |
|
| 40 |
+
## Training Pipeline
|
| 41 |
+
|
| 42 |
+
```
|
| 43 |
+
Base Voxtral-Mini-3B → SFT (LoRA, 3 epochs) → RL (RAFT, 1 epoch)
|
| 44 |
+
```
|
| 45 |
|
| 46 |
+
1. **SFT**: LoRA finetuning on 808 synthetic audio samples with expressive tags (lr=2e-4, 3 epochs)
|
| 47 |
+
2. **RL (RAFT)**: Rejection sampling — generate 4 completions per sample, score with rule-based reward (WER accuracy + Tag F1 - hallucination penalty), keep best, then SFT on curated data (lr=5e-5, 1 epoch)
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
This follows the approach from [GRPO for Speech Recognition](https://arxiv.org/abs/2509.01939) and Voxtral's own SFT→DPO training recipe.
|
| 50 |
+
|
| 51 |
+
## Evaluation Results
|
| 52 |
+
|
| 53 |
+
Evaluated on 50 held-out test samples. Full benchmark (Evoxtral-Bench) with 7 metrics:
|
| 54 |
+
|
| 55 |
+
### Core Metrics — Base vs SFT vs RL
|
| 56 |
+
|
| 57 |
+
| Metric | Base Voxtral | Evoxtral SFT | Evoxtral RL | Best |
|
| 58 |
+
|--------|-------------|-------------|------------|------|
|
| 59 |
+
| **WER** | 6.64% | **4.47%** | 5.12% | SFT |
|
| 60 |
+
| **CER** | 2.72% | **1.23%** | 1.48% | SFT |
|
| 61 |
+
| **Tag F1** | 22.0% | 67.2% | **69.4%** | RL |
|
| 62 |
+
| **Tag Precision** | 22.0% | 67.4% | **68.5%** | RL |
|
| 63 |
+
| **Tag Recall** | 22.0% | 69.4% | **72.7%** | RL |
|
| 64 |
+
| **Emphasis F1** | 42.0% | 84.0% | **86.0%** | RL |
|
| 65 |
+
| **Tag Hallucination** | 0.0% | **19.3%** | 20.2% | SFT |
|
| 66 |
+
|
| 67 |
+
**SFT** excels at raw transcription accuracy (best WER/CER). **RL** further improves expressive tag generation (+2.2% Tag F1, +3.3% Tag Recall, +2% Emphasis F1) at a small cost to WER.
|
| 68 |
+
|
| 69 |
+
### Per-Tag F1 Breakdown (SFT → RL)
|
| 70 |
+
|
| 71 |
+
| Tag | SFT F1 | RL F1 | Change | Support |
|
| 72 |
+
|-----|--------|-------|--------|---------|
|
| 73 |
+
| `[sighs]` | 1.000 | **1.000** | — | 9 |
|
| 74 |
+
| `[clears throat]` | 0.889 | **1.000** | +12.5% | 8 |
|
| 75 |
+
| `[gasps]` | 0.957 | **0.957** | — | 12 |
|
| 76 |
+
| `[pause]` | 0.885 | **0.902** | +1.9% | 25 |
|
| 77 |
+
| `[nervous]` | 0.800 | **0.846** | +5.8% | 13 |
|
| 78 |
+
| `[stammers]` | 0.889 | 0.842 | -5.3% | 8 |
|
| 79 |
+
| `[laughs]` | 0.800 | **0.815** | +1.9% | 12 |
|
| 80 |
+
| `[sad]` | 0.667 | **0.750** | +12.4% | 4 |
|
| 81 |
+
| `[whispers]` | 0.636 | **0.667** | +4.9% | 13 |
|
| 82 |
+
| `[crying]` | 0.750 | 0.571 | -23.9% | 5 |
|
| 83 |
+
| `[excited]` | 0.615 | 0.571 | -7.2% | 5 |
|
| 84 |
+
| `[shouts]` | 0.400 | **0.500** | +25.0% | 3 |
|
| 85 |
+
| `[calm]` | 0.200 | **0.400** | +100% | 6 |
|
| 86 |
+
| `[frustrated]` | 0.444 | 0.444 | — | 3 |
|
| 87 |
+
| `[angry]` | 0.667 | 0.667 | — | 2 |
|
| 88 |
+
| `[confused]` | 0.000 | 0.000 | — | 1 |
|
| 89 |
+
| `[scared]` | 0.000 | 0.000 | — | 1 |
|
| 90 |
+
|
| 91 |
+
RL improved 9 tags, kept 4 stable, and regressed 3. Biggest gains on [clears throat] (+12.5%), [calm] (+100%), [sad] (+12.4%), and [shouts] (+25%).
|
| 92 |
|
| 93 |
## Training Details
|
| 94 |
|
| 95 |
+
### SFT Stage
|
| 96 |
+
|
| 97 |
| Parameter | Value |
|
| 98 |
|-----------|-------|
|
| 99 |
| Base model | `mistralai/Voxtral-Mini-3B-2507` |
|
|
|
|
| 112 |
| Training time | ~25 minutes |
|
| 113 |
| Trainable params | 124.8M / 4.8B (2.6%) |
|
| 114 |
|
| 115 |
+
### RL Stage (RAFT)
|
| 116 |
+
|
| 117 |
+
| Parameter | Value |
|
| 118 |
+
|-----------|-------|
|
| 119 |
+
| Method | Rejection sampling + SFT (RAFT) |
|
| 120 |
+
| Samples per input | 4 (temperature=0.7, top_p=0.9) |
|
| 121 |
+
| Reward function | 0.4×(1-WER) + 0.4×Tag_F1 + 0.2×(1-hallucination) |
|
| 122 |
+
| Curated samples | 727 (bottom 10% filtered, reward > 0.954) |
|
| 123 |
+
| Avg reward | 0.980 |
|
| 124 |
+
| Learning rate | 5e-5 |
|
| 125 |
+
| Epochs | 1 |
|
| 126 |
+
| Final loss | 0.021 |
|
| 127 |
+
| Training time | ~7 minutes |
|
| 128 |
+
|
| 129 |
## Dataset
|
| 130 |
|
| 131 |
Custom synthetic dataset of 1,010 audio samples generated with ElevenLabs TTS v3:
|
|
|
|
| 142 |
from peft import PeftModel
|
| 143 |
|
| 144 |
repo_id = "mistralai/Voxtral-Mini-3B-2507"
|
| 145 |
+
# Use "YongkangZOU/evoxtral-lora" for SFT or "YongkangZOU/evoxtral-rl" for RL
|
| 146 |
+
adapter_id = "YongkangZOU/evoxtral-rl"
|
| 147 |
|
| 148 |
processor = AutoProcessor.from_pretrained(repo_id)
|
| 149 |
base_model = VoxtralForConditionalGeneration.from_pretrained(
|
|
|
|
| 169 |
# [nervous] So... I was thinking maybe we could [clears throat] try that new restaurant downtown?
|
| 170 |
```
|
| 171 |
|
| 172 |
+
## API
|
| 173 |
+
|
| 174 |
+
A serverless API with Swagger UI is available on Modal:
|
| 175 |
+
|
| 176 |
+
```bash
|
| 177 |
+
curl -X POST https://yongkang-zou1999--evoxtral-api-evoxtralmodel-web.modal.run/transcribe \
|
| 178 |
+
-F "file=@audio.wav"
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
- [Swagger UI](https://yongkang-zou1999--evoxtral-api-evoxtralmodel-web.modal.run/docs)
|
| 182 |
+
- [Live Demo (HF Space)](https://huggingface.co/spaces/YongkangZOU/evoxtral)
|
| 183 |
+
|
| 184 |
## W&B Tracking
|
| 185 |
|
| 186 |
All training and evaluation runs are tracked on Weights & Biases:
|
| 187 |
+
- [SFT Training](https://wandb.ai/yongkang-zou-ai/evoxtral/runs/t8ak7a20)
|
| 188 |
+
- [RL Training (RAFT)](https://wandb.ai/yongkang-zou-ai/evoxtral)
|
| 189 |
+
- [Base model eval](https://wandb.ai/yongkang-zou-ai/evoxtral/runs/bvqa4ioo)
|
| 190 |
+
- [SFT model eval](https://wandb.ai/yongkang-zou-ai/evoxtral/runs/ayx4ldyd)
|
| 191 |
+
- [RL model eval](https://wandb.ai/yongkang-zou-ai/evoxtral)
|
| 192 |
- [Project dashboard](https://wandb.ai/yongkang-zou-ai/evoxtral)
|
| 193 |
|
| 194 |
## Supported Tags
|
|
|
|
| 200 |
## Limitations
|
| 201 |
|
| 202 |
- Trained on synthetic (TTS-generated) audio, not natural speech recordings
|
| 203 |
+
- ~20% tag hallucination rate — model occasionally predicts tags not in the reference
|
| 204 |
+
- Rare/subtle tags ([calm], [confused], [scared]) have low accuracy due to limited training examples
|
| 205 |
+
- RL variant trades ~0.65% WER for better tag accuracy
|
| 206 |
- English only
|
| 207 |
- Best results on conversational and emotionally expressive speech
|
| 208 |
|
docs/research/references.md
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Evoxtral — Research References & Sources
|
| 2 |
+
|
| 3 |
+
References for the technical report covering expressive tagged transcription, ASR evaluation, and post-training methods.
|
| 4 |
+
|
| 5 |
+
## Core Model
|
| 6 |
+
|
| 7 |
+
- **Voxtral: An Audio-Language Model** — Mistral AI, 2025. Voxtral Mini and Small are open-weights audio chat models trained with SFT + DPO + Online DPO. Online DPO delivers crisper grounding, fewer hallucinations, and more helpful responses.
|
| 8 |
+
- Paper: https://arxiv.org/abs/2507.13264
|
| 9 |
+
- Model: https://huggingface.co/mistralai/Voxtral-Mini-3B-2507
|
| 10 |
+
|
| 11 |
+
## RL / Post-Training for ASR
|
| 12 |
+
|
| 13 |
+
- **Group Relative Policy Optimization for Speech Recognition** — Proposes GRPO with rule-based rewards for LLM-based ASR. Achieved 18% relative WER reduction on AMI-IHM and 27.9% on AMI-SDM compared to SFT-adapted models.
|
| 14 |
+
- Paper: https://arxiv.org/abs/2509.01939
|
| 15 |
+
|
| 16 |
+
- **Advancing Speech Understanding in Speech-Aware Language Models with GRPO** — Applies GRPO to large audio language models, investigating different rule-based reward functions and RL data construction strategies.
|
| 17 |
+
- Paper: https://arxiv.org/abs/2509.16990
|
| 18 |
+
|
| 19 |
+
- **Reinforcement Learning Outperforms Supervised Fine-Tuning: A Case Study on Audio Question Answering** — Demonstrates RL (GRPO) achieves state-of-the-art on audio QA tasks, outperforming SFT baselines.
|
| 20 |
+
- Paper: https://arxiv.org/abs/2503.11197
|
| 21 |
+
|
| 22 |
+
- **Explore the Reinforcement Learning for LLM-based ASR and TTS System** — Surveys RL applications to speech models, noting that while RL has enhanced text-based LLMs, application to ASR/TTS remains underexplored.
|
| 23 |
+
- Paper: https://arxiv.org/abs/2509.18569
|
| 24 |
+
|
| 25 |
+
## ASR Evaluation Metrics
|
| 26 |
+
|
| 27 |
+
- **Speech Recognition Accuracy: Production Metrics & Optimization 2025** — Deepgram. Covers WER, CER, Keyword Recall Rate (KRR), Real-Time Factor (RTF), and end-to-end latency. Production systems need blended metrics depending on use case.
|
| 28 |
+
- Source: https://deepgram.com/learn/speech-recognition-accuracy-production-metrics
|
| 29 |
+
|
| 30 |
+
- **Moving Beyond Word Error Rate to Evaluate ASR in Clinical Samples** — Argues WER alone is insufficient; error type, meaning, and context matter. A single substitution can drastically change intent with the same WER penalty.
|
| 31 |
+
- Paper: https://www.sciencedirect.com/science/article/pii/S0165178125003385
|
| 32 |
+
|
| 33 |
+
- **Measuring the Accuracy of Automatic Speech Recognition Solutions** — ACM survey on ASR evaluation methodology, limitations of WER/CER, and alternative metrics.
|
| 34 |
+
- Paper: https://dl.acm.org/doi/10.1145/3636513
|
| 35 |
+
|
| 36 |
+
- **On the Robust Approximation of ASR Metrics** — ACL 2025. Novel label-free approach for approximating ASR performance using multimodal embeddings.
|
| 37 |
+
- Paper: https://arxiv.org/abs/2502.12408
|
| 38 |
+
|
| 39 |
+
- **ProfASR-Bench: A Professional-talk ASR Dataset for High-Stakes Applications** — Evaluation suite supporting conventional metrics plus entity-aware scores and slice-wise reporting by accent and gender.
|
| 40 |
+
- Paper: https://arxiv.org/abs/2512.23686
|
| 41 |
+
|
| 42 |
+
## Expressive Speech & TTS
|
| 43 |
+
|
| 44 |
+
- **ElevenLabs v3 Text-to-Speech** — TTS model supporting inline expressive audio tags for emotions, non-verbal sounds, and delivery cues. Tag set used as target vocabulary for Evoxtral.
|
| 45 |
+
- Docs: https://elevenlabs.io/docs/api-reference/text-to-speech
|
| 46 |
+
|
| 47 |
+
## Fine-Tuning Frameworks
|
| 48 |
+
|
| 49 |
+
- **ms-swift** — ModelScope framework supporting SFT/DPO/GRPO for 600+ LLMs and 300+ MLLMs. AAAI 2025.
|
| 50 |
+
- GitHub: https://github.com/modelscope/ms-swift
|
| 51 |
+
|
| 52 |
+
- **OpenRLHF** — Scalable agentic RL framework based on Ray, supporting PPO, DAPO, REINFORCE++, and more.
|
| 53 |
+
- GitHub: https://github.com/OpenRLHF/OpenRLHF
|
| 54 |
+
|
| 55 |
+
- **PEFT (Parameter-Efficient Fine-Tuning)** — HuggingFace library for LoRA and other adapter methods.
|
| 56 |
+
- GitHub: https://github.com/huggingface/peft
|
| 57 |
+
|
| 58 |
+
## Evaluation Methodology
|
| 59 |
+
|
| 60 |
+
- **jiwer** — Python library for WER, CER, and other ASR metrics based on edit distance.
|
| 61 |
+
- GitHub: https://github.com/jitsi/jiwer
|
| 62 |
+
|
| 63 |
+
## Our Benchmark: Evoxtral-Bench
|
| 64 |
+
|
| 65 |
+
Metrics computed by `scripts/train_modal.py::evaluate()`:
|
| 66 |
+
|
| 67 |
+
| Metric | Description | Direction |
|
| 68 |
+
|--------|-------------|-----------|
|
| 69 |
+
| **WER** | Word Error Rate on plain text (tags stripped) | lower = better |
|
| 70 |
+
| **CER** | Character Error Rate on plain text | lower = better |
|
| 71 |
+
| **Tag F1** | F1 score for tag extraction (multiset intersection) | higher = better |
|
| 72 |
+
| **Tag Precision** | Fraction of predicted tags that match reference | higher = better |
|
| 73 |
+
| **Tag Recall** | Fraction of reference tags captured by prediction | higher = better |
|
| 74 |
+
| **Tag Hallucination Rate** | Fraction of predicted tags not in reference at all | lower = better |
|
| 75 |
+
| **Emphasis F1** | F1 for CAPITALIZED emphasis words | higher = better |
|
| 76 |
+
| **Per-Tag F1** | Breakdown by individual tag type (e.g. [laughs], [sighs]) | higher = better |
|
| 77 |
+
|
| 78 |
+
### Why These Metrics
|
| 79 |
+
|
| 80 |
+
- **WER + CER**: Standard ASR quality. CER is more granular and catches character-level errors WER misses.
|
| 81 |
+
- **Tag Precision vs Recall**: F1 alone hides whether the model hallucinates tags (low precision) or misses them (low recall). Judges care about this distinction.
|
| 82 |
+
- **Tag Hallucination Rate**: Critical for downstream TTS — hallucinated tags produce wrong prosody/effects.
|
| 83 |
+
- **Per-Tag Breakdown**: Shows which expressive cues the model handles well vs struggles with, informing future data collection.
|
| 84 |
+
- **Emphasis F1**: CAPS words (e.g., "BELIEVE", "NEVER") carry prosodic emphasis in ElevenLabs v3; measuring this separately tracks delivery accuracy.
|
docs/technical_report.md
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Evoxtral: Expressive Tagged Transcription via Supervised Fine-Tuning and Rejection Sampling
|
| 2 |
+
|
| 3 |
+
**Mistral AI Online Hackathon 2026 — W&B Fine-Tuning Track**
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## Abstract
|
| 8 |
+
|
| 9 |
+
Standard automatic speech recognition (ASR) systems discard paralinguistic and expressive information present in spoken audio, producing plain text that fails to capture sighs, laughs, hesitations, emotional tone, and other prosodic cues. We present **Evoxtral**, a LoRA adapter for `mistralai/Voxtral-Mini-3B-2507` [1] that produces transcriptions enriched with inline expressive audio tags drawn from the ElevenLabs v3 tag vocabulary [6]. We apply a two-stage post-training pipeline: supervised fine-tuning (SFT) followed by rejection sampling fine-tuning (RAFT) [4]. SFT reduces word error rate (WER) by 33% relative (6.64% → 4.47%) and increases Tag F1 from 22.0% to 67.2%. The subsequent RAFT stage further improves Tag F1 to 69.4% and Tag Recall to 72.7% at a marginal WER cost. We release two model variants, a serverless inference API, and a live interactive demo.
|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## 1. Introduction
|
| 14 |
+
|
| 15 |
+
Modern ASR pipelines excel at converting speech to text with low word error rates, yet they systematically strip out the expressive dimension of human communication. When a speaker sighs before a sentence, laughs nervously mid-phrase, or whispers for emphasis, these paralinguistic signals carry meaning that plain transcription cannot represent. This information is especially critical for downstream text-to-speech (TTS) synthesis: next-generation TTS systems such as ElevenLabs v3 [6] consume inline expressive tags to control prosody, affect, and delivery at a fine-grained level.
|
| 16 |
+
|
| 17 |
+
We ask: can a multimodal audio-language model be trained to produce ASR output that preserves expressive content through inline tags? We answer affirmatively with **Evoxtral**, built on Voxtral-Mini-3B-2507 [1].
|
| 18 |
+
|
| 19 |
+
**Why Voxtral.** We chose Voxtral-Mini-3B as our base model for several reasons. First, as a *generative* audio-language model (rather than a CTC or encoder-only ASR), Voxtral decodes transcriptions autoregressively — meaning it can naturally produce arbitrary inline tokens such as `[laughs]` or `[nervous]` within the text stream, without requiring architectural changes. Traditional ASR models constrain their output vocabulary to words and punctuation; Voxtral's LLM decoder has no such limitation. Second, Voxtral's compact 3B-parameter architecture makes LoRA fine-tuning feasible on a single A10G GPU within hackathon time constraints, while still delivering competitive ASR quality. Third, Voxtral was itself trained with post-training alignment (SFT + DPO + Online DPO) [1], meaning the model is already instruction-following and amenable to further fine-tuning — a strong foundation for adding new capabilities. Finally, as a Mistral AI model released under Apache 2.0, Voxtral aligns with the hackathon's focus on the Mistral ecosystem and enables open redistribution of our adapters.
|
| 20 |
+
|
| 21 |
+
Our approach fine-tunes Voxtral-Mini-3B-2507 using parameter-efficient LoRA adapters [3] on a synthetically generated dataset of expressive speech paired with tagged transcriptions.
|
| 22 |
+
|
| 23 |
+
To illustrate the contrast between standard ASR and Evoxtral output, consider the following example:
|
| 24 |
+
|
| 25 |
+
**Standard ASR output:**
|
| 26 |
+
> "So I was thinking maybe we could try that new restaurant downtown. I mean if you're free this weekend."
|
| 27 |
+
|
| 28 |
+
**Evoxtral output:**
|
| 29 |
+
> "[nervous] So... [stammers] I was thinking maybe we could... [clears throat] try that new restaurant downtown? [laughs nervously] I mean, if you're free this weekend?"
|
| 30 |
+
|
| 31 |
+
The Evoxtral output captures hesitation, nervous laughter, and a throat clear — paralinguistic content that is acoustically present but conventionally discarded. Our two-stage training pipeline (SFT followed by RAFT) achieves strong tag generation performance while maintaining competitive transcription accuracy. We make the following contributions:
|
| 32 |
+
|
| 33 |
+
1. A synthetic dataset of 1,010 expressive speech samples paired with tagged transcriptions across 17 ElevenLabs v3 tag types.
|
| 34 |
+
2. A two-stage fine-tuning recipe (SFT → RAFT) for expressive ASR using LoRA on Voxtral-Mini-3B.
|
| 35 |
+
3. A custom evaluation benchmark, **Evoxtral-Bench**, with seven metrics covering both transcription accuracy and tag generation quality.
|
| 36 |
+
4. Two released model variants optimized for different use cases: accuracy-critical (SFT) and expressiveness-critical (RL).
|
| 37 |
+
|
| 38 |
+
---
|
| 39 |
+
|
| 40 |
+
## 2. Related Work
|
| 41 |
+
|
| 42 |
+
**Voxtral [1].** Voxtral-Mini-3B-2507 is a multimodal audio-language model released by Mistral AI. It is built on a Whisper-based audio encoder fused with a Mistral language model backbone, trained via SFT, DPO, and online DPO. Our work builds directly on this foundation, adding expressive tagging capability via LoRA adaptation.
|
| 43 |
+
|
| 44 |
+
**Reinforcement Learning for LLM-based ASR [2].** Shi et al. apply group relative policy optimization (GRPO) to LLM-based ASR, achieving an 18% relative WER reduction without paired preference data. Closely related work [8, 9, 12] further demonstrates that RL-based training consistently outperforms SFT alone for speech understanding tasks. Our RAFT stage is philosophically aligned with this line of work but uses rule-based rejection sampling rather than policy gradient methods, making it simpler to implement and more stable to train.
|
| 45 |
+
|
| 46 |
+
**LoRA [3].** Low-rank adaptation (LoRA) [3] inserts trainable low-rank matrices into the attention projections and feed-forward layers of a frozen base model. This reduces the number of trainable parameters by orders of magnitude while matching full fine-tuning performance. We use LoRA with rank 64 and alpha 128, implemented via HuggingFace PEFT [10].
|
| 47 |
+
|
| 48 |
+
**Rejection Sampling Fine-Tuning (RAFT/RFT) [4].** Yuan et al. [4] propose generating multiple model completions for each training input, scoring them with a reward function, and performing SFT on the highest-scoring completions. This iterative rejection sampling approach is computationally simpler than policy gradient methods while still providing a reinforcement learning signal. We adopt this approach for our RL stage.
|
| 49 |
+
|
| 50 |
+
**NEFTune [5].** Jain et al. [5] demonstrate that adding uniform random noise to embedding vectors during training improves instruction-following performance. We apply NEFTune with noise alpha = 5.0 during SFT to regularize training on our small dataset.
|
| 51 |
+
|
| 52 |
+
**ElevenLabs v3 Audio Tags [6].** ElevenLabs v3 TTS introduces a structured vocabulary of inline expressive tags that control how synthesized speech is delivered. These tags — such as `[sighs]`, `[laughs]`, `[whispers]`, and `[nervous]` — are the target output vocabulary for Evoxtral.
|
| 53 |
+
|
| 54 |
+
**ASR Evaluation [7, 11].** We compute WER and CER using the `jiwer` library [7], following standard definitions from Morris et al. [11].
|
| 55 |
+
|
| 56 |
+
---
|
| 57 |
+
|
| 58 |
+
## 3. Method
|
| 59 |
+
|
| 60 |
+
### 3.1 Dataset
|
| 61 |
+
|
| 62 |
+
We construct a synthetic dataset of 1,010 audio samples generated using the ElevenLabs TTS v3 API [6]. Each sample consists of a short spoken utterance (5–30 seconds) paired with a reference tagged transcription containing inline ElevenLabs v3 expressive tags. The dataset covers 17 tag types:
|
| 63 |
+
|
| 64 |
+
`[sighs]`, `[laughs]`, `[whispers]`, `[nervous]`, `[frustrated]`, `[clears throat]`, `[pause]`, `[excited]`, `[stammers]`, `[gasps]`, `[sad]`, `[angry]`, `[calm]`, `[crying]`, `[shouts]`, `[confused]`, `[scared]`
|
| 65 |
+
|
| 66 |
+
The dataset is split into 808 training, 101 validation, and 101 test samples. Tag frequency follows a long-tail distribution: `[pause]` is the most common tag while `[confused]` and `[scared]` appear rarely. The audio encoder (Whisper-based) is kept frozen throughout training; only the language model backbone and projector are fine-tuned.
|
| 67 |
+
|
| 68 |
+
### 3.2 Stage 1: Supervised Fine-Tuning (SFT)
|
| 69 |
+
|
| 70 |
+
We fine-tune `mistralai/Voxtral-Mini-3B-2507` [1] using LoRA [3] with the following configuration:
|
| 71 |
+
|
| 72 |
+
| Hyperparameter | Value |
|
| 73 |
+
|---|---|
|
| 74 |
+
| LoRA rank | 64 |
|
| 75 |
+
| LoRA alpha | 128 |
|
| 76 |
+
| LoRA dropout | 0.05 |
|
| 77 |
+
| Target modules | `q/k/v/o_proj`, `gate/up/down_proj`, `multi_modal_projector` |
|
| 78 |
+
| Learning rate | 2e-4 |
|
| 79 |
+
| LR schedule | Cosine decay |
|
| 80 |
+
| Epochs | 3 |
|
| 81 |
+
| Batch size | 2 |
|
| 82 |
+
| Gradient accumulation steps | 8 (effective batch = 16) |
|
| 83 |
+
| NEFTune noise alpha | 5.0 |
|
| 84 |
+
| Precision | bf16 |
|
| 85 |
+
| Hardware | NVIDIA A10G (24 GB) |
|
| 86 |
+
| Training time | ~25 minutes |
|
| 87 |
+
| Trainable parameters | 124.8M / 4.8B (2.6%) |
|
| 88 |
+
|
| 89 |
+
The SFT objective is standard next-token prediction (cross-entropy) on the tagged transcription tokens, conditioned on the audio encoder outputs. NEFTune [5] noise is applied to the input embeddings to reduce overfitting on the small training set.
|
| 90 |
+
|
| 91 |
+
### 3.3 Stage 2: Rejection Sampling Fine-Tuning (RAFT)
|
| 92 |
+
|
| 93 |
+
Following Yuan et al. [4] and inspired by Voxtral's own SFT→DPO recipe [1] and GRPO-based ASR work [2], we apply a rejection sampling stage to refine tag generation quality. The procedure is as follows:
|
| 94 |
+
|
| 95 |
+
```
|
| 96 |
+
┌─────────────────────────────────────────────────────────────────┐
|
| 97 |
+
│ RAFT Training Pipeline │
|
| 98 |
+
├─────────────────────────────────────────────────────────────────┤
|
| 99 |
+
│ │
|
| 100 |
+
│ Training Set (808 samples) │
|
| 101 |
+
│ │ │
|
| 102 |
+
│ ▼ │
|
| 103 |
+
│ ┌─────────────────────────────────────────┐ │
|
| 104 |
+
│ │ Generate N=4 completions per sample │ │
|
| 105 |
+
│ │ (temperature=0.7, top_p=0.9) │ │
|
| 106 |
+
│ └─────────────────────────────────────────┘ │
|
| 107 |
+
│ │ │
|
| 108 |
+
│ ▼ │
|
| 109 |
+
│ ┌─────────────────────────────────────────┐ │
|
| 110 |
+
│ │ Score each completion: │ │
|
| 111 |
+
│ │ R = 0.4 × (1 - WER) │ │
|
| 112 |
+
│ │ + 0.4 × Tag_F1 │ │
|
| 113 |
+
│ │ + 0.2 × (1 - hallucination_rate) │ │
|
| 114 |
+
│ └─────────────────────────────────────────┘ │
|
| 115 |
+
│ │ │
|
| 116 |
+
│ ▼ │
|
| 117 |
+
│ ┌─────────────────────────────────────────┐ │
|
| 118 |
+
│ │ Keep best completion per sample │ │
|
| 119 |
+
│ │ Filter bottom 10% (reward ≤ 0.954) │ │
|
| 120 |
+
│ │ → 727 curated samples remain │ │
|
| 121 |
+
│ │ (avg reward: 0.980) │ │
|
| 122 |
+
│ └─────────────────────────────────────────┘ │
|
| 123 |
+
│ │ │
|
| 124 |
+
│ ▼ │
|
| 125 |
+
│ ┌─────────────────────────────────────────┐ │
|
| 126 |
+
│ │ SFT on curated data │ │
|
| 127 |
+
│ │ (lr=5e-5, 1 epoch, ~7 minutes) │ │
|
| 128 |
+
│ └─────────────────────────────────────────┘ │
|
| 129 |
+
│ │
|
| 130 |
+
└─────────────────────────────────────────────────────────────────┘
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
**Figure 1.** RAFT training pipeline. For each training sample, four completions are generated and scored by a rule-based reward function balancing transcription accuracy, tag quality, and hallucination rate. Only high-reward completions are retained for the final SFT pass.
|
| 134 |
+
|
| 135 |
+
The reward function is designed to balance three objectives: transcription accuracy (WER), tag generation quality (Tag F1), and avoidance of hallucinated tags. The 0.4/0.4/0.2 weighting reflects an equal priority on accuracy and expressiveness, with a penalty for hallucination. After filtering, 727 of the original 808 training samples remain, with a mean reward of 0.980. The RAFT SFT stage trains for one epoch at a reduced learning rate of 5e-5, completing in approximately 7 minutes. Final training loss: 0.021.
|
| 136 |
+
|
| 137 |
+
The full two-stage pipeline is illustrated below:
|
| 138 |
+
|
| 139 |
+
```
|
| 140 |
+
┌──────────────────────────────────────────────────────────────────┐
|
| 141 |
+
│ Evoxtral Training Overview │
|
| 142 |
+
├──────────────────────────────────────────────────────────────────┤
|
| 143 |
+
│ │
|
| 144 |
+
│ Voxtral-Mini-3B-2507 (frozen audio encoder) │
|
| 145 |
+
│ │ │
|
| 146 |
+
│ ▼ │
|
| 147 |
+
│ ┌─────────────────────┐ │
|
| 148 |
+
│ │ Stage 1: SFT │ lr=2e-4, 3 epochs, NEFTune │
|
| 149 |
+
│ │ (808 samples) │ ~25 min on A10G │
|
| 150 |
+
│ └─────────────────────┘ │
|
| 151 |
+
│ │ │
|
| 152 |
+
│ ▼ │
|
| 153 |
+
│ ┌─────────────────────┐ │
|
| 154 |
+
│ │ Evoxtral-SFT │ WER=4.47%, Tag F1=67.2% │
|
| 155 |
+
│ └─────────────────────┘ │
|
| 156 |
+
│ │ │
|
| 157 |
+
│ ▼ │
|
| 158 |
+
│ ┌���────────────────────┐ │
|
| 159 |
+
│ │ Stage 2: RAFT │ N=4 samples, reward filter │
|
| 160 |
+
│ │ (727 samples) │ lr=5e-5, 1 epoch, ~7 min │
|
| 161 |
+
│ └─────────────────────┘ │
|
| 162 |
+
│ │ │
|
| 163 |
+
│ ▼ │
|
| 164 |
+
│ ┌─────────────────────┐ │
|
| 165 |
+
│ │ Evoxtral-RL │ WER=5.12%, Tag F1=69.4% │
|
| 166 |
+
│ └─────────────────────┘ │
|
| 167 |
+
│ │
|
| 168 |
+
└──────────────────────────────────────────────────────────────────┘
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
**Figure 2.** Full two-stage Evoxtral training pipeline from base Voxtral-Mini-3B to the two released model variants.
|
| 172 |
+
|
| 173 |
+
---
|
| 174 |
+
|
| 175 |
+
## 4. Evaluation
|
| 176 |
+
|
| 177 |
+
### 4.1 Evoxtral-Bench
|
| 178 |
+
|
| 179 |
+
We evaluate on **Evoxtral-Bench**, a held-out benchmark of 50 test samples drawn from the 101-sample test split. This subset enables fast iteration while maintaining statistical validity. We compute seven evaluation metrics:
|
| 180 |
+
|
| 181 |
+
- **WER** (Word Error Rate, ↓) — computed via `jiwer` [7], tags stripped before comparison
|
| 182 |
+
- **CER** (Character Error Rate, ↓) — character-level transcription accuracy via `jiwer` [7]
|
| 183 |
+
- **Tag F1** (↑) — token-level F1 on predicted vs. reference tag sequences
|
| 184 |
+
- **Tag Precision** (↑) — fraction of predicted tags present in reference
|
| 185 |
+
- **Tag Recall** (↑) — fraction of reference tags predicted by the model
|
| 186 |
+
- **Tag Hallucination Rate** (↓) — fraction of predicted tags not present in any reference tag
|
| 187 |
+
- **Emphasis F1** (↑) — F1 on emphasized word spans (ellipsis and capitalization markers)
|
| 188 |
+
|
| 189 |
+
Per-tag F1 is additionally computed across all 17 tag types to diagnose per-class performance.
|
| 190 |
+
|
| 191 |
+
### 4.2 Core Results
|
| 192 |
+
|
| 193 |
+
**Table 1.** Core evaluation results on Evoxtral-Bench (50 samples). Bold indicates best per metric. ↓ lower is better; ↑ higher is better.
|
| 194 |
+
|
| 195 |
+
| Metric | Base Voxtral | Evoxtral SFT | Evoxtral RL | Best |
|
| 196 |
+
|---|---|---|---|---|
|
| 197 |
+
| WER ↓ | 6.64% | **4.47%** | 5.12% | SFT |
|
| 198 |
+
| CER ↓ | 2.72% | **1.23%** | 1.48% | SFT |
|
| 199 |
+
| Tag F1 ↑ | 22.0% | 67.2% | **69.4%** | RL |
|
| 200 |
+
| Tag Precision ↑ | 22.0% | 67.4% | **68.5%** | RL |
|
| 201 |
+
| Tag Recall ↑ | 22.0% | 69.4% | **72.7%** | RL |
|
| 202 |
+
| Emphasis F1 ↑ | 42.0% | 84.0% | **86.0%** | RL |
|
| 203 |
+
| Tag Hallucination Rate ↓ | 0.0% | **19.3%** | 20.2% | SFT |
|
| 204 |
+
|
| 205 |
+
The base Voxtral model achieves 22.0% Tag F1, suggesting some limited native capability to produce expressive tokens but with low precision and recall. SFT provides the dominant improvement: WER decreases by 33% relative (6.64% → 4.47%) and Tag F1 increases by 45 percentage points (22.0% → 67.2%). RAFT further refines tag metrics: Tag F1 improves by 2.2pp (67.2% → 69.4%), Tag Recall by 3.3pp (69.4% → 72.7%), and Emphasis F1 by 2.0pp (84.0% → 86.0%). However, RAFT introduces a small WER regression (4.47% → 5.12%), reflecting a Pareto tradeoff between transcription accuracy and expressive richness.
|
| 206 |
+
|
| 207 |
+
Tag hallucination — predicted tags absent from the reference — is 19.3% for SFT and 20.2% for RL. The base model has 0% hallucination because it rarely predicts any tags at all.
|
| 208 |
+
|
| 209 |
+
### 4.3 Per-Tag F1 Breakdown
|
| 210 |
+
|
| 211 |
+
**Table 2.** Per-tag F1 scores for Evoxtral-SFT and Evoxtral-RL on Evoxtral-Bench. Support indicates the number of test samples containing each tag.
|
| 212 |
+
|
| 213 |
+
| Tag | SFT F1 | RL F1 | Delta | Support |
|
| 214 |
+
|---|---|---|---|---|
|
| 215 |
+
| [sighs] | 1.000 | 1.000 | — | 9 |
|
| 216 |
+
| [clears throat] | 0.889 | **1.000** | +12.5% | 8 |
|
| 217 |
+
| [gasps] | 0.957 | 0.957 | — | 12 |
|
| 218 |
+
| [pause] | 0.885 | **0.902** | +1.9% | 25 |
|
| 219 |
+
| [nervous] | 0.800 | **0.846** | +5.8% | 13 |
|
| 220 |
+
| [stammers] | **0.889** | 0.842 | -5.3% | 8 |
|
| 221 |
+
| [laughs] | 0.800 | **0.815** | +1.9% | 12 |
|
| 222 |
+
| [sad] | 0.667 | **0.750** | +12.4% | 4 |
|
| 223 |
+
| [whispers] | 0.636 | **0.667** | +4.9% | 13 |
|
| 224 |
+
| [crying] | **0.750** | 0.571 | -23.9% | 5 |
|
| 225 |
+
| [excited] | **0.615** | 0.571 | -7.2% | 5 |
|
| 226 |
+
| [shouts] | 0.400 | **0.500** | +25.0% | 3 |
|
| 227 |
+
| [calm] | 0.200 | **0.400** | +100.0% | 6 |
|
| 228 |
+
| [frustrated] | 0.444 | 0.444 | — | 3 |
|
| 229 |
+
| [angry] | 0.667 | 0.667 | — | 2 |
|
| 230 |
+
| [confused] | 0.000 | 0.000 | — | 1 |
|
| 231 |
+
| [scared] | 0.000 | 0.000 | — | 1 |
|
| 232 |
+
|
| 233 |
+
RAFT improves 9 tags, maintains 4 stable, and regresses on 3. The largest gains are observed for `[calm]` (+100%, 0.200 → 0.400), `[shouts]` (+25.0%), `[clears throat]` (+12.5%), and `[sad]` (+12.4%). Regressions are noted for `[crying]` (-23.9%), `[excited]` (-7.2%), and `[stammers]` (-5.3%). The two zero-F1 tags (`[confused]`, `[scared]`) each appear only once in the test set, making estimation unreliable.
|
| 234 |
+
|
| 235 |
+
---
|
| 236 |
+
|
| 237 |
+
## 5. Analysis and Discussion
|
| 238 |
+
|
| 239 |
+
**SFT as the primary driver of improvement.** The SFT stage accounts for the vast majority of the performance gain: WER drops 33% relative and Tag F1 increases by 45 percentage points. This aligns with findings from GRPO-based ASR work [2, 8, 9] suggesting that a well-supervised initial adaptation is a strong foundation for subsequent RL refinement.
|
| 240 |
+
|
| 241 |
+
**The WER-Tag tradeoff.** RAFT improves tag metrics at the cost of a modest WER regression (4.47% → 5.12%). This suggests the existence of a Pareto frontier between transcription accuracy and expressive richness: optimizing for tag generation pushes the model toward producing more tags, which can introduce minor word-level errors. This motivates releasing two model variants — Evoxtral-SFT for accuracy-critical applications (e.g., professional transcription) and Evoxtral-RL for expressiveness-critical applications (e.g., downstream TTS synthesis with ElevenLabs v3 [6]).
|
| 242 |
+
|
| 243 |
+
**Tag hallucination.** Approximately 20% of predicted tags are not present in the reference transcription. Hallucination occurs when the model infers an expressive tone from acoustic cues that are present in the audio but absent or differently annotated in the reference. This may partly reflect annotation noise in synthetic data rather than pure model error. Future work should address this with contrastive or calibration-based training objectives.
|
| 244 |
+
|
| 245 |
+
**Effect of NEFTune.** Applying NEFTune [5] with noise alpha = 5.0 during SFT provided a regularization benefit on the small 808-sample training set, consistent with Jain et al.'s [5] findings on instruction-following tasks. Ablating this component was not feasible within hackathon time constraints but remains a planned analysis.
|
| 246 |
+
|
| 247 |
+
**Rare tag performance.** Tags with very low test support (`[confused]`, `[scared]`, support = 1) have zero F1, which is uninformative. Tags with support 2–6 (`[angry]`, `[shouts]`, `[calm]`, `[sad]`, `[frustrated]`) show high variance in F1 estimates. A larger, more balanced evaluation set would provide more reliable per-tag metrics.
|
| 248 |
+
|
| 249 |
+
**Reward function design.** The RAFT reward R = 0.4×(1−WER) + 0.4×Tag\_F1 + 0.2×(1−hallucination\_rate) explicitly encodes the design preference for equal weight on accuracy and expressiveness. The 0.2 weight on hallucination acts as a weak regularizer. An ablation across reward weightings would quantify the sensitivity of the final model to this design choice.
|
| 250 |
+
|
| 251 |
+
---
|
| 252 |
+
|
| 253 |
+
## 6. Limitations
|
| 254 |
+
|
| 255 |
+
The following limitations apply to this work:
|
| 256 |
+
|
| 257 |
+
- **Synthetic training data.** All 1,010 samples are synthesized using ElevenLabs TTS v3 [6]. The acoustic properties of synthetic speech differ from natural human speech (e.g., prosodic consistency, noise, spontaneous disfluencies). Performance on natural speech recordings may differ from the reported results.
|
| 258 |
+
- **Tag hallucination.** Approximately 20% of predicted tags in the RL model are not present in the reference. This may limit applicability in settings requiring precise expressive annotation.
|
| 259 |
+
- **Rare tag coverage.** Seventeen tag types are represented in the dataset, but several occur in fewer than 5 test samples. Per-tag F1 estimates for rare categories are unreliable.
|
| 260 |
+
- **English only.** The dataset and training are English-only. Generalization to other languages is not evaluated.
|
| 261 |
+
- **Small dataset.** 808 training samples is a small fine-tuning set. Scaling to thousands or tens of thousands of examples with natural speech could substantially improve performance.
|
| 262 |
+
- **Evaluation scope.** Evoxtral-Bench covers 50 test samples for fast iteration. A larger evaluation set would yield more statistically robust estimates across all metrics and tag types.
|
| 263 |
+
|
| 264 |
+
---
|
| 265 |
+
|
| 266 |
+
## 7. Conclusion
|
| 267 |
+
|
| 268 |
+
We presented Evoxtral, a LoRA-adapted version of Voxtral-Mini-3B-2507 [1] that produces expressive tagged transcriptions using ElevenLabs v3 audio tags [6]. Our two-stage training pipeline — SFT followed by RAFT [4] — demonstrates that expressive tagging capability can be effectively injected into a pre-trained ASR model with parameter-efficient fine-tuning [3, 10].
|
| 269 |
+
|
| 270 |
+
SFT achieves a 33% relative WER reduction and a 45 percentage-point improvement in Tag F1 over the base model. RAFT further improves tag recall and F1 by targeting tag generation quality directly through a rule-based reward signal, at a modest transcription accuracy cost. The two resulting model variants cover a Pareto frontier between accuracy and expressiveness, allowing practitioners to select the appropriate trade-off for their application.
|
| 271 |
+
|
| 272 |
+
Future directions include: (1) collecting natural speech data with crowd-sourced expressive annotations to reduce the synthetic data gap; (2) replacing RAFT with GRPO [2] or DPO [1] for more sample-efficient RL training; (3) expanding to multilingual settings leveraging Voxtral's multilingual audio encoder; and (4) developing joint ASR+TTS evaluation protocols that measure downstream TTS quality when Evoxtral output is used as input to ElevenLabs v3 [6].
|
| 273 |
+
|
| 274 |
+
---
|
| 275 |
+
|
| 276 |
+
## References
|
| 277 |
+
|
| 278 |
+
[1] Mistral AI. "Voxtral." arXiv:2507.13264, 2025. https://arxiv.org/abs/2507.13264
|
| 279 |
+
|
| 280 |
+
[2] "Group Relative Policy Optimization for Speech Recognition." arXiv:2509.01939, 2025. https://arxiv.org/abs/2509.01939
|
| 281 |
+
|
| 282 |
+
[3] Hu, E.J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., Wang, L., and Chen, W. "LoRA: Low-Rank Adaptation of Large Language Models." In *Proceedings of ICLR*, 2022. arXiv:2106.09685. https://arxiv.org/abs/2106.09685
|
| 283 |
+
|
| 284 |
+
[4] Yuan, Z., Yuan, H., Li, C., Dong, G., Tan, C., and Zhou, C. "Scaling Relationship on Learning Mathematical Reasoning with Large Language Models." arXiv:2308.01825, 2023. https://arxiv.org/abs/2308.01825
|
| 285 |
+
|
| 286 |
+
[5] Jain, N., Chiang, P., Yeh, Y., Kirchenbauer, J., Chu, C., Somepalli, G., Bartoldson, B., Kailkhura, B., Schwarzschild, A., Saha, A., Goldblum, M., Geiping, J., and Goldstein, T. "NEFTune: Noisy Embeddings Improve Instruction Finetuning." In *Proceedings of ICLR*, 2024. arXiv:2310.05914. https://arxiv.org/abs/2310.05914
|
| 287 |
+
|
| 288 |
+
[6] ElevenLabs. "Text-to-Speech v3 Audio Tags." ElevenLabs Developer Documentation, 2025. https://elevenlabs.io/docs/api-reference/text-to-speech
|
| 289 |
+
|
| 290 |
+
[7] "JiWER: Evaluate your speech recognition system." Python library for ASR evaluation metrics. https://github.com/jitsi/jiwer
|
| 291 |
+
|
| 292 |
+
[8] "Advancing Speech Understanding in Speech-Aware Language Models with GRPO." arXiv:2509.16990, 2025. https://arxiv.org/abs/2509.16990
|
| 293 |
+
|
| 294 |
+
[9] "Explore the Reinforcement Learning for LLM-based ASR and TTS System." arXiv:2509.18569, 2025. https://arxiv.org/abs/2509.18569
|
| 295 |
+
|
| 296 |
+
[10] Mangrulkar, S., Gugger, S., Debut, L., Belkada, Y., Paul, S., and Bossan, B. "PEFT: State-of-the-art Parameter-Efficient Fine-Tuning methods." HuggingFace, 2022. https://github.com/huggingface/peft
|
| 297 |
+
|
| 298 |
+
[11] Morris, A.C., Maier, V., and Green, P. "From WER and RIL to MER and WIL: Improved evaluation measures for connected speech recognition." In *Proceedings of INTERSPEECH*, 2004.
|
| 299 |
+
|
| 300 |
+
[12] "Reinforcement Learning Outperforms Supervised Fine-Tuning: A Case Study on Audio Question Answering." arXiv:2503.11197, 2025. https://arxiv.org/abs/2503.11197
|
docs/technical_report.tex
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
\documentclass[11pt,a4paper]{article}
|
| 2 |
+
|
| 3 |
+
% ── Packages ──────────────────────────────────────────────────────────
|
| 4 |
+
\usepackage[utf8]{inputenc}
|
| 5 |
+
\usepackage[T1]{fontenc}
|
| 6 |
+
\usepackage{lmodern}
|
| 7 |
+
\usepackage[margin=1in]{geometry}
|
| 8 |
+
\usepackage{amsmath,amssymb}
|
| 9 |
+
\usepackage{graphicx}
|
| 10 |
+
\usepackage{booktabs}
|
| 11 |
+
\usepackage{hyperref}
|
| 12 |
+
\usepackage{xcolor}
|
| 13 |
+
\usepackage{tikz}
|
| 14 |
+
\usetikzlibrary{arrows.meta,positioning,shapes.geometric,fit,calc}
|
| 15 |
+
\usepackage{caption}
|
| 16 |
+
\usepackage{subcaption}
|
| 17 |
+
\usepackage{enumitem}
|
| 18 |
+
\usepackage{multirow}
|
| 19 |
+
\usepackage{array}
|
| 20 |
+
\usepackage{tabularx}
|
| 21 |
+
\usepackage{float}
|
| 22 |
+
\usepackage{natbib}
|
| 23 |
+
\usepackage{authblk}
|
| 24 |
+
\usepackage{microtype}
|
| 25 |
+
|
| 26 |
+
\hypersetup{
|
| 27 |
+
colorlinks=true,
|
| 28 |
+
linkcolor=blue!60!black,
|
| 29 |
+
citecolor=blue!60!black,
|
| 30 |
+
urlcolor=blue!60!black,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
% ── Inline tag command ────────────────────────────────────────────────
|
| 34 |
+
\newcommand{\etag}[1]{\texttt{[#1]}}
|
| 35 |
+
|
| 36 |
+
% ── Title ─────────────────────────────────────────────────────────────
|
| 37 |
+
\title{\textbf{Evoxtral: Expressive Tagged Transcription\\via Supervised Fine-Tuning and Rejection Sampling}}
|
| 38 |
+
\author[1]{Yongkang Zou}
|
| 39 |
+
\affil[1]{Mistral AI Online Hackathon 2026 --- W\&B Fine-Tuning Track}
|
| 40 |
+
\date{}
|
| 41 |
+
|
| 42 |
+
\begin{document}
|
| 43 |
+
\maketitle
|
| 44 |
+
|
| 45 |
+
% ══════════════════════════════════════════════════════════════════════
|
| 46 |
+
\begin{abstract}
|
| 47 |
+
Standard automatic speech recognition (ASR) systems discard paralinguistic and expressive information present in spoken audio, producing plain text that fails to capture sighs, laughs, hesitations, emotional tone, and other prosodic cues. We present \textbf{Evoxtral}, a LoRA adapter for \texttt{Voxtral-Mini-3B-2507}~\cite{voxtral} that produces transcriptions enriched with inline expressive audio tags drawn from the ElevenLabs v3 tag vocabulary~\cite{elevenlabs}. We apply a two-stage post-training pipeline: supervised fine-tuning (SFT) followed by rejection sampling fine-tuning (RAFT)~\cite{yuan2023rft}. SFT reduces word error rate (WER) by 33\% relative (6.64\%~$\to$~4.47\%) and increases Tag~F1 from 22.0\% to 67.2\%. The subsequent RAFT stage further improves Tag~F1 to 69.4\% and Tag~Recall to 72.7\% at a marginal WER cost. We release two model variants, a serverless inference API, and a live interactive demo.
|
| 48 |
+
\end{abstract}
|
| 49 |
+
|
| 50 |
+
% ══════════════════════════════════════════════════════════════════════
|
| 51 |
+
\section{Introduction}
|
| 52 |
+
|
| 53 |
+
Modern ASR pipelines excel at converting speech to text with low word error rates, yet they systematically strip out the expressive dimension of human communication. When a speaker sighs before a sentence, laughs nervously mid-phrase, or whispers for emphasis, these paralinguistic signals carry meaning that plain transcription cannot represent. This information is especially critical for downstream text-to-speech (TTS) synthesis: next-generation TTS systems such as ElevenLabs v3~\cite{elevenlabs} consume inline expressive tags to control prosody, affect, and delivery at a fine-grained level.
|
| 54 |
+
|
| 55 |
+
We ask: \emph{can a multimodal audio-language model be trained to produce ASR output that preserves expressive content through inline tags?} We answer affirmatively with \textbf{Evoxtral}, built on Voxtral-Mini-3B-2507~\cite{voxtral}.
|
| 56 |
+
|
| 57 |
+
\paragraph{Why Voxtral.} We chose Voxtral-Mini-3B as our base model for several reasons. First, as a \emph{generative} audio-language model (rather than a CTC or encoder-only ASR), Voxtral decodes transcriptions autoregressively---meaning it can naturally produce arbitrary inline tokens such as \etag{laughs} or \etag{nervous} within the text stream, without requiring architectural changes. Traditional ASR models constrain their output vocabulary to words and punctuation; Voxtral's LLM decoder has no such limitation. Second, Voxtral's compact 3B-parameter architecture makes LoRA fine-tuning feasible on a single A10G GPU within hackathon time constraints, while still delivering competitive ASR quality. Third, Voxtral was itself trained with post-training alignment (SFT~+~DPO~+~Online~DPO)~\cite{voxtral}, meaning the model is already instruction-following and amenable to further fine-tuning---a strong foundation for adding new capabilities. Finally, as a Mistral AI model released under Apache~2.0, Voxtral aligns with the hackathon's focus on the Mistral ecosystem and enables open redistribution of our adapters.
|
| 58 |
+
|
| 59 |
+
Our approach fine-tunes Voxtral-Mini-3B-2507 using parameter-efficient LoRA adapters~\cite{hu2021lora} on a synthetically generated dataset of expressive speech paired with tagged transcriptions.
|
| 60 |
+
|
| 61 |
+
\medskip
|
| 62 |
+
To illustrate the contrast, consider the following example:
|
| 63 |
+
|
| 64 |
+
\begin{quote}
|
| 65 |
+
\textbf{Standard ASR:} ``So I was thinking maybe we could try that new restaurant downtown. I mean if you're free this weekend.''
|
| 66 |
+
|
| 67 |
+
\textbf{Evoxtral:} ``\etag{nervous} So\ldots\ \etag{stammers} I was thinking maybe we could\ldots\ \etag{clears throat} try that new restaurant downtown? \etag{laughs nervously} I mean, if you're free this weekend?''
|
| 68 |
+
\end{quote}
|
| 69 |
+
|
| 70 |
+
\noindent The Evoxtral output captures hesitation, nervous laughter, and a throat clear---paralinguistic content that is acoustically present but conventionally discarded. We make the following contributions:
|
| 71 |
+
|
| 72 |
+
\begin{enumerate}[leftmargin=*,itemsep=2pt]
|
| 73 |
+
\item A synthetic dataset of 1,010 expressive speech samples paired with tagged transcriptions across 17 ElevenLabs v3 tag types.
|
| 74 |
+
\item A two-stage fine-tuning recipe (SFT~$\to$~RAFT) for expressive ASR using LoRA on Voxtral-Mini-3B.
|
| 75 |
+
\item A custom evaluation benchmark, \textbf{Evoxtral-Bench}, with seven metrics covering both transcription accuracy and tag generation quality.
|
| 76 |
+
\item Two released model variants optimized for different use cases: accuracy-critical (SFT) and expressiveness-critical (RL).
|
| 77 |
+
\end{enumerate}
|
| 78 |
+
|
| 79 |
+
% ══════════════════════════════════════════════════════════════════════
|
| 80 |
+
\section{Related Work}
|
| 81 |
+
|
| 82 |
+
\paragraph{Voxtral~\cite{voxtral}.} Voxtral-Mini-3B-2507 is a multimodal audio-language model released by Mistral~AI. It is built on a Whisper-based audio encoder fused with a Mistral language model backbone, trained via SFT, DPO, and online DPO. Our work builds directly on this foundation, adding expressive tagging capability via LoRA adaptation.
|
| 83 |
+
|
| 84 |
+
\paragraph{Reinforcement Learning for LLM-based ASR~\cite{grpo_asr}.} Shi et al.\ apply group relative policy optimization (GRPO) to LLM-based ASR, achieving an 18\% relative WER reduction without paired preference data. Closely related work~\cite{grpo_speech,rl_asr_tts,rl_audio_qa} further demonstrates that RL-based training consistently outperforms SFT alone for speech understanding tasks. Our RAFT stage is philosophically aligned with this line of work but uses rule-based rejection sampling rather than policy gradient methods, making it simpler to implement and more stable to train.
|
| 85 |
+
|
| 86 |
+
\paragraph{LoRA~\cite{hu2021lora}.} Low-rank adaptation inserts trainable low-rank matrices into the attention projections and feed-forward layers of a frozen base model, reducing the number of trainable parameters by orders of magnitude while matching full fine-tuning performance. We use LoRA with rank~64 and alpha~128, implemented via HuggingFace PEFT~\cite{peft}.
|
| 87 |
+
|
| 88 |
+
\paragraph{Rejection Sampling Fine-Tuning~\cite{yuan2023rft}.} Yuan et al.\ propose generating multiple model completions for each training input, scoring them with a reward function, and performing SFT on the highest-scoring completions. This approach is computationally simpler than policy gradient methods while still providing a reinforcement learning signal.
|
| 89 |
+
|
| 90 |
+
\paragraph{NEFTune~\cite{neftune}.} Jain et al.\ demonstrate that adding uniform random noise to embedding vectors during training improves instruction-following performance. We apply NEFTune with noise alpha~$=5.0$ during SFT to regularize training on our small dataset.
|
| 91 |
+
|
| 92 |
+
\paragraph{ElevenLabs v3 Audio Tags~\cite{elevenlabs}.} ElevenLabs v3 TTS introduces a structured vocabulary of inline expressive tags that control how synthesized speech is delivered. These tags---such as \etag{sighs}, \etag{laughs}, \etag{whispers}, and \etag{nervous}---are the target output vocabulary for Evoxtral.
|
| 93 |
+
|
| 94 |
+
\paragraph{ASR Evaluation~\cite{jiwer,morris2004wer}.} We compute WER and CER using the \texttt{jiwer} library~\cite{jiwer}, following standard definitions from Morris et al.~\cite{morris2004wer}.
|
| 95 |
+
|
| 96 |
+
% ══════════════════════════════════════════════════════════════════════
|
| 97 |
+
\section{Method}
|
| 98 |
+
|
| 99 |
+
\subsection{Dataset}
|
| 100 |
+
|
| 101 |
+
We construct a synthetic dataset of 1,010 audio samples generated using the ElevenLabs TTS v3 API~\cite{elevenlabs}. Each sample consists of a short spoken utterance (5--30\,s) paired with a reference tagged transcription containing inline expressive tags. The dataset covers 17~tag types: \etag{sighs}, \etag{laughs}, \etag{whispers}, \etag{nervous}, \etag{frustrated}, \etag{clears~throat}, \etag{pause}, \etag{excited}, \etag{stammers}, \etag{gasps}, \etag{sad}, \etag{angry}, \etag{calm}, \etag{crying}, \etag{shouts}, \etag{confused}, and \etag{scared}.
|
| 102 |
+
|
| 103 |
+
The dataset is split into 808~training, 101~validation, and 101~test samples. Tag frequency follows a long-tail distribution: \etag{pause} is the most common while \etag{confused} and \etag{scared} appear rarely. The Whisper-based audio encoder is kept frozen throughout training; only the language model backbone and multi-modal projector are fine-tuned.
|
| 104 |
+
|
| 105 |
+
\subsection{Stage 1: Supervised Fine-Tuning (SFT)}
|
| 106 |
+
|
| 107 |
+
We fine-tune Voxtral-Mini-3B-2507~\cite{voxtral} using LoRA~\cite{hu2021lora} with the configuration shown in Table~\ref{tab:sft_config}.
|
| 108 |
+
|
| 109 |
+
\begin{table}[h]
|
| 110 |
+
\centering
|
| 111 |
+
\caption{SFT hyperparameters.}
|
| 112 |
+
\label{tab:sft_config}
|
| 113 |
+
\begin{tabular}{@{}ll@{}}
|
| 114 |
+
\toprule
|
| 115 |
+
\textbf{Hyperparameter} & \textbf{Value} \\
|
| 116 |
+
\midrule
|
| 117 |
+
LoRA rank & 64 \\
|
| 118 |
+
LoRA alpha & 128 \\
|
| 119 |
+
LoRA dropout & 0.05 \\
|
| 120 |
+
Target modules & \texttt{q/k/v/o\_proj}, \texttt{gate/up/down\_proj}, \texttt{mm\_projector} \\
|
| 121 |
+
Learning rate & $2 \times 10^{-4}$ \\
|
| 122 |
+
LR schedule & Cosine decay \\
|
| 123 |
+
Epochs & 3 \\
|
| 124 |
+
Batch size & 2 (effective 16 via gradient accumulation $\times 8$) \\
|
| 125 |
+
NEFTune noise alpha & 5.0 \\
|
| 126 |
+
Precision & bf16 \\
|
| 127 |
+
Hardware & NVIDIA A10G (24\,GB) \\
|
| 128 |
+
Training time & $\sim$25 minutes \\
|
| 129 |
+
Trainable parameters & 124.8\,M / 4.8\,B (2.6\%) \\
|
| 130 |
+
\bottomrule
|
| 131 |
+
\end{tabular}
|
| 132 |
+
\end{table}
|
| 133 |
+
|
| 134 |
+
The SFT objective is standard next-token prediction (cross-entropy) on the tagged transcription tokens, conditioned on the audio encoder outputs. NEFTune~\cite{neftune} noise is applied to the input embeddings to reduce overfitting on the small training set.
|
| 135 |
+
|
| 136 |
+
\subsection{Stage 2: Rejection Sampling Fine-Tuning (RAFT)}
|
| 137 |
+
|
| 138 |
+
Following Yuan et al.~\cite{yuan2023rft} and inspired by Voxtral's own SFT$\to$DPO recipe~\cite{voxtral} and GRPO-based ASR work~\cite{grpo_asr}, we apply a rejection sampling stage to refine tag generation quality. The pipeline is illustrated in Figure~\ref{fig:raft_pipeline}.
|
| 139 |
+
|
| 140 |
+
\begin{figure}[t]
|
| 141 |
+
\centering
|
| 142 |
+
\begin{tikzpicture}[
|
| 143 |
+
node distance=1.0cm,
|
| 144 |
+
box/.style={draw, rounded corners, minimum width=6.5cm, minimum height=0.9cm, align=center, font=\small},
|
| 145 |
+
arrow/.style={-{Stealth[length=3mm]}, thick},
|
| 146 |
+
]
|
| 147 |
+
\node[box, fill=blue!8] (data) {Training Set (808 samples)};
|
| 148 |
+
\node[box, fill=orange!10, below=of data] (gen) {Generate $N{=}4$ completions per sample\\($T{=}0.7$, top-$p{=}0.9$)};
|
| 149 |
+
\node[box, fill=red!8, below=of gen] (score) {Score each: $R = 0.4(1{-}\text{WER}) + 0.4\,\text{Tag\_F1}$\\$+ \; 0.2(1{-}\text{hallucination\_rate})$};
|
| 150 |
+
\node[box, fill=yellow!10, below=of score] (filter) {Keep best per sample, filter bottom 10\%\\$\to$ 727 curated samples (avg reward: 0.980)};
|
| 151 |
+
\node[box, fill=green!10, below=of filter] (sft) {SFT on curated data\\(lr$\,{=}\,5{\times}10^{-5}$, 1 epoch, $\sim$7\,min)};
|
| 152 |
+
|
| 153 |
+
\draw[arrow] (data) -- (gen);
|
| 154 |
+
\draw[arrow] (gen) -- (score);
|
| 155 |
+
\draw[arrow] (score) -- (filter);
|
| 156 |
+
\draw[arrow] (filter) -- (sft);
|
| 157 |
+
\end{tikzpicture}
|
| 158 |
+
\caption{RAFT training pipeline. For each training sample, four completions are generated and scored by a rule-based reward function balancing transcription accuracy, tag quality, and hallucination avoidance. Only high-reward completions are retained for the final SFT pass.}
|
| 159 |
+
\label{fig:raft_pipeline}
|
| 160 |
+
\end{figure}
|
| 161 |
+
|
| 162 |
+
The reward function balances three objectives: transcription accuracy (WER), tag generation quality (Tag~F1), and avoidance of hallucinated tags:
|
| 163 |
+
\begin{equation}
|
| 164 |
+
R = 0.4 \times (1 - \text{WER}) + 0.4 \times \text{Tag\_F1} + 0.2 \times (1 - \text{Hallucination\_Rate})
|
| 165 |
+
\label{eq:reward}
|
| 166 |
+
\end{equation}
|
| 167 |
+
|
| 168 |
+
The 0.4/0.4/0.2 weighting reflects an equal priority on accuracy and expressiveness, with a penalty for hallucination. After filtering the bottom 10\% by reward (threshold: $R > 0.954$), 727 of the original 808 training samples remain with a mean reward of 0.980. The RAFT SFT stage trains for one epoch at a reduced learning rate of $5 \times 10^{-5}$, completing in approximately 7~minutes with a final training loss of 0.021.
|
| 169 |
+
|
| 170 |
+
The full two-stage pipeline overview is shown in Figure~\ref{fig:overview}.
|
| 171 |
+
|
| 172 |
+
\begin{figure}[t]
|
| 173 |
+
\centering
|
| 174 |
+
\begin{tikzpicture}[
|
| 175 |
+
node distance=1.0cm,
|
| 176 |
+
box/.style={draw, rounded corners, minimum width=5.5cm, minimum height=0.8cm, align=center, font=\small},
|
| 177 |
+
result/.style={draw, rounded corners, minimum width=5.5cm, minimum height=0.8cm, align=center, font=\small, fill=green!8},
|
| 178 |
+
arrow/.style={-{Stealth[length=3mm]}, thick},
|
| 179 |
+
label/.style={font=\scriptsize\itshape, text=gray!70!black},
|
| 180 |
+
]
|
| 181 |
+
\node[box, fill=blue!10] (base) {Voxtral-Mini-3B-2507\\(frozen audio encoder)};
|
| 182 |
+
\node[box, fill=orange!10, below=of base] (sft) {Stage 1: SFT\\(808 samples, 3 epochs)};
|
| 183 |
+
\node[result, below=of sft] (sft_r) {\textbf{Evoxtral-SFT}\\WER\,=\,4.47\%, Tag F1\,=\,67.2\%};
|
| 184 |
+
\node[box, fill=red!8, below=of sft_r] (raft) {Stage 2: RAFT\\(727 curated samples, 1 epoch)};
|
| 185 |
+
\node[result, below=of raft] (rl_r) {\textbf{Evoxtral-RL}\\WER\,=\,5.12\%, Tag F1\,=\,69.4\%};
|
| 186 |
+
|
| 187 |
+
\draw[arrow] (base) -- (sft) node[midway, right=0.3cm, label] {lr\,$=$\,$2{\times}10^{-4}$, NEFTune};
|
| 188 |
+
\draw[arrow] (sft) -- (sft_r);
|
| 189 |
+
\draw[arrow] (sft_r) -- (raft) node[midway, right=0.3cm, label] {lr\,$=$\,$5{\times}10^{-5}$, reward filter};
|
| 190 |
+
\draw[arrow] (raft) -- (rl_r);
|
| 191 |
+
\end{tikzpicture}
|
| 192 |
+
\caption{Full two-stage Evoxtral training pipeline from base Voxtral-Mini-3B to the two released model variants.}
|
| 193 |
+
\label{fig:overview}
|
| 194 |
+
\end{figure}
|
| 195 |
+
|
| 196 |
+
% ══════════════════════════════════════════════════════════════════════
|
| 197 |
+
\section{Evaluation}
|
| 198 |
+
|
| 199 |
+
\subsection{Evoxtral-Bench}
|
| 200 |
+
|
| 201 |
+
We evaluate on \textbf{Evoxtral-Bench}, a held-out benchmark of 50 test samples drawn from the 101-sample test split. We compute seven evaluation metrics:
|
| 202 |
+
|
| 203 |
+
\begin{itemize}[leftmargin=*,itemsep=1pt]
|
| 204 |
+
\item \textbf{WER} (Word Error Rate, $\downarrow$) --- via \texttt{jiwer}~\cite{jiwer}, tags stripped before comparison
|
| 205 |
+
\item \textbf{CER} (Character Error Rate, $\downarrow$) --- character-level accuracy via \texttt{jiwer}
|
| 206 |
+
\item \textbf{Tag F1} ($\uparrow$) --- token-level F1 on predicted vs.\ reference tag multisets
|
| 207 |
+
\item \textbf{Tag Precision} ($\uparrow$) --- fraction of predicted tags present in reference
|
| 208 |
+
\item \textbf{Tag Recall} ($\uparrow$) --- fraction of reference tags captured by the model
|
| 209 |
+
\item \textbf{Tag Hallucination Rate} ($\downarrow$) --- fraction of predicted tags absent from reference
|
| 210 |
+
\item \textbf{Emphasis F1} ($\uparrow$) --- F1 on CAPITALIZED emphasis words
|
| 211 |
+
\end{itemize}
|
| 212 |
+
|
| 213 |
+
Per-tag F1 is additionally computed across all 17 tag types.
|
| 214 |
+
|
| 215 |
+
\subsection{Core Results}
|
| 216 |
+
|
| 217 |
+
\begin{table}[t]
|
| 218 |
+
\centering
|
| 219 |
+
\caption{Core evaluation results on Evoxtral-Bench (50 samples). Bold indicates best per metric. $\downarrow$ lower is better; $\uparrow$ higher is better.}
|
| 220 |
+
\label{tab:core_results}
|
| 221 |
+
\begin{tabular}{@{}lcccc@{}}
|
| 222 |
+
\toprule
|
| 223 |
+
\textbf{Metric} & \textbf{Base Voxtral} & \textbf{Evoxtral SFT} & \textbf{Evoxtral RL} & \textbf{Best} \\
|
| 224 |
+
\midrule
|
| 225 |
+
WER $\downarrow$ & 6.64\% & \textbf{4.47\%} & 5.12\% & SFT \\
|
| 226 |
+
CER $\downarrow$ & 2.72\% & \textbf{1.23\%} & 1.48\% & SFT \\
|
| 227 |
+
Tag F1 $\uparrow$ & 22.0\% & 67.2\% & \textbf{69.4\%} & RL \\
|
| 228 |
+
Tag Precision $\uparrow$ & 22.0\% & 67.4\% & \textbf{68.5\%} & RL \\
|
| 229 |
+
Tag Recall $\uparrow$ & 22.0\% & 69.4\% & \textbf{72.7\%} & RL \\
|
| 230 |
+
Emphasis F1 $\uparrow$ & 42.0\% & 84.0\% & \textbf{86.0\%} & RL \\
|
| 231 |
+
Tag Hallucination $\downarrow$ & 0.0\% & \textbf{19.3\%} & 20.2\% & SFT \\
|
| 232 |
+
\bottomrule
|
| 233 |
+
\end{tabular}
|
| 234 |
+
\end{table}
|
| 235 |
+
|
| 236 |
+
The base Voxtral model achieves 22.0\% Tag~F1, suggesting some limited native capability to produce expressive tokens but with low precision and recall. SFT provides the dominant improvement: WER decreases by 33\% relative (6.64\%~$\to$~4.47\%) and Tag~F1 increases by 45~percentage points (22.0\%~$\to$~67.2\%). RAFT further refines tag metrics: Tag~F1 improves by 2.2\,pp (67.2\%~$\to$~69.4\%), Tag~Recall by 3.3\,pp (69.4\%~$\to$~72.7\%), and Emphasis~F1 by 2.0\,pp (84.0\%~$\to$~86.0\%). However, RAFT introduces a small WER regression (4.47\%~$\to$~5.12\%), reflecting a Pareto tradeoff between transcription accuracy and expressive richness.
|
| 237 |
+
|
| 238 |
+
Tag hallucination---predicted tags absent from the reference---is 19.3\% for SFT and 20.2\% for RL. The base model has 0\% hallucination because it rarely predicts any tags at all.
|
| 239 |
+
|
| 240 |
+
\subsection{Per-Tag F1 Breakdown}
|
| 241 |
+
|
| 242 |
+
\begin{table}[t]
|
| 243 |
+
\centering
|
| 244 |
+
\caption{Per-tag F1 scores for SFT and RL on Evoxtral-Bench. Support indicates the number of test samples containing each tag.}
|
| 245 |
+
\label{tab:per_tag}
|
| 246 |
+
\begin{tabular}{@{}lcccc@{}}
|
| 247 |
+
\toprule
|
| 248 |
+
\textbf{Tag} & \textbf{SFT F1} & \textbf{RL F1} & \textbf{$\Delta$} & \textbf{Support} \\
|
| 249 |
+
\midrule
|
| 250 |
+
\etag{sighs} & 1.000 & 1.000 & --- & 9 \\
|
| 251 |
+
\etag{clears throat} & 0.889 & \textbf{1.000} & +12.5\% & 8 \\
|
| 252 |
+
\etag{gasps} & 0.957 & 0.957 & --- & 12 \\
|
| 253 |
+
\etag{pause} & 0.885 & \textbf{0.902} & +1.9\% & 25 \\
|
| 254 |
+
\etag{nervous} & 0.800 & \textbf{0.846} & +5.8\% & 13 \\
|
| 255 |
+
\etag{stammers} & \textbf{0.889} & 0.842 & $-$5.3\% & 8 \\
|
| 256 |
+
\etag{laughs} & 0.800 & \textbf{0.815} & +1.9\% & 12 \\
|
| 257 |
+
\etag{sad} & 0.667 & \textbf{0.750} & +12.4\% & 4 \\
|
| 258 |
+
\etag{whispers} & 0.636 & \textbf{0.667} & +4.9\% & 13 \\
|
| 259 |
+
\etag{crying} & \textbf{0.750} & 0.571 & $-$23.9\% & 5 \\
|
| 260 |
+
\etag{excited} & \textbf{0.615} & 0.571 & $-$7.2\% & 5 \\
|
| 261 |
+
\etag{shouts} & 0.400 & \textbf{0.500} & +25.0\% & 3 \\
|
| 262 |
+
\etag{calm} & 0.200 & \textbf{0.400} & +100\% & 6 \\
|
| 263 |
+
\etag{frustrated} & 0.444 & 0.444 & --- & 3 \\
|
| 264 |
+
\etag{angry} & 0.667 & 0.667 & --- & 2 \\
|
| 265 |
+
\etag{confused} & 0.000 & 0.000 & --- & 1 \\
|
| 266 |
+
\etag{scared} & 0.000 & 0.000 & --- & 1 \\
|
| 267 |
+
\bottomrule
|
| 268 |
+
\end{tabular}
|
| 269 |
+
\end{table}
|
| 270 |
+
|
| 271 |
+
RAFT improves 9~tags, maintains 4~stable, and regresses on~3 (Table~\ref{tab:per_tag}). The largest gains are observed for \etag{calm} (+100\%, 0.200~$\to$~0.400), \etag{shouts} (+25.0\%), \etag{clears~throat} (+12.5\%), and \etag{sad} (+12.4\%). Regressions are noted for \etag{crying} ($-$23.9\%), \etag{excited} ($-$7.2\%), and \etag{stammers} ($-$5.3\%). The two zero-F1 tags (\etag{confused}, \etag{scared}) each appear only once in the test set, making estimation unreliable.
|
| 272 |
+
|
| 273 |
+
% ══════════════════════════════════════════════════════════════════════
|
| 274 |
+
\section{Analysis and Discussion}
|
| 275 |
+
|
| 276 |
+
\paragraph{SFT as the primary driver of improvement.} The SFT stage accounts for the vast majority of the performance gain: WER drops 33\% relative and Tag~F1 increases by 45~percentage points. This aligns with findings from GRPO-based ASR work~\cite{grpo_asr,grpo_speech,rl_asr_tts} suggesting that a well-supervised initial adaptation is a strong foundation for subsequent RL refinement.
|
| 277 |
+
|
| 278 |
+
\paragraph{The WER--Tag tradeoff.} RAFT improves tag metrics at the cost of a modest WER regression (4.47\%~$\to$~5.12\%). This suggests the existence of a Pareto frontier between transcription accuracy and expressive richness: optimizing for tag generation pushes the model toward producing more tags, which can introduce minor word-level errors. This motivates releasing two model variants---Evoxtral-SFT for accuracy-critical applications (e.g., professional transcription) and Evoxtral-RL for expressiveness-critical applications (e.g., downstream TTS synthesis with ElevenLabs v3~\cite{elevenlabs}).
|
| 279 |
+
|
| 280 |
+
\paragraph{Tag hallucination.} Approximately 20\% of predicted tags are not present in the reference transcription. Hallucination may occur when the model infers an expressive tone from acoustic cues that are present in the audio but absent or differently annotated in the reference. This may partly reflect annotation noise in synthetic data rather than pure model error. Future work should address this with contrastive or calibration-based training objectives.
|
| 281 |
+
|
| 282 |
+
\paragraph{Effect of NEFTune.} Applying NEFTune~\cite{neftune} with noise alpha~$=5.0$ during SFT provided a regularization benefit on the small 808-sample training set, consistent with Jain et al.'s findings on instruction-following tasks. Ablating this component was not feasible within hackathon time constraints but remains a planned analysis.
|
| 283 |
+
|
| 284 |
+
\paragraph{Rare tag performance.} Tags with very low test support (\etag{confused}, \etag{scared}, support\,$=$\,1) have zero~F1, which is uninformative. Tags with support 2--6 show high variance in F1 estimates. A larger, more balanced evaluation set would provide more reliable per-tag metrics.
|
| 285 |
+
|
| 286 |
+
\paragraph{Reward function design.} The RAFT reward (Equation~\ref{eq:reward}) explicitly encodes the design preference for equal weight on accuracy and expressiveness. The 0.2 weight on hallucination acts as a weak regularizer. An ablation across reward weightings would quantify the sensitivity of the final model to this design choice.
|
| 287 |
+
|
| 288 |
+
% ══════════════════════════════════════════════════════════════════════
|
| 289 |
+
\section{Limitations}
|
| 290 |
+
|
| 291 |
+
\begin{itemize}[leftmargin=*,itemsep=2pt]
|
| 292 |
+
\item \textbf{Synthetic training data.} All 1,010 samples are synthesized using ElevenLabs TTS v3~\cite{elevenlabs}. The acoustic properties of synthetic speech differ from natural human speech. Performance on natural speech recordings may differ.
|
| 293 |
+
\item \textbf{Tag hallucination.} Approximately 20\% of predicted tags in the RL model are not present in the reference, which may limit applicability in settings requiring precise expressive annotation.
|
| 294 |
+
\item \textbf{Rare tag coverage.} Seventeen tag types are represented, but several occur in fewer than 5 test samples. Per-tag F1 estimates for rare categories are unreliable.
|
| 295 |
+
\item \textbf{English only.} The dataset and training are English-only. Generalization to other languages is not evaluated.
|
| 296 |
+
\item \textbf{Small dataset.} 808 training samples is a small fine-tuning set. Scaling to thousands of examples with natural speech could substantially improve performance.
|
| 297 |
+
\item \textbf{Evaluation scope.} Evoxtral-Bench covers 50 test samples. A larger evaluation set would yield more statistically robust estimates.
|
| 298 |
+
\end{itemize}
|
| 299 |
+
|
| 300 |
+
% ══════════════════════════════════════════════════════════════════════
|
| 301 |
+
\section{Conclusion}
|
| 302 |
+
|
| 303 |
+
We presented Evoxtral, a LoRA-adapted version of Voxtral-Mini-3B-2507~\cite{voxtral} that produces expressive tagged transcriptions using ElevenLabs v3 audio tags~\cite{elevenlabs}. Our two-stage training pipeline---SFT followed by RAFT~\cite{yuan2023rft}---demonstrates that expressive tagging capability can be effectively injected into a pre-trained ASR model with parameter-efficient fine-tuning~\cite{hu2021lora,peft}.
|
| 304 |
+
|
| 305 |
+
SFT achieves a 33\% relative WER reduction and a 45~percentage-point improvement in Tag~F1 over the base model. RAFT further improves tag recall and F1 by targeting tag generation quality directly through a rule-based reward signal, at a modest transcription accuracy cost. The two resulting model variants cover a Pareto frontier between accuracy and expressiveness, allowing practitioners to select the appropriate trade-off for their application.
|
| 306 |
+
|
| 307 |
+
Future directions include: (1)~collecting natural speech data with crowd-sourced expressive annotations to reduce the synthetic data gap; (2)~replacing RAFT with GRPO~\cite{grpo_asr} or DPO~\cite{voxtral} for more sample-efficient RL training; (3)~expanding to multilingual settings leveraging Voxtral's multilingual audio encoder; and (4)~developing joint ASR+TTS evaluation protocols that measure downstream TTS quality when Evoxtral output is used as input to ElevenLabs v3~\cite{elevenlabs}.
|
| 308 |
+
|
| 309 |
+
% ══════════════════════════════════════════════════════════════════════
|
| 310 |
+
\bibliographystyle{plainnat}
|
| 311 |
+
|
| 312 |
+
\begin{thebibliography}{12}
|
| 313 |
+
|
| 314 |
+
\bibitem[{Mistral AI}(2025)]{voxtral}
|
| 315 |
+
Mistral AI.
|
| 316 |
+
\newblock Voxtral.
|
| 317 |
+
\newblock \emph{arXiv preprint arXiv:2507.13264}, 2025.
|
| 318 |
+
\newblock \url{https://arxiv.org/abs/2507.13264}
|
| 319 |
+
|
| 320 |
+
\bibitem[{Shi et~al.}(2025)]{grpo_asr}
|
| 321 |
+
Shi, B. et~al.
|
| 322 |
+
\newblock Group relative policy optimization for speech recognition.
|
| 323 |
+
\newblock \emph{arXiv preprint arXiv:2509.01939}, 2025.
|
| 324 |
+
\newblock \url{https://arxiv.org/abs/2509.01939}
|
| 325 |
+
|
| 326 |
+
\bibitem[{Hu et~al.}(2022)]{hu2021lora}
|
| 327 |
+
Hu, E.J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., Wang, L., and Chen, W.
|
| 328 |
+
\newblock {LoRA}: Low-rank adaptation of large language models.
|
| 329 |
+
\newblock In \emph{Proceedings of ICLR}, 2022.
|
| 330 |
+
\newblock \url{https://arxiv.org/abs/2106.09685}
|
| 331 |
+
|
| 332 |
+
\bibitem[{Yuan et~al.}(2023)]{yuan2023rft}
|
| 333 |
+
Yuan, Z., Yuan, H., Li, C., Dong, G., Tan, C., and Zhou, C.
|
| 334 |
+
\newblock Scaling relationship on learning mathematical reasoning with large language models.
|
| 335 |
+
\newblock \emph{arXiv preprint arXiv:2308.01825}, 2023.
|
| 336 |
+
\newblock \url{https://arxiv.org/abs/2308.01825}
|
| 337 |
+
|
| 338 |
+
\bibitem[{Jain et~al.}(2024)]{neftune}
|
| 339 |
+
Jain, N., Chiang, P., Yeh, Y., Kirchenbauer, J., et~al.
|
| 340 |
+
\newblock {NEFTune}: Noisy embeddings improve instruction finetuning.
|
| 341 |
+
\newblock In \emph{Proceedings of ICLR}, 2024.
|
| 342 |
+
\newblock \url{https://arxiv.org/abs/2310.05914}
|
| 343 |
+
|
| 344 |
+
\bibitem[{ElevenLabs}(2025)]{elevenlabs}
|
| 345 |
+
ElevenLabs.
|
| 346 |
+
\newblock Text-to-speech v3 audio tags.
|
| 347 |
+
\newblock ElevenLabs Developer Documentation, 2025.
|
| 348 |
+
\newblock \url{https://elevenlabs.io/docs/api-reference/text-to-speech}
|
| 349 |
+
|
| 350 |
+
\bibitem[{JiWER}(2024)]{jiwer}
|
| 351 |
+
JiWER: Evaluate your speech recognition system.
|
| 352 |
+
\newblock Python library for ASR evaluation metrics.
|
| 353 |
+
\newblock \url{https://github.com/jitsi/jiwer}
|
| 354 |
+
|
| 355 |
+
\bibitem[{}(2025a)]{grpo_speech}
|
| 356 |
+
Advancing speech understanding in speech-aware language models with {GRPO}.
|
| 357 |
+
\newblock \emph{arXiv preprint arXiv:2509.16990}, 2025.
|
| 358 |
+
\newblock \url{https://arxiv.org/abs/2509.16990}
|
| 359 |
+
|
| 360 |
+
\bibitem[{}(2025b)]{rl_asr_tts}
|
| 361 |
+
Explore the reinforcement learning for {LLM}-based {ASR} and {TTS} system.
|
| 362 |
+
\newblock \emph{arXiv preprint arXiv:2509.18569}, 2025.
|
| 363 |
+
\newblock \url{https://arxiv.org/abs/2509.18569}
|
| 364 |
+
|
| 365 |
+
\bibitem[{Mangrulkar et~al.}(2022)]{peft}
|
| 366 |
+
Mangrulkar, S., Gugger, S., Debut, L., Belkada, Y., Paul, S., and Bossan, B.
|
| 367 |
+
\newblock {PEFT}: State-of-the-art parameter-efficient fine-tuning methods.
|
| 368 |
+
\newblock HuggingFace, 2022.
|
| 369 |
+
\newblock \url{https://github.com/huggingface/peft}
|
| 370 |
+
|
| 371 |
+
\bibitem[{Morris et~al.}(2004)]{morris2004wer}
|
| 372 |
+
Morris, A.C., Maier, V., and Green, P.
|
| 373 |
+
\newblock From {WER} and {RIL} to {MER} and {WIL}: Improved evaluation measures for connected speech recognition.
|
| 374 |
+
\newblock In \emph{Proceedings of INTERSPEECH}, 2004.
|
| 375 |
+
|
| 376 |
+
\bibitem[{}(2025c)]{rl_audio_qa}
|
| 377 |
+
Reinforcement learning outperforms supervised fine-tuning: A case study on audio question answering.
|
| 378 |
+
\newblock \emph{arXiv preprint arXiv:2503.11197}, 2025.
|
| 379 |
+
\newblock \url{https://arxiv.org/abs/2503.11197}
|
| 380 |
+
|
| 381 |
+
\end{thebibliography}
|
| 382 |
+
|
| 383 |
+
\end{document}
|
space/app.py
CHANGED
|
@@ -8,7 +8,7 @@ from transformers import VoxtralForConditionalGeneration, AutoProcessor
|
|
| 8 |
from peft import PeftModel
|
| 9 |
|
| 10 |
MODEL_ID = "mistralai/Voxtral-Mini-3B-2507"
|
| 11 |
-
ADAPTER_ID = "YongkangZOU/evoxtral-
|
| 12 |
|
| 13 |
# Load model on CPU at startup, ZeroGPU moves to GPU on demand
|
| 14 |
print("Loading model...")
|
|
|
|
| 8 |
from peft import PeftModel
|
| 9 |
|
| 10 |
MODEL_ID = "mistralai/Voxtral-Mini-3B-2507"
|
| 11 |
+
ADAPTER_ID = "YongkangZOU/evoxtral-rl"
|
| 12 |
|
| 13 |
# Load model on CPU at startup, ZeroGPU moves to GPU on demand
|
| 14 |
print("Loading model...")
|
training/scripts/rl_modal.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evoxtral RL — Rejection sampling + SFT on best completions (RAFT).
|
| 2 |
+
|
| 3 |
+
Follows the GRPO-for-ASR approach (arxiv:2509.01939) simplified for hackathon:
|
| 4 |
+
1. Generate N completions per training sample (with sampling)
|
| 5 |
+
2. Score each with rule-based reward (WER + Tag F1 + hallucination penalty)
|
| 6 |
+
3. Keep best completion per sample
|
| 7 |
+
4. SFT on the curated high-quality dataset (1 epoch, lower LR)
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
modal run scripts/rl_modal.py
|
| 11 |
+
modal run scripts/rl_modal.py --num-samples 4 --lr 5e-5
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import modal
|
| 16 |
+
|
| 17 |
+
image = (
|
| 18 |
+
modal.Image.debian_slim(python_version="3.11")
|
| 19 |
+
.apt_install("ffmpeg", "libsndfile1")
|
| 20 |
+
.pip_install(
|
| 21 |
+
"torch>=2.4.0",
|
| 22 |
+
"torchaudio>=2.4.0",
|
| 23 |
+
"transformers==4.56.0",
|
| 24 |
+
"datasets>=2.14.0",
|
| 25 |
+
"accelerate>=1.0.0",
|
| 26 |
+
"peft>=0.13.0",
|
| 27 |
+
"wandb>=0.18.0",
|
| 28 |
+
"jiwer>=3.0.0",
|
| 29 |
+
"librosa>=0.10.0",
|
| 30 |
+
"soundfile>=0.12.0",
|
| 31 |
+
"huggingface_hub",
|
| 32 |
+
"safetensors",
|
| 33 |
+
"sentencepiece",
|
| 34 |
+
"mistral-common",
|
| 35 |
+
"torchcodec",
|
| 36 |
+
gpu="A10G",
|
| 37 |
+
)
|
| 38 |
+
.env({
|
| 39 |
+
"HF_HUB_CACHE": "/cache/huggingface",
|
| 40 |
+
})
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
app = modal.App("evoxtral-rl", image=image)
|
| 44 |
+
|
| 45 |
+
hf_cache = modal.Volume.from_name("evoxtral-hf-cache", create_if_missing=True)
|
| 46 |
+
data_vol = modal.Volume.from_name("evoxtral-data", create_if_missing=True)
|
| 47 |
+
output_vol = modal.Volume.from_name("evoxtral-output", create_if_missing=True)
|
| 48 |
+
|
| 49 |
+
VOLUMES = {
|
| 50 |
+
"/cache/huggingface": hf_cache,
|
| 51 |
+
"/data": data_vol,
|
| 52 |
+
"/output": output_vol,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
MODEL_ID = "mistralai/Voxtral-Mini-3B-2507"
|
| 56 |
+
SFT_ADAPTER = "/output/evoxtral-lora"
|
| 57 |
+
RL_OUTPUT = "/output/evoxtral-rl"
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@app.function(
|
| 61 |
+
gpu="A10G",
|
| 62 |
+
volumes=VOLUMES,
|
| 63 |
+
secrets=[
|
| 64 |
+
modal.Secret.from_name("wandb-secret"),
|
| 65 |
+
modal.Secret.from_name("huggingface-secret"),
|
| 66 |
+
],
|
| 67 |
+
timeout=7200,
|
| 68 |
+
memory=32768,
|
| 69 |
+
)
|
| 70 |
+
def generate_and_score(num_samples: int = 4, temperature: float = 0.7):
|
| 71 |
+
"""Step 1: Generate N completions per sample and score them."""
|
| 72 |
+
import torch
|
| 73 |
+
import json
|
| 74 |
+
import re
|
| 75 |
+
from pathlib import Path
|
| 76 |
+
from collections import Counter
|
| 77 |
+
from jiwer import wer as compute_wer
|
| 78 |
+
from datasets import load_from_disk, Audio
|
| 79 |
+
from transformers import VoxtralForConditionalGeneration, AutoProcessor
|
| 80 |
+
from peft import PeftModel
|
| 81 |
+
|
| 82 |
+
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 83 |
+
|
| 84 |
+
# --- Reward helpers ---
|
| 85 |
+
def extract_tags(text):
|
| 86 |
+
return [m.group(1).lower() for m in re.finditer(r'\[([^\]]+)\]', text)]
|
| 87 |
+
|
| 88 |
+
def strip_tags(text):
|
| 89 |
+
text = re.sub(r'\[[^\]]+\]\s*', '', text)
|
| 90 |
+
text = re.sub(r'\b[A-Z]{2,}\b', lambda m: m.group(0).lower(), text)
|
| 91 |
+
return text.strip()
|
| 92 |
+
|
| 93 |
+
def compute_reward(prediction, reference):
|
| 94 |
+
"""Rule-based reward: WER accuracy + Tag F1 - hallucination penalty."""
|
| 95 |
+
# WER component (accuracy = 1 - WER)
|
| 96 |
+
ref_plain = strip_tags(reference)
|
| 97 |
+
pred_plain = strip_tags(prediction)
|
| 98 |
+
if ref_plain.strip():
|
| 99 |
+
wer_score = compute_wer(ref_plain, pred_plain)
|
| 100 |
+
wer_accuracy = max(0.0, 1.0 - wer_score)
|
| 101 |
+
else:
|
| 102 |
+
wer_accuracy = 1.0
|
| 103 |
+
|
| 104 |
+
# Tag F1 component
|
| 105 |
+
pred_tags = Counter(extract_tags(prediction))
|
| 106 |
+
ref_tags = Counter(extract_tags(reference))
|
| 107 |
+
if not ref_tags and not pred_tags:
|
| 108 |
+
tag_f1 = 1.0
|
| 109 |
+
hall_rate = 0.0
|
| 110 |
+
elif not ref_tags or not pred_tags:
|
| 111 |
+
tag_f1 = 0.0
|
| 112 |
+
hall_rate = 1.0 if pred_tags else 0.0
|
| 113 |
+
else:
|
| 114 |
+
common = sum((pred_tags & ref_tags).values())
|
| 115 |
+
prec = common / sum(pred_tags.values())
|
| 116 |
+
rec = common / sum(ref_tags.values())
|
| 117 |
+
tag_f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
|
| 118 |
+
# Hallucination rate
|
| 119 |
+
ref_set = set(ref_tags.keys())
|
| 120 |
+
hallucinated = sum(v for k, v in pred_tags.items() if k not in ref_set)
|
| 121 |
+
hall_rate = hallucinated / sum(pred_tags.values())
|
| 122 |
+
|
| 123 |
+
# Combined reward
|
| 124 |
+
reward = 0.4 * wer_accuracy + 0.4 * tag_f1 + 0.2 * (1.0 - hall_rate)
|
| 125 |
+
return reward, {"wer_accuracy": wer_accuracy, "tag_f1": tag_f1, "hall_rate": hall_rate}
|
| 126 |
+
|
| 127 |
+
# --- Load model ---
|
| 128 |
+
print("Loading SFT model...")
|
| 129 |
+
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
| 130 |
+
base_model = VoxtralForConditionalGeneration.from_pretrained(
|
| 131 |
+
MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto",
|
| 132 |
+
)
|
| 133 |
+
model = PeftModel.from_pretrained(base_model, SFT_ADAPTER)
|
| 134 |
+
model.eval()
|
| 135 |
+
print(f"Model loaded on {model.device}")
|
| 136 |
+
|
| 137 |
+
# --- Load training data ---
|
| 138 |
+
ds = load_from_disk("/data/processed")
|
| 139 |
+
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
| 140 |
+
train_ds = ds["train"]
|
| 141 |
+
|
| 142 |
+
print(f"Generating {num_samples} completions per sample for {len(train_ds)} training examples...")
|
| 143 |
+
import time
|
| 144 |
+
start_time = time.time()
|
| 145 |
+
|
| 146 |
+
# Resume from checkpoint if exists
|
| 147 |
+
checkpoint_path = "/output/rl_curated_checkpoint.json"
|
| 148 |
+
if Path(checkpoint_path).exists():
|
| 149 |
+
with open(checkpoint_path) as f:
|
| 150 |
+
curated_data = json.load(f)
|
| 151 |
+
start_idx = len(curated_data)
|
| 152 |
+
total_reward = sum(d["reward"] for d in curated_data)
|
| 153 |
+
print(f"Resuming from checkpoint: {start_idx} samples already done")
|
| 154 |
+
else:
|
| 155 |
+
curated_data = []
|
| 156 |
+
total_reward = 0.0
|
| 157 |
+
start_idx = 0
|
| 158 |
+
|
| 159 |
+
for i in range(start_idx, len(train_ds)):
|
| 160 |
+
row = train_ds[i]
|
| 161 |
+
reference = row["tagged_text"]
|
| 162 |
+
audio_array = row["audio"]["array"]
|
| 163 |
+
|
| 164 |
+
# Build inputs
|
| 165 |
+
inputs = processor.apply_transcription_request(
|
| 166 |
+
language="en",
|
| 167 |
+
audio=[audio_array],
|
| 168 |
+
format=["WAV"],
|
| 169 |
+
model_id=MODEL_ID,
|
| 170 |
+
return_tensors="pt",
|
| 171 |
+
)
|
| 172 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 173 |
+
|
| 174 |
+
# Generate all N completions in one call with num_return_sequences
|
| 175 |
+
with torch.no_grad():
|
| 176 |
+
output_ids = model.generate(
|
| 177 |
+
**inputs,
|
| 178 |
+
max_new_tokens=512,
|
| 179 |
+
do_sample=True,
|
| 180 |
+
temperature=temperature,
|
| 181 |
+
top_p=0.9,
|
| 182 |
+
num_return_sequences=num_samples,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
input_len = inputs["input_ids"].shape[1]
|
| 186 |
+
|
| 187 |
+
# Score each completion, keep best
|
| 188 |
+
best_reward = -1.0
|
| 189 |
+
best_prediction = None
|
| 190 |
+
best_details = None
|
| 191 |
+
|
| 192 |
+
for s in range(num_samples):
|
| 193 |
+
prediction = processor.tokenizer.decode(
|
| 194 |
+
output_ids[s][input_len:], skip_special_tokens=True
|
| 195 |
+
)
|
| 196 |
+
reward, details = compute_reward(prediction, reference)
|
| 197 |
+
if reward > best_reward:
|
| 198 |
+
best_reward = reward
|
| 199 |
+
best_prediction = prediction
|
| 200 |
+
best_details = details
|
| 201 |
+
|
| 202 |
+
curated_data.append({
|
| 203 |
+
"audio_idx": i,
|
| 204 |
+
"reference": reference,
|
| 205 |
+
"best_prediction": best_prediction,
|
| 206 |
+
"reward": best_reward,
|
| 207 |
+
**best_details,
|
| 208 |
+
})
|
| 209 |
+
total_reward += best_reward
|
| 210 |
+
|
| 211 |
+
if i < 5 or i % 50 == 0:
|
| 212 |
+
elapsed = time.time() - start_time
|
| 213 |
+
done = i - start_idx + 1
|
| 214 |
+
rate = done / elapsed if elapsed > 0 else 0
|
| 215 |
+
eta = (len(train_ds) - i - 1) / rate if rate > 0 else 0
|
| 216 |
+
print(f" [{i}/{len(train_ds)}] reward={best_reward:.3f} "
|
| 217 |
+
f"wer_acc={best_details['wer_accuracy']:.3f} "
|
| 218 |
+
f"tag_f1={best_details['tag_f1']:.3f} "
|
| 219 |
+
f"hall={best_details['hall_rate']:.3f} "
|
| 220 |
+
f"({rate:.1f} samples/s, ETA {eta/60:.0f}min)")
|
| 221 |
+
if i < 3:
|
| 222 |
+
print(f" ref: {reference[:80]}...")
|
| 223 |
+
print(f" best: {best_prediction[:80]}...")
|
| 224 |
+
|
| 225 |
+
# Save checkpoint every 50 samples
|
| 226 |
+
if (i + 1) % 50 == 0:
|
| 227 |
+
with open(checkpoint_path, "w") as f:
|
| 228 |
+
json.dump(curated_data, f)
|
| 229 |
+
output_vol.commit()
|
| 230 |
+
print(f" [checkpoint saved: {len(curated_data)} samples]")
|
| 231 |
+
|
| 232 |
+
avg_reward = total_reward / len(curated_data)
|
| 233 |
+
print(f"\nGeneration complete! Avg reward: {avg_reward:.4f}")
|
| 234 |
+
print(f"Curated {len(curated_data)} samples")
|
| 235 |
+
|
| 236 |
+
# Save curated data
|
| 237 |
+
output_path = "/output/rl_curated_data.json"
|
| 238 |
+
with open(output_path, "w") as f:
|
| 239 |
+
json.dump(curated_data, f)
|
| 240 |
+
output_vol.commit()
|
| 241 |
+
print(f"Saved curated data to {output_path}")
|
| 242 |
+
|
| 243 |
+
return {"avg_reward": avg_reward, "num_samples": len(curated_data)}
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
@app.function(
|
| 247 |
+
gpu="A10G",
|
| 248 |
+
volumes=VOLUMES,
|
| 249 |
+
secrets=[
|
| 250 |
+
modal.Secret.from_name("wandb-secret"),
|
| 251 |
+
modal.Secret.from_name("huggingface-secret"),
|
| 252 |
+
],
|
| 253 |
+
timeout=7200,
|
| 254 |
+
memory=32768,
|
| 255 |
+
)
|
| 256 |
+
def rl_finetune(learning_rate: float = 5e-5, num_epochs: int = 1, push_to_hub: bool = True):
|
| 257 |
+
"""Step 2: SFT on curated best completions (RAFT)."""
|
| 258 |
+
import torch
|
| 259 |
+
import wandb
|
| 260 |
+
import json
|
| 261 |
+
from pathlib import Path
|
| 262 |
+
from datasets import Dataset, Audio, load_from_disk
|
| 263 |
+
from transformers import (
|
| 264 |
+
VoxtralForConditionalGeneration,
|
| 265 |
+
AutoProcessor,
|
| 266 |
+
TrainingArguments,
|
| 267 |
+
Trainer,
|
| 268 |
+
)
|
| 269 |
+
from peft import PeftModel
|
| 270 |
+
|
| 271 |
+
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 272 |
+
|
| 273 |
+
# --- Load curated data ---
|
| 274 |
+
with open("/output/rl_curated_data.json") as f:
|
| 275 |
+
curated_data = json.load(f)
|
| 276 |
+
print(f"Loaded {len(curated_data)} curated samples")
|
| 277 |
+
|
| 278 |
+
# Filter out low-reward samples (bottom 10%)
|
| 279 |
+
rewards = [d["reward"] for d in curated_data]
|
| 280 |
+
threshold = sorted(rewards)[len(rewards) // 10]
|
| 281 |
+
curated_data = [d for d in curated_data if d["reward"] > threshold]
|
| 282 |
+
print(f"After filtering (reward > {threshold:.3f}): {len(curated_data)} samples")
|
| 283 |
+
|
| 284 |
+
# --- Load original audio dataset ---
|
| 285 |
+
ds = load_from_disk("/data/processed")
|
| 286 |
+
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
| 287 |
+
train_ds = ds["train"]
|
| 288 |
+
|
| 289 |
+
# Build RL training dataset: use original audio + best prediction as target
|
| 290 |
+
rl_examples = []
|
| 291 |
+
for d in curated_data:
|
| 292 |
+
idx = d["audio_idx"]
|
| 293 |
+
row = train_ds[idx]
|
| 294 |
+
rl_examples.append({
|
| 295 |
+
"audio": row["audio"],
|
| 296 |
+
"tagged_text": d["best_prediction"], # RL target = best sampled completion
|
| 297 |
+
})
|
| 298 |
+
|
| 299 |
+
rl_dataset = Dataset.from_list(rl_examples)
|
| 300 |
+
rl_dataset = rl_dataset.cast_column("audio", Audio(sampling_rate=16000))
|
| 301 |
+
print(f"RL training dataset: {len(rl_dataset)} samples")
|
| 302 |
+
|
| 303 |
+
# --- W&B ---
|
| 304 |
+
run = wandb.init(
|
| 305 |
+
project="evoxtral",
|
| 306 |
+
name=f"rl-raft-lr{learning_rate}-ep{num_epochs}",
|
| 307 |
+
config={
|
| 308 |
+
"method": "RAFT (rejection sampling + SFT)",
|
| 309 |
+
"base_adapter": "evoxtral-lora (SFT)",
|
| 310 |
+
"learning_rate": learning_rate,
|
| 311 |
+
"epochs": num_epochs,
|
| 312 |
+
"num_curated": len(rl_dataset),
|
| 313 |
+
"reward_threshold": threshold,
|
| 314 |
+
},
|
| 315 |
+
tags=["evoxtral", "rl", "raft", "rejection-sampling"],
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
avg_reward = sum(d["reward"] for d in curated_data) / len(curated_data)
|
| 319 |
+
wandb.log({"rl/curated_samples": len(curated_data), "rl/avg_reward": avg_reward})
|
| 320 |
+
|
| 321 |
+
# --- Load SFT model ---
|
| 322 |
+
print("Loading SFT model for RL finetuning...")
|
| 323 |
+
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
| 324 |
+
base_model = VoxtralForConditionalGeneration.from_pretrained(
|
| 325 |
+
MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto",
|
| 326 |
+
)
|
| 327 |
+
model = PeftModel.from_pretrained(base_model, SFT_ADAPTER)
|
| 328 |
+
# Unfreeze LoRA for continued training
|
| 329 |
+
for name, param in model.named_parameters():
|
| 330 |
+
if "lora" in name.lower():
|
| 331 |
+
param.requires_grad = True
|
| 332 |
+
model.print_trainable_parameters()
|
| 333 |
+
|
| 334 |
+
# --- Data Collator (same as SFT) ---
|
| 335 |
+
class VoxtralDataCollator:
|
| 336 |
+
def __init__(self, processor, model_id, max_text_len=512):
|
| 337 |
+
self.processor = processor
|
| 338 |
+
self.model_id = model_id
|
| 339 |
+
self.max_text_len = max_text_len
|
| 340 |
+
self.pad_id = processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id
|
| 341 |
+
|
| 342 |
+
def __call__(self, examples):
|
| 343 |
+
texts = [ex["tagged_text"] for ex in examples]
|
| 344 |
+
audios = [ex["audio"]["array"] for ex in examples]
|
| 345 |
+
|
| 346 |
+
prompt = self.processor.apply_transcription_request(
|
| 347 |
+
language="en",
|
| 348 |
+
model_id=self.model_id,
|
| 349 |
+
audio=audios,
|
| 350 |
+
format=["WAV"] * len(audios),
|
| 351 |
+
return_tensors="pt",
|
| 352 |
+
)
|
| 353 |
+
passthrough = {k: v for k, v in prompt.items()
|
| 354 |
+
if k not in ("input_ids", "attention_mask")}
|
| 355 |
+
|
| 356 |
+
prompt_ids = prompt["input_ids"]
|
| 357 |
+
prompt_attn = prompt["attention_mask"]
|
| 358 |
+
B = prompt_ids.size(0)
|
| 359 |
+
|
| 360 |
+
tok = self.processor.tokenizer
|
| 361 |
+
text_tok = tok(
|
| 362 |
+
texts,
|
| 363 |
+
add_special_tokens=False,
|
| 364 |
+
padding=False,
|
| 365 |
+
truncation=True,
|
| 366 |
+
max_length=self.max_text_len,
|
| 367 |
+
return_tensors=None,
|
| 368 |
+
)
|
| 369 |
+
text_ids_list = text_tok["input_ids"]
|
| 370 |
+
|
| 371 |
+
input_ids, attention_mask, labels = [], [], []
|
| 372 |
+
for i in range(B):
|
| 373 |
+
p_ids = prompt_ids[i].tolist()
|
| 374 |
+
p_att = prompt_attn[i].tolist()
|
| 375 |
+
t_ids = text_ids_list[i]
|
| 376 |
+
|
| 377 |
+
ids = p_ids + t_ids + [tok.eos_token_id]
|
| 378 |
+
attn = p_att + [1] * (len(t_ids) + 1)
|
| 379 |
+
lab = [-100] * len(p_ids) + t_ids + [tok.eos_token_id]
|
| 380 |
+
|
| 381 |
+
input_ids.append(ids)
|
| 382 |
+
attention_mask.append(attn)
|
| 383 |
+
labels.append(lab)
|
| 384 |
+
|
| 385 |
+
max_len = max(len(x) for x in input_ids)
|
| 386 |
+
|
| 387 |
+
def pad_to(seq, fill, L):
|
| 388 |
+
return seq + [fill] * (L - len(seq))
|
| 389 |
+
|
| 390 |
+
input_ids = [pad_to(x, self.pad_id, max_len) for x in input_ids]
|
| 391 |
+
attention_mask = [pad_to(x, 0, max_len) for x in attention_mask]
|
| 392 |
+
labels = [pad_to(x, -100, max_len) for x in labels]
|
| 393 |
+
|
| 394 |
+
batch = {
|
| 395 |
+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
| 396 |
+
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
| 397 |
+
"labels": torch.tensor(labels, dtype=torch.long),
|
| 398 |
+
}
|
| 399 |
+
for k, v in passthrough.items():
|
| 400 |
+
batch[k] = v
|
| 401 |
+
return batch
|
| 402 |
+
|
| 403 |
+
collator = VoxtralDataCollator(processor, MODEL_ID)
|
| 404 |
+
|
| 405 |
+
# --- Training args (lower LR, 1 epoch) ---
|
| 406 |
+
training_args = TrainingArguments(
|
| 407 |
+
output_dir=RL_OUTPUT,
|
| 408 |
+
num_train_epochs=num_epochs,
|
| 409 |
+
per_device_train_batch_size=2,
|
| 410 |
+
gradient_accumulation_steps=8,
|
| 411 |
+
learning_rate=learning_rate,
|
| 412 |
+
lr_scheduler_type="cosine",
|
| 413 |
+
warmup_steps=20,
|
| 414 |
+
weight_decay=0.01,
|
| 415 |
+
max_grad_norm=1.0,
|
| 416 |
+
bf16=True,
|
| 417 |
+
gradient_checkpointing=True,
|
| 418 |
+
gradient_checkpointing_kwargs={"use_reentrant": False},
|
| 419 |
+
logging_steps=5,
|
| 420 |
+
save_strategy="epoch",
|
| 421 |
+
save_total_limit=2,
|
| 422 |
+
report_to="wandb",
|
| 423 |
+
remove_unused_columns=False,
|
| 424 |
+
dataloader_pin_memory=True,
|
| 425 |
+
dataloader_num_workers=4,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
trainer = Trainer(
|
| 429 |
+
model=model,
|
| 430 |
+
args=training_args,
|
| 431 |
+
train_dataset=rl_dataset,
|
| 432 |
+
data_collator=collator,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
print("Starting RL finetuning...")
|
| 436 |
+
train_result = trainer.train()
|
| 437 |
+
|
| 438 |
+
wandb.log({
|
| 439 |
+
"rl/final_loss": train_result.metrics.get("train_loss", 0),
|
| 440 |
+
"rl/runtime_seconds": train_result.metrics.get("train_runtime", 0),
|
| 441 |
+
})
|
| 442 |
+
|
| 443 |
+
# Save adapter
|
| 444 |
+
print(f"Saving RL adapter to {RL_OUTPUT}")
|
| 445 |
+
trainer.save_model(RL_OUTPUT)
|
| 446 |
+
processor.save_pretrained(RL_OUTPUT)
|
| 447 |
+
|
| 448 |
+
# Log as W&B artifact
|
| 449 |
+
artifact = wandb.Artifact(
|
| 450 |
+
"evoxtral-rl-adapter",
|
| 451 |
+
type="model",
|
| 452 |
+
metadata={"method": "RAFT", "base": "evoxtral-lora"},
|
| 453 |
+
)
|
| 454 |
+
artifact.add_dir(RL_OUTPUT)
|
| 455 |
+
run.log_artifact(artifact)
|
| 456 |
+
|
| 457 |
+
# Push to Hub
|
| 458 |
+
if push_to_hub:
|
| 459 |
+
from huggingface_hub import HfApi
|
| 460 |
+
HUB_ID = "YongkangZOU/evoxtral-rl"
|
| 461 |
+
print(f"Pushing to HuggingFace Hub: {HUB_ID}")
|
| 462 |
+
try:
|
| 463 |
+
api = HfApi(token=os.environ.get("HF_TOKEN"))
|
| 464 |
+
api.create_repo(HUB_ID, repo_type="model", exist_ok=True)
|
| 465 |
+
api.upload_folder(
|
| 466 |
+
folder_path=RL_OUTPUT,
|
| 467 |
+
repo_id=HUB_ID,
|
| 468 |
+
repo_type="model",
|
| 469 |
+
commit_message=f"RL adapter (RAFT): lr={learning_rate}, ep={num_epochs}",
|
| 470 |
+
)
|
| 471 |
+
print(f"Pushed to {HUB_ID}")
|
| 472 |
+
except Exception as e:
|
| 473 |
+
print(f"Hub push failed: {e}")
|
| 474 |
+
|
| 475 |
+
output_vol.commit()
|
| 476 |
+
wandb.finish()
|
| 477 |
+
print("RL finetuning complete!")
|
| 478 |
+
return train_result.metrics
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
@app.local_entrypoint()
|
| 482 |
+
def main(
|
| 483 |
+
num_samples: int = 4,
|
| 484 |
+
temperature: float = 0.7,
|
| 485 |
+
lr: float = 5e-5,
|
| 486 |
+
epochs: int = 1,
|
| 487 |
+
push_to_hub: bool = True,
|
| 488 |
+
finetune_only: bool = False,
|
| 489 |
+
):
|
| 490 |
+
if not finetune_only:
|
| 491 |
+
print("Step 1: Generating and scoring completions...")
|
| 492 |
+
gen_results = generate_and_score.remote(
|
| 493 |
+
num_samples=num_samples,
|
| 494 |
+
temperature=temperature,
|
| 495 |
+
)
|
| 496 |
+
print(f"Generation results: {gen_results}")
|
| 497 |
+
|
| 498 |
+
print("\nStep 2: RL finetuning on curated data...")
|
| 499 |
+
ft_results = rl_finetune.remote(
|
| 500 |
+
learning_rate=lr,
|
| 501 |
+
num_epochs=epochs,
|
| 502 |
+
push_to_hub=push_to_hub,
|
| 503 |
+
)
|
| 504 |
+
print(f"RL finetune results: {ft_results}")
|
| 505 |
+
print("\nDone! Run eval with: modal run scripts/train_modal.py --eval-only")
|
| 506 |
+
print("(Update adapter_path in evaluate() to /output/evoxtral-rl)")
|
training/scripts/serve_modal.py
CHANGED
|
@@ -4,14 +4,18 @@ Swagger UI: https://yongkang-zou1999--evoxtral-api-evoxtralmodel-web.modal.run/d
|
|
| 4 |
|
| 5 |
Usage:
|
| 6 |
# Deploy:
|
| 7 |
-
modal deploy scripts/serve_modal.py
|
| 8 |
|
| 9 |
# Test locally:
|
| 10 |
-
modal serve scripts/serve_modal.py
|
| 11 |
|
| 12 |
-
# Call the API:
|
| 13 |
curl -X POST https://yongkang-zou1999--evoxtral-api-evoxtralmodel-web.modal.run/transcribe \
|
| 14 |
-F "file=@audio.wav"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
"""
|
| 16 |
|
| 17 |
import modal
|
|
@@ -34,6 +38,7 @@ image = (
|
|
| 34 |
"sentencepiece",
|
| 35 |
"fastapi",
|
| 36 |
"python-multipart",
|
|
|
|
| 37 |
gpu="A10G",
|
| 38 |
)
|
| 39 |
.env({"HF_HUB_CACHE": "/cache/huggingface"})
|
|
@@ -43,7 +48,50 @@ app = modal.App("evoxtral-api", image=image)
|
|
| 43 |
hf_cache = modal.Volume.from_name("evoxtral-hf-cache", create_if_missing=True)
|
| 44 |
|
| 45 |
MODEL_ID = "mistralai/Voxtral-Mini-3B-2507"
|
| 46 |
-
ADAPTER_ID = "YongkangZOU/evoxtral-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
@app.cls(
|
|
@@ -73,23 +121,30 @@ class EvoxtralModel:
|
|
| 73 |
@modal.asgi_app()
|
| 74 |
def web(self):
|
| 75 |
import torch
|
|
|
|
|
|
|
| 76 |
import numpy as np
|
|
|
|
| 77 |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 78 |
-
from fastapi.responses import JSONResponse
|
|
|
|
| 79 |
|
| 80 |
web_app = FastAPI(
|
| 81 |
title="Evoxtral API",
|
| 82 |
description=(
|
| 83 |
"Expressive tagged transcription powered by Voxtral-Mini-3B + LoRA. "
|
| 84 |
"Upload audio and get transcriptions with inline expressive tags like "
|
| 85 |
-
"[sighs], [laughs], [whispers], etc."
|
|
|
|
|
|
|
|
|
|
| 86 |
),
|
| 87 |
-
version="
|
| 88 |
)
|
| 89 |
|
| 90 |
@web_app.get("/health", summary="Health check")
|
| 91 |
async def health():
|
| 92 |
-
return {"status": "ok", "model": "evoxtral-
|
| 93 |
|
| 94 |
@web_app.post(
|
| 95 |
"/transcribe",
|
|
@@ -100,39 +155,16 @@ class EvoxtralModel:
|
|
| 100 |
file: UploadFile = File(..., description="Audio file (WAV, MP3, FLAC, etc.)"),
|
| 101 |
language: str = Form("en", description="Language code (e.g. 'en', 'fr', 'es')"),
|
| 102 |
):
|
| 103 |
-
import librosa
|
| 104 |
-
import soundfile as sf
|
| 105 |
-
import io
|
| 106 |
-
|
| 107 |
audio_bytes = await file.read()
|
| 108 |
if not audio_bytes:
|
| 109 |
raise HTTPException(status_code=400, detail="Empty audio file")
|
| 110 |
|
| 111 |
-
# Decode audio
|
| 112 |
try:
|
| 113 |
-
audio_array
|
| 114 |
-
audio_array = audio_array.astype(np.float32)
|
| 115 |
-
if audio_array.ndim > 1:
|
| 116 |
-
audio_array = audio_array.mean(axis=1)
|
| 117 |
-
if sr != 16000:
|
| 118 |
-
audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=16000)
|
| 119 |
except Exception as e:
|
| 120 |
raise HTTPException(status_code=400, detail=f"Failed to decode audio: {e}")
|
| 121 |
|
| 122 |
-
|
| 123 |
-
inputs = self.processor.apply_transcription_request(
|
| 124 |
-
language=language,
|
| 125 |
-
audio=[audio_array],
|
| 126 |
-
format=["WAV"],
|
| 127 |
-
model_id=MODEL_ID,
|
| 128 |
-
return_tensors="pt",
|
| 129 |
-
)
|
| 130 |
-
inputs = {
|
| 131 |
-
k: v.to(self.model.device, dtype=torch.bfloat16)
|
| 132 |
-
if v.dtype in (torch.float32, torch.float16, torch.bfloat16)
|
| 133 |
-
else v.to(self.model.device)
|
| 134 |
-
for k, v in inputs.items()
|
| 135 |
-
}
|
| 136 |
|
| 137 |
with torch.no_grad():
|
| 138 |
output_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=False)
|
|
@@ -145,7 +177,61 @@ class EvoxtralModel:
|
|
| 145 |
return {
|
| 146 |
"transcription": transcription,
|
| 147 |
"language": language,
|
| 148 |
-
"model": "evoxtral-
|
| 149 |
}
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
return web_app
|
|
|
|
| 4 |
|
| 5 |
Usage:
|
| 6 |
# Deploy:
|
| 7 |
+
modal deploy training/scripts/serve_modal.py
|
| 8 |
|
| 9 |
# Test locally:
|
| 10 |
+
modal serve training/scripts/serve_modal.py
|
| 11 |
|
| 12 |
+
# Call the API (JSON response):
|
| 13 |
curl -X POST https://yongkang-zou1999--evoxtral-api-evoxtralmodel-web.modal.run/transcribe \
|
| 14 |
-F "file=@audio.wav"
|
| 15 |
+
|
| 16 |
+
# Call the streaming API (Server-Sent Events):
|
| 17 |
+
curl -N -X POST https://yongkang-zou1999--evoxtral-api-evoxtralmodel-web.modal.run/transcribe/stream \
|
| 18 |
+
-F "file=@audio.wav"
|
| 19 |
"""
|
| 20 |
|
| 21 |
import modal
|
|
|
|
| 38 |
"sentencepiece",
|
| 39 |
"fastapi",
|
| 40 |
"python-multipart",
|
| 41 |
+
"sse-starlette",
|
| 42 |
gpu="A10G",
|
| 43 |
)
|
| 44 |
.env({"HF_HUB_CACHE": "/cache/huggingface"})
|
|
|
|
| 48 |
hf_cache = modal.Volume.from_name("evoxtral-hf-cache", create_if_missing=True)
|
| 49 |
|
| 50 |
MODEL_ID = "mistralai/Voxtral-Mini-3B-2507"
|
| 51 |
+
ADAPTER_ID = "YongkangZOU/evoxtral-rl"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _decode_audio(audio_bytes):
|
| 55 |
+
"""Decode audio bytes to float32 numpy array at 16kHz.
|
| 56 |
+
|
| 57 |
+
Uses librosa (backed by ffmpeg) so all common formats work:
|
| 58 |
+
WAV, FLAC, MP3, MP4, M4A, WebM, OGG, etc.
|
| 59 |
+
"""
|
| 60 |
+
import numpy as np
|
| 61 |
+
import librosa
|
| 62 |
+
import tempfile
|
| 63 |
+
import os
|
| 64 |
+
|
| 65 |
+
# librosa needs a file path (uses ffmpeg under the hood for non-WAV)
|
| 66 |
+
with tempfile.NamedTemporaryFile(suffix=".audio", delete=False) as f:
|
| 67 |
+
f.write(audio_bytes)
|
| 68 |
+
tmp_path = f.name
|
| 69 |
+
try:
|
| 70 |
+
audio_array, sr = librosa.load(tmp_path, sr=16000, mono=True)
|
| 71 |
+
finally:
|
| 72 |
+
os.unlink(tmp_path)
|
| 73 |
+
|
| 74 |
+
return audio_array.astype(np.float32)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _prepare_inputs(processor, audio_array, language, device):
|
| 78 |
+
"""Prepare model inputs from audio array."""
|
| 79 |
+
import torch
|
| 80 |
+
|
| 81 |
+
inputs = processor.apply_transcription_request(
|
| 82 |
+
language=language,
|
| 83 |
+
audio=[audio_array],
|
| 84 |
+
format=["WAV"],
|
| 85 |
+
model_id=MODEL_ID,
|
| 86 |
+
return_tensors="pt",
|
| 87 |
+
)
|
| 88 |
+
inputs = {
|
| 89 |
+
k: v.to(device, dtype=torch.bfloat16)
|
| 90 |
+
if v.dtype in (torch.float32, torch.float16, torch.bfloat16)
|
| 91 |
+
else v.to(device)
|
| 92 |
+
for k, v in inputs.items()
|
| 93 |
+
}
|
| 94 |
+
return inputs
|
| 95 |
|
| 96 |
|
| 97 |
@app.cls(
|
|
|
|
| 121 |
@modal.asgi_app()
|
| 122 |
def web(self):
|
| 123 |
import torch
|
| 124 |
+
import json
|
| 125 |
+
import asyncio
|
| 126 |
import numpy as np
|
| 127 |
+
from threading import Thread
|
| 128 |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 129 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 130 |
+
from transformers import TextIteratorStreamer
|
| 131 |
|
| 132 |
web_app = FastAPI(
|
| 133 |
title="Evoxtral API",
|
| 134 |
description=(
|
| 135 |
"Expressive tagged transcription powered by Voxtral-Mini-3B + LoRA. "
|
| 136 |
"Upload audio and get transcriptions with inline expressive tags like "
|
| 137 |
+
"[sighs], [laughs], [whispers], etc.\n\n"
|
| 138 |
+
"**Endpoints:**\n"
|
| 139 |
+
"- `POST /transcribe` — Returns full transcription as JSON\n"
|
| 140 |
+
"- `POST /transcribe/stream` — Streams tokens via Server-Sent Events (SSE)"
|
| 141 |
),
|
| 142 |
+
version="2.0.0",
|
| 143 |
)
|
| 144 |
|
| 145 |
@web_app.get("/health", summary="Health check")
|
| 146 |
async def health():
|
| 147 |
+
return {"status": "ok", "model": "evoxtral-rl", "base": MODEL_ID}
|
| 148 |
|
| 149 |
@web_app.post(
|
| 150 |
"/transcribe",
|
|
|
|
| 155 |
file: UploadFile = File(..., description="Audio file (WAV, MP3, FLAC, etc.)"),
|
| 156 |
language: str = Form("en", description="Language code (e.g. 'en', 'fr', 'es')"),
|
| 157 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
audio_bytes = await file.read()
|
| 159 |
if not audio_bytes:
|
| 160 |
raise HTTPException(status_code=400, detail="Empty audio file")
|
| 161 |
|
|
|
|
| 162 |
try:
|
| 163 |
+
audio_array = _decode_audio(audio_bytes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
except Exception as e:
|
| 165 |
raise HTTPException(status_code=400, detail=f"Failed to decode audio: {e}")
|
| 166 |
|
| 167 |
+
inputs = _prepare_inputs(self.processor, audio_array, language, self.model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
with torch.no_grad():
|
| 170 |
output_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=False)
|
|
|
|
| 177 |
return {
|
| 178 |
"transcription": transcription,
|
| 179 |
"language": language,
|
| 180 |
+
"model": "evoxtral-rl",
|
| 181 |
}
|
| 182 |
|
| 183 |
+
@web_app.post(
|
| 184 |
+
"/transcribe/stream",
|
| 185 |
+
summary="Transcribe audio with streaming (SSE)",
|
| 186 |
+
response_description="Server-Sent Events stream of transcription tokens",
|
| 187 |
+
)
|
| 188 |
+
async def transcribe_stream(
|
| 189 |
+
file: UploadFile = File(..., description="Audio file (WAV, MP3, FLAC, etc.)"),
|
| 190 |
+
language: str = Form("en", description="Language code (e.g. 'en', 'fr', 'es')"),
|
| 191 |
+
):
|
| 192 |
+
audio_bytes = await file.read()
|
| 193 |
+
if not audio_bytes:
|
| 194 |
+
raise HTTPException(status_code=400, detail="Empty audio file")
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
audio_array = _decode_audio(audio_bytes)
|
| 198 |
+
except Exception as e:
|
| 199 |
+
raise HTTPException(status_code=400, detail=f"Failed to decode audio: {e}")
|
| 200 |
+
|
| 201 |
+
inputs = _prepare_inputs(self.processor, audio_array, language, self.model.device)
|
| 202 |
+
|
| 203 |
+
streamer = TextIteratorStreamer(
|
| 204 |
+
self.processor.tokenizer,
|
| 205 |
+
skip_prompt=True,
|
| 206 |
+
skip_special_tokens=True,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
generate_kwargs = dict(
|
| 210 |
+
**inputs,
|
| 211 |
+
max_new_tokens=512,
|
| 212 |
+
do_sample=False,
|
| 213 |
+
streamer=streamer,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
thread = Thread(target=lambda: self.model.generate(**generate_kwargs))
|
| 217 |
+
thread.start()
|
| 218 |
+
|
| 219 |
+
async def event_generator():
|
| 220 |
+
full_text = ""
|
| 221 |
+
for token_text in streamer:
|
| 222 |
+
if token_text:
|
| 223 |
+
full_text += token_text
|
| 224 |
+
yield f"data: {json.dumps({'token': token_text})}\n\n"
|
| 225 |
+
yield f"data: {json.dumps({'done': True, 'transcription': full_text, 'language': language, 'model': 'evoxtral-rl'})}\n\n"
|
| 226 |
+
|
| 227 |
+
return StreamingResponse(
|
| 228 |
+
event_generator(),
|
| 229 |
+
media_type="text/event-stream",
|
| 230 |
+
headers={
|
| 231 |
+
"Cache-Control": "no-cache",
|
| 232 |
+
"Connection": "keep-alive",
|
| 233 |
+
"X-Accel-Buffering": "no",
|
| 234 |
+
},
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
return web_app
|
training/scripts/train_modal.py
CHANGED
|
@@ -361,14 +361,14 @@ def train(
|
|
| 361 |
timeout=3600,
|
| 362 |
memory=32768,
|
| 363 |
)
|
| 364 |
-
def evaluate(adapter_path: str = "/output/evoxtral-lora"):
|
| 365 |
"""Run Evoxtral-Bench evaluation (base vs finetuned)."""
|
| 366 |
import torch
|
| 367 |
import wandb
|
| 368 |
import weave
|
| 369 |
import json
|
| 370 |
from pathlib import Path
|
| 371 |
-
from jiwer import wer as compute_wer
|
| 372 |
from datasets import load_from_disk, Audio
|
| 373 |
from transformers import VoxtralForConditionalGeneration, AutoProcessor
|
| 374 |
from peft import PeftModel
|
|
@@ -403,9 +403,24 @@ def evaluate(adapter_path: str = "/output/evoxtral-lora"):
|
|
| 403 |
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
| 404 |
test_ds = ds["test"]
|
| 405 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
def run_model_eval(model, processor, model_name):
|
| 407 |
model.eval()
|
| 408 |
-
results = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
for i in range(min(len(test_ds), 50)): # eval on up to 50 samples
|
| 411 |
row = test_ds[i]
|
|
@@ -428,22 +443,108 @@ def evaluate(adapter_path: str = "/output/evoxtral-lora"):
|
|
| 428 |
input_len = inputs["input_ids"].shape[1]
|
| 429 |
prediction = processor.tokenizer.decode(output_ids[0][input_len:], skip_special_tokens=True)
|
| 430 |
|
| 431 |
-
#
|
| 432 |
ref_plain = strip_tags(tagged_text)
|
| 433 |
pred_plain = strip_tags(prediction)
|
| 434 |
if ref_plain.strip():
|
| 435 |
results["wer"].append(compute_wer(ref_plain, pred_plain))
|
| 436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
if i < 5:
|
| 439 |
print(f"\n[{model_name} Sample {i}]")
|
| 440 |
print(f" Reference: {tagged_text[:100]}...")
|
| 441 |
print(f" Predicted: {prediction[:100]}...")
|
| 442 |
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
|
| 448 |
# ── Evaluate base model ──
|
| 449 |
wandb.init(project="evoxtral", name="eval-base", job_type="evaluation", tags=["eval", "base"])
|
|
@@ -453,7 +554,26 @@ def evaluate(adapter_path: str = "/output/evoxtral-lora"):
|
|
| 453 |
MODEL_ID, dtype=torch.bfloat16, device_map="auto",
|
| 454 |
)
|
| 455 |
base_results = run_model_eval(base_model, processor, "BASE")
|
| 456 |
-
wandb.log({
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
wandb.finish()
|
| 458 |
|
| 459 |
# Free base model memory
|
|
@@ -462,21 +582,45 @@ def evaluate(adapter_path: str = "/output/evoxtral-lora"):
|
|
| 462 |
|
| 463 |
# ── Evaluate finetuned model ──
|
| 464 |
if Path(adapter_path).exists():
|
| 465 |
-
wandb.init(project="evoxtral", name="eval-
|
| 466 |
print("Loading finetuned model...")
|
| 467 |
base_model = VoxtralForConditionalGeneration.from_pretrained(
|
| 468 |
MODEL_ID, dtype=torch.bfloat16, device_map="auto",
|
| 469 |
)
|
| 470 |
ft_model = PeftModel.from_pretrained(base_model, adapter_path)
|
| 471 |
ft_results = run_model_eval(ft_model, processor, "FINETUNED")
|
| 472 |
-
wandb.log({
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
wandb.finish()
|
| 474 |
|
| 475 |
-
print(f"\n{'='*
|
| 476 |
print(f"COMPARISON: Base vs Finetuned")
|
| 477 |
-
print(f"{'='*
|
| 478 |
-
print(f"WER:
|
| 479 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
else:
|
| 481 |
print(f"No adapter found at {adapter_path}, skipping finetuned eval")
|
| 482 |
ft_results = None
|
|
@@ -494,8 +638,12 @@ def main(
|
|
| 494 |
batch_size: int = 2,
|
| 495 |
push_to_hub: bool = True,
|
| 496 |
eval_only: bool = False,
|
|
|
|
| 497 |
):
|
| 498 |
-
if
|
|
|
|
|
|
|
|
|
|
| 499 |
results = evaluate.remote()
|
| 500 |
print(results)
|
| 501 |
else:
|
|
|
|
| 361 |
timeout=3600,
|
| 362 |
memory=32768,
|
| 363 |
)
|
| 364 |
+
def evaluate(adapter_path: str = "/output/evoxtral-lora", eval_name: str = "finetuned"):
|
| 365 |
"""Run Evoxtral-Bench evaluation (base vs finetuned)."""
|
| 366 |
import torch
|
| 367 |
import wandb
|
| 368 |
import weave
|
| 369 |
import json
|
| 370 |
from pathlib import Path
|
| 371 |
+
from jiwer import wer as compute_wer, cer as compute_cer_jiwer
|
| 372 |
from datasets import load_from_disk, Audio
|
| 373 |
from transformers import VoxtralForConditionalGeneration, AutoProcessor
|
| 374 |
from peft import PeftModel
|
|
|
|
| 403 |
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
| 404 |
test_ds = ds["test"]
|
| 405 |
|
| 406 |
+
def compute_cer(ref, hyp):
|
| 407 |
+
"""Character Error Rate."""
|
| 408 |
+
if not ref.strip():
|
| 409 |
+
return 0.0
|
| 410 |
+
return compute_cer_jiwer(ref, hyp)
|
| 411 |
+
|
| 412 |
def run_model_eval(model, processor, model_name):
|
| 413 |
model.eval()
|
| 414 |
+
results = {
|
| 415 |
+
"wer": [], "cer": [],
|
| 416 |
+
"tag_f1": [], "tag_precision": [], "tag_recall": [],
|
| 417 |
+
"tag_hallucination_rate": [],
|
| 418 |
+
"emphasis_f1": [],
|
| 419 |
+
}
|
| 420 |
+
per_tag_tp = Counter() # true positives per tag type
|
| 421 |
+
per_tag_fp = Counter() # false positives (predicted but not in ref)
|
| 422 |
+
per_tag_fn = Counter() # false negatives (in ref but not predicted)
|
| 423 |
+
all_predictions = []
|
| 424 |
|
| 425 |
for i in range(min(len(test_ds), 50)): # eval on up to 50 samples
|
| 426 |
row = test_ds[i]
|
|
|
|
| 443 |
input_len = inputs["input_ids"].shape[1]
|
| 444 |
prediction = processor.tokenizer.decode(output_ids[0][input_len:], skip_special_tokens=True)
|
| 445 |
|
| 446 |
+
# --- Text metrics ---
|
| 447 |
ref_plain = strip_tags(tagged_text)
|
| 448 |
pred_plain = strip_tags(prediction)
|
| 449 |
if ref_plain.strip():
|
| 450 |
results["wer"].append(compute_wer(ref_plain, pred_plain))
|
| 451 |
+
results["cer"].append(compute_cer(ref_plain, pred_plain))
|
| 452 |
+
|
| 453 |
+
# --- Tag metrics (precision, recall, F1) ---
|
| 454 |
+
pred_tags = Counter(extract_tags(prediction))
|
| 455 |
+
ref_tags = Counter(extract_tags(tagged_text))
|
| 456 |
+
|
| 457 |
+
common = pred_tags & ref_tags
|
| 458 |
+
tp = sum(common.values())
|
| 459 |
+
fp = sum(pred_tags.values()) - tp
|
| 460 |
+
fn = sum(ref_tags.values()) - tp
|
| 461 |
+
|
| 462 |
+
prec = tp / (tp + fp) if (tp + fp) > 0 else (1.0 if not ref_tags else 0.0)
|
| 463 |
+
rec = tp / (tp + fn) if (tp + fn) > 0 else (1.0 if not pred_tags else 0.0)
|
| 464 |
+
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else (1.0 if not ref_tags and not pred_tags else 0.0)
|
| 465 |
+
|
| 466 |
+
results["tag_precision"].append(prec)
|
| 467 |
+
results["tag_recall"].append(rec)
|
| 468 |
+
results["tag_f1"].append(f1)
|
| 469 |
+
|
| 470 |
+
# Tag hallucination rate
|
| 471 |
+
ref_tag_set = set(ref_tags.keys())
|
| 472 |
+
hallucinated = sum(v for k, v in pred_tags.items() if k not in ref_tag_set)
|
| 473 |
+
hall_rate = hallucinated / sum(pred_tags.values()) if pred_tags else 0.0
|
| 474 |
+
results["tag_hallucination_rate"].append(hall_rate)
|
| 475 |
+
|
| 476 |
+
# Per-tag breakdown
|
| 477 |
+
for tag in set(list(pred_tags.keys()) + list(ref_tags.keys())):
|
| 478 |
+
matched = min(pred_tags.get(tag, 0), ref_tags.get(tag, 0))
|
| 479 |
+
per_tag_tp[tag] += matched
|
| 480 |
+
per_tag_fp[tag] += max(0, pred_tags.get(tag, 0) - matched)
|
| 481 |
+
per_tag_fn[tag] += max(0, ref_tags.get(tag, 0) - matched)
|
| 482 |
+
|
| 483 |
+
# Emphasis F1
|
| 484 |
+
pred_emph = Counter([m.group(0).lower() for m in re.finditer(r'\b[A-Z]{2,}\b', prediction)])
|
| 485 |
+
ref_emph = Counter([m.group(0).lower() for m in re.finditer(r'\b[A-Z]{2,}\b', tagged_text)])
|
| 486 |
+
emph_common = sum((pred_emph & ref_emph).values())
|
| 487 |
+
emph_total_p = sum(pred_emph.values())
|
| 488 |
+
emph_total_r = sum(ref_emph.values())
|
| 489 |
+
if emph_total_p == 0 and emph_total_r == 0:
|
| 490 |
+
emph_f1 = 1.0
|
| 491 |
+
elif emph_total_p == 0 or emph_total_r == 0:
|
| 492 |
+
emph_f1 = 0.0
|
| 493 |
+
else:
|
| 494 |
+
ep = emph_common / emph_total_p
|
| 495 |
+
er = emph_common / emph_total_r
|
| 496 |
+
emph_f1 = 2 * ep * er / (ep + er) if (ep + er) > 0 else 0.0
|
| 497 |
+
results["emphasis_f1"].append(emph_f1)
|
| 498 |
+
|
| 499 |
+
# Store prediction for W&B table
|
| 500 |
+
all_predictions.append({
|
| 501 |
+
"sample_idx": i,
|
| 502 |
+
"reference": tagged_text,
|
| 503 |
+
"prediction": prediction,
|
| 504 |
+
"wer": results["wer"][-1] if results["wer"] else None,
|
| 505 |
+
"tag_f1": f1,
|
| 506 |
+
"tag_hallucination_rate": hall_rate,
|
| 507 |
+
})
|
| 508 |
|
| 509 |
if i < 5:
|
| 510 |
print(f"\n[{model_name} Sample {i}]")
|
| 511 |
print(f" Reference: {tagged_text[:100]}...")
|
| 512 |
print(f" Predicted: {prediction[:100]}...")
|
| 513 |
|
| 514 |
+
# Compute averages
|
| 515 |
+
def avg(lst):
|
| 516 |
+
return sum(lst) / max(len(lst), 1)
|
| 517 |
+
|
| 518 |
+
avg_metrics = {f"avg_{k}": avg(v) for k, v in results.items()}
|
| 519 |
+
|
| 520 |
+
# Per-tag F1 breakdown
|
| 521 |
+
per_tag_f1 = {}
|
| 522 |
+
for tag in sorted(per_tag_tp.keys() | per_tag_fp.keys() | per_tag_fn.keys()):
|
| 523 |
+
tp = per_tag_tp[tag]
|
| 524 |
+
fp = per_tag_fp[tag]
|
| 525 |
+
fn = per_tag_fn[tag]
|
| 526 |
+
p = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 527 |
+
r = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 528 |
+
f = 2 * p * r / (p + r) if (p + r) > 0 else 0.0
|
| 529 |
+
per_tag_f1[tag] = {"precision": round(p, 3), "recall": round(r, 3), "f1": round(f, 3),
|
| 530 |
+
"support": tp + fn}
|
| 531 |
+
|
| 532 |
+
avg_metrics["per_tag_f1"] = per_tag_f1
|
| 533 |
+
avg_metrics["predictions"] = all_predictions
|
| 534 |
+
|
| 535 |
+
print(f"\n{model_name} Results:")
|
| 536 |
+
print(f" WER: {avg_metrics['avg_wer']:.4f}")
|
| 537 |
+
print(f" CER: {avg_metrics['avg_cer']:.4f}")
|
| 538 |
+
print(f" Tag F1: {avg_metrics['avg_tag_f1']:.4f}")
|
| 539 |
+
print(f" Tag Precision: {avg_metrics['avg_tag_precision']:.4f}")
|
| 540 |
+
print(f" Tag Recall: {avg_metrics['avg_tag_recall']:.4f}")
|
| 541 |
+
print(f" Tag Hallucination: {avg_metrics['avg_tag_hallucination_rate']:.4f}")
|
| 542 |
+
print(f" Emphasis F1: {avg_metrics['avg_emphasis_f1']:.4f}")
|
| 543 |
+
print(f"\n Per-tag breakdown:")
|
| 544 |
+
for tag, m in sorted(per_tag_f1.items(), key=lambda x: -x[1]["support"]):
|
| 545 |
+
print(f" [{tag}]: F1={m['f1']:.3f} P={m['precision']:.3f} R={m['recall']:.3f} (n={m['support']})")
|
| 546 |
+
|
| 547 |
+
return avg_metrics
|
| 548 |
|
| 549 |
# ── Evaluate base model ──
|
| 550 |
wandb.init(project="evoxtral", name="eval-base", job_type="evaluation", tags=["eval", "base"])
|
|
|
|
| 554 |
MODEL_ID, dtype=torch.bfloat16, device_map="auto",
|
| 555 |
)
|
| 556 |
base_results = run_model_eval(base_model, processor, "BASE")
|
| 557 |
+
wandb.log({
|
| 558 |
+
"eval/wer": base_results["avg_wer"],
|
| 559 |
+
"eval/cer": base_results["avg_cer"],
|
| 560 |
+
"eval/tag_f1": base_results["avg_tag_f1"],
|
| 561 |
+
"eval/tag_precision": base_results["avg_tag_precision"],
|
| 562 |
+
"eval/tag_recall": base_results["avg_tag_recall"],
|
| 563 |
+
"eval/tag_hallucination_rate": base_results["avg_tag_hallucination_rate"],
|
| 564 |
+
"eval/emphasis_f1": base_results["avg_emphasis_f1"],
|
| 565 |
+
})
|
| 566 |
+
# Log per-tag breakdown as table
|
| 567 |
+
tag_table = wandb.Table(columns=["tag", "f1", "precision", "recall", "support"])
|
| 568 |
+
for tag, m in base_results["per_tag_f1"].items():
|
| 569 |
+
tag_table.add_data(tag, m["f1"], m["precision"], m["recall"], m["support"])
|
| 570 |
+
wandb.log({"eval/per_tag_breakdown": tag_table})
|
| 571 |
+
# Log predictions table
|
| 572 |
+
pred_table = wandb.Table(columns=["idx", "reference", "prediction", "wer", "tag_f1", "hallucination_rate"])
|
| 573 |
+
for p in base_results["predictions"]:
|
| 574 |
+
pred_table.add_data(p["sample_idx"], p["reference"], p["prediction"],
|
| 575 |
+
p["wer"], p["tag_f1"], p["tag_hallucination_rate"])
|
| 576 |
+
wandb.log({"eval/predictions": pred_table})
|
| 577 |
wandb.finish()
|
| 578 |
|
| 579 |
# Free base model memory
|
|
|
|
| 582 |
|
| 583 |
# ── Evaluate finetuned model ──
|
| 584 |
if Path(adapter_path).exists():
|
| 585 |
+
wandb.init(project="evoxtral", name=f"eval-{eval_name}", job_type="evaluation", tags=["eval", eval_name])
|
| 586 |
print("Loading finetuned model...")
|
| 587 |
base_model = VoxtralForConditionalGeneration.from_pretrained(
|
| 588 |
MODEL_ID, dtype=torch.bfloat16, device_map="auto",
|
| 589 |
)
|
| 590 |
ft_model = PeftModel.from_pretrained(base_model, adapter_path)
|
| 591 |
ft_results = run_model_eval(ft_model, processor, "FINETUNED")
|
| 592 |
+
wandb.log({
|
| 593 |
+
"eval/wer": ft_results["avg_wer"],
|
| 594 |
+
"eval/cer": ft_results["avg_cer"],
|
| 595 |
+
"eval/tag_f1": ft_results["avg_tag_f1"],
|
| 596 |
+
"eval/tag_precision": ft_results["avg_tag_precision"],
|
| 597 |
+
"eval/tag_recall": ft_results["avg_tag_recall"],
|
| 598 |
+
"eval/tag_hallucination_rate": ft_results["avg_tag_hallucination_rate"],
|
| 599 |
+
"eval/emphasis_f1": ft_results["avg_emphasis_f1"],
|
| 600 |
+
})
|
| 601 |
+
# Log per-tag breakdown as table
|
| 602 |
+
tag_table = wandb.Table(columns=["tag", "f1", "precision", "recall", "support"])
|
| 603 |
+
for tag, m in ft_results["per_tag_f1"].items():
|
| 604 |
+
tag_table.add_data(tag, m["f1"], m["precision"], m["recall"], m["support"])
|
| 605 |
+
wandb.log({"eval/per_tag_breakdown": tag_table})
|
| 606 |
+
# Log predictions table
|
| 607 |
+
pred_table = wandb.Table(columns=["idx", "reference", "prediction", "wer", "tag_f1", "hallucination_rate"])
|
| 608 |
+
for p in ft_results["predictions"]:
|
| 609 |
+
pred_table.add_data(p["sample_idx"], p["reference"], p["prediction"],
|
| 610 |
+
p["wer"], p["tag_f1"], p["tag_hallucination_rate"])
|
| 611 |
+
wandb.log({"eval/predictions": pred_table})
|
| 612 |
wandb.finish()
|
| 613 |
|
| 614 |
+
print(f"\n{'='*60}")
|
| 615 |
print(f"COMPARISON: Base vs Finetuned")
|
| 616 |
+
print(f"{'='*60}")
|
| 617 |
+
print(f"WER: {base_results['avg_wer']:.4f} → {ft_results['avg_wer']:.4f}")
|
| 618 |
+
print(f"CER: {base_results['avg_cer']:.4f} → {ft_results['avg_cer']:.4f}")
|
| 619 |
+
print(f"Tag F1: {base_results['avg_tag_f1']:.4f} → {ft_results['avg_tag_f1']:.4f}")
|
| 620 |
+
print(f"Tag Precision: {base_results['avg_tag_precision']:.4f} → {ft_results['avg_tag_precision']:.4f}")
|
| 621 |
+
print(f"Tag Recall: {base_results['avg_tag_recall']:.4f} → {ft_results['avg_tag_recall']:.4f}")
|
| 622 |
+
print(f"Tag Hallucination: {base_results['avg_tag_hallucination_rate']:.4f} → {ft_results['avg_tag_hallucination_rate']:.4f}")
|
| 623 |
+
print(f"Emphasis F1: {base_results['avg_emphasis_f1']:.4f} → {ft_results['avg_emphasis_f1']:.4f}")
|
| 624 |
else:
|
| 625 |
print(f"No adapter found at {adapter_path}, skipping finetuned eval")
|
| 626 |
ft_results = None
|
|
|
|
| 638 |
batch_size: int = 2,
|
| 639 |
push_to_hub: bool = True,
|
| 640 |
eval_only: bool = False,
|
| 641 |
+
eval_rl: bool = False,
|
| 642 |
):
|
| 643 |
+
if eval_rl:
|
| 644 |
+
results = evaluate.remote(adapter_path="/output/evoxtral-rl", eval_name="rl")
|
| 645 |
+
print(results)
|
| 646 |
+
elif eval_only:
|
| 647 |
results = evaluate.remote()
|
| 648 |
print(results)
|
| 649 |
else:
|
web/src/app/api/speech-to-text/route.ts
CHANGED
|
@@ -20,12 +20,17 @@ export async function POST(req: Request) {
|
|
| 20 |
);
|
| 21 |
}
|
| 22 |
|
| 23 |
-
const
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
// Forward the formData to the
|
| 26 |
-
const
|
|
|
|
|
|
|
|
|
|
| 27 |
method: "POST",
|
| 28 |
-
body:
|
| 29 |
// Signal timeout after 5 minutes
|
| 30 |
signal: AbortSignal.timeout(5 * 60 * 1000),
|
| 31 |
});
|
|
|
|
| 20 |
);
|
| 21 |
}
|
| 22 |
|
| 23 |
+
const MODAL_API_URL =
|
| 24 |
+
process.env.MODAL_API_URL ??
|
| 25 |
+
"https://yongkang-zou1999--evoxtral-api-evoxtralmodel-web.modal.run";
|
| 26 |
|
| 27 |
+
// Forward the formData to the Modal API (non-streaming)
|
| 28 |
+
const upstream = new FormData();
|
| 29 |
+
const originalName = (audioFile as File).name || "audio.wav";
|
| 30 |
+
upstream.append("file", audioFile, originalName);
|
| 31 |
+
const response = await fetch(`${MODAL_API_URL}/transcribe`, {
|
| 32 |
method: "POST",
|
| 33 |
+
body: upstream,
|
| 34 |
// Signal timeout after 5 minutes
|
| 35 |
signal: AbortSignal.timeout(5 * 60 * 1000),
|
| 36 |
});
|
web/src/app/api/transcribe-stream/route.ts
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* SSE proxy: forwards audio to Modal /transcribe/stream and pipes tokens back.
|
| 3 |
+
* Browser cannot call Modal directly (no CORS), so this Next.js route acts as relay.
|
| 4 |
+
*/
|
| 5 |
+
|
| 6 |
+
const MODAL_API_URL =
|
| 7 |
+
process.env.MODAL_API_URL ??
|
| 8 |
+
"https://yongkang-zou1999--evoxtral-api-evoxtralmodel-web.modal.run";
|
| 9 |
+
|
| 10 |
+
export async function POST(req: Request) {
|
| 11 |
+
const formData = await req.formData();
|
| 12 |
+
const audioFile = formData.get("audio") as Blob | null;
|
| 13 |
+
|
| 14 |
+
if (!audioFile) {
|
| 15 |
+
return new Response(JSON.stringify({ error: "Audio file is required" }), {
|
| 16 |
+
status: 400,
|
| 17 |
+
headers: { "Content-Type": "application/json" },
|
| 18 |
+
});
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
const MAX_UPLOAD_BYTES = 100 * 1024 * 1024;
|
| 22 |
+
if (audioFile.size > MAX_UPLOAD_BYTES) {
|
| 23 |
+
return new Response(
|
| 24 |
+
JSON.stringify({ error: `File exceeds ${MAX_UPLOAD_BYTES / 1024 / 1024}MB limit` }),
|
| 25 |
+
{ status: 413, headers: { "Content-Type": "application/json" } }
|
| 26 |
+
);
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
// Forward to Modal streaming endpoint, preserving original filename for format detection
|
| 30 |
+
const upstream = new FormData();
|
| 31 |
+
const originalName = (audioFile as File).name || "audio.wav";
|
| 32 |
+
upstream.append("file", audioFile, originalName);
|
| 33 |
+
|
| 34 |
+
const language = formData.get("language") as string | null;
|
| 35 |
+
if (language) upstream.append("language", language);
|
| 36 |
+
|
| 37 |
+
try {
|
| 38 |
+
const res = await fetch(`${MODAL_API_URL}/transcribe/stream`, {
|
| 39 |
+
method: "POST",
|
| 40 |
+
body: upstream,
|
| 41 |
+
signal: AbortSignal.timeout(5 * 60 * 1000),
|
| 42 |
+
});
|
| 43 |
+
|
| 44 |
+
if (!res.ok) {
|
| 45 |
+
const errText = await res.text().catch(() => "Upstream error");
|
| 46 |
+
return new Response(
|
| 47 |
+
JSON.stringify({ error: errText }),
|
| 48 |
+
{ status: res.status, headers: { "Content-Type": "application/json" } }
|
| 49 |
+
);
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
// Pipe the SSE stream through to the client
|
| 53 |
+
return new Response(res.body, {
|
| 54 |
+
headers: {
|
| 55 |
+
"Content-Type": "text/event-stream",
|
| 56 |
+
"Cache-Control": "no-cache",
|
| 57 |
+
Connection: "keep-alive",
|
| 58 |
+
"X-Accel-Buffering": "no",
|
| 59 |
+
},
|
| 60 |
+
});
|
| 61 |
+
} catch (error: unknown) {
|
| 62 |
+
const isTimeout =
|
| 63 |
+
error instanceof Error &&
|
| 64 |
+
(error.name === "TimeoutError" || error.name === "AbortError");
|
| 65 |
+
return new Response(
|
| 66 |
+
JSON.stringify({
|
| 67 |
+
error: isTimeout ? "Transcription timed out" : "Failed to reach Modal API",
|
| 68 |
+
}),
|
| 69 |
+
{ status: isTimeout ? 504 : 502, headers: { "Content-Type": "application/json" } }
|
| 70 |
+
);
|
| 71 |
+
}
|
| 72 |
+
}
|
web/src/app/studio/page.tsx
CHANGED
|
@@ -28,8 +28,6 @@ import { cn } from "@/lib/utils"
|
|
| 28 |
import { MagnifyingGlass } from "@phosphor-icons/react"
|
| 29 |
import NextImage from "next/image"
|
| 30 |
|
| 31 |
-
const API_BASE = process.env.NEXT_PUBLIC_API_URL ?? "http://localhost:3000"
|
| 32 |
-
|
| 33 |
// --- Constants ---
|
| 34 |
const SPEAKER_COLORS = [
|
| 35 |
{ avatar: "bg-blue-400", track: "bg-blue-200" },
|
|
@@ -669,12 +667,13 @@ function StudioContent() {
|
|
| 669 |
// Per-segment DOM element refs for auto-scroll
|
| 670 |
const segmentRefs = useRef<Map<number, HTMLDivElement>>(new Map())
|
| 671 |
|
| 672 |
-
const [session, setSession] = useState
|
| 673 |
const [activeId, setActiveId] = useState<number>(1)
|
| 674 |
const [isPlaying, setIsPlaying] = useState(false)
|
| 675 |
const [currentTime, setCurrentTime] = useState(0)
|
| 676 |
const [isProcessing, setIsProcessing] = useState(false)
|
| 677 |
const [processError, setProcessError] = useState<string | null>(null)
|
|
|
|
| 678 |
|
| 679 |
// Sync session state with sessionId param
|
| 680 |
useEffect(() => {
|
|
@@ -688,8 +687,8 @@ function StudioContent() {
|
|
| 688 |
}, [sessionId])
|
| 689 |
|
| 690 |
// Automatic processing for pending sessions.
|
| 691 |
-
// Uses
|
| 692 |
-
//
|
| 693 |
useEffect(() => {
|
| 694 |
if (!session || processingRef.current || processError) return
|
| 695 |
|
|
@@ -698,58 +697,75 @@ function StudioContent() {
|
|
| 698 |
const process = async () => {
|
| 699 |
setIsProcessing(true)
|
| 700 |
setProcessError(null)
|
|
|
|
| 701 |
try {
|
| 702 |
-
// 1. Submit job — server responds immediately with job_id (202)
|
| 703 |
const formData = new FormData()
|
| 704 |
formData.append("audio", session.file!, session.filename)
|
| 705 |
|
| 706 |
-
const
|
| 707 |
method: "POST",
|
| 708 |
body: formData,
|
| 709 |
})
|
| 710 |
|
| 711 |
-
if (!
|
| 712 |
-
const errData = await
|
| 713 |
-
throw new Error(errData.error ?? "
|
| 714 |
}
|
| 715 |
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
const
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 733 |
}
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
if (pollData.status === "done") {
|
| 738 |
-
resolve(pollData.data as DiarizeResult)
|
| 739 |
-
} else if (pollData.status === "error") {
|
| 740 |
-
reject(new Error(pollData.error ?? "Processing failed"))
|
| 741 |
-
} else {
|
| 742 |
-
setTimeout(tick, POLL_INTERVAL)
|
| 743 |
-
}
|
| 744 |
-
} catch (e) {
|
| 745 |
-
reject(e)
|
| 746 |
}
|
|
|
|
|
|
|
| 747 |
}
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 753 |
}
|
| 754 |
|
| 755 |
updateSession(session.id, data)
|
|
@@ -758,9 +774,11 @@ function StudioContent() {
|
|
| 758 |
if (updated?.data.segments && updated.data.segments.length > 0) {
|
| 759 |
setActiveId(updated.data.segments[0].id)
|
| 760 |
}
|
|
|
|
| 761 |
} catch (e) {
|
| 762 |
processingRef.current = false
|
| 763 |
setProcessError(e instanceof Error ? e.message : "Request failed")
|
|
|
|
| 764 |
} finally {
|
| 765 |
setIsProcessing(false)
|
| 766 |
}
|
|
@@ -850,7 +868,9 @@ function StudioContent() {
|
|
| 850 |
{isProcessing && (
|
| 851 |
<Badge variant="secondary" className="bg-blue-500/10 text-blue-500 hover:bg-blue-500/10 border-blue-500/20 gap-2 font-medium px-3 h-8">
|
| 852 |
<div className="size-2 rounded-full bg-blue-500 animate-pulse" />
|
| 853 |
-
|
|
|
|
|
|
|
| 854 |
</Badge>
|
| 855 |
)}
|
| 856 |
|
|
@@ -889,15 +909,37 @@ function StudioContent() {
|
|
| 889 |
<div className="flex-1 overflow-hidden flex flex-col min-w-0 relative">
|
| 890 |
{isProcessing && segments.length === 0 && (
|
| 891 |
<div className="absolute inset-0 z-20 bg-background/80 backdrop-blur-sm flex flex-col items-center justify-center p-8 text-center animate-in fade-in duration-500">
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
<div className="
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
|
| 899 |
-
|
| 900 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 901 |
</div>
|
| 902 |
)}
|
| 903 |
|
|
|
|
| 28 |
import { MagnifyingGlass } from "@phosphor-icons/react"
|
| 29 |
import NextImage from "next/image"
|
| 30 |
|
|
|
|
|
|
|
| 31 |
// --- Constants ---
|
| 32 |
const SPEAKER_COLORS = [
|
| 33 |
{ avatar: "bg-blue-400", track: "bg-blue-200" },
|
|
|
|
| 667 |
// Per-segment DOM element refs for auto-scroll
|
| 668 |
const segmentRefs = useRef<Map<number, HTMLDivElement>>(new Map())
|
| 669 |
|
| 670 |
+
const [session, setSession] = useState<ReturnType<typeof getSession>>(null)
|
| 671 |
const [activeId, setActiveId] = useState<number>(1)
|
| 672 |
const [isPlaying, setIsPlaying] = useState(false)
|
| 673 |
const [currentTime, setCurrentTime] = useState(0)
|
| 674 |
const [isProcessing, setIsProcessing] = useState(false)
|
| 675 |
const [processError, setProcessError] = useState<string | null>(null)
|
| 676 |
+
const [streamingText, setStreamingText] = useState<string | null>(null)
|
| 677 |
|
| 678 |
// Sync session state with sessionId param
|
| 679 |
useEffect(() => {
|
|
|
|
| 687 |
}, [sessionId])
|
| 688 |
|
| 689 |
// Automatic processing for pending sessions.
|
| 690 |
+
// Uses Modal streaming API: tokens arrive via SSE for live display,
|
| 691 |
+
// then the full transcription is converted to a DiarizeResult.
|
| 692 |
useEffect(() => {
|
| 693 |
if (!session || processingRef.current || processError) return
|
| 694 |
|
|
|
|
| 697 |
const process = async () => {
|
| 698 |
setIsProcessing(true)
|
| 699 |
setProcessError(null)
|
| 700 |
+
setStreamingText("")
|
| 701 |
try {
|
|
|
|
| 702 |
const formData = new FormData()
|
| 703 |
formData.append("audio", session.file!, session.filename)
|
| 704 |
|
| 705 |
+
const res = await fetch("/api/transcribe-stream", {
|
| 706 |
method: "POST",
|
| 707 |
body: formData,
|
| 708 |
})
|
| 709 |
|
| 710 |
+
if (!res.ok) {
|
| 711 |
+
const errData = await res.json().catch(() => ({}))
|
| 712 |
+
throw new Error(errData.error ?? "Transcription failed")
|
| 713 |
}
|
| 714 |
|
| 715 |
+
// Consume SSE stream
|
| 716 |
+
const reader = res.body!.getReader()
|
| 717 |
+
const decoder = new TextDecoder()
|
| 718 |
+
let fullText = ""
|
| 719 |
+
let buffer = ""
|
| 720 |
+
|
| 721 |
+
while (true) {
|
| 722 |
+
const { done, value } = await reader.read()
|
| 723 |
+
if (done) break
|
| 724 |
+
|
| 725 |
+
buffer += decoder.decode(value, { stream: true })
|
| 726 |
+
const lines = buffer.split("\n")
|
| 727 |
+
buffer = lines.pop() ?? ""
|
| 728 |
+
|
| 729 |
+
for (const line of lines) {
|
| 730 |
+
if (!line.startsWith("data: ")) continue
|
| 731 |
+
try {
|
| 732 |
+
const payload = JSON.parse(line.slice(6))
|
| 733 |
+
if (payload.token) {
|
| 734 |
+
fullText += payload.token
|
| 735 |
+
setStreamingText(fullText)
|
| 736 |
}
|
| 737 |
+
if (payload.done) {
|
| 738 |
+
fullText = payload.transcription ?? fullText
|
| 739 |
+
setStreamingText(fullText)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 740 |
}
|
| 741 |
+
} catch {
|
| 742 |
+
// skip malformed SSE lines
|
| 743 |
}
|
| 744 |
+
}
|
| 745 |
+
}
|
| 746 |
+
|
| 747 |
+
// Get audio duration from the media element
|
| 748 |
+
const mediaDuration = mediaRef.current?.duration || 0
|
| 749 |
+
|
| 750 |
+
// Derive emotion from the first bracket tag in transcription
|
| 751 |
+
const firstTagMatch = fullText.match(/\[([^\]]+)\]/)
|
| 752 |
+
const firstTag = firstTagMatch ? getTagEntry(firstTagMatch[1]) : null
|
| 753 |
+
|
| 754 |
+
// Build DiarizeResult from plain transcription
|
| 755 |
+
const data: DiarizeResult = {
|
| 756 |
+
segments: fullText.trim() ? [{
|
| 757 |
+
id: 1,
|
| 758 |
+
speaker: "SPEAKER_00",
|
| 759 |
+
start: 0,
|
| 760 |
+
end: mediaDuration || 30,
|
| 761 |
+
text: fullText.trim(),
|
| 762 |
+
emotion: firstTag?.emotion ?? "Neutral",
|
| 763 |
+
valence: firstTag?.valence ?? 0,
|
| 764 |
+
arousal: firstTag?.arousal ?? 0,
|
| 765 |
+
}] : [],
|
| 766 |
+
duration: mediaDuration || 30,
|
| 767 |
+
text: fullText.trim(),
|
| 768 |
+
filename: session.filename,
|
| 769 |
}
|
| 770 |
|
| 771 |
updateSession(session.id, data)
|
|
|
|
| 774 |
if (updated?.data.segments && updated.data.segments.length > 0) {
|
| 775 |
setActiveId(updated.data.segments[0].id)
|
| 776 |
}
|
| 777 |
+
setStreamingText(null)
|
| 778 |
} catch (e) {
|
| 779 |
processingRef.current = false
|
| 780 |
setProcessError(e instanceof Error ? e.message : "Request failed")
|
| 781 |
+
setStreamingText(null)
|
| 782 |
} finally {
|
| 783 |
setIsProcessing(false)
|
| 784 |
}
|
|
|
|
| 868 |
{isProcessing && (
|
| 869 |
<Badge variant="secondary" className="bg-blue-500/10 text-blue-500 hover:bg-blue-500/10 border-blue-500/20 gap-2 font-medium px-3 h-8">
|
| 870 |
<div className="size-2 rounded-full bg-blue-500 animate-pulse" />
|
| 871 |
+
{streamingText !== null && streamingText.length > 0
|
| 872 |
+
? "Streaming..."
|
| 873 |
+
: "Connecting to Evoxtral..."}
|
| 874 |
</Badge>
|
| 875 |
)}
|
| 876 |
|
|
|
|
| 909 |
<div className="flex-1 overflow-hidden flex flex-col min-w-0 relative">
|
| 910 |
{isProcessing && segments.length === 0 && (
|
| 911 |
<div className="absolute inset-0 z-20 bg-background/80 backdrop-blur-sm flex flex-col items-center justify-center p-8 text-center animate-in fade-in duration-500">
|
| 912 |
+
{streamingText !== null && streamingText.length > 0 ? (
|
| 913 |
+
/* Live streaming tokens */
|
| 914 |
+
<div className="max-w-lg w-full space-y-4">
|
| 915 |
+
<div className="flex items-center justify-center gap-2 mb-4">
|
| 916 |
+
<div className="size-2 rounded-full bg-emerald-500 animate-pulse" />
|
| 917 |
+
<span className="text-sm font-semibold text-emerald-600">Streaming transcription...</span>
|
| 918 |
+
</div>
|
| 919 |
+
<div className="bg-muted/40 rounded-xl p-6 text-left max-h-64 overflow-y-auto">
|
| 920 |
+
<p className="text-[15px] leading-relaxed font-medium tracking-tight whitespace-pre-wrap">
|
| 921 |
+
{streamingText}
|
| 922 |
+
<span className="inline-block w-[2px] h-[1em] bg-primary animate-pulse ml-0.5 align-text-bottom" />
|
| 923 |
+
</p>
|
| 924 |
+
</div>
|
| 925 |
+
<p className="text-muted-foreground text-xs">
|
| 926 |
+
Powered by Evoxtral — expressive tagged transcription
|
| 927 |
+
</p>
|
| 928 |
+
</div>
|
| 929 |
+
) : (
|
| 930 |
+
/* Initial loading spinner before first token arrives */
|
| 931 |
+
<>
|
| 932 |
+
<div className="size-16 mb-6 relative">
|
| 933 |
+
<div className="absolute inset-0 border-4 border-muted rounded-full" />
|
| 934 |
+
<div className="absolute inset-0 border-4 border-primary border-t-transparent rounded-full animate-spin" />
|
| 935 |
+
<Waveform size={24} className="absolute inset-0 m-auto text-primary animate-pulse" />
|
| 936 |
+
</div>
|
| 937 |
+
<h2 className="text-xl font-bold mb-2">Transcribing Audio</h2>
|
| 938 |
+
<p className="text-muted-foreground text-sm max-w-xs mx-auto">
|
| 939 |
+
Connecting to Evoxtral... tokens will stream in real-time.
|
| 940 |
+
</p>
|
| 941 |
+
</>
|
| 942 |
+
)}
|
| 943 |
</div>
|
| 944 |
)}
|
| 945 |
|