ai-time-machine / scripts /modal_audio.py
manikandanj's picture
Prepare AI Time Machine hackathon Space
5862322 verified
Raw
History Blame Contribute Delete
25.9 kB
from __future__ import annotations
import base64
import hashlib
import inspect
import os
import random
import tempfile
from pathlib import Path
from typing import Any
from time import perf_counter
import modal
app = modal.App("ai-time-machine-audio")
NEMOTRON_STT_MODEL_ID = "nvidia/nemotron-3.5-asr-streaming-0.6b"
QWEN_TTS_MODEL_ID = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign"
base_image = (
modal.Image.debian_slim(python_version="3.12")
.apt_install("ffmpeg", "git", "libsndfile1", "sox")
.pip_install("fastapi[standard]")
)
nemotron_image = (
base_image
.pip_install("Cython", "packaging")
.pip_install("git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]")
)
qwen_tts_image = base_image.pip_install("qwen-tts", "soundfile")
chatterbox_tts_image = base_image.pip_install("chatterbox-tts", "soundfile", "torchaudio")
_asr_model: Any | None = None
_tts_model: Any | None = None
_chatterbox_tts_model: Any | None = None
_asr_loaded_at: float | None = None
_tts_loaded_at: float | None = None
_chatterbox_tts_loaded_at: float | None = None
MIN_CONTAINERS = int(os.getenv("TIME_MACHINE_MODAL_MIN_CONTAINERS", "1"))
MAX_CONTAINERS = int(os.getenv("TIME_MACHINE_MODAL_MAX_CONTAINERS", "1"))
ACTIVE_TTS_MODEL_FAMILY = os.getenv(
"TIME_MACHINE_MODAL_TTS_MODEL_FAMILY",
"chatterbox_turbo",
).strip().lower().replace("-", "_")
QWEN_TTS_ENABLED = ACTIVE_TTS_MODEL_FAMILY == "qwen"
CHATTERBOX_TTS_ENABLED = ACTIVE_TTS_MODEL_FAMILY in {"chatterbox", "chatterbox_turbo", "turbo"}
SCALEDOWN_WINDOW_SECONDS = int(os.getenv("TIME_MACHINE_MODAL_SCALEDOWN_SECONDS", "1800"))
STARTUP_TIMEOUT_SECONDS = int(os.getenv("TIME_MACHINE_MODAL_STARTUP_TIMEOUT_SECONDS", "900"))
WARMUP_TTS = os.getenv("TIME_MACHINE_MODAL_WARMUP_TTS", "1").strip().lower() in {
"1",
"true",
"yes",
"on",
}
print(
"Modal audio service config: "
f"active_tts={ACTIVE_TTS_MODEL_FAMILY} "
f"stt_min_containers={MIN_CONTAINERS} "
f"qwen_tts_registered={QWEN_TTS_ENABLED} "
f"chatterbox_tts_registered={CHATTERBOX_TTS_ENABLED}"
)
# Persistent volume to share and cache downloaded model weights
hf_volume = modal.Volume.from_name("hf-cache-vol", create_if_missing=True)
@app.cls(
image=nemotron_image,
gpu="A10G",
timeout=600,
startup_timeout=STARTUP_TIMEOUT_SECONDS,
scaledown_window=SCALEDOWN_WINDOW_SECONDS,
min_containers=MIN_CONTAINERS,
max_containers=MAX_CONTAINERS,
volumes={"/root/.cache/huggingface": hf_volume},
)
class NemotronSTTService:
model: Any
@modal.enter()
def load(self) -> None:
self.model = _load_asr_model()
@modal.fastapi_endpoint(method="POST", label="time-machine-nemotron-stt")
def transcribe(self, item: dict[str, Any]) -> dict[str, Any]:
request_started = perf_counter()
audio_b64 = _required_string(item, "audio_b64")
language = str(item.get("language") or "auto")
audio_path = _write_request_audio(audio_b64)
target_lang = _target_language(language)
if hasattr(self.model, "set_inference_prompt"):
self.model.set_inference_prompt(target_lang)
preprocess_started = perf_counter()
proc_path = _convert_to_mono_16k(audio_path)
duration = _get_audio_duration(proc_path)
preprocess_seconds = perf_counter() - preprocess_started
inference_started = perf_counter()
result = _transcribe_manifest(self.model, proc_path, target_lang, duration)
inference_seconds = perf_counter() - inference_started
if proc_path.exists() and proc_path != audio_path:
proc_path.unlink(missing_ok=True)
audio_path.unlink(missing_ok=True)
text = _extract_transcript_text(result)
return {
"text": text,
"confidence": None,
"language": None if language == "auto" else language,
"is_final": True,
"timings": {
"preprocess_seconds": round(preprocess_seconds, 3),
"inference_seconds": round(inference_seconds, 3),
"total_seconds": round(perf_counter() - request_started, 3),
"model_loaded_at": _asr_loaded_at,
},
}
if QWEN_TTS_ENABLED:
@app.cls(
image=qwen_tts_image,
gpu="A10G",
timeout=600,
startup_timeout=STARTUP_TIMEOUT_SECONDS,
scaledown_window=SCALEDOWN_WINDOW_SECONDS,
min_containers=MIN_CONTAINERS,
max_containers=MAX_CONTAINERS,
volumes={"/root/.cache/huggingface": hf_volume},
)
class QwenTTSService:
model: Any
@modal.enter()
def load(self) -> None:
self.model = _load_tts_model()
if WARMUP_TTS:
self._warm_up()
@modal.fastapi_endpoint(method="POST", label="time-machine-qwen-tts")
def synthesize(self, item: dict[str, Any]) -> dict[str, Any]:
request_started = perf_counter()
text = _required_string(item, "text")
voice_profile = item.get("voice_profile")
if not isinstance(voice_profile, dict):
raise ValueError("voice_profile must be an object.")
language = str(item.get("language") or "English")
prosody_hint = item.get("prosody_hint")
instruction = _voice_instruction(voice_profile, prosody_hint)
voice_seed = _voice_seed(voice_profile, item.get("voice_seed"))
inference_started = perf_counter()
audio_bytes, duration_seconds = _synthesize_to_wav_bytes(
self.model,
text=text,
language=language,
instruction=instruction,
seed=voice_seed,
)
inference_seconds = perf_counter() - inference_started
return {
"audio_b64": base64.b64encode(audio_bytes).decode("ascii"),
"mime_type": "audio/wav",
"duration_seconds": duration_seconds,
"description": "Qwen3-TTS VoiceDesign synthesis on warm Modal GPU.",
"timings": {
"inference_seconds": round(inference_seconds, 3),
"total_seconds": round(perf_counter() - request_started, 3),
"model_loaded_at": _tts_loaded_at,
},
}
def _warm_up(self) -> None:
try:
_synthesize_to_wav_bytes(
self.model,
text="The signal is open.",
language="English",
instruction="Natural conversational voice. Pace: fast. Emotion: curious.",
)
except Exception as exc:
print(f"Qwen TTS warmup failed; first request may still pay setup cost: {exc}")
else:
print("Qwen TTS service not registered for this Modal serve run.")
if CHATTERBOX_TTS_ENABLED:
@app.cls(
image=chatterbox_tts_image,
gpu=os.getenv("TIME_MACHINE_CHATTERBOX_GPU", "L4"),
timeout=600,
startup_timeout=STARTUP_TIMEOUT_SECONDS,
scaledown_window=SCALEDOWN_WINDOW_SECONDS,
min_containers=MIN_CONTAINERS,
max_containers=MAX_CONTAINERS,
volumes={"/root/.cache/huggingface": hf_volume},
)
class ChatterboxTurboTTSService:
model: Any
@modal.enter()
def load(self) -> None:
self.runtime_name, self.model = _load_chatterbox_tts_model()
if WARMUP_TTS:
self._warm_up()
@modal.fastapi_endpoint(method="POST", label="time-machine-chatterbox-turbo-tts")
def synthesize(self, item: dict[str, Any]) -> dict[str, Any]:
request_started = perf_counter()
text = _required_string(item, "text")
voice_profile = item.get("voice_profile")
if not isinstance(voice_profile, dict):
raise ValueError("voice_profile must be an object.")
prosody_hint = item.get("prosody_hint")
prompt = _voice_instruction(voice_profile, prosody_hint)
voice_seed = _voice_seed(voice_profile, item.get("voice_seed"))
exaggeration = _float_item(item, "exaggeration", 0.65)
cfg_weight = _float_item(item, "cfg_weight", 0.35)
temperature = _float_item(item, "temperature", 0.8)
latency_profile = str(item.get("latency_profile") or "turbo")
runtime_name = str(getattr(self, "runtime_name", "turbo"))
inference_started = perf_counter()
audio_bytes, duration_seconds = _synthesize_chatterbox_to_wav_bytes(
self.model,
text=text,
prompt=prompt,
seed=voice_seed,
runtime_name=runtime_name,
exaggeration=exaggeration,
cfg_weight=cfg_weight,
temperature=temperature,
)
inference_seconds = perf_counter() - inference_started
return {
"audio_b64": base64.b64encode(audio_bytes).decode("ascii"),
"mime_type": "audio/wav",
"duration_seconds": duration_seconds,
"description": _chatterbox_description(
runtime_name,
latency_profile,
exaggeration,
cfg_weight,
),
"timings": {
"inference_seconds": round(inference_seconds, 3),
"total_seconds": round(perf_counter() - request_started, 3),
"model_loaded_at": _chatterbox_tts_loaded_at,
},
}
def _warm_up(self) -> None:
try:
_synthesize_chatterbox_to_wav_bytes(
self.model,
text="The signal is open.",
prompt="Natural expressive character voice. Pace: fast. Emotion: curious.",
runtime_name=str(getattr(self, "runtime_name", "turbo")),
exaggeration=0.6,
cfg_weight=0.35,
temperature=0.8,
)
except Exception as exc:
print(f"Chatterbox Turbo TTS warmup failed; first request may still pay setup cost: {exc}")
else:
print("Chatterbox TTS service not registered for this Modal serve run.")
def _load_asr_model() -> Any:
global _asr_model, _asr_loaded_at
if _asr_model is not None:
return _asr_model
started = perf_counter()
import nemo.collections.asr as nemo_asr
print(f"Loading Modal STT model: provider=nvidia runtime=nemo model={NEMOTRON_STT_MODEL_ID}")
_asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=NEMOTRON_STT_MODEL_ID)
print(
"Loaded Modal STT model: "
f"{_asr_model.__class__.__module__}.{_asr_model.__class__.__name__} "
f"model={NEMOTRON_STT_MODEL_ID}"
)
_asr_loaded_at = round(started, 3)
return _asr_model
def _load_tts_model() -> Any:
global _tts_model, _tts_loaded_at
if _tts_model is not None:
return _tts_model
started = perf_counter()
import torch
from qwen_tts import Qwen3TTSModel
kwargs: dict[str, Any] = {
"device_map": "cuda:0" if torch.cuda.is_available() else "cpu",
"dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32,
}
print(
"Loading Modal TTS model: "
f"provider=qwen runtime=qwen-tts model={QWEN_TTS_MODEL_ID} "
f"device_map={kwargs['device_map']} dtype={kwargs['dtype']}"
)
try:
_tts_model = Qwen3TTSModel.from_pretrained(
QWEN_TTS_MODEL_ID,
attn_implementation="flash_attention_2",
**kwargs,
)
attention = "flash_attention_2"
except Exception:
_tts_model = Qwen3TTSModel.from_pretrained(
QWEN_TTS_MODEL_ID,
**kwargs,
)
attention = "default"
print(
"Loaded Modal TTS model: "
f"{_tts_model.__class__.__module__}.{_tts_model.__class__.__name__} "
f"model={QWEN_TTS_MODEL_ID} attention={attention}"
)
_tts_loaded_at = round(started, 3)
return _tts_model
def _load_chatterbox_tts_model() -> tuple[str, Any]:
global _chatterbox_tts_model, _chatterbox_tts_loaded_at
if _chatterbox_tts_model is not None:
return _infer_chatterbox_runtime(_chatterbox_tts_model), _chatterbox_tts_model
started = perf_counter()
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
_ensure_chatterbox_watermarker()
runtime_name, _chatterbox_tts_model = _load_first_available_chatterbox_model(device)
print(
"Loaded Modal TTS model: "
f"provider=resemble_ai runtime=chatterbox-{runtime_name} "
f"configured_models={','.join(_chatterbox_model_candidates())} "
f"{runtime_name} ({_chatterbox_tts_model.__class__.__module__}."
f"{_chatterbox_tts_model.__class__.__name__}) on {device}"
)
_chatterbox_tts_loaded_at = round(started, 3)
return runtime_name, _chatterbox_tts_model
def _ensure_chatterbox_watermarker() -> None:
try:
import perth
except Exception as exc:
print(f"Chatterbox Perth watermarker unavailable; using no-op watermarker: {exc}")
return
watermarker_cls = getattr(perth, "PerthImplicitWatermarker", None)
if callable(watermarker_cls):
return
class _NoOpWatermarker:
def apply_watermark(self, wav: Any, *args: Any, **kwargs: Any) -> Any:
return wav
def watermark(self, wav: Any, *args: Any, **kwargs: Any) -> Any:
return wav
perth.PerthImplicitWatermarker = _NoOpWatermarker
print("Chatterbox PerthImplicitWatermarker is missing; using no-op watermarker.")
def _load_first_available_chatterbox_model(device: str) -> tuple[str, Any]:
last_error: Exception | None = None
for runtime_name, chatterbox_cls in _chatterbox_model_classes():
try:
print(f"Trying Chatterbox TTS runtime: {runtime_name}")
return runtime_name, _load_chatterbox_from_pretrained(chatterbox_cls, device)
except Exception as exc:
last_error = exc
print(f"Chatterbox {runtime_name} load failed; trying next fallback: {exc}")
if last_error is not None:
raise last_error
raise RuntimeError("No Chatterbox TTS runtime is available.")
def _chatterbox_model_classes() -> list[tuple[str, Any]]:
classes: list[tuple[str, Any]] = []
try:
from chatterbox.tts_turbo import ChatterboxTurboTTS
classes.append(("turbo", ChatterboxTurboTTS))
except Exception as exc:
print(f"Chatterbox Turbo runtime unavailable; trying standard Chatterbox: {exc}")
try:
from chatterbox.tts import ChatterboxTTS
classes.append(("standard", ChatterboxTTS))
except Exception as exc:
print(f"Standard Chatterbox runtime unavailable: {exc}")
return classes
def _load_chatterbox_from_pretrained(chatterbox_cls: Any, device: str) -> Any:
from_pretrained = chatterbox_cls.from_pretrained
if _supports_chatterbox_model_id(from_pretrained):
for model_id in _chatterbox_model_candidates():
try:
return from_pretrained(model_id, device=device)
except Exception as exc:
print(f"Chatterbox model id {model_id!r} failed; trying next fallback: {exc}")
return from_pretrained(device=device)
def _supports_chatterbox_model_id(from_pretrained: Any) -> bool:
try:
parameters = inspect.signature(from_pretrained).parameters
except (TypeError, ValueError):
return False
positional = [
parameter
for parameter in parameters.values()
if parameter.kind
in {
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
}
]
model_id_parameter_names = {
"model_id",
"repo_id",
"pretrained_model_name_or_path",
"checkpoint",
}
return bool(positional) and positional[0].name in model_id_parameter_names
def _chatterbox_model_candidates() -> list[str]:
raw = os.getenv("TIME_MACHINE_CHATTERBOX_TURBO_MODEL_ID", "ResembleAI/chatterbox-turbo")
raw_list = os.getenv("TIME_MACHINE_CHATTERBOX_MODEL_IDS", raw)
candidates = [item.strip() for item in raw_list.split(",") if item.strip()]
if "ResembleAI/chatterbox" not in candidates:
candidates.append("ResembleAI/chatterbox")
return candidates
def _target_language(language: str) -> str:
supported_langs = {
"en-US", "en", "en-GB", "enGB",
"es-ES", "esES", "es-US", "es",
"zh-CN", "zh-ZH"
}
return language if language in supported_langs else "en"
def _transcribe_manifest(
model: Any,
proc_path: Path,
target_lang: str,
duration: float,
) -> Any:
import json
manifest_path = proc_path.with_suffix(".json")
try:
with open(manifest_path, "w", encoding="utf-8") as f:
f.write(json.dumps({
"audio_filepath": str(proc_path),
"duration": duration,
"text": "",
"target_lang": target_lang,
"lang": target_lang,
"language": target_lang
}) + "\n")
return model.transcribe([str(manifest_path)])
finally:
manifest_path.unlink(missing_ok=True)
def _synthesize_to_wav_bytes(
model: Any,
text: str,
language: str,
instruction: str,
seed: int | None = None,
) -> tuple[bytes, float]:
if seed is not None:
_seed_tts_generation(seed)
wavs, sample_rate = model.generate_voice_design(
text=text,
language=language,
instruct=instruction,
)
output_path = Path(tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name)
try:
import soundfile as sf
sf.write(output_path, wavs[0], sample_rate)
audio_bytes = output_path.read_bytes()
duration_seconds = round(len(wavs[0]) / float(sample_rate), 3)
finally:
output_path.unlink(missing_ok=True)
return audio_bytes, duration_seconds
def _synthesize_chatterbox_to_wav_bytes(
model: Any,
text: str,
prompt: str,
seed: int | None = None,
runtime_name: str | None = None,
exaggeration: float = 0.65,
cfg_weight: float = 0.35,
temperature: float = 0.8,
) -> tuple[bytes, float]:
if seed is not None:
_seed_tts_generation(seed)
runtime = runtime_name or _infer_chatterbox_runtime(model)
generate_values: dict[str, Any] = {
"text": text,
"prompt": prompt,
"condition_prompt": prompt,
"temperature": temperature,
}
if runtime != "turbo":
generate_values["exaggeration"] = exaggeration
generate_values["cfg_weight"] = cfg_weight
generate_kwargs = _supported_kwargs(model.generate, generate_values)
if "text" in generate_kwargs:
wav = model.generate(**generate_kwargs)
else:
wav = model.generate(text, **generate_kwargs)
sample_rate = int(getattr(model, "sr", getattr(model, "sample_rate", 24000)))
audio = _to_numpy_audio(wav)
output_path = Path(tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name)
try:
import soundfile as sf
sf.write(output_path, audio, sample_rate)
audio_bytes = output_path.read_bytes()
duration_seconds = round(len(audio) / float(sample_rate), 3)
finally:
output_path.unlink(missing_ok=True)
return audio_bytes, duration_seconds
def _chatterbox_description(
runtime_name: str,
latency_profile: str,
exaggeration: float,
cfg_weight: float,
) -> str:
if runtime_name == "turbo":
return f"Chatterbox Turbo TTS synthesis on warm Modal GPU ({latency_profile})."
return (
"Chatterbox TTS synthesis on warm Modal GPU "
f"({latency_profile}, exaggeration={exaggeration:g}, cfg={cfg_weight:g})."
)
def _infer_chatterbox_runtime(model: Any) -> str:
class_path = f"{model.__class__.__module__}.{model.__class__.__name__}".lower()
return "turbo" if "turbo" in class_path else "standard"
def _supported_kwargs(callable_obj: Any, values: dict[str, Any]) -> dict[str, Any]:
try:
parameters = inspect.signature(callable_obj).parameters
except (TypeError, ValueError):
return values
if any(parameter.kind == inspect.Parameter.VAR_KEYWORD for parameter in parameters.values()):
return values
return {key: value for key, value in values.items() if key in parameters}
def _to_numpy_audio(wav: Any) -> Any:
if isinstance(wav, tuple) and wav:
wav = wav[0]
if isinstance(wav, list) and wav and hasattr(wav[0], "__len__"):
wav = wav[0]
if hasattr(wav, "detach"):
wav = wav.detach().cpu().numpy()
elif hasattr(wav, "cpu"):
wav = wav.cpu().numpy()
elif isinstance(wav, list):
import numpy as np
wav = np.asarray(wav, dtype="float32")
if hasattr(wav, "squeeze"):
wav = wav.squeeze()
return wav
def _seed_tts_generation(seed: int) -> None:
normalized = seed % (2**31 - 1)
random.seed(normalized)
try:
import numpy as np
np.random.seed(normalized)
except Exception:
pass
try:
import torch
torch.manual_seed(normalized)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(normalized)
except Exception:
pass
def _convert_to_mono_16k(input_path: Path) -> Path:
if _is_mono_16k_wav(input_path):
return input_path
import subprocess
output_path = input_path.with_name(input_path.stem + "_mono_16k.wav")
try:
cmd = [
"ffmpeg",
"-y",
"-i", str(input_path),
"-ac", "1",
"-ar", "16000",
str(output_path)
]
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return output_path
except Exception as e:
print(f"ffmpeg conversion failed: {e}")
return input_path
def _is_mono_16k_wav(path: Path) -> bool:
try:
import wave
with wave.open(str(path), "rb") as handle:
return (
handle.getnchannels() == 1
and handle.getframerate() == 16000
and handle.getsampwidth() == 2
)
except Exception:
return False
def _get_audio_duration(path: Path) -> float:
try:
import wave
with wave.open(str(path), "rb") as f:
frames = f.getnframes()
rate = f.getframerate()
return max(0.1, frames / float(rate))
except Exception:
return 10.0
def _write_request_audio(audio_b64: str) -> Path:
audio_bytes = base64.b64decode(audio_b64)
audio_path = Path(tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name)
audio_path.write_bytes(audio_bytes)
return audio_path
def _extract_transcript_text(result: Any) -> str:
if isinstance(result, str):
return result.strip()
if isinstance(result, list) and result:
first = result[0]
if isinstance(first, str):
return first.strip()
if hasattr(first, "text"):
return str(first.text).strip()
if isinstance(first, dict):
return str(first.get("text", "")).strip()
if hasattr(result, "text"):
return str(result.text).strip()
return ""
def _voice_instruction(
voice_profile: dict[str, Any],
prosody_hint: object,
) -> str:
voice_id = str(voice_profile.get("voice_id") or "character")
parts = [
(
f"Consistent speaker identity: {voice_id}. "
"Keep the same timbre and apparent speaker across separate lines."
),
str(voice_profile.get("description") or "Natural conversational character voice."),
f"Pace: {voice_profile.get('pace') or 'medium'}.",
f"Emotion: {voice_profile.get('emotion') or 'curious'}.",
]
accent_hint = voice_profile.get("accent_hint")
if accent_hint:
parts.append(f"Accent or local color: {accent_hint}.")
if prosody_hint:
parts.append(f"Prosody: {prosody_hint}.")
return " ".join(parts)
def _voice_seed(voice_profile: dict[str, Any], provided: object = None) -> int:
try:
parsed = int(provided)
except (TypeError, ValueError):
parsed = 0
if parsed > 0:
return parsed
payload = "\n".join(
[
str(voice_profile.get("voice_id") or ""),
str(voice_profile.get("description") or ""),
str(voice_profile.get("pace") or ""),
str(voice_profile.get("emotion") or ""),
str(voice_profile.get("accent_hint") or ""),
]
)
return int(hashlib.sha256(payload.encode("utf-8")).hexdigest()[:8], 16)
def _float_item(item: dict[str, Any], key: str, default: float) -> float:
try:
return float(item.get(key, default))
except (TypeError, ValueError):
return default
def _required_string(item: dict[str, Any], key: str) -> str:
value = item.get(key)
if not isinstance(value, str) or not value:
raise ValueError(f"{key} is required.")
return value