| import os |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
| os.environ["GRADIO_SSR_MODE"] = "false" |
| import json |
| import time |
| import uuid |
| import base64 |
| import numpy as np |
| import spaces |
| import torch |
| from gradio import Server |
| from fastapi import Request |
| from fastapi.responses import HTMLResponse |
| from huggingface_hub import hf_hub_download |
| import asyncio, threading, queue as _q, contextvars, struct |
| import gradio as gr |
| from gradio.context import LocalContext |
| from fastapi import WebSocket, WebSocketDisconnect |
|
|
| from magenta_rt import paths |
|
|
| MAGENTA_HOME = "/data" if os.path.isdir("/data") else "/tmp/magenta" |
| home = os.path.join(MAGENTA_HOME, "magenta-rt-v2") |
| os.makedirs(home, exist_ok=True) |
| for ck in ("mrt2_small.safetensors", "mrt2_base.safetensors"): |
| hf_hub_download("google/magenta-realtime-2", f"checkpoints/{ck}", local_dir=home) |
| paths.set_magenta_home(home) |
|
|
| from magenta_rt.torch import MagentaRT2 |
| from magenta_rt.torch.musiccoca import MusicCoCa |
|
|
| style_model = MusicCoCa(device="cpu") |
| |
| |
| |
| mrt_base = MagentaRT2(size="mrt2_base", device="cuda", dtype=torch.bfloat16, style_model=style_model) |
| mrt_small = mrt_base |
| MODELS = {"mrt2_base": mrt_base, "mrt2_small": mrt_base} |
| for name, repo in (("mrt2_base", "magenta-community/magenta-rt-aoti-base"),): |
| try: |
| MODELS[name].load_compiled(repo_id=repo) |
| print(f"Loaded AOTI for {name} from {repo}") |
| except Exception as e: |
| print(f"AOTI load failed for {name} ({repo}) -> eager:", e) |
|
|
| SR = 48000 |
| SESSION_DIR = "/tmp/mrt_jam2_sessions" |
| os.makedirs(SESSION_DIR, exist_ok=True) |
| _EMB_CACHE = {} |
|
|
|
|
| def _embed(label): |
| label = (label or "").strip() or "instrumental music" |
| if label not in _EMB_CACHE: |
| _EMB_CACHE[label] = np.asarray(style_model.embed(label).cpu().numpy(), np.float32) |
| return _EMB_CACHE[label] |
|
|
|
|
| def _slot(sid): |
| return os.path.join(SESSION_DIR, f"{os.path.basename(sid)}.json") |
|
|
|
|
| def write_slot(sid, d): |
| tmp = _slot(sid) + "." + uuid.uuid4().hex |
| with open(tmp, "w") as f: |
| json.dump(d, f) |
| os.replace(tmp, _slot(sid)) |
|
|
|
|
| def read_slot(sid): |
| try: |
| with open(_slot(sid)) as f: |
| return json.load(f) |
| except Exception: |
| return None |
|
|
|
|
| def read_client_lead(sid): |
| try: |
| with open(os.path.join(SESSION_DIR, f"{os.path.basename(sid)}_buf.txt")) as f: |
| return float(f.read()) |
| except Exception: |
| return None |
|
|
|
|
| def _apply_set(body): |
| sid = body["session_id"] |
| prev = read_slot(sid) or {} |
| write_slot(sid, { |
| "prompts": body.get("prompts") or ["instrumental music"], |
| "weights": body.get("weights") or [1.0], |
| "audio": body.get("audio") or [], |
| "temperature": float(body.get("temperature", 1.1)), |
| "top_k": int(body.get("top_k", 50)), |
| "cfg_musiccoca": float(body.get("cfg", 1.6)), |
| "cfg_notes": float(body.get("cfg_notes", 2.4)), |
| "cfg_drums": float(body.get("cfg_drums", 4.0)), |
| "model": body.get("model", prev.get("model", "mrt2_base")), |
| "buffer": int(body.get("buffer", 0)), |
| "notes": list(body.get("notes", [])), |
| "unmaskwidth": int(body.get("unmaskwidth", 0)), |
| "drumless": bool(body.get("drumless", False)), |
| "onsetmode": bool(body.get("onsetmode", False)), |
| "reset": int(body.get("reset", 0)), |
| "seed": int(body.get("seed", 0)), |
| "bank_op": body.get("bank_op", prev.get("bank_op")), |
| "ts": time.time(), |
| }) |
|
|
|
|
| def _apply_buffer(sid, lead): |
| sid = os.path.basename(str(sid)) |
| p = os.path.join(SESSION_DIR, f"{sid}_buf.txt") |
| with open(p + ".tmp", "w") as f: |
| f.write(str(float(lead))) |
| os.replace(p + ".tmp", p) |
|
|
|
|
| app = Server() |
|
|
|
|
| @app.post("/buffer") |
| async def set_clientbuffer(request: Request): |
| body = await request.json() |
| _apply_buffer(body["session_id"], body.get("lead", 0)) |
| return {"ok": True} |
|
|
|
|
| @app.post("/set") |
| async def set_collider(request: Request): |
| body = await request.json() |
| _apply_set(body) |
| return {"ok": True} |
|
|
|
|
| @app.post("/audio") |
| async def set_audio(request: Request): |
| """Store an uploaded clip (native-rate mono float32) per (session, slot); |
| the GPU stream embeds it with MusicCoCa.embed_audio (resampy -> 16kHz).""" |
| body = await request.json() |
| sid = os.path.basename(body["session_id"]); slot = int(body.get("slot", 0)) |
| path = os.path.join(SESSION_DIR, f"{sid}_aud{slot}.bin") |
| samples = body.get("samples") |
| if not samples: |
| try: os.remove(path) |
| except OSError: pass |
| return {"ok": True} |
| raw = base64.b64decode(samples) |
| sr = int(body.get("sample_rate", 48000)) |
| with open(path + ".tmp", "wb") as f: |
| f.write(int(sr).to_bytes(4, "little")); f.write(raw) |
| os.replace(path + ".tmp", path) |
| return {"ok": True} |
|
|
|
|
| @app.get("/banks") |
| async def banks(session_id: str = ""): |
| sid = os.path.basename(session_id) |
| return {"bankStatus": [os.path.exists(os.path.join(SESSION_DIR, f"{sid}_bank{i}.pt")) for i in range(3)]} |
|
|
|
|
| @spaces.GPU(duration=60) |
| def gpu_stream(session_id): |
| """Continuous gen; switches model live when the dropdown changes.""" |
| from magenta_rt.torch.system import make_sampler, discretize_cfg, _float_to_int16, FRAME_SAMPLES |
| if style_model.device != "cuda": |
| style_model.to("cuda") |
| dev, dt = "cuda", torch.bfloat16 |
| deadline = time.time() + 8.0 |
| while read_slot(session_id) is None and time.time() < deadline: |
| time.sleep(0.05) |
| gen = torch.Generator(device=dev).manual_seed(0) |
| try: |
| style_model.embed("warmup") |
| _wd = mrt_base.model.decoder.init_streaming_f(1, dev, dt) |
| _wc = mrt_base._conditioning([-1] * mrt_base.num_musiccoca, [-1] * 128, [-1], |
| [discretize_cfg(1.6, 0.2, 40), discretize_cfg(2.4, 0.2, 40), discretize_cfg(4.0, 1.0, 8)]) |
| _ws = mrt_base.model.encode(_wc).to(dt) |
| for _ in range(2): |
| mrt_base.model.decoder.step_f(_wd, _ws, sampler=make_sampler(1.1, 50, gen), |
| temporal_step=mrt_base._temporal_step, depth_step=mrt_base._depth_step) |
| print("[warmup] done", flush=True) |
| except Exception as _e: |
| print("[warmup]", repr(_e), flush=True) |
| notes, drums = [-1] * 128, [-1] |
| cur_name = model = dstate = source = None |
| decode_state = {} |
| emitted_samples = 0 |
| t0 = time.time() |
| buf_log, buf_t = [], t0 |
| cur_style_sig = cur_note_sig = cur_tokens = None |
| prev_active = set() |
| cur_reset = cur_seed = 0 |
| had_onsets = False |
| txt_cache, aud_cache = {}, {} |
| CHUNK = 3 |
| cur_bank_ver = 0 |
| while time.time() - t0 < 40.0: |
| c = read_slot(session_id) |
| if c is None: |
| time.sleep(0.02) |
| continue |
| mname = c.get("model", "mrt2_base") |
| reset = int(c.get("reset", 0)) |
| if mname != cur_name or reset != cur_reset: |
| cur_name, model, cur_reset = mname, MODELS.get(mname, mrt_base), reset |
| dstate = model.model.decoder.init_streaming_f(1, dev, dt) |
| decode_state = model.init_decode_state() |
| emitted_samples, source, prev_active = 0, None, set() |
| cur_style_sig = cur_note_sig = cur_tokens = None; had_onsets = False |
| seed = int(c.get("seed", 0)) |
| if seed != cur_seed: |
| cur_seed = seed |
| gen = torch.Generator(device=dev).manual_seed(seed) |
| bop = c.get("bank_op") |
| if bop and int(bop.get("ver", 0)) != cur_bank_ver: |
| cur_bank_ver = int(bop.get("ver", 0)) |
| bpath = os.path.join(SESSION_DIR, f"{os.path.basename(session_id)}_bank{int(bop.get('idx', 0))}.pt") |
| try: |
| if bop.get("action") == "save": |
| torch.save({"dstate": dstate, "decode_state": decode_state, "emitted": emitted_samples}, bpath) |
| elif bop.get("action") == "load" and os.path.exists(bpath): |
| d = torch.load(bpath, map_location=dev) |
| dstate, decode_state, emitted_samples = d["dstate"], d["decode_state"], int(d["emitted"]) |
| source, cur_note_sig = None, None |
| except Exception as e: |
| print("[bank] error:", repr(e), flush=True) |
| toks = [] |
| gen_t = time.time() |
| for _ in range(CHUNK): |
| c = read_slot(session_id) or c |
| prompts = c.get("prompts") or ["instrumental music"] |
| weights = c.get("weights") or [1.0] * len(prompts) |
| aflags = (c.get("audio") or []) + [False] * len(prompts) |
| active = set(c.get("notes") or []) |
| unmask = int(c.get("unmaskwidth", 0)) |
| onsetmode = bool(c.get("onsetmode", False)) |
| drumless = bool(c.get("drumless", False)) |
| |
| style_sig = (tuple(prompts), tuple(round(float(w), 4) for w in weights), |
| tuple(bool(a) for a in aflags[:len(prompts)])) |
| if cur_tokens is None or style_sig != cur_style_sig: |
| cur_style_sig = style_sig |
| emb = torch.zeros(style_model.embedding_dim, device=dev, dtype=torch.float32) |
| tot = float(sum(w for w in weights if w > 0)) or 1.0 |
| for i in range(len(prompts)): |
| w = float(weights[i]) if i < len(weights) else 0.0 |
| if w <= 0: |
| continue |
| if i < len(aflags) and aflags[i]: |
| ap = os.path.join(SESSION_DIR, f"{os.path.basename(session_id)}_aud{i}.bin") |
| try: mt = os.path.getmtime(ap) |
| except OSError: continue |
| ce = aud_cache.get(i) |
| if ce is None or ce[0] != mt: |
| data = open(ap, "rb").read() |
| asr = int.from_bytes(data[:4], "little") |
| samp = np.frombuffer(data[4:], dtype=np.float32) |
| ce = (mt, style_model.embed_audio(samp, asr).float()); aud_cache[i] = ce |
| e = ce[1] |
| else: |
| lbl = (prompts[i] or "").strip() |
| if not lbl: |
| continue |
| e = txt_cache.get(lbl) |
| if e is None: |
| e = style_model.embed(lbl).float(); txt_cache[lbl] = e |
| emb = emb + (w / tot) * e |
| cur_tokens = style_model.tokenize(emb).tolist() |
| |
| note_sig = (tuple(sorted(active)), unmask, onsetmode, drumless, |
| round(float(c.get("cfg_musiccoca", 1.6)), 3), round(float(c.get("cfg_notes", 2.4)), 3), |
| round(float(c.get("cfg_drums", 4.0)), 3), cur_style_sig) |
| if source is None or note_sig != cur_note_sig or had_onsets: |
| onsets = active - prev_active |
| prev_active = set(active) |
| had_onsets = onsetmode and bool(onsets) |
| cur_note_sig = note_sig |
| nvec = [] |
| for pitch in range(128): |
| if pitch in active: |
| nvec.append((2 if pitch in onsets else 1) if onsetmode else 3) |
| elif unmask >= 127 or (active and any(abs(pitch - h) <= unmask for h in active)): |
| nvec.append(0) |
| else: |
| nvec.append(-1) |
| drm = [0] if drumless else [-1] |
| cfg_notes_v = 7.0 if (unmask >= 127 and not active) else c.get("cfg_notes", 2.4) |
| cfgs = [discretize_cfg(c.get("cfg_musiccoca", 1.6), 0.2, 40), |
| discretize_cfg(cfg_notes_v, 0.2, 40), |
| discretize_cfg(c.get("cfg_drums", 4.0), 1.0, 8)] |
| cond = model._conditioning((list(cur_tokens) + [-1] * model.num_musiccoca)[:model.num_musiccoca], |
| nvec, drm, cfgs) |
| source = model.model.encode(cond).to(dt) |
| sampler = make_sampler(c.get("temperature", 1.1), c.get("top_k", 50), gen) |
| toks.append(model.model.decoder.step_f(dstate, source, sampler=sampler, |
| temporal_step=model._temporal_step, depth_step=model._depth_step)) |
| new_codes = torch.cat(toks, dim=1) |
| audio = model.decode_stream(new_codes, decode_state) |
| emitted_samples += audio.shape[1] |
| frame_ms = (time.time() - gen_t) * 1000.0 / CHUNK |
| if audio.shape[1] > 0: |
| ab = _float_to_int16(audio[0].float().cpu().numpy()).astype("<i2").tobytes() |
| yield (ab, frame_ms) |
| buf = (0.15, 0.3, 0.5)[max(0, min(2, int(c.get("buffer", 0))))] |
| ahead = (emitted_samples / SR) - (time.time() - t0) |
| if ahead < -0.3: |
| t0 = time.time() - (emitted_samples / SR) + buf |
| ahead = buf |
| buf_log.append(ahead) |
| if time.time() - buf_t >= 15.0 and buf_log: |
| print(f"[buffer] avg={sum(buf_log)/len(buf_log):.2f}s min={min(buf_log):.2f} " |
| f"max={max(buf_log):.2f} target={buf:.2f}s n={len(buf_log)}", flush=True) |
| buf_log, buf_t = [], time.time() |
| sleep_for = max(0.0, ahead - buf) |
| cl = read_client_lead(session_id) |
| if cl is not None and cl - buf * 1000.0 > 200: |
| sleep_for = max(sleep_for, min((cl - buf * 1000.0) / 1000.0 * 0.3, 0.2)) |
| if sleep_for > 0: |
| time.sleep(min(sleep_for, 0.5)) |
|
|
|
|
| _sessions = {} |
| _slock = threading.Lock() |
|
|
|
|
| def _worker(session_id, audio_q, stop_ev): |
| try: |
| while not stop_ev.is_set(): |
| try: |
| for ab, fm in gpu_stream(session_id): |
| if stop_ev.is_set(): |
| break |
| if audio_q.full(): |
| try: audio_q.get_nowait() |
| except _q.Empty: pass |
| try: audio_q.put_nowait((ab, fm)) |
| except _q.Full: pass |
| except Exception as e: |
| if "abort" in str(e).lower() or "duration" in str(e).lower(): |
| continue |
| print("[worker]", repr(e), flush=True); break |
| finally: |
| stop_ev.set() |
|
|
|
|
| @app.api(name="start") |
| def start(session_id: str = "", request: gr.Request = None) -> str: |
| if not session_id: |
| return "" |
| req = request or LocalContext.request.get(None) |
| with _slock: |
| prev = _sessions.pop(session_id, None) |
| if prev: |
| prev["stop"].set() |
| audio_q = _q.Queue(maxsize=48); stop_ev = threading.Event() |
| def run(): |
| if req is not None: |
| LocalContext.request.set(req) |
| _worker(session_id, audio_q, stop_ev) |
| ctx = contextvars.copy_context() |
| t = threading.Thread(target=ctx.run, args=(run,), daemon=True) |
| with _slock: |
| _sessions[session_id] = {"audio": audio_q, "stop": stop_ev} |
| t.start() |
| return session_id |
|
|
|
|
| @app.websocket("/ws") |
| async def audio_ws(websocket: WebSocket, session_id: str = ""): |
| await websocket.accept() |
| sess = _sessions.get(session_id) |
| if not sess: |
| await websocket.close(code=1008); return |
| audio_q, stop_ev = sess["audio"], sess["stop"] |
| loop = asyncio.get_event_loop() |
|
|
| async def send_audio(): |
| try: |
| while not stop_ev.is_set(): |
| try: |
| ab, fm = await loop.run_in_executor(None, lambda: audio_q.get(timeout=0.5)) |
| except Exception: |
| continue |
| await websocket.send_bytes(struct.pack("<f", fm) + ab) |
| try: await websocket.send_json({"type": "ended"}) |
| except Exception: pass |
| finally: |
| stop_ev.set() |
|
|
| async def recv_control(): |
| try: |
| while not stop_ev.is_set(): |
| msg = await websocket.receive_json() |
| t = msg.get("type") |
| if t == "set": |
| try: _apply_set(msg) |
| except Exception as e: print("[ws set]", repr(e), flush=True) |
| elif t == "buffer": |
| try: _apply_buffer(msg.get("session_id", session_id), msg.get("lead", 0)) |
| except Exception: pass |
| finally: |
| stop_ev.set() |
|
|
| try: |
| await asyncio.gather(send_audio(), recv_control(), return_exceptions=True) |
| finally: |
| stop_ev.set() |
| with _slock: |
| _sessions.pop(session_id, None) |
|
|
|
|
| @app.get("/") |
| async def index(): |
| with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")) as f: |
| return HTMLResponse(f.read()) |
|
|
|
|
| app.launch(show_error=True, ssr_mode=False) |
|
|