File size: 18,830 Bytes
910b66a
 
53f14f2
910b66a
 
 
 
 
 
 
 
 
 
 
7e2ebf9
b7ce408
 
7e2ebf9
910b66a
 
 
 
 
 
25044e6
 
910b66a
 
 
 
 
25044e6
 
 
 
 
 
 
 
 
 
 
 
910b66a
 
25044e6
910b66a
25044e6
 
 
 
 
 
 
 
910b66a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52b09db
 
 
 
 
 
 
 
976d4cf
910b66a
25044e6
910b66a
9354f9b
 
 
af88b54
 
 
 
 
25044e6
ed36257
8cc4a31
 
 
 
 
 
4c7e7bd
7d8b7bc
910b66a
976d4cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9354f9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
910b66a
 
4c7e7bd
 
 
 
 
 
25044e6
7e2ebf9
25044e6
 
910b66a
 
25044e6
910b66a
 
 
25044e6
82e98f6
 
 
 
 
 
 
 
 
 
 
 
25044e6
794e4b7
 
 
25044e6
15da1e9
319bbdf
8cc4a31
 
319bbdf
9354f9b
794e4b7
4c7e7bd
25044e6
 
 
 
 
 
8cc4a31
 
 
25044e6
794e4b7
319bbdf
 
8cc4a31
 
 
 
4c7e7bd
dfd4eb6
4c7e7bd
 
 
 
794e4b7
4c7e7bd
 
794e4b7
319bbdf
4c7e7bd
 
c50ec43
0cfac40
319bbdf
c50ec43
9354f9b
 
 
8cc4a31
 
 
 
319bbdf
 
 
 
 
9354f9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319bbdf
 
 
 
 
 
 
 
 
 
 
8cc4a31
 
 
d54323c
319bbdf
8cc4a31
319bbdf
 
33d4d65
8cc4a31
33d4d65
af88b54
319bbdf
8cc4a31
c50ec43
dfd4eb6
 
 
794e4b7
 
 
a0edc57
25044e6
7e2ebf9
 
0b202ed
794e4b7
adc3f9e
794e4b7
adc3f9e
15da1e9
 
 
 
 
52b09db
 
82e98f6
 
52b09db
 
03884b8
 
7e2ebf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7ce408
7e2ebf9
b7ce408
 
7e2ebf9
 
 
 
 
b7ce408
 
 
 
7e2ebf9
b7ce408
7e2ebf9
 
 
b7ce408
7e2ebf9
 
 
 
 
 
 
 
 
 
976d4cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e2ebf9
976d4cf
7e2ebf9
 
 
 
 
 
5781d9a
 
ed36257
 
5781d9a
 
53f14f2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
import os
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
os.environ["GRADIO_SSR_MODE"] = "false"   # serve via Python (FastAPI) directly so /ws works
import json
import time
import uuid
import base64
import numpy as np
import spaces
import torch
from gradio import Server
from fastapi import Request
from fastapi.responses import HTMLResponse
from huggingface_hub import hf_hub_download
import asyncio, threading, queue as _q, contextvars, struct
import gradio as gr
from gradio.context import LocalContext
from fastapi import WebSocket, WebSocketDisconnect

from magenta_rt import paths

MAGENTA_HOME = "/data" if os.path.isdir("/data") else "/tmp/magenta"
home = os.path.join(MAGENTA_HOME, "magenta-rt-v2")
os.makedirs(home, exist_ok=True)
for ck in ("mrt2_small.safetensors", "mrt2_base.safetensors"):
    hf_hub_download("google/magenta-realtime-2", f"checkpoints/{ck}", local_dir=home)
paths.set_magenta_home(home)

from magenta_rt.torch import MagentaRT2
from magenta_rt.torch.musiccoca import MusicCoCa

style_model = MusicCoCa(device="cpu")
# Both models fit on ZeroGPU — preload both; the dropdown switches between them.
mrt_small = MagentaRT2(size="mrt2_small", device="cuda", dtype=torch.bfloat16, style_model=style_model)
mrt_base = MagentaRT2(size="mrt2_base", device="cuda", dtype=torch.bfloat16, style_model=style_model)
MODELS = {"mrt2_small": mrt_small, "mrt2_base": mrt_base}
for name, repo in (("mrt2_base", "magenta-torch/magenta-rt-aoti"),
                   ("mrt2_small", "magenta-torch/magenta-rt-aoti-small")):
    try:
        MODELS[name].load_compiled(repo_id=repo)
        print(f"Loaded AOTI for {name} from {repo}")
    except Exception as e:
        print(f"AOTI load failed for {name} ({repo}) -> eager:", e)

SR = 48000
SESSION_DIR = "/tmp/mrt_jam2_sessions"
os.makedirs(SESSION_DIR, exist_ok=True)
_EMB_CACHE = {}


def _embed(label):
    label = (label or "").strip() or "instrumental music"
    if label not in _EMB_CACHE:
        _EMB_CACHE[label] = np.asarray(style_model.embed(label).cpu().numpy(), np.float32)
    return _EMB_CACHE[label]


def _slot(sid):
    return os.path.join(SESSION_DIR, f"{os.path.basename(sid)}.json")


def write_slot(sid, d):
    tmp = _slot(sid) + "." + uuid.uuid4().hex
    with open(tmp, "w") as f:
        json.dump(d, f)
    os.replace(tmp, _slot(sid))


def read_slot(sid):
    try:
        with open(_slot(sid)) as f:
            return json.load(f)
    except Exception:
        return None


def read_client_lead(sid):
    try:
        with open(os.path.join(SESSION_DIR, f"{os.path.basename(sid)}_buf.txt")) as f:
            return float(f.read())
    except Exception:
        return None


def _apply_set(body):                          # build + write the conditioning slot (shared by HTTP /set and the WS)
    sid = body["session_id"]
    prev = read_slot(sid) or {}
    write_slot(sid, {
        "prompts": body.get("prompts") or ["instrumental music"],
        "weights": body.get("weights") or [1.0],
        "audio": body.get("audio") or [],
        "temperature": float(body.get("temperature", 1.1)),
        "top_k": int(body.get("top_k", 50)),
        "cfg_musiccoca": float(body.get("cfg", 1.6)),
        "cfg_notes": float(body.get("cfg_notes", 2.4)),
        "cfg_drums": float(body.get("cfg_drums", 4.0)),
        "model": body.get("model", prev.get("model", "mrt2_base")),
        "buffer": int(body.get("buffer", 0)),
        "notes": list(body.get("notes", [])),
        "unmaskwidth": int(body.get("unmaskwidth", 0)),
        "drumless": bool(body.get("drumless", False)),
        "onsetmode": bool(body.get("onsetmode", False)),
        "reset": int(body.get("reset", 0)),
        "seed": int(body.get("seed", 0)),
        "bank_op": body.get("bank_op", prev.get("bank_op")),
        "ts": time.time(),
    })


def _apply_buffer(sid, lead):
    sid = os.path.basename(str(sid))
    p = os.path.join(SESSION_DIR, f"{sid}_buf.txt")
    with open(p + ".tmp", "w") as f:
        f.write(str(float(lead)))
    os.replace(p + ".tmp", p)


app = Server()


@app.post("/buffer")
async def set_clientbuffer(request: Request):
    body = await request.json()
    _apply_buffer(body["session_id"], body.get("lead", 0))
    return {"ok": True}


@app.post("/set")
async def set_collider(request: Request):
    body = await request.json()
    _apply_set(body)
    return {"ok": True}


@app.post("/audio")
async def set_audio(request: Request):
    """Store an uploaded clip (native-rate mono float32) per (session, slot);
    the GPU stream embeds it with MusicCoCa.embed_audio (resampy -> 16kHz)."""
    body = await request.json()
    sid = os.path.basename(body["session_id"]); slot = int(body.get("slot", 0))
    path = os.path.join(SESSION_DIR, f"{sid}_aud{slot}.bin")
    samples = body.get("samples")
    if not samples:
        try: os.remove(path)
        except OSError: pass
        return {"ok": True}
    raw = base64.b64decode(samples)
    sr = int(body.get("sample_rate", 48000))
    with open(path + ".tmp", "wb") as f:
        f.write(int(sr).to_bytes(4, "little")); f.write(raw)
    os.replace(path + ".tmp", path)
    return {"ok": True}


@app.get("/banks")
async def banks(session_id: str = ""):
    sid = os.path.basename(session_id)
    return {"bankStatus": [os.path.exists(os.path.join(SESSION_DIR, f"{sid}_bank{i}.pt")) for i in range(3)]}


@spaces.GPU(duration=90)
def gpu_stream(session_id):
    """Continuous gen; switches model live when the dropdown changes."""
    from magenta_rt.torch.system import make_sampler, discretize_cfg, _float_to_int16, FRAME_SAMPLES
    if style_model.device != "cuda":
        style_model.to("cuda")
    dev, dt = "cuda", torch.bfloat16
    deadline = time.time() + 8.0
    while read_slot(session_id) is None and time.time() < deadline:
        time.sleep(0.05)
    gen = torch.Generator(device=dev).manual_seed(0)
    try:                                          # one-time GPU warm-up so the first audio frame isn't slow
        style_model.embed("warmup")
        _wd = mrt_base.model.decoder.init_streaming_f(1, dev, dt)
        _wc = mrt_base._conditioning([-1] * mrt_base.num_musiccoca, [-1] * 128, [-1],
                                     [discretize_cfg(1.6, 0.2, 40), discretize_cfg(2.4, 0.2, 40), discretize_cfg(4.0, 1.0, 8)])
        _ws = mrt_base.model.encode(_wc).to(dt)
        for _ in range(2):
            mrt_base.model.decoder.step_f(_wd, _ws, sampler=make_sampler(1.1, 50, gen),
                                          temporal_step=mrt_base._temporal_step, depth_step=mrt_base._depth_step)
        print("[warmup] done", flush=True)
    except Exception as _e:
        print("[warmup]", repr(_e), flush=True)
    notes, drums = [-1] * 128, [-1]
    cur_name = model = dstate = source = None
    decode_state = {}
    emitted_samples = 0
    t0 = time.time()
    buf_log, buf_t = [], t0
    cur_style_sig = cur_note_sig = cur_tokens = None
    prev_active = set()
    cur_reset = cur_seed = 0
    had_onsets = False
    txt_cache, aud_cache = {}, {}
    CHUNK = 3                                   # frames per yield (smaller = finer delivery, lower buffer floor)
    cur_bank_ver = 0
    while time.time() - t0 < 55.0:
        c = read_slot(session_id)
        if c is None:
            time.sleep(0.02)
            continue
        mname = c.get("model", "mrt2_base")
        reset = int(c.get("reset", 0))
        if mname != cur_name or reset != cur_reset:   # model switch / state reset -> re-init
            cur_name, model, cur_reset = mname, MODELS.get(mname, mrt_base), reset
            dstate = model.model.decoder.init_streaming_f(1, dev, dt)
            decode_state = model.init_decode_state()
            emitted_samples, source, prev_active = 0, None, set()
            cur_style_sig = cur_note_sig = cur_tokens = None; had_onsets = False
        seed = int(c.get("seed", 0))
        if seed != cur_seed:
            cur_seed = seed
            gen = torch.Generator(device=dev).manual_seed(seed)
        bop = c.get("bank_op")
        if bop and int(bop.get("ver", 0)) != cur_bank_ver:    # save/recall generation state
            cur_bank_ver = int(bop.get("ver", 0))
            bpath = os.path.join(SESSION_DIR, f"{os.path.basename(session_id)}_bank{int(bop.get('idx', 0))}.pt")
            try:
                if bop.get("action") == "save":
                    torch.save({"dstate": dstate, "decode_state": decode_state, "emitted": emitted_samples}, bpath)
                elif bop.get("action") == "load" and os.path.exists(bpath):
                    d = torch.load(bpath, map_location=dev)
                    dstate, decode_state, emitted_samples = d["dstate"], d["decode_state"], int(d["emitted"])
                    source, cur_note_sig = None, None
            except Exception as e:
                print("[bank] error:", repr(e), flush=True)
        toks = []
        gen_t = time.time()
        for _ in range(CHUNK):                        # per-frame conditioning (native parity: re-derive every 40ms frame, no debounce)
            c = read_slot(session_id) or c
            prompts = c.get("prompts") or ["instrumental music"]
            weights = c.get("weights") or [1.0] * len(prompts)
            aflags = (c.get("audio") or []) + [False] * len(prompts)
            active = set(c.get("notes") or [])
            unmask = int(c.get("unmaskwidth", 0))
            onsetmode = bool(c.get("onsetmode", False))
            drumless = bool(c.get("drumless", False))
            # ---- style tokens on GPU: embed+blend+tokenize, cached until prompts/weights/audio change ----
            style_sig = (tuple(prompts), tuple(round(float(w), 4) for w in weights),
                         tuple(bool(a) for a in aflags[:len(prompts)]))
            if cur_tokens is None or style_sig != cur_style_sig:
                cur_style_sig = style_sig
                emb = torch.zeros(style_model.embedding_dim, device=dev, dtype=torch.float32)
                tot = float(sum(w for w in weights if w > 0)) or 1.0
                for i in range(len(prompts)):
                    w = float(weights[i]) if i < len(weights) else 0.0
                    if w <= 0:
                        continue
                    if i < len(aflags) and aflags[i]:
                        ap = os.path.join(SESSION_DIR, f"{os.path.basename(session_id)}_aud{i}.bin")
                        try: mt = os.path.getmtime(ap)
                        except OSError: continue
                        ce = aud_cache.get(i)
                        if ce is None or ce[0] != mt:
                            data = open(ap, "rb").read()
                            asr = int.from_bytes(data[:4], "little")
                            samp = np.frombuffer(data[4:], dtype=np.float32)
                            ce = (mt, style_model.embed_audio(samp, asr).float()); aud_cache[i] = ce
                        e = ce[1]
                    else:
                        lbl = (prompts[i] or "").strip()
                        if not lbl:
                            continue
                        e = txt_cache.get(lbl)
                        if e is None:
                            e = style_model.embed(lbl).float(); txt_cache[lbl] = e
                    emb = emb + (w / tot) * e
                cur_tokens = style_model.tokenize(emb).tolist()
            # ---- notes + cfg -> re-encode on ANY change every frame (no debounce); onset(2)->sustain(1) via had_onsets ----
            note_sig = (tuple(sorted(active)), unmask, onsetmode, drumless,
                        round(float(c.get("cfg_musiccoca", 1.6)), 3), round(float(c.get("cfg_notes", 2.4)), 3),
                        round(float(c.get("cfg_drums", 4.0)), 3), cur_style_sig)
            if source is None or note_sig != cur_note_sig or had_onsets:
                onsets = active - prev_active
                prev_active = set(active)
                had_onsets = onsetmode and bool(onsets)        # carry one frame so the onset settles to sustain next frame
                cur_note_sig = note_sig
                nvec = []                                      # per-pitch: -1 masked / 0 off / 1 cont / 2 onset / 3 on
                for pitch in range(128):
                    if pitch in active:
                        nvec.append((2 if pitch in onsets else 1) if onsetmode else 3)
                    elif unmask >= 127 or (active and any(abs(pitch - h) <= unmask for h in active)):
                        nvec.append(0)                         # solo (>=127) or within unmask window -> off
                    else:
                        nvec.append(-1)                        # masked (model free)
                drm = [0] if drumless else [-1]                # 0 no-drum / -1 masked
                cfg_notes_v = 7.0 if (unmask >= 127 and not active) else c.get("cfg_notes", 2.4)   # solo + no note held -> ramp cfg_notes to max bin (native JamApp suppresses melody)
                cfgs = [discretize_cfg(c.get("cfg_musiccoca", 1.6), 0.2, 40),
                        discretize_cfg(cfg_notes_v, 0.2, 40),
                        discretize_cfg(c.get("cfg_drums", 4.0), 1.0, 8)]
                cond = model._conditioning((list(cur_tokens) + [-1] * model.num_musiccoca)[:model.num_musiccoca],
                                           nvec, drm, cfgs)
                source = model.model.encode(cond).to(dt)
            sampler = make_sampler(c.get("temperature", 1.1), c.get("top_k", 50), gen)
            toks.append(model.model.decoder.step_f(dstate, source, sampler=sampler,
                        temporal_step=model._temporal_step, depth_step=model._depth_step))
        new_codes = torch.cat(toks, dim=1)
        audio = model.decode_stream(new_codes, decode_state)      # FLOP-optimal stateful streaming decode
        emitted_samples += audio.shape[1]
        frame_ms = (time.time() - gen_t) * 1000.0 / CHUNK     # real per-frame inference time
        if audio.shape[1] > 0:
            ab = _float_to_int16(audio[0].float().cpu().numpy()).astype("<i2").tobytes()
            yield (ab, frame_ms)
        buf = (0.15, 0.3, 0.5)[max(0, min(2, int(c.get("buffer", 0))))]
        ahead = (emitted_samples / SR) - (time.time() - t0)
        if ahead < -0.3:                              # fell behind (cold-start/hiccup) -> resync, no catch-up burst
            t0 = time.time() - (emitted_samples / SR) + buf
            ahead = buf
        buf_log.append(ahead)
        if time.time() - buf_t >= 15.0 and buf_log:
            print(f"[buffer] avg={sum(buf_log)/len(buf_log):.2f}s min={min(buf_log):.2f} "
                  f"max={max(buf_log):.2f} target={buf:.2f}s n={len(buf_log)}", flush=True)
            buf_log, buf_t = [], time.time()
        sleep_for = max(0.0, ahead - buf)
        cl = read_client_lead(session_id)                 # client's real Web-Audio queue (ms)
        if cl is not None and cl - buf * 1000.0 > 200:    # only drain a genuinely-deep queue (target+200ms)
            sleep_for = max(sleep_for, min((cl - buf * 1000.0) / 1000.0 * 0.3, 0.2))  # gentle + capped
        if sleep_for > 0:
            time.sleep(min(sleep_for, 0.5))


_sessions = {}
_slock = threading.Lock()


def _worker(session_id, audio_q, stop_ev):
    try:
        while not stop_ev.is_set():
            try:
                for ab, fm in gpu_stream(session_id):           # ~55s GPU grant; re-grants on loop
                    if stop_ev.is_set():
                        break
                    if audio_q.full():
                        try: audio_q.get_nowait()
                        except _q.Empty: pass
                    try: audio_q.put_nowait((ab, fm))
                    except _q.Full: pass
            except Exception as e:
                if "abort" in str(e).lower() or "duration" in str(e).lower():
                    continue                                    # grant expired -> re-grant
                print("[worker]", repr(e), flush=True); break
    finally:
        stop_ev.set()


@app.api(name="start")
def start(session_id: str = "", request: gr.Request = None) -> str:
    if not session_id:
        return ""
    req = request or LocalContext.request.get(None)      # user's gr.Request -> carries X-IP-Token (ZeroGPU quota)
    with _slock:
        prev = _sessions.pop(session_id, None)
    if prev:
        prev["stop"].set()
    audio_q = _q.Queue(maxsize=48); stop_ev = threading.Event()
    def run():
        if req is not None:
            LocalContext.request.set(req)                # graft user request so worker @spaces.GPU bills the USER
        _worker(session_id, audio_q, stop_ev)
    ctx = contextvars.copy_context()
    t = threading.Thread(target=ctx.run, args=(run,), daemon=True)
    with _slock:
        _sessions[session_id] = {"audio": audio_q, "stop": stop_ev}
    t.start()
    return session_id


@app.websocket("/ws")
async def audio_ws(websocket: WebSocket, session_id: str = ""):
    await websocket.accept()
    sess = _sessions.get(session_id)
    if not sess:
        await websocket.close(code=1008); return
    audio_q, stop_ev = sess["audio"], sess["stop"]
    loop = asyncio.get_event_loop()

    async def send_audio():                       # server -> client: binary audio frames
        try:
            while not stop_ev.is_set():
                try:
                    ab, fm = await loop.run_in_executor(None, lambda: audio_q.get(timeout=0.5))
                except Exception:
                    continue
                await websocket.send_bytes(struct.pack("<f", fm) + ab)
            try: await websocket.send_json({"type": "ended"})   # worker ended (quota/grant)
            except Exception: pass
        finally:
            stop_ev.set()

    async def recv_control():                     # client -> server: low-latency steering (notes/params), no HTTP RTT
        try:
            while not stop_ev.is_set():
                msg = await websocket.receive_json()            # raises WebSocketDisconnect on close
                t = msg.get("type")
                if t == "set":
                    try: _apply_set(msg)
                    except Exception as e: print("[ws set]", repr(e), flush=True)
                elif t == "buffer":
                    try: _apply_buffer(msg.get("session_id", session_id), msg.get("lead", 0))
                    except Exception: pass
        finally:
            stop_ev.set()

    try:
        await asyncio.gather(send_audio(), recv_control(), return_exceptions=True)
    finally:
        stop_ev.set()
        with _slock:
            _sessions.pop(session_id, None)


@app.get("/")
async def index():
    with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")) as f:
        return HTMLResponse(f.read())


app.launch(show_error=True, ssr_mode=False)