mrblackdev commited on
Commit
a2446b2
·
verified ·
1 Parent(s): bf213f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -328
app.py CHANGED
@@ -1,367 +1,265 @@
1
- # app.py - Audio -> Multi-track MIDI (HPSS + Multi-pitch + Clustering)
2
- # Designed for Hugging Face Spaces (Gradio).
3
- # Author: AlexGPT (responding to your request)
4
 
5
  import os
6
  import tempfile
 
 
7
  import traceback
8
  import numpy as np
9
  import librosa
10
  import pretty_midi
11
  import gradio as gr
12
- from sklearn.cluster import AgglomerativeClustering
13
 
14
- # ---------- Config ----------
15
- A440 = 440.0
 
 
 
 
 
 
 
16
 
17
- # ---------- Utilities ----------
18
- def hz_to_midi(f):
19
- """Return float MIDI number or np.nan for invalid f."""
 
 
 
20
  try:
21
- if f is None or np.isnan(f) or f <= 0:
22
- return np.nan
23
- return 69 + 12 * np.log2(f / A440)
24
- except Exception:
25
- return np.nan
26
-
27
- def safe_median_filter(data, size=3):
28
- """Median filter forcing float64 to avoid scipy errors; fallback to identity."""
29
- try:
30
- from scipy.ndimage import median_filter
31
- arr = np.asarray(data)
32
- if arr.dtype != np.float64:
33
- arr = arr.astype(np.float64)
34
- return median_filter(arr, size=size)
35
- except Exception as e:
36
- print("median_filter fallback:", e)
37
- return np.asarray(data, dtype=np.float64)
38
-
39
- def round_to_grid(seconds, bpm, division=4):
40
- if bpm <= 0:
41
- return seconds
42
- beat = 60.0 / bpm
43
- grid = beat / division
44
- ticks = np.round(seconds / grid)
45
- return ticks * grid
46
-
47
- # ---------- Signal separation & percussive detection ----------
48
- def separate_harmonic_percussive(y):
49
- """HPSS separation; returns (harmonic, percussive). If fails, return (y, zeros)."""
50
- try:
51
- y_h, y_p = librosa.effects.hpss(y)
52
- return y_h, y_p
53
  except Exception as e:
54
- print("HPSS fallback:", e)
55
- return y, np.zeros_like(y)
 
 
 
56
 
57
- def detect_percussive_hits(y_p, sr, backtrack=False):
58
- """
59
- Detect percussive onsets and map them to simple drum MIDI notes.
60
- Returns list of (time_seconds, midi_note).
61
- Heuristics: use spectral centroid & onset energy to classify kick/snare/hihat.
62
- """
63
- try:
64
- onset_env = librosa.onset.onset_strength(y=y_p, sr=sr)
65
- onsets = librosa.onset.onset_detect(onset_envelope=onset_env, sr=sr, backtrack=backtrack)
66
- hits = []
67
- if len(onsets) == 0:
68
- return hits
69
- S = np.abs(librosa.stft(y_p, n_fft=2048))
70
- for fr in onsets:
71
- t = float(librosa.frames_to_time(fr, sr=sr))
72
- # spectral centroid around the frame (safe slicing)
73
- start = max(0, fr - 2)
74
- end = min(fr + 3, S.shape[1] - 1)
75
- try:
76
- centroid = np.mean(librosa.feature.spectral_centroid(S=S[:, start:end+1], sr=sr))
77
- except Exception:
78
- centroid = 0.0
79
- # Heurística simple:
80
- # centroid small -> kick, medium -> snare, large -> hihat
81
- if centroid < 1500:
82
- midi_note = 36 # Kick
83
- elif centroid < 3500:
84
- midi_note = 38 # Acoustic snare
85
- else:
86
- midi_note = 42 # Closed hi-hat
87
- hits.append((t, midi_note))
88
- return hits
89
- except Exception as e:
90
- print("Percussive detection error:", e)
91
- return []
92
 
93
- # ---------- Multi-pitch extraction ----------
94
- def extract_multi_pitches(y_h, sr, hop_length=256, top_n=3, min_confidence=0.08):
95
- """
96
- Use piptrack to extract candidate pitches per frame.
97
- Returns list of (time_seconds, freq_hz).
98
- """
99
  try:
100
- S = np.abs(librosa.stft(y_h, n_fft=2048, hop_length=hop_length))
101
- pitches, mags = librosa.piptrack(S=S, sr=sr, hop_length=hop_length)
102
- times = librosa.frames_to_time(np.arange(pitches.shape[1]), sr=sr, hop_length=hop_length)
103
- candidates = []
104
- for i in range(pitches.shape[1]):
105
- col_p = pitches[:, i]
106
- col_m = mags[:, i]
107
- if np.max(col_m) <= 0:
108
- continue
109
- # pick top_n bins by magnitude
110
- idx = np.argsort(col_m)[-top_n:]
111
- max_col = np.max(col_m)
112
- for k in idx:
113
- if col_m[k] > 0 and col_m[k] >= min_confidence * max_col:
114
- candidates.append((times[i], float(col_p[k])))
115
- # filter zeros & NaNs
116
- candidates = [(t, p) for (t, p) in candidates if p is not None and p > 0 and not np.isnan(p)]
117
- return candidates
118
- except Exception as e:
119
- print("extract_multi_pitches error:", e)
120
- return []
121
 
122
- # ---------- Clustering / track formation ----------
123
- def cluster_pitch_trajectories(candidates, max_voices=4):
124
  """
125
- Cluster candidate (time, pitch) pairs into trajectories representing voices/instruments.
126
- Returns list of tracks; each track is a sorted list of (time, freq_hz).
 
127
  """
128
- if not candidates:
129
- return []
 
130
  try:
131
- X = np.array([[t, hz_to_midi(h)] for (t, h) in candidates], dtype=np.float64)
132
- # Normalize columns
133
- Xn = X.copy()
134
- if Xn[:,0].ptp() > 1e-9:
135
- Xn[:,0] = (Xn[:,0] - Xn[:,0].min()) / (Xn[:,0].ptp())
136
- else:
137
- Xn[:,0] = 0.0
138
- if Xn[:,1].ptp() > 1e-9:
139
- Xn[:,1] = (Xn[:,1] - Xn[:,1].min()) / (Xn[:,1].ptp())
140
- else:
141
- Xn[:,1] = 0.0
142
- n_clusters = min(max_voices, max(1, int(np.unique(np.round(Xn, 3), axis=0).shape[0])))
143
- if n_clusters <= 1:
144
- labels = np.zeros(len(Xn), dtype=int)
145
- else:
146
- clustering = AgglomerativeClustering(n_clusters=n_clusters).fit(Xn)
147
- labels = clustering.labels_
148
- tracks = []
149
- for lab in range(int(labels.max()) + 1):
150
- idxs = np.where(labels == lab)[0]
151
- if len(idxs) == 0:
152
- continue
153
- pts = [(float(X[i,0]), float(X[i,1])) for i in idxs]
154
- # convert midi values back to hz for smoothing/processing (midi->hz)
155
- pts_hz = [(t, A440 * (2 ** ((m - 69) / 12))) for (t, m) in pts]
156
- pts_sorted = sorted(pts_hz, key=lambda x: x[0])
157
- tracks.append(pts_sorted)
158
- return tracks
159
  except Exception as e:
160
- print("cluster_pitch_trajectories error:", e)
161
- return []
162
 
163
- def trajectories_to_notes(tracks, hop_length, sr, min_note_ms=80):
 
164
  """
165
- Convert each trajectory (time,freq) to notes (midi_int, start, end).
166
- Groups consecutive equal rounded-midis and enforces minimum duration.
 
 
 
167
  """
168
- notes = []
169
- for tr in tracks:
170
- if not tr:
171
- continue
172
- times = np.array([t for t, _ in tr])
173
- freqs = np.array([f for _, f in tr])
174
- # Smooth frequencies
175
- freqs_s = safe_median_filter(freqs.astype(np.float64), size=3)
176
- midis = np.round([hz_to_midi(f) for f in freqs_s])
177
- # Group consecutive equal midis
178
- i = 0
179
- n = len(midis)
180
- frame_ms = 1000.0 * hop_length / sr
181
- min_frames = max(1, int(np.ceil(min_note_ms / frame_ms)))
182
- while i < n:
183
- j = i + 1
184
- while j < n and midis[j] == midis[i]:
185
- j += 1
186
- if (j - i) >= min_frames and not np.isnan(midis[i]):
187
- t0 = float(times[i])
188
- t1 = float(times[j - 1] + hop_length / sr)
189
- notes.append((int(midis[i]), t0, t1))
190
- i = j
191
- return notes
192
-
193
- # ---------- Main multi-instrument conversion ----------
194
- def audio_to_midi_multi(
195
- audio,
196
- hop_length=256,
197
- frame_length=2048,
198
- max_voices=3,
199
- percussive=True,
200
- bpm=120,
201
- quantize=True,
202
- division=4,
203
- velocity=100,
204
- program_map=None,
205
- top_n=4,
206
- min_confidence=0.10,
207
- min_note_ms=80,
208
- ):
209
- """
210
- Full pipeline:
211
- - load audio
212
- - HPSS
213
- - detect percussive hits -> drum track
214
- - extract multi-pitch candidates from harmonic part
215
- - cluster candidates into tracks (voices)
216
- - convert tracks to MIDI notes and split into separate instruments by pitch ranges
217
- """
218
- try:
219
- # Load audio
220
- if isinstance(audio, tuple):
221
- sr, y = audio
222
- y = np.array(y, dtype=np.float32)
223
- else:
224
- y, sr = librosa.load(audio, sr=None, mono=True)
225
- if y.size == 0:
226
- raise ValueError("Empty audio")
227
- # normalize
228
- if np.max(np.abs(y)) > 0:
229
- y = y / np.max(np.abs(y))
230
-
231
- # HPSS
232
- y_h, y_p = separate_harmonic_percussive(y)
233
-
234
- pm = pretty_midi.PrettyMIDI()
235
-
236
- # Percussion track
237
- if percussive:
238
- hits = detect_percussive_hits(y_p, sr)
239
- if hits:
240
- drum_inst = pretty_midi.Instrument(program=0, is_drum=True)
241
- for t, midi_note in hits:
242
- # tiny duration for hits
243
- drum_inst.notes.append(pretty_midi.Note(velocity=int(velocity), pitch=int(midi_note),
244
- start=float(t), end=float(t + 0.05)))
245
- pm.instruments.append(drum_inst)
246
-
247
- # Harmonic: multi-pitch extraction
248
- candidates = extract_multi_pitches(y_h, sr, hop_length=hop_length, top_n=top_n, min_confidence=min_confidence)
249
- tracks = cluster_pitch_trajectories(candidates, max_voices=max_voices)
250
- notes = trajectories_to_notes(tracks, hop_length=hop_length, sr=sr, min_note_ms=min_note_ms)
251
-
252
- # If we have notes, split by pitch quantiles into up to max_voices instrument tracks.
253
- if notes:
254
- midi_vals = np.array([n[0] for n in notes])
255
- unique = np.unique(midi_vals)
256
- groups = int(min(max_voices, max(1, len(unique))))
257
- edges = np.quantile(midi_vals, np.linspace(0, 1, groups + 1))
258
- for g in range(groups):
259
- program = program_map[g] if (program_map and g < len(program_map)) else 0
260
- inst = pretty_midi.Instrument(program=int(program))
261
- low = edges[g]
262
- high = edges[g + 1]
263
- for m, t0, t1 in notes:
264
- if m >= low - 0.0001 and m <= high + 0.0001:
265
- inst.notes.append(pretty_midi.Note(velocity=int(velocity), pitch=int(m), start=float(t0),
266
- end=float(t1)))
267
- # Only append instruments that have notes
268
- if len(inst.notes) > 0:
269
  pm.instruments.append(inst)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
- # Quantize to grid if requested
272
- if quantize and bpm > 0:
273
- for instr in pm.instruments:
274
- for note in instr.notes:
275
- note.start = float(round_to_grid(note.start, bpm, division))
276
- note.end = float(round_to_grid(note.end, bpm, division))
277
- if note.end <= note.start:
278
- note.end = note.start + (60.0 / bpm) / division
279
-
280
- # Save MIDI
281
- tmpdir = tempfile.mkdtemp()
282
- midi_path = os.path.join(tmpdir, "multi_output.mid")
283
- pm.write(midi_path)
284
-
285
- summary = {
286
- "duration_s": round(len(y) / sr, 3),
287
- "instruments": len(pm.instruments),
288
- "notes_total": sum(len(i.notes) for i in pm.instruments),
289
- "bpm": bpm,
290
- "voices_requested": max_voices,
291
- }
292
- return midi_path, summary
293
-
294
- except Exception as e:
295
- traceback.print_exc()
296
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
  # ---------- Gradio UI ----------
299
  CSS = """
300
- #app_title {font-size: 28px; font-weight: 800}
301
  #app_subtitle {opacity: .8}
302
  """
303
 
304
- with gr.Blocks(css=CSS, title="Audio Multi-MIDI (AlexGPT)") as demo:
305
- gr.Markdown("<div id='app_title'>🎤 Audio 🎹 MIDI (Polyphonic & Multi-instrument)</div>"
306
- "<div id='app_subtitle'>HPSS + Multi-pitch + Clustering multi-track MIDI</div>")
307
-
308
  with gr.Row():
309
  with gr.Column(scale=2):
310
- audio_in = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio de entrada (mono/mix)")
311
- with gr.Accordion("Extracción / Separación", open=False):
312
- hop = gr.Slider(128, 1024, value=256, step=64, label="Hop length (samples)")
313
- frame = gr.Slider(1024, 4096, value=2048, step=256, label="Frame length (samples)")
314
- max_voices = gr.Slider(1, 6, value=3, step=1, label="Máx voces (clusters)")
315
- percussive = gr.Checkbox(value=True, label="Detectar percusión (HPSS)")
316
- topn = gr.Slider(1, 8, value=4, step=1, label="Picos por frame (top N)")
317
- min_conf = gr.Slider(0.01, 0.5, value=0.1, step=0.01, label="Umbral relativo de confianza")
318
- min_note_ms = gr.Slider(10, 500, value=80, step=10, label="Duración mínima nota (ms)")
319
-
320
- with gr.Accordion("Salida MIDI", open=True):
321
- do_quant = gr.Checkbox(value=True, label="Cuantizar a rejilla")
322
- bpm = gr.Slider(40, 220, value=120, step=1, label="BPM")
323
- division = gr.Dropdown([1, 2, 4, 8, 16], value=4, label="División por negra (1=negra, 4=semicorchea)")
324
- velocity = gr.Slider(1, 127, value=100, step=1, label="Velocidad (1-127)")
325
- # program_map not editable in UI for simplicity; advanced: add dynamic inputs
326
-
327
- run_btn = gr.Button("🔄 Convertir a MIDI", variant="primary")
328
-
329
  with gr.Column(scale=1):
330
- midi_out = gr.File(label="Archivo MIDI generado")
331
- summary_out = gr.JSON(label="Resumen")
332
- gr.Markdown(
333
- "**Sugerencias**\n\n"
334
- "- Este método es heurístico: los mejores resultados salen de mezclas con instrumentos claros y poca reverb.\n"
335
- "- Para separar pistas reales (vocal, synth, bass) usa modelos de source separation (Demucs/Spleeter) antes del análisis.\n"
336
- "- Ajusta `Máx voces` al número aproximado de instrumentos melódicos.\n"
337
- )
338
-
339
- def _convert(audio_path, hop_length, frame_length, max_voices_val, percussive_val, topn_val,
340
- do_quantize, bpm_val, division_val, velocity_val, min_conf_val, min_note_ms_val):
341
  try:
342
- midi_path, summary = audio_to_midi_multi(
343
- audio=audio_path,
344
- hop_length=int(hop_length),
345
- frame_length=int(frame_length),
346
- max_voices=int(max_voices_val),
347
- percussive=bool(percussive_val),
348
- bpm=float(bpm_val),
349
- quantize=bool(do_quantize),
350
- division=int(division_val),
351
- velocity=int(velocity_val),
352
- top_n=int(topn_val),
353
- min_confidence=float(min_conf_val),
354
- min_note_ms=int(min_note_ms_val),
355
- )
356
- return midi_path, summary
357
  except Exception as e:
358
- return gr.update(value=None), {"error": str(e)}
359
-
360
- run_btn.click(
361
- _convert,
362
- inputs=[audio_in, hop, frame, max_voices, percussive, topn, do_quant, bpm, division, velocity, min_conf, min_note_ms],
363
- outputs=[midi_out, summary_out],
364
- )
365
 
366
  if __name__ == "__main__":
367
  demo.launch()
 
1
+ # app.py - Demucs + Basic-Pitch pipeline -> multi-track MIDI (Gradio)
2
+ # Author: AlexGPT
3
+ # WARNING: heavy deps (demucs, basic-pitch, torch, tensorflow). Use a beefy Space or local env.
4
 
5
  import os
6
  import tempfile
7
+ import shutil
8
+ import subprocess
9
  import traceback
10
  import numpy as np
11
  import librosa
12
  import pretty_midi
13
  import gradio as gr
 
14
 
15
+ # Try imports for basic-pitch (tensorflow) if available
16
+ HAS_DEMUCS = False
17
+ HAS_BASIC_PITCH = False
18
+ DEMucs_MODEL_NAME = "htdemucs_ft" # reasonable default
19
+ try:
20
+ import demucs # noqa: F401
21
+ HAS_DEMUCS = True
22
+ except Exception:
23
+ HAS_DEMUCS = False
24
 
25
+ try:
26
+ # basic_pitch usage per README: import predict + load saved model
27
+ import tensorflow as tf # basic-pitch uses TF saved_model
28
+ from basic_pitch.inference import predict
29
+ from basic_pitch import ICASSP_2022_MODEL_PATH
30
+ # load model once (this may be heavy)
31
  try:
32
+ BASIC_PITCH_MODEL = tf.saved_model.load(str(ICASSP_2022_MODEL_PATH))
33
+ HAS_BASIC_PITCH = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  except Exception as e:
35
+ print("Could not load Basic-Pitch saved model:", e)
36
+ HAS_BASIC_PITCH = False
37
+ except Exception as e:
38
+ print("basic-pitch/TensorFlow not available:", e)
39
+ HAS_BASIC_PITCH = False
40
 
41
+ # Fallback simple pipeline (librosa-based) in case heavy libs missing
42
+ def librosa_mono_pitch_to_midi(audio_path, hop_length=256, frame_length=2048, bpm=120, quantize=True, division=4):
43
+ y, sr = librosa.load(audio_path, sr=None, mono=True)
44
+ if np.max(np.abs(y))>0:
45
+ y = y / np.max(np.abs(y))
46
+ f0, voiced_flag, _ = librosa.pyin(y, fmin=librosa.note_to_hz('C2'), fmax=librosa.note_to_hz('C7'),
47
+ sr=sr, frame_length=frame_length, hop_length=hop_length)
48
+ f0[~voiced_flag] = np.nan
49
+ # group frames into notes (simple)
50
+ times = np.arange(len(f0)) * hop_length / sr
51
+ midi_vals = np.array([69 + 12 * np.log2(v/440.0) if (v is not None and not np.isnan(v) and v>0) else np.nan for v in f0])
52
+ notes = []
53
+ i = 0
54
+ while i < len(midi_vals):
55
+ if np.isnan(midi_vals[i]):
56
+ i += 1
57
+ continue
58
+ v = int(round(midi_vals[i]))
59
+ start = i
60
+ j = i + 1
61
+ while j < len(midi_vals) and not np.isnan(midi_vals[j]) and int(round(midi_vals[j])) == v:
62
+ j += 1
63
+ t0 = times[start]
64
+ t1 = times[j-1] + hop_length/sr
65
+ notes.append((v, float(t0), float(t1)))
66
+ i = j
67
+ pm = pretty_midi.PrettyMIDI()
68
+ inst = pretty_midi.Instrument(program=0)
69
+ for m,t0,t1 in notes:
70
+ inst.notes.append(pretty_midi.Note(velocity=90, pitch=int(m), start=t0, end=t1))
71
+ pm.instruments.append(inst)
72
+ tmpdir = tempfile.mkdtemp()
73
+ out = os.path.join(tmpdir, "fallback.mid")
74
+ pm.write(out)
75
+ return out, {"engine":"librosa_pyin","notes":len(notes)}
76
 
77
+ # Utility: run demucs CLI to separate stems
78
+ def demucs_separate_cli(audio_path, model_name=DEMucs_MODEL_NAME):
79
+ # demucs CLI: demucs -n model audio.wav -o output_dir
80
+ out_root = tempfile.mkdtemp()
81
+ cmd = ["demucs", "-n", model_name, "-o", out_root, audio_path]
 
82
  try:
83
+ proc = subprocess.run(cmd, capture_output=True, text=True, check=True)
84
+ except FileNotFoundError:
85
+ # demucs not installed
86
+ raise RuntimeError("demucs CLI not found. Please install demucs in the environment.")
87
+ except subprocess.CalledProcessError as e:
88
+ raise RuntimeError(f"Demucs separation failed: {e.stderr or e.stdout}")
89
+ # output dir: out_root/separated/<model_name>/<basename> or demucs creates out_root/<model_name>/<basename>
90
+ # find the directory with stems
91
+ stems_dir = None
92
+ for root, dirs, files in os.walk(out_root):
93
+ if any(f.endswith(".wav") for f in files):
94
+ stems_dir = root
95
+ break
96
+ if stems_dir is None:
97
+ raise RuntimeError(f"demucs did not produce stems under {out_root}")
98
+ # expected stem names: vocals.wav, drums.wav, bass.wav, other.wav (depending on model)
99
+ return stems_dir
 
 
 
 
100
 
101
+ # Utility: run Basic Pitch inference on a given WAV file
102
+ def basic_pitch_transcribe(wav_path, model_obj=None):
103
  """
104
+ Uses basic_pitch.inference.predict(model, wav_path, ...) to produce MIDI bytes or notes.
105
+ According to basic-pitch README, predict returns a dict with keys including 'midi' and 'notes'.
106
+ We will attempt to call predict(BASIC_PITCH_MODEL, wav_path, **kwargs).
107
  """
108
+ if not HAS_BASIC_PITCH:
109
+ raise RuntimeError("basic-pitch is not available in this environment.")
110
+ # default parameters: see basic-pitch inference API
111
  try:
112
+ # predict returns dict with 'midi' as bytes or file path; adapt based on version
113
+ result = predict(model_obj if model_obj is not None else BASIC_PITCH_MODEL,
114
+ wav_path,
115
+ midi=False, # some versions: midi=True returns bytes, but we prefer structured notes
116
+ piano_roll=False)
117
+ # 'result' could have 'notes' key listing note dicts like {'start':, 'end':, 'pitch':, 'confidence':}
118
+ notes = result.get("notes") or result.get("pred_notes") or []
119
+ # Convert notes into pretty_midi instrument
120
+ inst = pretty_midi.Instrument(program=0)
121
+ for n in notes:
122
+ start = float(n.get("start", n.get("onset", 0.0)))
123
+ end = float(n.get("end", n.get("offset", start + 0.1)))
124
+ pitch = int(round(n.get("pitch", n.get("midi_pitch", 60))))
125
+ vel = int(n.get("velocity", 90)) if n.get("velocity") else 90
126
+ inst.notes.append(pretty_midi.Note(velocity=vel, pitch=pitch, start=start, end=end))
127
+ return inst, {"notes_count": len(inst.notes)}
 
 
 
 
 
 
 
 
 
 
 
 
128
  except Exception as e:
129
+ # fallback: raise with info
130
+ raise RuntimeError(f"basic_pitch prediction failed: {e}")
131
 
132
+ # Merge stems transcriptions into a single PrettyMIDI object
133
+ def merge_stems_to_midi(stem_paths, use_basic_pitch=True):
134
  """
135
+ stem_paths: dict {stem_name: path_wav}
136
+ For each stem:
137
+ - If basic-pitch available: transcribe with it (poliphonic)
138
+ - Else fallback to librosa_pyin per stem
139
+ Returns path_to_midi, summary
140
  """
141
+ pm = pretty_midi.PrettyMIDI()
142
+ summary = {"stems": {}, "engine": "mixed"}
143
+ for i, (stem_name, path) in enumerate(stem_paths.items()):
144
+ try:
145
+ if use_basic_pitch and HAS_BASIC_PITCH:
146
+ inst, info = basic_pitch_transcribe(path)
147
+ # assign instrument program heuristically (vocals->0, bass->32, drums as drum channel)
148
+ if stem_name.lower() == "drums" or stem_name.lower().startswith("drum"):
149
+ # drums: create drum instrument (is_drum True)
150
+ drum_inst = pretty_midi.Instrument(program=0, is_drum=True)
151
+ # pretty_midi drum notes are normal notes but set is_drum at instrument level
152
+ # copy notes from inst as hits
153
+ for n in inst.notes:
154
+ drum_inst.notes.append(pretty_midi.Note(velocity=n.velocity, pitch=n.pitch, start=n.start, end=n.end))
155
+ pm.instruments.append(drum_inst)
156
+ else:
157
+ # set program per stem (simple heuristics)
158
+ program = 0
159
+ if "bass" in stem_name.lower():
160
+ program = 32 # acoustic bass
161
+ elif "voc" in stem_name.lower() or "vocal" in stem_name.lower():
162
+ program = 54 # synth lead (as example)
163
+ inst.program = int(program)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  pm.instruments.append(inst)
165
+ summary["stems"][stem_name] = {"notes": info.get("notes_count", 0), "engine":"basic_pitch"}
166
+ else:
167
+ # fallback per-stem: librosa pyin then create instrument
168
+ out, info = librosa_mono_pitch_to_midi(path)
169
+ # load that MIDI and append tracks
170
+ midi = pretty_midi.PrettyMIDI(out)
171
+ # set program heuristics
172
+ for inst in midi.instruments:
173
+ if "drum" in stem_name.lower():
174
+ inst.is_drum = True
175
+ if "bass" in stem_name.lower():
176
+ inst.program = 32
177
+ pm.instruments.append(inst)
178
+ summary["stems"][stem_name] = {"notes": info.get("notes", 0), "engine": "librosa_fallback"}
179
+ except Exception as e:
180
+ # store error but continue
181
+ summary["stems"][stem_name] = {"error": str(e)}
182
+ # write midi
183
+ tmpdir = tempfile.mkdtemp()
184
+ out_midi = os.path.join(tmpdir, "separated_multi.mid")
185
+ pm.write(out_midi)
186
+ summary["instruments"] = len(pm.instruments)
187
+ summary["notes_total"] = sum(len(inst.notes) for inst in pm.instruments)
188
+ return out_midi, summary
189
 
190
+ # High-level pipeline: separate -> transcribe each stem -> merge
191
+ def full_pipeline(audio_filepath, demucs_model=DEMucs_MODEL_NAME, use_basic_pitch=True):
192
+ # 1) Demucs separation
193
+ if HAS_DEMUCS:
194
+ try:
195
+ stems_dir = demucs_separate_cli(audio_filepath, model_name=demucs_model)
196
+ # collect typical stems
197
+ available = {}
198
+ for name in os.listdir(stems_dir):
199
+ if name.endswith(".wav"):
200
+ stem_name = os.path.splitext(name)[0]
201
+ available[stem_name] = os.path.join(stems_dir, name)
202
+ # If demucs produced e.g. mix/<basename>/<stem>.wav or similar, try to find deeper
203
+ if not available:
204
+ # try nested
205
+ for root, dirs, files in os.walk(stems_dir):
206
+ for f in files:
207
+ if f.endswith(".wav"):
208
+ available[os.path.splitext(f)[0]] = os.path.join(root, f)
209
+ if not available:
210
+ raise RuntimeError("No stems found after Demucs separation.")
211
+ # 2) For each stem, transcribe
212
+ midi_path, summary = merge_stems_to_midi(available, use_basic_pitch=use_basic_pitch)
213
+ return midi_path, {"demucs_model":demucs_model, **summary}
214
+ except Exception as e:
215
+ traceback.print_exc()
216
+ # fallback to mono approach
217
+ print("Demucs pipeline failed, falling back to librosa mono pipeline:", e)
218
+ return librosa_mono_pitch_to_midi(audio_filepath)
219
+ else:
220
+ # If demucs not available, fallback to single-track transcribe (basic-pitch on full mix if available)
221
+ if use_basic_pitch and HAS_BASIC_PITCH:
222
+ try:
223
+ # basic-pitch on full mix
224
+ inst, info = basic_pitch_transcribe(audio_filepath)
225
+ pm = pretty_midi.PrettyMIDI()
226
+ inst.program = 0
227
+ pm.instruments.append(inst)
228
+ tmpdir = tempfile.mkdtemp()
229
+ out = os.path.join(tmpdir, "basicpitch_full.mid")
230
+ pm.write(out)
231
+ return out, {"engine":"basic_pitch_full","notes":info.get("notes_count",0)}
232
+ except Exception as e:
233
+ print("basic-pitch on full mix failed:", e)
234
+ # final fallback
235
+ return librosa_mono_pitch_to_midi(audio_filepath)
236
 
237
  # ---------- Gradio UI ----------
238
  CSS = """
239
+ #app_title {font-size: 26px; font-weight: 800}
240
  #app_subtitle {opacity: .8}
241
  """
242
 
243
+ with gr.Blocks(css=CSS, title="Demucs + BasicPitch -> Multi-MIDI") as demo:
244
+ gr.Markdown("<div id='app_title'>🔊 Separate & Transcribe Multi-track MIDI</div>"
245
+ "<div id='app_subtitle'>Demucs (stems) + Basic-Pitch (polyphonic) pipeline. Fallbacks included.</div>")
 
246
  with gr.Row():
247
  with gr.Column(scale=2):
248
+ audio_in = gr.Audio(source="upload", type="filepath", label="Audio (mix) - WAV/MP3")
249
+ demucs_model = gr.Dropdown(["htdemucs_ft","htdemucs","htdemucs_6s","mdx","mdx_extra"], value=DEMucs_MODEL_NAME, label="Demucs model")
250
+ use_basic = gr.Checkbox(value=True, label="Use Basic-Pitch for stems (if available)")
251
+ run_btn = gr.Button("🚀 Run pipeline")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  with gr.Column(scale=1):
253
+ midi_out = gr.File(label="MIDI output")
254
+ log_out = gr.Textbox(label="Summary / Log", lines=12)
255
+ def run_pipeline(audio_path, demucs_model_name, use_basic_bool):
 
 
 
 
 
 
 
 
256
  try:
257
+ midi_path, summary = full_pipeline(audio_path, demucs_model=demucs_model_name, use_basic_pitch=use_basic_bool)
258
+ return midi_path, str(summary)
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  except Exception as e:
260
+ tb = traceback.format_exc()
261
+ return None, f"Error: {e}\\n\\nTrace:\\n{tb}"
262
+ run_btn.click(run_pipeline, inputs=[audio_in, demucs_model, use_basic], outputs=[midi_out, log_out])
 
 
 
 
263
 
264
  if __name__ == "__main__":
265
  demo.launch()