| """ |
| Wren-ASR processor: audio → Mimi codes (and optionally back to text via the |
| tokenizer for decoding model outputs). |
| |
| Usage: |
| processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True) |
| inputs = processor(audio=wav, sampling_rate=sr) # → {"audio_codes": [k, T]} |
| ids = model.generate(**inputs, max_new_tokens=200) |
| text = processor.batch_decode(ids, skip_special_tokens=True)[0] |
| """ |
|
|
| from typing import Optional, Union |
|
|
| import torch |
| from transformers.processing_utils import ProcessorMixin |
|
|
|
|
| class WrenASRProcessor(ProcessorMixin): |
| attributes = ["tokenizer"] |
| tokenizer_class = "AutoTokenizer" |
|
|
| def __init__(self, tokenizer, mimi_model_name: str = "kyutai/mimi", k_codebooks: int = 3, **kwargs): |
| super().__init__(tokenizer=tokenizer) |
| self.mimi_model_name = mimi_model_name |
| self.k_codebooks = k_codebooks |
| self._mimi = None |
|
|
| @property |
| def mimi(self): |
| if self._mimi is None: |
| from transformers import MimiModel |
| self._mimi = MimiModel.from_pretrained(self.mimi_model_name).eval() |
| for p in self._mimi.parameters(): |
| p.requires_grad_(False) |
| return self._mimi |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| audio: Optional[torch.Tensor] = None, |
| sampling_rate: Optional[int] = None, |
| audio_codes: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ): |
| """Either pass `audio` (raw waveform) + `sampling_rate`, or pre-computed |
| `audio_codes` of shape [k, T] / [B, k, T]. |
| |
| Returns: {"audio_codes": LongTensor [B, k, T]}. |
| """ |
| if audio_codes is not None: |
| codes = audio_codes |
| if codes.dim() == 2: |
| codes = codes.unsqueeze(0) |
| return {"audio_codes": codes} |
|
|
| if audio is None: |
| raise ValueError("Provide either `audio` (waveform) or `audio_codes`.") |
| if sampling_rate is None: |
| raise ValueError("`sampling_rate` is required when passing `audio`.") |
|
|
| wav = audio |
| if wav.dim() == 1: |
| wav = wav.unsqueeze(0) |
|
|
| if sampling_rate != 24000: |
| import torchaudio.transforms as T |
| wav = T.Resample(sampling_rate, 24000)(wav) |
|
|
| x = wav.unsqueeze(0) |
| out = self.mimi.encode(x, num_quantizers=self.k_codebooks) |
| codes = out.audio_codes |
| return {"audio_codes": codes} |
|
|
| def batch_decode(self, *args, **kwargs): |
| return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
| def decode(self, *args, **kwargs): |
| return self.tokenizer.decode(*args, **kwargs) |
|
|