apolinario
jam(dev): log frame_ms + cudagraph/eager mode in buffer line
19fe543
import os
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
os.environ["GRADIO_SSR_MODE"] = "false" # serve via Python (FastAPI) directly so /ws works
import json
import time
import uuid
import base64
import numpy as np
import spaces
import torch
from gradio import Server
from fastapi import Request
from fastapi.responses import HTMLResponse
from huggingface_hub import hf_hub_download
import asyncio, threading, queue as _q, contextvars, struct
import gradio as gr
from gradio.context import LocalContext
from fastapi import WebSocket, WebSocketDisconnect
from magenta_rt import paths
MAGENTA_HOME = "/data" if os.path.isdir("/data") else "/tmp/magenta"
home = os.path.join(MAGENTA_HOME, "magenta-rt-v2")
os.makedirs(home, exist_ok=True)
for ck in ("mrt2_small.safetensors", "mrt2_base.safetensors"):
hf_hub_download("google/magenta-realtime-2", f"checkpoints/{ck}", local_dir=home)
paths.set_magenta_home(home)
from magenta_rt.torch import MagentaRT2
from magenta_rt.torch.musiccoca import MusicCoCa
style_model = MusicCoCa(device="cpu")
# Both models fit on ZeroGPU — preload both; the dropdown switches between them.
mrt_small = MagentaRT2(size="mrt2_small", device="cuda", dtype=torch.bfloat16, style_model=style_model)
mrt_base = MagentaRT2(size="mrt2_base", device="cuda", dtype=torch.bfloat16, style_model=style_model)
MODELS = {"mrt2_small": mrt_small, "mrt2_base": mrt_base}
for name, repo in (("mrt2_base", "magenta-torch/magenta-rt-aoti"),
("mrt2_small", "magenta-torch/magenta-rt-aoti-small")):
try:
MODELS[name].load_compiled(repo_id=repo)
print(f"Loaded AOTI for {name} from {repo}")
except Exception as e:
print(f"AOTI load failed for {name} ({repo}) -> eager:", e)
SR = 48000
SESSION_DIR = "/tmp/mrt_jam2_sessions"
os.makedirs(SESSION_DIR, exist_ok=True)
_EMB_CACHE = {}
def _embed(label):
label = (label or "").strip() or "instrumental music"
if label not in _EMB_CACHE:
_EMB_CACHE[label] = np.asarray(style_model.embed(label).cpu().numpy(), np.float32)
return _EMB_CACHE[label]
def _slot(sid):
return os.path.join(SESSION_DIR, f"{os.path.basename(sid)}.json")
def write_slot(sid, d):
tmp = _slot(sid) + "." + uuid.uuid4().hex
with open(tmp, "w") as f:
json.dump(d, f)
os.replace(tmp, _slot(sid))
def read_slot(sid):
try:
with open(_slot(sid)) as f:
return json.load(f)
except Exception:
return None
def read_client_lead(sid):
try:
with open(os.path.join(SESSION_DIR, f"{os.path.basename(sid)}_buf.txt")) as f:
return float(f.read())
except Exception:
return None
def _apply_set(body): # build + write the conditioning slot (shared by HTTP /set and the WS)
sid = body["session_id"]
prev = read_slot(sid) or {}
write_slot(sid, {
"prompts": body.get("prompts") or ["instrumental music"],
"weights": body.get("weights") or [1.0],
"audio": body.get("audio") or [],
"temperature": float(body.get("temperature", 1.1)),
"top_k": int(body.get("top_k", 50)),
"cfg_musiccoca": float(body.get("cfg", 1.6)),
"cfg_notes": float(body.get("cfg_notes", 2.4)),
"cfg_drums": float(body.get("cfg_drums", 4.0)),
"model": body.get("model", prev.get("model", "mrt2_base")),
"buffer": int(body.get("buffer", 0)),
"notes": list(body.get("notes", [])),
"unmaskwidth": int(body.get("unmaskwidth", 0)),
"drumless": bool(body.get("drumless", False)),
"onsetmode": bool(body.get("onsetmode", False)),
"reset": int(body.get("reset", 0)),
"seed": int(body.get("seed", 0)),
"bank_op": body.get("bank_op", prev.get("bank_op")),
"ts": time.time(),
})
def _apply_buffer(sid, lead):
sid = os.path.basename(str(sid))
p = os.path.join(SESSION_DIR, f"{sid}_buf.txt")
with open(p + ".tmp", "w") as f:
f.write(str(float(lead)))
os.replace(p + ".tmp", p)
app = Server()
@app.post("/buffer")
async def set_clientbuffer(request: Request):
body = await request.json()
_apply_buffer(body["session_id"], body.get("lead", 0))
return {"ok": True}
@app.post("/set")
async def set_collider(request: Request):
body = await request.json()
_apply_set(body)
return {"ok": True}
@app.post("/audio")
async def set_audio(request: Request):
"""Store an uploaded clip (native-rate mono float32) per (session, slot);
the GPU stream embeds it with MusicCoCa.embed_audio (resampy -> 16kHz)."""
body = await request.json()
sid = os.path.basename(body["session_id"]); slot = int(body.get("slot", 0))
path = os.path.join(SESSION_DIR, f"{sid}_aud{slot}.bin")
samples = body.get("samples")
if not samples:
try: os.remove(path)
except OSError: pass
return {"ok": True}
raw = base64.b64decode(samples)
sr = int(body.get("sample_rate", 48000))
with open(path + ".tmp", "wb") as f:
f.write(int(sr).to_bytes(4, "little")); f.write(raw)
os.replace(path + ".tmp", path)
return {"ok": True}
@app.get("/banks")
async def banks(session_id: str = ""):
sid = os.path.basename(session_id)
return {"bankStatus": [os.path.exists(os.path.join(SESSION_DIR, f"{sid}_bank{i}.pt")) for i in range(3)]}
@spaces.GPU(duration=90)
def gpu_stream(session_id):
"""Continuous gen; switches model live when the dropdown changes."""
from magenta_rt.torch.system import make_sampler, discretize_cfg, _float_to_int16, FRAME_SAMPLES
from magenta_rt.torch.cudagraph import CudaGraphStreamer
USE_CG = os.environ.get('MRT_CUDAGRAPH', '1') == '1' # single-dispatch CUDA-graph stepping (eager fallback)
cg_ok = True
if style_model.device != "cuda":
style_model.to("cuda")
dev, dt = "cuda", torch.bfloat16
deadline = time.time() + 8.0
while read_slot(session_id) is None and time.time() < deadline:
time.sleep(0.05)
gen = torch.Generator(device=dev).manual_seed(0)
try: # one-time GPU warm-up so the first audio frame isn't slow
style_model.embed("warmup")
_wd = mrt_base.model.decoder.init_streaming_f(1, dev, dt)
_wc = mrt_base._conditioning([-1] * mrt_base.num_musiccoca, [-1] * 128, [-1],
[discretize_cfg(1.6, 0.2, 40), discretize_cfg(2.4, 0.2, 40), discretize_cfg(4.0, 1.0, 8)])
_ws = mrt_base.model.encode(_wc).to(dt)
for _ in range(2):
mrt_base.model.decoder.step_f(_wd, _ws, sampler=make_sampler(1.1, 50, gen),
temporal_step=mrt_base._temporal_step, depth_step=mrt_base._depth_step)
print("[warmup] done", flush=True)
except Exception as _e:
print("[warmup]", repr(_e), flush=True)
notes, drums = [-1] * 128, [-1]
cur_name = model = dstate = source = None
streamer = last_src_for_graph = None
decode_state = {}
emitted_samples = 0
t0 = time.time()
buf_log, buf_t = [], t0
cur_style_sig = cur_note_sig = cur_tokens = None
prev_active = set()
cur_reset = cur_seed = 0
had_onsets = False
txt_cache, aud_cache = {}, {}
CHUNK = 3 # frames per yield (smaller = finer delivery, lower buffer floor)
cur_bank_ver = 0
while time.time() - t0 < 55.0:
c = read_slot(session_id)
if c is None:
time.sleep(0.02)
continue
mname = c.get("model", "mrt2_base")
reset = int(c.get("reset", 0))
if mname != cur_name or reset != cur_reset: # model switch / state reset -> re-init
cur_name, model, cur_reset = mname, MODELS.get(mname, mrt_base), reset
dstate = model.model.decoder.init_streaming_f(1, dev, dt)
decode_state = model.init_decode_state()
emitted_samples, source, prev_active = 0, None, set()
cur_style_sig = cur_note_sig = cur_tokens = None; had_onsets = False
streamer = last_src_for_graph = None # rebuild CUDA graph on model switch / reset
seed = int(c.get("seed", 0))
if seed != cur_seed:
cur_seed = seed
gen = torch.Generator(device=dev).manual_seed(seed)
streamer = None # re-seed => re-capture (graph RNG fixed at capture)
bop = c.get("bank_op")
if bop and int(bop.get("ver", 0)) != cur_bank_ver and not USE_CG: # save/recall (eager only; cudagraph KV is static)
cur_bank_ver = int(bop.get("ver", 0))
bpath = os.path.join(SESSION_DIR, f"{os.path.basename(session_id)}_bank{int(bop.get('idx', 0))}.pt")
try:
if bop.get("action") == "save":
torch.save({"dstate": dstate, "decode_state": decode_state, "emitted": emitted_samples}, bpath)
elif bop.get("action") == "load" and os.path.exists(bpath):
d = torch.load(bpath, map_location=dev)
dstate, decode_state, emitted_samples = d["dstate"], d["decode_state"], int(d["emitted"])
source, cur_note_sig = None, None
except Exception as e:
print("[bank] error:", repr(e), flush=True)
toks = []
gen_t = time.time()
for _ in range(CHUNK): # per-frame conditioning (native parity: re-derive every 40ms frame, no debounce)
c = read_slot(session_id) or c
prompts = c.get("prompts") or ["instrumental music"]
weights = c.get("weights") or [1.0] * len(prompts)
aflags = (c.get("audio") or []) + [False] * len(prompts)
active = set(c.get("notes") or [])
unmask = int(c.get("unmaskwidth", 0))
onsetmode = bool(c.get("onsetmode", False))
drumless = bool(c.get("drumless", False))
# ---- style tokens on GPU: embed+blend+tokenize, cached until prompts/weights/audio change ----
style_sig = (tuple(prompts), tuple(round(float(w), 4) for w in weights),
tuple(bool(a) for a in aflags[:len(prompts)]))
if cur_tokens is None or style_sig != cur_style_sig:
cur_style_sig = style_sig
emb = torch.zeros(style_model.embedding_dim, device=dev, dtype=torch.float32)
tot = float(sum(w for w in weights if w > 0)) or 1.0
for i in range(len(prompts)):
w = float(weights[i]) if i < len(weights) else 0.0
if w <= 0:
continue
if i < len(aflags) and aflags[i]:
ap = os.path.join(SESSION_DIR, f"{os.path.basename(session_id)}_aud{i}.bin")
try: mt = os.path.getmtime(ap)
except OSError: continue
ce = aud_cache.get(i)
if ce is None or ce[0] != mt:
data = open(ap, "rb").read()
asr = int.from_bytes(data[:4], "little")
samp = np.frombuffer(data[4:], dtype=np.float32)
ce = (mt, style_model.embed_audio(samp, asr).float()); aud_cache[i] = ce
e = ce[1]
else:
lbl = (prompts[i] or "").strip()
if not lbl:
continue
e = txt_cache.get(lbl)
if e is None:
e = style_model.embed(lbl).float(); txt_cache[lbl] = e
emb = emb + (w / tot) * e
cur_tokens = style_model.tokenize(emb).tolist()
# ---- notes + cfg -> re-encode on ANY change every frame (no debounce); onset(2)->sustain(1) via had_onsets ----
note_sig = (tuple(sorted(active)), unmask, onsetmode, drumless,
round(float(c.get("cfg_musiccoca", 1.6)), 3), round(float(c.get("cfg_notes", 2.4)), 3),
round(float(c.get("cfg_drums", 4.0)), 3), cur_style_sig)
if source is None or note_sig != cur_note_sig or had_onsets:
onsets = active - prev_active
prev_active = set(active)
had_onsets = onsetmode and bool(onsets) # carry one frame so the onset settles to sustain next frame
cur_note_sig = note_sig
nvec = [] # per-pitch: -1 masked / 0 off / 1 cont / 2 onset / 3 on
for pitch in range(128):
if pitch in active:
nvec.append((2 if pitch in onsets else 1) if onsetmode else 3)
elif unmask >= 127 or (active and any(abs(pitch - h) <= unmask for h in active)):
nvec.append(0) # solo (>=127) or within unmask window -> off
else:
nvec.append(-1) # masked (model free)
drm = [0] if drumless else [-1] # 0 no-drum / -1 masked
cfg_notes_v = 7.0 if (unmask >= 127 and not active) else c.get("cfg_notes", 2.4) # solo + no note held -> ramp cfg_notes to max bin (native JamApp suppresses melody)
cfgs = [discretize_cfg(c.get("cfg_musiccoca", 1.6), 0.2, 40),
discretize_cfg(cfg_notes_v, 0.2, 40),
discretize_cfg(c.get("cfg_drums", 4.0), 1.0, 8)]
cond = model._conditioning((list(cur_tokens) + [-1] * model.num_musiccoca)[:model.num_musiccoca],
nvec, drm, cfgs)
source = model.model.encode(cond).to(dt)
temp = c.get("temperature", 1.1); topk = int(c.get("top_k", 50))
ok = False
if USE_CG and cg_ok: # single-dispatch CUDA-graph step
try:
if streamer is None: # build + capture on first frame (~2-3s warmup)
streamer = CudaGraphStreamer(model.model.decoder, source, dt,
temperature=temp, top_k=topk, seed=cur_seed)
last_src_for_graph = source
elif source is not last_src_for_graph: # conditioning changed -> update static buffer
streamer.set_source(source); last_src_for_graph = source
streamer.set_temperature(temp)
toks.append(streamer.step()); ok = True
except Exception as _cge:
print("[cudagraph] fallback to eager:", repr(_cge), flush=True)
cg_ok = False; streamer = None
dstate = model.model.decoder.init_streaming_f(1, dev, dt)
if not ok: # eager fallback path
sampler = make_sampler(temp, topk, gen)
toks.append(model.model.decoder.step_f(dstate, source, sampler=sampler,
temporal_step=model._temporal_step, depth_step=model._depth_step))
new_codes = torch.cat(toks, dim=1)
audio = model.decode_stream(new_codes, decode_state) # FLOP-optimal stateful streaming decode
emitted_samples += audio.shape[1]
frame_ms = (time.time() - gen_t) * 1000.0 / CHUNK # real per-frame inference time
if audio.shape[1] > 0:
ab = _float_to_int16(audio[0].float().cpu().numpy()).astype("<i2").tobytes()
yield (ab, frame_ms)
buf = (0.15, 0.3, 0.5)[max(0, min(2, int(c.get("buffer", 0))))]
ahead = (emitted_samples / SR) - (time.time() - t0)
if ahead < -0.3: # fell behind (cold-start/hiccup) -> resync, no catch-up burst
t0 = time.time() - (emitted_samples / SR) + buf
ahead = buf
buf_log.append(ahead)
if time.time() - buf_t >= 15.0 and buf_log:
print(f"[buffer] avg={sum(buf_log)/len(buf_log):.2f}s min={min(buf_log):.2f} "
f"max={max(buf_log):.2f} target={buf:.2f}s n={len(buf_log)} "
f"frame={frame_ms:.1f}ms mode={'cudagraph' if (USE_CG and cg_ok) else 'eager'}", flush=True)
buf_log, buf_t = [], time.time()
sleep_for = max(0.0, ahead - buf)
cl = read_client_lead(session_id) # client's real Web-Audio queue (ms)
if cl is not None and cl - buf * 1000.0 > 200: # only drain a genuinely-deep queue (target+200ms)
sleep_for = max(sleep_for, min((cl - buf * 1000.0) / 1000.0 * 0.3, 0.2)) # gentle + capped
if sleep_for > 0:
time.sleep(min(sleep_for, 0.5))
_sessions = {}
_slock = threading.Lock()
def _worker(session_id, audio_q, stop_ev):
try:
while not stop_ev.is_set():
try:
for ab, fm in gpu_stream(session_id): # ~55s GPU grant; re-grants on loop
if stop_ev.is_set():
break
if audio_q.full():
try: audio_q.get_nowait()
except _q.Empty: pass
try: audio_q.put_nowait((ab, fm))
except _q.Full: pass
except Exception as e:
if "abort" in str(e).lower() or "duration" in str(e).lower():
continue # grant expired -> re-grant
print("[worker]", repr(e), flush=True); break
finally:
stop_ev.set()
@app.api(name="start")
def start(session_id: str = "", request: gr.Request = None) -> str:
if not session_id:
return ""
req = request or LocalContext.request.get(None) # user's gr.Request -> carries X-IP-Token (ZeroGPU quota)
with _slock:
prev = _sessions.pop(session_id, None)
if prev:
prev["stop"].set()
audio_q = _q.Queue(maxsize=48); stop_ev = threading.Event()
def run():
if req is not None:
LocalContext.request.set(req) # graft user request so worker @spaces.GPU bills the USER
_worker(session_id, audio_q, stop_ev)
ctx = contextvars.copy_context()
t = threading.Thread(target=ctx.run, args=(run,), daemon=True)
with _slock:
_sessions[session_id] = {"audio": audio_q, "stop": stop_ev}
t.start()
return session_id
@app.websocket("/ws")
async def audio_ws(websocket: WebSocket, session_id: str = ""):
await websocket.accept()
sess = _sessions.get(session_id)
if not sess:
await websocket.close(code=1008); return
audio_q, stop_ev = sess["audio"], sess["stop"]
loop = asyncio.get_event_loop()
async def send_audio(): # server -> client: binary audio frames
try:
while not stop_ev.is_set():
try:
ab, fm = await loop.run_in_executor(None, lambda: audio_q.get(timeout=0.5))
except Exception:
continue
await websocket.send_bytes(struct.pack("<f", fm) + ab)
try: await websocket.send_json({"type": "ended"}) # worker ended (quota/grant)
except Exception: pass
finally:
stop_ev.set()
async def recv_control(): # client -> server: low-latency steering (notes/params), no HTTP RTT
try:
while not stop_ev.is_set():
msg = await websocket.receive_json() # raises WebSocketDisconnect on close
t = msg.get("type")
if t == "set":
try: _apply_set(msg)
except Exception as e: print("[ws set]", repr(e), flush=True)
elif t == "buffer":
try: _apply_buffer(msg.get("session_id", session_id), msg.get("lead", 0))
except Exception: pass
finally:
stop_ev.set()
try:
await asyncio.gather(send_audio(), recv_control(), return_exceptions=True)
finally:
stop_ev.set()
with _slock:
_sessions.pop(session_id, None)
@app.get("/")
async def index():
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")) as f:
return HTMLResponse(f.read())
app.launch(show_error=True, ssr_mode=False)