Spaces:
Running on Zero
Running on Zero
| import os | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| os.environ["GRADIO_SSR_MODE"] = "false" # serve via Python (FastAPI) directly so /ws works | |
| import json | |
| import time | |
| import uuid | |
| import base64 | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from gradio import Server | |
| from fastapi import Request | |
| from fastapi.responses import HTMLResponse | |
| from huggingface_hub import hf_hub_download | |
| import asyncio, threading, queue as _q, contextvars, struct | |
| import gradio as gr | |
| from gradio.context import LocalContext | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| from magenta_rt import paths | |
| MAGENTA_HOME = "/data" if os.path.isdir("/data") else "/tmp/magenta" | |
| home = os.path.join(MAGENTA_HOME, "magenta-rt-v2") | |
| os.makedirs(home, exist_ok=True) | |
| for ck in ("mrt2_small.safetensors", "mrt2_base.safetensors"): | |
| hf_hub_download("google/magenta-realtime-2", f"checkpoints/{ck}", local_dir=home) | |
| paths.set_magenta_home(home) | |
| from magenta_rt.torch import MagentaRT2 | |
| from magenta_rt.torch.musiccoca import MusicCoCa | |
| style_model = MusicCoCa(device="cpu") | |
| # Both models fit on ZeroGPU — preload both; the dropdown switches between them. | |
| mrt_small = MagentaRT2(size="mrt2_small", device="cuda", dtype=torch.bfloat16, style_model=style_model) | |
| mrt_base = MagentaRT2(size="mrt2_base", device="cuda", dtype=torch.bfloat16, style_model=style_model) | |
| MODELS = {"mrt2_small": mrt_small, "mrt2_base": mrt_base} | |
| for name, repo in (("mrt2_base", "magenta-torch/magenta-rt-aoti"), | |
| ("mrt2_small", "magenta-torch/magenta-rt-aoti-small")): | |
| try: | |
| MODELS[name].load_compiled(repo_id=repo) | |
| print(f"Loaded AOTI for {name} from {repo}") | |
| except Exception as e: | |
| print(f"AOTI load failed for {name} ({repo}) -> eager:", e) | |
| SR = 48000 | |
| SESSION_DIR = "/tmp/mrt_jam2_sessions" | |
| os.makedirs(SESSION_DIR, exist_ok=True) | |
| _EMB_CACHE = {} | |
| def _embed(label): | |
| label = (label or "").strip() or "instrumental music" | |
| if label not in _EMB_CACHE: | |
| _EMB_CACHE[label] = np.asarray(style_model.embed(label).cpu().numpy(), np.float32) | |
| return _EMB_CACHE[label] | |
| def _slot(sid): | |
| return os.path.join(SESSION_DIR, f"{os.path.basename(sid)}.json") | |
| def write_slot(sid, d): | |
| tmp = _slot(sid) + "." + uuid.uuid4().hex | |
| with open(tmp, "w") as f: | |
| json.dump(d, f) | |
| os.replace(tmp, _slot(sid)) | |
| def read_slot(sid): | |
| try: | |
| with open(_slot(sid)) as f: | |
| return json.load(f) | |
| except Exception: | |
| return None | |
| def read_client_lead(sid): | |
| try: | |
| with open(os.path.join(SESSION_DIR, f"{os.path.basename(sid)}_buf.txt")) as f: | |
| return float(f.read()) | |
| except Exception: | |
| return None | |
| def _apply_set(body): # build + write the conditioning slot (shared by HTTP /set and the WS) | |
| sid = body["session_id"] | |
| prev = read_slot(sid) or {} | |
| write_slot(sid, { | |
| "prompts": body.get("prompts") or ["instrumental music"], | |
| "weights": body.get("weights") or [1.0], | |
| "audio": body.get("audio") or [], | |
| "temperature": float(body.get("temperature", 1.1)), | |
| "top_k": int(body.get("top_k", 50)), | |
| "cfg_musiccoca": float(body.get("cfg", 1.6)), | |
| "cfg_notes": float(body.get("cfg_notes", 2.4)), | |
| "cfg_drums": float(body.get("cfg_drums", 4.0)), | |
| "model": body.get("model", prev.get("model", "mrt2_base")), | |
| "buffer": int(body.get("buffer", 0)), | |
| "notes": list(body.get("notes", [])), | |
| "unmaskwidth": int(body.get("unmaskwidth", 0)), | |
| "drumless": bool(body.get("drumless", False)), | |
| "onsetmode": bool(body.get("onsetmode", False)), | |
| "reset": int(body.get("reset", 0)), | |
| "seed": int(body.get("seed", 0)), | |
| "bank_op": body.get("bank_op", prev.get("bank_op")), | |
| "ts": time.time(), | |
| }) | |
| def _apply_buffer(sid, lead): | |
| sid = os.path.basename(str(sid)) | |
| p = os.path.join(SESSION_DIR, f"{sid}_buf.txt") | |
| with open(p + ".tmp", "w") as f: | |
| f.write(str(float(lead))) | |
| os.replace(p + ".tmp", p) | |
| app = Server() | |
| async def set_clientbuffer(request: Request): | |
| body = await request.json() | |
| _apply_buffer(body["session_id"], body.get("lead", 0)) | |
| return {"ok": True} | |
| async def set_collider(request: Request): | |
| body = await request.json() | |
| _apply_set(body) | |
| return {"ok": True} | |
| async def set_audio(request: Request): | |
| """Store an uploaded clip (native-rate mono float32) per (session, slot); | |
| the GPU stream embeds it with MusicCoCa.embed_audio (resampy -> 16kHz).""" | |
| body = await request.json() | |
| sid = os.path.basename(body["session_id"]); slot = int(body.get("slot", 0)) | |
| path = os.path.join(SESSION_DIR, f"{sid}_aud{slot}.bin") | |
| samples = body.get("samples") | |
| if not samples: | |
| try: os.remove(path) | |
| except OSError: pass | |
| return {"ok": True} | |
| raw = base64.b64decode(samples) | |
| sr = int(body.get("sample_rate", 48000)) | |
| with open(path + ".tmp", "wb") as f: | |
| f.write(int(sr).to_bytes(4, "little")); f.write(raw) | |
| os.replace(path + ".tmp", path) | |
| return {"ok": True} | |
| async def banks(session_id: str = ""): | |
| sid = os.path.basename(session_id) | |
| return {"bankStatus": [os.path.exists(os.path.join(SESSION_DIR, f"{sid}_bank{i}.pt")) for i in range(3)]} | |
| def gpu_stream(session_id): | |
| """Continuous gen; switches model live when the dropdown changes.""" | |
| from magenta_rt.torch.system import make_sampler, discretize_cfg, _float_to_int16, FRAME_SAMPLES | |
| from magenta_rt.torch.cudagraph import CudaGraphStreamer | |
| USE_CG = os.environ.get('MRT_CUDAGRAPH', '1') == '1' # single-dispatch CUDA-graph stepping (eager fallback) | |
| cg_ok = True | |
| if style_model.device != "cuda": | |
| style_model.to("cuda") | |
| dev, dt = "cuda", torch.bfloat16 | |
| deadline = time.time() + 8.0 | |
| while read_slot(session_id) is None and time.time() < deadline: | |
| time.sleep(0.05) | |
| gen = torch.Generator(device=dev).manual_seed(0) | |
| try: # one-time GPU warm-up so the first audio frame isn't slow | |
| style_model.embed("warmup") | |
| _wd = mrt_base.model.decoder.init_streaming_f(1, dev, dt) | |
| _wc = mrt_base._conditioning([-1] * mrt_base.num_musiccoca, [-1] * 128, [-1], | |
| [discretize_cfg(1.6, 0.2, 40), discretize_cfg(2.4, 0.2, 40), discretize_cfg(4.0, 1.0, 8)]) | |
| _ws = mrt_base.model.encode(_wc).to(dt) | |
| for _ in range(2): | |
| mrt_base.model.decoder.step_f(_wd, _ws, sampler=make_sampler(1.1, 50, gen), | |
| temporal_step=mrt_base._temporal_step, depth_step=mrt_base._depth_step) | |
| print("[warmup] done", flush=True) | |
| except Exception as _e: | |
| print("[warmup]", repr(_e), flush=True) | |
| notes, drums = [-1] * 128, [-1] | |
| cur_name = model = dstate = source = None | |
| streamer = last_src_for_graph = None | |
| decode_state = {} | |
| emitted_samples = 0 | |
| t0 = time.time() | |
| buf_log, buf_t = [], t0 | |
| cur_style_sig = cur_note_sig = cur_tokens = None | |
| prev_active = set() | |
| cur_reset = cur_seed = 0 | |
| had_onsets = False | |
| txt_cache, aud_cache = {}, {} | |
| CHUNK = 3 # frames per yield (smaller = finer delivery, lower buffer floor) | |
| cur_bank_ver = 0 | |
| while time.time() - t0 < 55.0: | |
| c = read_slot(session_id) | |
| if c is None: | |
| time.sleep(0.02) | |
| continue | |
| mname = c.get("model", "mrt2_base") | |
| reset = int(c.get("reset", 0)) | |
| if mname != cur_name or reset != cur_reset: # model switch / state reset -> re-init | |
| cur_name, model, cur_reset = mname, MODELS.get(mname, mrt_base), reset | |
| dstate = model.model.decoder.init_streaming_f(1, dev, dt) | |
| decode_state = model.init_decode_state() | |
| emitted_samples, source, prev_active = 0, None, set() | |
| cur_style_sig = cur_note_sig = cur_tokens = None; had_onsets = False | |
| streamer = last_src_for_graph = None # rebuild CUDA graph on model switch / reset | |
| seed = int(c.get("seed", 0)) | |
| if seed != cur_seed: | |
| cur_seed = seed | |
| gen = torch.Generator(device=dev).manual_seed(seed) | |
| streamer = None # re-seed => re-capture (graph RNG fixed at capture) | |
| bop = c.get("bank_op") | |
| if bop and int(bop.get("ver", 0)) != cur_bank_ver and not USE_CG: # save/recall (eager only; cudagraph KV is static) | |
| cur_bank_ver = int(bop.get("ver", 0)) | |
| bpath = os.path.join(SESSION_DIR, f"{os.path.basename(session_id)}_bank{int(bop.get('idx', 0))}.pt") | |
| try: | |
| if bop.get("action") == "save": | |
| torch.save({"dstate": dstate, "decode_state": decode_state, "emitted": emitted_samples}, bpath) | |
| elif bop.get("action") == "load" and os.path.exists(bpath): | |
| d = torch.load(bpath, map_location=dev) | |
| dstate, decode_state, emitted_samples = d["dstate"], d["decode_state"], int(d["emitted"]) | |
| source, cur_note_sig = None, None | |
| except Exception as e: | |
| print("[bank] error:", repr(e), flush=True) | |
| toks = [] | |
| gen_t = time.time() | |
| for _ in range(CHUNK): # per-frame conditioning (native parity: re-derive every 40ms frame, no debounce) | |
| c = read_slot(session_id) or c | |
| prompts = c.get("prompts") or ["instrumental music"] | |
| weights = c.get("weights") or [1.0] * len(prompts) | |
| aflags = (c.get("audio") or []) + [False] * len(prompts) | |
| active = set(c.get("notes") or []) | |
| unmask = int(c.get("unmaskwidth", 0)) | |
| onsetmode = bool(c.get("onsetmode", False)) | |
| drumless = bool(c.get("drumless", False)) | |
| # ---- style tokens on GPU: embed+blend+tokenize, cached until prompts/weights/audio change ---- | |
| style_sig = (tuple(prompts), tuple(round(float(w), 4) for w in weights), | |
| tuple(bool(a) for a in aflags[:len(prompts)])) | |
| if cur_tokens is None or style_sig != cur_style_sig: | |
| cur_style_sig = style_sig | |
| emb = torch.zeros(style_model.embedding_dim, device=dev, dtype=torch.float32) | |
| tot = float(sum(w for w in weights if w > 0)) or 1.0 | |
| for i in range(len(prompts)): | |
| w = float(weights[i]) if i < len(weights) else 0.0 | |
| if w <= 0: | |
| continue | |
| if i < len(aflags) and aflags[i]: | |
| ap = os.path.join(SESSION_DIR, f"{os.path.basename(session_id)}_aud{i}.bin") | |
| try: mt = os.path.getmtime(ap) | |
| except OSError: continue | |
| ce = aud_cache.get(i) | |
| if ce is None or ce[0] != mt: | |
| data = open(ap, "rb").read() | |
| asr = int.from_bytes(data[:4], "little") | |
| samp = np.frombuffer(data[4:], dtype=np.float32) | |
| ce = (mt, style_model.embed_audio(samp, asr).float()); aud_cache[i] = ce | |
| e = ce[1] | |
| else: | |
| lbl = (prompts[i] or "").strip() | |
| if not lbl: | |
| continue | |
| e = txt_cache.get(lbl) | |
| if e is None: | |
| e = style_model.embed(lbl).float(); txt_cache[lbl] = e | |
| emb = emb + (w / tot) * e | |
| cur_tokens = style_model.tokenize(emb).tolist() | |
| # ---- notes + cfg -> re-encode on ANY change every frame (no debounce); onset(2)->sustain(1) via had_onsets ---- | |
| note_sig = (tuple(sorted(active)), unmask, onsetmode, drumless, | |
| round(float(c.get("cfg_musiccoca", 1.6)), 3), round(float(c.get("cfg_notes", 2.4)), 3), | |
| round(float(c.get("cfg_drums", 4.0)), 3), cur_style_sig) | |
| if source is None or note_sig != cur_note_sig or had_onsets: | |
| onsets = active - prev_active | |
| prev_active = set(active) | |
| had_onsets = onsetmode and bool(onsets) # carry one frame so the onset settles to sustain next frame | |
| cur_note_sig = note_sig | |
| nvec = [] # per-pitch: -1 masked / 0 off / 1 cont / 2 onset / 3 on | |
| for pitch in range(128): | |
| if pitch in active: | |
| nvec.append((2 if pitch in onsets else 1) if onsetmode else 3) | |
| elif unmask >= 127 or (active and any(abs(pitch - h) <= unmask for h in active)): | |
| nvec.append(0) # solo (>=127) or within unmask window -> off | |
| else: | |
| nvec.append(-1) # masked (model free) | |
| drm = [0] if drumless else [-1] # 0 no-drum / -1 masked | |
| cfg_notes_v = 7.0 if (unmask >= 127 and not active) else c.get("cfg_notes", 2.4) # solo + no note held -> ramp cfg_notes to max bin (native JamApp suppresses melody) | |
| cfgs = [discretize_cfg(c.get("cfg_musiccoca", 1.6), 0.2, 40), | |
| discretize_cfg(cfg_notes_v, 0.2, 40), | |
| discretize_cfg(c.get("cfg_drums", 4.0), 1.0, 8)] | |
| cond = model._conditioning((list(cur_tokens) + [-1] * model.num_musiccoca)[:model.num_musiccoca], | |
| nvec, drm, cfgs) | |
| source = model.model.encode(cond).to(dt) | |
| temp = c.get("temperature", 1.1); topk = int(c.get("top_k", 50)) | |
| ok = False | |
| if USE_CG and cg_ok: # single-dispatch CUDA-graph step | |
| try: | |
| if streamer is None: # build + capture on first frame (~2-3s warmup) | |
| streamer = CudaGraphStreamer(model.model.decoder, source, dt, | |
| temperature=temp, top_k=topk, seed=cur_seed) | |
| last_src_for_graph = source | |
| elif source is not last_src_for_graph: # conditioning changed -> update static buffer | |
| streamer.set_source(source); last_src_for_graph = source | |
| streamer.set_temperature(temp) | |
| toks.append(streamer.step()); ok = True | |
| except Exception as _cge: | |
| print("[cudagraph] fallback to eager:", repr(_cge), flush=True) | |
| cg_ok = False; streamer = None | |
| dstate = model.model.decoder.init_streaming_f(1, dev, dt) | |
| if not ok: # eager fallback path | |
| sampler = make_sampler(temp, topk, gen) | |
| toks.append(model.model.decoder.step_f(dstate, source, sampler=sampler, | |
| temporal_step=model._temporal_step, depth_step=model._depth_step)) | |
| new_codes = torch.cat(toks, dim=1) | |
| audio = model.decode_stream(new_codes, decode_state) # FLOP-optimal stateful streaming decode | |
| emitted_samples += audio.shape[1] | |
| frame_ms = (time.time() - gen_t) * 1000.0 / CHUNK # real per-frame inference time | |
| if audio.shape[1] > 0: | |
| ab = _float_to_int16(audio[0].float().cpu().numpy()).astype("<i2").tobytes() | |
| yield (ab, frame_ms) | |
| buf = (0.15, 0.3, 0.5)[max(0, min(2, int(c.get("buffer", 0))))] | |
| ahead = (emitted_samples / SR) - (time.time() - t0) | |
| if ahead < -0.3: # fell behind (cold-start/hiccup) -> resync, no catch-up burst | |
| t0 = time.time() - (emitted_samples / SR) + buf | |
| ahead = buf | |
| buf_log.append(ahead) | |
| if time.time() - buf_t >= 15.0 and buf_log: | |
| print(f"[buffer] avg={sum(buf_log)/len(buf_log):.2f}s min={min(buf_log):.2f} " | |
| f"max={max(buf_log):.2f} target={buf:.2f}s n={len(buf_log)} " | |
| f"frame={frame_ms:.1f}ms mode={'cudagraph' if (USE_CG and cg_ok) else 'eager'}", flush=True) | |
| buf_log, buf_t = [], time.time() | |
| sleep_for = max(0.0, ahead - buf) | |
| cl = read_client_lead(session_id) # client's real Web-Audio queue (ms) | |
| if cl is not None and cl - buf * 1000.0 > 200: # only drain a genuinely-deep queue (target+200ms) | |
| sleep_for = max(sleep_for, min((cl - buf * 1000.0) / 1000.0 * 0.3, 0.2)) # gentle + capped | |
| if sleep_for > 0: | |
| time.sleep(min(sleep_for, 0.5)) | |
| _sessions = {} | |
| _slock = threading.Lock() | |
| def _worker(session_id, audio_q, stop_ev): | |
| try: | |
| while not stop_ev.is_set(): | |
| try: | |
| for ab, fm in gpu_stream(session_id): # ~55s GPU grant; re-grants on loop | |
| if stop_ev.is_set(): | |
| break | |
| if audio_q.full(): | |
| try: audio_q.get_nowait() | |
| except _q.Empty: pass | |
| try: audio_q.put_nowait((ab, fm)) | |
| except _q.Full: pass | |
| except Exception as e: | |
| if "abort" in str(e).lower() or "duration" in str(e).lower(): | |
| continue # grant expired -> re-grant | |
| print("[worker]", repr(e), flush=True); break | |
| finally: | |
| stop_ev.set() | |
| def start(session_id: str = "", request: gr.Request = None) -> str: | |
| if not session_id: | |
| return "" | |
| req = request or LocalContext.request.get(None) # user's gr.Request -> carries X-IP-Token (ZeroGPU quota) | |
| with _slock: | |
| prev = _sessions.pop(session_id, None) | |
| if prev: | |
| prev["stop"].set() | |
| audio_q = _q.Queue(maxsize=48); stop_ev = threading.Event() | |
| def run(): | |
| if req is not None: | |
| LocalContext.request.set(req) # graft user request so worker @spaces.GPU bills the USER | |
| _worker(session_id, audio_q, stop_ev) | |
| ctx = contextvars.copy_context() | |
| t = threading.Thread(target=ctx.run, args=(run,), daemon=True) | |
| with _slock: | |
| _sessions[session_id] = {"audio": audio_q, "stop": stop_ev} | |
| t.start() | |
| return session_id | |
| async def audio_ws(websocket: WebSocket, session_id: str = ""): | |
| await websocket.accept() | |
| sess = _sessions.get(session_id) | |
| if not sess: | |
| await websocket.close(code=1008); return | |
| audio_q, stop_ev = sess["audio"], sess["stop"] | |
| loop = asyncio.get_event_loop() | |
| async def send_audio(): # server -> client: binary audio frames | |
| try: | |
| while not stop_ev.is_set(): | |
| try: | |
| ab, fm = await loop.run_in_executor(None, lambda: audio_q.get(timeout=0.5)) | |
| except Exception: | |
| continue | |
| await websocket.send_bytes(struct.pack("<f", fm) + ab) | |
| try: await websocket.send_json({"type": "ended"}) # worker ended (quota/grant) | |
| except Exception: pass | |
| finally: | |
| stop_ev.set() | |
| async def recv_control(): # client -> server: low-latency steering (notes/params), no HTTP RTT | |
| try: | |
| while not stop_ev.is_set(): | |
| msg = await websocket.receive_json() # raises WebSocketDisconnect on close | |
| t = msg.get("type") | |
| if t == "set": | |
| try: _apply_set(msg) | |
| except Exception as e: print("[ws set]", repr(e), flush=True) | |
| elif t == "buffer": | |
| try: _apply_buffer(msg.get("session_id", session_id), msg.get("lead", 0)) | |
| except Exception: pass | |
| finally: | |
| stop_ev.set() | |
| try: | |
| await asyncio.gather(send_audio(), recv_control(), return_exceptions=True) | |
| finally: | |
| stop_ev.set() | |
| with _slock: | |
| _sessions.pop(session_id, None) | |
| async def index(): | |
| with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")) as f: | |
| return HTMLResponse(f.read()) | |
| app.launch(show_error=True, ssr_mode=False) | |