thecollabagepatch commited on
Commit
989061c
·
1 Parent(s): 1c7440e

updating web tester and simplifying websockets route

Browse files
Files changed (2) hide show
  1. app.py +196 -326
  2. magentaRT_rt_tester.html +31 -79
app.py CHANGED
@@ -1497,19 +1497,24 @@ async def log_requests(request: Request, call_next):
1497
 
1498
 
1499
  # ----------------------------
1500
- # websockets route
1501
  # ----------------------------
1502
 
1503
 
1504
  @app.websocket("/ws/jam")
1505
  async def ws_jam(websocket: WebSocket):
 
 
 
 
 
 
 
 
 
1506
  await websocket.accept()
1507
- sid = None
1508
- worker = None
1509
  binary_audio = False
1510
- mode = "rt" # or "bar"
1511
 
1512
- # NEW: capture ws in closure
1513
  async def send_json(obj):
1514
  return await send_json_safe(websocket, obj)
1515
 
@@ -1519,335 +1524,193 @@ async def ws_jam(websocket: WebSocket):
1519
  msg = json.loads(raw)
1520
  mtype = msg.get("type")
1521
 
1522
- # --- START ---
1523
  if mtype == "start":
1524
  binary_audio = bool(msg.get("binary_audio", False))
1525
- mode = msg.get("mode", "rt")
1526
  params = msg.get("params", {}) or {}
1527
- sid = msg.get("session_id")
1528
-
1529
- # attach or create
1530
- if sid:
1531
- with jam_lock:
1532
- worker = jam_registry.get(sid)
1533
- if worker is None or not worker.is_alive():
1534
- await send_json({"type":"error","error":"Session not found"})
1535
- continue
1536
- else:
1537
- # optionally accept base64 loop and start a new worker (bar-mode)
1538
- if mode == "bar":
1539
- loop_b64 = msg.get("loop_audio_b64")
1540
- if not loop_b64:
1541
- await send_json({"type":"error","error":"loop_audio_b64 required for mode=bar when no session_id"})
1542
- continue
1543
- loop_bytes = base64.b64decode(loop_b64)
1544
- # mimic /jam/start
1545
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
1546
- tmp.write(loop_bytes); tmp_path = tmp.name
1547
- # build JamParams similar to /jam/start
1548
- mrt = get_mrt()
1549
- model_sr = int(mrt.sample_rate) # typically 48000
1550
- # Defaults for WS: raw loudness @ model SR, unless overridden by client:
1551
- target_sr = int(params.get("target_sr", model_sr))
1552
- loudness_mode = params.get("loudness_mode", "none")
1553
- headroom_db = float(params.get("headroom_db", 1.0))
1554
- loop = au.Waveform.from_file(tmp_path).resample(mrt.sample_rate).as_stereo()
1555
-
1556
- codec_fps = float(mrt.codec.frame_rate)
1557
- ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
1558
- bpm = float(params.get("bpm", 120.0))
1559
- bpb = int(params.get("beats_per_bar", 4))
1560
- loop_tail = take_bar_aligned_tail(loop, bpm, bpb, ctx_seconds)
1561
-
1562
- # style vector (loop + extra styles)
1563
- embeds, weights = [mrt.embed_style(loop_tail)], [float(params.get("loop_weight", 1.0))]
1564
- extra = [s for s in (params.get("styles","").split(",")) if s.strip()]
1565
- sw = [float(x) for x in params.get("style_weights","").split(",") if x.strip()]
1566
- for i, s in enumerate(extra):
1567
- embeds.append(mrt.embed_style(s.strip()))
1568
- weights.append(sw[i] if i < len(sw) else 1.0)
1569
- wsum = sum(weights) or 1.0
1570
- weights = [w/wsum for w in weights]
1571
- style_vec = np.sum([w*e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
1572
-
1573
- # target SR fallback: input SR
1574
- inp_info = sf.info(tmp_path)
1575
- target_sr = int(params.get("target_sr", int(inp_info.samplerate)))
1576
-
1577
- # Build JamParams for WS bar-mode
1578
- jp = JamParams(
1579
- bpm=bpm, beats_per_bar=bpb, bars_per_chunk=int(params.get("bars_per_chunk", 8)),
1580
- target_sr=target_sr,
1581
- loudness_mode=loudness_mode, headroom_db=headroom_db,
1582
- style_vec=style_vec,
1583
- ref_loop=None if loudness_mode == "none" else loop_tail, # disable match by default
1584
- combined_loop=loop,
1585
- guidance_weight=float(params.get("guidance_weight", 1.1)),
1586
- temperature=float(params.get("temperature", 1.1)),
1587
- topk=int(params.get("topk", 40)),
1588
- )
1589
- worker = JamWorker(get_mrt(), jp)
1590
- sid = str(uuid.uuid4())
1591
- with jam_lock:
1592
- # single active jam per GPU, mirroring /jam/start
1593
- for _sid, w in list(jam_registry.items()):
1594
- if w.is_alive():
1595
- await send_json({"type":"error","error":"A jam is already running"})
1596
- worker = None; sid = None
1597
- break
1598
- if worker is not None:
1599
- jam_registry[sid] = worker
1600
- worker.start()
1601
-
1602
- else:
1603
- # mode == "rt" (Colab-style, no loop context)
1604
- mrt = get_mrt()
1605
- state = mrt.init_state()
1606
-
1607
- # Build silent context (10s) tokens
1608
- codec_fps = float(mrt.codec.frame_rate)
1609
- ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
1610
- sr = int(mrt.sample_rate)
1611
- samples = int(max(1, round(ctx_seconds * sr)))
1612
- silent = au.Waveform(np.zeros((samples, 2), np.float32), sr)
1613
- tokens = mrt.codec.encode(silent).astype(np.int32)[:, :mrt.config.decoder_codec_rvq_depth]
1614
- state.context_tokens = tokens
1615
-
1616
- # Parse params (including steering)
1617
- asset_manager.ensure_assets_loaded(get_mrt())
1618
- styles_str = params.get("styles", "warmup") or ""
1619
- style_weights_str = params.get("style_weights", "") or ""
1620
- mean_w = float(params.get("mean", 0.0) or 0.0)
1621
- cw_str = str(params.get("centroid_weights", "") or "")
1622
-
1623
- text_list = [s.strip() for s in styles_str.split(",") if s.strip()]
1624
- try:
1625
- text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
1626
- except ValueError:
1627
- text_w = []
1628
- try:
1629
- cw = [float(x) for x in cw_str.split(",") if x.strip() != ""]
1630
- except ValueError:
1631
- cw = []
1632
-
1633
- # Clamp centroid weights to available centroids
1634
- if _CENTROIDS is not None and len(cw) > int(_CENTROIDS.shape[0]):
1635
- cw = cw[: int(_CENTROIDS.shape[0])]
1636
-
1637
- # Build initial style vector (no loop_embed in rt mode)
1638
- style_vec = build_style_vector(
1639
- mrt,
1640
- text_styles=text_list,
1641
- text_weights=text_w,
1642
- loop_embed=None,
1643
- loop_weight=None,
1644
- mean_weight=mean_w,
1645
- centroid_weights=cw,
1646
- )
1647
-
1648
- # Stash rt session fields
1649
- websocket._mrt = mrt
1650
- websocket._state = state
1651
- websocket._style_cur = style_vec
1652
- websocket._style_tgt = style_vec
1653
- websocket._style_ramp_s = float(params.get("style_ramp_seconds", 0.0))
1654
-
1655
- websocket._rt_mean = mean_w
1656
- websocket._rt_centroid_weights = cw
1657
- websocket._rt_running = True
1658
- websocket._rt_sr = sr
1659
- websocket._rt_topk = int(params.get("topk", 40))
1660
- websocket._rt_temp = float(params.get("temperature", 1.1))
1661
- websocket._rt_guid = float(params.get("guidance_weight", 1.1))
1662
- websocket._pace = params.get("pace", "asap") # "realtime" | "asap"
1663
-
1664
- # (Optional) report whether steering assets were loaded
1665
- assets_ok = (_MEAN_EMBED is not None) or (_CENTROIDS is not None)
1666
- await send_json({"type": "started", "mode": "rt", "steering_assets": "loaded" if assets_ok else "none"})
1667
-
1668
- # kick off the ~2s streaming loop
1669
- async def _rt_loop():
1670
- try:
1671
- mrt = websocket._mrt
1672
- chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
1673
- target_next = time.perf_counter()
1674
- while websocket._rt_running:
1675
- mrt.guidance_weight = websocket._rt_guid
1676
- mrt.temperature = websocket._rt_temp
1677
- mrt.topk = websocket._rt_topk
1678
-
1679
- # ramp style
1680
- ramp = float(getattr(websocket, "_style_ramp_s", 0.0) or 0.0)
1681
- if ramp <= 0.0:
1682
- websocket._style_cur = websocket._style_tgt
1683
- else:
1684
- step = min(1.0, chunk_secs / ramp)
1685
- websocket._style_cur = websocket._style_cur + step * (websocket._style_tgt - websocket._style_cur)
1686
-
1687
- wav, new_state = mrt.generate_chunk(state=websocket._state, style=websocket._style_cur)
1688
- websocket._state = new_state
1689
-
1690
- x = wav.samples.astype(np.float32, copy=False)
1691
- buf = io.BytesIO()
1692
- sf.write(buf, x, mrt.sample_rate, subtype="FLOAT", format="WAV")
1693
-
1694
- ok = True
1695
- if binary_audio:
1696
- try:
1697
- await websocket.send_bytes(buf.getvalue())
1698
- ok = await send_json({"type": "chunk_meta", "metadata": {"sample_rate": mrt.sample_rate}})
1699
- except Exception:
1700
- ok = False
1701
- else:
1702
- b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
1703
- ok = await send_json({"type": "chunk", "audio_base64": b64,
1704
- "metadata": {"sample_rate": mrt.sample_rate}})
1705
-
1706
- if not ok:
1707
- break
1708
-
1709
- if getattr(websocket, "_pace", "asap") == "realtime":
1710
- t1 = time.perf_counter()
1711
- target_next += chunk_secs
1712
- sleep_s = max(0.0, target_next - t1 - 0.02)
1713
- if sleep_s > 0:
1714
- await asyncio.sleep(sleep_s)
1715
- except asyncio.CancelledError:
1716
- pass
1717
- except Exception:
1718
- pass
1719
-
1720
- websocket._rt_task = asyncio.create_task(_rt_loop())
1721
- continue # skip the “bar-mode started” message below
1722
-
1723
- await send_json({"type":"started","session_id": sid, "mode": mode})
1724
-
1725
- # if we’re in bar-mode, begin pushing chunks as they arrive
1726
- if mode == "bar" and worker is not None:
1727
- async def _pump():
1728
- while True:
1729
- if not worker.is_alive():
1730
- break
1731
- chunk = worker.get_next_chunk(timeout=60.0)
1732
- if chunk is None:
1733
- continue
1734
- audio_base64 = base64.b64encode(chunk.audio_bytes).decode("utf-8")
1735
  if binary_audio:
1736
- await websocket.send_bytes(chunk.audio_bytes)
1737
- await send_json({"type":"chunk_meta","index":chunk.index,"metadata":chunk.metadata})
 
 
 
1738
  else:
1739
- await send_json({"type":"chunk","index":chunk.index,
1740
- "audio_base64":audio_base64,"metadata":chunk.metadata})
1741
- asyncio.create_task(_pump())
1742
 
1743
- # --- UPDATES (bar or rt) ---
1744
- elif mtype == "update":
1745
- if mode == "bar":
1746
- if not sid:
1747
- await send_json({"type":"error","error":"No session_id yet"}); return
1748
- # fan values straight into your existing HTTP handler:
1749
- res = jam_update(
1750
- session_id=sid,
1751
- guidance_weight=msg.get("guidance_weight"),
1752
- temperature=msg.get("temperature"),
1753
- topk=msg.get("topk"),
1754
- styles=msg.get("styles",""),
1755
- style_weights=msg.get("style_weights",""),
1756
- loop_weight=msg.get("loop_weight"),
1757
- use_current_mix_as_style=bool(msg.get("use_current_mix_as_style", False)),
1758
- )
1759
- await send_json({"type":"status", **res}) # {"ok": True}
1760
- else:
1761
- # rt-mode: there’s no JamWorker; update the local knobs/state
1762
- websocket._rt_temp = float(msg.get("temperature", websocket._rt_temp))
1763
- websocket._rt_topk = int(msg.get("topk", websocket._rt_topk))
1764
- websocket._rt_guid = float(msg.get("guidance_weight", websocket._rt_guid))
1765
-
1766
- # NEW steering fields
1767
- if "mean" in msg and msg["mean"] is not None:
1768
- try: websocket._rt_mean = float(msg["mean"])
1769
- except: websocket._rt_mean = 0.0
1770
-
1771
- if "centroid_weights" in msg:
1772
- cw = [w.strip() for w in str(msg["centroid_weights"]).split(",") if w.strip() != ""]
1773
- try:
1774
- websocket._rt_centroid_weights = [float(x) for x in cw]
1775
- except:
1776
- websocket._rt_centroid_weights = []
1777
-
1778
- # styles / text weights (optional, comma-separated)
1779
- styles_str = msg.get("styles", None)
1780
- style_weights_str = msg.get("style_weights", "")
1781
-
1782
- text_list = [s for s in (styles_str.split(",") if styles_str else []) if s.strip()]
1783
- text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
1784
 
1785
- asset_manager.ensure_assets_loaded(get_mrt())
1786
- websocket._style_tgt = build_style_vector(
1787
- websocket._mrt,
1788
- text_styles=text_list,
1789
- text_weights=text_w,
1790
- loop_embed=None,
1791
- loop_weight=None,
1792
- mean_weight=float(websocket._rt_mean),
1793
- centroid_weights=websocket._rt_centroid_weights,
1794
- )
1795
- # optionally allow live changes to ramp:
1796
- if "style_ramp_seconds" in msg:
1797
- try: websocket._style_ramp_s = float(msg["style_ramp_seconds"])
1798
- except: pass
1799
- await send_json({"type":"status","updated":"rt-knobs+style"})
1800
-
1801
- elif mtype == "consume" and mode == "bar":
1802
- with jam_lock:
1803
- worker = jam_registry.get(msg.get("session_id"))
1804
- if worker is not None:
1805
- worker.mark_chunk_consumed(int(msg.get("chunk_index", -1)))
1806
-
1807
- elif mtype == "reseed" and mode == "bar":
1808
- with jam_lock:
1809
- worker = jam_registry.get(msg.get("session_id"))
1810
- if worker is None or not worker.is_alive():
1811
- await send_json({"type":"error","error":"Session not found"}); continue
1812
- loop_b64 = msg.get("loop_audio_b64")
1813
- if not loop_b64:
1814
- await send_json({"type":"error","error":"loop_audio_b64 required"}); continue
1815
- loop_bytes = base64.b64decode(loop_b64)
1816
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
1817
- tmp.write(loop_bytes); path = tmp.name
1818
- wav = au.Waveform.from_file(path).resample(worker.mrt.sample_rate).as_stereo()
1819
- worker.reseed_from_waveform(wav)
1820
- await send_json({"type":"status","reseeded":True})
1821
-
1822
- elif mtype == "reseed_splice" and mode == "bar":
1823
- with jam_lock:
1824
- worker = jam_registry.get(msg.get("session_id"))
1825
- if worker is None or not worker.is_alive():
1826
- await send_json({"type":"error","error":"Session not found"}); continue
1827
- anchor = float(msg.get("anchor_bars", 2.0))
1828
- b64 = msg.get("combined_audio_b64")
1829
- if b64:
1830
- data = base64.b64decode(b64)
1831
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
1832
- tmp.write(data); path = tmp.name
1833
- wav = au.Waveform.from_file(path).resample(worker.mrt.sample_rate).as_stereo()
1834
- worker.reseed_splice(wav, anchor_bars=anchor)
1835
- else:
1836
- # fallback: model-side stream splice
1837
- worker.reseed_splice(worker.params.combined_loop, anchor_bars=anchor)
1838
- await send_json({"type":"status","splice":anchor})
1839
 
1840
- elif mtype == "stop":
1841
- if mode == "rt":
1842
- websocket._rt_running = False
1843
- task = getattr(websocket, "_rt_task", None)
1844
- if task is not None:
1845
- task.cancel()
1846
- try: await task
1847
- except asyncio.CancelledError: pass
1848
- await send_json({"type":"stopped"})
1849
- break # <- add this if you want to end the socket after stop
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1850
 
 
 
 
 
 
 
 
 
 
 
 
 
1851
  elif mtype == "ping":
1852
  await send_json({"type":"pong"})
1853
 
@@ -1855,7 +1718,6 @@ async def ws_jam(websocket: WebSocket):
1855
  await send_json({"type":"error","error":f"Unknown type {mtype}"})
1856
 
1857
  except WebSocketDisconnect:
1858
- # best-effort cleanup for bar-mode sessions started within this socket (optional)
1859
  pass
1860
  except Exception as e:
1861
  try:
@@ -1863,6 +1725,14 @@ async def ws_jam(websocket: WebSocket):
1863
  except Exception:
1864
  pass
1865
  finally:
 
 
 
 
 
 
 
 
1866
  try:
1867
  if websocket.client_state != WebSocketState.DISCONNECTED:
1868
  await websocket.close()
 
1497
 
1498
 
1499
  # ----------------------------
1500
+ # websockets route (rt-mode only)
1501
  # ----------------------------
1502
 
1503
 
1504
  @app.websocket("/ws/jam")
1505
  async def ws_jam(websocket: WebSocket):
1506
+ """
1507
+ Real-time streaming WebSocket endpoint for MagentaRT.
1508
+
1509
+ This route operates in 'rt' mode only - for bar-aligned jam sessions,
1510
+ use the HTTP endpoints (/jam/start, /jam/chunk, etc.) instead.
1511
+
1512
+ The server handles crossfading internally via MagentaRTState, so clients
1513
+ can simply play back chunks sequentially without additional crossfade logic.
1514
+ """
1515
  await websocket.accept()
 
 
1516
  binary_audio = False
 
1517
 
 
1518
  async def send_json(obj):
1519
  return await send_json_safe(websocket, obj)
1520
 
 
1524
  msg = json.loads(raw)
1525
  mtype = msg.get("type")
1526
 
1527
+ # --- START (rt-mode only) ---
1528
  if mtype == "start":
1529
  binary_audio = bool(msg.get("binary_audio", False))
 
1530
  params = msg.get("params", {}) or {}
1531
+
1532
+ # Initialize MagentaRT state
1533
+ mrt = get_mrt()
1534
+ state = mrt.init_state()
1535
+
1536
+ # Build silent context (10s) tokens
1537
+ codec_fps = float(mrt.codec.frame_rate)
1538
+ ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
1539
+ sr = int(mrt.sample_rate)
1540
+ samples = int(max(1, round(ctx_seconds * sr)))
1541
+ silent = au.Waveform(np.zeros((samples, 2), np.float32), sr)
1542
+ tokens = mrt.codec.encode(silent).astype(np.int32)[:, :mrt.config.decoder_codec_rvq_depth]
1543
+ state.context_tokens = tokens
1544
+
1545
+ # Parse params (including steering)
1546
+ asset_manager.ensure_assets_loaded(get_mrt())
1547
+ styles_str = params.get("styles", "warmup") or ""
1548
+ style_weights_str = params.get("style_weights", "") or ""
1549
+ mean_w = float(params.get("mean", 0.0) or 0.0)
1550
+ cw_str = str(params.get("centroid_weights", "") or "")
1551
+
1552
+ text_list = [s.strip() for s in styles_str.split(",") if s.strip()]
1553
+ try:
1554
+ text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
1555
+ except ValueError:
1556
+ text_w = []
1557
+ try:
1558
+ cw = [float(x) for x in cw_str.split(",") if x.strip() != ""]
1559
+ except ValueError:
1560
+ cw = []
1561
+
1562
+ # Clamp centroid weights to available centroids
1563
+ if _CENTROIDS is not None and len(cw) > int(_CENTROIDS.shape[0]):
1564
+ cw = cw[: int(_CENTROIDS.shape[0])]
1565
+
1566
+ # Build initial style vector (no loop_embed in rt mode)
1567
+ style_vec = build_style_vector(
1568
+ mrt,
1569
+ text_styles=text_list,
1570
+ text_weights=text_w,
1571
+ loop_embed=None,
1572
+ loop_weight=None,
1573
+ mean_weight=mean_w,
1574
+ centroid_weights=cw,
1575
+ )
1576
+
1577
+ # Stash rt session fields on the websocket object
1578
+ websocket._mrt = mrt
1579
+ websocket._state = state
1580
+ websocket._style_cur = style_vec
1581
+ websocket._style_tgt = style_vec
1582
+ websocket._style_ramp_s = float(params.get("style_ramp_seconds", 0.0))
1583
+
1584
+ websocket._rt_mean = mean_w
1585
+ websocket._rt_centroid_weights = cw
1586
+ websocket._rt_running = True
1587
+ websocket._rt_sr = sr
1588
+ websocket._rt_topk = int(params.get("topk", 40))
1589
+ websocket._rt_temp = float(params.get("temperature", 1.1))
1590
+ websocket._rt_guid = float(params.get("guidance_weight", 1.1))
1591
+ websocket._pace = params.get("pace", "asap") # "realtime" | "asap"
1592
+
1593
+ # Report whether steering assets were loaded
1594
+ assets_ok = (_MEAN_EMBED is not None) or (_CENTROIDS is not None)
1595
+ await send_json({"type": "started", "mode": "rt", "steering_assets": "loaded" if assets_ok else "none"})
1596
+
1597
+ # Kick off the ~2s streaming loop
1598
+ async def _rt_loop():
1599
+ try:
1600
+ mrt = websocket._mrt
1601
+ chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
1602
+ target_next = time.perf_counter()
1603
+ while websocket._rt_running:
1604
+ mrt.guidance_weight = websocket._rt_guid
1605
+ mrt.temperature = websocket._rt_temp
1606
+ mrt.topk = websocket._rt_topk
1607
+
1608
+ # Ramp style toward target
1609
+ ramp = float(getattr(websocket, "_style_ramp_s", 0.0) or 0.0)
1610
+ if ramp <= 0.0:
1611
+ websocket._style_cur = websocket._style_tgt
1612
+ else:
1613
+ step = min(1.0, chunk_secs / ramp)
1614
+ websocket._style_cur = websocket._style_cur + step * (websocket._style_tgt - websocket._style_cur)
1615
+
1616
+ # Generate chunk (crossfading handled internally by MagentaRT)
1617
+ wav, new_state = mrt.generate_chunk(state=websocket._state, style=websocket._style_cur)
1618
+ websocket._state = new_state
1619
+
1620
+ x = wav.samples.astype(np.float32, copy=False)
1621
+ buf = io.BytesIO()
1622
+ sf.write(buf, x, mrt.sample_rate, subtype="FLOAT", format="WAV")
1623
+
1624
+ ok = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1625
  if binary_audio:
1626
+ try:
1627
+ await websocket.send_bytes(buf.getvalue())
1628
+ ok = await send_json({"type": "chunk_meta", "metadata": {"sample_rate": mrt.sample_rate}})
1629
+ except Exception:
1630
+ ok = False
1631
  else:
1632
+ b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
1633
+ ok = await send_json({"type": "chunk", "audio_base64": b64,
1634
+ "metadata": {"sample_rate": mrt.sample_rate}})
1635
 
1636
+ if not ok:
1637
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1638
 
1639
+ if getattr(websocket, "_pace", "asap") == "realtime":
1640
+ t1 = time.perf_counter()
1641
+ target_next += chunk_secs
1642
+ sleep_s = max(0.0, target_next - t1 - 0.02)
1643
+ if sleep_s > 0:
1644
+ await asyncio.sleep(sleep_s)
1645
+ except asyncio.CancelledError:
1646
+ pass
1647
+ except Exception:
1648
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1649
 
1650
+ websocket._rt_task = asyncio.create_task(_rt_loop())
1651
+
1652
+ # --- UPDATE (rt-mode knobs & style) ---
1653
+ elif mtype == "update":
1654
+ if not hasattr(websocket, "_rt_running"):
1655
+ await send_json({"type":"error","error":"Session not started"}); continue
1656
+
1657
+ websocket._rt_temp = float(msg.get("temperature", websocket._rt_temp))
1658
+ websocket._rt_topk = int(msg.get("topk", websocket._rt_topk))
1659
+ websocket._rt_guid = float(msg.get("guidance_weight", websocket._rt_guid))
1660
+
1661
+ # Steering fields
1662
+ if "mean" in msg and msg["mean"] is not None:
1663
+ try: websocket._rt_mean = float(msg["mean"])
1664
+ except: websocket._rt_mean = 0.0
1665
+
1666
+ if "centroid_weights" in msg:
1667
+ cw = [w.strip() for w in str(msg["centroid_weights"]).split(",") if w.strip() != ""]
1668
+ try:
1669
+ websocket._rt_centroid_weights = [float(x) for x in cw]
1670
+ except:
1671
+ websocket._rt_centroid_weights = []
1672
+
1673
+ # Styles / text weights (optional, comma-separated)
1674
+ styles_str = msg.get("styles", None)
1675
+ style_weights_str = msg.get("style_weights", "")
1676
+
1677
+ text_list = [s for s in (styles_str.split(",") if styles_str else []) if s.strip()]
1678
+ text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
1679
+
1680
+ asset_manager.ensure_assets_loaded(get_mrt())
1681
+ websocket._style_tgt = build_style_vector(
1682
+ websocket._mrt,
1683
+ text_styles=text_list,
1684
+ text_weights=text_w,
1685
+ loop_embed=None,
1686
+ loop_weight=None,
1687
+ mean_weight=float(websocket._rt_mean),
1688
+ centroid_weights=websocket._rt_centroid_weights,
1689
+ )
1690
+ # Optionally allow live changes to ramp:
1691
+ if "style_ramp_seconds" in msg:
1692
+ try: websocket._style_ramp_s = float(msg["style_ramp_seconds"])
1693
+ except: pass
1694
+ await send_json({"type":"status","updated":"rt-knobs+style"})
1695
+
1696
+ # --- BUFFER STATUS (from frontend for adaptive pacing - acknowledged silently) ---
1697
+ elif mtype == "buffer_status":
1698
+ # Frontend reports its buffer level; could be used for adaptive pacing
1699
+ # For now we just acknowledge receipt without action
1700
+ pass
1701
 
1702
+ # --- STOP ---
1703
+ elif mtype == "stop":
1704
+ websocket._rt_running = False
1705
+ task = getattr(websocket, "_rt_task", None)
1706
+ if task is not None:
1707
+ task.cancel()
1708
+ try: await task
1709
+ except asyncio.CancelledError: pass
1710
+ await send_json({"type":"stopped"})
1711
+ break
1712
+
1713
+ # --- PING/PONG ---
1714
  elif mtype == "ping":
1715
  await send_json({"type":"pong"})
1716
 
 
1718
  await send_json({"type":"error","error":f"Unknown type {mtype}"})
1719
 
1720
  except WebSocketDisconnect:
 
1721
  pass
1722
  except Exception as e:
1723
  try:
 
1725
  except Exception:
1726
  pass
1727
  finally:
1728
+ # Ensure streaming loop is stopped
1729
+ if hasattr(websocket, "_rt_running"):
1730
+ websocket._rt_running = False
1731
+ task = getattr(websocket, "_rt_task", None)
1732
+ if task is not None:
1733
+ task.cancel()
1734
+ try: await asyncio.wait_for(task, timeout=1.0)
1735
+ except: pass
1736
  try:
1737
  if websocket.client_state != WebSocketState.DISCONNECTED:
1738
  await websocket.close()
magentaRT_rt_tester.html CHANGED
@@ -256,93 +256,45 @@
256
  const rngC4 = $("rngC4"), numC4 = $("numC4");
257
  const rngC5 = $("rngC5"), numC5 = $("numC5");
258
 
259
- const XFADE_MS = 40; // crossfade length
 
260
 
261
  let pending = []; // decoded AudioBuffers waiting to be scheduled
262
- let playing = false; // have we started playback?
263
- const START_CUSHION = 0.12; // already used
264
-
265
- const fade = XFADE_MS / 1000;
266
-
267
- // Equal-power crossfading functions
268
- function equalPowerFadeOut(t) {
269
- // cos²(t * π/2) where t goes from 0 to 1
270
- return Math.cos(t * Math.PI / 2) ** 2;
271
- }
272
-
273
- function equalPowerFadeIn(t) {
274
- // sin²(t * π/2) where t goes from 0 to 1
275
- return Math.sin(t * Math.PI / 2) ** 2;
276
- }
277
-
278
- function scheduleAudioBuffer(abuf) {
279
- // Equal-power crossfade scheduling
280
- const src = ctx.createBufferSource();
281
- const g = ctx.createGain();
282
- src.buffer = abuf;
283
- src.connect(g); g.connect(gain);
284
-
285
- if (nextTime < ctx.currentTime + 0.05) nextTime = ctx.currentTime + START_CUSHION;
286
- const startAt = nextTime;
287
- const dur = abuf.duration;
288
 
289
- // Overlap by 'fade' so there's no dip
290
- nextTime = startAt + Math.max(0, dur - fade);
291
 
292
- // Equal-power crossfading using custom curves
293
- const numPoints = 64; // More points for smoother curves
294
- const times = [];
295
- const values = [];
296
-
297
- // Fade in from 0 to 1 over fade duration
298
- for (let i = 0; i <= numPoints; i++) {
299
- const t = i / numPoints;
300
- const time = startAt + t * fade;
301
- const value = equalPowerFadeIn(t);
302
- times.push(time);
303
- values.push(value);
304
- }
305
-
306
- // Hold at 1.0 during the main portion
307
- const holdStart = startAt + fade;
308
- const holdEnd = startAt + Math.max(0, dur - fade);
309
- if (holdEnd > holdStart) {
310
- times.push(holdStart);
311
- values.push(1.0);
312
- times.push(holdEnd);
313
- values.push(1.0);
314
- }
315
-
316
- // Fade out from 1 to 0 over fade duration
317
- for (let i = 0; i <= numPoints; i++) {
318
- const t = i / numPoints;
319
- const time = startAt + Math.max(0, dur - fade) + t * fade;
320
- const value = equalPowerFadeOut(t);
321
- times.push(time);
322
- values.push(value);
323
- }
324
 
325
- // Apply the envelope
326
- g.gain.setValueAtTime(values[0], times[0]);
327
- for (let i = 1; i < times.length; i++) {
328
- g.gain.linearRampToValueAtTime(values[i], times[i]);
329
  }
330
 
331
- src.start(startAt);
332
- scheduled.push({ src, when: startAt, dur });
333
- updateQueueUI();
334
- src.onended = () => { scheduled = scheduled.filter(s => s.src !== src); updateQueueUI(); };
335
- }
336
-
337
- function beginPlaybackFromPending() {
338
- if (playing) return;
339
- playing = true;
340
- nextTime = ctx.currentTime + START_CUSHION;
341
- while (pending.length) {
342
- const abuf = pending.shift();
343
- scheduleAudioBuffer(abuf);
344
  }
345
- }
346
 
347
  // Audio chain
348
  let AudioCtx = window.AudioContext || window.webkitAudioContext;
 
256
  const rngC4 = $("rngC4"), numC4 = $("numC4");
257
  const rngC5 = $("rngC5"), numC5 = $("numC5");
258
 
259
+ // Simplified playback - server handles crossfading via MagentaRTState
260
+ // Chunks arrive pre-crossfaded; we just schedule them back-to-back
261
 
262
  let pending = []; // decoded AudioBuffers waiting to be scheduled
263
+ let playing = false; // have we started playback?
264
+ const START_CUSHION = 0.12; // initial buffer before first playback
265
+
266
+ function scheduleAudioBuffer(abuf) {
267
+ // Simple back-to-back scheduling (no client-side crossfade needed)
268
+ const src = ctx.createBufferSource();
269
+ src.buffer = abuf;
270
+ src.connect(gain);
271
+
272
+ // Ensure we don't schedule in the past
273
+ if (nextTime < ctx.currentTime + 0.05) {
274
+ nextTime = ctx.currentTime + START_CUSHION;
275
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
+ const startAt = nextTime;
278
+ const dur = abuf.duration;
279
 
280
+ // Schedule next chunk right after this one ends
281
+ nextTime = startAt + dur;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
+ src.start(startAt);
284
+ scheduled.push({ src, when: startAt, dur });
285
+ updateQueueUI();
286
+ src.onended = () => { scheduled = scheduled.filter(s => s.src !== src); updateQueueUI(); };
287
  }
288
 
289
+ function beginPlaybackFromPending() {
290
+ if (playing) return;
291
+ playing = true;
292
+ nextTime = ctx.currentTime + START_CUSHION;
293
+ while (pending.length) {
294
+ const abuf = pending.shift();
295
+ scheduleAudioBuffer(abuf);
296
+ }
 
 
 
 
 
297
  }
 
298
 
299
  // Audio chain
300
  let AudioCtx = window.AudioContext || window.webkitAudioContext;