eyov commited on
Commit
4686d19
Β·
verified Β·
1 Parent(s): 01c86f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +419 -116
app.py CHANGED
@@ -1,135 +1,438 @@
1
- import gradio as gr
2
  import os
3
- import tempfile
4
- from pathlib import Path
5
- from typing import List, Tuple, Optional
6
- from concurrent.futures import ThreadPoolExecutor
7
  import logging
8
- import soundfile as sf
 
 
9
  import numpy as np
10
- import shutil
11
- from validators import AudioValidator
12
- from demucs_handler import DemucsProcessor
13
- from basic_pitch_handler import BasicPitchConverter
14
 
15
- # Suppress TF logging
16
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
17
- logging.getLogger('tensorflow').setLevel(logging.ERROR)
 
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  logger = logging.getLogger(__name__)
20
 
21
- # Create a persistent directory for outputs
 
 
 
 
22
  OUTPUT_DIR = Path("/tmp/audio_processor")
23
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
24
 
25
- def process_single_audio(audio_path: str, stem_type: str, convert_midi: bool) -> Tuple[Tuple[int, np.ndarray], Optional[str]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  try:
27
- # Create unique subdirectory for this processing
28
- process_dir = OUTPUT_DIR / str(hash(audio_path))
29
- process_dir.mkdir(parents=True, exist_ok=True)
30
-
31
- processor = DemucsProcessor()
32
- converter = BasicPitchConverter()
33
-
34
- print(f"Starting processing of file: {audio_path}")
35
-
36
- # Process stems
37
- sources, sample_rate = processor.separate_stems(audio_path)
38
- print(f"Number of sources returned: {sources.shape}")
39
- print(f"Stem type requested: {stem_type}")
40
-
41
- # Get the requested stem
42
- stem_index = ["drums", "bass", "other", "vocals"].index(stem_type)
43
- selected_stem = sources[0, stem_index]
44
-
45
- # Save stem
46
- stem_path = process_dir / f"{stem_type}.wav"
47
- processor.save_stem(selected_stem, stem_type, str(process_dir), sample_rate)
48
- print(f"Saved stem to: {stem_path}")
49
-
50
- # Load the saved audio file for Gradio
51
- audio_data, sr = sf.read(str(stem_path))
52
- if len(audio_data.shape) > 1:
53
- audio_data = audio_data.mean(axis=1) # Convert to mono if stereo
54
-
55
- # Convert to int16 format
56
- audio_data = (audio_data * 32767).astype(np.int16)
57
-
58
- # Convert to MIDI if requested
59
- midi_path = None
60
- if convert_midi:
61
- midi_path = process_dir / f"{stem_type}.mid"
62
- converter.convert_to_midi(str(stem_path), str(midi_path))
63
- print(f"Saved MIDI to: {midi_path}")
64
-
65
- return (sr, audio_data), str(midi_path) if midi_path else None
66
- except Exception as e:
67
- print(f"Error in process_single_audio: {str(e)}")
68
- raise
69
 
70
- def create_interface():
71
- processor = DemucsProcessor()
72
- converter = BasicPitchConverter()
73
- validator = AudioValidator()
74
-
75
- def process_audio(
76
- audio_files: List[str],
77
- stem_type: str,
78
- convert_midi: bool = True,
79
- progress=gr.Progress()
80
- ) -> Tuple[Tuple[int, np.ndarray], Optional[str]]:
81
- try:
82
- print(f"Starting processing of {len(audio_files)} files")
83
- print(f"Selected stem type: {stem_type}")
 
 
 
 
 
 
84
 
85
- # Process single file for now
86
- if len(audio_files) > 0:
87
- audio_path = audio_files[0] # Take first file
88
- print(f"Processing file: {audio_path}")
89
- return process_single_audio(audio_path, stem_type, convert_midi)
90
- else:
91
- raise ValueError("No audio files provided")
92
 
93
- except Exception as e:
94
- print(f"Error in audio processing: {str(e)}")
95
- raise gr.Error(str(e))
96
-
97
- interface = gr.Interface(
98
- fn=process_audio,
99
- inputs=[
100
- gr.File(
101
- file_count="multiple",
102
- file_types=AudioValidator.SUPPORTED_FORMATS,
103
- label="Upload Audio Files"
104
- ),
105
- gr.Dropdown(
106
- choices=["vocals", "drums", "bass", "other"],
107
- label="Select Stem",
108
- value="vocals"
109
- ),
110
- gr.Checkbox(label="Convert to MIDI", value=True)
111
- ],
112
- outputs=[
113
- gr.Audio(label="Separated Stems", type="numpy"),
114
- gr.File(label="MIDI Files")
115
- ],
116
- title="Audio Stem Separator & MIDI Converter",
117
- description="Upload audio files to separate stems and convert to MIDI\n\n" +
118
- "Created by Ever Olivares - Looking for Summer 2025 Internship Opportunities\n" +
119
- "Connect with me: [LinkedIn](https://www.linkedin.com/in/everolivares/)",
120
- cache_examples=True,
121
- allow_flagging="never"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
123
-
124
- return interface
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  if __name__ == "__main__":
127
- interface = create_interface()
128
- interface.launch(
129
- share=False,
 
 
 
 
 
130
  server_name="0.0.0.0",
131
  server_port=7860,
132
- auth=None,
133
- ssl_keyfile=None,
134
- ssl_certfile=None
135
- )
 
 
1
  import os
2
+ import uuid
 
 
 
3
  import logging
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
  import numpy as np
8
+ import soundfile as sf
 
 
 
9
 
10
+ # Non-interactive Matplotlib backend β€” must be set before pyplot is imported
11
+ import matplotlib
12
+ matplotlib.use("Agg")
13
+ import matplotlib.pyplot as plt
14
+ import matplotlib.patches as patches
15
 
16
+ import librosa
17
+ import pyrubberband as pyrb
18
+
19
+ import pretty_midi
20
+ import gradio as gr
21
+
22
+ # ── Environment variables ────────────────────────────────────────────────────
23
+ # Suppress verbose TF / Metal logs
24
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
25
+ # Allow PyTorch MPS to fall back to CPU for any unsupported ops instead of
26
+ # raising an error. Must be set before torch is imported (which happens
27
+ # inside demucs_handler).
28
+ os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
29
+
30
+ # torchaudio is imported inside the handlers; audio loading is done via
31
+ # soundfile directly (TorchCodec is not available on Apple Silicon).
32
+
33
+ logging.getLogger("tensorflow").setLevel(logging.ERROR)
34
+ logging.basicConfig(level=logging.INFO)
35
  logger = logging.getLogger(__name__)
36
 
37
+ from validators import AudioValidator # noqa: E402
38
+ from demucs_handler import DemucsProcessor # noqa: E402
39
+ from basic_pitch_handler import BasicPitchConverter # noqa: E402
40
+
41
+ # ── Output directory ─────────────────────────────────────────────────────────
42
  OUTPUT_DIR = Path("/tmp/audio_processor")
43
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
44
 
45
+ # ── Singleton model instances (loaded once, reused across requests) ──────────
46
+ _processor: Optional[DemucsProcessor] = None
47
+ _converter: Optional[BasicPitchConverter] = None
48
+
49
+
50
+ def get_processor() -> DemucsProcessor:
51
+ global _processor
52
+ if _processor is None:
53
+ _processor = DemucsProcessor()
54
+ return _processor
55
+
56
+
57
+ def get_converter() -> BasicPitchConverter:
58
+ global _converter
59
+ if _converter is None:
60
+ _converter = BasicPitchConverter()
61
+ return _converter
62
+
63
+
64
+ # ── Piano roll renderer ───────────────────────────────────────────────────────
65
+ _NOTE_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
66
+ _BLACK_KEYS = {1, 3, 6, 8, 10}
67
+
68
+
69
+ def render_piano_roll(midi_path: str) -> np.ndarray:
70
+ """
71
+ Render a dark-themed piano-roll image from a MIDI file.
72
+
73
+ Returns an (H, W, 3) uint8 RGB array suitable for gr.Image(type='numpy').
74
+ """
75
+ midi = pretty_midi.PrettyMIDI(midi_path)
76
+ all_notes = [note for inst in midi.instruments for note in inst.notes]
77
+
78
+ fig, ax = plt.subplots(figsize=(18, 6), dpi=100, facecolor="#0d1117")
79
+ ax.set_facecolor("#161b22")
80
+
81
+ if not all_notes:
82
+ ax.text(
83
+ 0.5, 0.5,
84
+ "No notes detected β€” try lowering Onset or Frame threshold",
85
+ transform=ax.transAxes,
86
+ color="#8b949e", ha="center", va="center", fontsize=12,
87
+ )
88
+ else:
89
+ t_max = max(n.end for n in all_notes)
90
+ p_min = max(0, min(n.pitch for n in all_notes) - 3)
91
+ p_max = min(127, max(n.pitch for n in all_notes) + 3)
92
+
93
+ # ── Black-key shading ────────────────────────────────────────────
94
+ for p in range(p_min, p_max + 1):
95
+ if p % 12 in _BLACK_KEYS:
96
+ ax.axhspan(p - 0.5, p + 0.5, alpha=0.08, color="white", linewidth=0)
97
+
98
+ # ── Octave separator lines ────────────────────────────────────────
99
+ for p in range(p_min, p_max + 1):
100
+ if p % 12 == 0:
101
+ ax.axhline(p - 0.5, color="#21262d", linewidth=0.8, zorder=1)
102
+
103
+ # ── Notes, colour-coded by instrument track ───────────────────────
104
+ n_inst = max(1, len(midi.instruments))
105
+ cmap = plt.cm.cool
106
+ for i, inst in enumerate(midi.instruments):
107
+ colour = cmap(i / n_inst)
108
+ for note in inst.notes:
109
+ dur = max(note.end - note.start, 0.015) # minimum visual width
110
+ alpha = 0.45 + 0.55 * (note.velocity / 127.0)
111
+ ax.add_patch(
112
+ patches.FancyBboxPatch(
113
+ (note.start, note.pitch - 0.45), dur, 0.90,
114
+ boxstyle="round,pad=0.01",
115
+ linewidth=0,
116
+ facecolor=colour,
117
+ alpha=alpha,
118
+ zorder=2,
119
+ )
120
+ )
121
+
122
+ ax.set_xlim(0, t_max)
123
+ ax.set_ylim(p_min - 1, p_max + 1)
124
+
125
+ # Y-axis: label only C notes (one per octave)
126
+ c_ticks = [p for p in range(p_min, p_max + 1) if p % 12 == 0]
127
+ ax.set_yticks(c_ticks)
128
+ ax.set_yticklabels(
129
+ [f"{_NOTE_NAMES[p % 12]}{p // 12 - 1}" for p in c_ticks],
130
+ color="#8b949e", fontsize=8,
131
+ )
132
+ ax.tick_params(axis="x", colors="#8b949e", labelsize=8)
133
+ ax.tick_params(axis="y", length=0)
134
+ ax.set_xlabel("Time (s)", color="#8b949e", fontsize=9)
135
+
136
+ for spine in ax.spines.values():
137
+ spine.set_edgecolor("#30363d")
138
+
139
+ n = len(all_notes)
140
+ ax.set_title(
141
+ f"Piano Roll Β· {n} note{'s' if n != 1 else ''} Β· {t_max:.1f} s",
142
+ color="#e6edf3", fontsize=11, pad=8,
143
+ )
144
+
145
+ plt.tight_layout(pad=0.4)
146
+ fig.canvas.draw()
147
+ # buffer_rgba() β†’ RGBA array; drop alpha channel for gr.Image
148
+ img = np.asarray(fig.canvas.buffer_rgba())[..., :3]
149
+ plt.close(fig)
150
+ return img
151
+
152
+
153
+ # ── Core processing function ──────────────────────────────────────────────────
154
+ def process_audio(
155
+ audio_file,
156
+ stem_type: str,
157
+ target_bpm: float,
158
+ convert_midi: bool,
159
+ onset_threshold: float,
160
+ frame_threshold: float,
161
+ min_note_length: float,
162
+ multiple_pitch_bends: bool,
163
+ progress=gr.Progress(track_tqdm=True),
164
+ ):
165
+ """
166
+ Gradio Blocks handler.
167
+
168
+ Inputs (must match the order in run_btn.click(inputs=[...])):
169
+ audio_file, stem_type, target_bpm, convert_midi,
170
+ onset_threshold, frame_threshold, min_note_length, multiple_pitch_bends
171
+
172
+ Outputs β†’ [stem_audio, midi_file, piano_roll]
173
+ """
174
+ # ── Validate input ────────────────────────────────────────────────────
175
+ if audio_file is None:
176
+ raise gr.Error("Upload an audio file first.")
177
+
178
+ # gr.File may return a string path or an object with a .name attribute
179
+ file_path = audio_file.name if hasattr(audio_file, "name") else str(audio_file)
180
+
181
+ valid, msg = AudioValidator.validate_audio_file(file_path)
182
+ if not valid:
183
+ raise gr.Error(f"File validation failed: {msg}")
184
+
185
+ # ── Work directory (UUID prevents collisions) ─────────────────────────
186
+ work_dir = OUTPUT_DIR / uuid.uuid4().hex
187
+ work_dir.mkdir(parents=True, exist_ok=True)
188
+
189
  try:
190
+ # ── Stage 1: stem separation ──────────────────────────────────────
191
+ progress(0.05, desc="Loading Demucs model…")
192
+ processor = get_processor()
193
+
194
+ progress(0.10, desc="Separating stems (this takes ~30-90 s on first run)…")
195
+ sources, sr = processor.separate_stems(file_path)
196
+
197
+ # Use model.sources for robust stem index lookup
198
+ stem_index = processor.model.sources.index(stem_type)
199
+ selected_stem = sources[0, stem_index] # shape: (2, time)
200
+
201
+ processor.save_stem(selected_stem, stem_type, str(work_dir))
202
+ stem_path = work_dir / f"{stem_type}.wav"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ progress(0.55, desc="Stem extracted.")
205
+
206
+ # Read stem back for processing and preview (mono int16)
207
+ y, orig_sr = librosa.load(str(stem_path), sr=None)
208
+
209
+ # ── Polymath Integration: BPM Quantization ───────────────────────
210
+ # If Target BPM > 0, we time-stretch the audio to perfectly align
211
+ # to that exact grid before MIDI conversion. This ensures the output
212
+ # MIDI notes lock onto the piano roll.
213
+ if target_bpm > 0:
214
+ progress(0.56, desc=f"Quantizing stem to {target_bpm} BPM…")
215
+
216
+ # Extract harmonic/percussive and find beats
217
+ y_harmonic, y_percussive = librosa.effects.hpss(y)
218
+ tempo, beats = librosa.beat.beat_track(
219
+ sr=orig_sr,
220
+ onset_envelope=librosa.onset.onset_strength(y=y_percussive, sr=orig_sr),
221
+ trim=False
222
+ )
223
+ beat_frames = librosa.frames_to_samples(beats)
224
 
225
+ # Generate target metronome map
226
+ fixed_beat_times = [i * 120 / target_bpm for i in range(len(beat_frames))]
227
+ fixed_beat_frames = librosa.time_to_samples(fixed_beat_times)
 
 
 
 
228
 
229
+ # Construct time map for pyrubberband
230
+ time_map = list(zip(beat_frames, fixed_beat_frames))
231
+
232
+ # Handle the ending clip length
233
+ if len(beat_frames) > 0 and len(y) > beat_frames[-1]:
234
+ orig_end_diff = len(y) - beat_frames[-1]
235
+ # tempo is an ndarray, so we extract the scalar float for math
236
+ tempo_val = tempo[0] if isinstance(tempo, np.ndarray) else tempo
237
+ new_ending = int(round(fixed_beat_frames[-1] + orig_end_diff * (tempo_val / target_bpm)))
238
+ time_map.append((len(y), new_ending))
239
+
240
+ # Time-stretch
241
+ y = pyrb.timemap_stretch(y, orig_sr, time_map)
242
+ # Re-save the stretched stem to use for Basic Pitch
243
+ sf.write(str(stem_path), y, orig_sr)
244
+ progress(0.59, desc="Quantization complete.")
245
+
246
+ # Preview Audio formatting
247
+ if y.ndim > 1:
248
+ y = y.mean(axis=1)
249
+ audio_out = (orig_sr, (y * 32767).astype(np.int16))
250
+
251
+ # ── Early exit if MIDI not requested ─────────────────────────────
252
+ if not convert_midi:
253
+ progress(1.0, desc="Done.")
254
+ return audio_out, None, gr.update(value=None, visible=False)
255
+
256
+ # ── Stage 2: MIDI conversion ──────────────────────────────────────
257
+ progress(0.60, desc="Running Basic Pitch (TFLite inference)…")
258
+ converter = get_converter()
259
+ converter.set_process_options(
260
+ onset_threshold=onset_threshold,
261
+ frame_threshold=frame_threshold,
262
+ minimum_note_length=min_note_length,
263
+ multiple_pitch_bends=multiple_pitch_bends,
264
+ )
265
+
266
+ midi_path = work_dir / f"{stem_type}.mid"
267
+ converter.convert_to_midi(str(stem_path), str(midi_path))
268
+
269
+ # ── Stage 3: piano roll render ────────────────────────────────────
270
+ progress(0.90, desc="Rendering piano roll…")
271
+ roll_img = render_piano_roll(str(midi_path))
272
+
273
+ progress(1.0, desc="Done.")
274
+ return audio_out, str(midi_path), gr.update(value=roll_img, visible=True)
275
+
276
+ except gr.Error:
277
+ raise
278
+ except Exception as exc:
279
+ logger.exception("Processing failed")
280
+ raise gr.Error(str(exc)) from exc
281
+
282
+
283
+ # ── Direct-path processing (used by Quick Test UI) ───────────────────────────
284
+ def process_audio_path(
285
+ file_path: str,
286
+ stem_type: str,
287
+ target_bpm: float,
288
+ convert_midi: bool,
289
+ onset_threshold: float,
290
+ frame_threshold: float,
291
+ min_note_length: float,
292
+ multiple_pitch_bends: bool,
293
+ progress=gr.Progress(track_tqdm=True),
294
+ ):
295
+ """Same as process_audio but accepts a plain file-system path string."""
296
+
297
+ class _FakePath:
298
+ def __init__(self, p):
299
+ self.name = p
300
+
301
+ return process_audio(
302
+ _FakePath(file_path),
303
+ stem_type, target_bpm, convert_midi,
304
+ onset_threshold, frame_threshold, min_note_length, multiple_pitch_bends,
305
+ progress,
306
  )
 
 
307
 
308
+
309
+ # Discover any audio files in the repo's mp3/ folder for the Quick Test picker
310
+ _MP3_DIR = Path(__file__).parent / "mp3"
311
+
312
+ def get_test_files():
313
+ if not _MP3_DIR.is_dir():
314
+ return []
315
+ return sorted(
316
+ str(p) for p in _MP3_DIR.glob("*")
317
+ if p.suffix.lower() in (".mp3", ".wav", ".flac")
318
+ )
319
+
320
+ # ── Gradio Blocks UI ──────────────────────────────────────────────────────────
321
+ def build_interface() -> gr.Blocks:
322
+ with gr.Blocks(
323
+ title="Aud2Stm2Mdi",
324
+ theme=gr.themes.Base(primary_hue="indigo", neutral_hue="slate"),
325
+ ) as demo:
326
+
327
+ gr.Markdown(
328
+ "## Aud2Stm2Mdi\n"
329
+ "Separate audio into stems with **Demucs** `htdemucs`, "
330
+ "then transcribe to **MIDI** with **Basic Pitch**."
331
+ )
332
+
333
+ with gr.Row():
334
+
335
+ # ── Left column: controls ─────────────────────────────────────
336
+ with gr.Column(scale=1, min_width=300):
337
+
338
+ audio_input = gr.File(
339
+ label="Audio File (.mp3 / .wav / .flac)",
340
+ file_types=[".mp3", ".wav", ".flac"],
341
+ )
342
+ stem_dd = gr.Dropdown(
343
+ choices=["vocals", "drums", "bass", "other"],
344
+ value="vocals",
345
+ label="Stem to extract",
346
+ )
347
+ midi_cb = gr.Checkbox(label="Convert to MIDI", value=True)
348
+
349
+ with gr.Accordion("BPM Quantization (Polymath Core)", open=False):
350
+ bpm_sl = gr.Slider(
351
+ 0, 200, value=0, step=1,
352
+ label="Target BPM",
353
+ info="Time-stretches the stem so MIDI falls perfectly on the beat grid. Set to 0 to disable."
354
+ )
355
+
356
+ with gr.Accordion("MIDI Parameters", open=False):
357
+ onset_sl = gr.Slider(
358
+ 0.10, 0.95, value=0.50, step=0.05,
359
+ label="Onset Threshold",
360
+ info="Higher β†’ fewer but more confident note onsets",
361
+ )
362
+ frame_sl = gr.Slider(
363
+ 0.10, 0.95, value=0.40, step=0.05,
364
+ label="Frame Threshold",
365
+ info="Higher β†’ shorter notes, less legato smear",
366
+ )
367
+ minlen_sl = gr.Slider(
368
+ 50, 500, value=150, step=10,
369
+ label="Min Note Length (ms)",
370
+ info="Increase to filter ghost / glitch notes",
371
+ )
372
+ bends_cb = gr.Checkbox(
373
+ label="Multiple Pitch Bends",
374
+ value=False,
375
+ info="Keep OFF for cleaner Ableton import",
376
+ )
377
+
378
+ run_btn = gr.Button("Process", variant="primary", size="lg", elem_id="run_btn")
379
+
380
+ test_files = get_test_files()
381
+ with gr.Accordion("πŸ§ͺ Quick Test (pre-loaded files)", open=bool(test_files), elem_id="quick_test", visible=bool(test_files)):
382
+ test_dd = gr.Dropdown(
383
+ choices=test_files if test_files else ["No files found"],
384
+ value=test_files[0] if test_files else None,
385
+ label="Select test file",
386
+ elem_id="test_file_dd",
387
+ )
388
+ test_btn = gr.Button(
389
+ "Run Quick Test", variant="secondary", size="sm",
390
+ elem_id="test_btn",
391
+ )
392
+
393
+ # ── Right column: results ─────────────────────────────────────
394
+ with gr.Column(scale=2):
395
+ stem_audio = gr.Audio(label="Separated Stem", type="numpy")
396
+ midi_file = gr.File(label="MIDI Download")
397
+ piano_roll = gr.Image(
398
+ label="Piano Roll Preview",
399
+ type="numpy",
400
+ visible=False, # hidden until MIDI is produced
401
+ )
402
+
403
+ run_btn.click(
404
+ fn=process_audio,
405
+ inputs=[
406
+ audio_input, stem_dd, bpm_sl, midi_cb,
407
+ onset_sl, frame_sl, minlen_sl, bends_cb,
408
+ ],
409
+ outputs=[stem_audio, midi_file, piano_roll],
410
+ )
411
+
412
+ test_btn.click(
413
+ fn=process_audio_path,
414
+ inputs=[
415
+ test_dd, stem_dd, bpm_sl, midi_cb,
416
+ onset_sl, frame_sl, minlen_sl, bends_cb,
417
+ ],
418
+ outputs=[stem_audio, midi_file, piano_roll],
419
+ )
420
+
421
+ return demo
422
+
423
+
424
+ # ── Entry point ───────────────────────────────────────────────────────────────
425
  if __name__ == "__main__":
426
+ # Load both models eagerly at startup so the first request doesn't pay
427
+ # the full model-load penalty.
428
+ print("Loading models at startup…")
429
+ get_processor()
430
+ get_converter()
431
+ print("Models ready β€” launching server.")
432
+
433
+ build_interface().launch(
434
  server_name="0.0.0.0",
435
  server_port=7860,
436
+ share=False,
437
+ allowed_paths=[str(OUTPUT_DIR)],
438
+ )