linoyts HF Staff commited on
Commit
dfd4eb6
·
1 Parent(s): a8fed3a

Revert "jam: single-dispatch CUDA-graph stepping in the WS worker (eager fallback)"

Browse files

This reverts commit a8fed3aa371c9b349e9ebf31ba5b4c5cc378644f.

app.py CHANGED
@@ -162,9 +162,6 @@ async def banks(session_id: str = ""):
162
  def gpu_stream(session_id):
163
  """Continuous gen; switches model live when the dropdown changes."""
164
  from magenta_rt.torch.system import make_sampler, discretize_cfg, _float_to_int16, FRAME_SAMPLES
165
- from magenta_rt.torch.modeling_magenta_rt2 import CudaGraphStreamer
166
- USE_CG = os.environ.get('MRT_CUDAGRAPH', '1') == '1' # single-dispatch CUDA-graph stepping (eager fallback)
167
- cg_ok = True
168
  if style_model.device != "cuda":
169
  style_model.to("cuda")
170
  dev, dt = "cuda", torch.bfloat16
@@ -186,7 +183,6 @@ def gpu_stream(session_id):
186
  print("[warmup]", repr(_e), flush=True)
187
  notes, drums = [-1] * 128, [-1]
188
  cur_name = model = dstate = source = None
189
- streamer = last_src_for_graph = None
190
  decode_state = {}
191
  emitted_samples = 0
192
  t0 = time.time()
@@ -211,14 +207,12 @@ def gpu_stream(session_id):
211
  decode_state = model.init_decode_state()
212
  emitted_samples, source, prev_active = 0, None, set()
213
  cur_style_sig = cur_note_sig = cur_tokens = None; had_onsets = False
214
- streamer = last_src_for_graph = None # rebuild CUDA graph on model switch / reset
215
  seed = int(c.get("seed", 0))
216
  if seed != cur_seed:
217
  cur_seed = seed
218
  gen = torch.Generator(device=dev).manual_seed(seed)
219
- streamer = None # re-seed => re-capture (graph RNG fixed at capture)
220
  bop = c.get("bank_op")
221
- if bop and int(bop.get("ver", 0)) != cur_bank_ver and not USE_CG: # save/recall (eager only; cudagraph KV is static)
222
  cur_bank_ver = int(bop.get("ver", 0))
223
  bpath = os.path.join(SESSION_DIR, f"{os.path.basename(session_id)}_bank{int(bop.get('idx', 0))}.pt")
224
  try:
@@ -297,26 +291,9 @@ def gpu_stream(session_id):
297
  cond = model._conditioning((list(cur_tokens) + [-1] * model.num_musiccoca)[:model.num_musiccoca],
298
  nvec, drm, cfgs)
299
  source = model.model.encode(cond).to(dt)
300
- temp = c.get("temperature", 1.1); topk = int(c.get("top_k", 50))
301
- ok = False
302
- if USE_CG and cg_ok: # single-dispatch CUDA-graph step
303
- try:
304
- if streamer is None: # build + capture on first frame (~2-3s warmup)
305
- streamer = CudaGraphStreamer(model.model.decoder, source, dt,
306
- temperature=temp, top_k=topk, seed=cur_seed)
307
- last_src_for_graph = source
308
- elif source is not last_src_for_graph: # conditioning changed -> update static buffer
309
- streamer.set_source(source); last_src_for_graph = source
310
- streamer.set_temperature(temp)
311
- toks.append(streamer.step()); ok = True
312
- except Exception as _cge:
313
- print("[cudagraph] fallback to eager:", repr(_cge), flush=True)
314
- cg_ok = False; streamer = None
315
- dstate = model.model.decoder.init_streaming_f(1, dev, dt)
316
- if not ok: # eager fallback path
317
- sampler = make_sampler(temp, topk, gen)
318
- toks.append(model.model.decoder.step_f(dstate, source, sampler=sampler,
319
- temporal_step=model._temporal_step, depth_step=model._depth_step))
320
  new_codes = torch.cat(toks, dim=1)
321
  audio = model.decode_stream(new_codes, decode_state) # FLOP-optimal stateful streaming decode
322
  emitted_samples += audio.shape[1]
 
162
  def gpu_stream(session_id):
163
  """Continuous gen; switches model live when the dropdown changes."""
164
  from magenta_rt.torch.system import make_sampler, discretize_cfg, _float_to_int16, FRAME_SAMPLES
 
 
 
165
  if style_model.device != "cuda":
166
  style_model.to("cuda")
167
  dev, dt = "cuda", torch.bfloat16
 
183
  print("[warmup]", repr(_e), flush=True)
184
  notes, drums = [-1] * 128, [-1]
185
  cur_name = model = dstate = source = None
 
186
  decode_state = {}
187
  emitted_samples = 0
188
  t0 = time.time()
 
207
  decode_state = model.init_decode_state()
208
  emitted_samples, source, prev_active = 0, None, set()
209
  cur_style_sig = cur_note_sig = cur_tokens = None; had_onsets = False
 
210
  seed = int(c.get("seed", 0))
211
  if seed != cur_seed:
212
  cur_seed = seed
213
  gen = torch.Generator(device=dev).manual_seed(seed)
 
214
  bop = c.get("bank_op")
215
+ if bop and int(bop.get("ver", 0)) != cur_bank_ver: # save/recall generation state
216
  cur_bank_ver = int(bop.get("ver", 0))
217
  bpath = os.path.join(SESSION_DIR, f"{os.path.basename(session_id)}_bank{int(bop.get('idx', 0))}.pt")
218
  try:
 
291
  cond = model._conditioning((list(cur_tokens) + [-1] * model.num_musiccoca)[:model.num_musiccoca],
292
  nvec, drm, cfgs)
293
  source = model.model.encode(cond).to(dt)
294
+ sampler = make_sampler(c.get("temperature", 1.1), c.get("top_k", 50), gen)
295
+ toks.append(model.model.decoder.step_f(dstate, source, sampler=sampler,
296
+ temporal_step=model._temporal_step, depth_step=model._depth_step))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  new_codes = torch.cat(toks, dim=1)
298
  audio = model.decode_stream(new_codes, decode_state) # FLOP-optimal stateful streaming decode
299
  emitted_samples += audio.shape[1]
magenta_rt/torch/configuration_magenta_rt2.py DELETED
@@ -1,115 +0,0 @@
1
- # Copyright 2026 Google LLC
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """HF config for the Magenta RealTime 2 PyTorch model."""
16
-
17
- from transformers import PretrainedConfig
18
-
19
-
20
- class MagentaRT2Config(PretrainedConfig):
21
- """Config for `MagentaRT2ForConditionalGeneration`.
22
-
23
- `temporal` / `depth` are [num_layers, model_dims, hidden_dims, num_heads,
24
- dim_per_head] for the two Depthformer transformer stacks.
25
- """
26
-
27
- model_type = "magenta_rt2"
28
-
29
- def __init__(
30
- self,
31
- size="mrt2_small",
32
- encoder_model_dims=256,
33
- temporal=(12, 1024, 4096, 8, 128),
34
- depth=(2, 768, 3072, 6, 128),
35
- temporal_max_past=41,
36
- depth_max_past=12,
37
- musiccoca_rvq=12,
38
- musiccoca_per_rvq_vocab=1031,
39
- musiccoca_embed_dim=768,
40
- regular_num_embeddings_per_channel=None,
41
- regular_num_channels=132,
42
- num_sinks=1,
43
- num_codebooks=12,
44
- codebook_size=1024,
45
- num_reserved_tokens=6,
46
- vocab_size=12294,
47
- soft_cap_logits=30.0,
48
- temperature=1.3,
49
- top_k=40,
50
- cfg_musiccoca=3.0,
51
- cfg_notes=1.0,
52
- cfg_drums=1.0,
53
- num_notes=128,
54
- num_drums=1,
55
- sample_rate=48000,
56
- frame_samples=1920,
57
- codec_param_shapes=None,
58
- **kwargs,
59
- ):
60
- self.size = size
61
- self.codec_param_shapes = codec_param_shapes
62
- self.encoder_model_dims = encoder_model_dims
63
- self.temporal = list(temporal)
64
- self.depth = list(depth)
65
- self.temporal_max_past = temporal_max_past
66
- self.depth_max_past = depth_max_past
67
- self.musiccoca_rvq = musiccoca_rvq
68
- self.musiccoca_per_rvq_vocab = musiccoca_per_rvq_vocab
69
- self.musiccoca_embed_dim = musiccoca_embed_dim
70
- self.regular_num_embeddings_per_channel = regular_num_embeddings_per_channel
71
- self.regular_num_channels = regular_num_channels
72
- self.num_sinks = num_sinks
73
- self.num_codebooks = num_codebooks
74
- self.codebook_size = codebook_size
75
- self.num_reserved_tokens = num_reserved_tokens
76
- self.vocab_size = vocab_size
77
- self.soft_cap_logits = soft_cap_logits
78
- self.temperature = temperature
79
- self.top_k = top_k
80
- self.cfg_musiccoca = cfg_musiccoca
81
- self.cfg_notes = cfg_notes
82
- self.cfg_drums = cfg_drums
83
- self.num_notes = num_notes
84
- self.num_drums = num_drums
85
- self.sample_rate = sample_rate
86
- self.frame_samples = frame_samples
87
- super().__init__(**kwargs)
88
-
89
- @classmethod
90
- def from_size(cls, size):
91
- from .depthformer import config_for
92
- from dataclasses import astuple
93
- c = config_for(size)
94
- return cls(
95
- size=size,
96
- encoder_model_dims=c.encoder_model_dims,
97
- temporal=list(astuple(c.temporal)),
98
- depth=list(astuple(c.depth)),
99
- temporal_max_past=c.temporal_max_past,
100
- depth_max_past=c.depth_max_past,
101
- musiccoca_rvq=c.musiccoca_rvq,
102
- musiccoca_per_rvq_vocab=c.musiccoca_per_rvq_vocab,
103
- musiccoca_embed_dim=c.musiccoca_embed_dim,
104
- regular_num_embeddings_per_channel=list(c.regular_num_embeddings_per_channel),
105
- regular_num_channels=c.regular_num_channels,
106
- num_sinks=c.num_sinks,
107
- num_codebooks=c.num_codebooks,
108
- codebook_size=c.codebook_size,
109
- num_reserved_tokens=c.num_reserved_tokens,
110
- vocab_size=c.vocab_size,
111
- soft_cap_logits=c.soft_cap_logits,
112
- )
113
-
114
-
115
- __all__ = ["MagentaRT2Config"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
magenta_rt/torch/depthformer.py CHANGED
@@ -256,15 +256,9 @@ class MultivariateDecoder(nn.Module):
256
  }
257
 
258
  def step_f(self, state, source_frame, sampler=None, forced=None,
259
- temporal_step=None, depth_step=None, cfg_scales=None):
260
  """One functional frame. temporal_step/depth_step override the eager fns
261
- (e.g. with AOTI-compiled callables). Updates state in place; returns [b,1,Q].
262
-
263
- cfg_scales: optional tuple of classifier-free-guidance scales. When set,
264
- `source_frame`/`state` are batched as [positive, neg_1, ...] with
265
- arity = 1 + len(cfg_scales); per-codebook logits are combined as
266
- ``cond + sum_i scale_i*(cond - neg_i)`` before sampling (the native
267
- MLX/.mlxfn path). The single sampled token is broadcast to all rows."""
268
  cfg = self.cfg
269
  tstep = temporal_step or self.temporal_step_fn
270
  dstep = depth_step or self.depth_step_fn
@@ -281,21 +275,10 @@ class MultivariateDecoder(nn.Module):
281
  logits, depth_kv = dstep(depth_input, depth_kv)
282
  lo = cfg.num_reserved_tokens + q * cfg.codebook_size
283
  hi = lo + cfg.codebook_size
284
- if cfg_scales is not None: # classifier-free guidance combine
285
- cond = logits[0:1]
286
- comb = cond
287
- for i, s in enumerate(cfg_scales, start=1):
288
- comb = comb + s * (cond - logits[i:i + 1])
289
- tok = forced[..., q] if forced is not None else sampler(comb.float(), q, lo, hi)
290
- depth_input = self.embed(tok.expand(logits.shape[0], -1))
291
- else:
292
- tok = forced[..., q] if forced is not None else sampler(logits.float(), q, lo, hi)
293
- depth_input = self.embed(tok)
294
  samples.append(tok)
 
295
  frame = torch.stack(samples, dim=-1)
296
- if cfg_scales is not None:
297
- state["prev"] = frame.expand(to.shape[0], -1, -1)
298
- return frame[:1]
299
  state["prev"] = frame
300
  return frame
301
 
 
256
  }
257
 
258
  def step_f(self, state, source_frame, sampler=None, forced=None,
259
+ temporal_step=None, depth_step=None):
260
  """One functional frame. temporal_step/depth_step override the eager fns
261
+ (e.g. with AOTI-compiled callables). Updates state in place; returns [b,1,Q]."""
 
 
 
 
 
 
262
  cfg = self.cfg
263
  tstep = temporal_step or self.temporal_step_fn
264
  dstep = depth_step or self.depth_step_fn
 
275
  logits, depth_kv = dstep(depth_input, depth_kv)
276
  lo = cfg.num_reserved_tokens + q * cfg.codebook_size
277
  hi = lo + cfg.codebook_size
278
+ tok = forced[..., q] if forced is not None else sampler(logits.float(), q, lo, hi)
 
 
 
 
 
 
 
 
 
279
  samples.append(tok)
280
+ depth_input = self.embed(tok)
281
  frame = torch.stack(samples, dim=-1)
 
 
 
282
  state["prev"] = frame
283
  return frame
284
 
magenta_rt/torch/modeling_magenta_rt2.py CHANGED
@@ -24,7 +24,6 @@ loop is a single token stream). MusicCoCa style encoding is a separate
24
 
25
  import json
26
  import os
27
- import warnings
28
 
29
  import numpy as np
30
  import torch
@@ -239,32 +238,6 @@ class MagentaRT2ForConditionalGeneration(MagentaRT2PreTrainedModel):
239
  ]
240
  return self._conditioning(style_tokens, notes, drums, cfgs)
241
 
242
- def _guidance_source(self, style, notes, drums, cfg_musiccoca, cfg_notes):
243
- """OPTIONAL classifier-free-guidance conditioning (the native MLX/.mlxfn path).
244
- Builds a 3-row batch [positive, neg_musiccoca, neg_notes] + per-component scales.
245
- cfg tokens are neutralized (guidance replaces them); negatives mask style / notes.
246
- Returns (source[3,Tc,enc], (cfg_musiccoca, cfg_notes))."""
247
- c = self.config
248
- if style is None:
249
- st = [-1] * self.num_musiccoca
250
- elif isinstance(style, (list, np.ndarray)) and np.asarray(style).ndim == 1 \
251
- and np.asarray(style).dtype.kind in "iu" and len(style) == self.num_musiccoca:
252
- st = list(style)
253
- else:
254
- st = self._tokenize_style(style)
255
- st = (list(st) + [-1] * self.num_musiccoca)[:self.num_musiccoca]
256
- notes = notes if notes is not None else [-1] * self.num_notes
257
- drums = drums if drums is not None else [-1] * self.num_drums
258
- CM = [-1, -1, -1] # neutralized cfg tokens
259
- cond = self._conditioning(st, notes, drums, CM)
260
- neg_mc = self._conditioning([-1] * self.num_musiccoca, notes, drums, CM)
261
- neg_n = self._conditioning(st, [-1] * self.num_notes, drums, CM)
262
- source = self.depthformer.encode(torch.cat([cond, neg_mc, neg_n], 0)).to(self._dt)
263
- cfg_mc = c.cfg_musiccoca if cfg_musiccoca is None else cfg_musiccoca
264
- cfg_n = c.cfg_notes if cfg_notes is None else cfg_notes
265
- _warn_high_cfg(cfg_mc, cfg_n)
266
- return source, (float(cfg_mc), float(cfg_n))
267
-
268
  # ---- codec ----
269
  def _decode_stream(self, history, emitted, context=STREAM_DECODE_CONTEXT,
270
  margin=STREAM_DECODE_MARGIN, flush=False):
@@ -287,25 +260,6 @@ class MagentaRT2ForConditionalGeneration(MagentaRT2PreTrainedModel):
287
  """Fresh state dict for streaming decode (decode_stream)."""
288
  return {}
289
 
290
- @torch.no_grad()
291
- def prefill_f(self, dstate, source_frame, seed_codes):
292
- """Teacher-force seed_codes [1,N,Q] (raw 0..codebook_size-1) through the
293
- temporal transformer to populate its KV cache (native mlx_engine prefill
294
- parity), so generation CONTINUES from the seed. Advances `dstate` in place.
295
- Returns unique-code frames [1,N,Q] for the codec decoder."""
296
- dec = self.depthformer.decoder
297
- Q = self.config.num_codebooks
298
- per_cb = (torch.arange(Q, device=seed_codes.device) * self.codebook_size
299
- + self.num_reserved_tokens).view(1, 1, Q)
300
- unique = seed_codes.to(torch.long) + per_cb
301
- N = unique.shape[1]
302
- for step in range(max(0, N - 1)):
303
- dec.step_f(dstate, source_frame, forced=unique[:, step:step + 1, :],
304
- temporal_step=self._temporal_step, depth_step=self._depth_step)
305
- if N > 0:
306
- dstate["prev"] = unique[:, N - 1:N, :]
307
- return unique
308
-
309
  def decode_stream(self, new_codes, state):
310
  """Incremental codec decode of new token frames [b, t_new, Q] -> audio [b, N, 2].
311
  FLOP-optimal stateful streaming (no overlap-save re-decode); bf16-equivalent to
@@ -328,35 +282,26 @@ class MagentaRT2ForConditionalGeneration(MagentaRT2PreTrainedModel):
328
  @torch.no_grad()
329
  def generate(self, style=None, notes=None, drums=None, cfg_musiccoca=None,
330
  cfg_notes=None, cfg_drums=None, temperature=None, top_k=None,
331
- frames=25, seed=0, state=None, flush=False, return_int16=False,
332
- guidance=False):
333
- """`guidance=False` (default): cfg_* are discretized conditioning tokens — the
334
- validated in-process/JAX path, unchanged. `guidance=True`: cfg_musiccoca/cfg_notes
335
- become classifier-free-guidance scales (negatives + per-codebook logit combine),
336
- matching the native MLX/Mac-app path. Guidance uses eager steps (batch>1)."""
337
  c = self.config
338
  temperature = c.temperature if temperature is None else temperature
339
  top_k = c.top_k if top_k is None else top_k
340
- if guidance:
341
- source, cfg_scales = self._guidance_source(style, notes, drums, cfg_musiccoca, cfg_notes)
342
- arity = len(cfg_scales) + 1
343
- else:
344
- cond = self._resolve_conditioning(style, notes, drums, cfg_musiccoca, cfg_notes, cfg_drums)
345
- source = self.depthformer.encode(cond).to(self._dt)
346
- cfg_scales, arity = None, 1
347
  if state is None:
348
- dstate = self.depthformer.decoder.init_streaming_f(arity, self._dev, self._dt)
349
  gen = torch.Generator(device=self._dev).manual_seed(seed)
350
- decode_state = self.init_decode_state()
 
351
  else:
352
- dstate, gen, decode_state = state["dstate"], state["gen"], state["decode_state"]
353
  sampler = make_sampler(temperature, top_k, gen)
354
- # dynamic-batch AOTI (or eager fallback) handles guidance B>1 and no-guidance B=1 alike.
355
  toks = [self.depthformer.decoder.step_f(
356
- dstate, source, sampler=sampler, cfg_scales=cfg_scales,
357
  temporal_step=self._temporal_step, depth_step=self._depth_step) for _ in range(frames)]
358
- audio = self.decode_stream(torch.cat(toks, dim=1), decode_state) # stateful per-frame streaming decode (40ms frames)
359
- new_state = {"dstate": dstate, "gen": gen, "decode_state": decode_state}
 
360
  wav = audio[0].float().cpu().numpy()
361
  i16 = _float_to_int16(wav)
362
  out = i16 if return_int16 else i16.astype(np.float32) / 32768.0
@@ -364,19 +309,9 @@ class MagentaRT2ForConditionalGeneration(MagentaRT2PreTrainedModel):
364
 
365
  @torch.no_grad()
366
  def stream(self, control, chunk_frames=10, max_seconds=55.0, seed=0,
367
- time_fn=None, sleep_fn=None, notes=None, drums=None, guidance=False,
368
- cudagraph=False):
369
  """Continuous generation. `control()` returns {style_tokens, temperature,
370
- top_k, cfg_*} read every chunk for mid-stream steering. Yields int16 [N,2].
371
-
372
- guidance=False (default): cfg_* are conditioning tokens (validated token path,
373
- unchanged). guidance=True: cfg_musiccoca/cfg_notes are classifier-free-guidance
374
- scales read live every chunk. cudagraph=True: single-dispatch CUDA-graph stepping
375
- (one capture at start, ~4-5x faster), steered via static input buffers."""
376
- if cudagraph:
377
- yield from self._stream_cudagraph(control, chunk_frames, max_seconds, seed,
378
- time_fn, sleep_fn, notes, drums, guidance)
379
- return
380
  import time as _time
381
  time_fn = time_fn or _time.time
382
  sleep_fn = sleep_fn or _time.sleep
@@ -384,11 +319,10 @@ class MagentaRT2ForConditionalGeneration(MagentaRT2PreTrainedModel):
384
  dev, dt = self._dev, self._dt
385
  notes = notes if notes is not None else [-1] * self.num_notes
386
  drums = drums if drums is not None else [-1] * self.num_drums
387
- arity = 3 if guidance else 1
388
- dstate = self.depthformer.decoder.init_streaming_f(arity, dev, dt)
389
  gen = torch.Generator(device=dev).manual_seed(seed)
390
- decode_state = self.init_decode_state()
391
- emitted_samples = 0
392
  cur_tokens = None
393
  source = None
394
  t0 = time_fn()
@@ -400,269 +334,23 @@ class MagentaRT2ForConditionalGeneration(MagentaRT2PreTrainedModel):
400
  tokens = ctl["style_tokens"]
401
  if tokens != cur_tokens:
402
  cur_tokens = tokens
403
- st = (list(tokens) + [-1] * self.num_musiccoca)[:self.num_musiccoca]
404
- if guidance: # [pos, neg_mc, neg_n]; cfg tokens neutralized
405
- source, _ = self._guidance_source(st, notes, drums, None, None)
406
- else:
407
- cfgs = [discretize_cfg(ctl.get("cfg_musiccoca", c.cfg_musiccoca), 0.2, 40),
408
- discretize_cfg(ctl.get("cfg_notes", c.cfg_notes), 0.2, 40),
409
- discretize_cfg(ctl.get("cfg_drums", c.cfg_drums), 1.0, 8)]
410
- source = self.depthformer.encode(self._conditioning(st, notes, drums, cfgs)).to(dt)
411
- cfg_scales = ((float(ctl.get("cfg_musiccoca", c.cfg_musiccoca)), # live scales (unclamped)
412
- float(ctl.get("cfg_notes", c.cfg_notes))) if guidance else None)
413
  sampler = make_sampler(ctl.get("temperature", c.temperature), ctl.get("top_k", c.top_k), gen)
414
  toks = [self.depthformer.decoder.step_f(
415
- dstate, source, sampler=sampler, cfg_scales=cfg_scales,
416
  temporal_step=self._temporal_step, depth_step=self._depth_step) for _ in range(chunk_frames)]
417
- audio = self.decode_stream(torch.cat(toks, dim=1), decode_state)
418
- emitted_samples += audio.shape[1]
419
  if audio.shape[1] > 0:
420
  yield _float_to_int16(audio[0].float().cpu().numpy())
421
- ahead = (emitted_samples / SR) - (time_fn() - t0)
422
  if ahead > 1.0:
423
  sleep_fn(min(ahead - 1.0, 0.5))
424
 
425
- @torch.no_grad()
426
- def _stream_cudagraph(self, control, chunk_frames, max_seconds, seed,
427
- time_fn, sleep_fn, notes, drums, guidance):
428
- """CUDA-graph backend for stream(cudagraph=True): one capture at start
429
- (warmup ~KEEP frames), then single-dispatch replay per frame. Steering
430
- goes through the streamer's static input buffers — cfg/temperature are
431
- buffer writes; a style change re-encodes + set_source (windowed ramp)."""
432
- import time as _time
433
- time_fn = time_fn or _time.time
434
- sleep_fn = sleep_fn or _time.sleep
435
- c = self.config
436
- dt = self._dt
437
- notes = notes if notes is not None else [-1] * self.num_notes
438
- drums = drums if drums is not None else [-1] * self.num_drums
439
-
440
- def encode_src(tokens, cfg_mc, cfg_n):
441
- st = (list(tokens) + [-1] * self.num_musiccoca)[:self.num_musiccoca]
442
- if guidance:
443
- return self._guidance_source(st, notes, drums, cfg_mc, cfg_n)[0]
444
- cfgs = [discretize_cfg(cfg_mc, 0.2, 40), discretize_cfg(cfg_n, 0.2, 40),
445
- discretize_cfg(c.cfg_drums, 1.0, 8)]
446
- return self.depthformer.encode(self._conditioning(st, notes, drums, cfgs)).to(dt)
447
 
448
- # bounded wait for the first conditioning, then build + capture the graph
449
- t0 = time_fn()
450
- ctl = control()
451
- while ctl is None and time_fn() - t0 < max_seconds:
452
- sleep_fn(0.02); ctl = control()
453
- if ctl is None:
454
- return
455
- cur_tokens = ctl["style_tokens"]
456
- cur_cfg = (float(ctl.get("cfg_musiccoca", c.cfg_musiccoca)),
457
- float(ctl.get("cfg_notes", c.cfg_notes)))
458
- streamer = self.make_cudagraph_streamer(
459
- style=cur_tokens, notes=notes, drums=drums,
460
- cfg_musiccoca=cur_cfg[0], cfg_notes=cur_cfg[1],
461
- temperature=ctl.get("temperature", c.temperature),
462
- top_k=ctl.get("top_k", c.top_k), seed=seed, guidance=guidance)
463
- decode_state = self.init_decode_state()
464
- emitted_samples = 0
465
- t0 = time_fn()
466
- while time_fn() - t0 < max_seconds:
467
- ctl = control()
468
- if ctl is None:
469
- sleep_fn(0.005); continue
470
- tokens = ctl["style_tokens"]
471
- cfg_mc = float(ctl.get("cfg_musiccoca", c.cfg_musiccoca))
472
- cfg_n = float(ctl.get("cfg_notes", c.cfg_notes))
473
- if guidance:
474
- if (cfg_mc, cfg_n) != cur_cfg:
475
- streamer.set_cfg([cfg_mc, cfg_n]); cur_cfg = (cfg_mc, cfg_n)
476
- if tokens != cur_tokens:
477
- streamer.set_source(encode_src(tokens, cfg_mc, cfg_n)); cur_tokens = tokens
478
- elif tokens != cur_tokens or (cfg_mc, cfg_n) != cur_cfg: # token path: cfg lives in source
479
- streamer.set_source(encode_src(tokens, cfg_mc, cfg_n))
480
- cur_tokens, cur_cfg = tokens, (cfg_mc, cfg_n)
481
- streamer.set_temperature(ctl.get("temperature", c.temperature))
482
- toks = [streamer.step() for _ in range(chunk_frames)]
483
- audio = self.decode_stream(torch.cat(toks, dim=1), decode_state)
484
- emitted_samples += audio.shape[1]
485
- if audio.shape[1] > 0:
486
- yield _float_to_int16(audio[0].float().cpu().numpy())
487
- ahead = (emitted_samples / SR) - (time_fn() - t0)
488
- if ahead > 1.0:
489
- sleep_fn(min(ahead - 1.0, 0.5))
490
-
491
- @torch.no_grad()
492
- def make_cudagraph_streamer(self, style=None, notes=None, drums=None,
493
- cfg_musiccoca=None, cfg_notes=None, cfg_drums=None,
494
- temperature=None, top_k=None, seed=0, guidance=False,
495
- warmup=None):
496
- """One-dispatch-per-frame CUDA-graph streaming: captures the whole frame
497
- (temporal + N-codebook depth + in-graph sampler + optional CFG) as a single
498
- `torch.cuda.graph` replay over fixed-size static KV buffers — ~MLX `.mlxfn`.
499
- Returns a `CudaGraphStreamer`; call `.step()` for the next frame [1,1,Q]
500
- (decode with `decode_stream`), and `.set_cfg/.set_temperature/.set_source`
501
- for live steering (no re-capture). `top_k` is fixed at capture time."""
502
- if guidance:
503
- source, scales = self._guidance_source(style, notes, drums, cfg_musiccoca, cfg_notes)
504
- num_neg = len(scales)
505
- else:
506
- cond = self._resolve_conditioning(style, notes, drums, cfg_musiccoca, cfg_notes, cfg_drums)
507
- source = self.depthformer.encode(cond).to(self._dt)
508
- scales, num_neg = None, 0
509
- temperature = self.config.temperature if temperature is None else temperature
510
- top_k = self.config.top_k if top_k is None else top_k
511
- return CudaGraphStreamer(self.depthformer.decoder, source, self._dt, num_neg, scales,
512
- temperature, top_k, seed, warmup)
513
-
514
-
515
- # Classifier-free guidance scales above this can run away / collapse the output
516
- # to silence under *sustained constant* conditioning over long runs (the native
517
- # UI uses a 0-5 slider, default 2.4). We don't clamp — values pass through to
518
- # match the native range — but we warn once so the caller knows the risk.
519
- GUIDANCE_CFG_WARN = 3.5
520
-
521
-
522
- def _warn_high_cfg(*scales):
523
- hi = [round(float(s), 2) for s in scales if float(s) > GUIDANCE_CFG_WARN]
524
- if hi:
525
- warnings.warn(
526
- f"CFG guidance scale(s) {hi} exceed ~{GUIDANCE_CFG_WARN}; sustained high "
527
- "guidance on constant conditioning can make the output run away / collapse "
528
- "to silence over long runs. (Changing notes/style during play avoids this.)",
529
- stacklevel=3)
530
- return True
531
- return False
532
-
533
-
534
- class CudaGraphStreamer:
535
- """Single-dispatch CUDA-graph frame stepper over fixed-size static KV buffers.
536
-
537
- Warms `KEEP` frames eagerly to fill the temporal/cross KV to steady state,
538
- snapshots them into static buffers, then captures one frame (temporal + depth +
539
- sampler) with `torch.cuda.graph`. `.step()` replays it (one GPU dispatch) and
540
- returns the new frame tokens. Live steering writes into static input buffers
541
- (`source`, `cfg`, `temperature`) — the captured graph reads them, no re-capture.
542
- Conditioning changes ramp in via the windowed cross-KV (optional hard flush)."""
543
-
544
- def __init__(self, decoder, source, decode_dtype, num_neg=0, cfg_scales=None,
545
- temperature=1.1, top_k=50, seed=0, warmup=None):
546
- """decoder: a MultivariateDecoder (`model.depthformer.decoder` for the
547
- modeling class, `model.model.decoder` for the system class). `source` is the
548
- pre-encoded conditioning [B, Tc, enc] (B = 1 + num_neg); `decode_dtype` the
549
- compute dtype. Class-agnostic so both model wrappers can build it."""
550
- dec = decoder
551
- c = dec.cfg
552
- self.dec = dec
553
- self.Q, self.CB, self.NR = c.num_codebooks, c.codebook_size, c.num_reserved_tokens
554
- self.KEEP = c.temporal_max_past + 1
555
- self.num_neg = num_neg
556
- self.top_k = int(top_k)
557
- dev, dt = source.device, decode_dtype
558
- B = source.shape[0]; self.B = B
559
- # live-steering static inputs
560
- self.source = source.clone()
561
- self.cfg = (torch.zeros(0, device=dev, dtype=torch.float32) if not num_neg
562
- else torch.tensor([float(s) for s in cfg_scales], device=dev, dtype=torch.float32))
563
- self.temp = torch.tensor(float(temperature), device=dev, dtype=torch.float32)
564
- torch.manual_seed(seed)
565
- # 1) prime to steady state (KV == KEEP on every layer)
566
- st = dec.init_streaming_f(B, dev, dt)
567
- K = self.KEEP
568
- for _ in range(K + 8 if warmup is None else warmup):
569
- to, ns, nc = dec.temporal_step_fn(st["prev"], st["self"], st["cross"], self.source)
570
- st["self"] = [(k[:, -K:], v[:, -K:]) for k, v in ns]
571
- st["cross"] = [(k[:, -K:], v[:, -K:]) for k, v in nc]
572
- frame = self._depth_sample(to)
573
- st["prev"] = frame.expand(B, -1, -1)
574
- # 2) static KV + state buffers
575
- L = len(st["self"]); self.L = L
576
- self.SK = [st["self"][i][0].clone() for i in range(L)]; self.SV = [st["self"][i][1].clone() for i in range(L)]
577
- self.CK = [st["cross"][i][0].clone() for i in range(L)]; self.CV = [st["cross"][i][1].clone() for i in range(L)]
578
- self.prev = st["prev"].clone()
579
- self.out = torch.zeros(1, 1, self.Q, dtype=torch.long, device=dev)
580
- # 3) capture (side-stream warmup is required before graph capture)
581
- s = torch.cuda.Stream(); s.wait_stream(torch.cuda.current_stream())
582
- with torch.cuda.stream(s):
583
- for _ in range(3):
584
- self._frame_static()
585
- torch.cuda.current_stream().wait_stream(s)
586
- self.graph = torch.cuda.CUDAGraph()
587
- with torch.cuda.graph(self.graph):
588
- self._frame_static()
589
-
590
- def _depth_sample(self, to):
591
- dec = self.dec; B = self.B; Q, CB, NR = self.Q, self.CB, self.NR
592
- dd = dec.cfg.depth
593
- z = torch.zeros(B, 0, dd.num_heads, dd.dim_per_head, device=to.device, dtype=to.dtype)
594
- dk = [(z, z) for _ in range(dd.num_layers)]
595
- di = to; toks = []
596
- for q in range(Q):
597
- logits, dk = dec.depth_step_fn(di, dk) # [B,1,V]
598
- lo = NR + q * CB
599
- ls = logits[..., lo:lo + CB]
600
- cond = ls[0:1]; comb = cond
601
- for i in range(self.num_neg): # classifier-free guidance combine
602
- comb = comb + self.cfg[i] * (cond - ls[i + 1:i + 2])
603
- kth = torch.topk(comb, self.top_k, dim=-1).values[..., -1:]
604
- comb = torch.where(comb >= kth, comb, torch.full_like(comb, -1e9))
605
- u = torch.rand(1, 1, CB, device=to.device, dtype=torch.float32) # graph-safe RNG
606
- g = -torch.log(-torch.log(u.clamp(1e-10, 1 - 1e-7)))
607
- tok = (comb + g * self.temp).argmax(-1) + lo
608
- toks.append(tok)
609
- di = dec.embed(tok.expand(B, -1))
610
- return torch.stack(toks, dim=-1) # [1,1,Q]
611
-
612
- def _frame_static(self):
613
- dec = self.dec; K = self.KEEP; L = self.L
614
- to, ns, nc = dec.temporal_step_fn(
615
- self.prev, [(self.SK[i], self.SV[i]) for i in range(L)],
616
- [(self.CK[i], self.CV[i]) for i in range(L)], self.source)
617
- for i in range(L):
618
- self.SK[i].copy_(ns[i][0][:, -K:]); self.SV[i].copy_(ns[i][1][:, -K:])
619
- self.CK[i].copy_(nc[i][0][:, -K:]); self.CV[i].copy_(nc[i][1][:, -K:])
620
- frame = self._depth_sample(to)
621
- self.out.copy_(frame)
622
- self.prev.copy_(frame.expand(self.B, -1, -1))
623
-
624
- # ---- live steering (no re-capture) ----
625
- def set_cfg(self, scales):
626
- if self.num_neg:
627
- if not getattr(self, "_cfg_warned", False):
628
- self._cfg_warned = _warn_high_cfg(*scales)
629
- self.cfg.copy_(torch.tensor([float(s) for s in scales],
630
- device=self.cfg.device, dtype=torch.float32))
631
-
632
- def set_temperature(self, t):
633
- self.temp.fill_(float(t))
634
-
635
- def set_source(self, source, flush=False):
636
- """Update conditioning. Ramps in via the windowed cross-KV; flush=True
637
- overwrites all cross-KV slots for an immediate change."""
638
- self.source.copy_(source if source.shape[0] == self.B else source.expand(self.B, -1, -1))
639
- if flush:
640
- for i in range(self.L):
641
- sk, sv = self.dec.temporal_body.layers[i]["cross_attention"]._kv(self.source)
642
- self.CK[i].copy_(sk[:, -self.KEEP:]); self.CV[i].copy_(sv[:, -self.KEEP:])
643
-
644
- def step(self):
645
- """Advance one frame (single CUDA-graph dispatch). Returns tokens [1,1,Q]."""
646
- self.graph.replay()
647
- return self.out.clone()
648
-
649
- def close(self):
650
- """Free the captured CUDA graph + its private memory pool. Idempotent;
651
- call at session end (the WS worker should). Safe during interpreter
652
- shutdown — swallows teardown-ordering errors."""
653
- g = getattr(self, "graph", None)
654
- if g is not None:
655
- try:
656
- g.reset()
657
- except Exception:
658
- pass
659
- self.graph = None
660
-
661
- def __del__(self):
662
- try:
663
- self.close()
664
- except Exception:
665
- pass
666
-
667
-
668
- __all__ = ["MagentaRT2ForConditionalGeneration", "MagentaRT2PreTrainedModel", "CudaGraphStreamer"]
 
24
 
25
  import json
26
  import os
 
27
 
28
  import numpy as np
29
  import torch
 
238
  ]
239
  return self._conditioning(style_tokens, notes, drums, cfgs)
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  # ---- codec ----
242
  def _decode_stream(self, history, emitted, context=STREAM_DECODE_CONTEXT,
243
  margin=STREAM_DECODE_MARGIN, flush=False):
 
260
  """Fresh state dict for streaming decode (decode_stream)."""
261
  return {}
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  def decode_stream(self, new_codes, state):
264
  """Incremental codec decode of new token frames [b, t_new, Q] -> audio [b, N, 2].
265
  FLOP-optimal stateful streaming (no overlap-save re-decode); bf16-equivalent to
 
282
  @torch.no_grad()
283
  def generate(self, style=None, notes=None, drums=None, cfg_musiccoca=None,
284
  cfg_notes=None, cfg_drums=None, temperature=None, top_k=None,
285
+ frames=25, seed=0, state=None, flush=False, return_int16=False):
 
 
 
 
 
286
  c = self.config
287
  temperature = c.temperature if temperature is None else temperature
288
  top_k = c.top_k if top_k is None else top_k
289
+ cond = self._resolve_conditioning(style, notes, drums, cfg_musiccoca, cfg_notes, cfg_drums)
290
+ source = self.depthformer.encode(cond).to(self._dt)
 
 
 
 
 
291
  if state is None:
292
+ dstate = self.depthformer.decoder.init_streaming_f(1, self._dev, self._dt)
293
  gen = torch.Generator(device=self._dev).manual_seed(seed)
294
+ history = torch.zeros((1, 0, c.num_codebooks), dtype=torch.long, device=self._dev)
295
+ emitted = 0
296
  else:
297
+ dstate, gen, history, emitted = state["dstate"], state["gen"], state["history"], state["emitted"]
298
  sampler = make_sampler(temperature, top_k, gen)
 
299
  toks = [self.depthformer.decoder.step_f(
300
+ dstate, source, sampler=sampler,
301
  temporal_step=self._temporal_step, depth_step=self._depth_step) for _ in range(frames)]
302
+ history = torch.cat([history] + toks, dim=1)
303
+ audio, emitted = self._decode_stream(history, emitted, flush=flush)
304
+ new_state = {"dstate": dstate, "gen": gen, "history": history, "emitted": emitted}
305
  wav = audio[0].float().cpu().numpy()
306
  i16 = _float_to_int16(wav)
307
  out = i16 if return_int16 else i16.astype(np.float32) / 32768.0
 
309
 
310
  @torch.no_grad()
311
  def stream(self, control, chunk_frames=10, max_seconds=55.0, seed=0,
312
+ time_fn=None, sleep_fn=None, notes=None, drums=None):
 
313
  """Continuous generation. `control()` returns {style_tokens, temperature,
314
+ top_k, cfg_*} read every chunk for mid-stream steering. Yields int16 [N,2]."""
 
 
 
 
 
 
 
 
 
315
  import time as _time
316
  time_fn = time_fn or _time.time
317
  sleep_fn = sleep_fn or _time.sleep
 
319
  dev, dt = self._dev, self._dt
320
  notes = notes if notes is not None else [-1] * self.num_notes
321
  drums = drums if drums is not None else [-1] * self.num_drums
322
+ dstate = self.depthformer.decoder.init_streaming_f(1, dev, dt)
 
323
  gen = torch.Generator(device=dev).manual_seed(seed)
324
+ history = torch.zeros((1, 0, c.num_codebooks), dtype=torch.long, device=dev)
325
+ emitted = 0
326
  cur_tokens = None
327
  source = None
328
  t0 = time_fn()
 
334
  tokens = ctl["style_tokens"]
335
  if tokens != cur_tokens:
336
  cur_tokens = tokens
337
+ cfgs = [discretize_cfg(ctl.get("cfg_musiccoca", c.cfg_musiccoca), 0.2, 40),
338
+ discretize_cfg(ctl.get("cfg_notes", c.cfg_notes), 0.2, 40),
339
+ discretize_cfg(ctl.get("cfg_drums", c.cfg_drums), 1.0, 8)]
340
+ cond = self._conditioning((list(tokens) + [-1] * self.num_musiccoca)[:self.num_musiccoca],
341
+ notes, drums, cfgs)
342
+ source = self.depthformer.encode(cond).to(dt)
 
 
 
 
343
  sampler = make_sampler(ctl.get("temperature", c.temperature), ctl.get("top_k", c.top_k), gen)
344
  toks = [self.depthformer.decoder.step_f(
345
+ dstate, source, sampler=sampler,
346
  temporal_step=self._temporal_step, depth_step=self._depth_step) for _ in range(chunk_frames)]
347
+ history = torch.cat([history] + toks, dim=1)
348
+ audio, emitted = self._decode_stream(history, emitted)
349
  if audio.shape[1] > 0:
350
  yield _float_to_int16(audio[0].float().cpu().numpy())
351
+ ahead = (emitted * FRAME_SAMPLES / SR) - (time_fn() - t0)
352
  if ahead > 1.0:
353
  sleep_fn(min(ahead - 1.0, 0.5))
354
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
+ __all__ = ["MagentaRT2ForConditionalGeneration", "MagentaRT2PreTrainedModel"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
magenta_rt/torch/processing_musiccoca.py DELETED
@@ -1,77 +0,0 @@
1
- # Copyright 2026 Google LLC
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """MusicCoCa style processor for Magenta RealTime 2.
16
-
17
- A processor-style component (like a feature extractor / tokenizer): turns a text
18
- prompt OR an audio clip into 12 RVQ style tokens that condition the model. Pure
19
- torch + sentencepiece (text tower, audio tower, RVQ all torch-native).
20
- """
21
-
22
- import os
23
-
24
- import numpy as np
25
-
26
-
27
- class MusicCoCaProcessor:
28
- """Text/audio -> 12 RVQ style tokens (and 768-d embeddings, for layering)."""
29
-
30
- def __init__(self, musiccoca):
31
- self._mc = musiccoca
32
-
33
- @classmethod
34
- def from_pretrained(cls, pretrained_model_name_or_path, device="cpu", **kwargs):
35
- from .musiccoca import MusicCoCa
36
- p = pretrained_model_name_or_path
37
- if p is not None and os.path.isdir(p) and os.path.exists(os.path.join(p, "text_encoder.pt")):
38
- mc = MusicCoCa(resource_dir=p, device=device)
39
- else:
40
- mc = MusicCoCa(repo_id=p, device=device) if p else MusicCoCa(device=device)
41
- return cls(mc)
42
-
43
- def save_pretrained(self, save_directory, **kwargs):
44
- # Artifacts live in the MusicCoCa hub repo; nothing extra to serialize here.
45
- os.makedirs(save_directory, exist_ok=True)
46
-
47
- @property
48
- def device(self):
49
- return self._mc.device
50
-
51
- def to(self, device):
52
- self._mc.to(device)
53
- return self
54
-
55
- def embed(self, text_or_audio):
56
- """Text str / audio (Waveform | (samples, sr) | np@16kHz) -> [768] torch."""
57
- return self._mc.embed(text_or_audio)
58
-
59
- def tokenize(self, embedding):
60
- """[768] embedding -> [12] int RVQ tokens (np.int64)."""
61
- return self._mc.tokenize(embedding)
62
-
63
- def layer(self, prompts, weights=None):
64
- """Blend several prompts (text/audio) by weighted-mean of embeddings,
65
- then tokenize. `prompts` is a list; `weights` defaults to uniform."""
66
- embs = [self.embed(p) for p in prompts]
67
- w = weights or [1.0 / len(embs)] * len(embs)
68
- emb = sum(wi * e for wi, e in zip(w, embs))
69
- return self.tokenize(emb).tolist()
70
-
71
- def __call__(self, text_or_audio, return_tokens=True):
72
- """-> 12 style tokens (list[int]) by default, or the [768] embedding."""
73
- emb = self.embed(text_or_audio)
74
- return self.tokenize(emb).tolist() if return_tokens else emb
75
-
76
-
77
- __all__ = ["MusicCoCaProcessor"]