from __future__ import annotations import base64 import hashlib import inspect import os import random import tempfile from pathlib import Path from typing import Any from time import perf_counter import modal app = modal.App("ai-time-machine-audio") NEMOTRON_STT_MODEL_ID = "nvidia/nemotron-3.5-asr-streaming-0.6b" QWEN_TTS_MODEL_ID = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign" base_image = ( modal.Image.debian_slim(python_version="3.12") .apt_install("ffmpeg", "git", "libsndfile1", "sox") .pip_install("fastapi[standard]") ) nemotron_image = ( base_image .pip_install("Cython", "packaging") .pip_install("git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]") ) qwen_tts_image = base_image.pip_install("qwen-tts", "soundfile") chatterbox_tts_image = base_image.pip_install("chatterbox-tts", "soundfile", "torchaudio") _asr_model: Any | None = None _tts_model: Any | None = None _chatterbox_tts_model: Any | None = None _asr_loaded_at: float | None = None _tts_loaded_at: float | None = None _chatterbox_tts_loaded_at: float | None = None MIN_CONTAINERS = int(os.getenv("TIME_MACHINE_MODAL_MIN_CONTAINERS", "1")) MAX_CONTAINERS = int(os.getenv("TIME_MACHINE_MODAL_MAX_CONTAINERS", "1")) ACTIVE_TTS_MODEL_FAMILY = os.getenv( "TIME_MACHINE_MODAL_TTS_MODEL_FAMILY", "chatterbox_turbo", ).strip().lower().replace("-", "_") QWEN_TTS_ENABLED = ACTIVE_TTS_MODEL_FAMILY == "qwen" CHATTERBOX_TTS_ENABLED = ACTIVE_TTS_MODEL_FAMILY in {"chatterbox", "chatterbox_turbo", "turbo"} SCALEDOWN_WINDOW_SECONDS = int(os.getenv("TIME_MACHINE_MODAL_SCALEDOWN_SECONDS", "1800")) STARTUP_TIMEOUT_SECONDS = int(os.getenv("TIME_MACHINE_MODAL_STARTUP_TIMEOUT_SECONDS", "900")) WARMUP_TTS = os.getenv("TIME_MACHINE_MODAL_WARMUP_TTS", "1").strip().lower() in { "1", "true", "yes", "on", } print( "Modal audio service config: " f"active_tts={ACTIVE_TTS_MODEL_FAMILY} " f"stt_min_containers={MIN_CONTAINERS} " f"qwen_tts_registered={QWEN_TTS_ENABLED} " f"chatterbox_tts_registered={CHATTERBOX_TTS_ENABLED}" ) # Persistent volume to share and cache downloaded model weights hf_volume = modal.Volume.from_name("hf-cache-vol", create_if_missing=True) @app.cls( image=nemotron_image, gpu="A10G", timeout=600, startup_timeout=STARTUP_TIMEOUT_SECONDS, scaledown_window=SCALEDOWN_WINDOW_SECONDS, min_containers=MIN_CONTAINERS, max_containers=MAX_CONTAINERS, volumes={"/root/.cache/huggingface": hf_volume}, ) class NemotronSTTService: model: Any @modal.enter() def load(self) -> None: self.model = _load_asr_model() @modal.fastapi_endpoint(method="POST", label="time-machine-nemotron-stt") def transcribe(self, item: dict[str, Any]) -> dict[str, Any]: request_started = perf_counter() audio_b64 = _required_string(item, "audio_b64") language = str(item.get("language") or "auto") audio_path = _write_request_audio(audio_b64) target_lang = _target_language(language) if hasattr(self.model, "set_inference_prompt"): self.model.set_inference_prompt(target_lang) preprocess_started = perf_counter() proc_path = _convert_to_mono_16k(audio_path) duration = _get_audio_duration(proc_path) preprocess_seconds = perf_counter() - preprocess_started inference_started = perf_counter() result = _transcribe_manifest(self.model, proc_path, target_lang, duration) inference_seconds = perf_counter() - inference_started if proc_path.exists() and proc_path != audio_path: proc_path.unlink(missing_ok=True) audio_path.unlink(missing_ok=True) text = _extract_transcript_text(result) return { "text": text, "confidence": None, "language": None if language == "auto" else language, "is_final": True, "timings": { "preprocess_seconds": round(preprocess_seconds, 3), "inference_seconds": round(inference_seconds, 3), "total_seconds": round(perf_counter() - request_started, 3), "model_loaded_at": _asr_loaded_at, }, } if QWEN_TTS_ENABLED: @app.cls( image=qwen_tts_image, gpu="A10G", timeout=600, startup_timeout=STARTUP_TIMEOUT_SECONDS, scaledown_window=SCALEDOWN_WINDOW_SECONDS, min_containers=MIN_CONTAINERS, max_containers=MAX_CONTAINERS, volumes={"/root/.cache/huggingface": hf_volume}, ) class QwenTTSService: model: Any @modal.enter() def load(self) -> None: self.model = _load_tts_model() if WARMUP_TTS: self._warm_up() @modal.fastapi_endpoint(method="POST", label="time-machine-qwen-tts") def synthesize(self, item: dict[str, Any]) -> dict[str, Any]: request_started = perf_counter() text = _required_string(item, "text") voice_profile = item.get("voice_profile") if not isinstance(voice_profile, dict): raise ValueError("voice_profile must be an object.") language = str(item.get("language") or "English") prosody_hint = item.get("prosody_hint") instruction = _voice_instruction(voice_profile, prosody_hint) voice_seed = _voice_seed(voice_profile, item.get("voice_seed")) inference_started = perf_counter() audio_bytes, duration_seconds = _synthesize_to_wav_bytes( self.model, text=text, language=language, instruction=instruction, seed=voice_seed, ) inference_seconds = perf_counter() - inference_started return { "audio_b64": base64.b64encode(audio_bytes).decode("ascii"), "mime_type": "audio/wav", "duration_seconds": duration_seconds, "description": "Qwen3-TTS VoiceDesign synthesis on warm Modal GPU.", "timings": { "inference_seconds": round(inference_seconds, 3), "total_seconds": round(perf_counter() - request_started, 3), "model_loaded_at": _tts_loaded_at, }, } def _warm_up(self) -> None: try: _synthesize_to_wav_bytes( self.model, text="The signal is open.", language="English", instruction="Natural conversational voice. Pace: fast. Emotion: curious.", ) except Exception as exc: print(f"Qwen TTS warmup failed; first request may still pay setup cost: {exc}") else: print("Qwen TTS service not registered for this Modal serve run.") if CHATTERBOX_TTS_ENABLED: @app.cls( image=chatterbox_tts_image, gpu=os.getenv("TIME_MACHINE_CHATTERBOX_GPU", "L4"), timeout=600, startup_timeout=STARTUP_TIMEOUT_SECONDS, scaledown_window=SCALEDOWN_WINDOW_SECONDS, min_containers=MIN_CONTAINERS, max_containers=MAX_CONTAINERS, volumes={"/root/.cache/huggingface": hf_volume}, ) class ChatterboxTurboTTSService: model: Any @modal.enter() def load(self) -> None: self.runtime_name, self.model = _load_chatterbox_tts_model() if WARMUP_TTS: self._warm_up() @modal.fastapi_endpoint(method="POST", label="time-machine-chatterbox-turbo-tts") def synthesize(self, item: dict[str, Any]) -> dict[str, Any]: request_started = perf_counter() text = _required_string(item, "text") voice_profile = item.get("voice_profile") if not isinstance(voice_profile, dict): raise ValueError("voice_profile must be an object.") prosody_hint = item.get("prosody_hint") prompt = _voice_instruction(voice_profile, prosody_hint) voice_seed = _voice_seed(voice_profile, item.get("voice_seed")) exaggeration = _float_item(item, "exaggeration", 0.65) cfg_weight = _float_item(item, "cfg_weight", 0.35) temperature = _float_item(item, "temperature", 0.8) latency_profile = str(item.get("latency_profile") or "turbo") runtime_name = str(getattr(self, "runtime_name", "turbo")) inference_started = perf_counter() audio_bytes, duration_seconds = _synthesize_chatterbox_to_wav_bytes( self.model, text=text, prompt=prompt, seed=voice_seed, runtime_name=runtime_name, exaggeration=exaggeration, cfg_weight=cfg_weight, temperature=temperature, ) inference_seconds = perf_counter() - inference_started return { "audio_b64": base64.b64encode(audio_bytes).decode("ascii"), "mime_type": "audio/wav", "duration_seconds": duration_seconds, "description": _chatterbox_description( runtime_name, latency_profile, exaggeration, cfg_weight, ), "timings": { "inference_seconds": round(inference_seconds, 3), "total_seconds": round(perf_counter() - request_started, 3), "model_loaded_at": _chatterbox_tts_loaded_at, }, } def _warm_up(self) -> None: try: _synthesize_chatterbox_to_wav_bytes( self.model, text="The signal is open.", prompt="Natural expressive character voice. Pace: fast. Emotion: curious.", runtime_name=str(getattr(self, "runtime_name", "turbo")), exaggeration=0.6, cfg_weight=0.35, temperature=0.8, ) except Exception as exc: print(f"Chatterbox Turbo TTS warmup failed; first request may still pay setup cost: {exc}") else: print("Chatterbox TTS service not registered for this Modal serve run.") def _load_asr_model() -> Any: global _asr_model, _asr_loaded_at if _asr_model is not None: return _asr_model started = perf_counter() import nemo.collections.asr as nemo_asr print(f"Loading Modal STT model: provider=nvidia runtime=nemo model={NEMOTRON_STT_MODEL_ID}") _asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=NEMOTRON_STT_MODEL_ID) print( "Loaded Modal STT model: " f"{_asr_model.__class__.__module__}.{_asr_model.__class__.__name__} " f"model={NEMOTRON_STT_MODEL_ID}" ) _asr_loaded_at = round(started, 3) return _asr_model def _load_tts_model() -> Any: global _tts_model, _tts_loaded_at if _tts_model is not None: return _tts_model started = perf_counter() import torch from qwen_tts import Qwen3TTSModel kwargs: dict[str, Any] = { "device_map": "cuda:0" if torch.cuda.is_available() else "cpu", "dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32, } print( "Loading Modal TTS model: " f"provider=qwen runtime=qwen-tts model={QWEN_TTS_MODEL_ID} " f"device_map={kwargs['device_map']} dtype={kwargs['dtype']}" ) try: _tts_model = Qwen3TTSModel.from_pretrained( QWEN_TTS_MODEL_ID, attn_implementation="flash_attention_2", **kwargs, ) attention = "flash_attention_2" except Exception: _tts_model = Qwen3TTSModel.from_pretrained( QWEN_TTS_MODEL_ID, **kwargs, ) attention = "default" print( "Loaded Modal TTS model: " f"{_tts_model.__class__.__module__}.{_tts_model.__class__.__name__} " f"model={QWEN_TTS_MODEL_ID} attention={attention}" ) _tts_loaded_at = round(started, 3) return _tts_model def _load_chatterbox_tts_model() -> tuple[str, Any]: global _chatterbox_tts_model, _chatterbox_tts_loaded_at if _chatterbox_tts_model is not None: return _infer_chatterbox_runtime(_chatterbox_tts_model), _chatterbox_tts_model started = perf_counter() import torch device = "cuda" if torch.cuda.is_available() else "cpu" _ensure_chatterbox_watermarker() runtime_name, _chatterbox_tts_model = _load_first_available_chatterbox_model(device) print( "Loaded Modal TTS model: " f"provider=resemble_ai runtime=chatterbox-{runtime_name} " f"configured_models={','.join(_chatterbox_model_candidates())} " f"{runtime_name} ({_chatterbox_tts_model.__class__.__module__}." f"{_chatterbox_tts_model.__class__.__name__}) on {device}" ) _chatterbox_tts_loaded_at = round(started, 3) return runtime_name, _chatterbox_tts_model def _ensure_chatterbox_watermarker() -> None: try: import perth except Exception as exc: print(f"Chatterbox Perth watermarker unavailable; using no-op watermarker: {exc}") return watermarker_cls = getattr(perth, "PerthImplicitWatermarker", None) if callable(watermarker_cls): return class _NoOpWatermarker: def apply_watermark(self, wav: Any, *args: Any, **kwargs: Any) -> Any: return wav def watermark(self, wav: Any, *args: Any, **kwargs: Any) -> Any: return wav perth.PerthImplicitWatermarker = _NoOpWatermarker print("Chatterbox PerthImplicitWatermarker is missing; using no-op watermarker.") def _load_first_available_chatterbox_model(device: str) -> tuple[str, Any]: last_error: Exception | None = None for runtime_name, chatterbox_cls in _chatterbox_model_classes(): try: print(f"Trying Chatterbox TTS runtime: {runtime_name}") return runtime_name, _load_chatterbox_from_pretrained(chatterbox_cls, device) except Exception as exc: last_error = exc print(f"Chatterbox {runtime_name} load failed; trying next fallback: {exc}") if last_error is not None: raise last_error raise RuntimeError("No Chatterbox TTS runtime is available.") def _chatterbox_model_classes() -> list[tuple[str, Any]]: classes: list[tuple[str, Any]] = [] try: from chatterbox.tts_turbo import ChatterboxTurboTTS classes.append(("turbo", ChatterboxTurboTTS)) except Exception as exc: print(f"Chatterbox Turbo runtime unavailable; trying standard Chatterbox: {exc}") try: from chatterbox.tts import ChatterboxTTS classes.append(("standard", ChatterboxTTS)) except Exception as exc: print(f"Standard Chatterbox runtime unavailable: {exc}") return classes def _load_chatterbox_from_pretrained(chatterbox_cls: Any, device: str) -> Any: from_pretrained = chatterbox_cls.from_pretrained if _supports_chatterbox_model_id(from_pretrained): for model_id in _chatterbox_model_candidates(): try: return from_pretrained(model_id, device=device) except Exception as exc: print(f"Chatterbox model id {model_id!r} failed; trying next fallback: {exc}") return from_pretrained(device=device) def _supports_chatterbox_model_id(from_pretrained: Any) -> bool: try: parameters = inspect.signature(from_pretrained).parameters except (TypeError, ValueError): return False positional = [ parameter for parameter in parameters.values() if parameter.kind in { inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, } ] model_id_parameter_names = { "model_id", "repo_id", "pretrained_model_name_or_path", "checkpoint", } return bool(positional) and positional[0].name in model_id_parameter_names def _chatterbox_model_candidates() -> list[str]: raw = os.getenv("TIME_MACHINE_CHATTERBOX_TURBO_MODEL_ID", "ResembleAI/chatterbox-turbo") raw_list = os.getenv("TIME_MACHINE_CHATTERBOX_MODEL_IDS", raw) candidates = [item.strip() for item in raw_list.split(",") if item.strip()] if "ResembleAI/chatterbox" not in candidates: candidates.append("ResembleAI/chatterbox") return candidates def _target_language(language: str) -> str: supported_langs = { "en-US", "en", "en-GB", "enGB", "es-ES", "esES", "es-US", "es", "zh-CN", "zh-ZH" } return language if language in supported_langs else "en" def _transcribe_manifest( model: Any, proc_path: Path, target_lang: str, duration: float, ) -> Any: import json manifest_path = proc_path.with_suffix(".json") try: with open(manifest_path, "w", encoding="utf-8") as f: f.write(json.dumps({ "audio_filepath": str(proc_path), "duration": duration, "text": "", "target_lang": target_lang, "lang": target_lang, "language": target_lang }) + "\n") return model.transcribe([str(manifest_path)]) finally: manifest_path.unlink(missing_ok=True) def _synthesize_to_wav_bytes( model: Any, text: str, language: str, instruction: str, seed: int | None = None, ) -> tuple[bytes, float]: if seed is not None: _seed_tts_generation(seed) wavs, sample_rate = model.generate_voice_design( text=text, language=language, instruct=instruction, ) output_path = Path(tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name) try: import soundfile as sf sf.write(output_path, wavs[0], sample_rate) audio_bytes = output_path.read_bytes() duration_seconds = round(len(wavs[0]) / float(sample_rate), 3) finally: output_path.unlink(missing_ok=True) return audio_bytes, duration_seconds def _synthesize_chatterbox_to_wav_bytes( model: Any, text: str, prompt: str, seed: int | None = None, runtime_name: str | None = None, exaggeration: float = 0.65, cfg_weight: float = 0.35, temperature: float = 0.8, ) -> tuple[bytes, float]: if seed is not None: _seed_tts_generation(seed) runtime = runtime_name or _infer_chatterbox_runtime(model) generate_values: dict[str, Any] = { "text": text, "prompt": prompt, "condition_prompt": prompt, "temperature": temperature, } if runtime != "turbo": generate_values["exaggeration"] = exaggeration generate_values["cfg_weight"] = cfg_weight generate_kwargs = _supported_kwargs(model.generate, generate_values) if "text" in generate_kwargs: wav = model.generate(**generate_kwargs) else: wav = model.generate(text, **generate_kwargs) sample_rate = int(getattr(model, "sr", getattr(model, "sample_rate", 24000))) audio = _to_numpy_audio(wav) output_path = Path(tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name) try: import soundfile as sf sf.write(output_path, audio, sample_rate) audio_bytes = output_path.read_bytes() duration_seconds = round(len(audio) / float(sample_rate), 3) finally: output_path.unlink(missing_ok=True) return audio_bytes, duration_seconds def _chatterbox_description( runtime_name: str, latency_profile: str, exaggeration: float, cfg_weight: float, ) -> str: if runtime_name == "turbo": return f"Chatterbox Turbo TTS synthesis on warm Modal GPU ({latency_profile})." return ( "Chatterbox TTS synthesis on warm Modal GPU " f"({latency_profile}, exaggeration={exaggeration:g}, cfg={cfg_weight:g})." ) def _infer_chatterbox_runtime(model: Any) -> str: class_path = f"{model.__class__.__module__}.{model.__class__.__name__}".lower() return "turbo" if "turbo" in class_path else "standard" def _supported_kwargs(callable_obj: Any, values: dict[str, Any]) -> dict[str, Any]: try: parameters = inspect.signature(callable_obj).parameters except (TypeError, ValueError): return values if any(parameter.kind == inspect.Parameter.VAR_KEYWORD for parameter in parameters.values()): return values return {key: value for key, value in values.items() if key in parameters} def _to_numpy_audio(wav: Any) -> Any: if isinstance(wav, tuple) and wav: wav = wav[0] if isinstance(wav, list) and wav and hasattr(wav[0], "__len__"): wav = wav[0] if hasattr(wav, "detach"): wav = wav.detach().cpu().numpy() elif hasattr(wav, "cpu"): wav = wav.cpu().numpy() elif isinstance(wav, list): import numpy as np wav = np.asarray(wav, dtype="float32") if hasattr(wav, "squeeze"): wav = wav.squeeze() return wav def _seed_tts_generation(seed: int) -> None: normalized = seed % (2**31 - 1) random.seed(normalized) try: import numpy as np np.random.seed(normalized) except Exception: pass try: import torch torch.manual_seed(normalized) if torch.cuda.is_available(): torch.cuda.manual_seed_all(normalized) except Exception: pass def _convert_to_mono_16k(input_path: Path) -> Path: if _is_mono_16k_wav(input_path): return input_path import subprocess output_path = input_path.with_name(input_path.stem + "_mono_16k.wav") try: cmd = [ "ffmpeg", "-y", "-i", str(input_path), "-ac", "1", "-ar", "16000", str(output_path) ] subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) return output_path except Exception as e: print(f"ffmpeg conversion failed: {e}") return input_path def _is_mono_16k_wav(path: Path) -> bool: try: import wave with wave.open(str(path), "rb") as handle: return ( handle.getnchannels() == 1 and handle.getframerate() == 16000 and handle.getsampwidth() == 2 ) except Exception: return False def _get_audio_duration(path: Path) -> float: try: import wave with wave.open(str(path), "rb") as f: frames = f.getnframes() rate = f.getframerate() return max(0.1, frames / float(rate)) except Exception: return 10.0 def _write_request_audio(audio_b64: str) -> Path: audio_bytes = base64.b64decode(audio_b64) audio_path = Path(tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name) audio_path.write_bytes(audio_bytes) return audio_path def _extract_transcript_text(result: Any) -> str: if isinstance(result, str): return result.strip() if isinstance(result, list) and result: first = result[0] if isinstance(first, str): return first.strip() if hasattr(first, "text"): return str(first.text).strip() if isinstance(first, dict): return str(first.get("text", "")).strip() if hasattr(result, "text"): return str(result.text).strip() return "" def _voice_instruction( voice_profile: dict[str, Any], prosody_hint: object, ) -> str: voice_id = str(voice_profile.get("voice_id") or "character") parts = [ ( f"Consistent speaker identity: {voice_id}. " "Keep the same timbre and apparent speaker across separate lines." ), str(voice_profile.get("description") or "Natural conversational character voice."), f"Pace: {voice_profile.get('pace') or 'medium'}.", f"Emotion: {voice_profile.get('emotion') or 'curious'}.", ] accent_hint = voice_profile.get("accent_hint") if accent_hint: parts.append(f"Accent or local color: {accent_hint}.") if prosody_hint: parts.append(f"Prosody: {prosody_hint}.") return " ".join(parts) def _voice_seed(voice_profile: dict[str, Any], provided: object = None) -> int: try: parsed = int(provided) except (TypeError, ValueError): parsed = 0 if parsed > 0: return parsed payload = "\n".join( [ str(voice_profile.get("voice_id") or ""), str(voice_profile.get("description") or ""), str(voice_profile.get("pace") or ""), str(voice_profile.get("emotion") or ""), str(voice_profile.get("accent_hint") or ""), ] ) return int(hashlib.sha256(payload.encode("utf-8")).hexdigest()[:8], 16) def _float_item(item: dict[str, Any], key: str, default: float) -> float: try: return float(item.get(key, default)) except (TypeError, ValueError): return default def _required_string(item: dict[str, Any], key: str) -> str: value = item.get(key) if not isinstance(value, str) or not value: raise ValueError(f"{key} is required.") return value