ai-techno-dj / render_set.py
rikhoffbauer2's picture
Add stem-based render_full_set using demucs + stem_mixer
eee1783 verified
raw
history blame
4.16 kB
def render_full_set(progress=gr.Progress()):
"""Render the DJ set using stem separation (demucs) for surgical mixing."""
if not app_state.transitions:
return None, "⚠️ Generate a set plan first"
progress(0.02, desc="Starting stem-based render...")
def progress_cb(p, msg):
progress(0.02 + p * 0.90, desc=msg)
# Separate stems for each track using demucs
import torch
from demucs.pretrained import get_model
from demucs.apply import apply_model
from stem_mixer import load_stems, mix_stems, make_full
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Loading demucs model on {device}...")
progress(0.03, desc="Loading demucs model...")
try:
model = get_model("htdemucs")
model.eval().to(device)
except Exception as e:
logger.error(f"Failed to load demucs: {e}")
# Fallback to non-stem mixer
from mixer import mix_set
set_audio, set_info = mix_set(
app_state.analyses, app_state.set_order, app_state.transitions,
progress_cb=progress_cb
)
app_state.rendered_set = set_audio
progress(0.95, desc="Saving (non-stem fallback)...")
tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
audio_int16 = (set_audio.T * 32767).astype(np.int16)
scipy.io.wavfile.write(tmp.name, 44100, audio_int16)
summary = f"# ✅ DJ Set Rendered (non-stem fallback)\n\n"
summary += f"- **Duration:** {set_info.get('total_duration', 0):.0f}s\n"
summary += f"- **Note:** demucs failed ({e}), used filter-based mixing\n"
return tmp.name, summary
# Separate each track into stems
all_stems = {}
for i, tidx in enumerate(app_state.set_order):
track = app_state.analyses[tidx]
progress(0.03 + (i / len(app_state.set_order)) * 0.50,
desc=f"Separating stems: {track.filename[:30]}...")
y, sr_file = librosa.load(track.path, sr=model.samplerate, mono=False)
if y.ndim == 1:
y = np.stack([y, y])
wav = torch.from_numpy(y).float().unsqueeze(0).to(device)
ref = wav.mean(1, keepdim=True)
wav_norm = (wav - ref.mean()) / (ref.std() + 1e-8)
with torch.no_grad():
sources = apply_model(model, wav_norm, device=device, split=True, overlap=0.25)
sources = sources[0] * ref.std() + ref.mean()
stems = {}
for name, source in zip(model.sources, sources):
stem_audio = source.cpu().numpy()
# Resample to 44100 if needed
if model.samplerate != 44100:
stem_audio = np.stack([
librosa.resample(stem_audio[c], orig_sr=model.samplerate, target_sr=44100)
for c in range(stem_audio.shape[0])
])
stems[name] = stem_audio
all_stems[tidx] = stems
logger.info(f"Separated {track.filename}: {list(stems.keys())}")
# Mix using stem mixer
progress(0.55, desc="Mixing with stems...")
set_audio, set_info = mix_stems(
all_stems, app_state.analyses, app_state.set_order,
progress_cb=lambda p, m: progress(0.55 + p * 0.35, desc=m)
)
app_state.rendered_set = set_audio
progress(0.92, desc="Saving...")
tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
audio_int16 = (set_audio.T * 32767).astype(np.int16)
scipy.io.wavfile.write(tmp.name, 44100, audio_int16)
# Summary
summary = f"# ✅ DJ Set Rendered (stem-based mixing)\n\n"
summary += f"- **Total duration:** {set_info.get('total_duration', 0):.1f}s "
summary += f"({set_info.get('total_duration', 0)/60:.1f} min)\n"
summary += f"- **Tracks:** {len(set_info.get('tracks', []))}\n"
summary += f"- **Method:** Demucs stem separation → surgical drum/bass swap\n\n"
summary += "## Tracklist\n"
for i, t in enumerate(set_info.get('tracks', [])):
summary += f"{i+1}. **{t['filename']}** — tl={t['tl_start']:.0f}s, stretch=×{t['stretch']:.3f}\n"
return tmp.name, summary