import spaces import argparse import base64 import functools import json import sys import threading import time from collections import OrderedDict from dataclasses import dataclass from pathlib import Path from typing import Iterator, Sequence import gradio as gr import numpy as np import os os.environ["TORCHDYNAMO_DISABLE"] = "1" import torch import torchaudio import torch._dynamo from transformers import AutoModel, AutoTokenizer from mossttsrealtime import MossTTSRealtime, MossTTSRealtimeProcessor from mossttsrealtime.streaming_mossttsrealtime import ( AudioStreamDecoder, MossTTSRealtimeInference, MossTTSRealtimeStreamingSession, ) torch._dynamo.config.cache_size_limit = 64 APP_DIR = Path(__file__).resolve().parent AUDIO_DIR = APP_DIR / "asset" LOG_DIR = APP_DIR / "logs" SAMPLE_RATE = 24000 CODEC_MODEL_PATH = "OpenMOSS-Team/MOSS-Audio-Tokenizer" MODEL_PATH = "OpenMOSS-Team/MOSS-TTS-Realtime" TOKENIZER_PATH = "OpenMOSS-Team/MOSS-TTS-Realtime" PROMPT_WAV = "asset/prompt_audio.mp3" USER_WAV = "asset/user1.wav" WARMUP_POLL_INTERVAL_SECONDS = 0.5 DEFAULT_REPETITION_WINDOW = 50 WARMUP_STEP_TOKENS = DEFAULT_REPETITION_WINDOW + 1 WARMUP_USER_TEXT = "Hello!" WARMUP_BASE_ASSISTANT_TEXT = ( "This startup warmup request primes the streaming text to speech path " "so the first real user request avoids the cold compile stall." ) def _apply_seed(seed: int | None) -> None: if seed is None: return # ZeroGPU: avoid touching torch.cuda outside the managed GPU call. torch.manual_seed(seed) def _load_audio(path: Path, target_sample_rate: int = SAMPLE_RATE) -> torch.Tensor: wav, sr = torchaudio.load(path) if sr != target_sample_rate: wav = torchaudio.functional.resample(wav, sr, target_sample_rate) if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) return wav def _load_codec(device: torch.device, codec_model_path: str): codec = AutoModel.from_pretrained(codec_model_path, trust_remote_code=True).eval() return codec.to(device) def _extract_codes(encode_result): if isinstance(encode_result, dict): codes = encode_result["audio_codes"] elif isinstance(encode_result, (list, tuple)) and encode_result: codes = encode_result[0] else: codes = encode_result if isinstance(codes, np.ndarray): codes = torch.from_numpy(codes) if isinstance(codes, torch.Tensor) and codes.dim() == 3: if codes.shape[1] == 1: codes = codes[:, 0, :] elif codes.shape[0] == 1: codes = codes[0] else: raise ValueError(f"Unsupported 3D audio code shape: {tuple(codes.shape)}") return codes @dataclass(frozen=True) class BackendPaths: model_path: str tokenizer_path: str codec_model_path: str device_str: str attn_impl: str @dataclass(frozen=True) class GenerationConfig: temperature: float top_p: float top_k: int repetition_penalty: float repetition_window: int do_sample: bool max_length: int seed: int | None @dataclass(frozen=True) class StreamingConfig: text_chunk_tokens: int input_delay: float decode_chunk_frames: int decode_overlap_frames: int chunk_duration: float prebuffer_seconds: float buffer_threshold_seconds: float = 0.0 @dataclass(frozen=True) class StreamingRequest: user_text: str assistant_text: str prompt_audio: str | None user_audio: str | None use_default_prompt: bool use_default_user: bool generation: GenerationConfig streaming: StreamingConfig backend: BackendPaths @dataclass(frozen=True) class StreamEvent: message: str audio: tuple[int, np.ndarray] | None = None @dataclass(frozen=True) class WarmupSnapshot: state: str progress: float message: str detail: str | None = None error: str | None = None @property def ready(self) -> bool: return self.state == "ready" @property def failed(self) -> bool: return self.state == "failed" def _make_log_path(prefix: str) -> Path: LOG_DIR.mkdir(parents=True, exist_ok=True) stamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) return LOG_DIR / f"{prefix}_{stamp}_{time.time_ns() % 1_000_000_000:09d}.jsonl" def _compute_rtf_metrics(sample_count: int, sample_rate: int, started_at: float) -> dict[str, float | None]: elapsed_s = max(0.0, time.monotonic() - started_at) audio_s = float(sample_count) / float(sample_rate) if sample_count > 0 and sample_rate > 0 else 0.0 rtf = (elapsed_s / audio_s) if audio_s > 0 else None return { "elapsed_s": elapsed_s, "audio_s": audio_s, "rtf": rtf, } class StreamRTFLogger: def __init__(self, path: Path, started_at: float): self.path = path self.started_at = started_at self.chunk_count = 0 self.sample_rate = SAMPLE_RATE self.samples_emitted = 0 @classmethod def create(cls, request: "StreamingRequest", started_at: float) -> "StreamRTFLogger": logger = cls(_make_log_path("rtf"), started_at) logger.log_request_started(request) print(f"[MossTTSRealtime][rtf-log] {logger.path}", flush=True) return logger def log_request_started(self, request: "StreamingRequest") -> None: self._append( { "event": "request_started", "user_text_chars": len(request.user_text), "assistant_text_chars": len(request.assistant_text), "text_chunk_tokens": request.streaming.text_chunk_tokens, "decode_chunk_frames": request.streaming.decode_chunk_frames, "decode_overlap_frames": request.streaming.decode_overlap_frames, "chunk_duration_s": request.streaming.chunk_duration, "prebuffer_seconds": request.streaming.prebuffer_seconds, "temperature": request.generation.temperature, "top_p": request.generation.top_p, "top_k": request.generation.top_k, "repetition_penalty": request.generation.repetition_penalty, "repetition_window": request.generation.repetition_window, "do_sample": request.generation.do_sample, "max_length": request.generation.max_length, "seed": request.generation.seed, "device": request.backend.device_str, "attn_implementation": request.backend.attn_impl, } ) def log_chunk( self, *, event_message: str, sample_rate: int, chunk: np.ndarray, first_audio_time: float | None, ) -> None: chunk = np.asarray(chunk).reshape(-1) if chunk.size == 0: return self.chunk_count += 1 self.sample_rate = int(sample_rate) self.samples_emitted += int(chunk.size) metrics = _compute_rtf_metrics(self.samples_emitted, self.sample_rate, self.started_at) record = { "event": "stream_chunk", "message": event_message, "chunk_idx": self.chunk_count, "chunk_audio_s": float(chunk.size) / float(self.sample_rate), "audio_s_emitted": metrics["audio_s"], "elapsed_s": metrics["elapsed_s"], "rtf": metrics["rtf"], } if first_audio_time is not None: record["time_to_first_audio_ms"] = max(0.0, (first_audio_time - self.started_at) * 1000.0) self._append(record) def log_completion(self, *, first_audio_time: float | None) -> None: metrics = _compute_rtf_metrics(self.samples_emitted, self.sample_rate, self.started_at) record = { "event": "stream_complete", "chunk_count": self.chunk_count, "audio_s_total": metrics["audio_s"], "elapsed_s": metrics["elapsed_s"], "rtf": metrics["rtf"], } if first_audio_time is not None: record["time_to_first_audio_ms"] = max(0.0, (first_audio_time - self.started_at) * 1000.0) self._append(record) def log_no_audio(self) -> None: metrics = _compute_rtf_metrics(0, self.sample_rate, self.started_at) self._append( { "event": "stream_complete", "chunk_count": 0, "audio_s_total": 0.0, "elapsed_s": metrics["elapsed_s"], "rtf": None, "warning": "No audio chunks emitted.", } ) def log_error(self, exc: Exception, *, first_audio_time: float | None) -> None: metrics = _compute_rtf_metrics(self.samples_emitted, self.sample_rate, self.started_at) record = { "event": "stream_error", "error_type": type(exc).__name__, "error": str(exc), "chunk_count": self.chunk_count, "audio_s_emitted": metrics["audio_s"], "elapsed_s": metrics["elapsed_s"], "rtf": metrics["rtf"], } if first_audio_time is not None: record["time_to_first_audio_ms"] = max(0.0, (first_audio_time - self.started_at) * 1000.0) self._append(record) def _append(self, payload: dict[str, object]) -> None: record = { "ts": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), **payload, } with self.path.open("a", encoding="utf-8") as handle: handle.write(json.dumps(record, ensure_ascii=False) + "\n") class TokenChunkStream: def __init__( self, tokens: Sequence[int], chunk_size: int, ): self._tokens = list(tokens) self._chunk_size = int(chunk_size) def __iter__(self) -> Iterator[list[int]]: if not self._tokens: return step = len(self._tokens) if self._chunk_size <= 0 else self._chunk_size for idx in range(0, len(self._tokens), step): yield self._tokens[idx : idx + step] class BufferedAudioTracker: def __init__(self, sample_rate: int): self.sample_rate = sample_rate self.start_time: float | None = None self.samples_emitted = 0 def add_chunk(self, chunk: np.ndarray) -> None: if chunk.size == 0: return if self.start_time is None: self.start_time = time.monotonic() self.samples_emitted += int(chunk.size) def buffered_seconds(self) -> float: if self.start_time is None: return 0.0 elapsed = time.monotonic() - self.start_time buffered = self.samples_emitted / self.sample_rate - elapsed return max(0.0, buffered) class AudioFrameDecoder: def __init__( self, decoder: AudioStreamDecoder, codebook_size: int, audio_eos_token: int, ): self.decoder = decoder self.codebook_size = codebook_size self.audio_eos_token = audio_eos_token def decode_frames(self, audio_frames: list[torch.Tensor]) -> Iterator[np.ndarray]: for frame in audio_frames: tokens = frame if tokens.dim() == 3: tokens = tokens[0] if tokens.dim() != 2: raise ValueError(f"Expected [T, C] audio tokens, got {tuple(tokens.shape)}") tokens, stop = _sanitize_tokens(tokens, self.codebook_size, self.audio_eos_token) if tokens.numel() == 0: if stop: break continue self.decoder.push_tokens(tokens.detach()) for wav in self.decoder.audio_chunks(): if wav.numel() == 0: continue yield wav.detach().cpu().numpy().reshape(-1) if stop: break def flush(self) -> Iterator[np.ndarray]: final_chunk = self.decoder.flush() if final_chunk is not None and final_chunk.numel() > 0: yield final_chunk.detach().cpu().numpy().reshape(-1) class StreamAudioEmitter: def __init__(self, sample_rate: int, prebuffer_seconds: float): self.sample_rate = sample_rate self._buffer_tracker = BufferedAudioTracker(sample_rate) self._prebuffer_target = max(0.0, float(prebuffer_seconds)) self._prebuffering = self._prebuffer_target > 0.0 self._pending_chunks: list[np.ndarray] = [] self._pending_samples = 0 self.chunk_count = 0 self.has_audio = False def wait_for_capacity(self, threshold_seconds: float) -> None: _maybe_wait_for_buffer(self._buffer_tracker, threshold_seconds) def emit_many(self, chunks: Iterator[np.ndarray], message_prefix: str) -> Iterator[StreamEvent]: for chunk in chunks: yield from self.emit(chunk, message_prefix) def emit(self, chunk: np.ndarray, message_prefix: str) -> Iterator[StreamEvent]: chunk = np.asarray(chunk).reshape(-1) if chunk.size == 0: return if self._prebuffering: self._pending_chunks.append(chunk) self._pending_samples += int(chunk.size) if (self._pending_samples / self.sample_rate) < self._prebuffer_target: return self._prebuffering = False pending_chunks = self._pending_chunks self._pending_chunks = [] self._pending_samples = 0 for pending in pending_chunks: yield self._make_event(pending, message_prefix) return yield self._make_event(chunk, message_prefix) def flush(self, message_prefix: str) -> Iterator[StreamEvent]: if not self._prebuffering or not self._pending_chunks: self._prebuffering = False return self._prebuffering = False pending_chunks = self._pending_chunks self._pending_chunks = [] self._pending_samples = 0 for chunk in pending_chunks: yield self._make_event(chunk, message_prefix) def _make_event(self, chunk: np.ndarray, message_prefix: str) -> StreamEvent: self.chunk_count += 1 self.has_audio = True self._buffer_tracker.add_chunk(chunk) return StreamEvent( message=f"{message_prefix} chunk {self.chunk_count}", audio=(self.sample_rate, chunk), ) def _maybe_wait_for_buffer(buffer_tracker: BufferedAudioTracker, threshold_seconds: float) -> None: if threshold_seconds <= 0: return while buffer_tracker.buffered_seconds() > threshold_seconds: time.sleep(0.01) def _sanitize_tokens( tokens: torch.Tensor, codebook_size: int, audio_eos_token: int, ) -> tuple[torch.Tensor, bool]: if tokens.dim() == 1: tokens = tokens.unsqueeze(0) if tokens.numel() == 0: return tokens, False eos_rows = (tokens[:, 0] == audio_eos_token).nonzero(as_tuple=False) invalid_rows = ((tokens < 0) | (tokens >= codebook_size)).any(dim=1) stop_idx = None if eos_rows.numel() > 0: stop_idx = int(eos_rows[0].item()) if invalid_rows.any(): invalid_idx = int(invalid_rows.nonzero(as_tuple=False)[0].item()) stop_idx = invalid_idx if stop_idx is None else min(stop_idx, invalid_idx) if stop_idx is not None: tokens = tokens[:stop_idx] return tokens, True return tokens, False def _build_streaming_session( model: MossTTSRealtime, tokenizer, processor: MossTTSRealtimeProcessor, codec, *, max_length: int, chunk_duration: float, temperature: float, top_p: float, top_k: int, do_sample: bool, repetition_penalty: float, repetition_window: int, ) -> tuple[MossTTSRealtimeStreamingSession, MossTTSRealtimeInference]: inferencer = MossTTSRealtimeInference(model, tokenizer, max_length=max_length) inferencer.reset_generation_state(keep_cache=False) session = MossTTSRealtimeStreamingSession( inferencer, processor, codec=codec, codec_sample_rate=SAMPLE_RATE, codec_encode_kwargs={"chunk_duration": chunk_duration}, prefill_text_len=processor.delay_tokens_len, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=do_sample, repetition_penalty=repetition_penalty, repetition_window=repetition_window, ) return session, inferencer def _build_frame_decoder( codec, inferencer: MossTTSRealtimeInference, device: torch.device, *, chunk_frames: int, overlap_frames: int, ) -> AudioFrameDecoder: decoder = AudioStreamDecoder( codec, chunk_frames=chunk_frames, overlap_frames=overlap_frames, decode_kwargs={"chunk_duration": -1}, device=device, ) return AudioFrameDecoder( decoder, int(getattr(codec, "codebook_size", 1024)), int(getattr(inferencer, "audio_eos_token", 1026)), ) def _normalize_seed(value: float | int | None) -> int | None: if value is None: return None seed = int(value) return None if seed == 0 else seed def _format_completion_status( chunk_count: int, sample_rate: int, full_audio: np.ndarray, started_at: float, first_audio_time: float | None, ) -> str: elapsed = time.monotonic() - started_at audio_seconds = float(full_audio.size) / float(sample_rate) if full_audio.size > 0 else 0.0 rtf = (elapsed / audio_seconds) if audio_seconds > 0 else float("inf") parts = [ "Done", ] return " | ".join(parts) @functools.lru_cache(maxsize=1) def _load_backend( model_path: str, tokenizer_path: str, codec_model_path: str, device_str: str, attn_impl: str, ): # ZeroGPU: do not call torch.cuda.is_available() here; it may trigger low-level CUDA init. device = torch.device(device_str) tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) processor = MossTTSRealtimeProcessor(tokenizer) # ZeroGPU: avoid torch.cuda.is_bf16_supported() before CUDA is fully managed. dtype = torch.bfloat16 if attn_impl and attn_impl.lower() not in {"none", ""}: model = MossTTSRealtime.from_pretrained(model_path, attn_implementation=attn_impl, torch_dtype=dtype).to(device) if ( attn_impl.lower() == "flash_attention_2" and hasattr(model, "language_model") and hasattr(model.language_model, "config") ): model.language_model.config.attn_implementation = "flash_attention_2" else: model = MossTTSRealtime.from_pretrained(model_path, torch_dtype=dtype).to(device) model.eval() codec = _load_codec(device, codec_model_path) return model, tokenizer, processor, codec, device def _resolve_audio_path(audio_path: str | None, use_default: bool, default_path: str | Path) -> Path | None: if audio_path: return Path(audio_path).expanduser() if use_default: return Path(default_path).expanduser() return None class StreamingTTSDemo: def __init__(self, audio_token_cache_size: int = 8): self._audio_token_cache_size = max(1, int(audio_token_cache_size)) self._audio_token_cache: OrderedDict[tuple[str, int, float], np.ndarray] = OrderedDict() def get_or_load_backend(self, backend: BackendPaths): return _load_backend( backend.model_path, backend.tokenizer_path, backend.codec_model_path, backend.device_str, backend.attn_impl, ) def _validate_request(self, request: StreamingRequest) -> tuple[Path | None, Path | None]: if not request.user_text.strip(): raise ValueError("user_text is required.") if not request.assistant_text.strip(): raise ValueError("assistant_text is required.") if request.streaming.text_chunk_tokens <= 0: raise ValueError("text_chunk_tokens must be greater than 0.") if request.streaming.decode_chunk_frames <= 0: raise ValueError("decode_chunk_frames must be greater than 0.") if request.streaming.chunk_duration <= 0: raise ValueError("chunk_duration must be greater than 0.") prompt_path = _resolve_audio_path(request.prompt_audio, request.use_default_prompt, PROMPT_WAV) user_path = _resolve_audio_path(request.user_audio, request.use_default_user, USER_WAV) if prompt_path is not None and not prompt_path.exists(): raise FileNotFoundError(f"Prompt wav not found: {prompt_path}") if user_path is not None and not user_path.exists(): raise FileNotFoundError(f"User wav not found: {user_path}") return prompt_path, user_path def _encode_audio_tokens( self, path: Path, codec, device: torch.device, chunk_duration: float, ) -> np.ndarray: resolved_path = path.expanduser().resolve() cache_key = (str(resolved_path), int(resolved_path.stat().st_mtime_ns), float(chunk_duration)) cached_tokens = self._audio_token_cache.get(cache_key) if cached_tokens is not None: self._audio_token_cache.move_to_end(cache_key) return cached_tokens with torch.inference_mode(): audio_tensor = _load_audio(resolved_path) waveform = audio_tensor.to(device) if waveform.dim() == 2: waveform = waveform.unsqueeze(0) encode_result = codec.encode(waveform, chunk_duration=chunk_duration) tokens = _extract_codes(encode_result) if isinstance(tokens, torch.Tensor): tokens = tokens.detach().cpu().numpy() else: tokens = np.asarray(tokens) self._audio_token_cache[cache_key] = tokens self._audio_token_cache.move_to_end(cache_key) while len(self._audio_token_cache) > self._audio_token_cache_size: self._audio_token_cache.popitem(last=False) return tokens @staticmethod def _build_text_only_turn_input( processor: MossTTSRealtimeProcessor, user_text: str, prompt_tokens: np.ndarray | None, ) -> np.ndarray: system_prompt = processor.make_ensemble(prompt_tokens) user_prompt_text = "<|im_end|>\n<|im_start|>user\n" + user_text + "<|im_end|>\n<|im_start|>assistant\n" user_prompt_tokens = processor.tokenizer(user_prompt_text)["input_ids"] user_prompt = np.full( shape=(len(user_prompt_tokens), processor.channels + 1), fill_value=processor.audio_channel_pad, dtype=np.int64, ) user_prompt[:, 0] = np.asarray(user_prompt_tokens, dtype=np.int64) return np.concatenate([system_prompt, user_prompt], axis=0) def _prepare_session_turn( self, session: MossTTSRealtimeStreamingSession, processor: MossTTSRealtimeProcessor, user_text: str, prompt_tokens: np.ndarray | None, user_tokens: np.ndarray | None, ) -> str | None: if user_tokens is None: turn_input_ids = self._build_text_only_turn_input(processor, user_text, prompt_tokens) session.reset_turn(input_ids=turn_input_ids, include_system_prompt=True, reset_cache=True) return "No user audio provided, running text-only turn." session.reset_turn( user_text=user_text, user_audio_tokens=user_tokens, include_system_prompt=True, reset_cache=True, ) return None def run_stream(self, request: StreamingRequest) -> Iterator[StreamEvent]: prompt_path, user_path = self._validate_request(request) model, tokenizer, processor, codec, device = self.get_or_load_backend(request.backend) _apply_seed(request.generation.seed) prompt_tokens = ( self._encode_audio_tokens( prompt_path, codec, device, chunk_duration=request.streaming.chunk_duration, ) if prompt_path is not None else None ) user_tokens = ( self._encode_audio_tokens( user_path, codec, device, chunk_duration=request.streaming.chunk_duration, ) if user_path is not None else None ) session, inferencer = _build_streaming_session( model, tokenizer, processor, codec, max_length=request.generation.max_length, chunk_duration=request.streaming.chunk_duration, temperature=request.generation.temperature, top_p=request.generation.top_p, top_k=request.generation.top_k, do_sample=request.generation.do_sample, repetition_penalty=request.generation.repetition_penalty, repetition_window=request.generation.repetition_window, ) if prompt_tokens is not None: session.set_voice_prompt_tokens(prompt_tokens) else: session.clear_voice_prompt() turn_message = self._prepare_session_turn( session, processor, request.user_text, prompt_tokens, user_tokens, ) if turn_message: yield StreamEvent(message=turn_message) frame_decoder = _build_frame_decoder( codec, inferencer, device, chunk_frames=request.streaming.decode_chunk_frames, overlap_frames=request.streaming.decode_overlap_frames, ) text_tokens = tokenizer.encode(request.assistant_text, add_special_tokens=False) if not text_tokens: raise RuntimeError("Assistant text tokenization returned no tokens.") token_stream = TokenChunkStream(text_tokens, request.streaming.text_chunk_tokens) audio_emitter = StreamAudioEmitter(SAMPLE_RATE, request.streaming.prebuffer_seconds) with codec.streaming(batch_size=1): for token_chunk in token_stream: audio_emitter.wait_for_capacity(request.streaming.buffer_threshold_seconds) audio_frames = session.push_text_tokens(token_chunk) yield from audio_emitter.emit_many(frame_decoder.decode_frames(audio_frames), "Streaming") if request.streaming.input_delay > 0: time.sleep(request.streaming.input_delay) final_frames = session.end_text() yield from audio_emitter.emit_many(frame_decoder.decode_frames(final_frames), "Finalizing") while True: drain_frames = session.drain(max_steps=1) if not drain_frames: break yield from audio_emitter.emit_many(frame_decoder.decode_frames(drain_frames), "Finalizing") if session.inferencer.is_finished: break yield from audio_emitter.emit_many(frame_decoder.flush(), "Final") yield from audio_emitter.flush("Final") if not audio_emitter.has_audio: raise RuntimeError("No audio waveform chunks decoded from streaming inference.") yield StreamEvent(message="Streaming complete.") class WarmupManager: def __init__(self, tts_demo: "StreamingTTSDemo", backend: BackendPaths): self.tts_demo = tts_demo self.backend = backend self._lock = threading.Lock() self._thread: threading.Thread | None = None self._started = False # ZeroGPU: startup warmup is disabled because it initializes CUDA outside @spaces.GPU. self._state = "ready" self._progress = 1.0 self._message = "Ready." self._detail = "Startup warmup disabled for ZeroGPU; the first generation will load the model." self._error: str | None = None def start(self) -> None: with self._lock: if self._started: return self._started = True self._thread = threading.Thread(target=self._run, name="tts-startup-warmup", daemon=True) self._thread.start() def snapshot(self) -> WarmupSnapshot: with self._lock: return WarmupSnapshot( state=self._state, progress=self._progress, message=self._message, detail=self._detail, error=self._error, ) def _set_state( self, *, state: str | None = None, progress: float | None = None, message: str | None = None, detail: str | None = None, error: str | None = None, ) -> None: with self._lock: if state is not None: self._state = state if progress is not None: self._progress = max(0.0, min(1.0, float(progress))) if message is not None: self._message = message if detail is not None: self._detail = detail self._error = error @staticmethod def _consume_audio(chunks: Iterator[np.ndarray]) -> None: for _chunk in chunks: pass @staticmethod def _ensure_warmup_text(tokenizer, minimum_tokens: int) -> tuple[str, list[int]]: text = WARMUP_BASE_ASSISTANT_TEXT tokens = tokenizer.encode(text, add_special_tokens=False) while len(tokens) < minimum_tokens: text = f"{text} {WARMUP_BASE_ASSISTANT_TEXT}" tokens = tokenizer.encode(text, add_special_tokens=False) return text, tokens @staticmethod def _warmup_step_detail(step_idx: int, total_steps: int) -> str: if step_idx == 1: return "First incremental step is compiling the cold streaming path." if step_idx == 2: return "Second incremental step is warming the next steady-state path." if step_idx == DEFAULT_REPETITION_WINDOW: return "Warming the first full repetition-window step." if step_idx == WARMUP_STEP_TOKENS: return "Confirming the post-window steady-state step." return f"Warming token step {step_idx}/{total_steps}." def _run(self) -> None: try: self._set_state( state="running", progress=0.02, message="Starting startup warmup.", detail="Preparing backend state for the first real request.", error=None, ) self._set_state( progress=0.08, message="Loading backend.", detail="Model, tokenizer, codec, and CUDA runtime are warming up.", error=None, ) model, tokenizer, processor, codec, device = self.tts_demo.get_or_load_backend(self.backend) self._set_state( progress=0.32, message="Preparing streaming session.", detail="Building a text-only warmup turn and its decoder.", error=None, ) session, inferencer = _build_streaming_session( model, tokenizer, processor, codec, max_length=256, chunk_duration=0.24, temperature=0.8, top_p=0.6, top_k=30, do_sample=True, repetition_penalty=1.1, repetition_window=DEFAULT_REPETITION_WINDOW, ) session.clear_voice_prompt() session.reset_turn( input_ids=self.tts_demo._build_text_only_turn_input(processor, WARMUP_USER_TEXT, None), include_system_prompt=True, reset_cache=True, ) frame_decoder = _build_frame_decoder( codec, inferencer, device, chunk_frames=WARMUP_STEP_TOKENS, overlap_frames=0, ) _, warmup_tokens = self._ensure_warmup_text( tokenizer, processor.delay_tokens_len + WARMUP_STEP_TOKENS, ) with codec.streaming(batch_size=1): self._set_state( progress=0.45, message="Running prefill.", detail="Building the first KV cache and warming the backbone path.", error=None, ) prefill_frames = session.push_text_tokens(warmup_tokens[: processor.delay_tokens_len]) self._consume_audio(frame_decoder.decode_frames(prefill_frames)) step_tokens = warmup_tokens[ processor.delay_tokens_len : processor.delay_tokens_len + WARMUP_STEP_TOKENS ] total_steps = max(1, len(step_tokens)) for idx, token in enumerate(step_tokens, start=1): self._set_state( progress=0.55 + 0.25 * (idx - 1) / total_steps, message="Compiling first streaming steps.", detail=self._warmup_step_detail(idx, total_steps), error=None, ) step_frames = session.push_text_tokens([token]) self._consume_audio(frame_decoder.decode_frames(step_frames)) self._set_state( progress=0.86, message="Warming finalization path.", detail="Priming end-text, drain, and decoder flush before user traffic.", error=None, ) final_frames = session.end_text() self._consume_audio(frame_decoder.decode_frames(final_frames)) drain_frames = session.drain(max_steps=1) self._consume_audio(frame_decoder.decode_frames(drain_frames)) self._consume_audio(frame_decoder.flush()) self._set_state( state="ready", progress=1.0, message="Warmup complete.", detail="The first real request should avoid the cold-start stall.", error=None, ) except Exception as exc: self._set_state( state="failed", progress=1.0, message="Warmup failed.", detail="The app did not finish startup warmup.", error=str(exc), ) print(f"[MossTTSRealtime][warmup-error] {exc}", file=sys.stderr, flush=True) def _warmup_button_update(snapshot: WarmupSnapshot): if snapshot.ready: return gr.update(value="Generate", interactive=True) if snapshot.failed: return gr.update(value="Warmup Failed", interactive=False) return gr.update(value="Warming Up...", interactive=False) def _warmup_gate_message(snapshot: WarmupSnapshot) -> str: progress_pct = int(round(max(0.0, min(1.0, snapshot.progress)) * 100.0)) if snapshot.failed: return f"Warmup failed: {snapshot.error or snapshot.message}" return f"Warmup in progress ({progress_pct}%): {snapshot.message}" def _status_from_snapshot(snapshot: WarmupSnapshot) -> str: return "Ready." if snapshot.ready else _warmup_gate_message(snapshot) def _warmup_status_update(snapshot: WarmupSnapshot): return gr.update(value=_status_from_snapshot(snapshot)) def _warmup_timer_update(snapshot: WarmupSnapshot): return gr.update(active=not (snapshot.ready or snapshot.failed)) def _encode_chunk(sr: int, chunk: np.ndarray, idx: int) -> str: if chunk.dtype != np.float32: chunk = chunk.astype(np.float32) if chunk.ndim != 1: chunk = chunk.reshape(-1) payload = { "sr": int(sr), "idx": int(idx), "data": base64.b64encode(chunk.tobytes()).decode("ascii"), } return json.dumps(payload) def _build_request( args: argparse.Namespace, *, user_text: str | None, assistant_text: str | None, prompt_audio: str | None, user_audio: str | None, use_default_prompt: bool, use_default_user: bool, temperature: float, top_p: float, top_k: int, repetition_penalty: float, repetition_window: int, do_sample: bool, max_length: int, seed: float | int | None, text_chunk_tokens: int, input_delay: float, decode_chunk_frames: int, decode_overlap_frames: int, chunk_duration: float, prebuffer_seconds: float, ) -> StreamingRequest: return StreamingRequest( user_text=str(user_text or "Hello!"), assistant_text=str(assistant_text or ""), prompt_audio=prompt_audio, user_audio=user_audio, use_default_prompt=use_default_prompt, use_default_user=use_default_user, generation=GenerationConfig( temperature=float(temperature), top_p=float(top_p), top_k=int(top_k), repetition_penalty=float(repetition_penalty), repetition_window=int(repetition_window), do_sample=bool(do_sample), max_length=int(max_length), seed=_normalize_seed(seed), ), streaming=StreamingConfig( text_chunk_tokens=int(text_chunk_tokens), input_delay=float(input_delay), decode_chunk_frames=int(decode_chunk_frames), decode_overlap_frames=int(decode_overlap_frames), chunk_duration=float(chunk_duration), prebuffer_seconds=float(prebuffer_seconds), ), backend=BackendPaths( model_path=args.model_path, tokenizer_path=args.tokenizer_path, codec_model_path=args.codec_model_path, device_str=args.device, attn_impl=args.attn_implementation, ), ) STREAM_PLAYER_HTML = """ """ STREAM_PLAYER_JS = r""" const elemId = "pcm_stream"; if (window.__pcm_streaming_inited__) { return; } window.__pcm_streaming_inited__ = true; let audioCtx = null; let nextTime = 0; let lastIdx = -1; let lastValue = ""; let boundField = null; let usingSetterHook = false; const FADE_MS = 6; const MIN_BUFFER_SEC = 0.25; function initAudio(sr) { if (audioCtx && audioCtx.sampleRate !== sr) { audioCtx.close(); audioCtx = null; } if (!audioCtx) { audioCtx = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: sr }); nextTime = audioCtx.currentTime; } if (audioCtx.state === "suspended") { audioCtx.resume(); } } function decodeBase64ToFloat32(base64) { const binary = atob(base64); const len = binary.length; const bytes = new Uint8Array(len); for (let i = 0; i < len; i++) { bytes[i] = binary.charCodeAt(i); } return new Float32Array(bytes.buffer); } function playChunk(samples, sr, idx) { initAudio(sr); const buffer = audioCtx.createBuffer(1, samples.length, sr); buffer.copyToChannel(samples, 0); const source = audioCtx.createBufferSource(); source.buffer = buffer; const gain = audioCtx.createGain(); source.connect(gain); gain.connect(audioCtx.destination); const now = audioCtx.currentTime; if (nextTime < now + MIN_BUFFER_SEC) { nextTime = now + MIN_BUFFER_SEC; } const startTime = Math.max(now, nextTime); const endTime = startTime + buffer.duration; const fade = Math.min(FADE_MS / 1000.0, buffer.duration / 4); gain.gain.setValueAtTime(0.0, startTime); gain.gain.linearRampToValueAtTime(1.0, startTime + fade); gain.gain.setValueAtTime(1.0, Math.max(startTime + fade, endTime - fade)); gain.gain.linearRampToValueAtTime(0.0, endTime); source.start(startTime); nextTime = endTime; } function handlePayload(text) { if (!text) return; let payload; try { payload = JSON.parse(text); } catch (e) { return; } if (Array.isArray(payload)) { for (const item of payload) { handlePayloadObject(item); } return; } handlePayloadObject(payload); } function handlePayloadObject(payload) { if (!payload) return; if (payload.reset) { lastIdx = -1; lastValue = ""; if (audioCtx) { audioCtx.close(); audioCtx = null; } return; } const idx = payload.idx ?? 0; if (idx <= lastIdx) return; lastIdx = idx; const sr = payload.sr || 24000; const samples = decodeBase64ToFloat32(payload.data); playChunk(samples, sr, idx); } function hookField(field) { if (!field || field === boundField) return; boundField = field; const proto = field.tagName === "TEXTAREA" ? HTMLTextAreaElement.prototype : HTMLInputElement.prototype; const desc = Object.getOwnPropertyDescriptor(proto, "value"); if (!desc || !desc.get || !desc.set) { usingSetterHook = false; return; } usingSetterHook = true; const nativeGet = desc.get; const nativeSet = desc.set; Object.defineProperty(field, "value", { configurable: true, get() { return nativeGet.call(field); }, set(v) { nativeSet.call(field, v); if (v && v !== lastValue) { lastValue = v; handlePayload(v); } }, }); const initial = field.value; if (initial && initial !== lastValue) { lastValue = initial; handlePayload(initial); } } function pollField() { const field = document.querySelector(`#${elemId} textarea, #${elemId} input`); if (!field) { boundField = null; usingSetterHook = false; setTimeout(pollField, 300); return; } if (field !== boundField) { hookField(field); } setTimeout(pollField, 300); } function pollValue() { if (usingSetterHook) { setTimeout(pollValue, 500); return; } const field = document.querySelector(`#${elemId} textarea, #${elemId} input`); if (!field) { setTimeout(pollValue, 300); return; } const value = field.value; if (value && value !== lastValue) { lastValue = value; handlePayload(value); } setTimeout(pollValue, 40); } function tryUnlockAudio() { if (!audioCtx) { audioCtx = new (window.AudioContext || window.webkitAudioContext)(); } if (audioCtx.state === "suspended") { audioCtx.resume(); } } document.addEventListener("click", (event) => { const btn = event.target.closest("#tts_generate"); if (btn) { tryUnlockAudio(); } }); pollField(); pollValue(); """ def _build_demo( args: argparse.Namespace, tts_demo: StreamingTTSDemo, warmup_manager: WarmupManager, ): initial_warmup_snapshot = warmup_manager.snapshot() with gr.Blocks(title="MossTTSRealtime") as demo: gr.Markdown("MossTTSRealtime demo") gr.HTML(STREAM_PLAYER_HTML, js_on_load=STREAM_PLAYER_JS) with gr.Row(): with gr.Column(): assistant_text = gr.Textbox(label="Assistant Text", lines=6) prompt_audio = gr.Audio(label="Prompt WAV (optional)", type="filepath") with gr.Accordion("User Input Options", open=False): user_text = gr.Textbox(label="User Text(optional)", lines=2) user_audio = gr.Audio(label="User WAV (optional)", type="filepath") use_default_prompt = gr.Checkbox(label="Use Default Prompt WAV (fallback)", value=False) use_default_user = gr.Checkbox(label="Use Default User WAV (fallback)", value=False) with gr.Accordion("Generation Options", open=False): temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="Top P") top_k = gr.Slider(1, 100, value=30, step=1, label="Top K") repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.05, label="Repetition Penalty") repetition_window = gr.Slider( 1, 200, value=DEFAULT_REPETITION_WINDOW, step=1, label="Repetition Window" ) do_sample = gr.Checkbox(label="Do Sample", value=True) max_length = gr.Slider(100, 10000, value=2000, step=10, label="Max Length") seed = gr.Number(value=0, precision=0, label="Seed (0 for random)") with gr.Accordion("Streaming Options", open=False): stream_text_chunk_tokens = gr.Slider(1, 64, value=12, step=1, label="Text Chunk Tokens") stream_input_delay = gr.Slider(0.0, 0.5, value=0.0, step=0.05, label="Input Delay (s)") stream_decode_chunk_frames = gr.Slider(1, 20, value=6, step=1, label="Decode Chunk Frames") stream_decode_overlap_frames = gr.Slider(0, 10, value=0, step=1, label="Decode Overlap Frames") chunk_duration = gr.Slider(0.08, 4.0, value=0.96, step=0.08, label="Codec Chunk Duration (s)") stream_prebuffer_seconds = gr.Slider(0.0, 20.0, value=0.0, step=0.05, label="Initial Buffer (s)") run_btn = gr.Button( "Generate" if initial_warmup_snapshot.ready else "Warming Up...", elem_id="tts_generate", interactive=initial_warmup_snapshot.ready, ) with gr.Column(): stream_data = gr.Textbox(label="PCM Stream (JSON)", elem_id="pcm_stream", interactive=False, lines=6) output_audio = gr.Audio(label="Final Audio", type="numpy") initial_status = _status_from_snapshot(initial_warmup_snapshot) status = gr.Textbox(label="Status", lines=3, value=initial_status) warmup_timer = gr.Timer(value=WARMUP_POLL_INTERVAL_SECONDS, active=not initial_warmup_snapshot.ready) def _poll_warmup_state(): snapshot = warmup_manager.snapshot() return ( _warmup_button_update(snapshot), _warmup_status_update(snapshot), _warmup_timer_update(snapshot), ) @spaces.GPU def _on_generate( user_text_value, assistant_text_value, prompt_audio_value, user_audio_value, use_default_prompt_value, use_default_user_value, temperature_value, top_p_value, top_k_value, repetition_penalty_value, repetition_window_value, do_sample_value, max_length_value, seed_value, stream_text_chunk_tokens_value, stream_input_delay_value, stream_decode_chunk_frames_value, stream_decode_overlap_frames_value, chunk_duration_value, stream_prebuffer_seconds_value, ): try: started_at = time.monotonic() full_chunks: list[np.ndarray] = [] first_audio_time: float | None = None sample_rate = SAMPLE_RATE rtf_logger: StreamRTFLogger | None = None request = _build_request( args, user_text=user_text_value, assistant_text=assistant_text_value, prompt_audio=prompt_audio_value, user_audio=user_audio_value, use_default_prompt=bool(use_default_prompt_value), use_default_user=bool(use_default_user_value), temperature=float(temperature_value), top_p=float(top_p_value), top_k=int(top_k_value), repetition_penalty=float(repetition_penalty_value), repetition_window=int(repetition_window_value), do_sample=bool(do_sample_value), max_length=int(max_length_value), seed=seed_value, text_chunk_tokens=int(stream_text_chunk_tokens_value), input_delay=float(stream_input_delay_value), decode_chunk_frames=int(stream_decode_chunk_frames_value), decode_overlap_frames=int(stream_decode_overlap_frames_value), chunk_duration=float(chunk_duration_value), prebuffer_seconds=float(stream_prebuffer_seconds_value), ) rtf_logger = StreamRTFLogger.create(request, started_at) for event in tts_demo.run_stream(request): if event.audio is None: continue sr, chunk = event.audio chunk = np.asarray(chunk).reshape(-1) if chunk.size == 0: continue full_chunks.append(chunk) sample_rate = sr if first_audio_time is None: first_audio_time = time.monotonic() if rtf_logger is not None: rtf_logger.log_chunk( event_message=event.message, sample_rate=sr, chunk=chunk, first_audio_time=first_audio_time, ) if full_chunks: full_audio = np.concatenate(full_chunks) if rtf_logger is not None: rtf_logger.log_completion(first_audio_time=first_audio_time) done_msg = _format_completion_status( len(full_chunks), sample_rate, full_audio, started_at, first_audio_time, ) return "", (sample_rate, full_audio), done_msg if rtf_logger is not None: rtf_logger.log_no_audio() return "", None, "Done | no audio chunks emitted" except Exception as exc: import traceback traceback.print_exc() if rtf_logger is not None: rtf_logger.log_error(exc, first_audio_time=first_audio_time) return "", None, f"Error: {exc}" run_btn.click( _on_generate, inputs=[ user_text, assistant_text, prompt_audio, user_audio, use_default_prompt, use_default_user, temperature, top_p, top_k, repetition_penalty, repetition_window, do_sample, max_length, seed, stream_text_chunk_tokens, stream_input_delay, stream_decode_chunk_frames, stream_decode_overlap_frames, chunk_duration, stream_prebuffer_seconds, ], outputs=[stream_data, output_audio, status], ) demo.load( _poll_warmup_state, outputs=[run_btn, status, warmup_timer], queue=False, show_progress="hidden", ) warmup_timer.tick( _poll_warmup_state, outputs=[run_btn, status, warmup_timer], queue=False, show_progress="hidden", ) return demo def main(): parser = argparse.ArgumentParser(description="MossTTSRealtime streaming TTS Gradio demo") parser.add_argument("--model_path", type=str, default=MODEL_PATH) parser.add_argument("--tokenizer_path", type=str, default=TOKENIZER_PATH) parser.add_argument("--codec_model_path", type=str, default=CODEC_MODEL_PATH) parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument( "--attn_implementation", type=str, default="sdpa", choices=["sdpa", "flash_attention_2", "eager", "none"], ) parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7860) parser.add_argument("--share", action="store_true") args = parser.parse_args() tts_demo = StreamingTTSDemo() warmup_manager = WarmupManager( tts_demo, BackendPaths( model_path=args.model_path, tokenizer_path=args.tokenizer_path, codec_model_path=args.codec_model_path, device_str=args.device, attn_impl=args.attn_implementation, ), ) # ZeroGPU: do not run startup warmup, because it would initialize CUDA # in a background thread outside @spaces.GPU. # warmup_manager.start() demo = _build_demo(args, tts_demo, warmup_manager) demo.queue(max_size=10, default_concurrency_limit=1).launch( server_name=args.host, server_port=args.port, share=args.share, ssr_mode=False, ) if __name__ == "__main__": main()