Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import functools | |
| import json | |
| import logging | |
| import os | |
| import time | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| try: | |
| import spaces | |
| except ImportError: | |
| class _SpacesFallback: | |
| def GPU(*_args, **_kwargs): | |
| def _decorator(func): | |
| return func | |
| return _decorator | |
| spaces = _SpacesFallback() | |
| from nano_tts_runtime import DEFAULT_VOICE, NanoTTSService | |
| from text_normalization_pipeline import WeTextProcessingManager, prepare_tts_request_texts | |
| APP_DIR = Path(__file__).resolve().parent | |
| CHECKPOINT_PATH = APP_DIR / "weights" / "tts" | |
| AUDIO_TOKENIZER_PATH = APP_DIR / "weights" / "codec" | |
| OUTPUT_DIR = Path("/tmp") / "nano-tts-space" | |
| PRELOAD_ENV_VAR = "NANO_TTS_PRELOAD_AT_STARTUP" | |
| DEMO_METADATA_PATH = APP_DIR / "assets" / "demo.jsonl" | |
| MODE_VOICE_CLONE = "voice_clone" | |
| class DemoEntry: | |
| demo_id: str | |
| name: str | |
| prompt_audio_path: Path | |
| text: str | |
| def load_demo_entries() -> list[DemoEntry]: | |
| if not DEMO_METADATA_PATH.is_file(): | |
| logging.warning("demo metadata file not found: %s", DEMO_METADATA_PATH) | |
| return [] | |
| demo_entries: list[DemoEntry] = [] | |
| for line_index, raw_line in enumerate(DEMO_METADATA_PATH.read_text(encoding="utf-8").splitlines(), start=1): | |
| line = raw_line.strip() | |
| if not line: | |
| continue | |
| try: | |
| payload = json.loads(line) | |
| except Exception: | |
| logging.warning("failed to parse demo metadata line=%s", line_index, exc_info=True) | |
| continue | |
| relative_audio_path = str(payload.get("role", "")).strip() | |
| text = str(payload.get("text", "")).strip() | |
| if not relative_audio_path or not text: | |
| logging.warning("skip invalid demo metadata line=%s role/text missing", line_index) | |
| continue | |
| prompt_audio_path = (APP_DIR / relative_audio_path).resolve() | |
| if not prompt_audio_path.is_file(): | |
| logging.warning("skip demo metadata line=%s prompt audio missing: %s", line_index, prompt_audio_path) | |
| continue | |
| name = str(payload.get("name", "")).strip() or f"Demo {len(demo_entries) + 1}: {prompt_audio_path.stem}" | |
| demo_entries.append( | |
| DemoEntry( | |
| demo_id=f"demo-{len(demo_entries) + 1}", | |
| name=name, | |
| prompt_audio_path=prompt_audio_path, | |
| text=text, | |
| ) | |
| ) | |
| return demo_entries | |
| DEMO_ENTRIES = load_demo_entries() | |
| DEMO_ENTRY_MAP = {entry.demo_id: entry for entry in DEMO_ENTRIES} | |
| DEMO_AUDIO_PATH_MAP = {str(entry.prompt_audio_path): entry for entry in DEMO_ENTRIES} | |
| DEMO_ENTRY_NAME_MAP = {entry.name: entry for entry in DEMO_ENTRIES} | |
| DEFAULT_DEMO_ENTRY = DEMO_ENTRIES[0] if DEMO_ENTRIES else None | |
| DEFAULT_DEMO_CASE_ID = DEFAULT_DEMO_ENTRY.demo_id if DEFAULT_DEMO_ENTRY is not None else "" | |
| DEFAULT_DEMO_AUDIO_PATH = str(DEFAULT_DEMO_ENTRY.prompt_audio_path) if DEFAULT_DEMO_ENTRY is not None else "" | |
| DEFAULT_DEMO_TEXT = DEFAULT_DEMO_ENTRY.text if DEFAULT_DEMO_ENTRY is not None else "" | |
| DEMO_CASE_CHOICES = [(entry.name, entry.demo_id) for entry in DEMO_ENTRIES] | |
| def parse_bool_env(name: str, default: bool) -> bool: | |
| value = os.getenv(name) | |
| if value is None: | |
| return default | |
| return value.strip().lower() in {"1", "true", "yes", "y", "on"} | |
| def parse_port(value: str | None, default: int) -> int: | |
| if not value: | |
| return default | |
| try: | |
| return int(value) | |
| except ValueError: | |
| return default | |
| def maybe_delete_file(path: str | Path | None) -> None: | |
| if not path: | |
| return | |
| try: | |
| Path(path).unlink(missing_ok=True) | |
| except OSError: | |
| logging.warning("failed to delete temporary file: %s", path, exc_info=True) | |
| def normalize_demo_case_id(demo_case_id: str | None) -> str: | |
| normalized = str(demo_case_id or "").strip() | |
| if not normalized: | |
| return "" | |
| if normalized in DEMO_ENTRY_MAP: | |
| return normalized | |
| matched_entry = DEMO_ENTRY_NAME_MAP.get(normalized) | |
| if matched_entry is not None: | |
| return matched_entry.demo_id | |
| return "" | |
| def get_tts_service(runtime_has_cuda: bool) -> NanoTTSService: | |
| return NanoTTSService( | |
| checkpoint_path=CHECKPOINT_PATH, | |
| audio_tokenizer_path=AUDIO_TOKENIZER_PATH, | |
| device="auto", | |
| dtype="auto", | |
| attn_implementation="auto", | |
| output_dir=OUTPUT_DIR, | |
| ) | |
| def get_runtime_tts_service() -> NanoTTSService: | |
| return get_tts_service(bool(torch.cuda.is_available())) | |
| def get_text_normalizer_manager() -> WeTextProcessingManager: | |
| manager = WeTextProcessingManager() | |
| manager.start() | |
| return manager | |
| def preload_service() -> None: | |
| started_at = time.monotonic() | |
| service = get_runtime_tts_service() | |
| logging.info( | |
| "preloading Nano-TTS model checkpoint=%s codec=%s device=%s", | |
| CHECKPOINT_PATH, | |
| AUDIO_TOKENIZER_PATH, | |
| service.device, | |
| ) | |
| service.get_model() | |
| logging.info("Nano-TTS preload finished in %.2fs", time.monotonic() - started_at) | |
| def render_mode_hint() -> str: | |
| return ( | |
| "Current mode: **Voice Clone** \n" | |
| "Select a Default Case or upload your own reference audio. Uploaded audio overrides the selected Default Case." | |
| ) | |
| def resolve_default_demo_entry() -> DemoEntry | None: | |
| return DEFAULT_DEMO_ENTRY | |
| def resolve_selected_demo_entry(demo_case_id: str | None) -> DemoEntry | None: | |
| normalized_demo_case_id = normalize_demo_case_id(demo_case_id) | |
| if normalized_demo_case_id: | |
| demo_entry = DEMO_ENTRY_MAP.get(normalized_demo_case_id) | |
| if demo_entry is not None: | |
| return demo_entry | |
| return resolve_default_demo_entry() | |
| def resolve_effective_prompt_audio_path( | |
| prompt_audio_path: str | None, | |
| selected_demo_audio_path: str | None, | |
| ) -> str | None: | |
| if prompt_audio_path: | |
| resolved_path = Path(prompt_audio_path).expanduser().resolve() | |
| if resolved_path.is_file(): | |
| return str(resolved_path) | |
| if selected_demo_audio_path: | |
| resolved_path = Path(selected_demo_audio_path).expanduser().resolve() | |
| if resolved_path.is_file(): | |
| return str(resolved_path) | |
| demo_entry = resolve_default_demo_entry() | |
| if demo_entry is not None: | |
| return str(demo_entry.prompt_audio_path) | |
| return None | |
| def build_prompt_source_text( | |
| *, | |
| prompt_audio_path: str | None, | |
| selected_demo_audio_path: str | None, | |
| ) -> str: | |
| effective_prompt_audio_path = resolve_effective_prompt_audio_path( | |
| prompt_audio_path, | |
| selected_demo_audio_path, | |
| ) | |
| if effective_prompt_audio_path: | |
| if prompt_audio_path: | |
| return f"Uploaded reference audio: {Path(effective_prompt_audio_path).name}" | |
| demo_entry = DEMO_AUDIO_PATH_MAP.get(effective_prompt_audio_path) | |
| if demo_entry is not None: | |
| return f"Default case: {demo_entry.name}" | |
| return f"Default case: {Path(effective_prompt_audio_path).name}" | |
| return "No default case available" | |
| def refresh_prompt_preview( | |
| prompt_audio_path: str | None, | |
| selected_demo_audio_path: str | None, | |
| ): | |
| preview_path = resolve_effective_prompt_audio_path( | |
| prompt_audio_path, | |
| selected_demo_audio_path, | |
| ) | |
| prompt_source = build_prompt_source_text( | |
| prompt_audio_path=prompt_audio_path, | |
| selected_demo_audio_path=selected_demo_audio_path, | |
| ) | |
| return preview_path, prompt_source | |
| def apply_demo_case_selection( | |
| demo_case_id: str, | |
| prompt_audio_path: str | None, | |
| ): | |
| demo_entry = resolve_selected_demo_entry(demo_case_id) | |
| if demo_entry is None: | |
| preview_path, prompt_source = refresh_prompt_preview(prompt_audio_path, "") | |
| return ( | |
| gr.update(), | |
| preview_path, | |
| "", | |
| prompt_source, | |
| ) | |
| selected_prompt_path = str(demo_entry.prompt_audio_path) | |
| preview_path, prompt_source = refresh_prompt_preview( | |
| prompt_audio_path, | |
| selected_prompt_path, | |
| ) | |
| return ( | |
| demo_entry.text, | |
| preview_path, | |
| selected_prompt_path, | |
| prompt_source, | |
| ) | |
| def validate_request( | |
| *, | |
| text: str, | |
| effective_prompt_audio_path: str | None, | |
| ) -> str: | |
| normalized_text = str(text or "").strip() | |
| if not normalized_text: | |
| raise ValueError("Please enter text to synthesize.") | |
| if not effective_prompt_audio_path: | |
| raise ValueError("No reference audio is available. Please select a Default Case or upload prompt audio.") | |
| return normalized_text | |
| def build_status_text( | |
| *, | |
| result: dict[str, object], | |
| prepared_texts: dict[str, object], | |
| reference_source: str, | |
| runtime_device: str, | |
| ) -> str: | |
| text_chunks = result.get("voice_clone_text_chunks") or [] | |
| chunk_count = len(text_chunks) if isinstance(text_chunks, list) and text_chunks else 1 | |
| return ( | |
| f"Done | mode={result['mode']} | ref={reference_source} | elapsed={result['elapsed_seconds']:.2f}s | " | |
| f"device={runtime_device} | sample_rate={result['sample_rate']} | " | |
| f"attn={result['effective_global_attn_implementation']} | " | |
| f"chunks={chunk_count} | normalization={prepared_texts['normalization_method']}" | |
| ) | |
| def estimate_gpu_duration( | |
| *args, | |
| **kwargs, | |
| ) -> int: | |
| text = kwargs.get("text", args[0] if len(args) > 0 else "") | |
| max_new_frames = kwargs.get("max_new_frames", args[5] if len(args) > 5 else 375) | |
| voice_clone_max_text_tokens = ( | |
| kwargs.get("voice_clone_max_text_tokens", args[6] if len(args) > 6 else 75) | |
| ) | |
| text_len = len(str(text or "").strip()) | |
| estimated = 75 + (text_len // 12) + int(max_new_frames) // 8 + int(voice_clone_max_text_tokens) // 10 | |
| return max(90, min(240, estimated)) | |
| def run_inference( | |
| text: str, | |
| prompt_audio_path: str | None, | |
| selected_demo_audio_path: str | None, | |
| enable_wetext_processing: bool, | |
| enable_normalize_tts_text: bool, | |
| max_new_frames: int, | |
| voice_clone_max_text_tokens: int, | |
| do_sample: bool, | |
| text_temperature: float, | |
| text_top_p: float, | |
| text_top_k: int, | |
| audio_temperature: float, | |
| audio_top_p: float, | |
| audio_top_k: int, | |
| audio_repetition_penalty: float, | |
| seed: float | int, | |
| ): | |
| generated_audio_path: str | None = None | |
| try: | |
| service = get_runtime_tts_service() | |
| text_normalizer_manager = get_text_normalizer_manager() if enable_wetext_processing else None | |
| effective_prompt_audio_path = resolve_effective_prompt_audio_path( | |
| prompt_audio_path, | |
| selected_demo_audio_path, | |
| ) | |
| normalized_text = validate_request( | |
| text=text, | |
| effective_prompt_audio_path=effective_prompt_audio_path, | |
| ) | |
| prepared_texts = prepare_tts_request_texts( | |
| text=normalized_text, | |
| prompt_text="", | |
| voice=DEFAULT_VOICE, | |
| enable_wetext=bool(enable_wetext_processing), | |
| enable_normalize_tts_text=bool(enable_normalize_tts_text), | |
| text_normalizer_manager=text_normalizer_manager, | |
| ) | |
| prompt_source = build_prompt_source_text( | |
| prompt_audio_path=prompt_audio_path, | |
| selected_demo_audio_path=selected_demo_audio_path, | |
| ) | |
| normalized_seed = None | |
| if seed not in {"", None}: | |
| resolved_seed = int(seed) | |
| if resolved_seed != 0: | |
| normalized_seed = resolved_seed | |
| result = service.synthesize( | |
| text=str(prepared_texts["text"]), | |
| mode=MODE_VOICE_CLONE, | |
| voice=DEFAULT_VOICE, | |
| prompt_audio_path=effective_prompt_audio_path or None, | |
| max_new_frames=int(max_new_frames), | |
| voice_clone_max_text_tokens=int(voice_clone_max_text_tokens), | |
| do_sample=bool(do_sample), | |
| text_temperature=float(text_temperature), | |
| text_top_p=float(text_top_p), | |
| text_top_k=int(text_top_k), | |
| audio_temperature=float(audio_temperature), | |
| audio_top_p=float(audio_top_p), | |
| audio_top_k=int(audio_top_k), | |
| audio_repetition_penalty=float(audio_repetition_penalty), | |
| seed=normalized_seed, | |
| ) | |
| generated_audio_path = str(result["audio_path"]) | |
| return ( | |
| (int(result["sample_rate"]), result["waveform_numpy"]), | |
| build_status_text( | |
| result=result, | |
| prepared_texts=prepared_texts, | |
| reference_source=prompt_source, | |
| runtime_device=str(service.device), | |
| ), | |
| str(prepared_texts["normalized_text"]), | |
| prompt_source, | |
| ) | |
| except Exception as exc: | |
| logging.exception("Nano-TTS inference failed") | |
| raise gr.Error(str(exc)) from exc | |
| finally: | |
| maybe_delete_file(generated_audio_path) | |
| def build_demo(): | |
| with gr.Blocks(title="Nano-TTS ZeroGPU Space") as demo: | |
| gr.Markdown( | |
| """ | |
| <div class="app-card"> | |
| <div class="app-title">Nano-TTS ZeroGPU</div> | |
| <div class="app-subtitle"> | |
| Hugging Face Space edition backed by local <code>weights/tts</code> and <code>weights/codec</code>. | |
| ZeroGPU requests a GPU only during inference, and audio is returned after full synthesis. | |
| </div> | |
| <p> | |
| MOSS-TTS-Nano is a zero-shot TTS model with approximately 100M parameters, supporting 48 kHz stereo | |
| input and output, streaming generation, multilingual synthesis, and long-form text. It is developed by | |
| the <a href="https://openmoss.github.io/" target="_blank" rel="noopener noreferrer">OpenMOSS Team</a>. | |
| For more details, see the | |
| <a href="https://github.com/OpenMOSS/MOSS-TTS-Nano" target="_blank" rel="noopener noreferrer">GitHub repository</a> | |
| and | |
| <a href="https://openmoss.github.io/MOSS-TTS-Nano-Demo/" target="_blank" rel="noopener noreferrer">blog</a>. | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=3): | |
| demo_case = gr.Dropdown( | |
| choices=DEMO_CASE_CHOICES, | |
| value=DEFAULT_DEMO_CASE_ID, | |
| label="Default Case", | |
| info="Select a built-in case to auto-fill the text and prompt preview.", | |
| allow_custom_value=True, | |
| ) | |
| text = gr.Textbox( | |
| label="Target Text", | |
| lines=10, | |
| value=DEFAULT_DEMO_TEXT, | |
| placeholder="Enter the text to synthesize.", | |
| ) | |
| mode_hint = gr.Markdown(render_mode_hint()) | |
| prompt_audio = gr.Audio( | |
| label="Reference Audio Upload (optional; overrides Default Case)", | |
| type="filepath", | |
| sources=["upload"], | |
| ) | |
| prompt_preview = gr.Audio( | |
| label="Effective Prompt Preview", | |
| value=DEFAULT_DEMO_AUDIO_PATH or None, | |
| type="filepath", | |
| interactive=False, | |
| ) | |
| gr.Markdown( | |
| "Runtime device and backbone are fixed by the Space and are not user-configurable. Uploaded reference audio overrides the selected Default Case." | |
| ) | |
| with gr.Accordion("Advanced Parameters", open=False): | |
| enable_wetext_processing = gr.Checkbox( | |
| value=True, | |
| label="Enable WeTextProcessing", | |
| ) | |
| enable_normalize_tts_text = gr.Checkbox( | |
| value=True, | |
| label="Enable normalize_tts_text", | |
| ) | |
| max_new_frames = gr.Slider( | |
| minimum=64, | |
| maximum=512, | |
| step=16, | |
| value=375, | |
| label="max_new_frames", | |
| ) | |
| voice_clone_max_text_tokens = gr.Slider( | |
| minimum=25, | |
| maximum=200, | |
| step=5, | |
| value=75, | |
| label="voice_clone_max_text_tokens", | |
| ) | |
| do_sample = gr.Checkbox( | |
| value=True, | |
| label="Enable Sampling", | |
| ) | |
| seed = gr.Number( | |
| value=0, | |
| precision=0, | |
| label="Seed (0 = random)", | |
| ) | |
| text_temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0, | |
| label="text_temperature", | |
| ) | |
| text_top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| step=0.01, | |
| value=1.0, | |
| label="text_top_p", | |
| ) | |
| text_top_k = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=50, | |
| label="text_top_k", | |
| ) | |
| audio_temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| step=0.05, | |
| value=0.8, | |
| label="audio_temperature", | |
| ) | |
| audio_top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| step=0.01, | |
| value=0.95, | |
| label="audio_top_p", | |
| ) | |
| audio_top_k = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=25, | |
| label="audio_top_k", | |
| ) | |
| audio_repetition_penalty = gr.Slider( | |
| minimum=0.8, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.2, | |
| label="audio_repetition_penalty", | |
| ) | |
| run_btn = gr.Button("Generate Speech", variant="primary", elem_id="run-btn") | |
| with gr.Column(scale=2): | |
| output_audio = gr.Audio(label="Output Audio", type="numpy") | |
| status = gr.Textbox(label="Status", lines=4, interactive=False) | |
| normalized_text = gr.Textbox(label="Normalized Text", lines=6, interactive=False) | |
| prompt_source = gr.Textbox( | |
| label="Prompt Source", | |
| value=build_prompt_source_text( | |
| prompt_audio_path=None, | |
| selected_demo_audio_path=DEFAULT_DEMO_AUDIO_PATH or None, | |
| ), | |
| lines=4, | |
| interactive=False, | |
| ) | |
| selected_demo_audio_path = gr.State(DEFAULT_DEMO_AUDIO_PATH) | |
| demo_case.change( | |
| fn=apply_demo_case_selection, | |
| inputs=[demo_case, prompt_audio], | |
| outputs=[text, prompt_preview, selected_demo_audio_path, prompt_source], | |
| ) | |
| prompt_audio.change( | |
| fn=refresh_prompt_preview, | |
| inputs=[prompt_audio, selected_demo_audio_path], | |
| outputs=[prompt_preview, prompt_source], | |
| ) | |
| run_btn.click( | |
| fn=run_inference, | |
| inputs=[ | |
| text, | |
| prompt_audio, | |
| selected_demo_audio_path, | |
| enable_wetext_processing, | |
| enable_normalize_tts_text, | |
| max_new_frames, | |
| voice_clone_max_text_tokens, | |
| do_sample, | |
| text_temperature, | |
| text_top_p, | |
| text_top_k, | |
| audio_temperature, | |
| audio_top_p, | |
| audio_top_k, | |
| audio_repetition_penalty, | |
| seed, | |
| ], | |
| outputs=[output_audio, status, normalized_text, prompt_source], | |
| ) | |
| return demo | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Nano-TTS ZeroGPU Hugging Face Space") | |
| parser.add_argument("--host", type=str, default="0.0.0.0") | |
| parser.add_argument( | |
| "--port", | |
| type=int, | |
| default=int(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT", "7860"))), | |
| ) | |
| parser.add_argument("--share", action="store_true") | |
| args = parser.parse_args() | |
| logging.basicConfig( | |
| format="%(asctime)s %(levelname)s %(name)s: %(message)s", | |
| level=logging.INFO, | |
| ) | |
| args.host = os.getenv("GRADIO_SERVER_NAME", args.host) | |
| args.port = parse_port(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT")), args.port) | |
| get_text_normalizer_manager() | |
| preload_enabled = parse_bool_env(PRELOAD_ENV_VAR, default=not bool(os.getenv("SPACE_ID"))) | |
| if preload_enabled: | |
| preload_service() | |
| else: | |
| logging.info("Skipping model preload (set %s=1 to enable).", PRELOAD_ENV_VAR) | |
| demo = build_demo() | |
| demo.queue(max_size=4, default_concurrency_limit=4).launch( | |
| server_name=args.host, | |
| server_port=args.port, | |
| share=args.share, | |
| ssr_mode=False, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |