Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import sys | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| ROOT = Path(__file__).resolve().parent | |
| WAVTOKENIZER_ROOT = ROOT / "third_party" / "WavTokenizer" | |
| if str(WAVTOKENIZER_ROOT) not in sys.path: | |
| sys.path.insert(0, str(WAVTOKENIZER_ROOT)) | |
| from decoder.pretrained import WavTokenizer # noqa: E402 | |
| from model import GPT, GPTConfig # noqa: E402 | |
| from tokenizer import create_joint_tokenizer # noqa: E402 | |
| SAMPLE_RATE = 24_000 | |
| DEFAULT_MAX_NEW_TOKENS = 500 | |
| DEFAULT_TEMPERATURE = 0.9 | |
| DEFAULT_TOP_K = 50 | |
| MAX_PROMPT_CHARS = 500 | |
| CHECKPOINT_PATH = ROOT / "checkpoints" / "ckpt_025000_inference.pt" | |
| TEXT_TOKENIZER_PATH = ROOT / "text_tokenizer" / "libritts_bpe.json" | |
| WAVTOKENIZER_REPO_ID = "novateur/WavTokenizer" | |
| WAVTOKENIZER_CONFIG = "wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn.yaml" | |
| WAVTOKENIZER_CHECKPOINT = "WavTokenizer_small_600_24k_4096.ckpt" | |
| class ModelBundle: | |
| model: GPT | |
| joint_tokenizer: object | |
| device: torch.device | |
| def _device() -> torch.device: | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def _clean_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: | |
| unwanted_prefix = "_orig_mod." | |
| cleaned = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith(unwanted_prefix): | |
| key = key[len(unwanted_prefix) :] | |
| cleaned[key] = value | |
| return cleaned | |
| def load_bundle() -> ModelBundle: | |
| if not CHECKPOINT_PATH.exists(): | |
| raise FileNotFoundError( | |
| f"Missing nanoTTS checkpoint at {CHECKPOINT_PATH}. " | |
| "Upload checkpoints/ckpt_025000_inference.pt to this Space first." | |
| ) | |
| device = _device() | |
| wavtokenizer_config_path = hf_hub_download( | |
| repo_id=WAVTOKENIZER_REPO_ID, | |
| filename=WAVTOKENIZER_CONFIG, | |
| ) | |
| wavtokenizer_checkpoint_path = hf_hub_download( | |
| repo_id=WAVTOKENIZER_REPO_ID, | |
| filename=WAVTOKENIZER_CHECKPOINT, | |
| ) | |
| wavtokenizer = WavTokenizer.from_pretrained0802( | |
| wavtokenizer_config_path, | |
| wavtokenizer_checkpoint_path, | |
| ).to("cpu") | |
| wavtokenizer.bandwidth_id = torch.tensor([0], device="cpu") | |
| joint_tokenizer = create_joint_tokenizer(str(TEXT_TOKENIZER_PATH), wavtokenizer) | |
| checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu") | |
| model_args = checkpoint["model_args"] | |
| model = GPT(GPTConfig(**model_args), joint_tokenizer) | |
| model.load_state_dict(_clean_state_dict(checkpoint["model"])) | |
| model.to(device) | |
| model.eval() | |
| return ModelBundle(model=model, joint_tokenizer=joint_tokenizer, device=device) | |
| def _validate_prompt(text: str, max_new_tokens: int) -> str: | |
| text = " ".join((text or "").strip().split()) | |
| if not text: | |
| raise gr.Error("Enter some text to synthesize.") | |
| if len(text) > MAX_PROMPT_CHARS: | |
| raise gr.Error(f"Keep the prompt under {MAX_PROMPT_CHARS} characters.") | |
| if max_new_tokens < 40: | |
| raise gr.Error("Generate at least 40 audio tokens.") | |
| return text | |
| def synthesize( | |
| text: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_k: int, | |
| seed: int, | |
| ) -> tuple[int, np.ndarray]: | |
| max_new_tokens = int(max_new_tokens) | |
| text = _validate_prompt(text, max_new_tokens) | |
| bundle = load_bundle() | |
| if seed > 0: | |
| torch.manual_seed(int(seed)) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(int(seed)) | |
| prompt = f"<BOS>{text}<AUDIO_START>" | |
| prompt_ids = bundle.joint_tokenizer.encode_text(prompt) | |
| if len(prompt_ids) >= bundle.model.config.block_size: | |
| raise gr.Error("Prompt is too long for this checkpoint's context window.") | |
| top_k_value = None if int(top_k) <= 0 else int(top_k) | |
| batch = torch.tensor(prompt_ids, dtype=torch.long, device=bundle.device).unsqueeze(0) | |
| with torch.inference_mode(): | |
| generated = bundle.model.generate( | |
| batch, | |
| max_new_tokens=max_new_tokens, | |
| temperature=float(temperature), | |
| top_k=top_k_value, | |
| ) | |
| audio = bundle.joint_tokenizer.decode(generated.detach().cpu()) | |
| if audio is None or audio.numel() == 0: | |
| raise gr.Error("The model did not generate any audio tokens. Try again or change the seed.") | |
| audio_np = audio.squeeze(0).float().cpu().numpy() | |
| audio_np = np.clip(audio_np, -1.0, 1.0) | |
| return SAMPLE_RATE, audio_np | |
| with gr.Blocks(title="nanoTTS") as demo: | |
| gr.Markdown( | |
| "# nanoTTS\n" | |
| "A minimal GPT-style text-to-speech model trained on LibriTTS. Code available at [psandovalsegura/nanoTTS](https://github.com/psandovalsegura/nanoTTS)." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| text_input = gr.Textbox( | |
| label="Text", | |
| placeholder="Type a short sentence...", | |
| lines=4, | |
| max_lines=8, | |
| value="Hello friend! This is nano text to speech.", | |
| ) | |
| generate_button = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=2): | |
| audio_output = gr.Audio(label="Generated audio", type="numpy") | |
| with gr.Accordion("Generation settings", open=False): | |
| max_new_tokens = gr.Slider( | |
| minimum=40, | |
| maximum=800, | |
| value=DEFAULT_MAX_NEW_TOKENS, | |
| step=20, | |
| label="Max audio tokens", | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.5, | |
| value=DEFAULT_TEMPERATURE, | |
| step=0.05, | |
| label="Temperature", | |
| ) | |
| top_k = gr.Slider( | |
| minimum=0, | |
| maximum=200, | |
| value=DEFAULT_TOP_K, | |
| step=1, | |
| label="Top-k", | |
| ) | |
| seed = gr.Number(value=0, precision=0, label="Seed (0 = random)") | |
| gr.Examples( | |
| examples=[ | |
| ["Nano text to speech", 500, 0.9, 50, 0], | |
| ["Created at the University of Maryland", 500, 0.9, 50, 0], | |
| ["Trained using only two hundred and forty five hours of speech data", 600, 0.9, 50, 0], | |
| ["In the year two thousand and twenty six", 500, 0.9, 50, 0], | |
| ["The quick brown fox jumps over the lazy dog.", 500, 0.9, 50, 0], | |
| ], | |
| inputs=[text_input, max_new_tokens, temperature, top_k, seed], | |
| outputs=audio_output, | |
| fn=synthesize, | |
| cache_examples=False, | |
| ) | |
| generate_button.click( | |
| fn=synthesize, | |
| inputs=[text_input, max_new_tokens, temperature, top_k, seed], | |
| outputs=audio_output, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=1).launch() | |