ai-time-machine / scripts /modal_warmup.py
manikandanj's picture
Prepare AI Time Machine hackathon Space
5862322 verified
Raw
History Blame Contribute Delete
5.9 kB
from __future__ import annotations
import base64
import json
import math
import os
import struct
import sys
import time
import urllib.error
import urllib.request
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
ROOT_ENV = ROOT / ".env"
LOCAL_ENV = ROOT / "data" / "local.env"
def main() -> int:
_load_local_env(ROOT_ENV, override=True)
_load_local_env(LOCAL_ENV, override=True)
stt_url = os.getenv("TIME_MACHINE_MODAL_STT_URL", "").strip()
tts_url = _required_env("TIME_MACHINE_MODAL_TTS_URL")
bearer_token = os.getenv("TIME_MACHINE_MODAL_BEARER_TOKEN")
if _env_flag("TIME_MACHINE_MODAL_WARMUP_STT", default=True):
if not stt_url:
raise RuntimeError("TIME_MACHINE_MODAL_STT_URL is required unless TIME_MACHINE_MODAL_WARMUP_STT=0.")
print("Warming Modal Nemotron STT endpoint...")
stt_started = time.perf_counter()
stt_response = _post_json(
stt_url,
{
"audio_b64": base64.b64encode(_tone_wav()).decode("ascii"),
"audio_mime_type": "audio/wav",
"language": os.getenv("TIME_MACHINE_MODAL_STT_LANGUAGE", "en"),
},
bearer_token=bearer_token,
timeout_seconds=float(os.getenv("TIME_MACHINE_MODAL_STT_TIMEOUT", "240")),
)
print(f"STT warmup wall time: {time.perf_counter() - stt_started:.1f}s")
print(f"STT response timings: {stt_response.get('timings')}")
else:
print("Skipping Modal STT warmup.")
tts_model_family = os.getenv("TIME_MACHINE_MODAL_TTS_MODEL_FAMILY", "chatterbox_turbo")
print(f"Warming Modal {tts_model_family} TTS endpoint...")
tts_started = time.perf_counter()
tts_response = _post_json(
tts_url,
{
"text": os.getenv("TIME_MACHINE_MODAL_WARMUP_TEXT", "The signal is open."),
"language": os.getenv("TIME_MACHINE_MODAL_TTS_LANGUAGE", "English"),
"prosody_hint": "brief, conversational, low latency",
"model_family": tts_model_family,
"latency_profile": os.getenv("TIME_MACHINE_MODAL_TTS_LATENCY_PROFILE", "turbo"),
"exaggeration": float(os.getenv("TIME_MACHINE_CHATTERBOX_EXAGGERATION", "0.65")),
"cfg_weight": float(os.getenv("TIME_MACHINE_CHATTERBOX_CFG_WEIGHT", "0.35")),
"temperature": float(os.getenv("TIME_MACHINE_CHATTERBOX_TEMPERATURE", "0.8")),
"voice_profile": {
"voice_id": f"{tts_model_family.replace('_', '-')}-warmup",
"description": "Natural expressive character voice for warmup.",
"pace": "fast",
"emotion": "engaged",
"accent_hint": None,
},
},
bearer_token=bearer_token,
timeout_seconds=float(os.getenv("TIME_MACHINE_MODAL_TTS_TIMEOUT", "240")),
)
print(f"TTS warmup wall time: {time.perf_counter() - tts_started:.1f}s")
print(f"TTS response timings: {tts_response.get('timings')}")
print(f"TTS audio duration: {tts_response.get('duration_seconds')}s")
return 0
def _load_local_env(path: Path, *, override: bool = False) -> None:
if not path.exists():
return
for raw_line in path.read_text(encoding="utf-8").splitlines():
line = raw_line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, value = line.split("=", 1)
key = key.strip()
value = value.strip().strip('"').strip("'")
if key and (override or key not in os.environ):
os.environ[key] = value
def _required_env(name: str) -> str:
value = os.getenv(name)
if not value or not value.strip():
raise RuntimeError(f"{name} is required.")
return value.strip()
def _env_flag(name: str, default: bool) -> bool:
value = os.getenv(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
def _post_json(
url: str,
payload: dict,
*,
bearer_token: str | None,
timeout_seconds: float,
) -> dict:
body = json.dumps(payload).encode("utf-8")
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
if bearer_token:
headers["Authorization"] = f"Bearer {bearer_token}"
request = urllib.request.Request(url, data=body, headers=headers, method="POST")
try:
with urllib.request.urlopen(request, timeout=timeout_seconds) as response:
decoded = json.loads(response.read().decode("utf-8"))
except urllib.error.HTTPError as exc:
detail = exc.read().decode("utf-8", errors="replace")
raise RuntimeError(f"{url} returned HTTP {exc.code}: {detail}") from exc
if not isinstance(decoded, dict):
raise RuntimeError(f"{url} returned a non-object response.")
return decoded
def _tone_wav(sample_rate: int = 16000, seconds: float = 0.35) -> bytes:
samples = int(sample_rate * seconds)
frames = bytearray()
for index in range(samples):
value = int(0.08 * 32767 * math.sin(2 * math.pi * 440 * index / sample_rate))
frames.extend(struct.pack("<h", value))
data_size = len(frames)
header = b"".join(
[
b"RIFF",
struct.pack("<I", 36 + data_size),
b"WAVEfmt ",
struct.pack("<IHHIIHH", 16, 1, 1, sample_rate, sample_rate * 2, 2, 16),
b"data",
struct.pack("<I", data_size),
]
)
return header + bytes(frames)
if __name__ == "__main__":
try:
raise SystemExit(main())
except Exception as exc:
print(f"Modal warmup failed: {exc}", file=sys.stderr)
raise SystemExit(1)