Wren-TTS-0.5B-multi / processing_wren.py
shangeth's picture
Upload Wren-TTS-0.5B-multi checkpoint
8549eb6 verified
"""
Wren processor: text tokenization + audio saving.
Text casing is preserved as-is. Pass text naturally ("Hello, World!") — the model
is trained on mixed-case data (LJSpeech mixed-case, LibriTTS with punctuation).
The `<|audio_start|>` separator is always appended so `model.generate(**processor(text))`
"just works".
"""
from typing import List, Union
import torch
from transformers.processing_utils import ProcessorMixin
class WrenProcessor(ProcessorMixin):
attributes = ["tokenizer"]
tokenizer_class = "AutoTokenizer"
def __init__(self, tokenizer, **kwargs):
# Recent `transformers` injects `chat_template` (and possibly other future
# fields) into processor_dict at load time — Qwen2.5's tokenizer ships one
# via chat_template.jinja. We don't use chat templates for TTS, so swallow
# any extra kwargs to stay forward-compatible.
super().__init__(tokenizer=tokenizer)
self.audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_start|>")
self.reference_start_id = tokenizer.convert_tokens_to_ids("<|reference_start|>")
self.reference_end_id = tokenizer.convert_tokens_to_ids("<|reference_end|>")
def __call__(
self,
text: Union[str, List[str]],
return_tensors: str = "pt",
**kwargs,
):
enc = self.tokenizer(
text,
add_special_tokens = False,
return_tensors = return_tensors,
**kwargs,
)
ids = enc["input_ids"]
if ids.dim() == 1:
ids = ids.unsqueeze(0)
# Append <|audio_start|> as the final prompt token
sep = torch.full(
(ids.shape[0], 1),
self.audio_start_id,
dtype=ids.dtype,
device=ids.device,
)
ids = torch.cat([ids, sep], dim=1)
return {"input_ids": ids}
def save_audio(
self,
waveform: torch.Tensor,
path: str,
sampling_rate: int = 24000,
) -> None:
"""Save a [1, T] or [T] waveform to disk at the given sample rate."""
import torchaudio
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
torchaudio.save(path, waveform.cpu(), sampling_rate)