Spaces:
Running
Running
stem_render: accept max_iter param, use it for refinement iterations"
Browse files- stem_render.py +33 -36
stem_render.py
CHANGED
|
@@ -10,19 +10,15 @@ import gradio as gr
|
|
| 10 |
logger = logging.getLogger("dj_engine")
|
| 11 |
|
| 12 |
|
| 13 |
-
def render_full_set_with_stems(app_state, progress=
|
| 14 |
-
"""Render the DJ set using demucs stem separation.
|
| 15 |
-
if progress is None:
|
| 16 |
-
progress = gr.Progress()
|
| 17 |
|
|
|
|
|
|
|
| 18 |
if not app_state.transitions:
|
| 19 |
return None, "⚠️ Generate a set plan first"
|
| 20 |
|
| 21 |
-
|
| 22 |
-
progress(0.02, desc="Starting stem-based render...")
|
| 23 |
-
except:
|
| 24 |
-
pass # progress might not be callable in all contexts
|
| 25 |
-
|
| 26 |
sr = 44100
|
| 27 |
|
| 28 |
try:
|
|
@@ -32,10 +28,7 @@ def render_full_set_with_stems(app_state, progress=None):
|
|
| 32 |
from stem_mixer import mix_stems
|
| 33 |
|
| 34 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 35 |
-
|
| 36 |
-
progress(0.03, desc=f"Loading demucs htdemucs ({device})...")
|
| 37 |
-
except:
|
| 38 |
-
pass
|
| 39 |
model = get_model("htdemucs")
|
| 40 |
model.eval().to(device)
|
| 41 |
|
|
@@ -44,11 +37,8 @@ def render_full_set_with_stems(app_state, progress=None):
|
|
| 44 |
n_tracks = len(app_state.set_order)
|
| 45 |
for i, tidx in enumerate(app_state.set_order):
|
| 46 |
track = app_state.analyses[tidx]
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
desc=f"Separating stems ({i+1}/{n_tracks}): {track.filename[:30]}...")
|
| 50 |
-
except:
|
| 51 |
-
pass
|
| 52 |
|
| 53 |
y, _ = librosa.load(track.path, sr=model.samplerate, mono=False)
|
| 54 |
if y.ndim == 1:
|
|
@@ -76,46 +66,45 @@ def render_full_set_with_stems(app_state, progress=None):
|
|
| 76 |
logger.info(f"Separated {track.filename}: {list(stems.keys())}")
|
| 77 |
|
| 78 |
# Mix using stem mixer
|
| 79 |
-
|
| 80 |
-
progress(0.55, desc="Mixing with stems...")
|
| 81 |
-
except:
|
| 82 |
-
pass
|
| 83 |
set_audio, set_info = mix_stems(
|
| 84 |
all_stems, app_state.analyses, app_state.set_order,
|
| 85 |
-
progress_cb=lambda p, m:
|
| 86 |
)
|
| 87 |
method = "Demucs htdemucs → surgical drum/bass swap on downbeats"
|
| 88 |
|
| 89 |
except Exception as e:
|
| 90 |
-
logger.warning(f"Stem separation failed: {e}
|
| 91 |
import traceback
|
| 92 |
traceback.print_exc()
|
|
|
|
| 93 |
from mixer import mix_set
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
app_state.
|
| 100 |
-
|
|
|
|
|
|
|
| 101 |
)
|
| 102 |
-
method = f"Filter-based (demucs failed: {e})"
|
| 103 |
|
| 104 |
app_state.rendered_set = set_audio
|
| 105 |
|
| 106 |
-
|
| 107 |
-
progress(0.92, desc="Saving audio...")
|
| 108 |
-
except:
|
| 109 |
-
pass
|
| 110 |
tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
|
| 111 |
audio_int16 = (set_audio.T * 32767).astype(np.int16)
|
| 112 |
scipy.io.wavfile.write(tmp.name, sr, audio_int16)
|
| 113 |
|
|
|
|
| 114 |
summary = f"# ✅ DJ Set Rendered\n\n"
|
| 115 |
summary += f"- **Total duration:** {set_info.get('total_duration', 0):.1f}s "
|
| 116 |
summary += f"({set_info.get('total_duration', 0)/60:.1f} min)\n"
|
| 117 |
summary += f"- **Tracks:** {len(set_info.get('tracks', []))}\n"
|
| 118 |
summary += f"- **Method:** {method}\n\n"
|
|
|
|
| 119 |
summary += "## Tracklist\n"
|
| 120 |
for i, t in enumerate(set_info.get('tracks', [])):
|
| 121 |
fn = t.get('filename', '?')
|
|
@@ -124,4 +113,12 @@ def render_full_set_with_stems(app_state, progress=None):
|
|
| 124 |
extra = f" (×{stretch:.3f})" if abs(stretch - 1.0) > 0.003 else ""
|
| 125 |
summary += f"{i+1}. **{fn}** — starts at {tl:.0f}s{extra}\n"
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
return tmp.name, summary
|
|
|
|
| 10 |
logger = logging.getLogger("dj_engine")
|
| 11 |
|
| 12 |
|
| 13 |
+
def render_full_set_with_stems(app_state, max_iter=20, progress=gr.Progress()):
|
| 14 |
+
"""Render the DJ set using demucs stem separation.
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
max_iter: number of refinement iterations (from the UI slider)
|
| 17 |
+
"""
|
| 18 |
if not app_state.transitions:
|
| 19 |
return None, "⚠️ Generate a set plan first"
|
| 20 |
|
| 21 |
+
progress(0.02, desc="Starting stem-based render...")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
sr = 44100
|
| 23 |
|
| 24 |
try:
|
|
|
|
| 28 |
from stem_mixer import mix_stems
|
| 29 |
|
| 30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
+
progress(0.03, desc=f"Loading demucs htdemucs ({device})...")
|
|
|
|
|
|
|
|
|
|
| 32 |
model = get_model("htdemucs")
|
| 33 |
model.eval().to(device)
|
| 34 |
|
|
|
|
| 37 |
n_tracks = len(app_state.set_order)
|
| 38 |
for i, tidx in enumerate(app_state.set_order):
|
| 39 |
track = app_state.analyses[tidx]
|
| 40 |
+
progress(0.03 + (i / n_tracks) * 0.50,
|
| 41 |
+
desc=f"Separating stems ({i+1}/{n_tracks}): {track.filename[:30]}...")
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
y, _ = librosa.load(track.path, sr=model.samplerate, mono=False)
|
| 44 |
if y.ndim == 1:
|
|
|
|
| 66 |
logger.info(f"Separated {track.filename}: {list(stems.keys())}")
|
| 67 |
|
| 68 |
# Mix using stem mixer
|
| 69 |
+
progress(0.55, desc="Mixing with stems (surgical drum/bass swap)...")
|
|
|
|
|
|
|
|
|
|
| 70 |
set_audio, set_info = mix_stems(
|
| 71 |
all_stems, app_state.analyses, app_state.set_order,
|
| 72 |
+
progress_cb=lambda p, m: progress(0.55 + p * 0.35, desc=m)
|
| 73 |
)
|
| 74 |
method = "Demucs htdemucs → surgical drum/bass swap on downbeats"
|
| 75 |
|
| 76 |
except Exception as e:
|
| 77 |
+
logger.warning(f"Stem separation failed: {e}")
|
| 78 |
import traceback
|
| 79 |
traceback.print_exc()
|
| 80 |
+
# Fallback to the original filter-based mixer with refinement loop
|
| 81 |
from mixer import mix_set
|
| 82 |
+
from quality_analyzer import run_refinement_loop, format_analysis_log
|
| 83 |
+
progress(0.10, desc="Fallback: filter-based mixing with refinement...")
|
| 84 |
+
set_audio, set_info, _ = run_refinement_loop(
|
| 85 |
+
mix_fn=mix_set,
|
| 86 |
+
tracks=app_state.analyses,
|
| 87 |
+
order=app_state.set_order,
|
| 88 |
+
transitions=app_state.transitions,
|
| 89 |
+
max_iter=int(max_iter),
|
| 90 |
+
progress_cb=lambda p, m: progress(0.10 + p * 0.80, desc=m)
|
| 91 |
)
|
| 92 |
+
method = f"Filter-based with {int(max_iter)} refinement iterations (demucs failed: {e})"
|
| 93 |
|
| 94 |
app_state.rendered_set = set_audio
|
| 95 |
|
| 96 |
+
progress(0.92, desc="Saving audio...")
|
|
|
|
|
|
|
|
|
|
| 97 |
tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
|
| 98 |
audio_int16 = (set_audio.T * 32767).astype(np.int16)
|
| 99 |
scipy.io.wavfile.write(tmp.name, sr, audio_int16)
|
| 100 |
|
| 101 |
+
# Summary
|
| 102 |
summary = f"# ✅ DJ Set Rendered\n\n"
|
| 103 |
summary += f"- **Total duration:** {set_info.get('total_duration', 0):.1f}s "
|
| 104 |
summary += f"({set_info.get('total_duration', 0)/60:.1f} min)\n"
|
| 105 |
summary += f"- **Tracks:** {len(set_info.get('tracks', []))}\n"
|
| 106 |
summary += f"- **Method:** {method}\n\n"
|
| 107 |
+
|
| 108 |
summary += "## Tracklist\n"
|
| 109 |
for i, t in enumerate(set_info.get('tracks', [])):
|
| 110 |
fn = t.get('filename', '?')
|
|
|
|
| 113 |
extra = f" (×{stretch:.3f})" if abs(stretch - 1.0) > 0.003 else ""
|
| 114 |
summary += f"{i+1}. **{fn}** — starts at {tl:.0f}s{extra}\n"
|
| 115 |
|
| 116 |
+
if set_info.get('transitions'):
|
| 117 |
+
summary += "\n## Transitions\n"
|
| 118 |
+
for t in set_info['transitions']:
|
| 119 |
+
if isinstance(t, dict):
|
| 120 |
+
summary += f"- {t}\n"
|
| 121 |
+
else:
|
| 122 |
+
summary += f"- {t}\n"
|
| 123 |
+
|
| 124 |
return tmp.name, summary
|