File size: 4,730 Bytes
d0f0efe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "moshi==0.2.11",
#     "torch",
#     "sphn",
#     "sounddevice",
# ]
# ///
import argparse
import sys

import numpy as np
import queue
import sphn
import time
import torch
from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel


def main():
    parser = argparse.ArgumentParser(
        description="Run Kyutai TTS using the PyTorch implementation"
    )
    parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
    parser.add_argument(
        "out", type=str, help="Output file to generate, use - for playing the audio"
    )
    parser.add_argument(
        "--hf-repo",
        type=str,
        default=DEFAULT_DSM_TTS_REPO,
        help="HF repo in which to look for the pretrained models.",
    )
    parser.add_argument(
        "--voice-repo",
        default=DEFAULT_DSM_TTS_VOICE_REPO,
        help="HF repo in which to look for pre-computed voice embeddings.",
    )
    parser.add_argument(
        "--voice",
        default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
        help="The voice to use, relative to the voice repo root. "
        f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="Device on which to run, defaults to 'cuda'.",
    )
    args = parser.parse_args()

    print("Loading model...")
    checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
    tts_model = TTSModel.from_checkpoint_info(
        checkpoint_info, n_q=32, temp=0.6, device=args.device
    )

    if args.inp == "-":
        if sys.stdin.isatty():  # Interactive
            print("Enter text to synthesize (Ctrl+D to end input):")
        text = sys.stdin.read().strip()
    else:
        with open(args.inp, "r", encoding="utf-8") as fobj:
            text = fobj.read().strip()

    # If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]
    entries = tts_model.prepare_script([text], padding_between=1)
    if args.voice.endswith(".safetensors"):
        voice_path = args.voice
    else:
        voice_path = tts_model.get_voice_path(args.voice)
    # CFG coef goes here because the model was trained with CFG distillation,
    # so it's not _actually_ doing CFG at inference time.
    # Also, if you are generating a dialog, you should have two voices in the list.
    condition_attributes = tts_model.make_condition_attributes(
        [voice_path], cfg_coef=2.0
    )
    _frames_cnt = 0

    if args.out == "-":
        # Stream the audio to the speakers using sounddevice.
        import sounddevice as sd

        pcms = queue.Queue()

        def _on_frame(frame):
            nonlocal _frames_cnt
            if (frame != -1).all():
                pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
                pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
                _frames_cnt += 1
                print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)

        def audio_callback(outdata, _a, _b, _c):
            try:
                pcm_data = pcms.get(block=False)
                outdata[:, 0] = pcm_data
            except queue.Empty:
                outdata[:] = 0

        with sd.OutputStream(
            samplerate=tts_model.mimi.sample_rate,
            blocksize=1920,
            channels=1,
            callback=audio_callback,
        ):
            with tts_model.mimi.streaming(1):
                tts_model.generate(
                    [entries], [condition_attributes], on_frame=_on_frame
                )
            time.sleep(3)
            while True:
                if pcms.qsize() == 0:
                    break
                time.sleep(1)
    else:

        def _on_frame(frame):
            nonlocal _frames_cnt
            if (frame != -1).all():
                _frames_cnt += 1
                print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)

        start_time = time.time()
        result = tts_model.generate(
            [entries], [condition_attributes], on_frame=_on_frame
        )
        print(f"\nTotal time: {time.time() - start_time:.2f}s")
        with tts_model.mimi.streaming(1), torch.no_grad():
            pcms = []
            for frame in result.frames[tts_model.delay_steps :]:
                pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
                pcms.append(np.clip(pcm[0, 0], -1, 1))
            pcm = np.concatenate(pcms, axis=-1)
        sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)


if __name__ == "__main__":
    main()