from __future__ import annotations import sys from dataclasses import dataclass from functools import lru_cache from pathlib import Path import gradio as gr import numpy as np import torch from huggingface_hub import hf_hub_download ROOT = Path(__file__).resolve().parent WAVTOKENIZER_ROOT = ROOT / "third_party" / "WavTokenizer" if str(WAVTOKENIZER_ROOT) not in sys.path: sys.path.insert(0, str(WAVTOKENIZER_ROOT)) from decoder.pretrained import WavTokenizer # noqa: E402 from model import GPT, GPTConfig # noqa: E402 from tokenizer import create_joint_tokenizer # noqa: E402 SAMPLE_RATE = 24_000 DEFAULT_MAX_NEW_TOKENS = 500 DEFAULT_TEMPERATURE = 0.9 DEFAULT_TOP_K = 50 MAX_PROMPT_CHARS = 500 CHECKPOINT_PATH = ROOT / "checkpoints" / "ckpt_025000_inference.pt" TEXT_TOKENIZER_PATH = ROOT / "text_tokenizer" / "libritts_bpe.json" WAVTOKENIZER_REPO_ID = "novateur/WavTokenizer" WAVTOKENIZER_CONFIG = "wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn.yaml" WAVTOKENIZER_CHECKPOINT = "WavTokenizer_small_600_24k_4096.ckpt" @dataclass(frozen=True) class ModelBundle: model: GPT joint_tokenizer: object device: torch.device def _device() -> torch.device: return torch.device("cuda" if torch.cuda.is_available() else "cpu") def _clean_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: unwanted_prefix = "_orig_mod." cleaned = {} for key, value in state_dict.items(): if key.startswith(unwanted_prefix): key = key[len(unwanted_prefix) :] cleaned[key] = value return cleaned @lru_cache(maxsize=1) def load_bundle() -> ModelBundle: if not CHECKPOINT_PATH.exists(): raise FileNotFoundError( f"Missing nanoTTS checkpoint at {CHECKPOINT_PATH}. " "Upload checkpoints/ckpt_025000_inference.pt to this Space first." ) device = _device() wavtokenizer_config_path = hf_hub_download( repo_id=WAVTOKENIZER_REPO_ID, filename=WAVTOKENIZER_CONFIG, ) wavtokenizer_checkpoint_path = hf_hub_download( repo_id=WAVTOKENIZER_REPO_ID, filename=WAVTOKENIZER_CHECKPOINT, ) wavtokenizer = WavTokenizer.from_pretrained0802( wavtokenizer_config_path, wavtokenizer_checkpoint_path, ).to("cpu") wavtokenizer.bandwidth_id = torch.tensor([0], device="cpu") joint_tokenizer = create_joint_tokenizer(str(TEXT_TOKENIZER_PATH), wavtokenizer) checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu") model_args = checkpoint["model_args"] model = GPT(GPTConfig(**model_args), joint_tokenizer) model.load_state_dict(_clean_state_dict(checkpoint["model"])) model.to(device) model.eval() return ModelBundle(model=model, joint_tokenizer=joint_tokenizer, device=device) def _validate_prompt(text: str, max_new_tokens: int) -> str: text = " ".join((text or "").strip().split()) if not text: raise gr.Error("Enter some text to synthesize.") if len(text) > MAX_PROMPT_CHARS: raise gr.Error(f"Keep the prompt under {MAX_PROMPT_CHARS} characters.") if max_new_tokens < 40: raise gr.Error("Generate at least 40 audio tokens.") return text def synthesize( text: str, max_new_tokens: int, temperature: float, top_k: int, seed: int, ) -> tuple[int, np.ndarray]: max_new_tokens = int(max_new_tokens) text = _validate_prompt(text, max_new_tokens) bundle = load_bundle() if seed > 0: torch.manual_seed(int(seed)) if torch.cuda.is_available(): torch.cuda.manual_seed_all(int(seed)) prompt = f"{text}" prompt_ids = bundle.joint_tokenizer.encode_text(prompt) if len(prompt_ids) >= bundle.model.config.block_size: raise gr.Error("Prompt is too long for this checkpoint's context window.") top_k_value = None if int(top_k) <= 0 else int(top_k) batch = torch.tensor(prompt_ids, dtype=torch.long, device=bundle.device).unsqueeze(0) with torch.inference_mode(): generated = bundle.model.generate( batch, max_new_tokens=max_new_tokens, temperature=float(temperature), top_k=top_k_value, ) audio = bundle.joint_tokenizer.decode(generated.detach().cpu()) if audio is None or audio.numel() == 0: raise gr.Error("The model did not generate any audio tokens. Try again or change the seed.") audio_np = audio.squeeze(0).float().cpu().numpy() audio_np = np.clip(audio_np, -1.0, 1.0) return SAMPLE_RATE, audio_np with gr.Blocks(title="nanoTTS") as demo: gr.Markdown( "# nanoTTS\n" "A minimal GPT-style text-to-speech model trained on LibriTTS. Code available at [psandovalsegura/nanoTTS](https://github.com/psandovalsegura/nanoTTS)." ) with gr.Row(): with gr.Column(scale=3): text_input = gr.Textbox( label="Text", placeholder="Type a short sentence...", lines=4, max_lines=8, value="Hello friend! This is nano text to speech.", ) generate_button = gr.Button("Generate", variant="primary") with gr.Column(scale=2): audio_output = gr.Audio(label="Generated audio", type="numpy") with gr.Accordion("Generation settings", open=False): max_new_tokens = gr.Slider( minimum=40, maximum=800, value=DEFAULT_MAX_NEW_TOKENS, step=20, label="Max audio tokens", ) temperature = gr.Slider( minimum=0.1, maximum=1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature", ) top_k = gr.Slider( minimum=0, maximum=200, value=DEFAULT_TOP_K, step=1, label="Top-k", ) seed = gr.Number(value=0, precision=0, label="Seed (0 = random)") gr.Examples( examples=[ ["Nano text to speech", 500, 0.9, 50, 0], ["Created at the University of Maryland", 500, 0.9, 50, 0], ["Trained using only two hundred and forty five hours of speech data", 600, 0.9, 50, 0], ["In the year two thousand and twenty six", 500, 0.9, 50, 0], ["The quick brown fox jumps over the lazy dog.", 500, 0.9, 50, 0], ], inputs=[text_input, max_new_tokens, temperature, top_k, seed], outputs=audio_output, fn=synthesize, cache_examples=False, ) generate_button.click( fn=synthesize, inputs=[text_input, max_new_tokens, temperature, top_k, seed], outputs=audio_output, ) if __name__ == "__main__": demo.queue(default_concurrency_limit=1).launch()