File size: 5,842 Bytes
cecbc0f
 
 
 
 
 
 
 
 
 
35401f2
 
cecbc0f
 
 
 
 
86825c9
 
 
 
cecbc0f
35401f2
 
 
 
 
d84e6ad
cecbc0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35401f2
 
 
 
 
 
 
 
 
 
 
 
 
 
cecbc0f
7847a40
cecbc0f
 
 
 
 
 
35401f2
 
 
 
 
 
cecbc0f
 
 
 
 
86825c9
 
 
 
 
 
 
cecbc0f
 
 
 
 
7847a40
cecbc0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35401f2
cecbc0f
 
 
 
 
7847a40
86825c9
c271af2
86825c9
 
cecbc0f
 
 
 
 
35401f2
86825c9
 
c271af2
86825c9
35401f2
cecbc0f
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os

os.environ.setdefault("NO_TORCH_COMPILE", "1")

import spaces
import numpy as np
import torch
import torchaudio
import gradio as gr

from transformers import AutoProcessor, MoonshineForConditionalGeneration

from generator import Segment, load_miso_8b

device = "cuda" if torch.cuda.is_available() else "cpu"
generator = load_miso_8b(device=device, model_path_or_repo_id="MisoLabs/MisoTTS")
SAMPLE_RATE = generator.sample_rate
# Mimi encodes in fixed-size frames. moshi 0.2.12 and the repo's pinned 0.2.2 pad a
# partial trailing frame differently, so trim the reference to a whole number of frames
# to get byte-identical reference codes (every full frame already matches 1:1).
MIMI_FRAME_SIZE = int(generator._audio_tokenizer.frame_size)

# Moonshine ASR for auto-transcribing reference clips. Kept on CPU and never called
# from inside an @spaces.GPU function, so it does not consume the ZeroGPU quota.
ASR_SAMPLE_RATE = 16000
asr_processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine-base")
asr_model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-base").eval()

MAX_INPUT_CHARS = 1000

DESCRIPTION = """
# Miso TTS 8B

Text-to-speech with the [MisoLabs/MisoTTS](https://huggingface.co/MisoLabs/MisoTTS) model — an
8B [Sesame CSM](https://github.com/SesameAILabs/csm)-style model that generates Mimi audio codes
from text, with optional voice continuation from a reference clip.
"""


def _resample_to_model(audio: torch.Tensor, sr: int) -> torch.Tensor:
    audio = audio.mean(dim=0) if audio.ndim > 1 else audio
    if sr != SAMPLE_RATE:
        audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=SAMPLE_RATE)
    return audio


def transcribe(ref_audio_path):
    """CPU-only auto-transcription of the reference clip (runs on the always-on host)."""
    if not ref_audio_path:
        return gr.update()
    wav, sr = torchaudio.load(ref_audio_path)
    wav = wav.mean(dim=0) if wav.ndim > 1 else wav
    if sr != ASR_SAMPLE_RATE:
        wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=ASR_SAMPLE_RATE)
    inputs = asr_processor(wav.numpy(), sampling_rate=ASR_SAMPLE_RATE, return_tensors="pt")
    with torch.no_grad():
        tokens = asr_model.generate(**inputs)
    return asr_processor.decode(tokens[0], skip_special_tokens=True).strip()


@spaces.GPU(duration=120)
def synthesize(text, ref_audio_path, ref_text, speaker_id, max_length_s, temperature, topk):
    text = (text or "").strip()
    if not text:
        raise gr.Error("Please enter some text to synthesize.")
    if len(text) > MAX_INPUT_CHARS:
        raise gr.Error(f"Text too long (>{MAX_INPUT_CHARS} characters).")

    # ZeroGPU streams weights to the real GPU on first entry but leaves the torchtune
    # KV-cache's non-persistent buffers (e.g. cache_pos) behind, causing a cuda/cpu
    # device mismatch. Re-place the model on the device here, inside the GPU worker.
    generator._model.to(device)
    generator._audio_tokenizer.to(device)

    context = []
    if ref_audio_path:
        if not (ref_text or "").strip():
            raise gr.Error("Please provide the transcript of the reference audio.")
        wav, sr = torchaudio.load(ref_audio_path)
        wav = _resample_to_model(wav, sr)
        usable = (wav.shape[-1] // MIMI_FRAME_SIZE) * MIMI_FRAME_SIZE
        if usable < SAMPLE_RATE:  # under ~1s of usable audio
            gr.Warning("The reference audio may be too short; result quality may suffer.")
        if usable > 0:
            wav = wav[:usable].to(device)
            context = [Segment(speaker=int(speaker_id), text=ref_text.strip(), audio=wav)]

    audio = generator.generate(
        text=text,
        speaker=int(speaker_id),
        context=context,
        max_audio_length_ms=float(max_length_s) * 1000.0,
        temperature=float(temperature),
        topk=int(topk),
    )

    audio_np = (audio * 32768).clamp(-32768, 32767).to(torch.int16).cpu().numpy()
    return SAMPLE_RATE, audio_np


with gr.Blocks(title="Miso TTS 8B") as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column():
            text = gr.Textbox(
                label="Text to synthesize",
                placeholder="Hello from Miso.",
                lines=3,
                value="Hello from Miso. This is an eight billion parameter text to speech model.",
            )
            with gr.Accordion("Voice cloning (optional)", open=False):
                ref_audio = gr.Audio(label="Reference audio", type="filepath")
                ref_text = gr.Textbox(
                    label="Reference transcript (auto-filled on upload)",
                    placeholder="The exact words spoken in the reference audio.",
                    lines=2,
                )
            with gr.Accordion("Advanced", open=False):
                speaker_id = gr.Slider(0, 1, value=0, step=1, label="Speaker ID")
                max_length = gr.Slider(2, 60, value=10, step=1, label="Max audio length (s)")
                temperature = gr.Slider(
                    0.1, 1.5, value=0.7, step=0.05,
                    label="Temperature (auto-lowered when cloning a voice)",
                )
                topk = gr.Slider(1, 100, value=50, step=1, label="Top-k")
            run = gr.Button("Generate", variant="primary")
        with gr.Column():
            out = gr.Audio(label="Generated speech")

    ref_audio.change(transcribe, inputs=[ref_audio], outputs=[ref_text])
    # Cloning tracks the reference much more closely at low temperature.
    ref_audio.change(
        lambda p: 0.4 if p else 0.7, inputs=[ref_audio], outputs=[temperature]
    )

    run.click(
        synthesize,
        inputs=[text, ref_audio, ref_text, speaker_id, max_length, temperature, topk],
        outputs=[out],
    )

demo.queue().launch()