nanoTTS / app.py
Pedro Sandoval
Add checkpoint with Git LFS
ff050fc
Raw
History Blame Contribute Delete
6.84 kB
from __future__ import annotations
import sys
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
import gradio as gr
import numpy as np
import torch
from huggingface_hub import hf_hub_download
ROOT = Path(__file__).resolve().parent
WAVTOKENIZER_ROOT = ROOT / "third_party" / "WavTokenizer"
if str(WAVTOKENIZER_ROOT) not in sys.path:
sys.path.insert(0, str(WAVTOKENIZER_ROOT))
from decoder.pretrained import WavTokenizer # noqa: E402
from model import GPT, GPTConfig # noqa: E402
from tokenizer import create_joint_tokenizer # noqa: E402
SAMPLE_RATE = 24_000
DEFAULT_MAX_NEW_TOKENS = 500
DEFAULT_TEMPERATURE = 0.9
DEFAULT_TOP_K = 50
MAX_PROMPT_CHARS = 500
CHECKPOINT_PATH = ROOT / "checkpoints" / "ckpt_025000_inference.pt"
TEXT_TOKENIZER_PATH = ROOT / "text_tokenizer" / "libritts_bpe.json"
WAVTOKENIZER_REPO_ID = "novateur/WavTokenizer"
WAVTOKENIZER_CONFIG = "wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
WAVTOKENIZER_CHECKPOINT = "WavTokenizer_small_600_24k_4096.ckpt"
@dataclass(frozen=True)
class ModelBundle:
model: GPT
joint_tokenizer: object
device: torch.device
def _device() -> torch.device:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _clean_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
unwanted_prefix = "_orig_mod."
cleaned = {}
for key, value in state_dict.items():
if key.startswith(unwanted_prefix):
key = key[len(unwanted_prefix) :]
cleaned[key] = value
return cleaned
@lru_cache(maxsize=1)
def load_bundle() -> ModelBundle:
if not CHECKPOINT_PATH.exists():
raise FileNotFoundError(
f"Missing nanoTTS checkpoint at {CHECKPOINT_PATH}. "
"Upload checkpoints/ckpt_025000_inference.pt to this Space first."
)
device = _device()
wavtokenizer_config_path = hf_hub_download(
repo_id=WAVTOKENIZER_REPO_ID,
filename=WAVTOKENIZER_CONFIG,
)
wavtokenizer_checkpoint_path = hf_hub_download(
repo_id=WAVTOKENIZER_REPO_ID,
filename=WAVTOKENIZER_CHECKPOINT,
)
wavtokenizer = WavTokenizer.from_pretrained0802(
wavtokenizer_config_path,
wavtokenizer_checkpoint_path,
).to("cpu")
wavtokenizer.bandwidth_id = torch.tensor([0], device="cpu")
joint_tokenizer = create_joint_tokenizer(str(TEXT_TOKENIZER_PATH), wavtokenizer)
checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu")
model_args = checkpoint["model_args"]
model = GPT(GPTConfig(**model_args), joint_tokenizer)
model.load_state_dict(_clean_state_dict(checkpoint["model"]))
model.to(device)
model.eval()
return ModelBundle(model=model, joint_tokenizer=joint_tokenizer, device=device)
def _validate_prompt(text: str, max_new_tokens: int) -> str:
text = " ".join((text or "").strip().split())
if not text:
raise gr.Error("Enter some text to synthesize.")
if len(text) > MAX_PROMPT_CHARS:
raise gr.Error(f"Keep the prompt under {MAX_PROMPT_CHARS} characters.")
if max_new_tokens < 40:
raise gr.Error("Generate at least 40 audio tokens.")
return text
def synthesize(
text: str,
max_new_tokens: int,
temperature: float,
top_k: int,
seed: int,
) -> tuple[int, np.ndarray]:
max_new_tokens = int(max_new_tokens)
text = _validate_prompt(text, max_new_tokens)
bundle = load_bundle()
if seed > 0:
torch.manual_seed(int(seed))
if torch.cuda.is_available():
torch.cuda.manual_seed_all(int(seed))
prompt = f"<BOS>{text}<AUDIO_START>"
prompt_ids = bundle.joint_tokenizer.encode_text(prompt)
if len(prompt_ids) >= bundle.model.config.block_size:
raise gr.Error("Prompt is too long for this checkpoint's context window.")
top_k_value = None if int(top_k) <= 0 else int(top_k)
batch = torch.tensor(prompt_ids, dtype=torch.long, device=bundle.device).unsqueeze(0)
with torch.inference_mode():
generated = bundle.model.generate(
batch,
max_new_tokens=max_new_tokens,
temperature=float(temperature),
top_k=top_k_value,
)
audio = bundle.joint_tokenizer.decode(generated.detach().cpu())
if audio is None or audio.numel() == 0:
raise gr.Error("The model did not generate any audio tokens. Try again or change the seed.")
audio_np = audio.squeeze(0).float().cpu().numpy()
audio_np = np.clip(audio_np, -1.0, 1.0)
return SAMPLE_RATE, audio_np
with gr.Blocks(title="nanoTTS") as demo:
gr.Markdown(
"# nanoTTS\n"
"A minimal GPT-style text-to-speech model trained on LibriTTS. Code available at [psandovalsegura/nanoTTS](https://github.com/psandovalsegura/nanoTTS)."
)
with gr.Row():
with gr.Column(scale=3):
text_input = gr.Textbox(
label="Text",
placeholder="Type a short sentence...",
lines=4,
max_lines=8,
value="Hello friend! This is nano text to speech.",
)
generate_button = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
audio_output = gr.Audio(label="Generated audio", type="numpy")
with gr.Accordion("Generation settings", open=False):
max_new_tokens = gr.Slider(
minimum=40,
maximum=800,
value=DEFAULT_MAX_NEW_TOKENS,
step=20,
label="Max audio tokens",
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.5,
value=DEFAULT_TEMPERATURE,
step=0.05,
label="Temperature",
)
top_k = gr.Slider(
minimum=0,
maximum=200,
value=DEFAULT_TOP_K,
step=1,
label="Top-k",
)
seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
gr.Examples(
examples=[
["Nano text to speech", 500, 0.9, 50, 0],
["Created at the University of Maryland", 500, 0.9, 50, 0],
["Trained using only two hundred and forty five hours of speech data", 600, 0.9, 50, 0],
["In the year two thousand and twenty six", 500, 0.9, 50, 0],
["The quick brown fox jumps over the lazy dog.", 500, 0.9, 50, 0],
],
inputs=[text_input, max_new_tokens, temperature, top_k, seed],
outputs=audio_output,
fn=synthesize,
cache_examples=False,
)
generate_button.click(
fn=synthesize,
inputs=[text_input, max_new_tokens, temperature, top_k, seed],
outputs=audio_output,
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=1).launch()