Spaces:
Runtime error
Runtime error
| # app.py - Demucs + Basic-Pitch pipeline -> multi-track MIDI (Gradio) | |
| # Author: MrblackDev | |
| # WARNING: heavy deps (demucs, basic-pitch, torch, tensorflow). Use a beefy Space or local env. | |
| import os | |
| import tempfile | |
| import shutil | |
| import subprocess | |
| import traceback | |
| import numpy as np | |
| import librosa | |
| import pretty_midi | |
| import gradio as gr | |
| # Try imports for basic-pitch (tensorflow) if available | |
| HAS_DEMUCS = False | |
| HAS_BASIC_PITCH = False | |
| DEMucs_MODEL_NAME = "htdemucs_ft" # reasonable default | |
| try: | |
| import demucs # noqa: F401 | |
| HAS_DEMUCS = True | |
| except Exception: | |
| HAS_DEMUCS = False | |
| try: | |
| # basic_pitch usage per README: import predict + load saved model | |
| import tensorflow as tf # basic-pitch uses TF saved_model | |
| from basic_pitch.inference import predict | |
| from basic_pitch import ICASSP_2022_MODEL_PATH | |
| # load model once (this may be heavy) | |
| try: | |
| BASIC_PITCH_MODEL = tf.saved_model.load(str(ICASSP_2022_MODEL_PATH)) | |
| HAS_BASIC_PITCH = True | |
| except Exception as e: | |
| print("Could not load Basic-Pitch saved model:", e) | |
| HAS_BASIC_PITCH = False | |
| except Exception as e: | |
| print("basic-pitch/TensorFlow not available:", e) | |
| HAS_BASIC_PITCH = False | |
| # Fallback simple pipeline (librosa-based) in case heavy libs missing | |
| def librosa_mono_pitch_to_midi(audio_path, hop_length=256, frame_length=2048, bpm=120, quantize=True, division=4): | |
| y, sr = librosa.load(audio_path, sr=None, mono=True) | |
| if np.max(np.abs(y))>0: | |
| y = y / np.max(np.abs(y)) | |
| f0, voiced_flag, _ = librosa.pyin(y, fmin=librosa.note_to_hz('C2'), fmax=librosa.note_to_hz('C7'), | |
| sr=sr, frame_length=frame_length, hop_length=hop_length) | |
| f0[~voiced_flag] = np.nan | |
| # group frames into notes (simple) | |
| times = np.arange(len(f0)) * hop_length / sr | |
| midi_vals = np.array([69 + 12 * np.log2(v/440.0) if (v is not None and not np.isnan(v) and v>0) else np.nan for v in f0]) | |
| notes = [] | |
| i = 0 | |
| while i < len(midi_vals): | |
| if np.isnan(midi_vals[i]): | |
| i += 1 | |
| continue | |
| v = int(round(midi_vals[i])) | |
| start = i | |
| j = i + 1 | |
| while j < len(midi_vals) and not np.isnan(midi_vals[j]) and int(round(midi_vals[j])) == v: | |
| j += 1 | |
| t0 = times[start] | |
| t1 = times[j-1] + hop_length/sr | |
| notes.append((v, float(t0), float(t1))) | |
| i = j | |
| pm = pretty_midi.PrettyMIDI() | |
| inst = pretty_midi.Instrument(program=0) | |
| for m,t0,t1 in notes: | |
| inst.notes.append(pretty_midi.Note(velocity=90, pitch=int(m), start=t0, end=t1)) | |
| pm.instruments.append(inst) | |
| tmpdir = tempfile.mkdtemp() | |
| out = os.path.join(tmpdir, "fallback.mid") | |
| pm.write(out) | |
| return out, {"engine":"librosa_pyin","notes":len(notes)} | |
| # Utility: run demucs CLI to separate stems | |
| def demucs_separate_cli(audio_path, model_name=DEMucs_MODEL_NAME): | |
| # demucs CLI: demucs -n model audio.wav -o output_dir | |
| out_root = tempfile.mkdtemp() | |
| cmd = ["demucs", "-n", model_name, "-o", out_root, audio_path] | |
| try: | |
| proc = subprocess.run(cmd, capture_output=True, text=True, check=True) | |
| except FileNotFoundError: | |
| # demucs not installed | |
| raise RuntimeError("demucs CLI not found. Please install demucs in the environment.") | |
| except subprocess.CalledProcessError as e: | |
| raise RuntimeError(f"Demucs separation failed: {e.stderr or e.stdout}") | |
| # output dir: out_root/separated/<model_name>/<basename> or demucs creates out_root/<model_name>/<basename> | |
| # find the directory with stems | |
| stems_dir = None | |
| for root, dirs, files in os.walk(out_root): | |
| if any(f.endswith(".wav") for f in files): | |
| stems_dir = root | |
| break | |
| if stems_dir is None: | |
| raise RuntimeError(f"demucs did not produce stems under {out_root}") | |
| # expected stem names: vocals.wav, drums.wav, bass.wav, other.wav (depending on model) | |
| return stems_dir | |
| # Utility: run Basic Pitch inference on a given WAV file | |
| def basic_pitch_transcribe(wav_path, model_obj=None): | |
| """ | |
| Uses basic_pitch.inference.predict(model, wav_path, ...) to produce MIDI bytes or notes. | |
| According to basic-pitch README, predict returns a dict with keys including 'midi' and 'notes'. | |
| We will attempt to call predict(BASIC_PITCH_MODEL, wav_path, **kwargs). | |
| """ | |
| if not HAS_BASIC_PITCH: | |
| raise RuntimeError("basic-pitch is not available in this environment.") | |
| # default parameters: see basic-pitch inference API | |
| try: | |
| # predict returns dict with 'midi' as bytes or file path; adapt based on version | |
| result = predict(model_obj if model_obj is not None else BASIC_PITCH_MODEL, | |
| wav_path, | |
| midi=False, # some versions: midi=True returns bytes, but we prefer structured notes | |
| piano_roll=False) | |
| # 'result' could have 'notes' key listing note dicts like {'start':, 'end':, 'pitch':, 'confidence':} | |
| notes = result.get("notes") or result.get("pred_notes") or [] | |
| # Convert notes into pretty_midi instrument | |
| inst = pretty_midi.Instrument(program=0) | |
| for n in notes: | |
| start = float(n.get("start", n.get("onset", 0.0))) | |
| end = float(n.get("end", n.get("offset", start + 0.1))) | |
| pitch = int(round(n.get("pitch", n.get("midi_pitch", 60)))) | |
| vel = int(n.get("velocity", 90)) if n.get("velocity") else 90 | |
| inst.notes.append(pretty_midi.Note(velocity=vel, pitch=pitch, start=start, end=end)) | |
| return inst, {"notes_count": len(inst.notes)} | |
| except Exception as e: | |
| # fallback: raise with info | |
| raise RuntimeError(f"basic_pitch prediction failed: {e}") | |
| # Merge stems transcriptions into a single PrettyMIDI object | |
| def merge_stems_to_midi(stem_paths, use_basic_pitch=True): | |
| """ | |
| stem_paths: dict {stem_name: path_wav} | |
| For each stem: | |
| - If basic-pitch available: transcribe with it (poliphonic) | |
| - Else fallback to librosa_pyin per stem | |
| Returns path_to_midi, summary | |
| """ | |
| pm = pretty_midi.PrettyMIDI() | |
| summary = {"stems": {}, "engine": "mixed"} | |
| for i, (stem_name, path) in enumerate(stem_paths.items()): | |
| try: | |
| if use_basic_pitch and HAS_BASIC_PITCH: | |
| inst, info = basic_pitch_transcribe(path) | |
| # assign instrument program heuristically (vocals->0, bass->32, drums as drum channel) | |
| if stem_name.lower() == "drums" or stem_name.lower().startswith("drum"): | |
| # drums: create drum instrument (is_drum True) | |
| drum_inst = pretty_midi.Instrument(program=0, is_drum=True) | |
| # pretty_midi drum notes are normal notes but set is_drum at instrument level | |
| # copy notes from inst as hits | |
| for n in inst.notes: | |
| drum_inst.notes.append(pretty_midi.Note(velocity=n.velocity, pitch=n.pitch, start=n.start, end=n.end)) | |
| pm.instruments.append(drum_inst) | |
| else: | |
| # set program per stem (simple heuristics) | |
| program = 0 | |
| if "bass" in stem_name.lower(): | |
| program = 32 # acoustic bass | |
| elif "voc" in stem_name.lower() or "vocal" in stem_name.lower(): | |
| program = 54 # synth lead (as example) | |
| inst.program = int(program) | |
| pm.instruments.append(inst) | |
| summary["stems"][stem_name] = {"notes": info.get("notes_count", 0), "engine":"basic_pitch"} | |
| else: | |
| # fallback per-stem: librosa pyin then create instrument | |
| out, info = librosa_mono_pitch_to_midi(path) | |
| # load that MIDI and append tracks | |
| midi = pretty_midi.PrettyMIDI(out) | |
| # set program heuristics | |
| for inst in midi.instruments: | |
| if "drum" in stem_name.lower(): | |
| inst.is_drum = True | |
| if "bass" in stem_name.lower(): | |
| inst.program = 32 | |
| pm.instruments.append(inst) | |
| summary["stems"][stem_name] = {"notes": info.get("notes", 0), "engine": "librosa_fallback"} | |
| except Exception as e: | |
| # store error but continue | |
| summary["stems"][stem_name] = {"error": str(e)} | |
| # write midi | |
| tmpdir = tempfile.mkdtemp() | |
| out_midi = os.path.join(tmpdir, "separated_multi.mid") | |
| pm.write(out_midi) | |
| summary["instruments"] = len(pm.instruments) | |
| summary["notes_total"] = sum(len(inst.notes) for inst in pm.instruments) | |
| return out_midi, summary | |
| # High-level pipeline: separate -> transcribe each stem -> merge | |
| def full_pipeline(audio_filepath, demucs_model=DEMucs_MODEL_NAME, use_basic_pitch=True): | |
| # 1) Demucs separation | |
| if HAS_DEMUCS: | |
| try: | |
| stems_dir = demucs_separate_cli(audio_filepath, model_name=demucs_model) | |
| # collect typical stems | |
| available = {} | |
| for name in os.listdir(stems_dir): | |
| if name.endswith(".wav"): | |
| stem_name = os.path.splitext(name)[0] | |
| available[stem_name] = os.path.join(stems_dir, name) | |
| # If demucs produced e.g. mix/<basename>/<stem>.wav or similar, try to find deeper | |
| if not available: | |
| # try nested | |
| for root, dirs, files in os.walk(stems_dir): | |
| for f in files: | |
| if f.endswith(".wav"): | |
| available[os.path.splitext(f)[0]] = os.path.join(root, f) | |
| if not available: | |
| raise RuntimeError("No stems found after Demucs separation.") | |
| # 2) For each stem, transcribe | |
| midi_path, summary = merge_stems_to_midi(available, use_basic_pitch=use_basic_pitch) | |
| return midi_path, {"demucs_model":demucs_model, **summary} | |
| except Exception as e: | |
| traceback.print_exc() | |
| # fallback to mono approach | |
| print("Demucs pipeline failed, falling back to librosa mono pipeline:", e) | |
| return librosa_mono_pitch_to_midi(audio_filepath) | |
| else: | |
| # If demucs not available, fallback to single-track transcribe (basic-pitch on full mix if available) | |
| if use_basic_pitch and HAS_BASIC_PITCH: | |
| try: | |
| # basic-pitch on full mix | |
| inst, info = basic_pitch_transcribe(audio_filepath) | |
| pm = pretty_midi.PrettyMIDI() | |
| inst.program = 0 | |
| pm.instruments.append(inst) | |
| tmpdir = tempfile.mkdtemp() | |
| out = os.path.join(tmpdir, "basicpitch_full.mid") | |
| pm.write(out) | |
| return out, {"engine":"basic_pitch_full","notes":info.get("notes_count",0)} | |
| except Exception as e: | |
| print("basic-pitch on full mix failed:", e) | |
| # final fallback | |
| return librosa_mono_pitch_to_midi(audio_filepath) | |
| # ---------- Gradio UI ---------- | |
| CSS = """ | |
| #app_title {font-size: 26px; font-weight: 800} | |
| #app_subtitle {opacity: .8} | |
| """ | |
| with gr.Blocks(css=CSS, title="Demucs + BasicPitch -> Multi-MIDI") as demo: | |
| gr.Markdown("<div id='app_title'>🔊 Separate & Transcribe → Multi-track MIDI</div>" | |
| "<div id='app_subtitle'>Demucs (stems) + Basic-Pitch (polyphonic) pipeline. Fallbacks included.</div>") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| audio_in = gr.Audio(sources=["upload"], type="filepath", label="Audio (mix) - WAV/MP3") | |
| demucs_model = gr.Dropdown(["htdemucs_ft","htdemucs","htdemucs_6s","mdx","mdx_extra"], value=DEMucs_MODEL_NAME, label="Demucs model") | |
| use_basic = gr.Checkbox(value=True, label="Use Basic-Pitch for stems (if available)") | |
| run_btn = gr.Button("🚀 Run pipeline") | |
| with gr.Column(scale=1): | |
| midi_out = gr.File(label="MIDI output") | |
| log_out = gr.Textbox(label="Summary / Log", lines=12) | |
| def run_pipeline(audio_path, demucs_model_name, use_basic_bool): | |
| try: | |
| midi_path, summary = full_pipeline(audio_path, demucs_model=demucs_model_name, use_basic_pitch=use_basic_bool) | |
| return midi_path, str(summary) | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| return None, f"Error: {e}\\n\\nTrace:\\n{tb}" | |
| run_btn.click(run_pipeline, inputs=[audio_in, demucs_model, use_basic], outputs=[midi_out, log_out]) | |
| if __name__ == "__main__": | |
| demo.launch() | |