File size: 16,951 Bytes
c519923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
"""WitnessBox on Modal — the runtime that serves the game's three models and
pre-generates its scripted beats.

Deploy:   modal deploy modal_app.py
Then run the Space with WITNESSBOX_BACKEND=modal and the Modal token set as
Space secrets (MODAL_TOKEN_ID / MODAL_TOKEN_SECRET).

How this is a genuine *best use of the platform* (not just hosting), mapped to
the README's "Best Use of Modal" section:

1. GPU inference behind `@app.cls`, **scale-to-zero** — three models, three
   right-sized GPUs, $0 when idle (`scaledown_window`).
2. **`keep_warm` / min_containers** on the witness brain + voice so a live
   examination doesn't pay a cold start every turn (the honest latency story).
3. **Parallel `.map()`** pre-generates every fixed beat at deploy time, fanning
   the 32 voice-crack takes across containers at once and keeping the best.
4. **Volume** persists the designed CFO reference voice + model cache + chosen
   beats across cold starts.
5. **Memory snapshots** cut CPU-side init on cold start.

NOTE: model-call signatures follow PRD.md / HACKATHON-CONTEXT.md (verified). The
exact VoxCPM2 / Nemotron import paths may need a one-line pin against the shipped
package versions at deploy time; each is isolated in a `_load` / `_synth` helper.
"""
from __future__ import annotations

import os

import modal

import config
from witnessbox import script

app = modal.App(config.MODAL_APP_NAME)
cache = modal.Volume.from_name("witnessbox-cache", create_if_missing=True)
CACHE_DIR = "/cache"
REF_VOICE_PATH = f"{CACHE_DIR}/cfo_reference.wav"
BEATS_DIR = f"{CACHE_DIR}/beats"

# Keep-warm is OPT-IN. Default 0 => true scale-to-zero, $0 when idle (the honest
# Best-Use-of-Modal story, and it won't burn credits between demos). Flip it on
# only for a live demo recording / judging window:
#     WITNESSBOX_KEEP_WARM=1 modal deploy modal_app.py
# Warm turns are then ~5.3s (reply) + ~8.6s (voice); a cold first turn pays the
# model-load once (memory snapshots + the Volume model cache keep that bounded).
_KEEP_WARM = int(os.environ.get("WITNESSBOX_KEEP_WARM", "0"))

# Per-model images keep conflicting deps (notably torch pins) apart.
_HF = {"HF_HOME": CACHE_DIR, "HF_HUB_ENABLE_HF_TRANSFER": "1"}

llm_image = (
    modal.Image.debian_slim(python_version="3.11")
    # MiniCPM4.1-8B is a standard text model — clean transformers deps, no omni
    # dependency cascade (PIL/librosa/soundfile/minicpmo/vocos/...).
    # transformers <5: MiniCPM4.1-8B's remote code imports is_torch_fx_available,
    # which transformers 5.x removed.
    .pip_install("torch>=2.5.0", "transformers>=4.46,<5", "accelerate",
                 "sentencepiece", "hf_transfer", "numpy")
    .env(_HF)
    .add_local_python_source("config", "witnessbox")
)
voice_image = (
    modal.Image.debian_slim(python_version="3.11")
    .apt_install("ffmpeg")
    .pip_install("torch>=2.5.0", "soundfile", "librosa", "numpy", "hf_transfer",
                 "voxcpm")  # the VoxCPM2 runtime package
    .env(_HF)
    .add_local_python_source("config", "witnessbox")
)
asr_image = (
    modal.Image.debian_slim(python_version="3.11")
    .apt_install("ffmpeg")
    .pip_install("torch>=2.5.0", "transformers>=4.49", "soundfile", "librosa",
                 "numpy", "hf_transfer")
    .env(_HF)
    .add_local_python_source("config", "witnessbox")
)


# --------------------------------------------------------------------------- #
# Witness brain — MiniCPM4.1-8B (standard text model; clean transformers deps)
# --------------------------------------------------------------------------- #
@app.cls(
    image=llm_image,
    gpu="A100",
    volumes={CACHE_DIR: cache},
    scaledown_window=300,        # scale-to-zero after 5 min idle
    min_containers=_KEEP_WARM,   # 0 = $0 idle; set WITNESSBOX_KEEP_WARM=1 for live demos
    enable_memory_snapshot=True,
)
class WitnessLLM:
    @modal.enter()
    def load(self):
        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer

        # Standard causal-LM load. sdpa avoids a flash-attn dependency.
        # Verified: https://huggingface.co/openbmb/MiniCPM4.1-8B
        self.tokenizer = AutoTokenizer.from_pretrained(
            config.WITNESS_LLM, trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            config.WITNESS_LLM,
            trust_remote_code=True,
            attn_implementation="sdpa",
            torch_dtype=torch.bfloat16,  # transformers 4.x uses torch_dtype, not dtype
            device_map="cuda",
        ).eval()

    @modal.method()
    def respond(self, system_prompt: str, messages: list[dict]) -> str:
        import re
        import torch

        msgs = [{"role": "system", "content": system_prompt}]
        for m in messages:
            msgs.append({"role": m["role"], "content": m["content"]})
        # enable_thinking=False -> direct in-character reply, no <think> trace.
        try:
            prompt = self.tokenizer.apply_chat_template(
                msgs, tokenize=False, add_generation_prompt=True, enable_thinking=False
            )
        except TypeError:
            prompt = self.tokenizer.apply_chat_template(
                msgs, tokenize=False, add_generation_prompt=True
            )
        inputs = self.tokenizer([prompt], return_tensors="pt").to("cuda")
        with torch.no_grad():
            out = self.model.generate(
                **inputs, max_new_tokens=160, do_sample=True, temperature=0.7, top_p=0.95
            )
        text = self.tokenizer.decode(
            out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
        )
        text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)  # safety net
        return text.strip()


# --------------------------------------------------------------------------- #
# Witness voice — VoxCPM2, style tag = game state
# --------------------------------------------------------------------------- #
@app.cls(
    image=voice_image,
    gpu="A10G",
    volumes={CACHE_DIR: cache},
    scaledown_window=300,
    min_containers=_KEEP_WARM,   # 0 = $0 idle; set WITNESSBOX_KEEP_WARM=1 for live demos
    enable_memory_snapshot=True,
)
class WitnessVoice:
    @modal.enter()
    def load(self):
        import os
        from voxcpm import VoxCPM  # class is VoxCPM; the model id is openbmb/VoxCPM2

        # torch>=2.5.0 enforced by the image. Denoiser off for speed.
        # Verified: https://voxcpm.readthedocs.io / pip install voxcpm
        # optimize=False: skip torch.compile. Compilation costs minutes on every
        # cold start (and would recompile on each scaled-up container); the
        # per-line speedup isn't worth that for a turn-based game. Documented
        # escape hatch in the VoxCPM docs.
        self.tts = VoxCPM.from_pretrained(
            config.WITNESS_VOICE, load_denoiser=False, optimize=False
        )
        self.sr = int(self.tts.tts_model.sample_rate)  # 48000 for VoxCPM2

        # Design the CFO reference voice ONCE and persist it on the Volume, so
        # every line is a controllable clone of the same designed voice.
        if not os.path.exists(REF_VOICE_PATH):
            os.makedirs(CACHE_DIR, exist_ok=True)
            wav = self._synth(
                "(a composed, measured, late-50s American male executive; dry, controlled)"
                "Counselor, I have nothing to hide.",
                reference=None,
            )
            _write_wav(REF_VOICE_PATH, wav, self.sr)
            cache.commit()

    def _synth(self, styled_text: str, reference: str | None):
        """One VoxCPM generate call. Voice-design when reference is None, else
        controllable-clone of the designed CFO voice (style tag in parens)."""
        kwargs = dict(text=styled_text, cfg_value=2.0, inference_timesteps=10)
        if reference is not None:
            kwargs["reference_wav_path"] = reference
        wav = self.tts.generate(**kwargs)
        import numpy as np
        return np.asarray(wav, dtype=np.float32).reshape(-1)

    @modal.method()
    def speak(self, text: str, style: str):
        wav = self._synth(f"({style}){text}", reference=REF_VOICE_PATH)
        return wav, self.sr

    @modal.method()
    def bake(self, key: str, idx: int, text: str, style: str) -> dict:
        """Render ONE beat take, write the WAV straight to the mounted Volume, and
        return only small metadata (path + break score).

        Why write-to-Volume instead of returning (wav, sr): `.map()/.starmap()`
        fetch large results through Modal's input-plane blob path, which errors
        `BlobGet UNIMPLEMENTED` on this deploy. Returning a tiny dict keeps the
        result inline (no blob), and doing the librosa break-scoring here fans
        that cost across containers too (it was a serial bottleneck before)."""
        import os
        wav = self._synth(f"({style}){text}", reference=REF_VOICE_PATH)
        os.makedirs(BEATS_DIR, exist_ok=True)
        path = f"{BEATS_DIR}/_take_{key}_{int(idx):02d}.wav"
        _write_wav(path, wav, self.sr)
        score = _break_score(wav, self.sr) if key == "break" else 0.0
        cache.commit()  # make this take visible to the orchestrator container
        return {"key": key, "idx": int(idx), "path": path,
                "score": float(score), "samples": int(len(wav)), "sr": self.sr}

    @modal.method()
    def beat(self, key: str):
        """Return a cached pre-generated beat, or render it live as a fallback."""
        import os
        path = f"{BEATS_DIR}/{key}.wav"
        if os.path.exists(path):
            wav, sr = _read_wav(path)
            return wav, sr
        spec = script.scripted_beats().get(key)
        if not spec:
            return None
        wav = self._synth(f"({spec['style']}){spec['text']}", reference=REF_VOICE_PATH)
        return wav, self.sr


# --------------------------------------------------------------------------- #
# Player ASR — Nemotron streaming, whisper-small fallback
# --------------------------------------------------------------------------- #
@app.cls(
    image=asr_image,
    gpu="A10G",
    volumes={CACHE_DIR: cache},
    scaledown_window=300,
    enable_memory_snapshot=True,
)
class PlayerASR:
    @modal.enter()
    def load(self):
        # First deploy uses whisper-small: light, reliable, and a real transformers
        # pipeline. Nemotron 0.6b is NeMo-ONLY (not a transformers model), so to
        # chase the Nemotron prize, add `nemo_toolkit[asr]` to asr_image and swap to:
        #   import nemo.collections.asr as nemo_asr
        #   self.model = nemo_asr.models.ASRModel.from_pretrained(config.PLAYER_ASR)
        #   # transcribe(["/tmp/x.wav"]) -> [hypothesis]; .text on the hypothesis
        from transformers import pipeline
        self.pipe = pipeline("automatic-speech-recognition",
                             model=config.PLAYER_ASR_FALLBACK, device=0)
        self.kind = "whisper-small"

    @modal.method()
    def transcribe(self, audio, sr: int) -> str:
        import numpy as np
        y = np.asarray(audio, dtype=np.float32).reshape(-1)
        out = self.pipe({"array": y, "sampling_rate": int(sr)})
        return (out.get("text", "") if isinstance(out, dict) else str(out)).strip()


# --------------------------------------------------------------------------- #
# Pre-generate every fixed beat in parallel (.map) and keep the best break take
# --------------------------------------------------------------------------- #
@app.function(image=voice_image, volumes={CACHE_DIR: cache}, timeout=1800)
def pregenerate_beats():
    """Fan the scripted beats across containers with `.map()`; the 32 break
    takes are generated concurrently and the most-broken one is cached.

    Writes a result/error JSON to the Volume so a local client can read the
    outcome from the file (dodges the flaky gRPC blob-fetch on long .get())."""
    import json
    import os
    import traceback

    result = {"ok": False}
    try:
        os.makedirs(BEATS_DIR, exist_ok=True)
        voice = WitnessVoice()
        beats = script.scripted_beats()

        # One (key, idx, text, style) per take: each single beat once, the break
        # N times. Fan ALL of them across containers with .starmap(); workers
        # write WAVs to the Volume and return only metadata (no audio blobs).
        args = [(k, i, b["text"], b["style"])
                for k, b in beats.items() for i in range(b["takes"])]
        metas = [m for m in voice.bake.starmap(args) if m]
        cache.reload()  # surface the WAVs the worker containers committed

        written = []
        # Single beats: promote _take_<key>_00.wav -> <key>.wav.
        for key, b in beats.items():
            if b["takes"] == 1:
                src = f"{BEATS_DIR}/_take_{key}_00.wav"
                if os.path.exists(src):
                    os.replace(src, f"{BEATS_DIR}/{key}.wav")
                    written.append(key)
        # The climax: keep the take whose voiced pitch is most unstable (cracks most).
        break_metas = [m for m in metas if m["key"] == "break"]
        best = max(break_metas, key=lambda m: m["score"], default=None)
        best_score = best["score"] if best else -1.0
        if best and os.path.exists(best["path"]):
            os.replace(best["path"], f"{BEATS_DIR}/break.wav")
            written.append("break")
        # Tidy up the losing takes.
        for m in metas:
            if os.path.exists(m["path"]):
                try:
                    os.remove(m["path"])
                except OSError:
                    pass
        result = {"ok": True, "break_score": float(best_score),
                  "written": written, "takes": len(args),
                  "break_scores": sorted((round(m["score"], 2) for m in break_metas), reverse=True)[:5]}
    except Exception as e:
        result = {"ok": False, "error": repr(e), "trace": traceback.format_exc()[-2500:]}

    os.makedirs(CACHE_DIR, exist_ok=True)
    with open(f"{CACHE_DIR}/beats_result.json", "w") as f:
        json.dump(result, f)
    cache.commit()
    print("PREGEN RESULT:", json.dumps(result)[:400])
    return result


# --------------------------------------------------------------------------- #
# Server-side end-to-end smoke (dodges flaky local gRPC: spawn + read Volume)
# --------------------------------------------------------------------------- #
@app.function(
    # needs the local source too, since the container imports modal_app (-> config)
    image=modal.Image.debian_slim(python_version="3.11").pip_install("numpy")
    .add_local_python_source("config", "witnessbox"),
    volumes={CACHE_DIR: cache},
    timeout=1800,
)
def smoke():
    """One LLM reply + one voice line, orchestrated *inside* Modal. Writes the
    result to the Volume so a local client only has to .spawn() (instant) and
    later read a tiny file — never hold a multi-minute streaming wait."""
    import json
    import os
    import numpy as np

    llm = WitnessLLM()
    voice = WitnessVoice()
    reply = llm.respond.remote(
        "You are Marcus Reid, a guarded CFO under oath. Answer in ONE short sentence, in character.",
        [{"role": "user", "content": "Did you authorize the twelve-million-dollar wire to Meridian?"}],
    )
    wav, sr = voice.speak.remote(
        "I have nothing to hide, counselor.", "calm, composed, faintly condescending"
    )
    result = {
        "reply": reply,
        "voice_samples": int(np.asarray(wav).size),
        "sr": int(sr),
        "ok": bool(reply) and int(np.asarray(wav).size) > 0,
    }
    os.makedirs(CACHE_DIR, exist_ok=True)
    with open(f"{CACHE_DIR}/smoke_result.json", "w") as f:
        json.dump(result, f)
    cache.commit()
    print("SMOKE RESULT:", json.dumps(result)[:300])
    return result


# --------------------------------------------------------------------------- #
# small audio io helpers (run inside the images)
# --------------------------------------------------------------------------- #
def _write_wav(path: str, wav, sr: int):
    import soundfile as sf
    import numpy as np
    sf.write(path, np.asarray(wav, dtype=np.float32).reshape(-1), int(sr))


def _read_wav(path: str):
    import soundfile as sf
    wav, sr = sf.read(path, dtype="float32")
    return wav.reshape(-1), int(sr)


def _break_score(wav, sr: int) -> float:
    """Heuristic 'how much does this take crack' — pitch instability of voiced f0."""
    try:
        import librosa
        import numpy as np
        f0, _, _ = librosa.pyin(np.asarray(wav, dtype=np.float32).reshape(-1),
                                fmin=65.0, fmax=400.0, sr=sr)
        vf = f0[np.isfinite(f0)]
        return float(np.std(vf)) if vf.size > 5 else 0.0
    except Exception:
        return 0.0


@app.local_entrypoint()
def warm():
    """`modal run modal_app.py` — pre-generate beats and report the break score."""
    print(pregenerate_beats.remote())