Wren-ASR-0.5B-multi / processing_wren_asr.py
shangeth's picture
Upload Wren-ASR-0.5B-multi checkpoint
72ecc83 verified
"""
Wren-ASR processor: audio → Mimi codes (and optionally back to text via the
tokenizer for decoding model outputs).
Usage:
processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
inputs = processor(audio=wav, sampling_rate=sr) # → {"audio_codes": [k, T]}
ids = model.generate(**inputs, max_new_tokens=200)
text = processor.batch_decode(ids, skip_special_tokens=True)[0]
"""
from typing import Optional, Union
import torch
from transformers.processing_utils import ProcessorMixin
class WrenASRProcessor(ProcessorMixin):
attributes = ["tokenizer"]
tokenizer_class = "AutoTokenizer"
def __init__(self, tokenizer, mimi_model_name: str = "kyutai/mimi", k_codebooks: int = 3, **kwargs):
super().__init__(tokenizer=tokenizer)
self.mimi_model_name = mimi_model_name
self.k_codebooks = k_codebooks
self._mimi = None
@property
def mimi(self):
if self._mimi is None:
from transformers import MimiModel
self._mimi = MimiModel.from_pretrained(self.mimi_model_name).eval()
for p in self._mimi.parameters():
p.requires_grad_(False)
return self._mimi
@torch.no_grad()
def __call__(
self,
audio: Optional[torch.Tensor] = None,
sampling_rate: Optional[int] = None,
audio_codes: Optional[torch.LongTensor] = None,
**kwargs,
):
"""Either pass `audio` (raw waveform) + `sampling_rate`, or pre-computed
`audio_codes` of shape [k, T] / [B, k, T].
Returns: {"audio_codes": LongTensor [B, k, T]}.
"""
if audio_codes is not None:
codes = audio_codes
if codes.dim() == 2:
codes = codes.unsqueeze(0)
return {"audio_codes": codes}
if audio is None:
raise ValueError("Provide either `audio` (waveform) or `audio_codes`.")
if sampling_rate is None:
raise ValueError("`sampling_rate` is required when passing `audio`.")
wav = audio
if wav.dim() == 1:
wav = wav.unsqueeze(0)
if sampling_rate != 24000:
import torchaudio.transforms as T
wav = T.Resample(sampling_rate, 24000)(wav)
x = wav.unsqueeze(0) # [1, 1, T]
out = self.mimi.encode(x, num_quantizers=self.k_codebooks)
codes = out.audio_codes # [1, k, T]
return {"audio_codes": codes}
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)