Commit
·
c29a250
1
Parent(s):
3d79c33
sometimes a claude yolo
Browse files- jam_worker.py +167 -226
jam_worker.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# jam_worker.py -
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import os
|
|
@@ -20,7 +20,6 @@ from utils import (
|
|
| 20 |
)
|
| 21 |
|
| 22 |
def _dbg_rms_dbfs(x: np.ndarray) -> float:
|
| 23 |
-
|
| 24 |
if x.ndim == 2:
|
| 25 |
x = x.mean(axis=1)
|
| 26 |
r = float(np.sqrt(np.mean(x * x) + 1e-12))
|
|
@@ -28,7 +27,6 @@ def _dbg_rms_dbfs(x: np.ndarray) -> float:
|
|
| 28 |
|
| 29 |
def _dbg_rms_dbfs_model(x: np.ndarray) -> float:
|
| 30 |
# x is model-rate, shape [S,C] or [S]
|
| 31 |
-
|
| 32 |
if x.ndim == 2:
|
| 33 |
x = x.mean(axis=1)
|
| 34 |
r = float(np.sqrt(np.mean(x * x) + 1e-12))
|
|
@@ -37,6 +35,19 @@ def _dbg_rms_dbfs_model(x: np.ndarray) -> float:
|
|
| 37 |
def _dbg_shape(x):
|
| 38 |
return tuple(x.shape) if hasattr(x, "shape") else ("-",)
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
# -----------------------------
|
| 41 |
# Data classes
|
| 42 |
# -----------------------------
|
|
@@ -55,7 +66,7 @@ class JamParams:
|
|
| 55 |
guidance_weight: float = 1.1
|
| 56 |
temperature: float = 1.1
|
| 57 |
topk: int = 40
|
| 58 |
-
style_ramp_seconds: float = 8.0
|
| 59 |
|
| 60 |
|
| 61 |
@dataclass
|
|
@@ -110,8 +121,6 @@ class JamWorker(threading.Thread):
|
|
| 110 |
self.mrt.temperature = float(self.params.temperature)
|
| 111 |
self.mrt.topk = int(self.params.topk)
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
# codec/setup
|
| 116 |
self._codec_fps = float(self.mrt.codec.frame_rate)
|
| 117 |
JamWorker.FRAMES_PER_SECOND = self._codec_fps
|
|
@@ -137,8 +146,9 @@ class JamWorker(threading.Thread):
|
|
| 137 |
self._spool = np.zeros((0, 2), dtype=np.float32) # (S,2) target SR
|
| 138 |
self._spool_written = 0 # absolute frames written into spool
|
| 139 |
|
| 140 |
-
|
| 141 |
-
self.
|
|
|
|
| 142 |
|
| 143 |
# bar clock: start with offset 0; if you have a downbeat estimator, set base later
|
| 144 |
self._bar_clock = BarClock(self.params.target_sr, self.params.bpm, self.params.beats_per_bar, base_offset_samples=0)
|
|
@@ -163,6 +173,47 @@ class JamWorker(threading.Thread):
|
|
| 163 |
# Prepare initial context from combined loop (best musical alignment)
|
| 164 |
if self.params.combined_loop is not None:
|
| 165 |
self._install_context_from_loop(self.params.combined_loop)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
# ---------- lifecycle ----------
|
| 168 |
|
|
@@ -248,13 +299,7 @@ class JamWorker(threading.Thread):
|
|
| 248 |
return toks
|
| 249 |
|
| 250 |
def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
|
| 251 |
-
"""Build *exactly* context_length_frames worth of tokens
|
| 252 |
-
while ensuring the *end* of the audio lands on a bar boundary.
|
| 253 |
-
Strategy: take the largest integer number of bars <= ctx_seconds as the tail,
|
| 254 |
-
then left-fill from just before that tail (wrapping if needed) to reach exactly
|
| 255 |
-
ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim
|
| 256 |
-
tokens to the expected frame count.
|
| 257 |
-
"""
|
| 258 |
wav = loop.as_stereo().resample(self._model_sr)
|
| 259 |
data = wav.samples.astype(np.float32, copy=False)
|
| 260 |
if data.ndim == 1:
|
|
@@ -289,8 +334,14 @@ class JamWorker(threading.Thread):
|
|
| 289 |
|
| 290 |
# final snap to *exact* ctx samples
|
| 291 |
if ctx.shape[0] < ctx_samps:
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
elif ctx.shape[0] > ctx_samps:
|
| 295 |
ctx = ctx[-ctx_samps:]
|
| 296 |
|
|
@@ -301,79 +352,20 @@ class JamWorker(threading.Thread):
|
|
| 301 |
|
| 302 |
# Force expected (F,D) at *return time*
|
| 303 |
tokens = self._coerce_tokens(tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
return tokens
|
| 305 |
|
| 306 |
-
def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
|
| 307 |
-
"""Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps),
|
| 308 |
-
while ensuring the *end* of the audio lands on a bar boundary.
|
| 309 |
-
Strategy: take the largest integer number of bars <= ctx_seconds as the tail,
|
| 310 |
-
then left-fill from just before that tail (wrapping if needed) to reach exactly
|
| 311 |
-
ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim
|
| 312 |
-
tokens to the expected frame count.
|
| 313 |
-
"""
|
| 314 |
-
wav = loop.as_stereo().resample(self._model_sr)
|
| 315 |
-
data = wav.samples.astype(np.float32, copy=False)
|
| 316 |
-
if data.ndim == 1:
|
| 317 |
-
data = data[:, None]
|
| 318 |
-
|
| 319 |
-
spb = self._bar_clock.seconds_per_bar()
|
| 320 |
-
ctx_sec = float(self._ctx_seconds)
|
| 321 |
-
sr = int(self._model_sr)
|
| 322 |
-
|
| 323 |
-
# bars that fit fully inside ctx_sec (at least 1)
|
| 324 |
-
bars_fit = max(1, int(ctx_sec // spb))
|
| 325 |
-
tail_len_samps = int(round(bars_fit * spb * sr))
|
| 326 |
-
|
| 327 |
-
# ensure we have enough source by tiling
|
| 328 |
-
need = int(round(ctx_sec * sr)) + tail_len_samps
|
| 329 |
-
if data.shape[0] == 0:
|
| 330 |
-
data = np.zeros((1, 2), dtype=np.float32)
|
| 331 |
-
reps = int(np.ceil(need / float(data.shape[0])))
|
| 332 |
-
tiled = np.tile(data, (reps, 1))
|
| 333 |
-
|
| 334 |
-
end = tiled.shape[0]
|
| 335 |
-
tail = tiled[end - tail_len_samps:end]
|
| 336 |
-
|
| 337 |
-
# left-fill to reach exact ctx samples (keeps end-of-bar alignment)
|
| 338 |
-
ctx_samps = int(round(ctx_sec * sr))
|
| 339 |
-
pad_len = ctx_samps - tail.shape[0]
|
| 340 |
-
if pad_len > 0:
|
| 341 |
-
pre = tiled[end - tail_len_samps - pad_len:end - tail_len_samps]
|
| 342 |
-
ctx = np.concatenate([pre, tail], axis=0)
|
| 343 |
-
else:
|
| 344 |
-
ctx = tail[-ctx_samps:]
|
| 345 |
-
|
| 346 |
-
# final snap to *exact* ctx samples
|
| 347 |
-
if ctx.shape[0] < ctx_samps:
|
| 348 |
-
pad = np.zeros((ctx_samps - ctx.shape[0], ctx.shape[1]), dtype=np.float32)
|
| 349 |
-
ctx = np.concatenate([pad, ctx], axis=0)
|
| 350 |
-
elif ctx.shape[0] > ctx_samps:
|
| 351 |
-
ctx = ctx[-ctx_samps:]
|
| 352 |
-
|
| 353 |
-
exact = au.Waveform(ctx, sr)
|
| 354 |
-
tokens_full = self.mrt.codec.encode(exact).astype(np.int32)
|
| 355 |
-
depth = int(self.mrt.config.decoder_codec_rvq_depth)
|
| 356 |
-
tokens = tokens_full[:, :depth]
|
| 357 |
-
|
| 358 |
-
# Last defense: force expected frame count
|
| 359 |
-
frames = tokens.shape[0]
|
| 360 |
-
exp = int(self._ctx_frames)
|
| 361 |
-
if frames < exp:
|
| 362 |
-
# repeat last frame
|
| 363 |
-
pad = np.repeat(tokens[-1:, :], exp - frames, axis=0)
|
| 364 |
-
tokens = np.concatenate([pad, tokens], axis=0)
|
| 365 |
-
elif frames > exp:
|
| 366 |
-
tokens = tokens[-exp:, :]
|
| 367 |
-
return tokens
|
| 368 |
-
|
| 369 |
-
|
| 370 |
def _install_context_from_loop(self, loop: au.Waveform):
|
| 371 |
# Build exact-length, bar-locked context tokens
|
| 372 |
context_tokens = self._encode_exact_context_tokens(loop)
|
| 373 |
s = self.mrt.init_state()
|
| 374 |
s.context_tokens = context_tokens
|
| 375 |
self.state = s
|
| 376 |
-
self.
|
| 377 |
|
| 378 |
def reseed_from_waveform(self, wav: au.Waveform):
|
| 379 |
"""Immediate reseed: replace context from provided wave (bar-locked, exact length)."""
|
|
@@ -383,14 +375,11 @@ class JamWorker(threading.Thread):
|
|
| 383 |
s.context_tokens = context_tokens
|
| 384 |
self.state = s
|
| 385 |
self._model_stream = None # drop model-domain continuity so next chunk starts cleanly
|
| 386 |
-
self.
|
|
|
|
| 387 |
|
| 388 |
def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
|
| 389 |
-
"""Queue a *seamless* reseed by token splicing instead of full restart.
|
| 390 |
-
We compute a fresh, bar-locked context token tensor of exact length
|
| 391 |
-
(e.g., 250 frames), then splice only the *tail* corresponding to
|
| 392 |
-
`anchor_bars` so generation continues smoothly without resetting state.
|
| 393 |
-
"""
|
| 394 |
new_ctx = self._encode_exact_context_tokens(recent_wav) # coerce to (F,D)
|
| 395 |
F, D = self._expected_token_shape()
|
| 396 |
|
|
@@ -419,44 +408,20 @@ class JamWorker(threading.Thread):
|
|
| 419 |
"tokens": spliced,
|
| 420 |
"debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar}
|
| 421 |
}
|
| 422 |
-
|
| 423 |
|
| 424 |
-
|
| 425 |
-
def reseed_from_waveform(self, wav: au.Waveform):
|
| 426 |
-
"""Immediate reseed: replace context from provided wave (bar-aligned tail)."""
|
| 427 |
-
wav = wav.as_stereo().resample(self._model_sr)
|
| 428 |
-
tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, self._ctx_seconds)
|
| 429 |
-
tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
|
| 430 |
-
depth = int(self.mrt.config.decoder_codec_rvq_depth)
|
| 431 |
-
context_tokens = tokens_full[:, :depth]
|
| 432 |
-
|
| 433 |
-
s = self.mrt.init_state()
|
| 434 |
-
s.context_tokens = context_tokens
|
| 435 |
-
self.state = s
|
| 436 |
-
# reset model stream so next generate starts cleanly
|
| 437 |
-
self._model_stream = None
|
| 438 |
-
|
| 439 |
-
# optional loudness match will be applied per-chunk on emission
|
| 440 |
-
|
| 441 |
-
# also remember this as new "original"
|
| 442 |
-
self._original_context_tokens = np.copy(context_tokens)
|
| 443 |
-
|
| 444 |
-
# ---------- core streaming helpers ----------
|
| 445 |
|
| 446 |
def _append_model_chunk_and_spool(self, wav: au.Waveform) -> None:
|
| 447 |
"""
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
This keeps external timing and bar alignment identical, but removes the audible
|
| 456 |
-
fade-to-zero at chunk ends.
|
| 457 |
"""
|
| 458 |
-
|
| 459 |
-
# ---- unpack model-rate samples ----
|
| 460 |
s = wav.samples.astype(np.float32, copy=False)
|
| 461 |
if s.ndim == 1:
|
| 462 |
s = s[:, None]
|
|
@@ -464,119 +429,90 @@ class JamWorker(threading.Thread):
|
|
| 464 |
if n_samps == 0:
|
| 465 |
return
|
| 466 |
|
| 467 |
-
#
|
|
|
|
|
|
|
|
|
|
| 468 |
try:
|
| 469 |
xfade_s = float(self.mrt.config.crossfade_length)
|
| 470 |
except Exception:
|
| 471 |
xfade_s = 0.0
|
| 472 |
xfade_n = int(round(max(0.0, xfade_s) * float(self._model_sr)))
|
| 473 |
|
| 474 |
-
|
|
|
|
|
|
|
| 475 |
def to_target(y: np.ndarray) -> np.ndarray:
|
| 476 |
return y if self._rs is None else self._rs.process(y, final=False)
|
| 477 |
|
| 478 |
-
#
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
# DEBUG: corrected overlap RMS (what we intend to hear at the boundary)
|
| 495 |
-
if y_mixed.size:
|
| 496 |
-
print(f"[append] mixedOverlap len={y_mixed.shape[0]} rms={_dbg_rms_dbfs(y_mixed):+.1f} dBFS")
|
| 497 |
-
|
| 498 |
-
# Overwrite the last `_pending_tail_target_len` samples of the spool with `y_mixed`.
|
| 499 |
-
# Use the *smaller* of the two lengths to be safe.
|
| 500 |
-
Lpop = min(self._pending_tail_target_len, self._spool.shape[0], Lcorr)
|
| 501 |
-
if Lpop > 0 and self._spool.size:
|
| 502 |
-
# Trim last Lpop samples
|
| 503 |
-
self._spool = self._spool[:-Lpop, :]
|
| 504 |
-
self._spool_written -= Lpop
|
| 505 |
-
# Append corrected overlap (trim/pad to Lpop to avoid drift)
|
| 506 |
-
if Lcorr != Lpop:
|
| 507 |
-
if Lcorr > Lpop:
|
| 508 |
-
y_m = y_mixed[-Lpop:, :]
|
| 509 |
-
else:
|
| 510 |
-
pad = np.zeros((Lpop - Lcorr, y_mixed.shape[1]), dtype=np.float32)
|
| 511 |
-
y_m = np.concatenate([y_mixed, pad], axis=0)
|
| 512 |
-
else:
|
| 513 |
-
y_m = y_mixed
|
| 514 |
-
self._spool = np.concatenate([self._spool, y_m], axis=0) if self._spool.size else y_m
|
| 515 |
-
self._spool_written += y_m.shape[0]
|
| 516 |
-
|
| 517 |
-
# For internal continuity, update _model_stream like before
|
| 518 |
-
if self._model_stream is None or self._model_stream.shape[0] < xfade_n:
|
| 519 |
-
self._model_stream = s[xfade_n:].copy()
|
| 520 |
-
else:
|
| 521 |
-
self._model_stream = np.concatenate([self._model_stream[:-xfade_n], mixed_model, s[xfade_n:]], axis=0)
|
| 522 |
-
else:
|
| 523 |
-
# First-ever call or too-short to mix: maintain _model_stream minimally
|
| 524 |
-
if xfade_n > 0 and n_samps > xfade_n:
|
| 525 |
-
self._model_stream = s[xfade_n:].copy() if self._model_stream is None else np.concatenate([self._model_stream, s[xfade_n:]], axis=0)
|
| 526 |
-
else:
|
| 527 |
-
self._model_stream = s.copy() if self._model_stream is None else np.concatenate([self._model_stream, s], axis=0)
|
| 528 |
-
|
| 529 |
-
# ------------------------------------------
|
| 530 |
-
# (B) Emit THIS chunk's body and tail (same external behavior)
|
| 531 |
-
# ------------------------------------------
|
| 532 |
-
if xfade_n > 0 and n_samps >= (2 * xfade_n):
|
| 533 |
-
body = s[xfade_n:-xfade_n, :]
|
| 534 |
-
print(f"[model] body len={body.shape[0]} rms={_dbg_rms_dbfs_model(body):+.1f} dBFS")
|
| 535 |
-
if body.size:
|
| 536 |
-
y_body = to_target(body.astype(np.float32))
|
| 537 |
-
if y_body.size:
|
| 538 |
-
# DEBUG: body RMS we are actually appending
|
| 539 |
-
print(f"[append] body len={y_body.shape[0]} rms={_dbg_rms_dbfs(y_body):+.1f} dBFS")
|
| 540 |
-
self._spool = np.concatenate([self._spool, y_body], axis=0) if self._spool.size else y_body
|
| 541 |
-
self._spool_written += y_body.shape[0]
|
| 542 |
else:
|
| 543 |
-
#
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 559 |
if xfade_n > 0 and n_samps >= xfade_n:
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
y_tail = to_target(tail.astype(np.float32))
|
| 563 |
-
Ltail = int(y_tail.shape[0])
|
| 564 |
-
if Ltail:
|
| 565 |
-
# DEBUG: tail RMS we are appending now (to be corrected next call)
|
| 566 |
-
print(f"[append] tail len={y_tail.shape[0]} rms={_dbg_rms_dbfs(y_tail):+.1f} dBFS")
|
| 567 |
-
self._spool = np.concatenate([self._spool, y_tail], axis=0) if self._spool.size else y_tail
|
| 568 |
-
self._spool_written += Ltail
|
| 569 |
-
self._pending_tail_model = tail.copy()
|
| 570 |
-
self._pending_tail_target_len = Ltail
|
| 571 |
-
else:
|
| 572 |
-
# Nothing appended (resampler returned nothing yet) — keep model tail but mark zero target len
|
| 573 |
-
self._pending_tail_model = tail.copy()
|
| 574 |
-
self._pending_tail_target_len = 0
|
| 575 |
else:
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
|
| 581 |
def _should_generate_next_chunk(self) -> bool:
|
| 582 |
# Allow running ahead relative to whichever is larger: last *consumed*
|
|
@@ -613,6 +549,7 @@ class JamWorker(threading.Thread):
|
|
| 613 |
"guidance_weight": float(self.params.guidance_weight),
|
| 614 |
"temperature": float(self.params.temperature),
|
| 615 |
"topk": int(self.params.topk),
|
|
|
|
| 616 |
}
|
| 617 |
chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta)
|
| 618 |
|
|
@@ -637,6 +574,7 @@ class JamWorker(threading.Thread):
|
|
| 637 |
# inplace update (no reset)
|
| 638 |
self.state.context_tokens = spliced
|
| 639 |
self._pending_token_splice = None
|
|
|
|
| 640 |
except Exception:
|
| 641 |
# fallback: full reseed using spliced tokens
|
| 642 |
new_state = self.mrt.init_state()
|
|
@@ -644,6 +582,7 @@ class JamWorker(threading.Thread):
|
|
| 644 |
self.state = new_state
|
| 645 |
self._model_stream = None
|
| 646 |
self._pending_token_splice = None
|
|
|
|
| 647 |
elif self._pending_reseed is not None:
|
| 648 |
ctx = self._coerce_tokens(self._pending_reseed["ctx"])
|
| 649 |
new_state = self.mrt.init_state()
|
|
@@ -651,6 +590,7 @@ class JamWorker(threading.Thread):
|
|
| 651 |
self.state = new_state
|
| 652 |
self._model_stream = None
|
| 653 |
self._pending_reseed = None
|
|
|
|
| 654 |
|
| 655 |
# ---------- main loop ----------
|
| 656 |
|
|
@@ -687,9 +627,10 @@ class JamWorker(threading.Thread):
|
|
| 687 |
self._emit_ready()
|
| 688 |
|
| 689 |
# finalize resampler (flush) — not strictly necessary here
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
|
|
|
| 694 |
# one last emit attempt
|
| 695 |
-
self._emit_ready()
|
|
|
|
| 1 |
+
# jam_worker.py - Updated with robust silence handling
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import os
|
|
|
|
| 20 |
)
|
| 21 |
|
| 22 |
def _dbg_rms_dbfs(x: np.ndarray) -> float:
|
|
|
|
| 23 |
if x.ndim == 2:
|
| 24 |
x = x.mean(axis=1)
|
| 25 |
r = float(np.sqrt(np.mean(x * x) + 1e-12))
|
|
|
|
| 27 |
|
| 28 |
def _dbg_rms_dbfs_model(x: np.ndarray) -> float:
|
| 29 |
# x is model-rate, shape [S,C] or [S]
|
|
|
|
| 30 |
if x.ndim == 2:
|
| 31 |
x = x.mean(axis=1)
|
| 32 |
r = float(np.sqrt(np.mean(x * x) + 1e-12))
|
|
|
|
| 35 |
def _dbg_shape(x):
|
| 36 |
return tuple(x.shape) if hasattr(x, "shape") else ("-",)
|
| 37 |
|
| 38 |
+
def _is_silent(audio: np.ndarray, threshold_db: float = -60.0) -> bool:
|
| 39 |
+
"""Check if audio is effectively silent."""
|
| 40 |
+
if audio.size == 0:
|
| 41 |
+
return True
|
| 42 |
+
if audio.ndim == 2:
|
| 43 |
+
audio = audio.mean(axis=1)
|
| 44 |
+
rms = float(np.sqrt(np.mean(audio**2)))
|
| 45 |
+
return 20.0 * np.log10(max(rms, 1e-12)) < threshold_db
|
| 46 |
+
|
| 47 |
+
def _has_energy(audio: np.ndarray, threshold_db: float = -40.0) -> bool:
|
| 48 |
+
"""Check if audio has significant energy (stricter than just non-silent)."""
|
| 49 |
+
return not _is_silent(audio, threshold_db)
|
| 50 |
+
|
| 51 |
# -----------------------------
|
| 52 |
# Data classes
|
| 53 |
# -----------------------------
|
|
|
|
| 66 |
guidance_weight: float = 1.1
|
| 67 |
temperature: float = 1.1
|
| 68 |
topk: int = 40
|
| 69 |
+
style_ramp_seconds: float = 8.0
|
| 70 |
|
| 71 |
|
| 72 |
@dataclass
|
|
|
|
| 121 |
self.mrt.temperature = float(self.params.temperature)
|
| 122 |
self.mrt.topk = int(self.params.topk)
|
| 123 |
|
|
|
|
|
|
|
| 124 |
# codec/setup
|
| 125 |
self._codec_fps = float(self.mrt.codec.frame_rate)
|
| 126 |
JamWorker.FRAMES_PER_SECOND = self._codec_fps
|
|
|
|
| 146 |
self._spool = np.zeros((0, 2), dtype=np.float32) # (S,2) target SR
|
| 147 |
self._spool_written = 0 # absolute frames written into spool
|
| 148 |
|
| 149 |
+
# Health monitoring
|
| 150 |
+
self._silence_streak = 0 # consecutive silent chunks
|
| 151 |
+
self._last_good_context_tokens = None # backup of last known good context
|
| 152 |
|
| 153 |
# bar clock: start with offset 0; if you have a downbeat estimator, set base later
|
| 154 |
self._bar_clock = BarClock(self.params.target_sr, self.params.bpm, self.params.beats_per_bar, base_offset_samples=0)
|
|
|
|
| 173 |
# Prepare initial context from combined loop (best musical alignment)
|
| 174 |
if self.params.combined_loop is not None:
|
| 175 |
self._install_context_from_loop(self.params.combined_loop)
|
| 176 |
+
# Save this as our "good" context backup
|
| 177 |
+
if hasattr(self.state, 'context_tokens') and self.state.context_tokens is not None:
|
| 178 |
+
self._last_good_context_tokens = np.copy(self.state.context_tokens)
|
| 179 |
+
|
| 180 |
+
# ---------- NEW: Health monitoring methods ----------
|
| 181 |
+
|
| 182 |
+
def _check_model_health(self, new_chunk: np.ndarray) -> bool:
|
| 183 |
+
"""Check if the model output looks healthy."""
|
| 184 |
+
if _is_silent(new_chunk, threshold_db=-80.0):
|
| 185 |
+
self._silence_streak += 1
|
| 186 |
+
print(f"⚠️ Silent chunk detected (streak: {self._silence_streak})")
|
| 187 |
+
return False
|
| 188 |
+
else:
|
| 189 |
+
if self._silence_streak > 0:
|
| 190 |
+
print(f"✅ Audio resumed after {self._silence_streak} silent chunks")
|
| 191 |
+
self._silence_streak = 0
|
| 192 |
+
return True
|
| 193 |
+
|
| 194 |
+
def _recover_from_silence(self):
|
| 195 |
+
"""Attempt to recover from silence by restoring last good context."""
|
| 196 |
+
print("🔧 Attempting recovery from silence...")
|
| 197 |
+
|
| 198 |
+
if self._last_good_context_tokens is not None:
|
| 199 |
+
# Restore last known good context
|
| 200 |
+
try:
|
| 201 |
+
new_state = self.mrt.init_state()
|
| 202 |
+
new_state.context_tokens = np.copy(self._last_good_context_tokens)
|
| 203 |
+
self.state = new_state
|
| 204 |
+
self._model_stream = None # Reset stream to start fresh
|
| 205 |
+
print(" Restored last good context")
|
| 206 |
+
except Exception as e:
|
| 207 |
+
print(f" Context restoration failed: {e}")
|
| 208 |
+
|
| 209 |
+
# If we have the original loop, rebuild context from it
|
| 210 |
+
elif self.params.combined_loop is not None:
|
| 211 |
+
try:
|
| 212 |
+
self._install_context_from_loop(self.params.combined_loop)
|
| 213 |
+
self._model_stream = None
|
| 214 |
+
print(" Rebuilt context from original loop")
|
| 215 |
+
except Exception as e:
|
| 216 |
+
print(f" Context rebuild failed: {e}")
|
| 217 |
|
| 218 |
# ---------- lifecycle ----------
|
| 219 |
|
|
|
|
| 299 |
return toks
|
| 300 |
|
| 301 |
def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
|
| 302 |
+
"""Build *exactly* context_length_frames worth of tokens, ensuring bar alignment."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
wav = loop.as_stereo().resample(self._model_sr)
|
| 304 |
data = wav.samples.astype(np.float32, copy=False)
|
| 305 |
if data.ndim == 1:
|
|
|
|
| 334 |
|
| 335 |
# final snap to *exact* ctx samples
|
| 336 |
if ctx.shape[0] < ctx_samps:
|
| 337 |
+
# Instead of zero padding, repeat the audio to fill
|
| 338 |
+
shortfall = ctx_samps - ctx.shape[0]
|
| 339 |
+
if ctx.shape[0] > 0:
|
| 340 |
+
fill = np.tile(ctx, (int(np.ceil(shortfall / ctx.shape[0])) + 1, 1))[:shortfall]
|
| 341 |
+
ctx = np.concatenate([fill, ctx], axis=0)
|
| 342 |
+
else:
|
| 343 |
+
print("⚠️ Zero-length context, using fallback")
|
| 344 |
+
ctx = np.zeros((ctx_samps, 2), dtype=np.float32)
|
| 345 |
elif ctx.shape[0] > ctx_samps:
|
| 346 |
ctx = ctx[-ctx_samps:]
|
| 347 |
|
|
|
|
| 352 |
|
| 353 |
# Force expected (F,D) at *return time*
|
| 354 |
tokens = self._coerce_tokens(tokens)
|
| 355 |
+
|
| 356 |
+
# Validate that we don't have a silent context
|
| 357 |
+
if _is_silent(ctx, threshold_db=-80.0):
|
| 358 |
+
print("⚠️ Generated silent context - this may cause issues")
|
| 359 |
+
|
| 360 |
return tokens
|
| 361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
def _install_context_from_loop(self, loop: au.Waveform):
|
| 363 |
# Build exact-length, bar-locked context tokens
|
| 364 |
context_tokens = self._encode_exact_context_tokens(loop)
|
| 365 |
s = self.mrt.init_state()
|
| 366 |
s.context_tokens = context_tokens
|
| 367 |
self.state = s
|
| 368 |
+
self._last_good_context_tokens = np.copy(context_tokens)
|
| 369 |
|
| 370 |
def reseed_from_waveform(self, wav: au.Waveform):
|
| 371 |
"""Immediate reseed: replace context from provided wave (bar-locked, exact length)."""
|
|
|
|
| 375 |
s.context_tokens = context_tokens
|
| 376 |
self.state = s
|
| 377 |
self._model_stream = None # drop model-domain continuity so next chunk starts cleanly
|
| 378 |
+
self._last_good_context_tokens = np.copy(context_tokens)
|
| 379 |
+
self._silence_streak = 0 # Reset health monitoring
|
| 380 |
|
| 381 |
def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
|
| 382 |
+
"""Queue a *seamless* reseed by token splicing instead of full restart."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
new_ctx = self._encode_exact_context_tokens(recent_wav) # coerce to (F,D)
|
| 384 |
F, D = self._expected_token_shape()
|
| 385 |
|
|
|
|
| 408 |
"tokens": spliced,
|
| 409 |
"debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar}
|
| 410 |
}
|
|
|
|
| 411 |
|
| 412 |
+
# ---------- REWRITTEN: core streaming helpers ----------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
|
| 414 |
def _append_model_chunk_and_spool(self, wav: au.Waveform) -> None:
|
| 415 |
"""
|
| 416 |
+
REWRITTEN: Robust audio processing with silence detection and health monitoring.
|
| 417 |
+
|
| 418 |
+
Strategy:
|
| 419 |
+
1. Validate input chunk for silence/issues
|
| 420 |
+
2. Use simpler crossfading that handles silence gracefully
|
| 421 |
+
3. Update model stream with health checks
|
| 422 |
+
4. Convert to target SR and append to spool
|
|
|
|
|
|
|
| 423 |
"""
|
| 424 |
+
# Unpack model-rate samples
|
|
|
|
| 425 |
s = wav.samples.astype(np.float32, copy=False)
|
| 426 |
if s.ndim == 1:
|
| 427 |
s = s[:, None]
|
|
|
|
| 429 |
if n_samps == 0:
|
| 430 |
return
|
| 431 |
|
| 432 |
+
# Health check on new chunk
|
| 433 |
+
is_healthy = self._check_model_health(s)
|
| 434 |
+
|
| 435 |
+
# Get crossfade params
|
| 436 |
try:
|
| 437 |
xfade_s = float(self.mrt.config.crossfade_length)
|
| 438 |
except Exception:
|
| 439 |
xfade_s = 0.0
|
| 440 |
xfade_n = int(round(max(0.0, xfade_s) * float(self._model_sr)))
|
| 441 |
|
| 442 |
+
print(f"[model] chunk len={n_samps} rms={_dbg_rms_dbfs_model(s):+.1f} dBFS healthy={is_healthy}")
|
| 443 |
+
|
| 444 |
+
# Helper: resample to target SR
|
| 445 |
def to_target(y: np.ndarray) -> np.ndarray:
|
| 446 |
return y if self._rs is None else self._rs.process(y, final=False)
|
| 447 |
|
| 448 |
+
# --- SIMPLIFIED CROSSFADE LOGIC ---
|
| 449 |
+
|
| 450 |
+
if self._model_stream is None:
|
| 451 |
+
# First chunk - no crossfading needed
|
| 452 |
+
self._model_stream = s.copy()
|
| 453 |
+
|
| 454 |
+
elif xfade_n <= 0 or n_samps < xfade_n:
|
| 455 |
+
# No crossfade configured or chunk too short - simple append
|
| 456 |
+
self._model_stream = np.concatenate([self._model_stream, s], axis=0)
|
| 457 |
+
|
| 458 |
+
elif _is_silent(self._model_stream[-xfade_n:]) or _is_silent(s[:xfade_n]):
|
| 459 |
+
# One side is silent - don't crossfade, just append
|
| 460 |
+
print(f"[crossfade] Skipping crossfade due to silence")
|
| 461 |
+
self._model_stream = np.concatenate([self._model_stream, s], axis=0)
|
| 462 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
else:
|
| 464 |
+
# Normal crossfade between non-silent audio
|
| 465 |
+
tail = self._model_stream[-xfade_n:]
|
| 466 |
+
head = s[:xfade_n]
|
| 467 |
+
body = s[xfade_n:] if n_samps > xfade_n else np.zeros((0, s.shape[1]), dtype=np.float32)
|
| 468 |
+
|
| 469 |
+
# Equal power crossfade
|
| 470 |
+
t = np.linspace(0.0, 1.0, xfade_n, dtype=np.float32)[:, None]
|
| 471 |
+
fade_out = np.cos(t * np.pi / 2.0)
|
| 472 |
+
fade_in = np.sin(t * np.pi / 2.0)
|
| 473 |
+
|
| 474 |
+
mixed = tail * fade_out + head * fade_in
|
| 475 |
+
|
| 476 |
+
print(f"[crossfade] tail rms={_dbg_rms_dbfs_model(tail):+.1f} head rms={_dbg_rms_dbfs_model(head):+.1f} mixed rms={_dbg_rms_dbfs_model(mixed):+.1f}")
|
| 477 |
+
|
| 478 |
+
# Update model stream: remove old tail, add mixed section, add body
|
| 479 |
+
self._model_stream = np.concatenate([
|
| 480 |
+
self._model_stream[:-xfade_n],
|
| 481 |
+
mixed,
|
| 482 |
+
body
|
| 483 |
+
], axis=0)
|
| 484 |
+
|
| 485 |
+
# --- CONVERT AND APPEND TO SPOOL ---
|
| 486 |
+
|
| 487 |
+
# Take the new audio from this iteration (avoid reprocessing old audio)
|
| 488 |
if xfade_n > 0 and n_samps >= xfade_n:
|
| 489 |
+
# Normal case: body after crossfade region
|
| 490 |
+
new_audio = s[xfade_n:] if n_samps > xfade_n else s
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
else:
|
| 492 |
+
# Short chunk or no crossfade: use entire chunk
|
| 493 |
+
new_audio = s
|
| 494 |
+
|
| 495 |
+
if new_audio.shape[0] > 0:
|
| 496 |
+
target_audio = to_target(new_audio)
|
| 497 |
+
if target_audio.shape[0] > 0:
|
| 498 |
+
print(f"[append] body len={target_audio.shape[0]} rms={_dbg_rms_dbfs(target_audio):+.1f} dBFS")
|
| 499 |
+
self._spool = np.concatenate([self._spool, target_audio], axis=0) if self._spool.size else target_audio
|
| 500 |
+
self._spool_written += target_audio.shape[0]
|
| 501 |
+
|
| 502 |
+
# --- HEALTH MONITORING ---
|
| 503 |
+
|
| 504 |
+
if not is_healthy:
|
| 505 |
+
if self._silence_streak >= 3: # After 3 silent chunks, try to recover
|
| 506 |
+
self._recover_from_silence()
|
| 507 |
+
else:
|
| 508 |
+
# Save current context as "good" backup
|
| 509 |
+
if hasattr(self.state, 'context_tokens') and self.state.context_tokens is not None:
|
| 510 |
+
self._last_good_context_tokens = np.copy(self.state.context_tokens)
|
| 511 |
|
| 512 |
+
# Trim model stream to reasonable length (keep ~30 seconds)
|
| 513 |
+
max_model_samples = int(30.0 * self._model_sr)
|
| 514 |
+
if self._model_stream.shape[0] > max_model_samples:
|
| 515 |
+
self._model_stream = self._model_stream[-max_model_samples:]
|
| 516 |
|
| 517 |
def _should_generate_next_chunk(self) -> bool:
|
| 518 |
# Allow running ahead relative to whichever is larger: last *consumed*
|
|
|
|
| 549 |
"guidance_weight": float(self.params.guidance_weight),
|
| 550 |
"temperature": float(self.params.temperature),
|
| 551 |
"topk": int(self.params.topk),
|
| 552 |
+
"silence_streak": self._silence_streak, # Add health info
|
| 553 |
}
|
| 554 |
chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta)
|
| 555 |
|
|
|
|
| 574 |
# inplace update (no reset)
|
| 575 |
self.state.context_tokens = spliced
|
| 576 |
self._pending_token_splice = None
|
| 577 |
+
print("[reseed] Token splice applied")
|
| 578 |
except Exception:
|
| 579 |
# fallback: full reseed using spliced tokens
|
| 580 |
new_state = self.mrt.init_state()
|
|
|
|
| 582 |
self.state = new_state
|
| 583 |
self._model_stream = None
|
| 584 |
self._pending_token_splice = None
|
| 585 |
+
print("[reseed] Token splice fallback to full reset")
|
| 586 |
elif self._pending_reseed is not None:
|
| 587 |
ctx = self._coerce_tokens(self._pending_reseed["ctx"])
|
| 588 |
new_state = self.mrt.init_state()
|
|
|
|
| 590 |
self.state = new_state
|
| 591 |
self._model_stream = None
|
| 592 |
self._pending_reseed = None
|
| 593 |
+
print("[reseed] Full reseed applied")
|
| 594 |
|
| 595 |
# ---------- main loop ----------
|
| 596 |
|
|
|
|
| 627 |
self._emit_ready()
|
| 628 |
|
| 629 |
# finalize resampler (flush) — not strictly necessary here
|
| 630 |
+
if self._rs is not None:
|
| 631 |
+
tail = self._rs.process(np.zeros((0,2), np.float32), final=True)
|
| 632 |
+
if tail.size:
|
| 633 |
+
self._spool = np.concatenate([self._spool, tail], axis=0)
|
| 634 |
+
self._spool_written += tail.shape[0]
|
| 635 |
# one last emit attempt
|
| 636 |
+
self._emit_ready()
|