magenta-retry / one_shot_generation.py
thecollabagepatch's picture
one shot generations start failing after a few successful ones...
2446a8b
raw
history blame
11.7 kB
"""
One-shot music generation functions for MagentaRT.
This module contains the core generation functions extracted from the main app
that can be used independently for single-shot music generation tasks.
"""
import math
import numpy as np
from magenta_rt import audio as au
from utils import (
match_loudness_to_reference,
stitch_generated,
hard_trim_seconds,
apply_micro_fades,
make_bar_aligned_context,
take_bar_aligned_tail
)
def generate_loop_continuation_with_mrt(
mrt,
input_wav_path: str,
bpm: float,
extra_styles=None,
style_weights=None,
bars: int = 8,
beats_per_bar: int = 4,
loop_weight: float = 1.0,
loudness_mode: str = "auto",
loudness_headroom_db: float = 1.0,
intro_bars_to_drop: int = 0,
progress_cb=None
):
"""
Generate a continuation of an input loop using MagentaRT.
"""
# ===== NEW: Force codec/model reset before generation =====
# Clear any accumulated state in the codec that might cause silence issues
try:
# Option 1: If codec has explicit reset
if hasattr(mrt.codec, 'reset') and callable(mrt.codec.reset):
mrt.codec.reset()
# Option 2: Force clear any cached codec state
if hasattr(mrt.codec, '_encode_cache'):
mrt.codec._encode_cache = None
if hasattr(mrt.codec, '_decode_cache'):
mrt.codec._decode_cache = None
# Option 3: Clear JAX compilation caches (nuclear but effective)
# Uncomment if issues persist:
# import jax
# jax.clear_caches()
except Exception as e:
import logging
logging.warning(f"Codec reset attempt failed (non-fatal): {e}")
# ============================================================
# Load & prep (unchanged)
loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
# Use tail for context
codec_fps = float(mrt.codec.frame_rate)
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
# ===== NEW: Force fresh token copies =====
tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32, copy=True) # ← Added copy=True
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth].copy() # ← Added .copy()
# ==========================================
# Bar-aligned token window
context_tokens = make_bar_aligned_context(
tokens, bpm=bpm, fps=float(mrt.codec.frame_rate),
ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar
)
# ===== NEW: More aggressive state initialization =====
state = mrt.init_state()
# Ensure context_tokens is a fresh array, not a view
state.context_tokens = np.array(context_tokens, dtype=np.int32, copy=True)
# If there's any internal model state cache, clear it
if hasattr(state, '_cache'):
state._cache = None
# =====================================================
# STYLE embed (unchanged but ensure fresh embedding)
loop_embed = mrt.embed_style(loop_for_context)
embeds, weights = [loop_embed.copy()], [float(loop_weight)] # ← Added .copy()
if extra_styles:
for i, s in enumerate(extra_styles):
if s.strip():
embeds.append(mrt.embed_style(s.strip()).copy()) # ← Added .copy()
w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0
weights.append(float(w))
wsum = float(sum(weights)) or 1.0
weights = [w / wsum for w in weights]
combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype, copy=True) # ← Added copy=True
# --- Length math (unchanged) ---
seconds_per_bar = beats_per_bar * (60.0 / bpm)
total_secs = bars * seconds_per_bar
drop_bars = max(0, int(intro_bars_to_drop))
drop_secs = min(drop_bars, bars) * seconds_per_bar
gen_total_secs = total_secs + drop_secs
chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate
steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
if progress_cb:
progress_cb(0, steps)
# ===== NEW: Generation loop with explicit state refresh =====
chunks = []
for i in range(steps):
# Generate chunk with current state
wav, new_state = mrt.generate_chunk(state=state, style=combined_style)
chunks.append(wav)
# CRITICAL: Replace state, don't mutate it
# This ensures we're not accumulating corrupted state
state = new_state
if progress_cb:
progress_cb(i + 1, steps)
# ============================================================
# Rest of the function unchanged...
stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
stitched = hard_trim_seconds(stitched, gen_total_secs)
if drop_secs > 0:
n_drop = int(round(drop_secs * stitched.sample_rate))
stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
out = hard_trim_seconds(stitched, total_secs)
out, loud_stats = apply_barwise_loudness_match(
out=out,
ref_loop=loop,
bpm=bpm,
beats_per_bar=beats_per_bar,
method=loudness_mode,
headroom_db=loudness_headroom_db,
smooth_ms=50,
)
apply_micro_fades(out, 5)
return out, loud_stats
def generate_style_only_with_mrt(
mrt,
bpm: float,
bars: int = 8,
beats_per_bar: int = 4,
styles: str = "warmup",
style_weights: str = "",
intro_bars_to_drop: int = 0,
):
"""
Style-only, bar-aligned generation using a silent context (no input audio).
Returns: (au.Waveform out, dict loud_stats_or_None)
"""
# ---- Build a 10s silent context, tokenized for the model ----
codec_fps = float(mrt.codec.frame_rate)
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
sr = int(mrt.sample_rate)
silent = au.Waveform(np.zeros((int(round(ctx_seconds * sr)), 2), np.float32), sr)
tokens_full = mrt.codec.encode(silent).astype(np.int32)
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
state = mrt.init_state()
state.context_tokens = tokens
# ---- Style vector (text prompts only, normalized weights) ----
prompts = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
if not prompts:
prompts = ["warmup"]
sw = [float(x) for x in style_weights.split(",")] if style_weights else []
embeds, weights = [], []
for i, p in enumerate(prompts):
embeds.append(mrt.embed_style(p))
weights.append(sw[i] if i < len(sw) else 1.0)
wsum = float(sum(weights)) or 1.0
weights = [w / wsum for w in weights]
style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
# ---- Target length math ----
seconds_per_bar = beats_per_bar * (60.0 / bpm)
total_secs = bars * seconds_per_bar
drop_bars = max(0, int(intro_bars_to_drop))
drop_secs = min(drop_bars, bars) * seconds_per_bar
gen_total_secs = total_secs + drop_secs
# ~2.0s chunk length from model config
chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
# Generate enough chunks to cover total, plus a pad chunk for crossfade headroom
steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
chunks = []
for _ in range(steps):
wav, state = mrt.generate_chunk(state=state, style=style_vec)
chunks.append(wav)
# Stitch & trim to exact musical length
stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
stitched = hard_trim_seconds(stitched, gen_total_secs)
if drop_secs > 0:
n_drop = int(round(drop_secs * stitched.sample_rate))
stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
out = hard_trim_seconds(stitched, total_secs)
out = out.peak_normalize(0.95)
apply_micro_fades(out, 5)
return out, None # loudness stats not applicable (no reference)
# loudness matching helper for /generate:
def apply_barwise_loudness_match(
out: au.Waveform,
ref_loop: au.Waveform,
*,
bpm: float,
beats_per_bar: int,
method: str = "auto",
headroom_db: float = 1.0,
smooth_ms: int = 50,
) -> tuple[au.Waveform, dict]:
"""
Bar-locked loudness matching that establishes the correct starting level
then maintains consistency. Only the first bar is matched to the reference;
subsequent bars use the same gain to maintain relative dynamics.
"""
sr = int(out.sample_rate)
spb = (60.0 / float(bpm)) * int(beats_per_bar)
bar_len = int(round(spb * sr))
y = out.samples.astype(np.float32, copy=False)
if y.ndim == 1: y = y[:, None]
if ref_loop.sample_rate != sr:
ref = ref_loop.resample(sr).as_stereo().samples.astype(np.float32, copy=False)
else:
ref = ref_loop.as_stereo().samples.astype(np.float32, copy=False)
if ref.ndim == 1: ref = ref[:, None]
if ref.shape[1] == 1: ref = np.repeat(ref, 2, axis=1)
from utils import match_loudness_to_reference
# Measure reference loudness once
ref_bar_len = min(ref.shape[0], bar_len)
ref_bar = au.Waveform(ref[:ref_bar_len], sr)
gains_db = []
out_adj = y.copy()
need = y.shape[0]
n_bars = max(1, int(np.ceil(need / float(bar_len))))
ramp = int(max(0, round(smooth_ms * sr / 1000.0)))
min_lufs_samples = int(0.4 * sr)
# Calculate gain from bar 0 matching
first_bar_gain_linear = 1.0
for i in range(n_bars):
s = i * bar_len
e = min(need, s + bar_len)
if e <= s:
break
bar_samples = e - s
tgt_bar = au.Waveform(y[s:e], sr) # Always read from ORIGINAL
# First bar: match to reference to establish gain
if i == 0:
effective_method = "rms" if bar_samples < min_lufs_samples else method
matched_bar, stats = match_loudness_to_reference(
ref_bar, tgt_bar, method=effective_method, headroom_db=headroom_db
)
# Calculate the linear gain that was applied
eps = 1e-12
first_bar_gain_linear = float(np.sqrt(
(np.mean(matched_bar.samples**2) + eps) /
(np.mean(tgt_bar.samples**2) + eps)
))
g = matched_bar.samples.astype(np.float32, copy=False)
else:
# Subsequent bars: apply the same gain from bar 0
g = (tgt_bar.samples * first_bar_gain_linear).astype(np.float32, copy=False)
# Calculate gain in dB for stats
if tgt_bar.samples.size > 0:
eps = 1e-12
g_lin = float(np.sqrt((np.mean(g**2) + eps) / (np.mean(tgt_bar.samples**2) + eps)))
else:
g_lin = 1.0
gains_db.append(20.0 * np.log10(max(g_lin, 1e-6)))
# Apply with ramp for smoothness
if i > 0 and ramp > 0:
ramp_len = min(ramp, e - s)
t = np.linspace(0.0, 1.0, ramp_len, dtype=np.float32)[:, None]
out_adj[s:s+ramp_len] = (1.0 - t) * out_adj[s:s+ramp_len] + t * g[:ramp_len]
out_adj[s+ramp_len:e] = g[ramp_len:e-s]
else:
out_adj[s:e] = g
out.samples = out_adj.astype(np.float32, copy=False)
return out, {"per_bar_gain_db": gains_db}