Spaces:
Runtime error
Runtime error
Revert "jam: single-dispatch CUDA-graph stepping in the WS worker (eager fallback)"
Browse filesThis reverts commit a8fed3aa371c9b349e9ebf31ba5b4c5cc378644f.
- app.py +4 -27
- magenta_rt/torch/configuration_magenta_rt2.py +0 -115
- magenta_rt/torch/depthformer.py +4 -21
- magenta_rt/torch/modeling_magenta_rt2.py +27 -339
- magenta_rt/torch/processing_musiccoca.py +0 -77
app.py
CHANGED
|
@@ -162,9 +162,6 @@ async def banks(session_id: str = ""):
|
|
| 162 |
def gpu_stream(session_id):
|
| 163 |
"""Continuous gen; switches model live when the dropdown changes."""
|
| 164 |
from magenta_rt.torch.system import make_sampler, discretize_cfg, _float_to_int16, FRAME_SAMPLES
|
| 165 |
-
from magenta_rt.torch.modeling_magenta_rt2 import CudaGraphStreamer
|
| 166 |
-
USE_CG = os.environ.get('MRT_CUDAGRAPH', '1') == '1' # single-dispatch CUDA-graph stepping (eager fallback)
|
| 167 |
-
cg_ok = True
|
| 168 |
if style_model.device != "cuda":
|
| 169 |
style_model.to("cuda")
|
| 170 |
dev, dt = "cuda", torch.bfloat16
|
|
@@ -186,7 +183,6 @@ def gpu_stream(session_id):
|
|
| 186 |
print("[warmup]", repr(_e), flush=True)
|
| 187 |
notes, drums = [-1] * 128, [-1]
|
| 188 |
cur_name = model = dstate = source = None
|
| 189 |
-
streamer = last_src_for_graph = None
|
| 190 |
decode_state = {}
|
| 191 |
emitted_samples = 0
|
| 192 |
t0 = time.time()
|
|
@@ -211,14 +207,12 @@ def gpu_stream(session_id):
|
|
| 211 |
decode_state = model.init_decode_state()
|
| 212 |
emitted_samples, source, prev_active = 0, None, set()
|
| 213 |
cur_style_sig = cur_note_sig = cur_tokens = None; had_onsets = False
|
| 214 |
-
streamer = last_src_for_graph = None # rebuild CUDA graph on model switch / reset
|
| 215 |
seed = int(c.get("seed", 0))
|
| 216 |
if seed != cur_seed:
|
| 217 |
cur_seed = seed
|
| 218 |
gen = torch.Generator(device=dev).manual_seed(seed)
|
| 219 |
-
streamer = None # re-seed => re-capture (graph RNG fixed at capture)
|
| 220 |
bop = c.get("bank_op")
|
| 221 |
-
if bop and int(bop.get("ver", 0)) != cur_bank_ver
|
| 222 |
cur_bank_ver = int(bop.get("ver", 0))
|
| 223 |
bpath = os.path.join(SESSION_DIR, f"{os.path.basename(session_id)}_bank{int(bop.get('idx', 0))}.pt")
|
| 224 |
try:
|
|
@@ -297,26 +291,9 @@ def gpu_stream(session_id):
|
|
| 297 |
cond = model._conditioning((list(cur_tokens) + [-1] * model.num_musiccoca)[:model.num_musiccoca],
|
| 298 |
nvec, drm, cfgs)
|
| 299 |
source = model.model.encode(cond).to(dt)
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
try:
|
| 304 |
-
if streamer is None: # build + capture on first frame (~2-3s warmup)
|
| 305 |
-
streamer = CudaGraphStreamer(model.model.decoder, source, dt,
|
| 306 |
-
temperature=temp, top_k=topk, seed=cur_seed)
|
| 307 |
-
last_src_for_graph = source
|
| 308 |
-
elif source is not last_src_for_graph: # conditioning changed -> update static buffer
|
| 309 |
-
streamer.set_source(source); last_src_for_graph = source
|
| 310 |
-
streamer.set_temperature(temp)
|
| 311 |
-
toks.append(streamer.step()); ok = True
|
| 312 |
-
except Exception as _cge:
|
| 313 |
-
print("[cudagraph] fallback to eager:", repr(_cge), flush=True)
|
| 314 |
-
cg_ok = False; streamer = None
|
| 315 |
-
dstate = model.model.decoder.init_streaming_f(1, dev, dt)
|
| 316 |
-
if not ok: # eager fallback path
|
| 317 |
-
sampler = make_sampler(temp, topk, gen)
|
| 318 |
-
toks.append(model.model.decoder.step_f(dstate, source, sampler=sampler,
|
| 319 |
-
temporal_step=model._temporal_step, depth_step=model._depth_step))
|
| 320 |
new_codes = torch.cat(toks, dim=1)
|
| 321 |
audio = model.decode_stream(new_codes, decode_state) # FLOP-optimal stateful streaming decode
|
| 322 |
emitted_samples += audio.shape[1]
|
|
|
|
| 162 |
def gpu_stream(session_id):
|
| 163 |
"""Continuous gen; switches model live when the dropdown changes."""
|
| 164 |
from magenta_rt.torch.system import make_sampler, discretize_cfg, _float_to_int16, FRAME_SAMPLES
|
|
|
|
|
|
|
|
|
|
| 165 |
if style_model.device != "cuda":
|
| 166 |
style_model.to("cuda")
|
| 167 |
dev, dt = "cuda", torch.bfloat16
|
|
|
|
| 183 |
print("[warmup]", repr(_e), flush=True)
|
| 184 |
notes, drums = [-1] * 128, [-1]
|
| 185 |
cur_name = model = dstate = source = None
|
|
|
|
| 186 |
decode_state = {}
|
| 187 |
emitted_samples = 0
|
| 188 |
t0 = time.time()
|
|
|
|
| 207 |
decode_state = model.init_decode_state()
|
| 208 |
emitted_samples, source, prev_active = 0, None, set()
|
| 209 |
cur_style_sig = cur_note_sig = cur_tokens = None; had_onsets = False
|
|
|
|
| 210 |
seed = int(c.get("seed", 0))
|
| 211 |
if seed != cur_seed:
|
| 212 |
cur_seed = seed
|
| 213 |
gen = torch.Generator(device=dev).manual_seed(seed)
|
|
|
|
| 214 |
bop = c.get("bank_op")
|
| 215 |
+
if bop and int(bop.get("ver", 0)) != cur_bank_ver: # save/recall generation state
|
| 216 |
cur_bank_ver = int(bop.get("ver", 0))
|
| 217 |
bpath = os.path.join(SESSION_DIR, f"{os.path.basename(session_id)}_bank{int(bop.get('idx', 0))}.pt")
|
| 218 |
try:
|
|
|
|
| 291 |
cond = model._conditioning((list(cur_tokens) + [-1] * model.num_musiccoca)[:model.num_musiccoca],
|
| 292 |
nvec, drm, cfgs)
|
| 293 |
source = model.model.encode(cond).to(dt)
|
| 294 |
+
sampler = make_sampler(c.get("temperature", 1.1), c.get("top_k", 50), gen)
|
| 295 |
+
toks.append(model.model.decoder.step_f(dstate, source, sampler=sampler,
|
| 296 |
+
temporal_step=model._temporal_step, depth_step=model._depth_step))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
new_codes = torch.cat(toks, dim=1)
|
| 298 |
audio = model.decode_stream(new_codes, decode_state) # FLOP-optimal stateful streaming decode
|
| 299 |
emitted_samples += audio.shape[1]
|
magenta_rt/torch/configuration_magenta_rt2.py
DELETED
|
@@ -1,115 +0,0 @@
|
|
| 1 |
-
# Copyright 2026 Google LLC
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
"""HF config for the Magenta RealTime 2 PyTorch model."""
|
| 16 |
-
|
| 17 |
-
from transformers import PretrainedConfig
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class MagentaRT2Config(PretrainedConfig):
|
| 21 |
-
"""Config for `MagentaRT2ForConditionalGeneration`.
|
| 22 |
-
|
| 23 |
-
`temporal` / `depth` are [num_layers, model_dims, hidden_dims, num_heads,
|
| 24 |
-
dim_per_head] for the two Depthformer transformer stacks.
|
| 25 |
-
"""
|
| 26 |
-
|
| 27 |
-
model_type = "magenta_rt2"
|
| 28 |
-
|
| 29 |
-
def __init__(
|
| 30 |
-
self,
|
| 31 |
-
size="mrt2_small",
|
| 32 |
-
encoder_model_dims=256,
|
| 33 |
-
temporal=(12, 1024, 4096, 8, 128),
|
| 34 |
-
depth=(2, 768, 3072, 6, 128),
|
| 35 |
-
temporal_max_past=41,
|
| 36 |
-
depth_max_past=12,
|
| 37 |
-
musiccoca_rvq=12,
|
| 38 |
-
musiccoca_per_rvq_vocab=1031,
|
| 39 |
-
musiccoca_embed_dim=768,
|
| 40 |
-
regular_num_embeddings_per_channel=None,
|
| 41 |
-
regular_num_channels=132,
|
| 42 |
-
num_sinks=1,
|
| 43 |
-
num_codebooks=12,
|
| 44 |
-
codebook_size=1024,
|
| 45 |
-
num_reserved_tokens=6,
|
| 46 |
-
vocab_size=12294,
|
| 47 |
-
soft_cap_logits=30.0,
|
| 48 |
-
temperature=1.3,
|
| 49 |
-
top_k=40,
|
| 50 |
-
cfg_musiccoca=3.0,
|
| 51 |
-
cfg_notes=1.0,
|
| 52 |
-
cfg_drums=1.0,
|
| 53 |
-
num_notes=128,
|
| 54 |
-
num_drums=1,
|
| 55 |
-
sample_rate=48000,
|
| 56 |
-
frame_samples=1920,
|
| 57 |
-
codec_param_shapes=None,
|
| 58 |
-
**kwargs,
|
| 59 |
-
):
|
| 60 |
-
self.size = size
|
| 61 |
-
self.codec_param_shapes = codec_param_shapes
|
| 62 |
-
self.encoder_model_dims = encoder_model_dims
|
| 63 |
-
self.temporal = list(temporal)
|
| 64 |
-
self.depth = list(depth)
|
| 65 |
-
self.temporal_max_past = temporal_max_past
|
| 66 |
-
self.depth_max_past = depth_max_past
|
| 67 |
-
self.musiccoca_rvq = musiccoca_rvq
|
| 68 |
-
self.musiccoca_per_rvq_vocab = musiccoca_per_rvq_vocab
|
| 69 |
-
self.musiccoca_embed_dim = musiccoca_embed_dim
|
| 70 |
-
self.regular_num_embeddings_per_channel = regular_num_embeddings_per_channel
|
| 71 |
-
self.regular_num_channels = regular_num_channels
|
| 72 |
-
self.num_sinks = num_sinks
|
| 73 |
-
self.num_codebooks = num_codebooks
|
| 74 |
-
self.codebook_size = codebook_size
|
| 75 |
-
self.num_reserved_tokens = num_reserved_tokens
|
| 76 |
-
self.vocab_size = vocab_size
|
| 77 |
-
self.soft_cap_logits = soft_cap_logits
|
| 78 |
-
self.temperature = temperature
|
| 79 |
-
self.top_k = top_k
|
| 80 |
-
self.cfg_musiccoca = cfg_musiccoca
|
| 81 |
-
self.cfg_notes = cfg_notes
|
| 82 |
-
self.cfg_drums = cfg_drums
|
| 83 |
-
self.num_notes = num_notes
|
| 84 |
-
self.num_drums = num_drums
|
| 85 |
-
self.sample_rate = sample_rate
|
| 86 |
-
self.frame_samples = frame_samples
|
| 87 |
-
super().__init__(**kwargs)
|
| 88 |
-
|
| 89 |
-
@classmethod
|
| 90 |
-
def from_size(cls, size):
|
| 91 |
-
from .depthformer import config_for
|
| 92 |
-
from dataclasses import astuple
|
| 93 |
-
c = config_for(size)
|
| 94 |
-
return cls(
|
| 95 |
-
size=size,
|
| 96 |
-
encoder_model_dims=c.encoder_model_dims,
|
| 97 |
-
temporal=list(astuple(c.temporal)),
|
| 98 |
-
depth=list(astuple(c.depth)),
|
| 99 |
-
temporal_max_past=c.temporal_max_past,
|
| 100 |
-
depth_max_past=c.depth_max_past,
|
| 101 |
-
musiccoca_rvq=c.musiccoca_rvq,
|
| 102 |
-
musiccoca_per_rvq_vocab=c.musiccoca_per_rvq_vocab,
|
| 103 |
-
musiccoca_embed_dim=c.musiccoca_embed_dim,
|
| 104 |
-
regular_num_embeddings_per_channel=list(c.regular_num_embeddings_per_channel),
|
| 105 |
-
regular_num_channels=c.regular_num_channels,
|
| 106 |
-
num_sinks=c.num_sinks,
|
| 107 |
-
num_codebooks=c.num_codebooks,
|
| 108 |
-
codebook_size=c.codebook_size,
|
| 109 |
-
num_reserved_tokens=c.num_reserved_tokens,
|
| 110 |
-
vocab_size=c.vocab_size,
|
| 111 |
-
soft_cap_logits=c.soft_cap_logits,
|
| 112 |
-
)
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
__all__ = ["MagentaRT2Config"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
magenta_rt/torch/depthformer.py
CHANGED
|
@@ -256,15 +256,9 @@ class MultivariateDecoder(nn.Module):
|
|
| 256 |
}
|
| 257 |
|
| 258 |
def step_f(self, state, source_frame, sampler=None, forced=None,
|
| 259 |
-
temporal_step=None, depth_step=None
|
| 260 |
"""One functional frame. temporal_step/depth_step override the eager fns
|
| 261 |
-
(e.g. with AOTI-compiled callables). Updates state in place; returns [b,1,Q].
|
| 262 |
-
|
| 263 |
-
cfg_scales: optional tuple of classifier-free-guidance scales. When set,
|
| 264 |
-
`source_frame`/`state` are batched as [positive, neg_1, ...] with
|
| 265 |
-
arity = 1 + len(cfg_scales); per-codebook logits are combined as
|
| 266 |
-
``cond + sum_i scale_i*(cond - neg_i)`` before sampling (the native
|
| 267 |
-
MLX/.mlxfn path). The single sampled token is broadcast to all rows."""
|
| 268 |
cfg = self.cfg
|
| 269 |
tstep = temporal_step or self.temporal_step_fn
|
| 270 |
dstep = depth_step or self.depth_step_fn
|
|
@@ -281,21 +275,10 @@ class MultivariateDecoder(nn.Module):
|
|
| 281 |
logits, depth_kv = dstep(depth_input, depth_kv)
|
| 282 |
lo = cfg.num_reserved_tokens + q * cfg.codebook_size
|
| 283 |
hi = lo + cfg.codebook_size
|
| 284 |
-
if
|
| 285 |
-
cond = logits[0:1]
|
| 286 |
-
comb = cond
|
| 287 |
-
for i, s in enumerate(cfg_scales, start=1):
|
| 288 |
-
comb = comb + s * (cond - logits[i:i + 1])
|
| 289 |
-
tok = forced[..., q] if forced is not None else sampler(comb.float(), q, lo, hi)
|
| 290 |
-
depth_input = self.embed(tok.expand(logits.shape[0], -1))
|
| 291 |
-
else:
|
| 292 |
-
tok = forced[..., q] if forced is not None else sampler(logits.float(), q, lo, hi)
|
| 293 |
-
depth_input = self.embed(tok)
|
| 294 |
samples.append(tok)
|
|
|
|
| 295 |
frame = torch.stack(samples, dim=-1)
|
| 296 |
-
if cfg_scales is not None:
|
| 297 |
-
state["prev"] = frame.expand(to.shape[0], -1, -1)
|
| 298 |
-
return frame[:1]
|
| 299 |
state["prev"] = frame
|
| 300 |
return frame
|
| 301 |
|
|
|
|
| 256 |
}
|
| 257 |
|
| 258 |
def step_f(self, state, source_frame, sampler=None, forced=None,
|
| 259 |
+
temporal_step=None, depth_step=None):
|
| 260 |
"""One functional frame. temporal_step/depth_step override the eager fns
|
| 261 |
+
(e.g. with AOTI-compiled callables). Updates state in place; returns [b,1,Q]."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
cfg = self.cfg
|
| 263 |
tstep = temporal_step or self.temporal_step_fn
|
| 264 |
dstep = depth_step or self.depth_step_fn
|
|
|
|
| 275 |
logits, depth_kv = dstep(depth_input, depth_kv)
|
| 276 |
lo = cfg.num_reserved_tokens + q * cfg.codebook_size
|
| 277 |
hi = lo + cfg.codebook_size
|
| 278 |
+
tok = forced[..., q] if forced is not None else sampler(logits.float(), q, lo, hi)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
samples.append(tok)
|
| 280 |
+
depth_input = self.embed(tok)
|
| 281 |
frame = torch.stack(samples, dim=-1)
|
|
|
|
|
|
|
|
|
|
| 282 |
state["prev"] = frame
|
| 283 |
return frame
|
| 284 |
|
magenta_rt/torch/modeling_magenta_rt2.py
CHANGED
|
@@ -24,7 +24,6 @@ loop is a single token stream). MusicCoCa style encoding is a separate
|
|
| 24 |
|
| 25 |
import json
|
| 26 |
import os
|
| 27 |
-
import warnings
|
| 28 |
|
| 29 |
import numpy as np
|
| 30 |
import torch
|
|
@@ -239,32 +238,6 @@ class MagentaRT2ForConditionalGeneration(MagentaRT2PreTrainedModel):
|
|
| 239 |
]
|
| 240 |
return self._conditioning(style_tokens, notes, drums, cfgs)
|
| 241 |
|
| 242 |
-
def _guidance_source(self, style, notes, drums, cfg_musiccoca, cfg_notes):
|
| 243 |
-
"""OPTIONAL classifier-free-guidance conditioning (the native MLX/.mlxfn path).
|
| 244 |
-
Builds a 3-row batch [positive, neg_musiccoca, neg_notes] + per-component scales.
|
| 245 |
-
cfg tokens are neutralized (guidance replaces them); negatives mask style / notes.
|
| 246 |
-
Returns (source[3,Tc,enc], (cfg_musiccoca, cfg_notes))."""
|
| 247 |
-
c = self.config
|
| 248 |
-
if style is None:
|
| 249 |
-
st = [-1] * self.num_musiccoca
|
| 250 |
-
elif isinstance(style, (list, np.ndarray)) and np.asarray(style).ndim == 1 \
|
| 251 |
-
and np.asarray(style).dtype.kind in "iu" and len(style) == self.num_musiccoca:
|
| 252 |
-
st = list(style)
|
| 253 |
-
else:
|
| 254 |
-
st = self._tokenize_style(style)
|
| 255 |
-
st = (list(st) + [-1] * self.num_musiccoca)[:self.num_musiccoca]
|
| 256 |
-
notes = notes if notes is not None else [-1] * self.num_notes
|
| 257 |
-
drums = drums if drums is not None else [-1] * self.num_drums
|
| 258 |
-
CM = [-1, -1, -1] # neutralized cfg tokens
|
| 259 |
-
cond = self._conditioning(st, notes, drums, CM)
|
| 260 |
-
neg_mc = self._conditioning([-1] * self.num_musiccoca, notes, drums, CM)
|
| 261 |
-
neg_n = self._conditioning(st, [-1] * self.num_notes, drums, CM)
|
| 262 |
-
source = self.depthformer.encode(torch.cat([cond, neg_mc, neg_n], 0)).to(self._dt)
|
| 263 |
-
cfg_mc = c.cfg_musiccoca if cfg_musiccoca is None else cfg_musiccoca
|
| 264 |
-
cfg_n = c.cfg_notes if cfg_notes is None else cfg_notes
|
| 265 |
-
_warn_high_cfg(cfg_mc, cfg_n)
|
| 266 |
-
return source, (float(cfg_mc), float(cfg_n))
|
| 267 |
-
|
| 268 |
# ---- codec ----
|
| 269 |
def _decode_stream(self, history, emitted, context=STREAM_DECODE_CONTEXT,
|
| 270 |
margin=STREAM_DECODE_MARGIN, flush=False):
|
|
@@ -287,25 +260,6 @@ class MagentaRT2ForConditionalGeneration(MagentaRT2PreTrainedModel):
|
|
| 287 |
"""Fresh state dict for streaming decode (decode_stream)."""
|
| 288 |
return {}
|
| 289 |
|
| 290 |
-
@torch.no_grad()
|
| 291 |
-
def prefill_f(self, dstate, source_frame, seed_codes):
|
| 292 |
-
"""Teacher-force seed_codes [1,N,Q] (raw 0..codebook_size-1) through the
|
| 293 |
-
temporal transformer to populate its KV cache (native mlx_engine prefill
|
| 294 |
-
parity), so generation CONTINUES from the seed. Advances `dstate` in place.
|
| 295 |
-
Returns unique-code frames [1,N,Q] for the codec decoder."""
|
| 296 |
-
dec = self.depthformer.decoder
|
| 297 |
-
Q = self.config.num_codebooks
|
| 298 |
-
per_cb = (torch.arange(Q, device=seed_codes.device) * self.codebook_size
|
| 299 |
-
+ self.num_reserved_tokens).view(1, 1, Q)
|
| 300 |
-
unique = seed_codes.to(torch.long) + per_cb
|
| 301 |
-
N = unique.shape[1]
|
| 302 |
-
for step in range(max(0, N - 1)):
|
| 303 |
-
dec.step_f(dstate, source_frame, forced=unique[:, step:step + 1, :],
|
| 304 |
-
temporal_step=self._temporal_step, depth_step=self._depth_step)
|
| 305 |
-
if N > 0:
|
| 306 |
-
dstate["prev"] = unique[:, N - 1:N, :]
|
| 307 |
-
return unique
|
| 308 |
-
|
| 309 |
def decode_stream(self, new_codes, state):
|
| 310 |
"""Incremental codec decode of new token frames [b, t_new, Q] -> audio [b, N, 2].
|
| 311 |
FLOP-optimal stateful streaming (no overlap-save re-decode); bf16-equivalent to
|
|
@@ -328,35 +282,26 @@ class MagentaRT2ForConditionalGeneration(MagentaRT2PreTrainedModel):
|
|
| 328 |
@torch.no_grad()
|
| 329 |
def generate(self, style=None, notes=None, drums=None, cfg_musiccoca=None,
|
| 330 |
cfg_notes=None, cfg_drums=None, temperature=None, top_k=None,
|
| 331 |
-
frames=25, seed=0, state=None, flush=False, return_int16=False
|
| 332 |
-
guidance=False):
|
| 333 |
-
"""`guidance=False` (default): cfg_* are discretized conditioning tokens — the
|
| 334 |
-
validated in-process/JAX path, unchanged. `guidance=True`: cfg_musiccoca/cfg_notes
|
| 335 |
-
become classifier-free-guidance scales (negatives + per-codebook logit combine),
|
| 336 |
-
matching the native MLX/Mac-app path. Guidance uses eager steps (batch>1)."""
|
| 337 |
c = self.config
|
| 338 |
temperature = c.temperature if temperature is None else temperature
|
| 339 |
top_k = c.top_k if top_k is None else top_k
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
arity = len(cfg_scales) + 1
|
| 343 |
-
else:
|
| 344 |
-
cond = self._resolve_conditioning(style, notes, drums, cfg_musiccoca, cfg_notes, cfg_drums)
|
| 345 |
-
source = self.depthformer.encode(cond).to(self._dt)
|
| 346 |
-
cfg_scales, arity = None, 1
|
| 347 |
if state is None:
|
| 348 |
-
dstate = self.depthformer.decoder.init_streaming_f(
|
| 349 |
gen = torch.Generator(device=self._dev).manual_seed(seed)
|
| 350 |
-
|
|
|
|
| 351 |
else:
|
| 352 |
-
dstate, gen,
|
| 353 |
sampler = make_sampler(temperature, top_k, gen)
|
| 354 |
-
# dynamic-batch AOTI (or eager fallback) handles guidance B>1 and no-guidance B=1 alike.
|
| 355 |
toks = [self.depthformer.decoder.step_f(
|
| 356 |
-
dstate, source, sampler=sampler,
|
| 357 |
temporal_step=self._temporal_step, depth_step=self._depth_step) for _ in range(frames)]
|
| 358 |
-
|
| 359 |
-
|
|
|
|
| 360 |
wav = audio[0].float().cpu().numpy()
|
| 361 |
i16 = _float_to_int16(wav)
|
| 362 |
out = i16 if return_int16 else i16.astype(np.float32) / 32768.0
|
|
@@ -364,19 +309,9 @@ class MagentaRT2ForConditionalGeneration(MagentaRT2PreTrainedModel):
|
|
| 364 |
|
| 365 |
@torch.no_grad()
|
| 366 |
def stream(self, control, chunk_frames=10, max_seconds=55.0, seed=0,
|
| 367 |
-
time_fn=None, sleep_fn=None, notes=None, drums=None
|
| 368 |
-
cudagraph=False):
|
| 369 |
"""Continuous generation. `control()` returns {style_tokens, temperature,
|
| 370 |
-
top_k, cfg_*} read every chunk for mid-stream steering. Yields int16 [N,2].
|
| 371 |
-
|
| 372 |
-
guidance=False (default): cfg_* are conditioning tokens (validated token path,
|
| 373 |
-
unchanged). guidance=True: cfg_musiccoca/cfg_notes are classifier-free-guidance
|
| 374 |
-
scales read live every chunk. cudagraph=True: single-dispatch CUDA-graph stepping
|
| 375 |
-
(one capture at start, ~4-5x faster), steered via static input buffers."""
|
| 376 |
-
if cudagraph:
|
| 377 |
-
yield from self._stream_cudagraph(control, chunk_frames, max_seconds, seed,
|
| 378 |
-
time_fn, sleep_fn, notes, drums, guidance)
|
| 379 |
-
return
|
| 380 |
import time as _time
|
| 381 |
time_fn = time_fn or _time.time
|
| 382 |
sleep_fn = sleep_fn or _time.sleep
|
|
@@ -384,11 +319,10 @@ class MagentaRT2ForConditionalGeneration(MagentaRT2PreTrainedModel):
|
|
| 384 |
dev, dt = self._dev, self._dt
|
| 385 |
notes = notes if notes is not None else [-1] * self.num_notes
|
| 386 |
drums = drums if drums is not None else [-1] * self.num_drums
|
| 387 |
-
|
| 388 |
-
dstate = self.depthformer.decoder.init_streaming_f(arity, dev, dt)
|
| 389 |
gen = torch.Generator(device=dev).manual_seed(seed)
|
| 390 |
-
|
| 391 |
-
|
| 392 |
cur_tokens = None
|
| 393 |
source = None
|
| 394 |
t0 = time_fn()
|
|
@@ -400,269 +334,23 @@ class MagentaRT2ForConditionalGeneration(MagentaRT2PreTrainedModel):
|
|
| 400 |
tokens = ctl["style_tokens"]
|
| 401 |
if tokens != cur_tokens:
|
| 402 |
cur_tokens = tokens
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
discretize_cfg(ctl.get("cfg_drums", c.cfg_drums), 1.0, 8)]
|
| 410 |
-
source = self.depthformer.encode(self._conditioning(st, notes, drums, cfgs)).to(dt)
|
| 411 |
-
cfg_scales = ((float(ctl.get("cfg_musiccoca", c.cfg_musiccoca)), # live scales (unclamped)
|
| 412 |
-
float(ctl.get("cfg_notes", c.cfg_notes))) if guidance else None)
|
| 413 |
sampler = make_sampler(ctl.get("temperature", c.temperature), ctl.get("top_k", c.top_k), gen)
|
| 414 |
toks = [self.depthformer.decoder.step_f(
|
| 415 |
-
dstate, source, sampler=sampler,
|
| 416 |
temporal_step=self._temporal_step, depth_step=self._depth_step) for _ in range(chunk_frames)]
|
| 417 |
-
|
| 418 |
-
|
| 419 |
if audio.shape[1] > 0:
|
| 420 |
yield _float_to_int16(audio[0].float().cpu().numpy())
|
| 421 |
-
ahead = (
|
| 422 |
if ahead > 1.0:
|
| 423 |
sleep_fn(min(ahead - 1.0, 0.5))
|
| 424 |
|
| 425 |
-
@torch.no_grad()
|
| 426 |
-
def _stream_cudagraph(self, control, chunk_frames, max_seconds, seed,
|
| 427 |
-
time_fn, sleep_fn, notes, drums, guidance):
|
| 428 |
-
"""CUDA-graph backend for stream(cudagraph=True): one capture at start
|
| 429 |
-
(warmup ~KEEP frames), then single-dispatch replay per frame. Steering
|
| 430 |
-
goes through the streamer's static input buffers — cfg/temperature are
|
| 431 |
-
buffer writes; a style change re-encodes + set_source (windowed ramp)."""
|
| 432 |
-
import time as _time
|
| 433 |
-
time_fn = time_fn or _time.time
|
| 434 |
-
sleep_fn = sleep_fn or _time.sleep
|
| 435 |
-
c = self.config
|
| 436 |
-
dt = self._dt
|
| 437 |
-
notes = notes if notes is not None else [-1] * self.num_notes
|
| 438 |
-
drums = drums if drums is not None else [-1] * self.num_drums
|
| 439 |
-
|
| 440 |
-
def encode_src(tokens, cfg_mc, cfg_n):
|
| 441 |
-
st = (list(tokens) + [-1] * self.num_musiccoca)[:self.num_musiccoca]
|
| 442 |
-
if guidance:
|
| 443 |
-
return self._guidance_source(st, notes, drums, cfg_mc, cfg_n)[0]
|
| 444 |
-
cfgs = [discretize_cfg(cfg_mc, 0.2, 40), discretize_cfg(cfg_n, 0.2, 40),
|
| 445 |
-
discretize_cfg(c.cfg_drums, 1.0, 8)]
|
| 446 |
-
return self.depthformer.encode(self._conditioning(st, notes, drums, cfgs)).to(dt)
|
| 447 |
|
| 448 |
-
|
| 449 |
-
t0 = time_fn()
|
| 450 |
-
ctl = control()
|
| 451 |
-
while ctl is None and time_fn() - t0 < max_seconds:
|
| 452 |
-
sleep_fn(0.02); ctl = control()
|
| 453 |
-
if ctl is None:
|
| 454 |
-
return
|
| 455 |
-
cur_tokens = ctl["style_tokens"]
|
| 456 |
-
cur_cfg = (float(ctl.get("cfg_musiccoca", c.cfg_musiccoca)),
|
| 457 |
-
float(ctl.get("cfg_notes", c.cfg_notes)))
|
| 458 |
-
streamer = self.make_cudagraph_streamer(
|
| 459 |
-
style=cur_tokens, notes=notes, drums=drums,
|
| 460 |
-
cfg_musiccoca=cur_cfg[0], cfg_notes=cur_cfg[1],
|
| 461 |
-
temperature=ctl.get("temperature", c.temperature),
|
| 462 |
-
top_k=ctl.get("top_k", c.top_k), seed=seed, guidance=guidance)
|
| 463 |
-
decode_state = self.init_decode_state()
|
| 464 |
-
emitted_samples = 0
|
| 465 |
-
t0 = time_fn()
|
| 466 |
-
while time_fn() - t0 < max_seconds:
|
| 467 |
-
ctl = control()
|
| 468 |
-
if ctl is None:
|
| 469 |
-
sleep_fn(0.005); continue
|
| 470 |
-
tokens = ctl["style_tokens"]
|
| 471 |
-
cfg_mc = float(ctl.get("cfg_musiccoca", c.cfg_musiccoca))
|
| 472 |
-
cfg_n = float(ctl.get("cfg_notes", c.cfg_notes))
|
| 473 |
-
if guidance:
|
| 474 |
-
if (cfg_mc, cfg_n) != cur_cfg:
|
| 475 |
-
streamer.set_cfg([cfg_mc, cfg_n]); cur_cfg = (cfg_mc, cfg_n)
|
| 476 |
-
if tokens != cur_tokens:
|
| 477 |
-
streamer.set_source(encode_src(tokens, cfg_mc, cfg_n)); cur_tokens = tokens
|
| 478 |
-
elif tokens != cur_tokens or (cfg_mc, cfg_n) != cur_cfg: # token path: cfg lives in source
|
| 479 |
-
streamer.set_source(encode_src(tokens, cfg_mc, cfg_n))
|
| 480 |
-
cur_tokens, cur_cfg = tokens, (cfg_mc, cfg_n)
|
| 481 |
-
streamer.set_temperature(ctl.get("temperature", c.temperature))
|
| 482 |
-
toks = [streamer.step() for _ in range(chunk_frames)]
|
| 483 |
-
audio = self.decode_stream(torch.cat(toks, dim=1), decode_state)
|
| 484 |
-
emitted_samples += audio.shape[1]
|
| 485 |
-
if audio.shape[1] > 0:
|
| 486 |
-
yield _float_to_int16(audio[0].float().cpu().numpy())
|
| 487 |
-
ahead = (emitted_samples / SR) - (time_fn() - t0)
|
| 488 |
-
if ahead > 1.0:
|
| 489 |
-
sleep_fn(min(ahead - 1.0, 0.5))
|
| 490 |
-
|
| 491 |
-
@torch.no_grad()
|
| 492 |
-
def make_cudagraph_streamer(self, style=None, notes=None, drums=None,
|
| 493 |
-
cfg_musiccoca=None, cfg_notes=None, cfg_drums=None,
|
| 494 |
-
temperature=None, top_k=None, seed=0, guidance=False,
|
| 495 |
-
warmup=None):
|
| 496 |
-
"""One-dispatch-per-frame CUDA-graph streaming: captures the whole frame
|
| 497 |
-
(temporal + N-codebook depth + in-graph sampler + optional CFG) as a single
|
| 498 |
-
`torch.cuda.graph` replay over fixed-size static KV buffers — ~MLX `.mlxfn`.
|
| 499 |
-
Returns a `CudaGraphStreamer`; call `.step()` for the next frame [1,1,Q]
|
| 500 |
-
(decode with `decode_stream`), and `.set_cfg/.set_temperature/.set_source`
|
| 501 |
-
for live steering (no re-capture). `top_k` is fixed at capture time."""
|
| 502 |
-
if guidance:
|
| 503 |
-
source, scales = self._guidance_source(style, notes, drums, cfg_musiccoca, cfg_notes)
|
| 504 |
-
num_neg = len(scales)
|
| 505 |
-
else:
|
| 506 |
-
cond = self._resolve_conditioning(style, notes, drums, cfg_musiccoca, cfg_notes, cfg_drums)
|
| 507 |
-
source = self.depthformer.encode(cond).to(self._dt)
|
| 508 |
-
scales, num_neg = None, 0
|
| 509 |
-
temperature = self.config.temperature if temperature is None else temperature
|
| 510 |
-
top_k = self.config.top_k if top_k is None else top_k
|
| 511 |
-
return CudaGraphStreamer(self.depthformer.decoder, source, self._dt, num_neg, scales,
|
| 512 |
-
temperature, top_k, seed, warmup)
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
# Classifier-free guidance scales above this can run away / collapse the output
|
| 516 |
-
# to silence under *sustained constant* conditioning over long runs (the native
|
| 517 |
-
# UI uses a 0-5 slider, default 2.4). We don't clamp — values pass through to
|
| 518 |
-
# match the native range — but we warn once so the caller knows the risk.
|
| 519 |
-
GUIDANCE_CFG_WARN = 3.5
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
def _warn_high_cfg(*scales):
|
| 523 |
-
hi = [round(float(s), 2) for s in scales if float(s) > GUIDANCE_CFG_WARN]
|
| 524 |
-
if hi:
|
| 525 |
-
warnings.warn(
|
| 526 |
-
f"CFG guidance scale(s) {hi} exceed ~{GUIDANCE_CFG_WARN}; sustained high "
|
| 527 |
-
"guidance on constant conditioning can make the output run away / collapse "
|
| 528 |
-
"to silence over long runs. (Changing notes/style during play avoids this.)",
|
| 529 |
-
stacklevel=3)
|
| 530 |
-
return True
|
| 531 |
-
return False
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
class CudaGraphStreamer:
|
| 535 |
-
"""Single-dispatch CUDA-graph frame stepper over fixed-size static KV buffers.
|
| 536 |
-
|
| 537 |
-
Warms `KEEP` frames eagerly to fill the temporal/cross KV to steady state,
|
| 538 |
-
snapshots them into static buffers, then captures one frame (temporal + depth +
|
| 539 |
-
sampler) with `torch.cuda.graph`. `.step()` replays it (one GPU dispatch) and
|
| 540 |
-
returns the new frame tokens. Live steering writes into static input buffers
|
| 541 |
-
(`source`, `cfg`, `temperature`) — the captured graph reads them, no re-capture.
|
| 542 |
-
Conditioning changes ramp in via the windowed cross-KV (optional hard flush)."""
|
| 543 |
-
|
| 544 |
-
def __init__(self, decoder, source, decode_dtype, num_neg=0, cfg_scales=None,
|
| 545 |
-
temperature=1.1, top_k=50, seed=0, warmup=None):
|
| 546 |
-
"""decoder: a MultivariateDecoder (`model.depthformer.decoder` for the
|
| 547 |
-
modeling class, `model.model.decoder` for the system class). `source` is the
|
| 548 |
-
pre-encoded conditioning [B, Tc, enc] (B = 1 + num_neg); `decode_dtype` the
|
| 549 |
-
compute dtype. Class-agnostic so both model wrappers can build it."""
|
| 550 |
-
dec = decoder
|
| 551 |
-
c = dec.cfg
|
| 552 |
-
self.dec = dec
|
| 553 |
-
self.Q, self.CB, self.NR = c.num_codebooks, c.codebook_size, c.num_reserved_tokens
|
| 554 |
-
self.KEEP = c.temporal_max_past + 1
|
| 555 |
-
self.num_neg = num_neg
|
| 556 |
-
self.top_k = int(top_k)
|
| 557 |
-
dev, dt = source.device, decode_dtype
|
| 558 |
-
B = source.shape[0]; self.B = B
|
| 559 |
-
# live-steering static inputs
|
| 560 |
-
self.source = source.clone()
|
| 561 |
-
self.cfg = (torch.zeros(0, device=dev, dtype=torch.float32) if not num_neg
|
| 562 |
-
else torch.tensor([float(s) for s in cfg_scales], device=dev, dtype=torch.float32))
|
| 563 |
-
self.temp = torch.tensor(float(temperature), device=dev, dtype=torch.float32)
|
| 564 |
-
torch.manual_seed(seed)
|
| 565 |
-
# 1) prime to steady state (KV == KEEP on every layer)
|
| 566 |
-
st = dec.init_streaming_f(B, dev, dt)
|
| 567 |
-
K = self.KEEP
|
| 568 |
-
for _ in range(K + 8 if warmup is None else warmup):
|
| 569 |
-
to, ns, nc = dec.temporal_step_fn(st["prev"], st["self"], st["cross"], self.source)
|
| 570 |
-
st["self"] = [(k[:, -K:], v[:, -K:]) for k, v in ns]
|
| 571 |
-
st["cross"] = [(k[:, -K:], v[:, -K:]) for k, v in nc]
|
| 572 |
-
frame = self._depth_sample(to)
|
| 573 |
-
st["prev"] = frame.expand(B, -1, -1)
|
| 574 |
-
# 2) static KV + state buffers
|
| 575 |
-
L = len(st["self"]); self.L = L
|
| 576 |
-
self.SK = [st["self"][i][0].clone() for i in range(L)]; self.SV = [st["self"][i][1].clone() for i in range(L)]
|
| 577 |
-
self.CK = [st["cross"][i][0].clone() for i in range(L)]; self.CV = [st["cross"][i][1].clone() for i in range(L)]
|
| 578 |
-
self.prev = st["prev"].clone()
|
| 579 |
-
self.out = torch.zeros(1, 1, self.Q, dtype=torch.long, device=dev)
|
| 580 |
-
# 3) capture (side-stream warmup is required before graph capture)
|
| 581 |
-
s = torch.cuda.Stream(); s.wait_stream(torch.cuda.current_stream())
|
| 582 |
-
with torch.cuda.stream(s):
|
| 583 |
-
for _ in range(3):
|
| 584 |
-
self._frame_static()
|
| 585 |
-
torch.cuda.current_stream().wait_stream(s)
|
| 586 |
-
self.graph = torch.cuda.CUDAGraph()
|
| 587 |
-
with torch.cuda.graph(self.graph):
|
| 588 |
-
self._frame_static()
|
| 589 |
-
|
| 590 |
-
def _depth_sample(self, to):
|
| 591 |
-
dec = self.dec; B = self.B; Q, CB, NR = self.Q, self.CB, self.NR
|
| 592 |
-
dd = dec.cfg.depth
|
| 593 |
-
z = torch.zeros(B, 0, dd.num_heads, dd.dim_per_head, device=to.device, dtype=to.dtype)
|
| 594 |
-
dk = [(z, z) for _ in range(dd.num_layers)]
|
| 595 |
-
di = to; toks = []
|
| 596 |
-
for q in range(Q):
|
| 597 |
-
logits, dk = dec.depth_step_fn(di, dk) # [B,1,V]
|
| 598 |
-
lo = NR + q * CB
|
| 599 |
-
ls = logits[..., lo:lo + CB]
|
| 600 |
-
cond = ls[0:1]; comb = cond
|
| 601 |
-
for i in range(self.num_neg): # classifier-free guidance combine
|
| 602 |
-
comb = comb + self.cfg[i] * (cond - ls[i + 1:i + 2])
|
| 603 |
-
kth = torch.topk(comb, self.top_k, dim=-1).values[..., -1:]
|
| 604 |
-
comb = torch.where(comb >= kth, comb, torch.full_like(comb, -1e9))
|
| 605 |
-
u = torch.rand(1, 1, CB, device=to.device, dtype=torch.float32) # graph-safe RNG
|
| 606 |
-
g = -torch.log(-torch.log(u.clamp(1e-10, 1 - 1e-7)))
|
| 607 |
-
tok = (comb + g * self.temp).argmax(-1) + lo
|
| 608 |
-
toks.append(tok)
|
| 609 |
-
di = dec.embed(tok.expand(B, -1))
|
| 610 |
-
return torch.stack(toks, dim=-1) # [1,1,Q]
|
| 611 |
-
|
| 612 |
-
def _frame_static(self):
|
| 613 |
-
dec = self.dec; K = self.KEEP; L = self.L
|
| 614 |
-
to, ns, nc = dec.temporal_step_fn(
|
| 615 |
-
self.prev, [(self.SK[i], self.SV[i]) for i in range(L)],
|
| 616 |
-
[(self.CK[i], self.CV[i]) for i in range(L)], self.source)
|
| 617 |
-
for i in range(L):
|
| 618 |
-
self.SK[i].copy_(ns[i][0][:, -K:]); self.SV[i].copy_(ns[i][1][:, -K:])
|
| 619 |
-
self.CK[i].copy_(nc[i][0][:, -K:]); self.CV[i].copy_(nc[i][1][:, -K:])
|
| 620 |
-
frame = self._depth_sample(to)
|
| 621 |
-
self.out.copy_(frame)
|
| 622 |
-
self.prev.copy_(frame.expand(self.B, -1, -1))
|
| 623 |
-
|
| 624 |
-
# ---- live steering (no re-capture) ----
|
| 625 |
-
def set_cfg(self, scales):
|
| 626 |
-
if self.num_neg:
|
| 627 |
-
if not getattr(self, "_cfg_warned", False):
|
| 628 |
-
self._cfg_warned = _warn_high_cfg(*scales)
|
| 629 |
-
self.cfg.copy_(torch.tensor([float(s) for s in scales],
|
| 630 |
-
device=self.cfg.device, dtype=torch.float32))
|
| 631 |
-
|
| 632 |
-
def set_temperature(self, t):
|
| 633 |
-
self.temp.fill_(float(t))
|
| 634 |
-
|
| 635 |
-
def set_source(self, source, flush=False):
|
| 636 |
-
"""Update conditioning. Ramps in via the windowed cross-KV; flush=True
|
| 637 |
-
overwrites all cross-KV slots for an immediate change."""
|
| 638 |
-
self.source.copy_(source if source.shape[0] == self.B else source.expand(self.B, -1, -1))
|
| 639 |
-
if flush:
|
| 640 |
-
for i in range(self.L):
|
| 641 |
-
sk, sv = self.dec.temporal_body.layers[i]["cross_attention"]._kv(self.source)
|
| 642 |
-
self.CK[i].copy_(sk[:, -self.KEEP:]); self.CV[i].copy_(sv[:, -self.KEEP:])
|
| 643 |
-
|
| 644 |
-
def step(self):
|
| 645 |
-
"""Advance one frame (single CUDA-graph dispatch). Returns tokens [1,1,Q]."""
|
| 646 |
-
self.graph.replay()
|
| 647 |
-
return self.out.clone()
|
| 648 |
-
|
| 649 |
-
def close(self):
|
| 650 |
-
"""Free the captured CUDA graph + its private memory pool. Idempotent;
|
| 651 |
-
call at session end (the WS worker should). Safe during interpreter
|
| 652 |
-
shutdown — swallows teardown-ordering errors."""
|
| 653 |
-
g = getattr(self, "graph", None)
|
| 654 |
-
if g is not None:
|
| 655 |
-
try:
|
| 656 |
-
g.reset()
|
| 657 |
-
except Exception:
|
| 658 |
-
pass
|
| 659 |
-
self.graph = None
|
| 660 |
-
|
| 661 |
-
def __del__(self):
|
| 662 |
-
try:
|
| 663 |
-
self.close()
|
| 664 |
-
except Exception:
|
| 665 |
-
pass
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
__all__ = ["MagentaRT2ForConditionalGeneration", "MagentaRT2PreTrainedModel", "CudaGraphStreamer"]
|
|
|
|
| 24 |
|
| 25 |
import json
|
| 26 |
import os
|
|
|
|
| 27 |
|
| 28 |
import numpy as np
|
| 29 |
import torch
|
|
|
|
| 238 |
]
|
| 239 |
return self._conditioning(style_tokens, notes, drums, cfgs)
|
| 240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
# ---- codec ----
|
| 242 |
def _decode_stream(self, history, emitted, context=STREAM_DECODE_CONTEXT,
|
| 243 |
margin=STREAM_DECODE_MARGIN, flush=False):
|
|
|
|
| 260 |
"""Fresh state dict for streaming decode (decode_stream)."""
|
| 261 |
return {}
|
| 262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
def decode_stream(self, new_codes, state):
|
| 264 |
"""Incremental codec decode of new token frames [b, t_new, Q] -> audio [b, N, 2].
|
| 265 |
FLOP-optimal stateful streaming (no overlap-save re-decode); bf16-equivalent to
|
|
|
|
| 282 |
@torch.no_grad()
|
| 283 |
def generate(self, style=None, notes=None, drums=None, cfg_musiccoca=None,
|
| 284 |
cfg_notes=None, cfg_drums=None, temperature=None, top_k=None,
|
| 285 |
+
frames=25, seed=0, state=None, flush=False, return_int16=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
c = self.config
|
| 287 |
temperature = c.temperature if temperature is None else temperature
|
| 288 |
top_k = c.top_k if top_k is None else top_k
|
| 289 |
+
cond = self._resolve_conditioning(style, notes, drums, cfg_musiccoca, cfg_notes, cfg_drums)
|
| 290 |
+
source = self.depthformer.encode(cond).to(self._dt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
if state is None:
|
| 292 |
+
dstate = self.depthformer.decoder.init_streaming_f(1, self._dev, self._dt)
|
| 293 |
gen = torch.Generator(device=self._dev).manual_seed(seed)
|
| 294 |
+
history = torch.zeros((1, 0, c.num_codebooks), dtype=torch.long, device=self._dev)
|
| 295 |
+
emitted = 0
|
| 296 |
else:
|
| 297 |
+
dstate, gen, history, emitted = state["dstate"], state["gen"], state["history"], state["emitted"]
|
| 298 |
sampler = make_sampler(temperature, top_k, gen)
|
|
|
|
| 299 |
toks = [self.depthformer.decoder.step_f(
|
| 300 |
+
dstate, source, sampler=sampler,
|
| 301 |
temporal_step=self._temporal_step, depth_step=self._depth_step) for _ in range(frames)]
|
| 302 |
+
history = torch.cat([history] + toks, dim=1)
|
| 303 |
+
audio, emitted = self._decode_stream(history, emitted, flush=flush)
|
| 304 |
+
new_state = {"dstate": dstate, "gen": gen, "history": history, "emitted": emitted}
|
| 305 |
wav = audio[0].float().cpu().numpy()
|
| 306 |
i16 = _float_to_int16(wav)
|
| 307 |
out = i16 if return_int16 else i16.astype(np.float32) / 32768.0
|
|
|
|
| 309 |
|
| 310 |
@torch.no_grad()
|
| 311 |
def stream(self, control, chunk_frames=10, max_seconds=55.0, seed=0,
|
| 312 |
+
time_fn=None, sleep_fn=None, notes=None, drums=None):
|
|
|
|
| 313 |
"""Continuous generation. `control()` returns {style_tokens, temperature,
|
| 314 |
+
top_k, cfg_*} read every chunk for mid-stream steering. Yields int16 [N,2]."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
import time as _time
|
| 316 |
time_fn = time_fn or _time.time
|
| 317 |
sleep_fn = sleep_fn or _time.sleep
|
|
|
|
| 319 |
dev, dt = self._dev, self._dt
|
| 320 |
notes = notes if notes is not None else [-1] * self.num_notes
|
| 321 |
drums = drums if drums is not None else [-1] * self.num_drums
|
| 322 |
+
dstate = self.depthformer.decoder.init_streaming_f(1, dev, dt)
|
|
|
|
| 323 |
gen = torch.Generator(device=dev).manual_seed(seed)
|
| 324 |
+
history = torch.zeros((1, 0, c.num_codebooks), dtype=torch.long, device=dev)
|
| 325 |
+
emitted = 0
|
| 326 |
cur_tokens = None
|
| 327 |
source = None
|
| 328 |
t0 = time_fn()
|
|
|
|
| 334 |
tokens = ctl["style_tokens"]
|
| 335 |
if tokens != cur_tokens:
|
| 336 |
cur_tokens = tokens
|
| 337 |
+
cfgs = [discretize_cfg(ctl.get("cfg_musiccoca", c.cfg_musiccoca), 0.2, 40),
|
| 338 |
+
discretize_cfg(ctl.get("cfg_notes", c.cfg_notes), 0.2, 40),
|
| 339 |
+
discretize_cfg(ctl.get("cfg_drums", c.cfg_drums), 1.0, 8)]
|
| 340 |
+
cond = self._conditioning((list(tokens) + [-1] * self.num_musiccoca)[:self.num_musiccoca],
|
| 341 |
+
notes, drums, cfgs)
|
| 342 |
+
source = self.depthformer.encode(cond).to(dt)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
sampler = make_sampler(ctl.get("temperature", c.temperature), ctl.get("top_k", c.top_k), gen)
|
| 344 |
toks = [self.depthformer.decoder.step_f(
|
| 345 |
+
dstate, source, sampler=sampler,
|
| 346 |
temporal_step=self._temporal_step, depth_step=self._depth_step) for _ in range(chunk_frames)]
|
| 347 |
+
history = torch.cat([history] + toks, dim=1)
|
| 348 |
+
audio, emitted = self._decode_stream(history, emitted)
|
| 349 |
if audio.shape[1] > 0:
|
| 350 |
yield _float_to_int16(audio[0].float().cpu().numpy())
|
| 351 |
+
ahead = (emitted * FRAME_SAMPLES / SR) - (time_fn() - t0)
|
| 352 |
if ahead > 1.0:
|
| 353 |
sleep_fn(min(ahead - 1.0, 0.5))
|
| 354 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
+
__all__ = ["MagentaRT2ForConditionalGeneration", "MagentaRT2PreTrainedModel"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
magenta_rt/torch/processing_musiccoca.py
DELETED
|
@@ -1,77 +0,0 @@
|
|
| 1 |
-
# Copyright 2026 Google LLC
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
"""MusicCoCa style processor for Magenta RealTime 2.
|
| 16 |
-
|
| 17 |
-
A processor-style component (like a feature extractor / tokenizer): turns a text
|
| 18 |
-
prompt OR an audio clip into 12 RVQ style tokens that condition the model. Pure
|
| 19 |
-
torch + sentencepiece (text tower, audio tower, RVQ all torch-native).
|
| 20 |
-
"""
|
| 21 |
-
|
| 22 |
-
import os
|
| 23 |
-
|
| 24 |
-
import numpy as np
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
class MusicCoCaProcessor:
|
| 28 |
-
"""Text/audio -> 12 RVQ style tokens (and 768-d embeddings, for layering)."""
|
| 29 |
-
|
| 30 |
-
def __init__(self, musiccoca):
|
| 31 |
-
self._mc = musiccoca
|
| 32 |
-
|
| 33 |
-
@classmethod
|
| 34 |
-
def from_pretrained(cls, pretrained_model_name_or_path, device="cpu", **kwargs):
|
| 35 |
-
from .musiccoca import MusicCoCa
|
| 36 |
-
p = pretrained_model_name_or_path
|
| 37 |
-
if p is not None and os.path.isdir(p) and os.path.exists(os.path.join(p, "text_encoder.pt")):
|
| 38 |
-
mc = MusicCoCa(resource_dir=p, device=device)
|
| 39 |
-
else:
|
| 40 |
-
mc = MusicCoCa(repo_id=p, device=device) if p else MusicCoCa(device=device)
|
| 41 |
-
return cls(mc)
|
| 42 |
-
|
| 43 |
-
def save_pretrained(self, save_directory, **kwargs):
|
| 44 |
-
# Artifacts live in the MusicCoCa hub repo; nothing extra to serialize here.
|
| 45 |
-
os.makedirs(save_directory, exist_ok=True)
|
| 46 |
-
|
| 47 |
-
@property
|
| 48 |
-
def device(self):
|
| 49 |
-
return self._mc.device
|
| 50 |
-
|
| 51 |
-
def to(self, device):
|
| 52 |
-
self._mc.to(device)
|
| 53 |
-
return self
|
| 54 |
-
|
| 55 |
-
def embed(self, text_or_audio):
|
| 56 |
-
"""Text str / audio (Waveform | (samples, sr) | np@16kHz) -> [768] torch."""
|
| 57 |
-
return self._mc.embed(text_or_audio)
|
| 58 |
-
|
| 59 |
-
def tokenize(self, embedding):
|
| 60 |
-
"""[768] embedding -> [12] int RVQ tokens (np.int64)."""
|
| 61 |
-
return self._mc.tokenize(embedding)
|
| 62 |
-
|
| 63 |
-
def layer(self, prompts, weights=None):
|
| 64 |
-
"""Blend several prompts (text/audio) by weighted-mean of embeddings,
|
| 65 |
-
then tokenize. `prompts` is a list; `weights` defaults to uniform."""
|
| 66 |
-
embs = [self.embed(p) for p in prompts]
|
| 67 |
-
w = weights or [1.0 / len(embs)] * len(embs)
|
| 68 |
-
emb = sum(wi * e for wi, e in zip(w, embs))
|
| 69 |
-
return self.tokenize(emb).tolist()
|
| 70 |
-
|
| 71 |
-
def __call__(self, text_or_audio, return_tokens=True):
|
| 72 |
-
"""-> 12 style tokens (list[int]) by default, or the [768] embedding."""
|
| 73 |
-
emb = self.embed(text_or_audio)
|
| 74 |
-
return self.tokenize(emb).tolist() if return_tokens else emb
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
__all__ = ["MusicCoCaProcessor"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|