Spaces:
Runtime error
Runtime error
File size: 18,830 Bytes
910b66a 53f14f2 910b66a 7e2ebf9 b7ce408 7e2ebf9 910b66a 25044e6 910b66a 25044e6 910b66a 25044e6 910b66a 25044e6 910b66a 52b09db 976d4cf 910b66a 25044e6 910b66a 9354f9b af88b54 25044e6 ed36257 8cc4a31 4c7e7bd 7d8b7bc 910b66a 976d4cf 9354f9b 910b66a 4c7e7bd 25044e6 7e2ebf9 25044e6 910b66a 25044e6 910b66a 25044e6 82e98f6 25044e6 794e4b7 25044e6 15da1e9 319bbdf 8cc4a31 319bbdf 9354f9b 794e4b7 4c7e7bd 25044e6 8cc4a31 25044e6 794e4b7 319bbdf 8cc4a31 4c7e7bd dfd4eb6 4c7e7bd 794e4b7 4c7e7bd 794e4b7 319bbdf 4c7e7bd c50ec43 0cfac40 319bbdf c50ec43 9354f9b 8cc4a31 319bbdf 9354f9b 319bbdf 8cc4a31 d54323c 319bbdf 8cc4a31 319bbdf 33d4d65 8cc4a31 33d4d65 af88b54 319bbdf 8cc4a31 c50ec43 dfd4eb6 794e4b7 a0edc57 25044e6 7e2ebf9 0b202ed 794e4b7 adc3f9e 794e4b7 adc3f9e 15da1e9 52b09db 82e98f6 52b09db 03884b8 7e2ebf9 b7ce408 7e2ebf9 b7ce408 7e2ebf9 b7ce408 7e2ebf9 b7ce408 7e2ebf9 b7ce408 7e2ebf9 976d4cf 7e2ebf9 976d4cf 7e2ebf9 5781d9a ed36257 5781d9a 53f14f2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 | import os
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
os.environ["GRADIO_SSR_MODE"] = "false" # serve via Python (FastAPI) directly so /ws works
import json
import time
import uuid
import base64
import numpy as np
import spaces
import torch
from gradio import Server
from fastapi import Request
from fastapi.responses import HTMLResponse
from huggingface_hub import hf_hub_download
import asyncio, threading, queue as _q, contextvars, struct
import gradio as gr
from gradio.context import LocalContext
from fastapi import WebSocket, WebSocketDisconnect
from magenta_rt import paths
MAGENTA_HOME = "/data" if os.path.isdir("/data") else "/tmp/magenta"
home = os.path.join(MAGENTA_HOME, "magenta-rt-v2")
os.makedirs(home, exist_ok=True)
for ck in ("mrt2_small.safetensors", "mrt2_base.safetensors"):
hf_hub_download("google/magenta-realtime-2", f"checkpoints/{ck}", local_dir=home)
paths.set_magenta_home(home)
from magenta_rt.torch import MagentaRT2
from magenta_rt.torch.musiccoca import MusicCoCa
style_model = MusicCoCa(device="cpu")
# Both models fit on ZeroGPU — preload both; the dropdown switches between them.
mrt_small = MagentaRT2(size="mrt2_small", device="cuda", dtype=torch.bfloat16, style_model=style_model)
mrt_base = MagentaRT2(size="mrt2_base", device="cuda", dtype=torch.bfloat16, style_model=style_model)
MODELS = {"mrt2_small": mrt_small, "mrt2_base": mrt_base}
for name, repo in (("mrt2_base", "magenta-torch/magenta-rt-aoti"),
("mrt2_small", "magenta-torch/magenta-rt-aoti-small")):
try:
MODELS[name].load_compiled(repo_id=repo)
print(f"Loaded AOTI for {name} from {repo}")
except Exception as e:
print(f"AOTI load failed for {name} ({repo}) -> eager:", e)
SR = 48000
SESSION_DIR = "/tmp/mrt_jam2_sessions"
os.makedirs(SESSION_DIR, exist_ok=True)
_EMB_CACHE = {}
def _embed(label):
label = (label or "").strip() or "instrumental music"
if label not in _EMB_CACHE:
_EMB_CACHE[label] = np.asarray(style_model.embed(label).cpu().numpy(), np.float32)
return _EMB_CACHE[label]
def _slot(sid):
return os.path.join(SESSION_DIR, f"{os.path.basename(sid)}.json")
def write_slot(sid, d):
tmp = _slot(sid) + "." + uuid.uuid4().hex
with open(tmp, "w") as f:
json.dump(d, f)
os.replace(tmp, _slot(sid))
def read_slot(sid):
try:
with open(_slot(sid)) as f:
return json.load(f)
except Exception:
return None
def read_client_lead(sid):
try:
with open(os.path.join(SESSION_DIR, f"{os.path.basename(sid)}_buf.txt")) as f:
return float(f.read())
except Exception:
return None
def _apply_set(body): # build + write the conditioning slot (shared by HTTP /set and the WS)
sid = body["session_id"]
prev = read_slot(sid) or {}
write_slot(sid, {
"prompts": body.get("prompts") or ["instrumental music"],
"weights": body.get("weights") or [1.0],
"audio": body.get("audio") or [],
"temperature": float(body.get("temperature", 1.1)),
"top_k": int(body.get("top_k", 50)),
"cfg_musiccoca": float(body.get("cfg", 1.6)),
"cfg_notes": float(body.get("cfg_notes", 2.4)),
"cfg_drums": float(body.get("cfg_drums", 4.0)),
"model": body.get("model", prev.get("model", "mrt2_base")),
"buffer": int(body.get("buffer", 0)),
"notes": list(body.get("notes", [])),
"unmaskwidth": int(body.get("unmaskwidth", 0)),
"drumless": bool(body.get("drumless", False)),
"onsetmode": bool(body.get("onsetmode", False)),
"reset": int(body.get("reset", 0)),
"seed": int(body.get("seed", 0)),
"bank_op": body.get("bank_op", prev.get("bank_op")),
"ts": time.time(),
})
def _apply_buffer(sid, lead):
sid = os.path.basename(str(sid))
p = os.path.join(SESSION_DIR, f"{sid}_buf.txt")
with open(p + ".tmp", "w") as f:
f.write(str(float(lead)))
os.replace(p + ".tmp", p)
app = Server()
@app.post("/buffer")
async def set_clientbuffer(request: Request):
body = await request.json()
_apply_buffer(body["session_id"], body.get("lead", 0))
return {"ok": True}
@app.post("/set")
async def set_collider(request: Request):
body = await request.json()
_apply_set(body)
return {"ok": True}
@app.post("/audio")
async def set_audio(request: Request):
"""Store an uploaded clip (native-rate mono float32) per (session, slot);
the GPU stream embeds it with MusicCoCa.embed_audio (resampy -> 16kHz)."""
body = await request.json()
sid = os.path.basename(body["session_id"]); slot = int(body.get("slot", 0))
path = os.path.join(SESSION_DIR, f"{sid}_aud{slot}.bin")
samples = body.get("samples")
if not samples:
try: os.remove(path)
except OSError: pass
return {"ok": True}
raw = base64.b64decode(samples)
sr = int(body.get("sample_rate", 48000))
with open(path + ".tmp", "wb") as f:
f.write(int(sr).to_bytes(4, "little")); f.write(raw)
os.replace(path + ".tmp", path)
return {"ok": True}
@app.get("/banks")
async def banks(session_id: str = ""):
sid = os.path.basename(session_id)
return {"bankStatus": [os.path.exists(os.path.join(SESSION_DIR, f"{sid}_bank{i}.pt")) for i in range(3)]}
@spaces.GPU(duration=90)
def gpu_stream(session_id):
"""Continuous gen; switches model live when the dropdown changes."""
from magenta_rt.torch.system import make_sampler, discretize_cfg, _float_to_int16, FRAME_SAMPLES
if style_model.device != "cuda":
style_model.to("cuda")
dev, dt = "cuda", torch.bfloat16
deadline = time.time() + 8.0
while read_slot(session_id) is None and time.time() < deadline:
time.sleep(0.05)
gen = torch.Generator(device=dev).manual_seed(0)
try: # one-time GPU warm-up so the first audio frame isn't slow
style_model.embed("warmup")
_wd = mrt_base.model.decoder.init_streaming_f(1, dev, dt)
_wc = mrt_base._conditioning([-1] * mrt_base.num_musiccoca, [-1] * 128, [-1],
[discretize_cfg(1.6, 0.2, 40), discretize_cfg(2.4, 0.2, 40), discretize_cfg(4.0, 1.0, 8)])
_ws = mrt_base.model.encode(_wc).to(dt)
for _ in range(2):
mrt_base.model.decoder.step_f(_wd, _ws, sampler=make_sampler(1.1, 50, gen),
temporal_step=mrt_base._temporal_step, depth_step=mrt_base._depth_step)
print("[warmup] done", flush=True)
except Exception as _e:
print("[warmup]", repr(_e), flush=True)
notes, drums = [-1] * 128, [-1]
cur_name = model = dstate = source = None
decode_state = {}
emitted_samples = 0
t0 = time.time()
buf_log, buf_t = [], t0
cur_style_sig = cur_note_sig = cur_tokens = None
prev_active = set()
cur_reset = cur_seed = 0
had_onsets = False
txt_cache, aud_cache = {}, {}
CHUNK = 3 # frames per yield (smaller = finer delivery, lower buffer floor)
cur_bank_ver = 0
while time.time() - t0 < 55.0:
c = read_slot(session_id)
if c is None:
time.sleep(0.02)
continue
mname = c.get("model", "mrt2_base")
reset = int(c.get("reset", 0))
if mname != cur_name or reset != cur_reset: # model switch / state reset -> re-init
cur_name, model, cur_reset = mname, MODELS.get(mname, mrt_base), reset
dstate = model.model.decoder.init_streaming_f(1, dev, dt)
decode_state = model.init_decode_state()
emitted_samples, source, prev_active = 0, None, set()
cur_style_sig = cur_note_sig = cur_tokens = None; had_onsets = False
seed = int(c.get("seed", 0))
if seed != cur_seed:
cur_seed = seed
gen = torch.Generator(device=dev).manual_seed(seed)
bop = c.get("bank_op")
if bop and int(bop.get("ver", 0)) != cur_bank_ver: # save/recall generation state
cur_bank_ver = int(bop.get("ver", 0))
bpath = os.path.join(SESSION_DIR, f"{os.path.basename(session_id)}_bank{int(bop.get('idx', 0))}.pt")
try:
if bop.get("action") == "save":
torch.save({"dstate": dstate, "decode_state": decode_state, "emitted": emitted_samples}, bpath)
elif bop.get("action") == "load" and os.path.exists(bpath):
d = torch.load(bpath, map_location=dev)
dstate, decode_state, emitted_samples = d["dstate"], d["decode_state"], int(d["emitted"])
source, cur_note_sig = None, None
except Exception as e:
print("[bank] error:", repr(e), flush=True)
toks = []
gen_t = time.time()
for _ in range(CHUNK): # per-frame conditioning (native parity: re-derive every 40ms frame, no debounce)
c = read_slot(session_id) or c
prompts = c.get("prompts") or ["instrumental music"]
weights = c.get("weights") or [1.0] * len(prompts)
aflags = (c.get("audio") or []) + [False] * len(prompts)
active = set(c.get("notes") or [])
unmask = int(c.get("unmaskwidth", 0))
onsetmode = bool(c.get("onsetmode", False))
drumless = bool(c.get("drumless", False))
# ---- style tokens on GPU: embed+blend+tokenize, cached until prompts/weights/audio change ----
style_sig = (tuple(prompts), tuple(round(float(w), 4) for w in weights),
tuple(bool(a) for a in aflags[:len(prompts)]))
if cur_tokens is None or style_sig != cur_style_sig:
cur_style_sig = style_sig
emb = torch.zeros(style_model.embedding_dim, device=dev, dtype=torch.float32)
tot = float(sum(w for w in weights if w > 0)) or 1.0
for i in range(len(prompts)):
w = float(weights[i]) if i < len(weights) else 0.0
if w <= 0:
continue
if i < len(aflags) and aflags[i]:
ap = os.path.join(SESSION_DIR, f"{os.path.basename(session_id)}_aud{i}.bin")
try: mt = os.path.getmtime(ap)
except OSError: continue
ce = aud_cache.get(i)
if ce is None or ce[0] != mt:
data = open(ap, "rb").read()
asr = int.from_bytes(data[:4], "little")
samp = np.frombuffer(data[4:], dtype=np.float32)
ce = (mt, style_model.embed_audio(samp, asr).float()); aud_cache[i] = ce
e = ce[1]
else:
lbl = (prompts[i] or "").strip()
if not lbl:
continue
e = txt_cache.get(lbl)
if e is None:
e = style_model.embed(lbl).float(); txt_cache[lbl] = e
emb = emb + (w / tot) * e
cur_tokens = style_model.tokenize(emb).tolist()
# ---- notes + cfg -> re-encode on ANY change every frame (no debounce); onset(2)->sustain(1) via had_onsets ----
note_sig = (tuple(sorted(active)), unmask, onsetmode, drumless,
round(float(c.get("cfg_musiccoca", 1.6)), 3), round(float(c.get("cfg_notes", 2.4)), 3),
round(float(c.get("cfg_drums", 4.0)), 3), cur_style_sig)
if source is None or note_sig != cur_note_sig or had_onsets:
onsets = active - prev_active
prev_active = set(active)
had_onsets = onsetmode and bool(onsets) # carry one frame so the onset settles to sustain next frame
cur_note_sig = note_sig
nvec = [] # per-pitch: -1 masked / 0 off / 1 cont / 2 onset / 3 on
for pitch in range(128):
if pitch in active:
nvec.append((2 if pitch in onsets else 1) if onsetmode else 3)
elif unmask >= 127 or (active and any(abs(pitch - h) <= unmask for h in active)):
nvec.append(0) # solo (>=127) or within unmask window -> off
else:
nvec.append(-1) # masked (model free)
drm = [0] if drumless else [-1] # 0 no-drum / -1 masked
cfg_notes_v = 7.0 if (unmask >= 127 and not active) else c.get("cfg_notes", 2.4) # solo + no note held -> ramp cfg_notes to max bin (native JamApp suppresses melody)
cfgs = [discretize_cfg(c.get("cfg_musiccoca", 1.6), 0.2, 40),
discretize_cfg(cfg_notes_v, 0.2, 40),
discretize_cfg(c.get("cfg_drums", 4.0), 1.0, 8)]
cond = model._conditioning((list(cur_tokens) + [-1] * model.num_musiccoca)[:model.num_musiccoca],
nvec, drm, cfgs)
source = model.model.encode(cond).to(dt)
sampler = make_sampler(c.get("temperature", 1.1), c.get("top_k", 50), gen)
toks.append(model.model.decoder.step_f(dstate, source, sampler=sampler,
temporal_step=model._temporal_step, depth_step=model._depth_step))
new_codes = torch.cat(toks, dim=1)
audio = model.decode_stream(new_codes, decode_state) # FLOP-optimal stateful streaming decode
emitted_samples += audio.shape[1]
frame_ms = (time.time() - gen_t) * 1000.0 / CHUNK # real per-frame inference time
if audio.shape[1] > 0:
ab = _float_to_int16(audio[0].float().cpu().numpy()).astype("<i2").tobytes()
yield (ab, frame_ms)
buf = (0.15, 0.3, 0.5)[max(0, min(2, int(c.get("buffer", 0))))]
ahead = (emitted_samples / SR) - (time.time() - t0)
if ahead < -0.3: # fell behind (cold-start/hiccup) -> resync, no catch-up burst
t0 = time.time() - (emitted_samples / SR) + buf
ahead = buf
buf_log.append(ahead)
if time.time() - buf_t >= 15.0 and buf_log:
print(f"[buffer] avg={sum(buf_log)/len(buf_log):.2f}s min={min(buf_log):.2f} "
f"max={max(buf_log):.2f} target={buf:.2f}s n={len(buf_log)}", flush=True)
buf_log, buf_t = [], time.time()
sleep_for = max(0.0, ahead - buf)
cl = read_client_lead(session_id) # client's real Web-Audio queue (ms)
if cl is not None and cl - buf * 1000.0 > 200: # only drain a genuinely-deep queue (target+200ms)
sleep_for = max(sleep_for, min((cl - buf * 1000.0) / 1000.0 * 0.3, 0.2)) # gentle + capped
if sleep_for > 0:
time.sleep(min(sleep_for, 0.5))
_sessions = {}
_slock = threading.Lock()
def _worker(session_id, audio_q, stop_ev):
try:
while not stop_ev.is_set():
try:
for ab, fm in gpu_stream(session_id): # ~55s GPU grant; re-grants on loop
if stop_ev.is_set():
break
if audio_q.full():
try: audio_q.get_nowait()
except _q.Empty: pass
try: audio_q.put_nowait((ab, fm))
except _q.Full: pass
except Exception as e:
if "abort" in str(e).lower() or "duration" in str(e).lower():
continue # grant expired -> re-grant
print("[worker]", repr(e), flush=True); break
finally:
stop_ev.set()
@app.api(name="start")
def start(session_id: str = "", request: gr.Request = None) -> str:
if not session_id:
return ""
req = request or LocalContext.request.get(None) # user's gr.Request -> carries X-IP-Token (ZeroGPU quota)
with _slock:
prev = _sessions.pop(session_id, None)
if prev:
prev["stop"].set()
audio_q = _q.Queue(maxsize=48); stop_ev = threading.Event()
def run():
if req is not None:
LocalContext.request.set(req) # graft user request so worker @spaces.GPU bills the USER
_worker(session_id, audio_q, stop_ev)
ctx = contextvars.copy_context()
t = threading.Thread(target=ctx.run, args=(run,), daemon=True)
with _slock:
_sessions[session_id] = {"audio": audio_q, "stop": stop_ev}
t.start()
return session_id
@app.websocket("/ws")
async def audio_ws(websocket: WebSocket, session_id: str = ""):
await websocket.accept()
sess = _sessions.get(session_id)
if not sess:
await websocket.close(code=1008); return
audio_q, stop_ev = sess["audio"], sess["stop"]
loop = asyncio.get_event_loop()
async def send_audio(): # server -> client: binary audio frames
try:
while not stop_ev.is_set():
try:
ab, fm = await loop.run_in_executor(None, lambda: audio_q.get(timeout=0.5))
except Exception:
continue
await websocket.send_bytes(struct.pack("<f", fm) + ab)
try: await websocket.send_json({"type": "ended"}) # worker ended (quota/grant)
except Exception: pass
finally:
stop_ev.set()
async def recv_control(): # client -> server: low-latency steering (notes/params), no HTTP RTT
try:
while not stop_ev.is_set():
msg = await websocket.receive_json() # raises WebSocketDisconnect on close
t = msg.get("type")
if t == "set":
try: _apply_set(msg)
except Exception as e: print("[ws set]", repr(e), flush=True)
elif t == "buffer":
try: _apply_buffer(msg.get("session_id", session_id), msg.get("lead", 0))
except Exception: pass
finally:
stop_ev.set()
try:
await asyncio.gather(send_audio(), recv_control(), return_exceptions=True)
finally:
stop_ev.set()
with _slock:
_sessions.pop(session_id, None)
@app.get("/")
async def index():
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")) as f:
return HTMLResponse(f.read())
app.launch(show_error=True, ssr_mode=False)
|