mrblackdev's picture
Update app.py
05960c8 verified
# app.py - Demucs + Basic-Pitch pipeline -> multi-track MIDI (Gradio)
# Author: MrblackDev
# WARNING: heavy deps (demucs, basic-pitch, torch, tensorflow). Use a beefy Space or local env.
import os
import tempfile
import shutil
import subprocess
import traceback
import numpy as np
import librosa
import pretty_midi
import gradio as gr
# Try imports for basic-pitch (tensorflow) if available
HAS_DEMUCS = False
HAS_BASIC_PITCH = False
DEMucs_MODEL_NAME = "htdemucs_ft" # reasonable default
try:
import demucs # noqa: F401
HAS_DEMUCS = True
except Exception:
HAS_DEMUCS = False
try:
# basic_pitch usage per README: import predict + load saved model
import tensorflow as tf # basic-pitch uses TF saved_model
from basic_pitch.inference import predict
from basic_pitch import ICASSP_2022_MODEL_PATH
# load model once (this may be heavy)
try:
BASIC_PITCH_MODEL = tf.saved_model.load(str(ICASSP_2022_MODEL_PATH))
HAS_BASIC_PITCH = True
except Exception as e:
print("Could not load Basic-Pitch saved model:", e)
HAS_BASIC_PITCH = False
except Exception as e:
print("basic-pitch/TensorFlow not available:", e)
HAS_BASIC_PITCH = False
# Fallback simple pipeline (librosa-based) in case heavy libs missing
def librosa_mono_pitch_to_midi(audio_path, hop_length=256, frame_length=2048, bpm=120, quantize=True, division=4):
y, sr = librosa.load(audio_path, sr=None, mono=True)
if np.max(np.abs(y))>0:
y = y / np.max(np.abs(y))
f0, voiced_flag, _ = librosa.pyin(y, fmin=librosa.note_to_hz('C2'), fmax=librosa.note_to_hz('C7'),
sr=sr, frame_length=frame_length, hop_length=hop_length)
f0[~voiced_flag] = np.nan
# group frames into notes (simple)
times = np.arange(len(f0)) * hop_length / sr
midi_vals = np.array([69 + 12 * np.log2(v/440.0) if (v is not None and not np.isnan(v) and v>0) else np.nan for v in f0])
notes = []
i = 0
while i < len(midi_vals):
if np.isnan(midi_vals[i]):
i += 1
continue
v = int(round(midi_vals[i]))
start = i
j = i + 1
while j < len(midi_vals) and not np.isnan(midi_vals[j]) and int(round(midi_vals[j])) == v:
j += 1
t0 = times[start]
t1 = times[j-1] + hop_length/sr
notes.append((v, float(t0), float(t1)))
i = j
pm = pretty_midi.PrettyMIDI()
inst = pretty_midi.Instrument(program=0)
for m,t0,t1 in notes:
inst.notes.append(pretty_midi.Note(velocity=90, pitch=int(m), start=t0, end=t1))
pm.instruments.append(inst)
tmpdir = tempfile.mkdtemp()
out = os.path.join(tmpdir, "fallback.mid")
pm.write(out)
return out, {"engine":"librosa_pyin","notes":len(notes)}
# Utility: run demucs CLI to separate stems
def demucs_separate_cli(audio_path, model_name=DEMucs_MODEL_NAME):
# demucs CLI: demucs -n model audio.wav -o output_dir
out_root = tempfile.mkdtemp()
cmd = ["demucs", "-n", model_name, "-o", out_root, audio_path]
try:
proc = subprocess.run(cmd, capture_output=True, text=True, check=True)
except FileNotFoundError:
# demucs not installed
raise RuntimeError("demucs CLI not found. Please install demucs in the environment.")
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Demucs separation failed: {e.stderr or e.stdout}")
# output dir: out_root/separated/<model_name>/<basename> or demucs creates out_root/<model_name>/<basename>
# find the directory with stems
stems_dir = None
for root, dirs, files in os.walk(out_root):
if any(f.endswith(".wav") for f in files):
stems_dir = root
break
if stems_dir is None:
raise RuntimeError(f"demucs did not produce stems under {out_root}")
# expected stem names: vocals.wav, drums.wav, bass.wav, other.wav (depending on model)
return stems_dir
# Utility: run Basic Pitch inference on a given WAV file
def basic_pitch_transcribe(wav_path, model_obj=None):
"""
Uses basic_pitch.inference.predict(model, wav_path, ...) to produce MIDI bytes or notes.
According to basic-pitch README, predict returns a dict with keys including 'midi' and 'notes'.
We will attempt to call predict(BASIC_PITCH_MODEL, wav_path, **kwargs).
"""
if not HAS_BASIC_PITCH:
raise RuntimeError("basic-pitch is not available in this environment.")
# default parameters: see basic-pitch inference API
try:
# predict returns dict with 'midi' as bytes or file path; adapt based on version
result = predict(model_obj if model_obj is not None else BASIC_PITCH_MODEL,
wav_path,
midi=False, # some versions: midi=True returns bytes, but we prefer structured notes
piano_roll=False)
# 'result' could have 'notes' key listing note dicts like {'start':, 'end':, 'pitch':, 'confidence':}
notes = result.get("notes") or result.get("pred_notes") or []
# Convert notes into pretty_midi instrument
inst = pretty_midi.Instrument(program=0)
for n in notes:
start = float(n.get("start", n.get("onset", 0.0)))
end = float(n.get("end", n.get("offset", start + 0.1)))
pitch = int(round(n.get("pitch", n.get("midi_pitch", 60))))
vel = int(n.get("velocity", 90)) if n.get("velocity") else 90
inst.notes.append(pretty_midi.Note(velocity=vel, pitch=pitch, start=start, end=end))
return inst, {"notes_count": len(inst.notes)}
except Exception as e:
# fallback: raise with info
raise RuntimeError(f"basic_pitch prediction failed: {e}")
# Merge stems transcriptions into a single PrettyMIDI object
def merge_stems_to_midi(stem_paths, use_basic_pitch=True):
"""
stem_paths: dict {stem_name: path_wav}
For each stem:
- If basic-pitch available: transcribe with it (poliphonic)
- Else fallback to librosa_pyin per stem
Returns path_to_midi, summary
"""
pm = pretty_midi.PrettyMIDI()
summary = {"stems": {}, "engine": "mixed"}
for i, (stem_name, path) in enumerate(stem_paths.items()):
try:
if use_basic_pitch and HAS_BASIC_PITCH:
inst, info = basic_pitch_transcribe(path)
# assign instrument program heuristically (vocals->0, bass->32, drums as drum channel)
if stem_name.lower() == "drums" or stem_name.lower().startswith("drum"):
# drums: create drum instrument (is_drum True)
drum_inst = pretty_midi.Instrument(program=0, is_drum=True)
# pretty_midi drum notes are normal notes but set is_drum at instrument level
# copy notes from inst as hits
for n in inst.notes:
drum_inst.notes.append(pretty_midi.Note(velocity=n.velocity, pitch=n.pitch, start=n.start, end=n.end))
pm.instruments.append(drum_inst)
else:
# set program per stem (simple heuristics)
program = 0
if "bass" in stem_name.lower():
program = 32 # acoustic bass
elif "voc" in stem_name.lower() or "vocal" in stem_name.lower():
program = 54 # synth lead (as example)
inst.program = int(program)
pm.instruments.append(inst)
summary["stems"][stem_name] = {"notes": info.get("notes_count", 0), "engine":"basic_pitch"}
else:
# fallback per-stem: librosa pyin then create instrument
out, info = librosa_mono_pitch_to_midi(path)
# load that MIDI and append tracks
midi = pretty_midi.PrettyMIDI(out)
# set program heuristics
for inst in midi.instruments:
if "drum" in stem_name.lower():
inst.is_drum = True
if "bass" in stem_name.lower():
inst.program = 32
pm.instruments.append(inst)
summary["stems"][stem_name] = {"notes": info.get("notes", 0), "engine": "librosa_fallback"}
except Exception as e:
# store error but continue
summary["stems"][stem_name] = {"error": str(e)}
# write midi
tmpdir = tempfile.mkdtemp()
out_midi = os.path.join(tmpdir, "separated_multi.mid")
pm.write(out_midi)
summary["instruments"] = len(pm.instruments)
summary["notes_total"] = sum(len(inst.notes) for inst in pm.instruments)
return out_midi, summary
# High-level pipeline: separate -> transcribe each stem -> merge
def full_pipeline(audio_filepath, demucs_model=DEMucs_MODEL_NAME, use_basic_pitch=True):
# 1) Demucs separation
if HAS_DEMUCS:
try:
stems_dir = demucs_separate_cli(audio_filepath, model_name=demucs_model)
# collect typical stems
available = {}
for name in os.listdir(stems_dir):
if name.endswith(".wav"):
stem_name = os.path.splitext(name)[0]
available[stem_name] = os.path.join(stems_dir, name)
# If demucs produced e.g. mix/<basename>/<stem>.wav or similar, try to find deeper
if not available:
# try nested
for root, dirs, files in os.walk(stems_dir):
for f in files:
if f.endswith(".wav"):
available[os.path.splitext(f)[0]] = os.path.join(root, f)
if not available:
raise RuntimeError("No stems found after Demucs separation.")
# 2) For each stem, transcribe
midi_path, summary = merge_stems_to_midi(available, use_basic_pitch=use_basic_pitch)
return midi_path, {"demucs_model":demucs_model, **summary}
except Exception as e:
traceback.print_exc()
# fallback to mono approach
print("Demucs pipeline failed, falling back to librosa mono pipeline:", e)
return librosa_mono_pitch_to_midi(audio_filepath)
else:
# If demucs not available, fallback to single-track transcribe (basic-pitch on full mix if available)
if use_basic_pitch and HAS_BASIC_PITCH:
try:
# basic-pitch on full mix
inst, info = basic_pitch_transcribe(audio_filepath)
pm = pretty_midi.PrettyMIDI()
inst.program = 0
pm.instruments.append(inst)
tmpdir = tempfile.mkdtemp()
out = os.path.join(tmpdir, "basicpitch_full.mid")
pm.write(out)
return out, {"engine":"basic_pitch_full","notes":info.get("notes_count",0)}
except Exception as e:
print("basic-pitch on full mix failed:", e)
# final fallback
return librosa_mono_pitch_to_midi(audio_filepath)
# ---------- Gradio UI ----------
CSS = """
#app_title {font-size: 26px; font-weight: 800}
#app_subtitle {opacity: .8}
"""
with gr.Blocks(css=CSS, title="Demucs + BasicPitch -> Multi-MIDI") as demo:
gr.Markdown("<div id='app_title'>🔊 Separate & Transcribe → Multi-track MIDI</div>"
"<div id='app_subtitle'>Demucs (stems) + Basic-Pitch (polyphonic) pipeline. Fallbacks included.</div>")
with gr.Row():
with gr.Column(scale=2):
audio_in = gr.Audio(sources=["upload"], type="filepath", label="Audio (mix) - WAV/MP3")
demucs_model = gr.Dropdown(["htdemucs_ft","htdemucs","htdemucs_6s","mdx","mdx_extra"], value=DEMucs_MODEL_NAME, label="Demucs model")
use_basic = gr.Checkbox(value=True, label="Use Basic-Pitch for stems (if available)")
run_btn = gr.Button("🚀 Run pipeline")
with gr.Column(scale=1):
midi_out = gr.File(label="MIDI output")
log_out = gr.Textbox(label="Summary / Log", lines=12)
def run_pipeline(audio_path, demucs_model_name, use_basic_bool):
try:
midi_path, summary = full_pipeline(audio_path, demucs_model=demucs_model_name, use_basic_pitch=use_basic_bool)
return midi_path, str(summary)
except Exception as e:
tb = traceback.format_exc()
return None, f"Error: {e}\\n\\nTrace:\\n{tb}"
run_btn.click(run_pipeline, inputs=[audio_in, demucs_model, use_basic], outputs=[midi_out, log_out])
if __name__ == "__main__":
demo.launch()