from __future__ import annotations import argparse import functools import json import logging import os import time from dataclasses import dataclass from pathlib import Path import gradio as gr import torch try: import spaces except ImportError: class _SpacesFallback: @staticmethod def GPU(*_args, **_kwargs): def _decorator(func): return func return _decorator spaces = _SpacesFallback() from nano_tts_runtime import DEFAULT_VOICE, NanoTTSService from text_normalization_pipeline import WeTextProcessingManager, prepare_tts_request_texts APP_DIR = Path(__file__).resolve().parent CHECKPOINT_PATH = APP_DIR / "weights" / "tts" AUDIO_TOKENIZER_PATH = APP_DIR / "weights" / "codec" OUTPUT_DIR = Path("/tmp") / "nano-tts-space" PRELOAD_ENV_VAR = "NANO_TTS_PRELOAD_AT_STARTUP" DEMO_METADATA_PATH = APP_DIR / "assets" / "demo.jsonl" MODE_VOICE_CLONE = "voice_clone" @dataclass(frozen=True) class DemoEntry: demo_id: str name: str prompt_audio_path: Path text: str def load_demo_entries() -> list[DemoEntry]: if not DEMO_METADATA_PATH.is_file(): logging.warning("demo metadata file not found: %s", DEMO_METADATA_PATH) return [] demo_entries: list[DemoEntry] = [] for line_index, raw_line in enumerate(DEMO_METADATA_PATH.read_text(encoding="utf-8").splitlines(), start=1): line = raw_line.strip() if not line: continue try: payload = json.loads(line) except Exception: logging.warning("failed to parse demo metadata line=%s", line_index, exc_info=True) continue relative_audio_path = str(payload.get("role", "")).strip() text = str(payload.get("text", "")).strip() if not relative_audio_path or not text: logging.warning("skip invalid demo metadata line=%s role/text missing", line_index) continue prompt_audio_path = (APP_DIR / relative_audio_path).resolve() if not prompt_audio_path.is_file(): logging.warning("skip demo metadata line=%s prompt audio missing: %s", line_index, prompt_audio_path) continue name = str(payload.get("name", "")).strip() or f"Demo {len(demo_entries) + 1}: {prompt_audio_path.stem}" demo_entries.append( DemoEntry( demo_id=f"demo-{len(demo_entries) + 1}", name=name, prompt_audio_path=prompt_audio_path, text=text, ) ) return demo_entries DEMO_ENTRIES = load_demo_entries() DEMO_ENTRY_MAP = {entry.demo_id: entry for entry in DEMO_ENTRIES} DEMO_AUDIO_PATH_MAP = {str(entry.prompt_audio_path): entry for entry in DEMO_ENTRIES} DEMO_ENTRY_NAME_MAP = {entry.name: entry for entry in DEMO_ENTRIES} DEFAULT_DEMO_ENTRY = DEMO_ENTRIES[0] if DEMO_ENTRIES else None DEFAULT_DEMO_CASE_ID = DEFAULT_DEMO_ENTRY.demo_id if DEFAULT_DEMO_ENTRY is not None else "" DEFAULT_DEMO_AUDIO_PATH = str(DEFAULT_DEMO_ENTRY.prompt_audio_path) if DEFAULT_DEMO_ENTRY is not None else "" DEFAULT_DEMO_TEXT = DEFAULT_DEMO_ENTRY.text if DEFAULT_DEMO_ENTRY is not None else "" DEMO_CASE_CHOICES = [(entry.name, entry.demo_id) for entry in DEMO_ENTRIES] def parse_bool_env(name: str, default: bool) -> bool: value = os.getenv(name) if value is None: return default return value.strip().lower() in {"1", "true", "yes", "y", "on"} def parse_port(value: str | None, default: int) -> int: if not value: return default try: return int(value) except ValueError: return default def maybe_delete_file(path: str | Path | None) -> None: if not path: return try: Path(path).unlink(missing_ok=True) except OSError: logging.warning("failed to delete temporary file: %s", path, exc_info=True) def normalize_demo_case_id(demo_case_id: str | None) -> str: normalized = str(demo_case_id or "").strip() if not normalized: return "" if normalized in DEMO_ENTRY_MAP: return normalized matched_entry = DEMO_ENTRY_NAME_MAP.get(normalized) if matched_entry is not None: return matched_entry.demo_id return "" @functools.lru_cache(maxsize=2) def get_tts_service(runtime_has_cuda: bool) -> NanoTTSService: return NanoTTSService( checkpoint_path=CHECKPOINT_PATH, audio_tokenizer_path=AUDIO_TOKENIZER_PATH, device="auto", dtype="auto", attn_implementation="auto", output_dir=OUTPUT_DIR, ) def get_runtime_tts_service() -> NanoTTSService: return get_tts_service(bool(torch.cuda.is_available())) @functools.lru_cache(maxsize=1) def get_text_normalizer_manager() -> WeTextProcessingManager: manager = WeTextProcessingManager() manager.start() return manager def preload_service() -> None: started_at = time.monotonic() service = get_runtime_tts_service() logging.info( "preloading Nano-TTS model checkpoint=%s codec=%s device=%s", CHECKPOINT_PATH, AUDIO_TOKENIZER_PATH, service.device, ) service.get_model() logging.info("Nano-TTS preload finished in %.2fs", time.monotonic() - started_at) def render_mode_hint() -> str: return ( "Current mode: **Voice Clone** \n" "Select a Default Case or upload your own reference audio. Uploaded audio overrides the selected Default Case." ) def resolve_default_demo_entry() -> DemoEntry | None: return DEFAULT_DEMO_ENTRY def resolve_selected_demo_entry(demo_case_id: str | None) -> DemoEntry | None: normalized_demo_case_id = normalize_demo_case_id(demo_case_id) if normalized_demo_case_id: demo_entry = DEMO_ENTRY_MAP.get(normalized_demo_case_id) if demo_entry is not None: return demo_entry return resolve_default_demo_entry() def resolve_effective_prompt_audio_path( prompt_audio_path: str | None, selected_demo_audio_path: str | None, ) -> str | None: if prompt_audio_path: resolved_path = Path(prompt_audio_path).expanduser().resolve() if resolved_path.is_file(): return str(resolved_path) if selected_demo_audio_path: resolved_path = Path(selected_demo_audio_path).expanduser().resolve() if resolved_path.is_file(): return str(resolved_path) demo_entry = resolve_default_demo_entry() if demo_entry is not None: return str(demo_entry.prompt_audio_path) return None def build_prompt_source_text( *, prompt_audio_path: str | None, selected_demo_audio_path: str | None, ) -> str: effective_prompt_audio_path = resolve_effective_prompt_audio_path( prompt_audio_path, selected_demo_audio_path, ) if effective_prompt_audio_path: if prompt_audio_path: return f"Uploaded reference audio: {Path(effective_prompt_audio_path).name}" demo_entry = DEMO_AUDIO_PATH_MAP.get(effective_prompt_audio_path) if demo_entry is not None: return f"Default case: {demo_entry.name}" return f"Default case: {Path(effective_prompt_audio_path).name}" return "No default case available" def refresh_prompt_preview( prompt_audio_path: str | None, selected_demo_audio_path: str | None, ): preview_path = resolve_effective_prompt_audio_path( prompt_audio_path, selected_demo_audio_path, ) prompt_source = build_prompt_source_text( prompt_audio_path=prompt_audio_path, selected_demo_audio_path=selected_demo_audio_path, ) return preview_path, prompt_source def apply_demo_case_selection( demo_case_id: str, prompt_audio_path: str | None, ): demo_entry = resolve_selected_demo_entry(demo_case_id) if demo_entry is None: preview_path, prompt_source = refresh_prompt_preview(prompt_audio_path, "") return ( gr.update(), preview_path, "", prompt_source, ) selected_prompt_path = str(demo_entry.prompt_audio_path) preview_path, prompt_source = refresh_prompt_preview( prompt_audio_path, selected_prompt_path, ) return ( demo_entry.text, preview_path, selected_prompt_path, prompt_source, ) def validate_request( *, text: str, effective_prompt_audio_path: str | None, ) -> str: normalized_text = str(text or "").strip() if not normalized_text: raise ValueError("Please enter text to synthesize.") if not effective_prompt_audio_path: raise ValueError("No reference audio is available. Please select a Default Case or upload prompt audio.") return normalized_text def build_status_text( *, result: dict[str, object], prepared_texts: dict[str, object], reference_source: str, runtime_device: str, ) -> str: text_chunks = result.get("voice_clone_text_chunks") or [] chunk_count = len(text_chunks) if isinstance(text_chunks, list) and text_chunks else 1 return ( f"Done | mode={result['mode']} | ref={reference_source} | elapsed={result['elapsed_seconds']:.2f}s | " f"device={runtime_device} | sample_rate={result['sample_rate']} | " f"attn={result['effective_global_attn_implementation']} | " f"chunks={chunk_count} | normalization={prepared_texts['normalization_method']}" ) def estimate_gpu_duration( *args, **kwargs, ) -> int: text = kwargs.get("text", args[0] if len(args) > 0 else "") max_new_frames = kwargs.get("max_new_frames", args[5] if len(args) > 5 else 375) voice_clone_max_text_tokens = ( kwargs.get("voice_clone_max_text_tokens", args[6] if len(args) > 6 else 75) ) text_len = len(str(text or "").strip()) estimated = 75 + (text_len // 12) + int(max_new_frames) // 8 + int(voice_clone_max_text_tokens) // 10 return max(90, min(240, estimated)) @spaces.GPU(size="large", duration=estimate_gpu_duration) def run_inference( text: str, prompt_audio_path: str | None, selected_demo_audio_path: str | None, enable_wetext_processing: bool, enable_normalize_tts_text: bool, max_new_frames: int, voice_clone_max_text_tokens: int, do_sample: bool, text_temperature: float, text_top_p: float, text_top_k: int, audio_temperature: float, audio_top_p: float, audio_top_k: int, audio_repetition_penalty: float, seed: float | int, ): generated_audio_path: str | None = None try: service = get_runtime_tts_service() text_normalizer_manager = get_text_normalizer_manager() if enable_wetext_processing else None effective_prompt_audio_path = resolve_effective_prompt_audio_path( prompt_audio_path, selected_demo_audio_path, ) normalized_text = validate_request( text=text, effective_prompt_audio_path=effective_prompt_audio_path, ) prepared_texts = prepare_tts_request_texts( text=normalized_text, prompt_text="", voice=DEFAULT_VOICE, enable_wetext=bool(enable_wetext_processing), enable_normalize_tts_text=bool(enable_normalize_tts_text), text_normalizer_manager=text_normalizer_manager, ) prompt_source = build_prompt_source_text( prompt_audio_path=prompt_audio_path, selected_demo_audio_path=selected_demo_audio_path, ) normalized_seed = None if seed not in {"", None}: resolved_seed = int(seed) if resolved_seed != 0: normalized_seed = resolved_seed result = service.synthesize( text=str(prepared_texts["text"]), mode=MODE_VOICE_CLONE, voice=DEFAULT_VOICE, prompt_audio_path=effective_prompt_audio_path or None, max_new_frames=int(max_new_frames), voice_clone_max_text_tokens=int(voice_clone_max_text_tokens), do_sample=bool(do_sample), text_temperature=float(text_temperature), text_top_p=float(text_top_p), text_top_k=int(text_top_k), audio_temperature=float(audio_temperature), audio_top_p=float(audio_top_p), audio_top_k=int(audio_top_k), audio_repetition_penalty=float(audio_repetition_penalty), seed=normalized_seed, ) generated_audio_path = str(result["audio_path"]) return ( (int(result["sample_rate"]), result["waveform_numpy"]), build_status_text( result=result, prepared_texts=prepared_texts, reference_source=prompt_source, runtime_device=str(service.device), ), str(prepared_texts["normalized_text"]), prompt_source, ) except Exception as exc: logging.exception("Nano-TTS inference failed") raise gr.Error(str(exc)) from exc finally: maybe_delete_file(generated_audio_path) def build_demo(): with gr.Blocks(title="Nano-TTS ZeroGPU Space") as demo: gr.Markdown( """
Nano-TTS ZeroGPU
Hugging Face Space edition backed by local weights/tts and weights/codec. ZeroGPU requests a GPU only during inference, and audio is returned after full synthesis.

MOSS-TTS-Nano is a zero-shot TTS model with approximately 100M parameters, supporting 48 kHz stereo input and output, streaming generation, multilingual synthesis, and long-form text. It is developed by the OpenMOSS Team. For more details, see the GitHub repository and blog.

""" ) with gr.Row(equal_height=False): with gr.Column(scale=3): demo_case = gr.Dropdown( choices=DEMO_CASE_CHOICES, value=DEFAULT_DEMO_CASE_ID, label="Default Case", info="Select a built-in case to auto-fill the text and prompt preview.", allow_custom_value=True, ) text = gr.Textbox( label="Target Text", lines=10, value=DEFAULT_DEMO_TEXT, placeholder="Enter the text to synthesize.", ) mode_hint = gr.Markdown(render_mode_hint()) prompt_audio = gr.Audio( label="Reference Audio Upload (optional; overrides Default Case)", type="filepath", sources=["upload"], ) prompt_preview = gr.Audio( label="Effective Prompt Preview", value=DEFAULT_DEMO_AUDIO_PATH or None, type="filepath", interactive=False, ) gr.Markdown( "Runtime device and backbone are fixed by the Space and are not user-configurable. Uploaded reference audio overrides the selected Default Case." ) with gr.Accordion("Advanced Parameters", open=False): enable_wetext_processing = gr.Checkbox( value=True, label="Enable WeTextProcessing", ) enable_normalize_tts_text = gr.Checkbox( value=True, label="Enable normalize_tts_text", ) max_new_frames = gr.Slider( minimum=64, maximum=512, step=16, value=375, label="max_new_frames", ) voice_clone_max_text_tokens = gr.Slider( minimum=25, maximum=200, step=5, value=75, label="voice_clone_max_text_tokens", ) do_sample = gr.Checkbox( value=True, label="Enable Sampling", ) seed = gr.Number( value=0, precision=0, label="Seed (0 = random)", ) text_temperature = gr.Slider( minimum=0.1, maximum=2.0, step=0.05, value=1.0, label="text_temperature", ) text_top_p = gr.Slider( minimum=0.1, maximum=1.0, step=0.01, value=1.0, label="text_top_p", ) text_top_k = gr.Slider( minimum=1, maximum=100, step=1, value=50, label="text_top_k", ) audio_temperature = gr.Slider( minimum=0.1, maximum=2.0, step=0.05, value=0.8, label="audio_temperature", ) audio_top_p = gr.Slider( minimum=0.1, maximum=1.0, step=0.01, value=0.95, label="audio_top_p", ) audio_top_k = gr.Slider( minimum=1, maximum=100, step=1, value=25, label="audio_top_k", ) audio_repetition_penalty = gr.Slider( minimum=0.8, maximum=2.0, step=0.05, value=1.2, label="audio_repetition_penalty", ) run_btn = gr.Button("Generate Speech", variant="primary", elem_id="run-btn") with gr.Column(scale=2): output_audio = gr.Audio(label="Output Audio", type="numpy") status = gr.Textbox(label="Status", lines=4, interactive=False) normalized_text = gr.Textbox(label="Normalized Text", lines=6, interactive=False) prompt_source = gr.Textbox( label="Prompt Source", value=build_prompt_source_text( prompt_audio_path=None, selected_demo_audio_path=DEFAULT_DEMO_AUDIO_PATH or None, ), lines=4, interactive=False, ) selected_demo_audio_path = gr.State(DEFAULT_DEMO_AUDIO_PATH) demo_case.change( fn=apply_demo_case_selection, inputs=[demo_case, prompt_audio], outputs=[text, prompt_preview, selected_demo_audio_path, prompt_source], ) prompt_audio.change( fn=refresh_prompt_preview, inputs=[prompt_audio, selected_demo_audio_path], outputs=[prompt_preview, prompt_source], ) run_btn.click( fn=run_inference, inputs=[ text, prompt_audio, selected_demo_audio_path, enable_wetext_processing, enable_normalize_tts_text, max_new_frames, voice_clone_max_text_tokens, do_sample, text_temperature, text_top_p, text_top_k, audio_temperature, audio_top_p, audio_top_k, audio_repetition_penalty, seed, ], outputs=[output_audio, status, normalized_text, prompt_source], ) return demo def main() -> None: parser = argparse.ArgumentParser(description="Nano-TTS ZeroGPU Hugging Face Space") parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument( "--port", type=int, default=int(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT", "7860"))), ) parser.add_argument("--share", action="store_true") args = parser.parse_args() logging.basicConfig( format="%(asctime)s %(levelname)s %(name)s: %(message)s", level=logging.INFO, ) args.host = os.getenv("GRADIO_SERVER_NAME", args.host) args.port = parse_port(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT")), args.port) get_text_normalizer_manager() preload_enabled = parse_bool_env(PRELOAD_ENV_VAR, default=not bool(os.getenv("SPACE_ID"))) if preload_enabled: preload_service() else: logging.info("Skipping model preload (set %s=1 to enable).", PRELOAD_ENV_VAR) demo = build_demo() demo.queue(max_size=4, default_concurrency_limit=4).launch( server_name=args.host, server_port=args.port, share=args.share, ssr_mode=False, ) if __name__ == "__main__": main()