adityas129 commited on
Commit
0aab971
·
verified ·
1 Parent(s): d3b9e54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -68
app.py CHANGED
@@ -143,14 +143,18 @@ def _sync_assets_globals_from_manager():
143
 
144
  def _any_jam_running() -> bool:
145
  with jam_lock:
146
- return any(w.is_alive() for w in jam_registry.values())
147
 
148
  def _stop_all_jams(timeout: float = 5.0):
149
  with jam_lock:
150
- for sid, w in list(jam_registry.items()):
 
151
  if w.is_alive():
152
  w.stop()
153
  w.join(timeout=timeout)
 
 
 
154
  jam_registry.pop(sid, None)
155
 
156
 
@@ -202,7 +206,7 @@ def _patch_t5x_for_gpu_coords():
202
  # Call the patch immediately at import time (before MagentaRT init)
203
  _patch_t5x_for_gpu_coords()
204
 
205
- jam_registry: dict[str, JamWorker] = {}
206
  jam_lock = threading.Lock()
207
 
208
  # ============================================================================
@@ -1157,13 +1161,16 @@ def jam_stop(session_id: str = Body(..., embed=True)):
1157
  mrt_index = session_info['mrt_index']
1158
 
1159
  worker.stop()
1160
- worker.join(timeout=5.0)
1161
  if worker.is_alive():
1162
- # It's daemon=True, so it won't block process exit, but report it
1163
- print(f"⚠️ JamWorker {session_id} did not stop within timeout")
 
 
1164
 
1165
- # Release MRT back to pool
1166
- release_mrt(mrt_index)
 
1167
 
1168
  with jam_lock:
1169
  jam_registry.pop(session_id, None)
@@ -1187,9 +1194,63 @@ def jam_stop_all():
1187
  # Release MRT back to pool
1188
  release_mrt(mrt_index)
1189
  jam_registry.pop(session_id, None)
1190
-
1191
  return {"stopped_sessions": stopped_sessions, "count": len(stopped_sessions)}
1192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1193
  @app.post("/jam/update")
1194
  def jam_update(
1195
  session_id: str = Form(...),
@@ -1394,6 +1455,91 @@ def jam_status(session_id: str):
1394
  "last_chunk_completed_at": worker.last_chunk_completed_at,
1395
  }
1396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1397
 
1398
  @app.get("/health")
1399
  def health():
@@ -1426,7 +1572,7 @@ def health():
1426
  # 3) Ready; include operational hints
1427
  warmed = bool(_WARMED)
1428
  with jam_lock:
1429
- active_jams = sum(1 for w in jam_registry.values() if w.is_alive())
1430
  return {
1431
  "ok": True,
1432
  "status": "ready" if warmed else "initializing",
@@ -1550,66 +1696,28 @@ async def ws_jam(websocket: WebSocket):
1550
  sid = str(uuid.uuid4())
1551
  with jam_lock:
1552
  # single active jam per GPU, mirroring /jam/start
1553
- for _sid, w in list(jam_registry.items()):
1554
- if w.is_alive():
1555
  await send_json({"type":"error","error":"A jam is already running"})
1556
  worker = None; sid = None
1557
  break
1558
  if worker is not None:
1559
- jam_registry[sid] = worker
1560
  worker.start()
1561
 
1562
  else:
1563
- # mode == "rt" (with optional loop context)
1564
  mrt = get_mrt()
1565
  state = mrt.init_state()
1566
 
1567
- # Build context tokens (silent or from loop)
1568
  codec_fps = float(mrt.codec.frame_rate)
1569
  ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
1570
  sr = int(mrt.sample_rate)
1571
-
1572
- # Check for optional loop audio
1573
- loop_b64 = msg.get("loop_audio_b64")
1574
- loop_embed = None
1575
-
1576
- if loop_b64:
1577
- try:
1578
- # Decode and load loop (similar to bar-mode)
1579
- loop_bytes = base64.b64decode(loop_b64)
1580
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
1581
- tmp.write(loop_bytes)
1582
- tmp_path = tmp.name
1583
-
1584
- loop = au.Waveform.from_file(tmp_path).resample(mrt.sample_rate).as_stereo()
1585
-
1586
- # Extract bar-aligned tail for embedding
1587
- bpm = float(params.get("bpm", 120.0))
1588
- bpb = int(params.get("beats_per_bar", 4))
1589
- loop_tail = take_bar_aligned_tail(loop, bpm, bpb, ctx_seconds)
1590
-
1591
- # Embed the loop audio
1592
- loop_embed = mrt.embed_style(loop_tail)
1593
-
1594
- # Use loop audio for context tokens instead of silent
1595
- tokens = mrt.codec.encode(loop_tail).astype(np.int32)[:, :mrt.config.decoder_codec_rvq_depth]
1596
- state.context_tokens = tokens
1597
-
1598
- except Exception as e:
1599
- # Log error but continue with silent context
1600
- print(f"Loop audio processing failed: {e}")
1601
- loop_embed = None
1602
- # Fall back to silent context
1603
- samples = int(max(1, round(ctx_seconds * sr)))
1604
- silent = au.Waveform(np.zeros((samples, 2), np.float32), sr)
1605
- tokens = mrt.codec.encode(silent).astype(np.int32)[:, :mrt.config.decoder_codec_rvq_depth]
1606
- state.context_tokens = tokens
1607
- else:
1608
- # No loop provided - use silent context (original behavior)
1609
- samples = int(max(1, round(ctx_seconds * sr)))
1610
- silent = au.Waveform(np.zeros((samples, 2), np.float32), sr)
1611
- tokens = mrt.codec.encode(silent).astype(np.int32)[:, :mrt.config.decoder_codec_rvq_depth]
1612
- state.context_tokens = tokens
1613
 
1614
  # Parse params (including steering)
1615
  asset_manager.ensure_assets_loaded(get_mrt())
@@ -1617,7 +1725,6 @@ async def ws_jam(websocket: WebSocket):
1617
  style_weights_str = params.get("style_weights", "") or ""
1618
  mean_w = float(params.get("mean", 0.0) or 0.0)
1619
  cw_str = str(params.get("centroid_weights", "") or "")
1620
- loop_weight = float(params.get("loop_weight", 1.0) or 1.0)
1621
 
1622
  text_list = [s.strip() for s in styles_str.split(",") if s.strip()]
1623
  try:
@@ -1633,13 +1740,13 @@ async def ws_jam(websocket: WebSocket):
1633
  if _CENTROIDS is not None and len(cw) > int(_CENTROIDS.shape[0]):
1634
  cw = cw[: int(_CENTROIDS.shape[0])]
1635
 
1636
- # Build initial style vector WITH optional loop_embed
1637
  style_vec = build_style_vector(
1638
  mrt,
1639
  text_styles=text_list,
1640
  text_weights=text_w,
1641
- loop_embed=loop_embed,
1642
- loop_weight=loop_weight,
1643
  mean_weight=mean_w,
1644
  centroid_weights=cw,
1645
  )
@@ -1798,15 +1905,16 @@ async def ws_jam(websocket: WebSocket):
1798
 
1799
  elif mtype == "consume" and mode == "bar":
1800
  with jam_lock:
1801
- worker = jam_registry.get(msg.get("session_id"))
1802
- if worker is not None:
1803
- worker.mark_chunk_consumed(int(msg.get("chunk_index", -1)))
1804
 
1805
  elif mtype == "reseed" and mode == "bar":
1806
  with jam_lock:
1807
- worker = jam_registry.get(msg.get("session_id"))
1808
- if worker is None or not worker.is_alive():
1809
  await send_json({"type":"error","error":"Session not found"}); continue
 
1810
  loop_b64 = msg.get("loop_audio_b64")
1811
  if not loop_b64:
1812
  await send_json({"type":"error","error":"loop_audio_b64 required"}); continue
@@ -1819,9 +1927,10 @@ async def ws_jam(websocket: WebSocket):
1819
 
1820
  elif mtype == "reseed_splice" and mode == "bar":
1821
  with jam_lock:
1822
- worker = jam_registry.get(msg.get("session_id"))
1823
- if worker is None or not worker.is_alive():
1824
  await send_json({"type":"error","error":"Session not found"}); continue
 
1825
  anchor = float(msg.get("anchor_bars", 2.0))
1826
  b64 = msg.get("combined_audio_b64")
1827
  if b64:
 
143
 
144
  def _any_jam_running() -> bool:
145
  with jam_lock:
146
+ return any(info['worker'].is_alive() for info in jam_registry.values())
147
 
148
  def _stop_all_jams(timeout: float = 5.0):
149
  with jam_lock:
150
+ for sid, info in list(jam_registry.items()):
151
+ w = info['worker']
152
  if w.is_alive():
153
  w.stop()
154
  w.join(timeout=timeout)
155
+ # Release MRT slot
156
+ if info.get('mrt_index') is not None:
157
+ release_mrt(info['mrt_index'])
158
  jam_registry.pop(sid, None)
159
 
160
 
 
206
  # Call the patch immediately at import time (before MagentaRT init)
207
  _patch_t5x_for_gpu_coords()
208
 
209
+ jam_registry: dict[str, dict] = {} # Now stores {'worker': JamWorker, 'mrt_index': int}
210
  jam_lock = threading.Lock()
211
 
212
  # ============================================================================
 
1161
  mrt_index = session_info['mrt_index']
1162
 
1163
  worker.stop()
1164
+ worker.join(timeout=10.0) # Increased from 5s to 10s to allow chunk generation to finish
1165
  if worker.is_alive():
1166
+ # Worker still running - don't release MRT slot to avoid corruption
1167
+ print(f"⚠️ JamWorker {session_id} did not stop within timeout - keeping MRT slot reserved")
1168
+ # Keep in registry so we can try to stop it again later
1169
+ return {"stopped": False, "timeout": True, "message": "Worker did not stop in time, retry /jam/stop"}
1170
 
1171
+ # Only release MRT if worker actually stopped
1172
+ if mrt_index is not None:
1173
+ release_mrt(mrt_index)
1174
 
1175
  with jam_lock:
1176
  jam_registry.pop(session_id, None)
 
1194
  # Release MRT back to pool
1195
  release_mrt(mrt_index)
1196
  jam_registry.pop(session_id, None)
1197
+
1198
  return {"stopped_sessions": stopped_sessions, "count": len(stopped_sessions)}
1199
 
1200
+ @app.post("/jam/cleanup")
1201
+ def jam_cleanup(force: bool = Form(False), idle_threshold_seconds: float = Form(300.0)):
1202
+ """
1203
+ Enhanced cleanup endpoint for stopping stale/orphaned sessions.
1204
+
1205
+ - force=False: Only stops sessions idle for > idle_threshold_seconds (default 5 min)
1206
+ - force=True: Stops ALL sessions regardless of activity (nuclear option)
1207
+ """
1208
+ stopped_sessions = []
1209
+ kept_sessions = []
1210
+ current_time = time.time()
1211
+
1212
+ with jam_lock:
1213
+ for session_id, session_info in list(jam_registry.items()):
1214
+ worker = session_info['worker']
1215
+ mrt_index = session_info['mrt_index']
1216
+
1217
+ # Determine if session should be stopped
1218
+ should_stop = force
1219
+ idle_time = 0
1220
+
1221
+ if not force and hasattr(worker, 'last_activity_at'):
1222
+ idle_time = current_time - worker.last_activity_at
1223
+ should_stop = idle_time > idle_threshold_seconds
1224
+
1225
+ if should_stop:
1226
+ if worker.is_alive():
1227
+ worker.stop()
1228
+ worker.join(timeout=10.0)
1229
+
1230
+ # Release MRT regardless of whether worker stopped
1231
+ release_mrt(mrt_index)
1232
+ jam_registry.pop(session_id, None)
1233
+ stopped_sessions.append({
1234
+ "session_id": session_id,
1235
+ "idle_seconds": round(idle_time, 1),
1236
+ "slot": mrt_index
1237
+ })
1238
+ else:
1239
+ kept_sessions.append({
1240
+ "session_id": session_id,
1241
+ "idle_seconds": round(idle_time, 1),
1242
+ "slot": mrt_index
1243
+ })
1244
+
1245
+ return {
1246
+ "stopped": stopped_sessions,
1247
+ "kept": kept_sessions,
1248
+ "stopped_count": len(stopped_sessions),
1249
+ "kept_count": len(kept_sessions),
1250
+ "force": force,
1251
+ "idle_threshold_seconds": idle_threshold_seconds
1252
+ }
1253
+
1254
  @app.post("/jam/update")
1255
  def jam_update(
1256
  session_id: str = Form(...),
 
1455
  "last_chunk_completed_at": worker.last_chunk_completed_at,
1456
  }
1457
 
1458
+ @app.get("/jam/sessions")
1459
+ def jam_sessions():
1460
+ """List all active JAM sessions with metadata for monitoring"""
1461
+ sessions = []
1462
+ current_time = time.time()
1463
+
1464
+ with jam_lock:
1465
+ for session_id, session_info in jam_registry.items():
1466
+ worker = session_info['worker']
1467
+ mrt_index = session_info['mrt_index']
1468
+
1469
+ # Calculate uptime and idle time
1470
+ uptime = current_time - worker.created_at if hasattr(worker, 'created_at') else 0
1471
+ last_activity = worker.last_activity_at if hasattr(worker, 'last_activity_at') else worker.created_at if hasattr(worker, 'created_at') else current_time
1472
+ idle_time = current_time - last_activity
1473
+
1474
+ # Get generation stats
1475
+ with worker._lock:
1476
+ last_generated = int(worker.idx)
1477
+ last_delivered = int(worker._last_delivered_index)
1478
+ queued = len(worker.outbox)
1479
+
1480
+ sessions.append({
1481
+ "session_id": session_id,
1482
+ "mrt_slot": mrt_index,
1483
+ "running": worker.is_alive(),
1484
+ "uptime_seconds": round(uptime, 1),
1485
+ "idle_seconds": round(idle_time, 1),
1486
+ "chunks_generated": last_generated,
1487
+ "chunks_delivered": last_delivered,
1488
+ "chunks_queued": queued,
1489
+ "bpm": worker.params.bpm,
1490
+ "bars_per_chunk": worker.params.bars_per_chunk,
1491
+ "created_at": worker.created_at if hasattr(worker, 'created_at') else None,
1492
+ "last_activity_at": last_activity,
1493
+ })
1494
+
1495
+ return {
1496
+ "sessions": sessions,
1497
+ "total_active": len(sessions),
1498
+ "mrt_pool_size": len(_MRT_POOL),
1499
+ }
1500
+
1501
+ @app.get("/jam/sessions")
1502
+ def jam_sessions():
1503
+ """List all active JAM sessions with metadata for monitoring"""
1504
+ sessions = []
1505
+ current_time = time.time()
1506
+
1507
+ with jam_lock:
1508
+ for session_id, session_info in jam_registry.items():
1509
+ worker = session_info['worker']
1510
+ mrt_index = session_info['mrt_index']
1511
+
1512
+ # Calculate uptime and idle time
1513
+ uptime = current_time - worker.created_at if hasattr(worker, 'created_at') else 0
1514
+ last_activity = worker.last_activity_at if hasattr(worker, 'last_activity_at') else (worker.created_at if hasattr(worker, 'created_at') else current_time)
1515
+ idle_time = current_time - last_activity
1516
+
1517
+ # Get generation stats
1518
+ with worker._lock:
1519
+ last_generated = int(worker.idx)
1520
+ last_delivered = int(worker._next_to_deliver) - 1
1521
+ queued = len(worker._outbox)
1522
+
1523
+ sessions.append({
1524
+ "session_id": session_id,
1525
+ "mrt_slot": mrt_index,
1526
+ "running": worker.is_alive(),
1527
+ "uptime_seconds": round(uptime, 1),
1528
+ "idle_seconds": round(idle_time, 1),
1529
+ "chunks_generated": last_generated,
1530
+ "chunks_delivered": last_delivered,
1531
+ "chunks_queued": queued,
1532
+ "bpm": worker.params.bpm,
1533
+ "bars_per_chunk": worker.params.bars_per_chunk,
1534
+ "created_at": worker.created_at if hasattr(worker, 'created_at') else None,
1535
+ "last_activity_at": last_activity,
1536
+ })
1537
+
1538
+ return {
1539
+ "sessions": sessions,
1540
+ "total_active": len(sessions),
1541
+ "mrt_pool_size": len(_MRT_POOL),
1542
+ }
1543
 
1544
  @app.get("/health")
1545
  def health():
 
1572
  # 3) Ready; include operational hints
1573
  warmed = bool(_WARMED)
1574
  with jam_lock:
1575
+ active_jams = sum(1 for info in jam_registry.values() if info['worker'].is_alive())
1576
  return {
1577
  "ok": True,
1578
  "status": "ready" if warmed else "initializing",
 
1696
  sid = str(uuid.uuid4())
1697
  with jam_lock:
1698
  # single active jam per GPU, mirroring /jam/start
1699
+ for _sid, info in list(jam_registry.items()):
1700
+ if info['worker'].is_alive():
1701
  await send_json({"type":"error","error":"A jam is already running"})
1702
  worker = None; sid = None
1703
  break
1704
  if worker is not None:
1705
+ jam_registry[sid] = {'worker': worker, 'mrt_index': None} # WebSocket mode doesn't use MRT pool
1706
  worker.start()
1707
 
1708
  else:
1709
+ # mode == "rt" (Colab-style, no loop context)
1710
  mrt = get_mrt()
1711
  state = mrt.init_state()
1712
 
1713
+ # Build silent context (10s) tokens
1714
  codec_fps = float(mrt.codec.frame_rate)
1715
  ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
1716
  sr = int(mrt.sample_rate)
1717
+ samples = int(max(1, round(ctx_seconds * sr)))
1718
+ silent = au.Waveform(np.zeros((samples, 2), np.float32), sr)
1719
+ tokens = mrt.codec.encode(silent).astype(np.int32)[:, :mrt.config.decoder_codec_rvq_depth]
1720
+ state.context_tokens = tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1721
 
1722
  # Parse params (including steering)
1723
  asset_manager.ensure_assets_loaded(get_mrt())
 
1725
  style_weights_str = params.get("style_weights", "") or ""
1726
  mean_w = float(params.get("mean", 0.0) or 0.0)
1727
  cw_str = str(params.get("centroid_weights", "") or "")
 
1728
 
1729
  text_list = [s.strip() for s in styles_str.split(",") if s.strip()]
1730
  try:
 
1740
  if _CENTROIDS is not None and len(cw) > int(_CENTROIDS.shape[0]):
1741
  cw = cw[: int(_CENTROIDS.shape[0])]
1742
 
1743
+ # Build initial style vector (no loop_embed in rt mode)
1744
  style_vec = build_style_vector(
1745
  mrt,
1746
  text_styles=text_list,
1747
  text_weights=text_w,
1748
+ loop_embed=None,
1749
+ loop_weight=None,
1750
  mean_weight=mean_w,
1751
  centroid_weights=cw,
1752
  )
 
1905
 
1906
  elif mtype == "consume" and mode == "bar":
1907
  with jam_lock:
1908
+ session_info = jam_registry.get(msg.get("session_id"))
1909
+ if session_info is not None:
1910
+ session_info['worker'].mark_chunk_consumed(int(msg.get("chunk_index", -1)))
1911
 
1912
  elif mtype == "reseed" and mode == "bar":
1913
  with jam_lock:
1914
+ session_info = jam_registry.get(msg.get("session_id"))
1915
+ if session_info is None or not session_info['worker'].is_alive():
1916
  await send_json({"type":"error","error":"Session not found"}); continue
1917
+ worker = session_info['worker']
1918
  loop_b64 = msg.get("loop_audio_b64")
1919
  if not loop_b64:
1920
  await send_json({"type":"error","error":"loop_audio_b64 required"}); continue
 
1927
 
1928
  elif mtype == "reseed_splice" and mode == "bar":
1929
  with jam_lock:
1930
+ session_info = jam_registry.get(msg.get("session_id"))
1931
+ if session_info is None or not session_info['worker'].is_alive():
1932
  await send_json({"type":"error","error":"Session not found"}); continue
1933
+ worker = session_info['worker']
1934
  anchor = float(msg.get("anchor_bars", 2.0))
1935
  b64 = msg.get("combined_audio_b64")
1936
  if b64: