rikhoffbauer2 commited on
Commit
e2c6ce1
·
verified ·
1 Parent(s): 9a04bdf

stem_render: accept max_iter param, use it for refinement iterations"

Browse files
Files changed (1) hide show
  1. stem_render.py +33 -36
stem_render.py CHANGED
@@ -10,19 +10,15 @@ import gradio as gr
10
  logger = logging.getLogger("dj_engine")
11
 
12
 
13
- def render_full_set_with_stems(app_state, progress=None):
14
- """Render the DJ set using demucs stem separation."""
15
- if progress is None:
16
- progress = gr.Progress()
17
 
 
 
18
  if not app_state.transitions:
19
  return None, "⚠️ Generate a set plan first"
20
 
21
- try:
22
- progress(0.02, desc="Starting stem-based render...")
23
- except:
24
- pass # progress might not be callable in all contexts
25
-
26
  sr = 44100
27
 
28
  try:
@@ -32,10 +28,7 @@ def render_full_set_with_stems(app_state, progress=None):
32
  from stem_mixer import mix_stems
33
 
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
- try:
36
- progress(0.03, desc=f"Loading demucs htdemucs ({device})...")
37
- except:
38
- pass
39
  model = get_model("htdemucs")
40
  model.eval().to(device)
41
 
@@ -44,11 +37,8 @@ def render_full_set_with_stems(app_state, progress=None):
44
  n_tracks = len(app_state.set_order)
45
  for i, tidx in enumerate(app_state.set_order):
46
  track = app_state.analyses[tidx]
47
- try:
48
- progress(0.03 + (i / n_tracks) * 0.50,
49
- desc=f"Separating stems ({i+1}/{n_tracks}): {track.filename[:30]}...")
50
- except:
51
- pass
52
 
53
  y, _ = librosa.load(track.path, sr=model.samplerate, mono=False)
54
  if y.ndim == 1:
@@ -76,46 +66,45 @@ def render_full_set_with_stems(app_state, progress=None):
76
  logger.info(f"Separated {track.filename}: {list(stems.keys())}")
77
 
78
  # Mix using stem mixer
79
- try:
80
- progress(0.55, desc="Mixing with stems...")
81
- except:
82
- pass
83
  set_audio, set_info = mix_stems(
84
  all_stems, app_state.analyses, app_state.set_order,
85
- progress_cb=lambda p, m: None # stem_mixer has its own progress
86
  )
87
  method = "Demucs htdemucs → surgical drum/bass swap on downbeats"
88
 
89
  except Exception as e:
90
- logger.warning(f"Stem separation failed: {e}.")
91
  import traceback
92
  traceback.print_exc()
 
93
  from mixer import mix_set
94
- try:
95
- progress(0.10, desc="Fallback: filter-based mixing...")
96
- except:
97
- pass
98
- set_audio, set_info = mix_set(
99
- app_state.analyses, app_state.set_order, app_state.transitions,
100
- progress_cb=lambda p, m: None
 
 
101
  )
102
- method = f"Filter-based (demucs failed: {e})"
103
 
104
  app_state.rendered_set = set_audio
105
 
106
- try:
107
- progress(0.92, desc="Saving audio...")
108
- except:
109
- pass
110
  tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
111
  audio_int16 = (set_audio.T * 32767).astype(np.int16)
112
  scipy.io.wavfile.write(tmp.name, sr, audio_int16)
113
 
 
114
  summary = f"# ✅ DJ Set Rendered\n\n"
115
  summary += f"- **Total duration:** {set_info.get('total_duration', 0):.1f}s "
116
  summary += f"({set_info.get('total_duration', 0)/60:.1f} min)\n"
117
  summary += f"- **Tracks:** {len(set_info.get('tracks', []))}\n"
118
  summary += f"- **Method:** {method}\n\n"
 
119
  summary += "## Tracklist\n"
120
  for i, t in enumerate(set_info.get('tracks', [])):
121
  fn = t.get('filename', '?')
@@ -124,4 +113,12 @@ def render_full_set_with_stems(app_state, progress=None):
124
  extra = f" (×{stretch:.3f})" if abs(stretch - 1.0) > 0.003 else ""
125
  summary += f"{i+1}. **{fn}** — starts at {tl:.0f}s{extra}\n"
126
 
 
 
 
 
 
 
 
 
127
  return tmp.name, summary
 
10
  logger = logging.getLogger("dj_engine")
11
 
12
 
13
+ def render_full_set_with_stems(app_state, max_iter=20, progress=gr.Progress()):
14
+ """Render the DJ set using demucs stem separation.
 
 
15
 
16
+ max_iter: number of refinement iterations (from the UI slider)
17
+ """
18
  if not app_state.transitions:
19
  return None, "⚠️ Generate a set plan first"
20
 
21
+ progress(0.02, desc="Starting stem-based render...")
 
 
 
 
22
  sr = 44100
23
 
24
  try:
 
28
  from stem_mixer import mix_stems
29
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ progress(0.03, desc=f"Loading demucs htdemucs ({device})...")
 
 
 
32
  model = get_model("htdemucs")
33
  model.eval().to(device)
34
 
 
37
  n_tracks = len(app_state.set_order)
38
  for i, tidx in enumerate(app_state.set_order):
39
  track = app_state.analyses[tidx]
40
+ progress(0.03 + (i / n_tracks) * 0.50,
41
+ desc=f"Separating stems ({i+1}/{n_tracks}): {track.filename[:30]}...")
 
 
 
42
 
43
  y, _ = librosa.load(track.path, sr=model.samplerate, mono=False)
44
  if y.ndim == 1:
 
66
  logger.info(f"Separated {track.filename}: {list(stems.keys())}")
67
 
68
  # Mix using stem mixer
69
+ progress(0.55, desc="Mixing with stems (surgical drum/bass swap)...")
 
 
 
70
  set_audio, set_info = mix_stems(
71
  all_stems, app_state.analyses, app_state.set_order,
72
+ progress_cb=lambda p, m: progress(0.55 + p * 0.35, desc=m)
73
  )
74
  method = "Demucs htdemucs → surgical drum/bass swap on downbeats"
75
 
76
  except Exception as e:
77
+ logger.warning(f"Stem separation failed: {e}")
78
  import traceback
79
  traceback.print_exc()
80
+ # Fallback to the original filter-based mixer with refinement loop
81
  from mixer import mix_set
82
+ from quality_analyzer import run_refinement_loop, format_analysis_log
83
+ progress(0.10, desc="Fallback: filter-based mixing with refinement...")
84
+ set_audio, set_info, _ = run_refinement_loop(
85
+ mix_fn=mix_set,
86
+ tracks=app_state.analyses,
87
+ order=app_state.set_order,
88
+ transitions=app_state.transitions,
89
+ max_iter=int(max_iter),
90
+ progress_cb=lambda p, m: progress(0.10 + p * 0.80, desc=m)
91
  )
92
+ method = f"Filter-based with {int(max_iter)} refinement iterations (demucs failed: {e})"
93
 
94
  app_state.rendered_set = set_audio
95
 
96
+ progress(0.92, desc="Saving audio...")
 
 
 
97
  tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
98
  audio_int16 = (set_audio.T * 32767).astype(np.int16)
99
  scipy.io.wavfile.write(tmp.name, sr, audio_int16)
100
 
101
+ # Summary
102
  summary = f"# ✅ DJ Set Rendered\n\n"
103
  summary += f"- **Total duration:** {set_info.get('total_duration', 0):.1f}s "
104
  summary += f"({set_info.get('total_duration', 0)/60:.1f} min)\n"
105
  summary += f"- **Tracks:** {len(set_info.get('tracks', []))}\n"
106
  summary += f"- **Method:** {method}\n\n"
107
+
108
  summary += "## Tracklist\n"
109
  for i, t in enumerate(set_info.get('tracks', [])):
110
  fn = t.get('filename', '?')
 
113
  extra = f" (×{stretch:.3f})" if abs(stretch - 1.0) > 0.003 else ""
114
  summary += f"{i+1}. **{fn}** — starts at {tl:.0f}s{extra}\n"
115
 
116
+ if set_info.get('transitions'):
117
+ summary += "\n## Transitions\n"
118
+ for t in set_info['transitions']:
119
+ if isinstance(t, dict):
120
+ summary += f"- {t}\n"
121
+ else:
122
+ summary += f"- {t}\n"
123
+
124
  return tmp.name, summary