# app.py - Demucs + Basic-Pitch pipeline -> multi-track MIDI (Gradio) # Author: MrblackDev # WARNING: heavy deps (demucs, basic-pitch, torch, tensorflow). Use a beefy Space or local env. import os import tempfile import shutil import subprocess import traceback import numpy as np import librosa import pretty_midi import gradio as gr # Try imports for basic-pitch (tensorflow) if available HAS_DEMUCS = False HAS_BASIC_PITCH = False DEMucs_MODEL_NAME = "htdemucs_ft" # reasonable default try: import demucs # noqa: F401 HAS_DEMUCS = True except Exception: HAS_DEMUCS = False try: # basic_pitch usage per README: import predict + load saved model import tensorflow as tf # basic-pitch uses TF saved_model from basic_pitch.inference import predict from basic_pitch import ICASSP_2022_MODEL_PATH # load model once (this may be heavy) try: BASIC_PITCH_MODEL = tf.saved_model.load(str(ICASSP_2022_MODEL_PATH)) HAS_BASIC_PITCH = True except Exception as e: print("Could not load Basic-Pitch saved model:", e) HAS_BASIC_PITCH = False except Exception as e: print("basic-pitch/TensorFlow not available:", e) HAS_BASIC_PITCH = False # Fallback simple pipeline (librosa-based) in case heavy libs missing def librosa_mono_pitch_to_midi(audio_path, hop_length=256, frame_length=2048, bpm=120, quantize=True, division=4): y, sr = librosa.load(audio_path, sr=None, mono=True) if np.max(np.abs(y))>0: y = y / np.max(np.abs(y)) f0, voiced_flag, _ = librosa.pyin(y, fmin=librosa.note_to_hz('C2'), fmax=librosa.note_to_hz('C7'), sr=sr, frame_length=frame_length, hop_length=hop_length) f0[~voiced_flag] = np.nan # group frames into notes (simple) times = np.arange(len(f0)) * hop_length / sr midi_vals = np.array([69 + 12 * np.log2(v/440.0) if (v is not None and not np.isnan(v) and v>0) else np.nan for v in f0]) notes = [] i = 0 while i < len(midi_vals): if np.isnan(midi_vals[i]): i += 1 continue v = int(round(midi_vals[i])) start = i j = i + 1 while j < len(midi_vals) and not np.isnan(midi_vals[j]) and int(round(midi_vals[j])) == v: j += 1 t0 = times[start] t1 = times[j-1] + hop_length/sr notes.append((v, float(t0), float(t1))) i = j pm = pretty_midi.PrettyMIDI() inst = pretty_midi.Instrument(program=0) for m,t0,t1 in notes: inst.notes.append(pretty_midi.Note(velocity=90, pitch=int(m), start=t0, end=t1)) pm.instruments.append(inst) tmpdir = tempfile.mkdtemp() out = os.path.join(tmpdir, "fallback.mid") pm.write(out) return out, {"engine":"librosa_pyin","notes":len(notes)} # Utility: run demucs CLI to separate stems def demucs_separate_cli(audio_path, model_name=DEMucs_MODEL_NAME): # demucs CLI: demucs -n model audio.wav -o output_dir out_root = tempfile.mkdtemp() cmd = ["demucs", "-n", model_name, "-o", out_root, audio_path] try: proc = subprocess.run(cmd, capture_output=True, text=True, check=True) except FileNotFoundError: # demucs not installed raise RuntimeError("demucs CLI not found. Please install demucs in the environment.") except subprocess.CalledProcessError as e: raise RuntimeError(f"Demucs separation failed: {e.stderr or e.stdout}") # output dir: out_root/separated// or demucs creates out_root// # find the directory with stems stems_dir = None for root, dirs, files in os.walk(out_root): if any(f.endswith(".wav") for f in files): stems_dir = root break if stems_dir is None: raise RuntimeError(f"demucs did not produce stems under {out_root}") # expected stem names: vocals.wav, drums.wav, bass.wav, other.wav (depending on model) return stems_dir # Utility: run Basic Pitch inference on a given WAV file def basic_pitch_transcribe(wav_path, model_obj=None): """ Uses basic_pitch.inference.predict(model, wav_path, ...) to produce MIDI bytes or notes. According to basic-pitch README, predict returns a dict with keys including 'midi' and 'notes'. We will attempt to call predict(BASIC_PITCH_MODEL, wav_path, **kwargs). """ if not HAS_BASIC_PITCH: raise RuntimeError("basic-pitch is not available in this environment.") # default parameters: see basic-pitch inference API try: # predict returns dict with 'midi' as bytes or file path; adapt based on version result = predict(model_obj if model_obj is not None else BASIC_PITCH_MODEL, wav_path, midi=False, # some versions: midi=True returns bytes, but we prefer structured notes piano_roll=False) # 'result' could have 'notes' key listing note dicts like {'start':, 'end':, 'pitch':, 'confidence':} notes = result.get("notes") or result.get("pred_notes") or [] # Convert notes into pretty_midi instrument inst = pretty_midi.Instrument(program=0) for n in notes: start = float(n.get("start", n.get("onset", 0.0))) end = float(n.get("end", n.get("offset", start + 0.1))) pitch = int(round(n.get("pitch", n.get("midi_pitch", 60)))) vel = int(n.get("velocity", 90)) if n.get("velocity") else 90 inst.notes.append(pretty_midi.Note(velocity=vel, pitch=pitch, start=start, end=end)) return inst, {"notes_count": len(inst.notes)} except Exception as e: # fallback: raise with info raise RuntimeError(f"basic_pitch prediction failed: {e}") # Merge stems transcriptions into a single PrettyMIDI object def merge_stems_to_midi(stem_paths, use_basic_pitch=True): """ stem_paths: dict {stem_name: path_wav} For each stem: - If basic-pitch available: transcribe with it (poliphonic) - Else fallback to librosa_pyin per stem Returns path_to_midi, summary """ pm = pretty_midi.PrettyMIDI() summary = {"stems": {}, "engine": "mixed"} for i, (stem_name, path) in enumerate(stem_paths.items()): try: if use_basic_pitch and HAS_BASIC_PITCH: inst, info = basic_pitch_transcribe(path) # assign instrument program heuristically (vocals->0, bass->32, drums as drum channel) if stem_name.lower() == "drums" or stem_name.lower().startswith("drum"): # drums: create drum instrument (is_drum True) drum_inst = pretty_midi.Instrument(program=0, is_drum=True) # pretty_midi drum notes are normal notes but set is_drum at instrument level # copy notes from inst as hits for n in inst.notes: drum_inst.notes.append(pretty_midi.Note(velocity=n.velocity, pitch=n.pitch, start=n.start, end=n.end)) pm.instruments.append(drum_inst) else: # set program per stem (simple heuristics) program = 0 if "bass" in stem_name.lower(): program = 32 # acoustic bass elif "voc" in stem_name.lower() or "vocal" in stem_name.lower(): program = 54 # synth lead (as example) inst.program = int(program) pm.instruments.append(inst) summary["stems"][stem_name] = {"notes": info.get("notes_count", 0), "engine":"basic_pitch"} else: # fallback per-stem: librosa pyin then create instrument out, info = librosa_mono_pitch_to_midi(path) # load that MIDI and append tracks midi = pretty_midi.PrettyMIDI(out) # set program heuristics for inst in midi.instruments: if "drum" in stem_name.lower(): inst.is_drum = True if "bass" in stem_name.lower(): inst.program = 32 pm.instruments.append(inst) summary["stems"][stem_name] = {"notes": info.get("notes", 0), "engine": "librosa_fallback"} except Exception as e: # store error but continue summary["stems"][stem_name] = {"error": str(e)} # write midi tmpdir = tempfile.mkdtemp() out_midi = os.path.join(tmpdir, "separated_multi.mid") pm.write(out_midi) summary["instruments"] = len(pm.instruments) summary["notes_total"] = sum(len(inst.notes) for inst in pm.instruments) return out_midi, summary # High-level pipeline: separate -> transcribe each stem -> merge def full_pipeline(audio_filepath, demucs_model=DEMucs_MODEL_NAME, use_basic_pitch=True): # 1) Demucs separation if HAS_DEMUCS: try: stems_dir = demucs_separate_cli(audio_filepath, model_name=demucs_model) # collect typical stems available = {} for name in os.listdir(stems_dir): if name.endswith(".wav"): stem_name = os.path.splitext(name)[0] available[stem_name] = os.path.join(stems_dir, name) # If demucs produced e.g. mix//.wav or similar, try to find deeper if not available: # try nested for root, dirs, files in os.walk(stems_dir): for f in files: if f.endswith(".wav"): available[os.path.splitext(f)[0]] = os.path.join(root, f) if not available: raise RuntimeError("No stems found after Demucs separation.") # 2) For each stem, transcribe midi_path, summary = merge_stems_to_midi(available, use_basic_pitch=use_basic_pitch) return midi_path, {"demucs_model":demucs_model, **summary} except Exception as e: traceback.print_exc() # fallback to mono approach print("Demucs pipeline failed, falling back to librosa mono pipeline:", e) return librosa_mono_pitch_to_midi(audio_filepath) else: # If demucs not available, fallback to single-track transcribe (basic-pitch on full mix if available) if use_basic_pitch and HAS_BASIC_PITCH: try: # basic-pitch on full mix inst, info = basic_pitch_transcribe(audio_filepath) pm = pretty_midi.PrettyMIDI() inst.program = 0 pm.instruments.append(inst) tmpdir = tempfile.mkdtemp() out = os.path.join(tmpdir, "basicpitch_full.mid") pm.write(out) return out, {"engine":"basic_pitch_full","notes":info.get("notes_count",0)} except Exception as e: print("basic-pitch on full mix failed:", e) # final fallback return librosa_mono_pitch_to_midi(audio_filepath) # ---------- Gradio UI ---------- CSS = """ #app_title {font-size: 26px; font-weight: 800} #app_subtitle {opacity: .8} """ with gr.Blocks(css=CSS, title="Demucs + BasicPitch -> Multi-MIDI") as demo: gr.Markdown("
🔊 Separate & Transcribe → Multi-track MIDI
" "
Demucs (stems) + Basic-Pitch (polyphonic) pipeline. Fallbacks included.
") with gr.Row(): with gr.Column(scale=2): audio_in = gr.Audio(sources=["upload"], type="filepath", label="Audio (mix) - WAV/MP3") demucs_model = gr.Dropdown(["htdemucs_ft","htdemucs","htdemucs_6s","mdx","mdx_extra"], value=DEMucs_MODEL_NAME, label="Demucs model") use_basic = gr.Checkbox(value=True, label="Use Basic-Pitch for stems (if available)") run_btn = gr.Button("🚀 Run pipeline") with gr.Column(scale=1): midi_out = gr.File(label="MIDI output") log_out = gr.Textbox(label="Summary / Log", lines=12) def run_pipeline(audio_path, demucs_model_name, use_basic_bool): try: midi_path, summary = full_pipeline(audio_path, demucs_model=demucs_model_name, use_basic_pitch=use_basic_bool) return midi_path, str(summary) except Exception as e: tb = traceback.format_exc() return None, f"Error: {e}\\n\\nTrace:\\n{tb}" run_btn.click(run_pipeline, inputs=[audio_in, demucs_model, use_basic], outputs=[midi_out, log_out]) if __name__ == "__main__": demo.launch()