shangeth commited on
Commit
72ecc83
·
verified ·
1 Parent(s): a14ffcf

Upload Wren-ASR-0.5B-multi checkpoint

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ - de
6
+ - fr
7
+ - es
8
+ - nl
9
+ - it
10
+ - pl
11
+ - pt
12
+ library_name: pytorch
13
+ tags:
14
+ - automatic-speech-recognition
15
+ - asr
16
+ - audio
17
+ - speech-recognition
18
+ - multilingual
19
+ - wren
20
+ - mimi
21
+ - qwen2.5
22
+ - neural-codec
23
+ pipeline_tag: automatic-speech-recognition
24
+ datasets:
25
+ - shangeth/mls-mimi-codes
26
+ - shangeth/libritts-r-mimi-codes
27
+ - shangeth/vctk-mimi-codes
28
+ - shangeth/jenny-mimi-codes
29
+ - shangeth/ljspeech-mimi-codes
30
+ - shangeth/expresso-mimi-codes-tagged
31
+ - facebook/multilingual_librispeech
32
+ - mythicinfinity/libritts_r
33
+ - keithito/lj_speech
34
+ - CSTR-Edinburgh/vctk
35
+ - reach-vb/jenny_tts_dataset
36
+ - ylacombe/expresso
37
+ ---
38
+
39
+ # Wren-ASR-0.5B-multi
40
+
41
+ **Multilingual** automatic speech recognition model in the Wren series. Encodes
42
+ audio with the [Kyutai Mimi](https://huggingface.co/kyutai/mimi) neural codec,
43
+ then transcribes with a [Qwen/Qwen2.5-0.5B](https://huggingface.co/Qwen/Qwen2.5-0.5B)
44
+ backbone — no acoustic encoder, no CTC, just a small LLM consuming Mimi codes as
45
+ input embeddings.
46
+
47
+ Supports **8 languages**: English, German, French, Spanish, Dutch, Italian, Polish, Portuguese.
48
+
49
+ ## Links
50
+
51
+ - **Training & inference code:** [github.com/shangeth/wren-asr](https://github.com/shangeth/wren-asr)
52
+ - **Wren research project:** [github.com/shangeth/wren](https://github.com/shangeth/wren)
53
+ - **TTS counterpart:** [shangeth/Wren-TTS-0.5B-multi](https://huggingface.co/shangeth/Wren-TTS-0.5B-multi)
54
+ - **Dataset extraction (Mimi codes):** [github.com/shangeth/wren-datasets](https://github.com/shangeth/wren-datasets)
55
+ - **Demo Space:** [huggingface.co/spaces/shangeth/Wren-ASR-0.5B-multi-demo](https://huggingface.co/spaces/shangeth/Wren-ASR-0.5B-multi-demo)
56
+
57
+ ## Architecture
58
+
59
+ ```
60
+ audio ──► Mimi encoder (k=3) ──► Qwen2.5-0.5B (audio prefix → text) ──► transcript
61
+ ```
62
+
63
+ Mimi codes serve as a discrete audio prefix in the LLM's input embedding space.
64
+ At each audio frame the k=3 codebook codes go through k separate input embedding
65
+ tables; their sum (scaled by 1/√k) is the input embedding for that step. The
66
+ audio prefix is wrapped in `<|audio_start|>` / `<|audio_end|>` tokens, after
67
+ which the LLM autoregressively emits text using its native vocabulary and
68
+ `lm_head` — no new output heads were added.
69
+
70
+ - **Backbone:** Qwen2.5-0.5B (causal LM; transformer body ~358M params, 151k-token multilingual vocab)
71
+ - **Audio tokenizer:** Mimi (`kyutai/mimi`), 12.5 fps, 2048-entry codebooks
72
+ - **Codebooks used:** first 3 (semantic-content-rich); reduces input embedding size 8/3× vs 8-codebook variants
73
+ - **Audio prefix:** `<|audio_start|>` + summed-codebook embeds × T_frames + `<|audio_end|>`
74
+ - **Output:** standard text autoregression via `model.llm.generate(inputs_embeds=...)`
75
+
76
+ ## Training data
77
+
78
+ Trained on the **union of every dataset used to train Wren-TTS** — the same
79
+ 6 corpora that power the en/multi/expressive TTS recipes, with text used as the
80
+ ASR target:
81
+
82
+ | Dataset | Rows | Language(s) |
83
+ |---|---|---|
84
+ | VCTK | ~44k | en (109 speakers, multiple accents) |
85
+ | Jenny | ~21k | en (single speaker) |
86
+ | LibriTTS-R | ~360k | en (clean_100 + clean_360 + other_500) |
87
+ | LJSpeech | ~13k | en (single speaker) |
88
+ | MLS | ~6.0M | de · fr · es · it · nl · pl · pt |
89
+ | Expresso (tagged) | ~26k | en (style tags stripped at load time) |
90
+ | **Total** | **~6.46M** rows / epoch | |
91
+
92
+ Mimi codes are pre-extracted and published as the per-corpus mimi-codes datasets
93
+ (see Datasets above) — no online encoding during training. Single-pass
94
+ from-scratch training, ~k=3 codebooks. Held-out validation combines LibriTTS-R
95
+ `dev_clean` + MLS `dev` (all 7 langs) + Expresso `dev` (tags stripped) + 5%
96
+ per single-speaker English source. All weights set to 1.0 (every row, every
97
+ epoch, no subsampling). Trained on a single A100-40GB.
98
+
99
+ Text casing and punctuation are preserved in the ground-truth transcripts.
100
+
101
+ ## Usage
102
+
103
+ ```bash
104
+ pip install torch torchaudio transformers
105
+ ```
106
+
107
+ ```python
108
+ import torch
109
+ import torchaudio
110
+ from transformers import AutoModel, AutoProcessor
111
+
112
+ model_id = "shangeth/Wren-ASR-0.5B-multi"
113
+ device = "cuda" if torch.cuda.is_available() else "cpu"
114
+
115
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
116
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
117
+
118
+ # Load any short clip (one of the 8 supported languages, ≤ 30 s)
119
+ wav, sr = torchaudio.load("input.wav")
120
+
121
+ inputs = processor(audio=wav, sampling_rate=sr)
122
+ inputs = {k: v.to(device) for k, v in inputs.items()}
123
+
124
+ ids = model.generate(**inputs, max_new_tokens=200)
125
+ text = processor.batch_decode(ids, skip_special_tokens=True)[0]
126
+ print(text)
127
+ ```
128
+
129
+ ## Sampling tips
130
+
131
+ Defaults: greedy decoding (`do_sample=False`). For longer / harder utterances:
132
+ - Pass `do_sample=True, temperature=0.7, top_p=0.9` for diverse beams
133
+ - Raise `max_new_tokens` if transcripts are getting cut off
134
+ - Audio is hard-capped at 30 s (375 frames @ 12.5 fps) by the training recipe;
135
+ for longer audio, segment first
136
+
137
+ ## Limitations & known issues
138
+
139
+ - **Language coverage:** only the 8 trained languages. Out-of-distribution
140
+ audio produces noise / hallucinated text in the closest matching language.
141
+ - **Per-language quality varies with data volume:** German / Dutch / French
142
+ are strongest (largest training shares); Polish / Portuguese / Italian have
143
+ less training data and may be less accurate.
144
+ - **Audiobook-style audio dominates training:** MLS / LibriTTS-R / LJSpeech /
145
+ Jenny are all studio-style read speech. Performance on conversational audio,
146
+ noisy environments, or accented far-field input may degrade.
147
+ - **0.5B backbone** — quality is below frontier ASR systems (Whisper-large-v3,
148
+ USM, etc.). The pitch is "small enough to run anywhere" + "shares architecture
149
+ with Wren-TTS-0.5B-multi for unified speech-text experimentation".
150
+ - **30s audio cap.** Hard-cap at training time; longer audio needs to be
151
+ segmented externally.
152
+ - **No speaker diarization.** Single-stream transcription only.
153
+
154
+ ## The Wren series
155
+
156
+ Wren is a family of compact (<3B parameter) multimodal speech LLMs — small
157
+ enough to run on a single consumer GPU, designed for open research on unified
158
+ speech understanding and synthesis.
159
+
160
+ - **Wren-TTS** — text → speech (English + multilingual + expressive variants)
161
+ - **Wren-ASR** — speech → text (this release)
162
+ - **Wren-LM** — speech-language modelling / dialog (planned)
163
+ - **Wren-Omni** — unified ASR + TTS + LM in one checkpoint (planned)
164
+
165
+ All Wren models share the same design principles: small backbone LLM + neural
166
+ audio codec, open weights, simple PyTorch checkpoints, reproducible training
167
+ recipes. Wren-ASR uses the same Qwen2.5-0.5B backbone as Wren-TTS-0.5B-multi
168
+ and is trained on the same corpora — making the pair a natural starting point
169
+ for unified speech-text modelling research.
170
+
171
+ ## Repository contents
172
+
173
+ | File | Purpose |
174
+ |---|---|
175
+ | `model.safetensors` | Model weights |
176
+ | `config.json` | `WrenASRConfig` (with `auto_map` for `trust_remote_code`) |
177
+ | `tokenizer.json` + friends | Qwen2.5 tokenizer with Wren-ASR's 2 special tokens added |
178
+ | `processor_config.json` | `WrenASRProcessor` auto_map |
179
+ | `configuration_wren_asr.py` | `WrenASRConfig(PretrainedConfig)` |
180
+ | `modeling_wren_asr.py` | `WrenForASR(PreTrainedModel)` — loads Mimi codec lazily on first call |
181
+ | `processing_wren_asr.py` | `WrenASRProcessor(ProcessorMixin)` — audio → Mimi codes + text decode |
182
+ | `README.md` | This model card |
183
+
184
+ ## Citation
185
+
186
+ ```bibtex
187
+ @misc{wren2026,
188
+ title = {Wren: A Family of Small Open-Weight Models for Unified Speech-Text Modelling},
189
+ author = {Shangeth Rajaa},
190
+ year = {2026},
191
+ url = {https://github.com/shangeth/wren}
192
+ }
193
+ ```
194
+
195
+ ## License
196
+
197
+ Apache-2.0 for the checkpoint weights and code in this repo.
198
+ Upstream components carry their own licenses — review before redistribution.
199
+ The Expresso dataset (used for English style robustness) is CC-BY-NC-4.0; if
200
+ you build derived models on this checkpoint and want to release them
201
+ commercially, retrain with Expresso excluded.
added_tokens.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<tool_call>": 151657,
4
+ "<|audio_end|>": 151666,
5
+ "<|audio_start|>": 151665,
6
+ "<|box_end|>": 151649,
7
+ "<|box_start|>": 151648,
8
+ "<|endoftext|>": 151643,
9
+ "<|file_sep|>": 151664,
10
+ "<|fim_middle|>": 151660,
11
+ "<|fim_pad|>": 151662,
12
+ "<|fim_prefix|>": 151659,
13
+ "<|fim_suffix|>": 151661,
14
+ "<|im_end|>": 151645,
15
+ "<|im_start|>": 151644,
16
+ "<|image_pad|>": 151655,
17
+ "<|object_ref_end|>": 151647,
18
+ "<|object_ref_start|>": 151646,
19
+ "<|quad_end|>": 151651,
20
+ "<|quad_start|>": 151650,
21
+ "<|repo_name|>": 151663,
22
+ "<|video_pad|>": 151656,
23
+ "<|vision_end|>": 151653,
24
+ "<|vision_pad|>": 151654,
25
+ "<|vision_start|>": 151652
26
+ }
chat_template.jinja ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0]['role'] == 'system' %}
4
+ {{- messages[0]['content'] }}
5
+ {%- else %}
6
+ {{- 'You are a helpful assistant.' }}
7
+ {%- endif %}
8
+ {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
9
+ {%- for tool in tools %}
10
+ {{- "\n" }}
11
+ {{- tool | tojson }}
12
+ {%- endfor %}
13
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
14
+ {%- else %}
15
+ {%- if messages[0]['role'] == 'system' %}
16
+ {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
17
+ {%- else %}
18
+ {{- '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}
19
+ {%- endif %}
20
+ {%- endif %}
21
+ {%- for message in messages %}
22
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
23
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
24
+ {%- elif message.role == "assistant" %}
25
+ {{- '<|im_start|>' + message.role }}
26
+ {%- if message.content %}
27
+ {{- '\n' + message.content }}
28
+ {%- endif %}
29
+ {%- for tool_call in message.tool_calls %}
30
+ {%- if tool_call.function is defined %}
31
+ {%- set tool_call = tool_call.function %}
32
+ {%- endif %}
33
+ {{- '\n<tool_call>\n{"name": "' }}
34
+ {{- tool_call.name }}
35
+ {{- '", "arguments": ' }}
36
+ {{- tool_call.arguments | tojson }}
37
+ {{- '}\n</tool_call>' }}
38
+ {%- endfor %}
39
+ {{- '<|im_end|>\n' }}
40
+ {%- elif message.role == "tool" %}
41
+ {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
42
+ {{- '<|im_start|>user' }}
43
+ {%- endif %}
44
+ {{- '\n<tool_response>\n' }}
45
+ {{- message.content }}
46
+ {{- '\n</tool_response>' }}
47
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
48
+ {{- '<|im_end|>\n' }}
49
+ {%- endif %}
50
+ {%- endif %}
51
+ {%- endfor %}
52
+ {%- if add_generation_prompt %}
53
+ {{- '<|im_start|>assistant\n' }}
54
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "WrenForASR"
4
+ ],
5
+ "audio_end_id": 151666,
6
+ "audio_start_id": 151665,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_wren_asr.WrenASRConfig",
9
+ "AutoModel": "modeling_wren_asr.WrenForASR"
10
+ },
11
+ "codebook_size": 2048,
12
+ "dtype": "bfloat16",
13
+ "eos_token_id": 151643,
14
+ "k_codebooks": 3,
15
+ "llm_name": "Qwen/Qwen2.5-0.5B",
16
+ "mimi_model_name": "kyutai/mimi",
17
+ "model_type": "wren_asr",
18
+ "sampling_rate": 24000,
19
+ "transformers_version": "4.57.6",
20
+ "vocab_size": 151672
21
+ }
configuration_wren_asr.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wren ASR configuration — transformers-compatible."""
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class WrenASRConfig(PretrainedConfig):
6
+ model_type = "wren_asr"
7
+
8
+ def __init__(
9
+ self,
10
+ llm_name: str = "Qwen/Qwen2.5-0.5B",
11
+ mimi_model_name: str = "kyutai/mimi",
12
+ k_codebooks: int = 3,
13
+ codebook_size: int = 2048,
14
+ vocab_size: int = 151944,
15
+ # Special-token IDs (in the resized text vocab)
16
+ audio_start_id: int = None, # <|audio_start|> — opens audio prefix
17
+ audio_end_id: int = None, # <|audio_end|> — closes audio prefix; text begins after
18
+ eos_token_id: int = None, # end of transcript (LLM's existing eos)
19
+ sampling_rate: int = 24000,
20
+ **kwargs,
21
+ ):
22
+ self.llm_name = llm_name
23
+ self.mimi_model_name = mimi_model_name
24
+ self.k_codebooks = k_codebooks
25
+ self.codebook_size = codebook_size
26
+ self.vocab_size = vocab_size
27
+ self.audio_start_id = audio_start_id
28
+ self.audio_end_id = audio_end_id
29
+ self.sampling_rate = sampling_rate
30
+ super().__init__(eos_token_id=eos_token_id, **kwargs)
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c209c4f1d83e711a818eaabf310ce0f9df18bbfdd7771ba2f59ca49e94f78ac
3
+ size 1009646296
modeling_wren_asr.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wren-ASR model — a transformers-compatible wrapper over Qwen2.5-0.5B + Mimi
3
+ input embedding tables.
4
+
5
+ Designed for use with `AutoModel.from_pretrained(..., trust_remote_code=True)`.
6
+ Self-contained: no imports from a `src/` folder.
7
+
8
+ Sequence layout:
9
+ [ <audio_start> | sum_q embed_q(codes[q, t]) for t in 0..T-1 | <audio_end> | text... | <eos> ]
10
+
11
+ Audio positions feed a single summed-codebook embedding per real frame (no delay
12
+ pattern). Text-token prediction uses the LLM's existing `lm_head`; no new output
13
+ heads are added.
14
+ """
15
+
16
+ import math
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
22
+
23
+ try:
24
+ from .configuration_wren_asr import WrenASRConfig # package context (HF trust_remote_code)
25
+ except ImportError:
26
+ import importlib
27
+ WrenASRConfig = importlib.import_module("configuration_wren_asr").WrenASRConfig
28
+
29
+
30
+ class WrenForASR(PreTrainedModel):
31
+ config_class = WrenASRConfig
32
+ base_model_prefix = "wren_asr"
33
+
34
+ def __init__(self, config: WrenASRConfig):
35
+ super().__init__(config)
36
+ self.k = config.k_codebooks
37
+
38
+ # Build backbone from its config only. Pretrained backbone weights are
39
+ # already in our state_dict; no need to re-download.
40
+ llm_cfg = AutoConfig.from_pretrained(config.llm_name)
41
+ llm_cfg.vocab_size = config.vocab_size
42
+ self.llm = AutoModelForCausalLM.from_config(llm_cfg)
43
+
44
+ hidden = self.llm.config.hidden_size
45
+
46
+ # k input embedding tables (codes are inputs only — no PAD row needed).
47
+ self.audio_embeds = nn.ModuleList([
48
+ nn.Embedding(config.codebook_size, hidden)
49
+ for _ in range(self.k)
50
+ ])
51
+
52
+ self.embed_scale = 1.0 / math.sqrt(self.k)
53
+ self._mimi = None # lazy-loaded on first use
54
+
55
+ # --- Mimi codec (lazy-loaded encoder for raw-waveform input) ---
56
+
57
+ @property
58
+ def mimi(self):
59
+ if self._mimi is None:
60
+ from transformers import MimiModel
61
+ self._mimi = MimiModel.from_pretrained(self.config.mimi_model_name).to(self.device)
62
+ self._mimi.eval()
63
+ for p in self._mimi.parameters():
64
+ p.requires_grad_(False)
65
+ return self._mimi
66
+
67
+ @torch.no_grad()
68
+ def encode_audio(
69
+ self,
70
+ waveform: torch.Tensor,
71
+ src_sample_rate: int = 24000,
72
+ ) -> torch.LongTensor:
73
+ """Encode a waveform to Mimi codes [k, n_frames]."""
74
+ if waveform.dim() == 1:
75
+ waveform = waveform.unsqueeze(0)
76
+ if src_sample_rate != self.config.sampling_rate:
77
+ import torchaudio.transforms as T
78
+ waveform = T.Resample(src_sample_rate, self.config.sampling_rate)(waveform)
79
+ x = waveform.unsqueeze(0).to(self.device)
80
+ out = self.mimi.encode(x, num_quantizers=self.k)
81
+ return out.audio_codes[0].cpu() # [k, n_frames]
82
+
83
+ # --- Generation ---
84
+
85
+ @torch.no_grad()
86
+ def generate(
87
+ self,
88
+ audio_codes: torch.LongTensor, # [k, T] or [B, k, T]
89
+ max_new_tokens: int = 200,
90
+ do_sample: bool = False,
91
+ temperature: float = 1.0,
92
+ top_k: int = 50,
93
+ top_p: float = 1.0,
94
+ eos_token_id: Optional[int] = None,
95
+ pad_token_id: Optional[int] = None,
96
+ **kwargs,
97
+ ) -> torch.LongTensor:
98
+ """Transcribe Mimi codes to text-token IDs.
99
+
100
+ Returns the generated token IDs (without the audio prefix). Decode them
101
+ with your tokenizer to get text — typically:
102
+
103
+ ids = model.generate(audio_codes=codes)
104
+ text = tokenizer.decode(ids[0], skip_special_tokens=True)
105
+ """
106
+ device = next(self.parameters()).device
107
+ self.eval()
108
+
109
+ audio_codes = audio_codes.to(device)
110
+ if audio_codes.dim() == 2:
111
+ audio_codes = audio_codes.unsqueeze(0) # [1, k, T]
112
+ B, k, T = audio_codes.shape
113
+ assert k == self.k, f"expected k={self.k}, got {k}"
114
+
115
+ embed_tokens = self.llm.get_input_embeddings()
116
+ llm_dtype = next(self.llm.parameters()).dtype
117
+
118
+ start_emb = embed_tokens(torch.tensor([[self.config.audio_start_id]], device=device))
119
+ end_emb = embed_tokens(torch.tensor([[self.config.audio_end_id]], device=device))
120
+
121
+ clamped = audio_codes.clamp(0, self.config.codebook_size - 1)
122
+ audio_sum = self.audio_embeds[0](clamped[:, 0, :])
123
+ for q in range(1, self.k):
124
+ audio_sum = audio_sum + self.audio_embeds[q](clamped[:, q, :])
125
+ audio_sum = audio_sum * self.embed_scale
126
+
127
+ prompt_embeds = torch.cat([
128
+ start_emb.expand(B, -1, -1),
129
+ audio_sum,
130
+ end_emb.expand(B, -1, -1),
131
+ ], dim=1).to(llm_dtype)
132
+
133
+ gen_ids = self.llm.generate(
134
+ inputs_embeds = prompt_embeds,
135
+ max_new_tokens = max_new_tokens,
136
+ do_sample = do_sample,
137
+ temperature = temperature if do_sample else 1.0,
138
+ top_k = top_k if do_sample else 0,
139
+ top_p = top_p if do_sample else 1.0,
140
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id,
141
+ pad_token_id = pad_token_id,
142
+ )
143
+ # When called with `inputs_embeds`, HF generate returns ONLY the
144
+ # generated ids (the prompt has no token ids to echo back).
145
+ return gen_ids
processing_wren_asr.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wren-ASR processor: audio → Mimi codes (and optionally back to text via the
3
+ tokenizer for decoding model outputs).
4
+
5
+ Usage:
6
+ processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
7
+ inputs = processor(audio=wav, sampling_rate=sr) # → {"audio_codes": [k, T]}
8
+ ids = model.generate(**inputs, max_new_tokens=200)
9
+ text = processor.batch_decode(ids, skip_special_tokens=True)[0]
10
+ """
11
+
12
+ from typing import Optional, Union
13
+
14
+ import torch
15
+ from transformers.processing_utils import ProcessorMixin
16
+
17
+
18
+ class WrenASRProcessor(ProcessorMixin):
19
+ attributes = ["tokenizer"]
20
+ tokenizer_class = "AutoTokenizer"
21
+
22
+ def __init__(self, tokenizer, mimi_model_name: str = "kyutai/mimi", k_codebooks: int = 3, **kwargs):
23
+ super().__init__(tokenizer=tokenizer)
24
+ self.mimi_model_name = mimi_model_name
25
+ self.k_codebooks = k_codebooks
26
+ self._mimi = None
27
+
28
+ @property
29
+ def mimi(self):
30
+ if self._mimi is None:
31
+ from transformers import MimiModel
32
+ self._mimi = MimiModel.from_pretrained(self.mimi_model_name).eval()
33
+ for p in self._mimi.parameters():
34
+ p.requires_grad_(False)
35
+ return self._mimi
36
+
37
+ @torch.no_grad()
38
+ def __call__(
39
+ self,
40
+ audio: Optional[torch.Tensor] = None,
41
+ sampling_rate: Optional[int] = None,
42
+ audio_codes: Optional[torch.LongTensor] = None,
43
+ **kwargs,
44
+ ):
45
+ """Either pass `audio` (raw waveform) + `sampling_rate`, or pre-computed
46
+ `audio_codes` of shape [k, T] / [B, k, T].
47
+
48
+ Returns: {"audio_codes": LongTensor [B, k, T]}.
49
+ """
50
+ if audio_codes is not None:
51
+ codes = audio_codes
52
+ if codes.dim() == 2:
53
+ codes = codes.unsqueeze(0)
54
+ return {"audio_codes": codes}
55
+
56
+ if audio is None:
57
+ raise ValueError("Provide either `audio` (waveform) or `audio_codes`.")
58
+ if sampling_rate is None:
59
+ raise ValueError("`sampling_rate` is required when passing `audio`.")
60
+
61
+ wav = audio
62
+ if wav.dim() == 1:
63
+ wav = wav.unsqueeze(0)
64
+
65
+ if sampling_rate != 24000:
66
+ import torchaudio.transforms as T
67
+ wav = T.Resample(sampling_rate, 24000)(wav)
68
+
69
+ x = wav.unsqueeze(0) # [1, 1, T]
70
+ out = self.mimi.encode(x, num_quantizers=self.k_codebooks)
71
+ codes = out.audio_codes # [1, k, T]
72
+ return {"audio_codes": codes}
73
+
74
+ def batch_decode(self, *args, **kwargs):
75
+ return self.tokenizer.batch_decode(*args, **kwargs)
76
+
77
+ def decode(self, *args, **kwargs):
78
+ return self.tokenizer.decode(*args, **kwargs)
processor_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "WrenASRProcessor",
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_wren_asr.WrenASRProcessor"
5
+ },
6
+ "mimi_model_name": "kyutai/mimi",
7
+ "k_codebooks": 3
8
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ {
4
+ "content": "<|audio_start|>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ {
11
+ "content": "<|audio_end|>",
12
+ "lstrip": false,
13
+ "normalized": false,
14
+ "rstrip": false,
15
+ "single_word": false
16
+ }
17
+ ],
18
+ "eos_token": {
19
+ "content": "<|endoftext|>",
20
+ "lstrip": false,
21
+ "normalized": false,
22
+ "rstrip": false,
23
+ "single_word": false
24
+ },
25
+ "pad_token": {
26
+ "content": "<|endoftext|>",
27
+ "lstrip": false,
28
+ "normalized": false,
29
+ "rstrip": false,
30
+ "single_word": false
31
+ }
32
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e6df08c131c86d31c1366adc78bf78a1c59f2ef0e05bfd93f48c963c3abcea9
3
+ size 11422278
tokenizer_config.json ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<|audio_start|>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": true
188
+ },
189
+ "151666": {
190
+ "content": "<|audio_end|>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": true
196
+ }
197
+ },
198
+ "additional_special_tokens": [
199
+ "<|audio_start|>",
200
+ "<|audio_end|>"
201
+ ],
202
+ "bos_token": null,
203
+ "clean_up_tokenization_spaces": false,
204
+ "eos_token": "<|endoftext|>",
205
+ "errors": "replace",
206
+ "extra_special_tokens": {},
207
+ "model_max_length": 131072,
208
+ "pad_token": "<|endoftext|>",
209
+ "split_special_tokens": false,
210
+ "tokenizer_class": "Qwen2Tokenizer",
211
+ "unk_token": null
212
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff