ai-techno-dj / stem_render.py
Rik Hoffbauer
Implement disk-based caching for stem separation
b198075
# stem_render.py β€” Stem-based rendering with demucs
import numpy as np
import scipy.io.wavfile
import tempfile
import logging
import gradio as gr
logger = logging.getLogger("dj_engine")
def render_full_set_with_stems(app_state, max_iter=20, stem_backend="spleeter", progress=gr.Progress()):
"""Render the DJ set using stem separation.
max_iter: number of refinement iterations (from the UI slider)
stem_backend: specify 'demucs', 'demucs-mlx', or 'spleeter'. Defaults to preferred_stem_backend()
"""
if not app_state.transitions:
return None, "⚠️ Generate a set plan first"
progress(0.02, desc="Starting stem-based render...")
sr = 44100
try:
from stem_mixer import mix_stems
from stem_separator import preferred_stem_backend, separate_stems_with_backend
actual_backend = stem_backend if stem_backend else preferred_stem_backend()
progress(0.03, desc=f"Loading stem separator ({actual_backend})...")
# Separate each track into stems (cache-aware β€” skips separation on cache hit)
all_stems = {}
backend_used = actual_backend
n_tracks = len(app_state.set_order)
for i, tidx in enumerate(app_state.set_order):
track = app_state.analyses[tidx]
progress(0.03 + (i / n_tracks) * 0.50,
desc=f"Separating stems ({actual_backend}, {i+1}/{n_tracks}): {track.filename[:30]}...")
stems, backend_used = separate_stems_with_backend(track.path, 0.0, None, sr, backend=stem_backend)
all_stems[tidx] = stems
logger.info(f"Stems for {track.filename} via {backend_used}: {list(stems.keys())}")
# Mix using stem mixer
progress(0.55, desc="Mixing with stems (surgical drum/bass swap)...")
set_audio, set_info = mix_stems(
all_stems, app_state.analyses, app_state.set_order,
progress_cb=lambda p, m: progress(0.55 + p * 0.35, desc=m)
)
method = f"{backend_used} htdemucs β†’ surgical drum/bass swap on downbeats"
except Exception as e:
logger.warning(f"Stem separation failed: {e}")
import traceback
traceback.print_exc()
# Fallback to the original filter-based mixer with refinement loop
from mixer import mix_set
from quality_analyzer import run_refinement_loop, format_analysis_log
progress(0.10, desc="Fallback: filter-based mixing with refinement...")
set_audio, set_info, _ = run_refinement_loop(
mix_fn=mix_set,
tracks=app_state.analyses,
order=app_state.set_order,
transitions=app_state.transitions,
max_iter=int(max_iter),
progress_cb=lambda p, m: progress(0.10 + p * 0.80, desc=m)
)
method = f"Filter-based with {int(max_iter)} refinement iterations (demucs failed: {e})"
app_state.rendered_set = set_audio
progress(0.92, desc="Saving audio...")
tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
audio_int16 = (set_audio.T * 32767).astype(np.int16)
scipy.io.wavfile.write(tmp.name, sr, audio_int16)
# Summary
summary = f"# βœ… DJ Set Rendered\n\n"
summary += f"- **Total duration:** {set_info.get('total_duration', 0):.1f}s "
summary += f"({set_info.get('total_duration', 0)/60:.1f} min)\n"
summary += f"- **Tracks:** {len(set_info.get('tracks', []))}\n"
summary += f"- **Method:** {method}\n\n"
summary += "## Tracklist\n"
for i, t in enumerate(set_info.get('tracks', [])):
fn = t.get('filename', '?')
tl = t.get('tl_start', 0)
stretch = t.get('stretch', 1.0)
extra = f" (Γ—{stretch:.3f})" if abs(stretch - 1.0) > 0.003 else ""
summary += f"{i+1}. **{fn}** β€” starts at {tl:.0f}s{extra}\n"
if set_info.get('transitions'):
summary += "\n## Transitions\n"
for t in set_info['transitions']:
if isinstance(t, dict):
summary += f"- {t}\n"
else:
summary += f"- {t}\n"
return tmp.name, summary