File size: 2,302 Bytes
e6ab073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8549eb6
 
 
 
 
e6ab073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
Wren processor: text tokenization + audio saving.

Text casing is preserved as-is. Pass text naturally ("Hello, World!") — the model
is trained on mixed-case data (LJSpeech mixed-case, LibriTTS with punctuation).
The `<|audio_start|>` separator is always appended so `model.generate(**processor(text))`
"just works".
"""

from typing import List, Union

import torch
from transformers.processing_utils import ProcessorMixin


class WrenProcessor(ProcessorMixin):
    attributes        = ["tokenizer"]
    tokenizer_class   = "AutoTokenizer"

    def __init__(self, tokenizer, **kwargs):
        # Recent `transformers` injects `chat_template` (and possibly other future
        # fields) into processor_dict at load time — Qwen2.5's tokenizer ships one
        # via chat_template.jinja. We don't use chat templates for TTS, so swallow
        # any extra kwargs to stay forward-compatible.
        super().__init__(tokenizer=tokenizer)
        self.audio_start_id     = tokenizer.convert_tokens_to_ids("<|audio_start|>")
        self.reference_start_id = tokenizer.convert_tokens_to_ids("<|reference_start|>")
        self.reference_end_id   = tokenizer.convert_tokens_to_ids("<|reference_end|>")

    def __call__(
        self,
        text:           Union[str, List[str]],
        return_tensors: str = "pt",
        **kwargs,
    ):
        enc = self.tokenizer(
            text,
            add_special_tokens = False,
            return_tensors     = return_tensors,
            **kwargs,
        )
        ids = enc["input_ids"]
        if ids.dim() == 1:
            ids = ids.unsqueeze(0)

        # Append <|audio_start|> as the final prompt token
        sep = torch.full(
            (ids.shape[0], 1),
            self.audio_start_id,
            dtype=ids.dtype,
            device=ids.device,
        )
        ids = torch.cat([ids, sep], dim=1)
        return {"input_ids": ids}

    def save_audio(
        self,
        waveform:       torch.Tensor,
        path:           str,
        sampling_rate:  int = 24000,
    ) -> None:
        """Save a [1, T] or [T] waveform to disk at the given sample rate."""
        import torchaudio
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)
        torchaudio.save(path, waveform.cpu(), sampling_rate)