mumble-cleanup / docs /HANDBOOK.md
adikuma's picture
initial upload: cleanup code and 688-pair seed dataset
fd0b01f verified
|
Raw
History Blame Contribute Delete
15.6 kB
# Mumble cleanup model: handbook
The complete picture for the optional transcript cleanup model: the research
behind it, the fine-tuning mechanism in detail, the metrics, and what every file
should do. This is the reference to read before writing the code.
Nothing in this project runs automatically. Training happens on a rented GPU
only when you decide to launch it.
---
## 1. Where this fits
Mumble's dictation pipeline today:
push-to-talk -> capture (cpal, 16kHz mono, gain normalized, pre-roll)
-> Parakeet-TDT-0.6B-v3 (int8 ONNX via sherpa-onnx)
-> custom dictionary substitution
-> paste at cursor
The ASR model transcribes verbatim, including disfluencies ("um", "uh"),
repeated words, and false starts, and it has no idea about list or paragraph
structure. The cleanup model is an OPTIONAL pass inserted between transcription
and paste:
... -> Parakeet text -> [cleanup model, if enabled] -> paste
It removes fillers and disfluencies, fixes punctuation and capitalization, and
does light formatting, WITHOUT rewording, adding facts, or answering questions
in the text. This mirrors what Wispr Flow does.
---
## 2. Research recap
### 2.1 The ASR model and int8 vs fp32
Mumble ships **Parakeet-TDT-0.6B-v3** quantized to **int8** (the
`encoder.int8.onnx` etc. assets), run through sherpa-onnx on ONNX Runtime. We
benchmarked int8 against the fp32 build on 150 utterances of LibriSpeech
test-clean (see `bench/BENCHMARK.md`):
| Variant | WER % | CER % | RTF | Model RAM | Peak RAM |
|---|---|---|---|---|---|
| int8 (shipped) | 1.69 | 0.64 | 0.036 (27x realtime) | 723 MB | 2.17 GB |
| fp32 | 1.44 | 0.47 | 0.08 (12.5x realtime) | 2.29 GB | 3.88 GB |
Published NVIDIA reference is 1.93% on full test-clean, so our setup is correct.
Decision: **int8 stays the default** (2.2x faster, a third of the RAM, only 0.25
points of WER cost); fp32 becomes an optional download for users who want max
quality. Both numbers go in the Settings model-quality selector.
### 2.2 Why a cleanup model, and why fine-tune
A zero-shot prototype (Qwen2.5-0.5B-Instruct, plain prompt) cleaned real Mumble
transcripts well: it removed "and and" -> "and", "Claude Claude" -> "Claude",
fillers, and fixed punctuation, at about 1.7s per transcript on CPU. But on
short or ambiguous inputs it slipped into assistant mode ("Yes, ...", "Sure, I
can ...") and lightly reworded. A source-overlap guard caught those cases.
Fine-tuning fixes this at the root: instead of hoping a prompt constrains the
model, we teach the behavior from data. This is also why a 0.5B model is enough.
Google's whole pitch for tiny models is task-specific fine-tuning.
### 2.3 Base model choice
| | Gemma 3 270M | Qwen2.5-0.5B (chosen) | Qwen3-0.6B |
|---|---|---|---|
| License | Gemma ToU, gated, passes down to shipped weights | Apache-2.0, ungated | Apache-2.0, ungated |
| Notes | best tiny FT base on merits, but license friction for a paid app | already prototyped, zero migration | a bit more headroom, must disable thinking |
Gemma 4 is Apache-2.0 and ungated but its smallest size is 2.3B, far over our
latency budget. **Decision: fine-tune Qwen2.5-0.5B-Instruct.** Licensing is the
deciding factor for a shippable product; quality after fine-tuning on a task
this narrow is a wash across the candidates.
---
## 3. The data: synthetic injection only
### 3.1 The idea
We do NOT collect raw->clean pairs from real STT. Instead we take **clean
written text** (the target) and programmatically **corrupt** it to make the raw
input. Because every corruption is an insertion plus punctuation/casing removal,
the clean target is recoverable from the raw by deletion and repunctuation
alone. The model therefore **cannot learn to invent content** (the exact failure
we saw). Faithfulness is structural, not prompt-hoped-for. This is the whole
reason for choosing injection.
### 3.2 The corruption recipe (`configs/inject.yaml`, `src/cleanup/inject.py`)
Given a clean sentence:
- **false start**: prepend a duplicated 1-4 word head ("i want, i want to go").
- **repetition**: duplicate an internal 1-3 word span ("go to to the store").
- **fillers**: insert "um/uh/like/you know" at word gaps.
- **lowercase + strip punctuation**: always, so the model learns to restore
casing and punctuation.
Each corruption fires with a probability from the config, so one clean sentence
yields many distinct raw variants (a few hundred clean sentences can generate
thousands of pairs).
### 3.3 Scope and the one caveat
- **v1 scope:** disfluency removal + punctuation/casing. Light formatting
(lists, paragraphs) is NOT learnable from injection and is deferred to v2
(which would add teacher distillation).
- **Caveat:** the clean source MUST be properly punctuated, capitalized written
text, or the model never learns to restore punctuation. `01_data.py` prints
sample pairs so you can confirm this on the first run. If the default source
is bare speech transcript, switch `clean_source` in the config.
- **Train synthetic, evaluate real.** Training is synthetic, but the held-out
test set is the REAL DisfluencySpeech test split (real disfluent speech with a
gold clean target). That measures whether the synthetic->real jump holds.
---
## 4. The fine-tuning mechanism (read this before writing 02_train.py)
### 4.1 What supervised fine-tuning (SFT) is
The base model is a next-token predictor: given tokens so far, it outputs a
probability distribution over the next token. Fine-tuning continues training it
on our (prompt, target) pairs so that, given the cleanup prompt + a raw
transcript, it produces the clean transcript.
Each example is rendered with the chat template into one token sequence:
<|im_start|>system\n {SYSTEM_PROMPT} <|im_end|>
<|im_start|>user\n {raw} <|im_end|>
<|im_start|>assistant\n {clean} <|im_end|>
### 4.2 The loss
Causal-LM cross-entropy (negative log-likelihood). For target tokens y_1..y_T,
loss = - (1/T) * sum_t log p_theta(y_t | y_<t, prompt)
The model is rewarded for putting high probability on the actual next clean
token at every position. **Completion-only loss:** we mask every token up to and
including `<|im_start|>assistant\n`, so the loss is computed ONLY on the clean
target tokens. The model is never trained to reproduce the prompt or the raw
input, only to emit the clean output. (This is the `DataCollatorForCompletion
OnlyLM` with the assistant response template.)
### 4.3 Gradients and optimization
- **Forward pass:** tokens -> logits -> loss.
- **Backward pass (backprop):** autograd computes the gradient of the loss with
respect to every trainable parameter, d(loss)/d(theta), by the chain rule.
- **Optimizer = AdamW.** For each trainable parameter it keeps two running
averages: m (first moment, the mean/momentum of recent gradients) and v
(second moment, the mean of squared gradients, a per-parameter variance). The
update is roughly
theta <- theta - lr * m_hat / (sqrt(v_hat) + eps) then theta <- theta - lr * wd * theta
Dividing by sqrt(v) gives each parameter its own effective step size (large for
consistently-signed gradients, small for noisy ones), which is why Adam-family
optimizers are stable for transformers. The "W" is decoupled weight decay (the
second term), a cleaner L2 regularizer than classic Adam.
- **Learning-rate schedule:** linear **warmup** for the first few percent of
steps (ramp lr from 0 so early, large, random-direction updates do not
destabilize the model), then **cosine decay** down toward 0.
- **Effective batch size = per_device_batch * grad_accum (* num_gpus).** Gradient
accumulation runs several micro-batches, sums their gradients, and only then
steps the optimizer, so you get the stability of a big batch without the
memory of one.
- **Gradient clipping (`max_grad_norm`)** rescales the gradient if its norm
exceeds a threshold, preventing a single bad batch from blowing up the weights.
- **Mixed precision:** bf16 (or fp16) for the forward/backward math is ~2x
faster and uses less memory than fp32; tf32 speeds up matmuls on NVIDIA Ampere
and newer. On CPU these are off (the smoke path).
- **Epochs:** full passes over the data. 2-3 is right for a few-thousand-pair
SFT; more risks overfitting the synthetic distribution.
### 4.4 LoRA, in detail (this is the key part)
**Problem with full fine-tuning.** Updating all ~0.5B parameters means storing a
gradient and two AdamW moment buffers per parameter (so several bytes x 0.5B =
multiple GB of optimizer state), and you end up with a full copy of the model per
task. Overkill for adapting to one narrow task.
**LoRA (Low-Rank Adaptation).** Freeze the pretrained weight matrix W0 (shape
d x k). Do not update it. Represent the fine-tuning update as a **low-rank**
product:
W_effective = W0 + dW, dW = (alpha / r) * B @ A
where A is r x k, B is d x r, and the **rank r is tiny** (we use r = 16) compared
to d and k (hundreds to thousands). Only A and B are trainable. The forward pass
becomes:
h = W0 @ x + (alpha / r) * B @ (A @ x)
**Why it works.** The weight update needed to adapt a big pretrained model to a
specific task has low "intrinsic rank" - a small-rank correction is enough. You
are not relearning language, just nudging behavior.
**Initialization.** A is initialized small-random (Gaussian), B is initialized to
**zero**, so dW = 0 at step 0 and training starts from exactly the pretrained
model. The adapters then learn the correction.
**The knobs:**
- `lora_r` (rank, 16): capacity of the adapter. Higher = more expressive, more
params.
- `lora_alpha` (32): a scaling factor; the effective scale applied to dW is
alpha/r (= 2 here). Think of it as the adapter's learning-rate gain.
- `lora_target_modules` (q_proj, k_proj, v_proj, o_proj): which weight matrices
get an adapter. The attention projections are the standard choice; you can add
the MLP projections for more capacity.
- `lora_dropout` (0.05): dropout on the adapter path, light regularization.
**Parameter and memory win.** Per adapted matrix, trainable params drop from
d*k to r*(d+k). For Qwen2.5-0.5B with adapters on the four attention projections
at r=16, you train on the order of a few million parameters (roughly 1% of the
model) instead of 490M. Gradients and AdamW moments exist only for those few
million, so it fits easily on a modest GPU, and the saved adapter is tens of MB,
not a gigabyte.
**Gradients with LoRA.** W0 has `requires_grad = False`, so backprop computes
gradients only into A and B. Everything else about the optimization (AdamW,
schedule, clipping) is unchanged, just applied to far fewer parameters.
**Merging for deployment (`04_export.py`).** At inference you do not want the
extra B@A matmul. `merge_and_unload()` folds the adapter back in:
W0 <- W0 + (alpha/r)*B@A, producing a standalone model identical in shape to the
base. That merged model is what we export to ONNX for the Rust `ort` backend.
**QLoRA vs LoRA.** QLoRA additionally quantizes the frozen base to 4-bit to save
even more memory. At 0.5B we do not need it; plain LoRA in bf16 fits comfortably,
and skipping 4-bit avoids a small quality hit.
---
## 5. Metrics
Training optimizes the cross-entropy loss above. To judge whether the fine-tune
is actually better than the base model, `03_evaluate.py` runs BOTH on the
held-out real test set and scores a suite that balances "did it edit correctly"
against "did it stay faithful":
- **chrF** (sacrebleu): character n-gram F-score of the output against the gold
clean. General overlap/fluency signal, stable on short text.
- **disfluency-removal F1:** treat cleanup as deleting tokens from the raw. Gold
deletions = tokens in raw not in gold; predicted deletions = tokens in raw not
in output. F1 over those deleted multisets. Measures the core job directly.
- **added-content rate:** fraction of output content tokens NOT present in the
raw input. Should be about 0. **This is the hallucination guard** - the metric
that catches the model inventing or answering. A fine-tune that improves chrF
but raises added-content is a regression, not a win.
- **source-overlap:** fraction of output tokens that are present in the raw.
Should be near 1.
(Optional, heavier: ERRANT F0.5 via spacy for formal edit scoring; a
sentence-embedding cosine for semantic drift. Left as extras behind the `errant`
optional dependency.)
**The decisive behavioral test:** feed dictated questions ("um what's the capital
of france") and confirm the fine-tune **cleans** them ("What's the capital of
France?") rather than **answering**. This is the exact failure the base model
made, so it is the single most important check.
**Protocol:** held-out real test never seen in training, greedy decoding for
determinism, one table with rows {base, fine-tune} and the four columns, plus a
qualitative before/after on real transcripts. Win condition: higher disfluency
F1 and chrF with added-content held at about 0.
---
## 6. File-by-file guide
Infra (already written, the format, do not need rewriting):
- `pyproject.toml`, `.python-version`, `uv.lock` (generated by `uv sync`): the
pinned, self-contained environment.
- `Makefile`: one make target per pipeline stage, RUN_ID / LR / EPOCHS settable.
- `configs/inject.yaml`: the injection recipe and the clean/eval data sources.
- `configs/train.yaml`: base model, LoRA knobs, optimizer, schedule, precision.
- `README.md`: quickstart and the Vast.ai steps.
Code to write (scaffolded with signatures + specs):
- `src/cleanup/config.py`: load the two yaml files into `TrainConfig` /
`InjectConfig` dataclasses.
- `src/cleanup/prompts.py`: the `SYSTEM_PROMPT` and `build_messages` (a starting
prompt is provided; tune it).
- `src/cleanup/inject.py`: `strip_punctuation_and_lowercase` and `make_raw` (the
injection algorithm from section 3.2). Pure functions, unit-tested.
- `src/cleanup/data.py`: load clean sentences (hf stream or a local file), build
pairs, split, load the real eval set, jsonl read/write.
- `src/cleanup/train.py`: `build_dataset` (chat-template the pairs) and `train`
(LoRA + trl SFTTrainer, completion-only loss). Section 4 is the spec.
- `src/cleanup/infer.py`: `load_model` (base + optional adapter) and `clean_text`
(greedy generate, output length capped near the input).
- `src/cleanup/evaluate.py`: the metric functions from section 5.
- `src/cleanup/export.py`: `merge_adapter` (fold LoRA in) and `export_onnx`.
- `src/cleanup/pack.py`: `render_report` and `pack_run` (tar + sha256).
- `scripts/01_data.py .. 06_pack_and_ship.py`: thin CLI entry points that wire
the src functions together per stage. Each has its argparse and a step list.
- `tests/test_inject.py`: assert the injection invariants (the faithfulness one
matters most).
---
## 7. Vast.ai workflow
1. rent a cuda pytorch instance, note the ssh host and port.
2. ssh in, clone the mumble repo, `cd models/cleanup`.
3. `uv sync` (generates `uv.lock` on first run; commit it).
4. `make all RUN_ID=r1` (or run stages one at a time and inspect between them).
5. `pack` prints a sha256 and an `scp -P <port> root@<host>:... .` line. pull the
`dist/<run>.tar.gz` off the box, run `shasum -a 256` to verify, then destroy
the instance.
`02_train.py` is cuda-aware (bf16 when the GPU supports it, else fp16, off on
CPU), so `make smoke` runs the same code on CPU locally with a few rows and one
epoch to validate wiring before you rent anything.
Do not launch a Vast.ai run until the code is written and the plan reviewed.