WitnessBox / modal_app.py
Farseen0's picture
Deploy WitnessBox
c519923 verified
Raw
History Blame Contribute Delete
17 kB
"""WitnessBox on Modal — the runtime that serves the game's three models and
pre-generates its scripted beats.
Deploy: modal deploy modal_app.py
Then run the Space with WITNESSBOX_BACKEND=modal and the Modal token set as
Space secrets (MODAL_TOKEN_ID / MODAL_TOKEN_SECRET).
How this is a genuine *best use of the platform* (not just hosting), mapped to
the README's "Best Use of Modal" section:
1. GPU inference behind `@app.cls`, **scale-to-zero** — three models, three
right-sized GPUs, $0 when idle (`scaledown_window`).
2. **`keep_warm` / min_containers** on the witness brain + voice so a live
examination doesn't pay a cold start every turn (the honest latency story).
3. **Parallel `.map()`** pre-generates every fixed beat at deploy time, fanning
the 32 voice-crack takes across containers at once and keeping the best.
4. **Volume** persists the designed CFO reference voice + model cache + chosen
beats across cold starts.
5. **Memory snapshots** cut CPU-side init on cold start.
NOTE: model-call signatures follow PRD.md / HACKATHON-CONTEXT.md (verified). The
exact VoxCPM2 / Nemotron import paths may need a one-line pin against the shipped
package versions at deploy time; each is isolated in a `_load` / `_synth` helper.
"""
from __future__ import annotations
import os
import modal
import config
from witnessbox import script
app = modal.App(config.MODAL_APP_NAME)
cache = modal.Volume.from_name("witnessbox-cache", create_if_missing=True)
CACHE_DIR = "/cache"
REF_VOICE_PATH = f"{CACHE_DIR}/cfo_reference.wav"
BEATS_DIR = f"{CACHE_DIR}/beats"
# Keep-warm is OPT-IN. Default 0 => true scale-to-zero, $0 when idle (the honest
# Best-Use-of-Modal story, and it won't burn credits between demos). Flip it on
# only for a live demo recording / judging window:
# WITNESSBOX_KEEP_WARM=1 modal deploy modal_app.py
# Warm turns are then ~5.3s (reply) + ~8.6s (voice); a cold first turn pays the
# model-load once (memory snapshots + the Volume model cache keep that bounded).
_KEEP_WARM = int(os.environ.get("WITNESSBOX_KEEP_WARM", "0"))
# Per-model images keep conflicting deps (notably torch pins) apart.
_HF = {"HF_HOME": CACHE_DIR, "HF_HUB_ENABLE_HF_TRANSFER": "1"}
llm_image = (
modal.Image.debian_slim(python_version="3.11")
# MiniCPM4.1-8B is a standard text model — clean transformers deps, no omni
# dependency cascade (PIL/librosa/soundfile/minicpmo/vocos/...).
# transformers <5: MiniCPM4.1-8B's remote code imports is_torch_fx_available,
# which transformers 5.x removed.
.pip_install("torch>=2.5.0", "transformers>=4.46,<5", "accelerate",
"sentencepiece", "hf_transfer", "numpy")
.env(_HF)
.add_local_python_source("config", "witnessbox")
)
voice_image = (
modal.Image.debian_slim(python_version="3.11")
.apt_install("ffmpeg")
.pip_install("torch>=2.5.0", "soundfile", "librosa", "numpy", "hf_transfer",
"voxcpm") # the VoxCPM2 runtime package
.env(_HF)
.add_local_python_source("config", "witnessbox")
)
asr_image = (
modal.Image.debian_slim(python_version="3.11")
.apt_install("ffmpeg")
.pip_install("torch>=2.5.0", "transformers>=4.49", "soundfile", "librosa",
"numpy", "hf_transfer")
.env(_HF)
.add_local_python_source("config", "witnessbox")
)
# --------------------------------------------------------------------------- #
# Witness brain — MiniCPM4.1-8B (standard text model; clean transformers deps)
# --------------------------------------------------------------------------- #
@app.cls(
image=llm_image,
gpu="A100",
volumes={CACHE_DIR: cache},
scaledown_window=300, # scale-to-zero after 5 min idle
min_containers=_KEEP_WARM, # 0 = $0 idle; set WITNESSBOX_KEEP_WARM=1 for live demos
enable_memory_snapshot=True,
)
class WitnessLLM:
@modal.enter()
def load(self):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Standard causal-LM load. sdpa avoids a flash-attn dependency.
# Verified: https://huggingface.co/openbmb/MiniCPM4.1-8B
self.tokenizer = AutoTokenizer.from_pretrained(
config.WITNESS_LLM, trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
config.WITNESS_LLM,
trust_remote_code=True,
attn_implementation="sdpa",
torch_dtype=torch.bfloat16, # transformers 4.x uses torch_dtype, not dtype
device_map="cuda",
).eval()
@modal.method()
def respond(self, system_prompt: str, messages: list[dict]) -> str:
import re
import torch
msgs = [{"role": "system", "content": system_prompt}]
for m in messages:
msgs.append({"role": m["role"], "content": m["content"]})
# enable_thinking=False -> direct in-character reply, no <think> trace.
try:
prompt = self.tokenizer.apply_chat_template(
msgs, tokenize=False, add_generation_prompt=True, enable_thinking=False
)
except TypeError:
prompt = self.tokenizer.apply_chat_template(
msgs, tokenize=False, add_generation_prompt=True
)
inputs = self.tokenizer([prompt], return_tensors="pt").to("cuda")
with torch.no_grad():
out = self.model.generate(
**inputs, max_new_tokens=160, do_sample=True, temperature=0.7, top_p=0.95
)
text = self.tokenizer.decode(
out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
)
text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL) # safety net
return text.strip()
# --------------------------------------------------------------------------- #
# Witness voice — VoxCPM2, style tag = game state
# --------------------------------------------------------------------------- #
@app.cls(
image=voice_image,
gpu="A10G",
volumes={CACHE_DIR: cache},
scaledown_window=300,
min_containers=_KEEP_WARM, # 0 = $0 idle; set WITNESSBOX_KEEP_WARM=1 for live demos
enable_memory_snapshot=True,
)
class WitnessVoice:
@modal.enter()
def load(self):
import os
from voxcpm import VoxCPM # class is VoxCPM; the model id is openbmb/VoxCPM2
# torch>=2.5.0 enforced by the image. Denoiser off for speed.
# Verified: https://voxcpm.readthedocs.io / pip install voxcpm
# optimize=False: skip torch.compile. Compilation costs minutes on every
# cold start (and would recompile on each scaled-up container); the
# per-line speedup isn't worth that for a turn-based game. Documented
# escape hatch in the VoxCPM docs.
self.tts = VoxCPM.from_pretrained(
config.WITNESS_VOICE, load_denoiser=False, optimize=False
)
self.sr = int(self.tts.tts_model.sample_rate) # 48000 for VoxCPM2
# Design the CFO reference voice ONCE and persist it on the Volume, so
# every line is a controllable clone of the same designed voice.
if not os.path.exists(REF_VOICE_PATH):
os.makedirs(CACHE_DIR, exist_ok=True)
wav = self._synth(
"(a composed, measured, late-50s American male executive; dry, controlled)"
"Counselor, I have nothing to hide.",
reference=None,
)
_write_wav(REF_VOICE_PATH, wav, self.sr)
cache.commit()
def _synth(self, styled_text: str, reference: str | None):
"""One VoxCPM generate call. Voice-design when reference is None, else
controllable-clone of the designed CFO voice (style tag in parens)."""
kwargs = dict(text=styled_text, cfg_value=2.0, inference_timesteps=10)
if reference is not None:
kwargs["reference_wav_path"] = reference
wav = self.tts.generate(**kwargs)
import numpy as np
return np.asarray(wav, dtype=np.float32).reshape(-1)
@modal.method()
def speak(self, text: str, style: str):
wav = self._synth(f"({style}){text}", reference=REF_VOICE_PATH)
return wav, self.sr
@modal.method()
def bake(self, key: str, idx: int, text: str, style: str) -> dict:
"""Render ONE beat take, write the WAV straight to the mounted Volume, and
return only small metadata (path + break score).
Why write-to-Volume instead of returning (wav, sr): `.map()/.starmap()`
fetch large results through Modal's input-plane blob path, which errors
`BlobGet UNIMPLEMENTED` on this deploy. Returning a tiny dict keeps the
result inline (no blob), and doing the librosa break-scoring here fans
that cost across containers too (it was a serial bottleneck before)."""
import os
wav = self._synth(f"({style}){text}", reference=REF_VOICE_PATH)
os.makedirs(BEATS_DIR, exist_ok=True)
path = f"{BEATS_DIR}/_take_{key}_{int(idx):02d}.wav"
_write_wav(path, wav, self.sr)
score = _break_score(wav, self.sr) if key == "break" else 0.0
cache.commit() # make this take visible to the orchestrator container
return {"key": key, "idx": int(idx), "path": path,
"score": float(score), "samples": int(len(wav)), "sr": self.sr}
@modal.method()
def beat(self, key: str):
"""Return a cached pre-generated beat, or render it live as a fallback."""
import os
path = f"{BEATS_DIR}/{key}.wav"
if os.path.exists(path):
wav, sr = _read_wav(path)
return wav, sr
spec = script.scripted_beats().get(key)
if not spec:
return None
wav = self._synth(f"({spec['style']}){spec['text']}", reference=REF_VOICE_PATH)
return wav, self.sr
# --------------------------------------------------------------------------- #
# Player ASR — Nemotron streaming, whisper-small fallback
# --------------------------------------------------------------------------- #
@app.cls(
image=asr_image,
gpu="A10G",
volumes={CACHE_DIR: cache},
scaledown_window=300,
enable_memory_snapshot=True,
)
class PlayerASR:
@modal.enter()
def load(self):
# First deploy uses whisper-small: light, reliable, and a real transformers
# pipeline. Nemotron 0.6b is NeMo-ONLY (not a transformers model), so to
# chase the Nemotron prize, add `nemo_toolkit[asr]` to asr_image and swap to:
# import nemo.collections.asr as nemo_asr
# self.model = nemo_asr.models.ASRModel.from_pretrained(config.PLAYER_ASR)
# # transcribe(["/tmp/x.wav"]) -> [hypothesis]; .text on the hypothesis
from transformers import pipeline
self.pipe = pipeline("automatic-speech-recognition",
model=config.PLAYER_ASR_FALLBACK, device=0)
self.kind = "whisper-small"
@modal.method()
def transcribe(self, audio, sr: int) -> str:
import numpy as np
y = np.asarray(audio, dtype=np.float32).reshape(-1)
out = self.pipe({"array": y, "sampling_rate": int(sr)})
return (out.get("text", "") if isinstance(out, dict) else str(out)).strip()
# --------------------------------------------------------------------------- #
# Pre-generate every fixed beat in parallel (.map) and keep the best break take
# --------------------------------------------------------------------------- #
@app.function(image=voice_image, volumes={CACHE_DIR: cache}, timeout=1800)
def pregenerate_beats():
"""Fan the scripted beats across containers with `.map()`; the 32 break
takes are generated concurrently and the most-broken one is cached.
Writes a result/error JSON to the Volume so a local client can read the
outcome from the file (dodges the flaky gRPC blob-fetch on long .get())."""
import json
import os
import traceback
result = {"ok": False}
try:
os.makedirs(BEATS_DIR, exist_ok=True)
voice = WitnessVoice()
beats = script.scripted_beats()
# One (key, idx, text, style) per take: each single beat once, the break
# N times. Fan ALL of them across containers with .starmap(); workers
# write WAVs to the Volume and return only metadata (no audio blobs).
args = [(k, i, b["text"], b["style"])
for k, b in beats.items() for i in range(b["takes"])]
metas = [m for m in voice.bake.starmap(args) if m]
cache.reload() # surface the WAVs the worker containers committed
written = []
# Single beats: promote _take_<key>_00.wav -> <key>.wav.
for key, b in beats.items():
if b["takes"] == 1:
src = f"{BEATS_DIR}/_take_{key}_00.wav"
if os.path.exists(src):
os.replace(src, f"{BEATS_DIR}/{key}.wav")
written.append(key)
# The climax: keep the take whose voiced pitch is most unstable (cracks most).
break_metas = [m for m in metas if m["key"] == "break"]
best = max(break_metas, key=lambda m: m["score"], default=None)
best_score = best["score"] if best else -1.0
if best and os.path.exists(best["path"]):
os.replace(best["path"], f"{BEATS_DIR}/break.wav")
written.append("break")
# Tidy up the losing takes.
for m in metas:
if os.path.exists(m["path"]):
try:
os.remove(m["path"])
except OSError:
pass
result = {"ok": True, "break_score": float(best_score),
"written": written, "takes": len(args),
"break_scores": sorted((round(m["score"], 2) for m in break_metas), reverse=True)[:5]}
except Exception as e:
result = {"ok": False, "error": repr(e), "trace": traceback.format_exc()[-2500:]}
os.makedirs(CACHE_DIR, exist_ok=True)
with open(f"{CACHE_DIR}/beats_result.json", "w") as f:
json.dump(result, f)
cache.commit()
print("PREGEN RESULT:", json.dumps(result)[:400])
return result
# --------------------------------------------------------------------------- #
# Server-side end-to-end smoke (dodges flaky local gRPC: spawn + read Volume)
# --------------------------------------------------------------------------- #
@app.function(
# needs the local source too, since the container imports modal_app (-> config)
image=modal.Image.debian_slim(python_version="3.11").pip_install("numpy")
.add_local_python_source("config", "witnessbox"),
volumes={CACHE_DIR: cache},
timeout=1800,
)
def smoke():
"""One LLM reply + one voice line, orchestrated *inside* Modal. Writes the
result to the Volume so a local client only has to .spawn() (instant) and
later read a tiny file — never hold a multi-minute streaming wait."""
import json
import os
import numpy as np
llm = WitnessLLM()
voice = WitnessVoice()
reply = llm.respond.remote(
"You are Marcus Reid, a guarded CFO under oath. Answer in ONE short sentence, in character.",
[{"role": "user", "content": "Did you authorize the twelve-million-dollar wire to Meridian?"}],
)
wav, sr = voice.speak.remote(
"I have nothing to hide, counselor.", "calm, composed, faintly condescending"
)
result = {
"reply": reply,
"voice_samples": int(np.asarray(wav).size),
"sr": int(sr),
"ok": bool(reply) and int(np.asarray(wav).size) > 0,
}
os.makedirs(CACHE_DIR, exist_ok=True)
with open(f"{CACHE_DIR}/smoke_result.json", "w") as f:
json.dump(result, f)
cache.commit()
print("SMOKE RESULT:", json.dumps(result)[:300])
return result
# --------------------------------------------------------------------------- #
# small audio io helpers (run inside the images)
# --------------------------------------------------------------------------- #
def _write_wav(path: str, wav, sr: int):
import soundfile as sf
import numpy as np
sf.write(path, np.asarray(wav, dtype=np.float32).reshape(-1), int(sr))
def _read_wav(path: str):
import soundfile as sf
wav, sr = sf.read(path, dtype="float32")
return wav.reshape(-1), int(sr)
def _break_score(wav, sr: int) -> float:
"""Heuristic 'how much does this take crack' — pitch instability of voiced f0."""
try:
import librosa
import numpy as np
f0, _, _ = librosa.pyin(np.asarray(wav, dtype=np.float32).reshape(-1),
fmin=65.0, fmax=400.0, sr=sr)
vf = f0[np.isfinite(f0)]
return float(np.std(vf)) if vf.size > 5 else 0.0
except Exception:
return 0.0
@app.local_entrypoint()
def warm():
"""`modal run modal_app.py` — pre-generate beats and report the break score."""
print(pregenerate_beats.remote())