from __future__ import annotations import argparse import functools import json import logging import os import time from dataclasses import dataclass from pathlib import Path import gradio as gr import torch try: import spaces except ImportError: class _SpacesFallback: @staticmethod def GPU(*_args, **_kwargs): def _decorator(func): return func return _decorator spaces = _SpacesFallback() from nano_tts_runtime import DEFAULT_VOICE, NanoTTSService from text_normalization_pipeline import WeTextProcessingManager, prepare_tts_request_texts APP_DIR = Path(__file__).resolve().parent CHECKPOINT_PATH = APP_DIR / "weights" / "tts" AUDIO_TOKENIZER_PATH = APP_DIR / "weights" / "codec" OUTPUT_DIR = Path("/tmp") / "nano-tts-space" PRELOAD_ENV_VAR = "NANO_TTS_PRELOAD_AT_STARTUP" DEMO_METADATA_PATH = APP_DIR / "assets" / "demo.jsonl" MODE_VOICE_CLONE = "voice_clone" @dataclass(frozen=True) class DemoEntry: demo_id: str name: str prompt_audio_path: Path text: str def load_demo_entries() -> list[DemoEntry]: if not DEMO_METADATA_PATH.is_file(): logging.warning("demo metadata file not found: %s", DEMO_METADATA_PATH) return [] demo_entries: list[DemoEntry] = [] for line_index, raw_line in enumerate(DEMO_METADATA_PATH.read_text(encoding="utf-8").splitlines(), start=1): line = raw_line.strip() if not line: continue try: payload = json.loads(line) except Exception: logging.warning("failed to parse demo metadata line=%s", line_index, exc_info=True) continue relative_audio_path = str(payload.get("role", "")).strip() text = str(payload.get("text", "")).strip() if not relative_audio_path or not text: logging.warning("skip invalid demo metadata line=%s role/text missing", line_index) continue prompt_audio_path = (APP_DIR / relative_audio_path).resolve() if not prompt_audio_path.is_file(): logging.warning("skip demo metadata line=%s prompt audio missing: %s", line_index, prompt_audio_path) continue name = str(payload.get("name", "")).strip() or f"Demo {len(demo_entries) + 1}: {prompt_audio_path.stem}" demo_entries.append( DemoEntry( demo_id=f"demo-{len(demo_entries) + 1}", name=name, prompt_audio_path=prompt_audio_path, text=text, ) ) return demo_entries DEMO_ENTRIES = load_demo_entries() DEMO_ENTRY_MAP = {entry.demo_id: entry for entry in DEMO_ENTRIES} DEMO_AUDIO_PATH_MAP = {str(entry.prompt_audio_path): entry for entry in DEMO_ENTRIES} DEMO_ENTRY_NAME_MAP = {entry.name: entry for entry in DEMO_ENTRIES} DEFAULT_DEMO_ENTRY = DEMO_ENTRIES[0] if DEMO_ENTRIES else None DEFAULT_DEMO_CASE_ID = DEFAULT_DEMO_ENTRY.demo_id if DEFAULT_DEMO_ENTRY is not None else "" DEFAULT_DEMO_AUDIO_PATH = str(DEFAULT_DEMO_ENTRY.prompt_audio_path) if DEFAULT_DEMO_ENTRY is not None else "" DEFAULT_DEMO_TEXT = DEFAULT_DEMO_ENTRY.text if DEFAULT_DEMO_ENTRY is not None else "" DEMO_CASE_CHOICES = [(entry.name, entry.demo_id) for entry in DEMO_ENTRIES] def parse_bool_env(name: str, default: bool) -> bool: value = os.getenv(name) if value is None: return default return value.strip().lower() in {"1", "true", "yes", "y", "on"} def parse_port(value: str | None, default: int) -> int: if not value: return default try: return int(value) except ValueError: return default def maybe_delete_file(path: str | Path | None) -> None: if not path: return try: Path(path).unlink(missing_ok=True) except OSError: logging.warning("failed to delete temporary file: %s", path, exc_info=True) def normalize_demo_case_id(demo_case_id: str | None) -> str: normalized = str(demo_case_id or "").strip() if not normalized: return "" if normalized in DEMO_ENTRY_MAP: return normalized matched_entry = DEMO_ENTRY_NAME_MAP.get(normalized) if matched_entry is not None: return matched_entry.demo_id return "" @functools.lru_cache(maxsize=2) def get_tts_service(runtime_has_cuda: bool) -> NanoTTSService: return NanoTTSService( checkpoint_path=CHECKPOINT_PATH, audio_tokenizer_path=AUDIO_TOKENIZER_PATH, device="auto", dtype="auto", attn_implementation="auto", output_dir=OUTPUT_DIR, ) def get_runtime_tts_service() -> NanoTTSService: return get_tts_service(bool(torch.cuda.is_available())) @functools.lru_cache(maxsize=1) def get_text_normalizer_manager() -> WeTextProcessingManager: manager = WeTextProcessingManager() manager.start() return manager def preload_service() -> None: started_at = time.monotonic() service = get_runtime_tts_service() logging.info( "preloading Nano-TTS model checkpoint=%s codec=%s device=%s", CHECKPOINT_PATH, AUDIO_TOKENIZER_PATH, service.device, ) service.get_model() logging.info("Nano-TTS preload finished in %.2fs", time.monotonic() - started_at) def render_mode_hint() -> str: return ( "Current mode: **Voice Clone** \n" "Select a Default Case or upload your own reference audio. Uploaded audio overrides the selected Default Case." ) def resolve_default_demo_entry() -> DemoEntry | None: return DEFAULT_DEMO_ENTRY def resolve_selected_demo_entry(demo_case_id: str | None) -> DemoEntry | None: normalized_demo_case_id = normalize_demo_case_id(demo_case_id) if normalized_demo_case_id: demo_entry = DEMO_ENTRY_MAP.get(normalized_demo_case_id) if demo_entry is not None: return demo_entry return resolve_default_demo_entry() def resolve_effective_prompt_audio_path( prompt_audio_path: str | None, selected_demo_audio_path: str | None, ) -> str | None: if prompt_audio_path: resolved_path = Path(prompt_audio_path).expanduser().resolve() if resolved_path.is_file(): return str(resolved_path) if selected_demo_audio_path: resolved_path = Path(selected_demo_audio_path).expanduser().resolve() if resolved_path.is_file(): return str(resolved_path) demo_entry = resolve_default_demo_entry() if demo_entry is not None: return str(demo_entry.prompt_audio_path) return None def build_prompt_source_text( *, prompt_audio_path: str | None, selected_demo_audio_path: str | None, ) -> str: effective_prompt_audio_path = resolve_effective_prompt_audio_path( prompt_audio_path, selected_demo_audio_path, ) if effective_prompt_audio_path: if prompt_audio_path: return f"Uploaded reference audio: {Path(effective_prompt_audio_path).name}" demo_entry = DEMO_AUDIO_PATH_MAP.get(effective_prompt_audio_path) if demo_entry is not None: return f"Default case: {demo_entry.name}" return f"Default case: {Path(effective_prompt_audio_path).name}" return "No default case available" def refresh_prompt_preview( prompt_audio_path: str | None, selected_demo_audio_path: str | None, ): preview_path = resolve_effective_prompt_audio_path( prompt_audio_path, selected_demo_audio_path, ) prompt_source = build_prompt_source_text( prompt_audio_path=prompt_audio_path, selected_demo_audio_path=selected_demo_audio_path, ) return preview_path, prompt_source def apply_demo_case_selection( demo_case_id: str, prompt_audio_path: str | None, ): demo_entry = resolve_selected_demo_entry(demo_case_id) if demo_entry is None: preview_path, prompt_source = refresh_prompt_preview(prompt_audio_path, "") return ( gr.update(), preview_path, "", prompt_source, ) selected_prompt_path = str(demo_entry.prompt_audio_path) preview_path, prompt_source = refresh_prompt_preview( prompt_audio_path, selected_prompt_path, ) return ( demo_entry.text, preview_path, selected_prompt_path, prompt_source, ) def validate_request( *, text: str, effective_prompt_audio_path: str | None, ) -> str: normalized_text = str(text or "").strip() if not normalized_text: raise ValueError("Please enter text to synthesize.") if not effective_prompt_audio_path: raise ValueError("No reference audio is available. Please select a Default Case or upload prompt audio.") return normalized_text def build_status_text( *, result: dict[str, object], prepared_texts: dict[str, object], reference_source: str, runtime_device: str, ) -> str: text_chunks = result.get("voice_clone_text_chunks") or [] chunk_count = len(text_chunks) if isinstance(text_chunks, list) and text_chunks else 1 return ( f"Done | mode={result['mode']} | ref={reference_source} | elapsed={result['elapsed_seconds']:.2f}s | " f"device={runtime_device} | sample_rate={result['sample_rate']} | " f"attn={result['effective_global_attn_implementation']} | " f"chunks={chunk_count} | normalization={prepared_texts['normalization_method']}" ) def estimate_gpu_duration( *args, **kwargs, ) -> int: text = kwargs.get("text", args[0] if len(args) > 0 else "") max_new_frames = kwargs.get("max_new_frames", args[5] if len(args) > 5 else 375) voice_clone_max_text_tokens = ( kwargs.get("voice_clone_max_text_tokens", args[6] if len(args) > 6 else 75) ) text_len = len(str(text or "").strip()) estimated = 75 + (text_len // 12) + int(max_new_frames) // 8 + int(voice_clone_max_text_tokens) // 10 return max(90, min(240, estimated)) @spaces.GPU(size="large", duration=estimate_gpu_duration) def run_inference( text: str, prompt_audio_path: str | None, selected_demo_audio_path: str | None, enable_wetext_processing: bool, enable_normalize_tts_text: bool, max_new_frames: int, voice_clone_max_text_tokens: int, do_sample: bool, text_temperature: float, text_top_p: float, text_top_k: int, audio_temperature: float, audio_top_p: float, audio_top_k: int, audio_repetition_penalty: float, seed: float | int, ): generated_audio_path: str | None = None try: service = get_runtime_tts_service() text_normalizer_manager = get_text_normalizer_manager() if enable_wetext_processing else None effective_prompt_audio_path = resolve_effective_prompt_audio_path( prompt_audio_path, selected_demo_audio_path, ) normalized_text = validate_request( text=text, effective_prompt_audio_path=effective_prompt_audio_path, ) prepared_texts = prepare_tts_request_texts( text=normalized_text, prompt_text="", voice=DEFAULT_VOICE, enable_wetext=bool(enable_wetext_processing), enable_normalize_tts_text=bool(enable_normalize_tts_text), text_normalizer_manager=text_normalizer_manager, ) prompt_source = build_prompt_source_text( prompt_audio_path=prompt_audio_path, selected_demo_audio_path=selected_demo_audio_path, ) normalized_seed = None if seed not in {"", None}: resolved_seed = int(seed) if resolved_seed != 0: normalized_seed = resolved_seed result = service.synthesize( text=str(prepared_texts["text"]), mode=MODE_VOICE_CLONE, voice=DEFAULT_VOICE, prompt_audio_path=effective_prompt_audio_path or None, max_new_frames=int(max_new_frames), voice_clone_max_text_tokens=int(voice_clone_max_text_tokens), do_sample=bool(do_sample), text_temperature=float(text_temperature), text_top_p=float(text_top_p), text_top_k=int(text_top_k), audio_temperature=float(audio_temperature), audio_top_p=float(audio_top_p), audio_top_k=int(audio_top_k), audio_repetition_penalty=float(audio_repetition_penalty), seed=normalized_seed, ) generated_audio_path = str(result["audio_path"]) return ( (int(result["sample_rate"]), result["waveform_numpy"]), build_status_text( result=result, prepared_texts=prepared_texts, reference_source=prompt_source, runtime_device=str(service.device), ), str(prepared_texts["normalized_text"]), prompt_source, ) except Exception as exc: logging.exception("Nano-TTS inference failed") raise gr.Error(str(exc)) from exc finally: maybe_delete_file(generated_audio_path) def build_demo(): with gr.Blocks(title="Nano-TTS ZeroGPU Space") as demo: gr.Markdown( """
weights/tts and weights/codec.
ZeroGPU requests a GPU only during inference, and audio is returned after full synthesis.
MOSS-TTS-Nano is a zero-shot TTS model with approximately 100M parameters, supporting 48 kHz stereo input and output, streaming generation, multilingual synthesis, and long-form text. It is developed by the OpenMOSS Team. For more details, see the GitHub repository and blog.