import os os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") os.environ["GRADIO_SSR_MODE"] = "false" # serve via Python (FastAPI) directly so /ws works import json import time import uuid import base64 import numpy as np import spaces import torch from gradio import Server from fastapi import Request from fastapi.responses import HTMLResponse from huggingface_hub import hf_hub_download import asyncio, threading, queue as _q, contextvars, struct import gradio as gr from gradio.context import LocalContext from fastapi import WebSocket, WebSocketDisconnect from magenta_rt import paths MAGENTA_HOME = "/data" if os.path.isdir("/data") else "/tmp/magenta" home = os.path.join(MAGENTA_HOME, "magenta-rt-v2") os.makedirs(home, exist_ok=True) for ck in ("mrt2_small.safetensors", "mrt2_base.safetensors"): hf_hub_download("google/magenta-realtime-2", f"checkpoints/{ck}", local_dir=home) paths.set_magenta_home(home) from magenta_rt.torch import MagentaRT2 from magenta_rt.torch.musiccoca import MusicCoCa style_model = MusicCoCa(device="cpu") # Both models fit on ZeroGPU — preload both; the dropdown switches between them. mrt_small = MagentaRT2(size="mrt2_small", device="cuda", dtype=torch.bfloat16, style_model=style_model) mrt_base = MagentaRT2(size="mrt2_base", device="cuda", dtype=torch.bfloat16, style_model=style_model) MODELS = {"mrt2_small": mrt_small, "mrt2_base": mrt_base} for name, repo in (("mrt2_base", "magenta-torch/magenta-rt-aoti"), ("mrt2_small", "magenta-torch/magenta-rt-aoti-small")): try: MODELS[name].load_compiled(repo_id=repo) print(f"Loaded AOTI for {name} from {repo}") except Exception as e: print(f"AOTI load failed for {name} ({repo}) -> eager:", e) SR = 48000 SESSION_DIR = "/tmp/mrt_jam2_sessions" os.makedirs(SESSION_DIR, exist_ok=True) _EMB_CACHE = {} def _embed(label): label = (label or "").strip() or "instrumental music" if label not in _EMB_CACHE: _EMB_CACHE[label] = np.asarray(style_model.embed(label).cpu().numpy(), np.float32) return _EMB_CACHE[label] def _slot(sid): return os.path.join(SESSION_DIR, f"{os.path.basename(sid)}.json") def write_slot(sid, d): tmp = _slot(sid) + "." + uuid.uuid4().hex with open(tmp, "w") as f: json.dump(d, f) os.replace(tmp, _slot(sid)) def read_slot(sid): try: with open(_slot(sid)) as f: return json.load(f) except Exception: return None def read_client_lead(sid): try: with open(os.path.join(SESSION_DIR, f"{os.path.basename(sid)}_buf.txt")) as f: return float(f.read()) except Exception: return None def _apply_set(body): # build + write the conditioning slot (shared by HTTP /set and the WS) sid = body["session_id"] prev = read_slot(sid) or {} write_slot(sid, { "prompts": body.get("prompts") or ["instrumental music"], "weights": body.get("weights") or [1.0], "audio": body.get("audio") or [], "temperature": float(body.get("temperature", 1.1)), "top_k": int(body.get("top_k", 50)), "cfg_musiccoca": float(body.get("cfg", 1.6)), "cfg_notes": float(body.get("cfg_notes", 2.4)), "cfg_drums": float(body.get("cfg_drums", 4.0)), "model": body.get("model", prev.get("model", "mrt2_base")), "buffer": int(body.get("buffer", 0)), "notes": list(body.get("notes", [])), "unmaskwidth": int(body.get("unmaskwidth", 0)), "drumless": bool(body.get("drumless", False)), "onsetmode": bool(body.get("onsetmode", False)), "reset": int(body.get("reset", 0)), "seed": int(body.get("seed", 0)), "bank_op": body.get("bank_op", prev.get("bank_op")), "ts": time.time(), }) def _apply_buffer(sid, lead): sid = os.path.basename(str(sid)) p = os.path.join(SESSION_DIR, f"{sid}_buf.txt") with open(p + ".tmp", "w") as f: f.write(str(float(lead))) os.replace(p + ".tmp", p) app = Server() @app.post("/buffer") async def set_clientbuffer(request: Request): body = await request.json() _apply_buffer(body["session_id"], body.get("lead", 0)) return {"ok": True} @app.post("/set") async def set_collider(request: Request): body = await request.json() _apply_set(body) return {"ok": True} @app.post("/audio") async def set_audio(request: Request): """Store an uploaded clip (native-rate mono float32) per (session, slot); the GPU stream embeds it with MusicCoCa.embed_audio (resampy -> 16kHz).""" body = await request.json() sid = os.path.basename(body["session_id"]); slot = int(body.get("slot", 0)) path = os.path.join(SESSION_DIR, f"{sid}_aud{slot}.bin") samples = body.get("samples") if not samples: try: os.remove(path) except OSError: pass return {"ok": True} raw = base64.b64decode(samples) sr = int(body.get("sample_rate", 48000)) with open(path + ".tmp", "wb") as f: f.write(int(sr).to_bytes(4, "little")); f.write(raw) os.replace(path + ".tmp", path) return {"ok": True} @app.get("/banks") async def banks(session_id: str = ""): sid = os.path.basename(session_id) return {"bankStatus": [os.path.exists(os.path.join(SESSION_DIR, f"{sid}_bank{i}.pt")) for i in range(3)]} @spaces.GPU(duration=90) def gpu_stream(session_id): """Continuous gen; switches model live when the dropdown changes.""" from magenta_rt.torch.system import make_sampler, discretize_cfg, _float_to_int16, FRAME_SAMPLES if style_model.device != "cuda": style_model.to("cuda") dev, dt = "cuda", torch.bfloat16 deadline = time.time() + 8.0 while read_slot(session_id) is None and time.time() < deadline: time.sleep(0.05) gen = torch.Generator(device=dev).manual_seed(0) try: # one-time GPU warm-up so the first audio frame isn't slow style_model.embed("warmup") _wd = mrt_base.model.decoder.init_streaming_f(1, dev, dt) _wc = mrt_base._conditioning([-1] * mrt_base.num_musiccoca, [-1] * 128, [-1], [discretize_cfg(1.6, 0.2, 40), discretize_cfg(2.4, 0.2, 40), discretize_cfg(4.0, 1.0, 8)]) _ws = mrt_base.model.encode(_wc).to(dt) for _ in range(2): mrt_base.model.decoder.step_f(_wd, _ws, sampler=make_sampler(1.1, 50, gen), temporal_step=mrt_base._temporal_step, depth_step=mrt_base._depth_step) print("[warmup] done", flush=True) except Exception as _e: print("[warmup]", repr(_e), flush=True) notes, drums = [-1] * 128, [-1] cur_name = model = dstate = source = None decode_state = {} emitted_samples = 0 t0 = time.time() buf_log, buf_t = [], t0 cur_style_sig = cur_note_sig = cur_tokens = None prev_active = set() cur_reset = cur_seed = 0 had_onsets = False txt_cache, aud_cache = {}, {} CHUNK = 3 # frames per yield (smaller = finer delivery, lower buffer floor) cur_bank_ver = 0 while time.time() - t0 < 55.0: c = read_slot(session_id) if c is None: time.sleep(0.02) continue mname = c.get("model", "mrt2_base") reset = int(c.get("reset", 0)) if mname != cur_name or reset != cur_reset: # model switch / state reset -> re-init cur_name, model, cur_reset = mname, MODELS.get(mname, mrt_base), reset dstate = model.model.decoder.init_streaming_f(1, dev, dt) decode_state = model.init_decode_state() emitted_samples, source, prev_active = 0, None, set() cur_style_sig = cur_note_sig = cur_tokens = None; had_onsets = False seed = int(c.get("seed", 0)) if seed != cur_seed: cur_seed = seed gen = torch.Generator(device=dev).manual_seed(seed) bop = c.get("bank_op") if bop and int(bop.get("ver", 0)) != cur_bank_ver: # save/recall generation state cur_bank_ver = int(bop.get("ver", 0)) bpath = os.path.join(SESSION_DIR, f"{os.path.basename(session_id)}_bank{int(bop.get('idx', 0))}.pt") try: if bop.get("action") == "save": torch.save({"dstate": dstate, "decode_state": decode_state, "emitted": emitted_samples}, bpath) elif bop.get("action") == "load" and os.path.exists(bpath): d = torch.load(bpath, map_location=dev) dstate, decode_state, emitted_samples = d["dstate"], d["decode_state"], int(d["emitted"]) source, cur_note_sig = None, None except Exception as e: print("[bank] error:", repr(e), flush=True) toks = [] gen_t = time.time() for _ in range(CHUNK): # per-frame conditioning (native parity: re-derive every 40ms frame, no debounce) c = read_slot(session_id) or c prompts = c.get("prompts") or ["instrumental music"] weights = c.get("weights") or [1.0] * len(prompts) aflags = (c.get("audio") or []) + [False] * len(prompts) active = set(c.get("notes") or []) unmask = int(c.get("unmaskwidth", 0)) onsetmode = bool(c.get("onsetmode", False)) drumless = bool(c.get("drumless", False)) # ---- style tokens on GPU: embed+blend+tokenize, cached until prompts/weights/audio change ---- style_sig = (tuple(prompts), tuple(round(float(w), 4) for w in weights), tuple(bool(a) for a in aflags[:len(prompts)])) if cur_tokens is None or style_sig != cur_style_sig: cur_style_sig = style_sig emb = torch.zeros(style_model.embedding_dim, device=dev, dtype=torch.float32) tot = float(sum(w for w in weights if w > 0)) or 1.0 for i in range(len(prompts)): w = float(weights[i]) if i < len(weights) else 0.0 if w <= 0: continue if i < len(aflags) and aflags[i]: ap = os.path.join(SESSION_DIR, f"{os.path.basename(session_id)}_aud{i}.bin") try: mt = os.path.getmtime(ap) except OSError: continue ce = aud_cache.get(i) if ce is None or ce[0] != mt: data = open(ap, "rb").read() asr = int.from_bytes(data[:4], "little") samp = np.frombuffer(data[4:], dtype=np.float32) ce = (mt, style_model.embed_audio(samp, asr).float()); aud_cache[i] = ce e = ce[1] else: lbl = (prompts[i] or "").strip() if not lbl: continue e = txt_cache.get(lbl) if e is None: e = style_model.embed(lbl).float(); txt_cache[lbl] = e emb = emb + (w / tot) * e cur_tokens = style_model.tokenize(emb).tolist() # ---- notes + cfg -> re-encode on ANY change every frame (no debounce); onset(2)->sustain(1) via had_onsets ---- note_sig = (tuple(sorted(active)), unmask, onsetmode, drumless, round(float(c.get("cfg_musiccoca", 1.6)), 3), round(float(c.get("cfg_notes", 2.4)), 3), round(float(c.get("cfg_drums", 4.0)), 3), cur_style_sig) if source is None or note_sig != cur_note_sig or had_onsets: onsets = active - prev_active prev_active = set(active) had_onsets = onsetmode and bool(onsets) # carry one frame so the onset settles to sustain next frame cur_note_sig = note_sig nvec = [] # per-pitch: -1 masked / 0 off / 1 cont / 2 onset / 3 on for pitch in range(128): if pitch in active: nvec.append((2 if pitch in onsets else 1) if onsetmode else 3) elif unmask >= 127 or (active and any(abs(pitch - h) <= unmask for h in active)): nvec.append(0) # solo (>=127) or within unmask window -> off else: nvec.append(-1) # masked (model free) drm = [0] if drumless else [-1] # 0 no-drum / -1 masked cfg_notes_v = 7.0 if (unmask >= 127 and not active) else c.get("cfg_notes", 2.4) # solo + no note held -> ramp cfg_notes to max bin (native JamApp suppresses melody) cfgs = [discretize_cfg(c.get("cfg_musiccoca", 1.6), 0.2, 40), discretize_cfg(cfg_notes_v, 0.2, 40), discretize_cfg(c.get("cfg_drums", 4.0), 1.0, 8)] cond = model._conditioning((list(cur_tokens) + [-1] * model.num_musiccoca)[:model.num_musiccoca], nvec, drm, cfgs) source = model.model.encode(cond).to(dt) sampler = make_sampler(c.get("temperature", 1.1), c.get("top_k", 50), gen) toks.append(model.model.decoder.step_f(dstate, source, sampler=sampler, temporal_step=model._temporal_step, depth_step=model._depth_step)) new_codes = torch.cat(toks, dim=1) audio = model.decode_stream(new_codes, decode_state) # FLOP-optimal stateful streaming decode emitted_samples += audio.shape[1] frame_ms = (time.time() - gen_t) * 1000.0 / CHUNK # real per-frame inference time if audio.shape[1] > 0: ab = _float_to_int16(audio[0].float().cpu().numpy()).astype(" resync, no catch-up burst t0 = time.time() - (emitted_samples / SR) + buf ahead = buf buf_log.append(ahead) if time.time() - buf_t >= 15.0 and buf_log: print(f"[buffer] avg={sum(buf_log)/len(buf_log):.2f}s min={min(buf_log):.2f} " f"max={max(buf_log):.2f} target={buf:.2f}s n={len(buf_log)}", flush=True) buf_log, buf_t = [], time.time() sleep_for = max(0.0, ahead - buf) cl = read_client_lead(session_id) # client's real Web-Audio queue (ms) if cl is not None and cl - buf * 1000.0 > 200: # only drain a genuinely-deep queue (target+200ms) sleep_for = max(sleep_for, min((cl - buf * 1000.0) / 1000.0 * 0.3, 0.2)) # gentle + capped if sleep_for > 0: time.sleep(min(sleep_for, 0.5)) _sessions = {} _slock = threading.Lock() def _worker(session_id, audio_q, stop_ev): try: while not stop_ev.is_set(): try: for ab, fm in gpu_stream(session_id): # ~55s GPU grant; re-grants on loop if stop_ev.is_set(): break if audio_q.full(): try: audio_q.get_nowait() except _q.Empty: pass try: audio_q.put_nowait((ab, fm)) except _q.Full: pass except Exception as e: if "abort" in str(e).lower() or "duration" in str(e).lower(): continue # grant expired -> re-grant print("[worker]", repr(e), flush=True); break finally: stop_ev.set() @app.api(name="start") def start(session_id: str = "", request: gr.Request = None) -> str: if not session_id: return "" req = request or LocalContext.request.get(None) # user's gr.Request -> carries X-IP-Token (ZeroGPU quota) with _slock: prev = _sessions.pop(session_id, None) if prev: prev["stop"].set() audio_q = _q.Queue(maxsize=48); stop_ev = threading.Event() def run(): if req is not None: LocalContext.request.set(req) # graft user request so worker @spaces.GPU bills the USER _worker(session_id, audio_q, stop_ev) ctx = contextvars.copy_context() t = threading.Thread(target=ctx.run, args=(run,), daemon=True) with _slock: _sessions[session_id] = {"audio": audio_q, "stop": stop_ev} t.start() return session_id @app.websocket("/ws") async def audio_ws(websocket: WebSocket, session_id: str = ""): await websocket.accept() sess = _sessions.get(session_id) if not sess: await websocket.close(code=1008); return audio_q, stop_ev = sess["audio"], sess["stop"] loop = asyncio.get_event_loop() async def send_audio(): # server -> client: binary audio frames try: while not stop_ev.is_set(): try: ab, fm = await loop.run_in_executor(None, lambda: audio_q.get(timeout=0.5)) except Exception: continue await websocket.send_bytes(struct.pack(" server: low-latency steering (notes/params), no HTTP RTT try: while not stop_ev.is_set(): msg = await websocket.receive_json() # raises WebSocketDisconnect on close t = msg.get("type") if t == "set": try: _apply_set(msg) except Exception as e: print("[ws set]", repr(e), flush=True) elif t == "buffer": try: _apply_buffer(msg.get("session_id", session_id), msg.get("lead", 0)) except Exception: pass finally: stop_ev.set() try: await asyncio.gather(send_audio(), recv_control(), return_exceptions=True) finally: stop_ev.set() with _slock: _sessions.pop(session_id, None) @app.get("/") async def index(): with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")) as f: return HTMLResponse(f.read()) app.launch(show_error=True, ssr_mode=False)