""" Wren-ASR processor: audio → Mimi codes (and optionally back to text via the tokenizer for decoding model outputs). Usage: processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True) inputs = processor(audio=wav, sampling_rate=sr) # → {"audio_codes": [k, T]} ids = model.generate(**inputs, max_new_tokens=200) text = processor.batch_decode(ids, skip_special_tokens=True)[0] """ from typing import Optional, Union import torch from transformers.processing_utils import ProcessorMixin class WrenASRProcessor(ProcessorMixin): attributes = ["tokenizer"] tokenizer_class = "AutoTokenizer" def __init__(self, tokenizer, mimi_model_name: str = "kyutai/mimi", k_codebooks: int = 3, **kwargs): super().__init__(tokenizer=tokenizer) self.mimi_model_name = mimi_model_name self.k_codebooks = k_codebooks self._mimi = None @property def mimi(self): if self._mimi is None: from transformers import MimiModel self._mimi = MimiModel.from_pretrained(self.mimi_model_name).eval() for p in self._mimi.parameters(): p.requires_grad_(False) return self._mimi @torch.no_grad() def __call__( self, audio: Optional[torch.Tensor] = None, sampling_rate: Optional[int] = None, audio_codes: Optional[torch.LongTensor] = None, **kwargs, ): """Either pass `audio` (raw waveform) + `sampling_rate`, or pre-computed `audio_codes` of shape [k, T] / [B, k, T]. Returns: {"audio_codes": LongTensor [B, k, T]}. """ if audio_codes is not None: codes = audio_codes if codes.dim() == 2: codes = codes.unsqueeze(0) return {"audio_codes": codes} if audio is None: raise ValueError("Provide either `audio` (waveform) or `audio_codes`.") if sampling_rate is None: raise ValueError("`sampling_rate` is required when passing `audio`.") wav = audio if wav.dim() == 1: wav = wav.unsqueeze(0) if sampling_rate != 24000: import torchaudio.transforms as T wav = T.Resample(sampling_rate, 24000)(wav) x = wav.unsqueeze(0) # [1, 1, T] out = self.mimi.encode(x, num_quantizers=self.k_codebooks) codes = out.audio_codes # [1, k, T] return {"audio_codes": codes} def batch_decode(self, *args, **kwargs): return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): return self.tokenizer.decode(*args, **kwargs)