# stem_render.py — Stem-based rendering with demucs import numpy as np import librosa import scipy.io.wavfile import tempfile import logging import gradio as gr logger = logging.getLogger("dj_engine") def render_full_set_with_stems(app_state, max_iter=20, progress=gr.Progress()): """Render the DJ set using demucs stem separation. max_iter: number of refinement iterations (from the UI slider) """ if not app_state.transitions: return None, "⚠️ Generate a set plan first" progress(0.02, desc="Starting stem-based render...") sr = 44100 try: import torch from demucs.pretrained import get_model from demucs.apply import apply_model from stem_mixer import mix_stems device = "cuda" if torch.cuda.is_available() else "cpu" progress(0.03, desc=f"Loading demucs htdemucs ({device})...") model = get_model("htdemucs") model.eval().to(device) # Separate each track into stems all_stems = {} n_tracks = len(app_state.set_order) for i, tidx in enumerate(app_state.set_order): track = app_state.analyses[tidx] progress(0.03 + (i / n_tracks) * 0.50, desc=f"Separating stems ({i+1}/{n_tracks}): {track.filename[:30]}...") y, _ = librosa.load(track.path, sr=model.samplerate, mono=False) if y.ndim == 1: y = np.stack([y, y]) wav = torch.from_numpy(y).float().unsqueeze(0).to(device) ref = wav.mean(1, keepdim=True) wav_norm = (wav - ref.mean()) / (ref.std() + 1e-8) with torch.no_grad(): sources = apply_model(model, wav_norm, device=device, split=True, overlap=0.25) sources = sources[0] * ref.std() + ref.mean() stems = {} for name, source in zip(model.sources, sources): stem_np = source.cpu().numpy() if model.samplerate != sr: stem_np = np.stack([ librosa.resample(stem_np[c], orig_sr=model.samplerate, target_sr=sr) for c in range(stem_np.shape[0]) ]) stems[name] = stem_np all_stems[tidx] = stems logger.info(f"Separated {track.filename}: {list(stems.keys())}") # Mix using stem mixer progress(0.55, desc="Mixing with stems (surgical drum/bass swap)...") set_audio, set_info = mix_stems( all_stems, app_state.analyses, app_state.set_order, progress_cb=lambda p, m: progress(0.55 + p * 0.35, desc=m) ) method = "Demucs htdemucs → surgical drum/bass swap on downbeats" except Exception as e: logger.warning(f"Stem separation failed: {e}") import traceback traceback.print_exc() # Fallback to the original filter-based mixer with refinement loop from mixer import mix_set from quality_analyzer import run_refinement_loop, format_analysis_log progress(0.10, desc="Fallback: filter-based mixing with refinement...") set_audio, set_info, _ = run_refinement_loop( mix_fn=mix_set, tracks=app_state.analyses, order=app_state.set_order, transitions=app_state.transitions, max_iter=int(max_iter), progress_cb=lambda p, m: progress(0.10 + p * 0.80, desc=m) ) method = f"Filter-based with {int(max_iter)} refinement iterations (demucs failed: {e})" app_state.rendered_set = set_audio progress(0.92, desc="Saving audio...") tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False) audio_int16 = (set_audio.T * 32767).astype(np.int16) scipy.io.wavfile.write(tmp.name, sr, audio_int16) # Summary summary = f"# ✅ DJ Set Rendered\n\n" summary += f"- **Total duration:** {set_info.get('total_duration', 0):.1f}s " summary += f"({set_info.get('total_duration', 0)/60:.1f} min)\n" summary += f"- **Tracks:** {len(set_info.get('tracks', []))}\n" summary += f"- **Method:** {method}\n\n" summary += "## Tracklist\n" for i, t in enumerate(set_info.get('tracks', [])): fn = t.get('filename', '?') tl = t.get('tl_start', 0) stretch = t.get('stretch', 1.0) extra = f" (×{stretch:.3f})" if abs(stretch - 1.0) > 0.003 else "" summary += f"{i+1}. **{fn}** — starts at {tl:.0f}s{extra}\n" if set_info.get('transitions'): summary += "\n## Transitions\n" for t in set_info['transitions']: if isinstance(t, dict): summary += f"- {t}\n" else: summary += f"- {t}\n" return tmp.name, summary