Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import importlib | |
| import logging | |
| import threading | |
| import time | |
| import uuid | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Iterator, Optional | |
| import numpy as np | |
| import torch | |
| from transformers import AutoModel, AutoModelForCausalLM | |
| MOSS_AUDIO_TOKENIZER_TYPE = "moss-audio-tokenizer-nano" | |
| APP_DIR = Path(__file__).resolve().parent | |
| DEFAULT_CHECKPOINT_PATH = APP_DIR / "weights" / "tts" | |
| DEFAULT_AUDIO_TOKENIZER_PATH = APP_DIR / "weights" / "codec" | |
| DEFAULT_PROMPT_AUDIO_DIR = APP_DIR / "asserts" / "audio" | |
| DEFAULT_OUTPUT_DIR = APP_DIR / "generated_audio" | |
| _DEFAULT_VOICE_FILES: dict[str, tuple[str, str]] = { | |
| "Junhao": ("zh_1.wav", "Chinese male voice A"), | |
| "Zhiming": ("zh_2.wav", "Chinese male voice B"), | |
| "Weiguo": ("zh_5.wav", "Chinese male voice C"), | |
| "Xiaoyu": ("zh_3.wav", "Chinese female voice A"), | |
| "Yuewen": ("zh_4.wav", "Chinese female voice B"), | |
| "Lingyu": ("zh_6.wav", "Chinese female voice C"), | |
| "Trump": ("en_1.wav", "Trump reference voice"), | |
| "Ava": ("en_2.wav", "English female voice A"), | |
| "Bella": ("en_3.wav", "English female voice B"), | |
| "Adam": ("en_4.wav", "English male voice A"), | |
| "Nathan": ("en_5.wav", "English male voice B"), | |
| "Yui": ("jp_2.wav", "Japanese female voice B"), | |
| "Aoi": ("jp_3.wav", "Japanese female voice C"), | |
| "Hina": ("jp_4.wav", "Japanese female voice D"), | |
| "Mei": ("jp_5.wav", "Japanese female voice E"), | |
| } | |
| DEFAULT_VOICE = "Junhao" | |
| FLASH_ATTENTION_DTYPES = {torch.float16, torch.bfloat16} | |
| class VoicePreset: | |
| name: str | |
| prompt_audio_path: Path | |
| description: str | |
| def build_default_voice_presets() -> dict[str, VoicePreset]: | |
| presets: dict[str, VoicePreset] = {} | |
| for voice_name, (file_name, description) in _DEFAULT_VOICE_FILES.items(): | |
| prompt_path = (DEFAULT_PROMPT_AUDIO_DIR / file_name).resolve() | |
| presets[voice_name] = VoicePreset( | |
| name=voice_name, | |
| prompt_audio_path=prompt_path, | |
| description=description, | |
| ) | |
| return presets | |
| def resolve_device(device_arg: str) -> torch.device: | |
| if device_arg == "auto": | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| return torch.device(device_arg) | |
| def resolve_dtype(dtype_arg: str, device: torch.device) -> torch.dtype: | |
| if dtype_arg == "float32": | |
| return torch.float32 | |
| if dtype_arg == "float16": | |
| return torch.float16 | |
| if dtype_arg == "bfloat16": | |
| return torch.bfloat16 | |
| if device.type == "cuda": | |
| if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported(): | |
| return torch.bfloat16 | |
| return torch.float16 | |
| return torch.float32 | |
| def waveform_to_numpy(waveform: torch.Tensor | np.ndarray) -> np.ndarray: | |
| if torch.is_tensor(waveform): | |
| array = waveform.detach().cpu().numpy() | |
| else: | |
| array = np.asarray(waveform) | |
| if array.ndim == 1: | |
| return array.astype(np.float32, copy=False) | |
| if array.ndim != 2: | |
| raise ValueError(f"Unsupported waveform shape: {array.shape}") | |
| if array.shape[0] <= 8 and array.shape[0] < array.shape[1]: | |
| array = array.T | |
| return array.astype(np.float32, copy=False) | |
| def _has_flash_attn() -> bool: | |
| try: | |
| importlib.import_module("flash_attn") | |
| except Exception: | |
| return False | |
| return True | |
| class NanoTTSService: | |
| def __init__( | |
| self, | |
| *, | |
| checkpoint_path: str | Path = DEFAULT_CHECKPOINT_PATH, | |
| audio_tokenizer_path: str | Path = DEFAULT_AUDIO_TOKENIZER_PATH, | |
| device: str = "auto", | |
| dtype: str = "auto", | |
| attn_implementation: str = "auto", | |
| output_dir: str | Path = DEFAULT_OUTPUT_DIR, | |
| voice_presets: Optional[dict[str, VoicePreset]] = None, | |
| ) -> None: | |
| self.checkpoint_path = Path(checkpoint_path).expanduser().resolve() | |
| self.audio_tokenizer_path = Path(audio_tokenizer_path).expanduser().resolve() | |
| self.output_dir = Path(output_dir).expanduser().resolve() | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| self.voice_presets = voice_presets or build_default_voice_presets() | |
| self.default_voice = DEFAULT_VOICE if DEFAULT_VOICE in self.voice_presets else next(iter(self.voice_presets)) | |
| self.device = resolve_device(device) | |
| self.dtype = resolve_dtype(dtype, self.device) | |
| self.attn_implementation = self._resolve_attn_implementation(attn_implementation) | |
| self._lock = threading.RLock() | |
| self._model = None | |
| self._audio_tokenizer = None | |
| self._checkpoint_global_attn_implementation: str | None = None | |
| self._checkpoint_local_attn_implementation: str | None = None | |
| self._configured_global_attn_implementation: str | None = None | |
| self._configured_local_attn_implementation: str | None = None | |
| self._configured_audio_tokenizer_attn_implementation: str | None = None | |
| self._configured_audio_tokenizer_compute_dtype: str | None = None | |
| def _can_use_flash_attention(self) -> bool: | |
| return self.device.type == "cuda" and self.dtype in FLASH_ATTENTION_DTYPES and _has_flash_attn() | |
| def _resolve_runtime_default_attn_implementation(self) -> str: | |
| return "flash_attention_2" if self._can_use_flash_attention() else "sdpa" | |
| def _resolve_attn_implementation(self, requested: str | None) -> str | None: | |
| normalized = str(requested).strip().lower() if requested is not None else "auto" | |
| if not normalized or normalized in {"auto", "default", "model_default"}: | |
| return None | |
| if normalized not in {"sdpa", "flash_attention_2", "eager"}: | |
| raise ValueError( | |
| "attn_implementation must be one of: model_default/auto, sdpa, flash_attention_2, eager" | |
| ) | |
| if normalized == "flash_attention_2": | |
| if not self._can_use_flash_attention(): | |
| logging.warning( | |
| "flash_attention_2 requires CUDA, flash_attn, and fp16/bf16; falling back to sdpa " | |
| "(device=%s dtype=%s flash_attn=%s)", | |
| self.device, | |
| self.dtype, | |
| _has_flash_attn(), | |
| ) | |
| return "sdpa" | |
| return normalized | |
| def _normalize_loaded_attn_implementation(attn_implementation: object) -> str: | |
| normalized = str(attn_implementation).strip().lower() if attn_implementation is not None else "" | |
| if not normalized or normalized == "none": | |
| return "eager" | |
| return normalized | |
| def _resolve_request_attention_implementation( | |
| self, | |
| requested: Optional[str], | |
| ) -> tuple[str, str, str]: | |
| normalized = str(requested).strip().lower() if requested is not None else "" | |
| resolved = self._resolve_attn_implementation(normalized) | |
| if resolved is not None: | |
| return normalized, resolved, resolved | |
| if self.attn_implementation is not None: | |
| return self.attn_implementation, self.attn_implementation, self.attn_implementation | |
| runtime_default = self._resolve_runtime_default_attn_implementation() | |
| return "auto", runtime_default, runtime_default | |
| def _resolve_codec_attention_implementation(tts_attn_implementation: str) -> str: | |
| return "flash_attention_2" if tts_attn_implementation == "flash_attention_2" else "sdpa" | |
| def _resolve_codec_compute_dtype(self, codec_attn_implementation: str) -> str: | |
| if codec_attn_implementation == "flash_attention_2": | |
| return "bf16" if self.dtype == torch.bfloat16 else "fp16" | |
| return "fp32" | |
| def _apply_model_attention_implementation(model, *, global_attn: str, local_attn: str) -> None: | |
| if hasattr(model, "_set_attention_implementation"): | |
| model._set_attention_implementation(global_attn, local_attn_implementation=local_attn) | |
| def _install_stream_decode_budget_patch(self, model) -> None: | |
| if self.device.type != "cuda": | |
| return | |
| model_cls = model.__class__ | |
| if getattr(model_cls, "_nanotts_stream_decode_budget_patch_installed", False): | |
| return | |
| compute_stream_lead = getattr(model_cls, "_compute_stream_lead_seconds", None) | |
| resolve_stream_budget = getattr(model_cls, "_resolve_stream_decode_frame_budget", None) | |
| if not callable(compute_stream_lead) or not callable(resolve_stream_budget): | |
| return | |
| def _patched_resolve_stream_decode_frame_budget( | |
| *, | |
| emitted_samples_total: int, | |
| sample_rate: int, | |
| first_audio_emitted_at, | |
| ) -> int: | |
| # The upstream streaming policy starts with one decode frame | |
| # (about 80 ms audio), which makes CUDA realtime decode emit many | |
| # tiny chunks and underrun browser playback on this checkpoint. | |
| lead_seconds = compute_stream_lead( | |
| emitted_samples_total=emitted_samples_total, | |
| sample_rate=sample_rate, | |
| first_audio_emitted_at=first_audio_emitted_at, | |
| ) | |
| if first_audio_emitted_at is None or lead_seconds < 0.20: | |
| return 4 | |
| if lead_seconds < 0.55: | |
| return 6 | |
| if lead_seconds < 1.10: | |
| return 8 | |
| return 12 | |
| model_cls._nanotts_original_resolve_stream_decode_frame_budget = resolve_stream_budget | |
| model_cls._resolve_stream_decode_frame_budget = staticmethod(_patched_resolve_stream_decode_frame_budget) | |
| model_cls._nanotts_stream_decode_budget_patch_installed = True | |
| logging.info("installed Nano-TTS CUDA streaming decode budget patch") | |
| def _discard_loaded_model_locked(self, reason: str) -> None: | |
| if self._model is None: | |
| return | |
| logging.warning("discarding loaded Nano-TTS model state: %s", reason) | |
| self._model = None | |
| if self.device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| def _discard_loaded_audio_tokenizer_locked(self, reason: str) -> None: | |
| if self._audio_tokenizer is None: | |
| return | |
| logging.warning("discarding loaded Nano-TTS audio tokenizer state: %s", reason) | |
| self._audio_tokenizer = None | |
| self._configured_audio_tokenizer_attn_implementation = None | |
| self._configured_audio_tokenizer_compute_dtype = None | |
| if self.device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| def _restore_model_execution_state(self, model): | |
| current_parameter = next(model.parameters(), None) | |
| if current_parameter is None or current_parameter.dtype == self.dtype: | |
| return model | |
| self._discard_loaded_model_locked( | |
| f"current_dtype={current_parameter.dtype} expected_dtype={self.dtype}; reloading checkpoint" | |
| ) | |
| return self._load_model_locked() | |
| def _read_model_attention_implementation(self, model) -> tuple[str, str]: | |
| global_attn = self._normalize_loaded_attn_implementation( | |
| getattr(getattr(model, "transformer", None), "attn_implementation", None) | |
| ) | |
| local_attn = self._normalize_loaded_attn_implementation( | |
| getattr(getattr(model, "local_transformer", None), "attn_implementation", None) | |
| ) | |
| return global_attn, local_attn | |
| def _ensure_paths(self) -> None: | |
| if not self.checkpoint_path.exists(): | |
| raise FileNotFoundError(f"Nano-TTS checkpoint not found: {self.checkpoint_path}") | |
| if not self.audio_tokenizer_path.exists(): | |
| raise FileNotFoundError(f"Audio tokenizer checkpoint not found: {self.audio_tokenizer_path}") | |
| def _load_audio_tokenizer_locked(self, *, tts_attn_implementation: str): | |
| codec_attn_implementation = self._resolve_codec_attention_implementation(tts_attn_implementation) | |
| codec_compute_dtype = self._resolve_codec_compute_dtype(codec_attn_implementation) | |
| if self._audio_tokenizer is None: | |
| logging.info( | |
| "loading Nano-TTS audio tokenizer checkpoint=%s device=%s attn=%s compute_dtype=%s", | |
| self.audio_tokenizer_path, | |
| self.device, | |
| codec_attn_implementation, | |
| codec_compute_dtype, | |
| ) | |
| audio_tokenizer = AutoModel.from_pretrained( | |
| str(self.audio_tokenizer_path), | |
| trust_remote_code=True, | |
| local_files_only=True, | |
| ) | |
| if hasattr(audio_tokenizer, "eval"): | |
| audio_tokenizer.eval() | |
| self._audio_tokenizer = audio_tokenizer | |
| audio_tokenizer = self._audio_tokenizer | |
| if hasattr(audio_tokenizer, "to"): | |
| audio_tokenizer = audio_tokenizer.to(self.device) | |
| if hasattr(audio_tokenizer, "set_attention_implementation"): | |
| audio_tokenizer.set_attention_implementation(codec_attn_implementation) | |
| if hasattr(audio_tokenizer, "set_compute_dtype"): | |
| audio_tokenizer.set_compute_dtype(codec_compute_dtype) | |
| if hasattr(audio_tokenizer, "eval"): | |
| audio_tokenizer.eval() | |
| self._audio_tokenizer = audio_tokenizer | |
| self._configured_audio_tokenizer_attn_implementation = codec_attn_implementation | |
| self._configured_audio_tokenizer_compute_dtype = codec_compute_dtype | |
| return self._audio_tokenizer | |
| def _load_model_locked(self): | |
| if self._model is not None: | |
| return self._model | |
| self._ensure_paths() | |
| logging.info( | |
| "loading Nano-TTS checkpoint=%s audio_tokenizer=%s device=%s dtype=%s attn=%s", | |
| self.checkpoint_path, | |
| self.audio_tokenizer_path, | |
| self.device, | |
| self.dtype, | |
| self.attn_implementation or "model_default", | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| str(self.checkpoint_path), | |
| trust_remote_code=True, | |
| local_files_only=True, | |
| ) | |
| model.to(device=self.device, dtype=self.dtype) | |
| self._checkpoint_global_attn_implementation, self._checkpoint_local_attn_implementation = ( | |
| self._read_model_attention_implementation(model) | |
| ) | |
| _, default_global_attn, default_local_attn = self._resolve_request_attention_implementation(None) | |
| self._apply_model_attention_implementation( | |
| model, | |
| global_attn=default_global_attn, | |
| local_attn=default_local_attn, | |
| ) | |
| self._install_stream_decode_budget_patch(model) | |
| model.eval() | |
| self._configured_global_attn_implementation, self._configured_local_attn_implementation = ( | |
| self._read_model_attention_implementation(model) | |
| ) | |
| self._model = model | |
| return self._model | |
| def get_model(self): | |
| with self._lock: | |
| return self._load_model_locked() | |
| def list_voice_names(self) -> list[str]: | |
| return list(self.voice_presets.keys()) | |
| def get_voice_preset(self, voice_name: Optional[str]) -> VoicePreset: | |
| if voice_name and voice_name in self.voice_presets: | |
| return self.voice_presets[voice_name] | |
| return self.voice_presets[self.default_voice] | |
| def resolve_prompt_audio_path( | |
| self, | |
| *, | |
| voice: Optional[str] = None, | |
| prompt_audio_path: Optional[str | Path] = None, | |
| ) -> Path: | |
| if prompt_audio_path: | |
| resolved = Path(prompt_audio_path).expanduser().resolve() | |
| if not resolved.exists(): | |
| raise FileNotFoundError(f"Prompt audio not found: {resolved}") | |
| return resolved | |
| preset = self.get_voice_preset(voice) | |
| if not preset.prompt_audio_path.exists(): | |
| raise FileNotFoundError(f"Voice preset prompt audio not found: {preset.prompt_audio_path}") | |
| return preset.prompt_audio_path | |
| def preload(self, *, voices: Optional[list[str]] = None, load_model: bool = True) -> dict[str, object]: | |
| loaded_voices: list[str] = [] | |
| if load_model: | |
| self.get_model() | |
| for voice_name in voices or [self.default_voice]: | |
| preset = self.get_voice_preset(voice_name) | |
| if preset.prompt_audio_path.exists(): | |
| loaded_voices.append(preset.name) | |
| return { | |
| "loaded_voices": loaded_voices, | |
| "device": str(self.device), | |
| "dtype": str(self.dtype), | |
| "attn_implementation": self.attn_implementation or "auto", | |
| "checkpoint_default_attn_implementation": self._checkpoint_global_attn_implementation or "eager", | |
| "checkpoint_default_local_attn_implementation": self._checkpoint_local_attn_implementation or "eager", | |
| "configured_attn_implementation": self._configured_global_attn_implementation or "eager", | |
| "configured_local_attn_implementation": self._configured_local_attn_implementation or "eager", | |
| "configured_codec_attn_implementation": self._configured_audio_tokenizer_attn_implementation or "unknown", | |
| "configured_codec_compute_dtype": self._configured_audio_tokenizer_compute_dtype or "unknown", | |
| } | |
| def _build_output_path(self, prefix: str) -> Path: | |
| timestamp = time.strftime("%Y%m%d_%H%M%S") | |
| random_suffix = uuid.uuid4().hex[:8] | |
| return self.output_dir / f"{prefix}_{timestamp}_{random_suffix}.wav" | |
| def synthesize( | |
| self, | |
| *, | |
| text: str, | |
| voice: Optional[str] = None, | |
| mode: str = "voice_clone", | |
| output_audio_path: Optional[str | Path] = None, | |
| prompt_audio_path: Optional[str | Path] = None, | |
| prompt_text: Optional[str] = None, | |
| max_new_frames: int = 375, | |
| voice_clone_max_text_tokens: int = 75, | |
| voice_clone_max_memory_per_sample_gb: float = 1.0, | |
| tts_max_batch_size: int = 0, | |
| codec_max_batch_size: int = 0, | |
| do_sample: bool = True, | |
| text_temperature: float = 1.0, | |
| text_top_p: float = 1.0, | |
| text_top_k: int = 50, | |
| audio_temperature: float = 0.8, | |
| audio_top_p: float = 0.95, | |
| audio_top_k: int = 25, | |
| audio_repetition_penalty: float = 1.2, | |
| nq: Optional[int] = None, | |
| seed: Optional[int] = None, | |
| attn_implementation: Optional[str] = None, | |
| ) -> dict[str, object]: | |
| normalized_text = str(text or "").strip() | |
| if not normalized_text: | |
| raise ValueError("text is required") | |
| normalized_mode = str(mode).strip().lower() | |
| if normalized_mode not in {"continuation", "voice_clone"}: | |
| raise ValueError("mode must be either 'continuation' or 'voice_clone'") | |
| effective_prompt_audio_path: Optional[Path] = None | |
| resolved_voice = self.get_voice_preset(voice).name | |
| if normalized_mode == "voice_clone": | |
| effective_prompt_audio_path = self.resolve_prompt_audio_path( | |
| voice=resolved_voice, | |
| prompt_audio_path=prompt_audio_path, | |
| ) | |
| elif prompt_audio_path is not None: | |
| effective_prompt_audio_path = self.resolve_prompt_audio_path(prompt_audio_path=prompt_audio_path) | |
| if not prompt_text: | |
| raise ValueError("continuation mode with prompt_audio_path also requires prompt_text") | |
| output_path = ( | |
| Path(output_audio_path).expanduser().resolve() | |
| if output_audio_path is not None | |
| else self._build_output_path(prefix=f"{resolved_voice}_{normalized_mode}") | |
| ) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| started_at = time.monotonic() | |
| with self._lock: | |
| model = self._load_model_locked() | |
| model = self._restore_model_execution_state(model) | |
| requested_attn_implementation, effective_global_attn_implementation, effective_local_attn_implementation = ( | |
| self._resolve_request_attention_implementation(attn_implementation) | |
| ) | |
| audio_tokenizer = self._load_audio_tokenizer_locked( | |
| tts_attn_implementation=effective_global_attn_implementation | |
| ) | |
| self._apply_model_attention_implementation( | |
| model, | |
| global_attn=effective_global_attn_implementation, | |
| local_attn=effective_local_attn_implementation, | |
| ) | |
| if seed is not None: | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| try: | |
| result = model.inference( | |
| text=normalized_text, | |
| output_audio_path=str(output_path), | |
| mode=normalized_mode, | |
| prompt_text=prompt_text, | |
| prompt_audio_path=None if effective_prompt_audio_path is None else str(effective_prompt_audio_path), | |
| text_tokenizer_path=str(self.checkpoint_path), | |
| audio_tokenizer=audio_tokenizer, | |
| device=self.device, | |
| nq=nq, | |
| max_new_frames=int(max_new_frames), | |
| voice_clone_max_text_tokens=int(voice_clone_max_text_tokens), | |
| voice_clone_max_memory_per_sample_gb=float(voice_clone_max_memory_per_sample_gb), | |
| tts_max_batch_size=int(tts_max_batch_size), | |
| codec_max_batch_size=int(codec_max_batch_size), | |
| do_sample=bool(do_sample), | |
| use_kv_cache=True, | |
| text_temperature=float(text_temperature), | |
| text_top_p=float(text_top_p), | |
| text_top_k=int(text_top_k), | |
| audio_temperature=float(audio_temperature), | |
| audio_top_p=float(audio_top_p), | |
| audio_top_k=int(audio_top_k), | |
| audio_repetition_penalty=float(audio_repetition_penalty), | |
| ) | |
| except Exception: | |
| self._discard_loaded_audio_tokenizer_locked( | |
| "inference failed; reloading audio tokenizer on next request" | |
| ) | |
| self._discard_loaded_model_locked("inference failed; reloading checkpoint on next request") | |
| raise | |
| effective_global_attn_implementation, effective_local_attn_implementation = ( | |
| self._read_model_attention_implementation(model) | |
| ) | |
| current_parameter = next(model.parameters(), None) | |
| if current_parameter is not None and current_parameter.dtype != self.dtype: | |
| self._discard_loaded_model_locked( | |
| f"inference left model in dtype={current_parameter.dtype}; reloading checkpoint on next request" | |
| ) | |
| waveform = result["waveform"].detach().cpu() | |
| waveform_numpy = waveform_to_numpy(waveform) | |
| return { | |
| "audio_path": str(output_path), | |
| "sample_rate": int(result["sample_rate"]), | |
| "waveform": waveform, | |
| "waveform_numpy": waveform_numpy, | |
| "audio_token_ids": result["audio_token_ids"], | |
| "reference_audio_token_ids": result["reference_audio_token_ids"], | |
| "elapsed_seconds": time.monotonic() - started_at, | |
| "voice": resolved_voice, | |
| "mode": normalized_mode, | |
| "prompt_audio_path": None if effective_prompt_audio_path is None else str(effective_prompt_audio_path), | |
| "requested_attn_implementation": requested_attn_implementation, | |
| "effective_global_attn_implementation": effective_global_attn_implementation, | |
| "effective_local_attn_implementation": effective_local_attn_implementation, | |
| "voice_clone_text_chunks": result.get("voice_clone_text_chunks"), | |
| "voice_clone_chunk_batch_size": result.get("voice_clone_chunk_batch_size"), | |
| "voice_clone_codec_batch_size": result.get("voice_clone_codec_batch_size"), | |
| } | |
| def synthesize_stream( | |
| self, | |
| *, | |
| text: str, | |
| voice: Optional[str] = None, | |
| mode: str = "voice_clone", | |
| output_audio_path: Optional[str | Path] = None, | |
| prompt_audio_path: Optional[str | Path] = None, | |
| prompt_text: Optional[str] = None, | |
| max_new_frames: int = 375, | |
| voice_clone_max_text_tokens: int = 75, | |
| voice_clone_max_memory_per_sample_gb: float = 1.0, | |
| tts_max_batch_size: int = 0, | |
| codec_max_batch_size: int = 0, | |
| do_sample: bool = True, | |
| text_temperature: float = 1.0, | |
| text_top_p: float = 1.0, | |
| text_top_k: int = 50, | |
| audio_temperature: float = 0.8, | |
| audio_top_p: float = 0.95, | |
| audio_top_k: int = 25, | |
| audio_repetition_penalty: float = 1.2, | |
| nq: Optional[int] = None, | |
| seed: Optional[int] = None, | |
| attn_implementation: Optional[str] = None, | |
| ) -> Iterator[dict[str, object]]: | |
| normalized_text = str(text or "").strip() | |
| if not normalized_text: | |
| raise ValueError("text is required") | |
| normalized_mode = str(mode).strip().lower() | |
| if normalized_mode not in {"continuation", "voice_clone"}: | |
| raise ValueError("mode must be either 'continuation' or 'voice_clone'") | |
| effective_prompt_audio_path: Optional[Path] = None | |
| resolved_voice = self.get_voice_preset(voice).name | |
| if normalized_mode == "voice_clone": | |
| effective_prompt_audio_path = self.resolve_prompt_audio_path( | |
| voice=resolved_voice, | |
| prompt_audio_path=prompt_audio_path, | |
| ) | |
| elif prompt_audio_path is not None: | |
| effective_prompt_audio_path = self.resolve_prompt_audio_path(prompt_audio_path=prompt_audio_path) | |
| if not prompt_text: | |
| raise ValueError("continuation mode with prompt_audio_path also requires prompt_text") | |
| output_path = ( | |
| Path(output_audio_path).expanduser().resolve() | |
| if output_audio_path is not None | |
| else self._build_output_path(prefix=f"{resolved_voice}_{normalized_mode}_stream") | |
| ) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| started_at = time.monotonic() | |
| final_result: dict[str, object] | None = None | |
| with self._lock: | |
| model = self._load_model_locked() | |
| model = self._restore_model_execution_state(model) | |
| requested_attn_implementation, effective_global_attn_implementation, effective_local_attn_implementation = ( | |
| self._resolve_request_attention_implementation(attn_implementation) | |
| ) | |
| audio_tokenizer = self._load_audio_tokenizer_locked( | |
| tts_attn_implementation=effective_global_attn_implementation | |
| ) | |
| self._apply_model_attention_implementation( | |
| model, | |
| global_attn=effective_global_attn_implementation, | |
| local_attn=effective_local_attn_implementation, | |
| ) | |
| if seed is not None: | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| try: | |
| for event in model.inference_stream( | |
| text=normalized_text, | |
| output_audio_path=str(output_path), | |
| mode=normalized_mode, | |
| prompt_text=prompt_text, | |
| prompt_audio_path=None if effective_prompt_audio_path is None else str(effective_prompt_audio_path), | |
| text_tokenizer_path=str(self.checkpoint_path), | |
| audio_tokenizer=audio_tokenizer, | |
| device=self.device, | |
| nq=nq, | |
| max_new_frames=int(max_new_frames), | |
| voice_clone_max_text_tokens=int(voice_clone_max_text_tokens), | |
| voice_clone_max_memory_per_sample_gb=float(voice_clone_max_memory_per_sample_gb), | |
| tts_max_batch_size=int(tts_max_batch_size), | |
| codec_max_batch_size=int(codec_max_batch_size), | |
| do_sample=bool(do_sample), | |
| use_kv_cache=True, | |
| text_temperature=float(text_temperature), | |
| text_top_p=float(text_top_p), | |
| text_top_k=int(text_top_k), | |
| audio_temperature=float(audio_temperature), | |
| audio_top_p=float(audio_top_p), | |
| audio_top_k=int(audio_top_k), | |
| audio_repetition_penalty=float(audio_repetition_penalty), | |
| ): | |
| event_type = str(event.get("type", "")) | |
| if event_type == "audio": | |
| waveform = torch.as_tensor(event["waveform"], dtype=torch.float32).cpu() | |
| yield { | |
| "type": "audio", | |
| "waveform": waveform, | |
| "waveform_numpy": waveform_to_numpy(waveform), | |
| "sample_rate": int(event["sample_rate"]), | |
| "chunk_index": int(event.get("chunk_index", 0)), | |
| "is_pause": bool(event.get("is_pause", False)), | |
| "emitted_audio_seconds": float(event.get("emitted_audio_seconds", 0.0)), | |
| "lead_seconds": float(event.get("lead_seconds", 0.0)), | |
| } | |
| continue | |
| if event_type == "result": | |
| final_result = dict(event) | |
| except Exception: | |
| self._discard_loaded_audio_tokenizer_locked( | |
| "streaming inference failed; reloading audio tokenizer on next request" | |
| ) | |
| self._discard_loaded_model_locked("streaming inference failed; reloading checkpoint on next request") | |
| raise | |
| effective_global_attn_implementation, effective_local_attn_implementation = ( | |
| self._read_model_attention_implementation(model) | |
| ) | |
| current_parameter = next(model.parameters(), None) | |
| if current_parameter is not None and current_parameter.dtype != self.dtype: | |
| self._discard_loaded_model_locked( | |
| f"streaming inference left model in dtype={current_parameter.dtype}; reloading checkpoint on next request" | |
| ) | |
| if final_result is None: | |
| raise RuntimeError("Streaming synthesis finished without a final result.") | |
| waveform = torch.as_tensor(final_result["waveform"], dtype=torch.float32).cpu() | |
| yield { | |
| "type": "result", | |
| "audio_path": str(final_result["audio_path"]), | |
| "sample_rate": int(final_result["sample_rate"]), | |
| "waveform": waveform, | |
| "waveform_numpy": waveform_to_numpy(waveform), | |
| "audio_token_ids": final_result["audio_token_ids"], | |
| "reference_audio_token_ids": final_result["reference_audio_token_ids"], | |
| "elapsed_seconds": time.monotonic() - started_at, | |
| "voice": resolved_voice, | |
| "mode": normalized_mode, | |
| "prompt_audio_path": None if effective_prompt_audio_path is None else str(effective_prompt_audio_path), | |
| "requested_attn_implementation": requested_attn_implementation, | |
| "effective_global_attn_implementation": effective_global_attn_implementation, | |
| "effective_local_attn_implementation": effective_local_attn_implementation, | |
| "voice_clone_text_chunks": final_result.get("voice_clone_text_chunks"), | |
| "voice_clone_chunk_batch_size": final_result.get("voice_clone_chunk_batch_size"), | |
| "voice_clone_codec_batch_size": final_result.get("voice_clone_codec_batch_size"), | |
| } | |
| def warmup( | |
| self, | |
| *, | |
| text: str = "你好,欢迎使用 Nano-TTS。", | |
| voice: Optional[str] = None, | |
| ) -> dict[str, object]: | |
| return self.synthesize( | |
| text=text, | |
| voice=voice or self.default_voice, | |
| mode="voice_clone", | |
| output_audio_path=self.output_dir / "_warmup" / "warmup.wav", | |
| max_new_frames=96, | |
| voice_clone_max_text_tokens=75, | |
| do_sample=False, | |
| text_temperature=1.0, | |
| text_top_p=1.0, | |
| text_top_k=50, | |
| audio_temperature=0.8, | |
| audio_top_p=0.95, | |
| audio_top_k=25, | |
| audio_repetition_penalty=1.0, | |
| ) | |