MOSS-TTS-Nano / app.py
Kuangwei Chen
update normalization
ee0d583
from __future__ import annotations
import argparse
import functools
import json
import logging
import os
import time
from dataclasses import dataclass
from pathlib import Path
import gradio as gr
import torch
try:
import spaces
except ImportError:
class _SpacesFallback:
@staticmethod
def GPU(*_args, **_kwargs):
def _decorator(func):
return func
return _decorator
spaces = _SpacesFallback()
from nano_tts_runtime import DEFAULT_VOICE, NanoTTSService
from text_normalization_pipeline import WeTextProcessingManager, prepare_tts_request_texts
APP_DIR = Path(__file__).resolve().parent
CHECKPOINT_PATH = APP_DIR / "weights" / "tts"
AUDIO_TOKENIZER_PATH = APP_DIR / "weights" / "codec"
OUTPUT_DIR = Path("/tmp") / "nano-tts-space"
PRELOAD_ENV_VAR = "NANO_TTS_PRELOAD_AT_STARTUP"
DEMO_METADATA_PATH = APP_DIR / "assets" / "demo.jsonl"
MODE_VOICE_CLONE = "voice_clone"
@dataclass(frozen=True)
class DemoEntry:
demo_id: str
name: str
prompt_audio_path: Path
text: str
def load_demo_entries() -> list[DemoEntry]:
if not DEMO_METADATA_PATH.is_file():
logging.warning("demo metadata file not found: %s", DEMO_METADATA_PATH)
return []
demo_entries: list[DemoEntry] = []
for line_index, raw_line in enumerate(DEMO_METADATA_PATH.read_text(encoding="utf-8").splitlines(), start=1):
line = raw_line.strip()
if not line:
continue
try:
payload = json.loads(line)
except Exception:
logging.warning("failed to parse demo metadata line=%s", line_index, exc_info=True)
continue
relative_audio_path = str(payload.get("role", "")).strip()
text = str(payload.get("text", "")).strip()
if not relative_audio_path or not text:
logging.warning("skip invalid demo metadata line=%s role/text missing", line_index)
continue
prompt_audio_path = (APP_DIR / relative_audio_path).resolve()
if not prompt_audio_path.is_file():
logging.warning("skip demo metadata line=%s prompt audio missing: %s", line_index, prompt_audio_path)
continue
name = str(payload.get("name", "")).strip() or f"Demo {len(demo_entries) + 1}: {prompt_audio_path.stem}"
demo_entries.append(
DemoEntry(
demo_id=f"demo-{len(demo_entries) + 1}",
name=name,
prompt_audio_path=prompt_audio_path,
text=text,
)
)
return demo_entries
DEMO_ENTRIES = load_demo_entries()
DEMO_ENTRY_MAP = {entry.demo_id: entry for entry in DEMO_ENTRIES}
DEMO_AUDIO_PATH_MAP = {str(entry.prompt_audio_path): entry for entry in DEMO_ENTRIES}
DEMO_ENTRY_NAME_MAP = {entry.name: entry for entry in DEMO_ENTRIES}
DEFAULT_DEMO_ENTRY = DEMO_ENTRIES[0] if DEMO_ENTRIES else None
DEFAULT_DEMO_CASE_ID = DEFAULT_DEMO_ENTRY.demo_id if DEFAULT_DEMO_ENTRY is not None else ""
DEFAULT_DEMO_AUDIO_PATH = str(DEFAULT_DEMO_ENTRY.prompt_audio_path) if DEFAULT_DEMO_ENTRY is not None else ""
DEFAULT_DEMO_TEXT = DEFAULT_DEMO_ENTRY.text if DEFAULT_DEMO_ENTRY is not None else ""
DEMO_CASE_CHOICES = [(entry.name, entry.demo_id) for entry in DEMO_ENTRIES]
def parse_bool_env(name: str, default: bool) -> bool:
value = os.getenv(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "y", "on"}
def parse_port(value: str | None, default: int) -> int:
if not value:
return default
try:
return int(value)
except ValueError:
return default
def maybe_delete_file(path: str | Path | None) -> None:
if not path:
return
try:
Path(path).unlink(missing_ok=True)
except OSError:
logging.warning("failed to delete temporary file: %s", path, exc_info=True)
def normalize_demo_case_id(demo_case_id: str | None) -> str:
normalized = str(demo_case_id or "").strip()
if not normalized:
return ""
if normalized in DEMO_ENTRY_MAP:
return normalized
matched_entry = DEMO_ENTRY_NAME_MAP.get(normalized)
if matched_entry is not None:
return matched_entry.demo_id
return ""
@functools.lru_cache(maxsize=2)
def get_tts_service(runtime_has_cuda: bool) -> NanoTTSService:
return NanoTTSService(
checkpoint_path=CHECKPOINT_PATH,
audio_tokenizer_path=AUDIO_TOKENIZER_PATH,
device="auto",
dtype="auto",
attn_implementation="auto",
output_dir=OUTPUT_DIR,
)
def get_runtime_tts_service() -> NanoTTSService:
return get_tts_service(bool(torch.cuda.is_available()))
@functools.lru_cache(maxsize=1)
def get_text_normalizer_manager() -> WeTextProcessingManager:
manager = WeTextProcessingManager()
manager.start()
return manager
def preload_service() -> None:
started_at = time.monotonic()
service = get_runtime_tts_service()
logging.info(
"preloading Nano-TTS model checkpoint=%s codec=%s device=%s",
CHECKPOINT_PATH,
AUDIO_TOKENIZER_PATH,
service.device,
)
service.get_model()
logging.info("Nano-TTS preload finished in %.2fs", time.monotonic() - started_at)
def render_mode_hint() -> str:
return (
"Current mode: **Voice Clone** \n"
"Select a Default Case or upload your own reference audio. Uploaded audio overrides the selected Default Case."
)
def resolve_default_demo_entry() -> DemoEntry | None:
return DEFAULT_DEMO_ENTRY
def resolve_selected_demo_entry(demo_case_id: str | None) -> DemoEntry | None:
normalized_demo_case_id = normalize_demo_case_id(demo_case_id)
if normalized_demo_case_id:
demo_entry = DEMO_ENTRY_MAP.get(normalized_demo_case_id)
if demo_entry is not None:
return demo_entry
return resolve_default_demo_entry()
def resolve_effective_prompt_audio_path(
prompt_audio_path: str | None,
selected_demo_audio_path: str | None,
) -> str | None:
if prompt_audio_path:
resolved_path = Path(prompt_audio_path).expanduser().resolve()
if resolved_path.is_file():
return str(resolved_path)
if selected_demo_audio_path:
resolved_path = Path(selected_demo_audio_path).expanduser().resolve()
if resolved_path.is_file():
return str(resolved_path)
demo_entry = resolve_default_demo_entry()
if demo_entry is not None:
return str(demo_entry.prompt_audio_path)
return None
def build_prompt_source_text(
*,
prompt_audio_path: str | None,
selected_demo_audio_path: str | None,
) -> str:
effective_prompt_audio_path = resolve_effective_prompt_audio_path(
prompt_audio_path,
selected_demo_audio_path,
)
if effective_prompt_audio_path:
if prompt_audio_path:
return f"Uploaded reference audio: {Path(effective_prompt_audio_path).name}"
demo_entry = DEMO_AUDIO_PATH_MAP.get(effective_prompt_audio_path)
if demo_entry is not None:
return f"Default case: {demo_entry.name}"
return f"Default case: {Path(effective_prompt_audio_path).name}"
return "No default case available"
def refresh_prompt_preview(
prompt_audio_path: str | None,
selected_demo_audio_path: str | None,
):
preview_path = resolve_effective_prompt_audio_path(
prompt_audio_path,
selected_demo_audio_path,
)
prompt_source = build_prompt_source_text(
prompt_audio_path=prompt_audio_path,
selected_demo_audio_path=selected_demo_audio_path,
)
return preview_path, prompt_source
def apply_demo_case_selection(
demo_case_id: str,
prompt_audio_path: str | None,
):
demo_entry = resolve_selected_demo_entry(demo_case_id)
if demo_entry is None:
preview_path, prompt_source = refresh_prompt_preview(prompt_audio_path, "")
return (
gr.update(),
preview_path,
"",
prompt_source,
)
selected_prompt_path = str(demo_entry.prompt_audio_path)
preview_path, prompt_source = refresh_prompt_preview(
prompt_audio_path,
selected_prompt_path,
)
return (
demo_entry.text,
preview_path,
selected_prompt_path,
prompt_source,
)
def validate_request(
*,
text: str,
effective_prompt_audio_path: str | None,
) -> str:
normalized_text = str(text or "").strip()
if not normalized_text:
raise ValueError("Please enter text to synthesize.")
if not effective_prompt_audio_path:
raise ValueError("No reference audio is available. Please select a Default Case or upload prompt audio.")
return normalized_text
def build_status_text(
*,
result: dict[str, object],
prepared_texts: dict[str, object],
reference_source: str,
runtime_device: str,
) -> str:
text_chunks = result.get("voice_clone_text_chunks") or []
chunk_count = len(text_chunks) if isinstance(text_chunks, list) and text_chunks else 1
return (
f"Done | mode={result['mode']} | ref={reference_source} | elapsed={result['elapsed_seconds']:.2f}s | "
f"device={runtime_device} | sample_rate={result['sample_rate']} | "
f"attn={result['effective_global_attn_implementation']} | "
f"chunks={chunk_count} | normalization={prepared_texts['normalization_method']}"
)
def estimate_gpu_duration(
*args,
**kwargs,
) -> int:
text = kwargs.get("text", args[0] if len(args) > 0 else "")
max_new_frames = kwargs.get("max_new_frames", args[5] if len(args) > 5 else 375)
voice_clone_max_text_tokens = (
kwargs.get("voice_clone_max_text_tokens", args[6] if len(args) > 6 else 75)
)
text_len = len(str(text or "").strip())
estimated = 75 + (text_len // 12) + int(max_new_frames) // 8 + int(voice_clone_max_text_tokens) // 10
return max(90, min(240, estimated))
@spaces.GPU(size="large", duration=estimate_gpu_duration)
def run_inference(
text: str,
prompt_audio_path: str | None,
selected_demo_audio_path: str | None,
enable_wetext_processing: bool,
enable_normalize_tts_text: bool,
max_new_frames: int,
voice_clone_max_text_tokens: int,
do_sample: bool,
text_temperature: float,
text_top_p: float,
text_top_k: int,
audio_temperature: float,
audio_top_p: float,
audio_top_k: int,
audio_repetition_penalty: float,
seed: float | int,
):
generated_audio_path: str | None = None
try:
service = get_runtime_tts_service()
text_normalizer_manager = get_text_normalizer_manager() if enable_wetext_processing else None
effective_prompt_audio_path = resolve_effective_prompt_audio_path(
prompt_audio_path,
selected_demo_audio_path,
)
normalized_text = validate_request(
text=text,
effective_prompt_audio_path=effective_prompt_audio_path,
)
prepared_texts = prepare_tts_request_texts(
text=normalized_text,
prompt_text="",
voice=DEFAULT_VOICE,
enable_wetext=bool(enable_wetext_processing),
enable_normalize_tts_text=bool(enable_normalize_tts_text),
text_normalizer_manager=text_normalizer_manager,
)
prompt_source = build_prompt_source_text(
prompt_audio_path=prompt_audio_path,
selected_demo_audio_path=selected_demo_audio_path,
)
normalized_seed = None
if seed not in {"", None}:
resolved_seed = int(seed)
if resolved_seed != 0:
normalized_seed = resolved_seed
result = service.synthesize(
text=str(prepared_texts["text"]),
mode=MODE_VOICE_CLONE,
voice=DEFAULT_VOICE,
prompt_audio_path=effective_prompt_audio_path or None,
max_new_frames=int(max_new_frames),
voice_clone_max_text_tokens=int(voice_clone_max_text_tokens),
do_sample=bool(do_sample),
text_temperature=float(text_temperature),
text_top_p=float(text_top_p),
text_top_k=int(text_top_k),
audio_temperature=float(audio_temperature),
audio_top_p=float(audio_top_p),
audio_top_k=int(audio_top_k),
audio_repetition_penalty=float(audio_repetition_penalty),
seed=normalized_seed,
)
generated_audio_path = str(result["audio_path"])
return (
(int(result["sample_rate"]), result["waveform_numpy"]),
build_status_text(
result=result,
prepared_texts=prepared_texts,
reference_source=prompt_source,
runtime_device=str(service.device),
),
str(prepared_texts["normalized_text"]),
prompt_source,
)
except Exception as exc:
logging.exception("Nano-TTS inference failed")
raise gr.Error(str(exc)) from exc
finally:
maybe_delete_file(generated_audio_path)
def build_demo():
with gr.Blocks(title="Nano-TTS ZeroGPU Space") as demo:
gr.Markdown(
"""
<div class="app-card">
<div class="app-title">Nano-TTS ZeroGPU</div>
<div class="app-subtitle">
Hugging Face Space edition backed by local <code>weights/tts</code> and <code>weights/codec</code>.
ZeroGPU requests a GPU only during inference, and audio is returned after full synthesis.
</div>
<p>
MOSS-TTS-Nano is a zero-shot TTS model with approximately 100M parameters, supporting 48 kHz stereo
input and output, streaming generation, multilingual synthesis, and long-form text. It is developed by
the <a href="https://openmoss.github.io/" target="_blank" rel="noopener noreferrer">OpenMOSS Team</a>.
For more details, see the
<a href="https://github.com/OpenMOSS/MOSS-TTS-Nano" target="_blank" rel="noopener noreferrer">GitHub repository</a>
and
<a href="https://openmoss.github.io/MOSS-TTS-Nano-Demo/" target="_blank" rel="noopener noreferrer">blog</a>.
</p>
</div>
"""
)
with gr.Row(equal_height=False):
with gr.Column(scale=3):
demo_case = gr.Dropdown(
choices=DEMO_CASE_CHOICES,
value=DEFAULT_DEMO_CASE_ID,
label="Default Case",
info="Select a built-in case to auto-fill the text and prompt preview.",
allow_custom_value=True,
)
text = gr.Textbox(
label="Target Text",
lines=10,
value=DEFAULT_DEMO_TEXT,
placeholder="Enter the text to synthesize.",
)
mode_hint = gr.Markdown(render_mode_hint())
prompt_audio = gr.Audio(
label="Reference Audio Upload (optional; overrides Default Case)",
type="filepath",
sources=["upload"],
)
prompt_preview = gr.Audio(
label="Effective Prompt Preview",
value=DEFAULT_DEMO_AUDIO_PATH or None,
type="filepath",
interactive=False,
)
gr.Markdown(
"Runtime device and backbone are fixed by the Space and are not user-configurable. Uploaded reference audio overrides the selected Default Case."
)
with gr.Accordion("Advanced Parameters", open=False):
enable_wetext_processing = gr.Checkbox(
value=True,
label="Enable WeTextProcessing",
)
enable_normalize_tts_text = gr.Checkbox(
value=True,
label="Enable normalize_tts_text",
)
max_new_frames = gr.Slider(
minimum=64,
maximum=512,
step=16,
value=375,
label="max_new_frames",
)
voice_clone_max_text_tokens = gr.Slider(
minimum=25,
maximum=200,
step=5,
value=75,
label="voice_clone_max_text_tokens",
)
do_sample = gr.Checkbox(
value=True,
label="Enable Sampling",
)
seed = gr.Number(
value=0,
precision=0,
label="Seed (0 = random)",
)
text_temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
step=0.05,
value=1.0,
label="text_temperature",
)
text_top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
step=0.01,
value=1.0,
label="text_top_p",
)
text_top_k = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=50,
label="text_top_k",
)
audio_temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
step=0.05,
value=0.8,
label="audio_temperature",
)
audio_top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
step=0.01,
value=0.95,
label="audio_top_p",
)
audio_top_k = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=25,
label="audio_top_k",
)
audio_repetition_penalty = gr.Slider(
minimum=0.8,
maximum=2.0,
step=0.05,
value=1.2,
label="audio_repetition_penalty",
)
run_btn = gr.Button("Generate Speech", variant="primary", elem_id="run-btn")
with gr.Column(scale=2):
output_audio = gr.Audio(label="Output Audio", type="numpy")
status = gr.Textbox(label="Status", lines=4, interactive=False)
normalized_text = gr.Textbox(label="Normalized Text", lines=6, interactive=False)
prompt_source = gr.Textbox(
label="Prompt Source",
value=build_prompt_source_text(
prompt_audio_path=None,
selected_demo_audio_path=DEFAULT_DEMO_AUDIO_PATH or None,
),
lines=4,
interactive=False,
)
selected_demo_audio_path = gr.State(DEFAULT_DEMO_AUDIO_PATH)
demo_case.change(
fn=apply_demo_case_selection,
inputs=[demo_case, prompt_audio],
outputs=[text, prompt_preview, selected_demo_audio_path, prompt_source],
)
prompt_audio.change(
fn=refresh_prompt_preview,
inputs=[prompt_audio, selected_demo_audio_path],
outputs=[prompt_preview, prompt_source],
)
run_btn.click(
fn=run_inference,
inputs=[
text,
prompt_audio,
selected_demo_audio_path,
enable_wetext_processing,
enable_normalize_tts_text,
max_new_frames,
voice_clone_max_text_tokens,
do_sample,
text_temperature,
text_top_p,
text_top_k,
audio_temperature,
audio_top_p,
audio_top_k,
audio_repetition_penalty,
seed,
],
outputs=[output_audio, status, normalized_text, prompt_source],
)
return demo
def main() -> None:
parser = argparse.ArgumentParser(description="Nano-TTS ZeroGPU Hugging Face Space")
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT", "7860"))),
)
parser.add_argument("--share", action="store_true")
args = parser.parse_args()
logging.basicConfig(
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
level=logging.INFO,
)
args.host = os.getenv("GRADIO_SERVER_NAME", args.host)
args.port = parse_port(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT")), args.port)
get_text_normalizer_manager()
preload_enabled = parse_bool_env(PRELOAD_ENV_VAR, default=not bool(os.getenv("SPACE_ID")))
if preload_enabled:
preload_service()
else:
logging.info("Skipping model preload (set %s=1 to enable).", PRELOAD_ENV_VAR)
demo = build_demo()
demo.queue(max_size=4, default_concurrency_limit=4).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
ssr_mode=False,
)
if __name__ == "__main__":
main()