MisoTTS / app.py
multimodalart's picture
multimodalart HF Staff
Upload app.py with huggingface_hub
c271af2 verified
import os
os.environ.setdefault("NO_TORCH_COMPILE", "1")
import spaces
import numpy as np
import torch
import torchaudio
import gradio as gr
from transformers import AutoProcessor, MoonshineForConditionalGeneration
from generator import Segment, load_miso_8b
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = load_miso_8b(device=device, model_path_or_repo_id="MisoLabs/MisoTTS")
SAMPLE_RATE = generator.sample_rate
# Mimi encodes in fixed-size frames. moshi 0.2.12 and the repo's pinned 0.2.2 pad a
# partial trailing frame differently, so trim the reference to a whole number of frames
# to get byte-identical reference codes (every full frame already matches 1:1).
MIMI_FRAME_SIZE = int(generator._audio_tokenizer.frame_size)
# Moonshine ASR for auto-transcribing reference clips. Kept on CPU and never called
# from inside an @spaces.GPU function, so it does not consume the ZeroGPU quota.
ASR_SAMPLE_RATE = 16000
asr_processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine-base")
asr_model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-base").eval()
MAX_INPUT_CHARS = 1000
DESCRIPTION = """
# Miso TTS 8B
Text-to-speech with the [MisoLabs/MisoTTS](https://huggingface.co/MisoLabs/MisoTTS) model — an
8B [Sesame CSM](https://github.com/SesameAILabs/csm)-style model that generates Mimi audio codes
from text, with optional voice continuation from a reference clip.
"""
def _resample_to_model(audio: torch.Tensor, sr: int) -> torch.Tensor:
audio = audio.mean(dim=0) if audio.ndim > 1 else audio
if sr != SAMPLE_RATE:
audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=SAMPLE_RATE)
return audio
def transcribe(ref_audio_path):
"""CPU-only auto-transcription of the reference clip (runs on the always-on host)."""
if not ref_audio_path:
return gr.update()
wav, sr = torchaudio.load(ref_audio_path)
wav = wav.mean(dim=0) if wav.ndim > 1 else wav
if sr != ASR_SAMPLE_RATE:
wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=ASR_SAMPLE_RATE)
inputs = asr_processor(wav.numpy(), sampling_rate=ASR_SAMPLE_RATE, return_tensors="pt")
with torch.no_grad():
tokens = asr_model.generate(**inputs)
return asr_processor.decode(tokens[0], skip_special_tokens=True).strip()
@spaces.GPU(duration=120)
def synthesize(text, ref_audio_path, ref_text, speaker_id, max_length_s, temperature, topk):
text = (text or "").strip()
if not text:
raise gr.Error("Please enter some text to synthesize.")
if len(text) > MAX_INPUT_CHARS:
raise gr.Error(f"Text too long (>{MAX_INPUT_CHARS} characters).")
# ZeroGPU streams weights to the real GPU on first entry but leaves the torchtune
# KV-cache's non-persistent buffers (e.g. cache_pos) behind, causing a cuda/cpu
# device mismatch. Re-place the model on the device here, inside the GPU worker.
generator._model.to(device)
generator._audio_tokenizer.to(device)
context = []
if ref_audio_path:
if not (ref_text or "").strip():
raise gr.Error("Please provide the transcript of the reference audio.")
wav, sr = torchaudio.load(ref_audio_path)
wav = _resample_to_model(wav, sr)
usable = (wav.shape[-1] // MIMI_FRAME_SIZE) * MIMI_FRAME_SIZE
if usable < SAMPLE_RATE: # under ~1s of usable audio
gr.Warning("The reference audio may be too short; result quality may suffer.")
if usable > 0:
wav = wav[:usable].to(device)
context = [Segment(speaker=int(speaker_id), text=ref_text.strip(), audio=wav)]
audio = generator.generate(
text=text,
speaker=int(speaker_id),
context=context,
max_audio_length_ms=float(max_length_s) * 1000.0,
temperature=float(temperature),
topk=int(topk),
)
audio_np = (audio * 32768).clamp(-32768, 32767).to(torch.int16).cpu().numpy()
return SAMPLE_RATE, audio_np
with gr.Blocks(title="Miso TTS 8B") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
text = gr.Textbox(
label="Text to synthesize",
placeholder="Hello from Miso.",
lines=3,
value="Hello from Miso. This is an eight billion parameter text to speech model.",
)
with gr.Accordion("Voice cloning (optional)", open=False):
ref_audio = gr.Audio(label="Reference audio", type="filepath")
ref_text = gr.Textbox(
label="Reference transcript (auto-filled on upload)",
placeholder="The exact words spoken in the reference audio.",
lines=2,
)
with gr.Accordion("Advanced", open=False):
speaker_id = gr.Slider(0, 1, value=0, step=1, label="Speaker ID")
max_length = gr.Slider(2, 60, value=10, step=1, label="Max audio length (s)")
temperature = gr.Slider(
0.1, 1.5, value=0.7, step=0.05,
label="Temperature (auto-lowered when cloning a voice)",
)
topk = gr.Slider(1, 100, value=50, step=1, label="Top-k")
run = gr.Button("Generate", variant="primary")
with gr.Column():
out = gr.Audio(label="Generated speech")
ref_audio.change(transcribe, inputs=[ref_audio], outputs=[ref_text])
# Cloning tracks the reference much more closely at low temperature.
ref_audio.change(
lambda p: 0.4 if p else 0.7, inputs=[ref_audio], outputs=[temperature]
)
run.click(
synthesize,
inputs=[text, ref_audio, ref_text, speaker_id, max_length, temperature, topk],
outputs=[out],
)
demo.queue().launch()