vtm / app.py
fomext's picture
Upload 4 files
24e6a09 verified
Raw
History Blame Contribute Delete
20.1 kB
"""
Voice-to-MIDI FastAPI Service
Powered by Spotify Basic Pitch
Optimized for Hugging Face Spaces CPU deployment
"""
import os
import uuid
import time
import asyncio
import logging
import tempfile
import threading
from pathlib import Path
from typing import Optional
from contextlib import asynccontextmanager
from collections import defaultdict
import numpy as np
import mido
import soundfile as sf
import librosa
import uvicorn
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks
from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
# ─────────────────────────────────────────────
# Logging
# ─────────────────────────────────────────────
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
log = logging.getLogger("voice-to-midi")
# ─────────────────────────────────────────────
# Constants
# ─────────────────────────────────────────────
SUPPORTED_TYPES = {
"audio/wav", "audio/x-wav", "audio/mpeg", "audio/mp3",
"audio/ogg", "audio/x-m4a", "audio/mp4", "audio/aac",
"audio/flac", "application/octet-stream",
}
SUPPORTED_EXTENSIONS = {".wav", ".mp3", ".ogg", ".m4a", ".flac", ".aac"}
OUTPUT_DIR = Path(tempfile.gettempdir()) / "voice_midi_outputs"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
MAX_FILE_AGE_SECONDS = 1800 # 30 min – auto-cleanup
MAX_UPLOAD_MB = 50
MIDI_NOTE_MIN = 21 # A0
MIDI_NOTE_MAX = 108 # C8
DEFAULT_SAMPLE_RATE = 22050
# ─────────────────────────────────────────────
# Global model state (loaded once at startup)
# ─────────────────────────────────────────────
_model_lock = threading.Lock()
_model_loaded = False
def _ensure_model_loaded() -> None:
"""Import basic_pitch and warm up TensorFlow graph on first call."""
global _model_loaded
if _model_loaded:
return
with _model_lock:
if _model_loaded:
return
log.info("Loading Basic Pitch model – this happens once at cold start …")
# Importing triggers TF/tflite graph compilation
from basic_pitch.inference import predict # noqa: F401 (warm up)
_model_loaded = True
log.info("Basic Pitch model ready.")
# ─────────────────────────────────────────────
# App lifespan (startup / shutdown)
# ─────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
# Warm the model in a thread so the event-loop isn't blocked
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, _ensure_model_loaded)
# Schedule background cleanup
cleanup_task = asyncio.create_task(_periodic_cleanup())
yield
cleanup_task.cancel()
app = FastAPI(
title="Voice-to-MIDI",
description="Convert voice / humming / whistling to MIDI via Spotify Basic Pitch",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ─────────────────────────────────────────────
# Background cleanup
# ─────────────────────────────────────────────
async def _periodic_cleanup():
"""Delete MIDI files older than MAX_FILE_AGE_SECONDS every 10 minutes."""
while True:
await asyncio.sleep(600)
_cleanup_old_files()
def _cleanup_old_files():
now = time.time()
removed = 0
for f in OUTPUT_DIR.glob("*.mid"):
if now - f.stat().st_mtime > MAX_FILE_AGE_SECONDS:
try:
f.unlink()
removed += 1
except OSError:
pass
if removed:
log.info(f"Cleaned up {removed} old MIDI file(s).")
# ─────────────────────────────────────────────
# Utility helpers
# ─────────────────────────────────────────────
NOTE_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
def midi_to_note_name(midi_pitch: int) -> str:
"""Convert MIDI pitch integer to scientific notation, e.g. 60 β†’ C4."""
octave = (midi_pitch // 12) - 1
name = NOTE_NAMES[midi_pitch % 12]
return f"{name}{octave}"
def load_audio(path: Path) -> tuple[np.ndarray, int]:
"""Load any audio file to mono float32, resampled to DEFAULT_SAMPLE_RATE."""
try:
audio, sr = librosa.load(str(path), sr=DEFAULT_SAMPLE_RATE, mono=True)
return audio, sr
except Exception as exc:
raise HTTPException(status_code=422, detail=f"Cannot decode audio: {exc}")
def estimate_bpm(audio: np.ndarray, sr: int) -> Optional[float]:
"""Use librosa onset/beat tracking to estimate tempo."""
try:
tempo, _ = librosa.beat.beat_track(y=audio, sr=sr)
val = float(np.atleast_1d(tempo)[0])
return round(val, 1) if 30 < val < 300 else None
except Exception:
return None
# ─────────────────────────────────────────────
# MIDI quality post-processing
# ─────────────────────────────────────────────
def clamp_pitch(pitch: int) -> int:
return max(MIDI_NOTE_MIN, min(MIDI_NOTE_MAX, pitch))
def remove_duplicate_overlaps(notes: list[dict]) -> list[dict]:
"""
Remove notes that are exact duplicates or where one note completely
contains another at the same pitch.
"""
# Sort by pitch then start time
notes = sorted(notes, key=lambda n: (n["pitch"], n["start"]))
result = []
by_pitch: dict[int, list[dict]] = defaultdict(list)
for n in notes:
by_pitch[n["pitch"]].append(n)
for pitch, group in by_pitch.items():
group = sorted(group, key=lambda n: n["start"])
merged = [group[0]]
for note in group[1:]:
prev = merged[-1]
# If new note starts before previous ends – merge or skip
if note["start"] < prev["end"]:
# Extend previous note if new one reaches further
if note["end"] > prev["end"]:
prev["end"] = note["end"]
prev["velocity"] = max(prev["velocity"], note["velocity"])
# else: new note is fully contained – skip it
else:
merged.append(note)
result.extend(merged)
return sorted(result, key=lambda n: n["start"])
def merge_tiny_notes(notes: list[dict], min_ms: float = 40) -> list[dict]:
"""Drop notes shorter than min_ms milliseconds."""
min_sec = min_ms / 1000.0
kept = []
for n in notes:
if (n["end"] - n["start"]) >= min_sec:
kept.append(n)
return kept
def quantize_notes(notes: list[dict], bpm: float, subdivisions: int = 16) -> list[dict]:
"""Snap note start/end times to the nearest subdivision grid."""
if bpm <= 0:
return notes
beat_sec = 60.0 / bpm
grid = beat_sec / (subdivisions / 4) # e.g. 16th note
def snap(t: float) -> float:
return round(round(t / grid) * grid, 4)
for n in notes:
n["start"] = snap(n["start"])
n["end"] = snap(n["end"])
if n["end"] <= n["start"]:
n["end"] = n["start"] + grid
return notes
def notes_to_midi(
notes: list[dict],
bpm: float,
instrument_name: str = "Acoustic Grand Piano",
) -> mido.MidiFile:
"""Convert list of note dicts to a mido MidiFile object."""
mid = mido.MidiFile(type=0, ticks_per_beat=480)
track = mido.MidiTrack()
mid.tracks.append(track)
# Tempo
tempo = mido.bpm2tempo(bpm)
track.append(mido.MetaMessage("set_tempo", tempo=tempo, time=0))
# Program change (General MIDI instrument)
gm_programs = {
"Acoustic Grand Piano": 0, "Electric Piano": 4,
"Violin": 40, "Flute": 73, "Synth Lead": 80,
}
program = gm_programs.get(instrument_name, 0)
track.append(mido.Message("program_change", program=program, time=0))
# Build flat event list: (abs_time_sec, type, pitch, velocity)
events = []
for n in notes:
events.append((n["start"], "note_on", n["pitch"], n["velocity"]))
events.append((n["end"], "note_off", n["pitch"], 0))
events.sort(key=lambda e: (e[0], 0 if e[1] == "note_off" else 1))
def sec_to_ticks(t: float) -> int:
return int(mido.second2tick(t, mid.ticks_per_beat, tempo))
prev_ticks = 0
for abs_sec, msg_type, pitch, vel in events:
abs_ticks = sec_to_ticks(abs_sec)
delta = max(0, abs_ticks - prev_ticks)
track.append(mido.Message(msg_type, note=pitch, velocity=vel, time=delta))
prev_ticks = abs_ticks
return mid
def group_chords(notes: list[dict], window_sec: float = 0.05) -> list[list[dict]]:
"""Group notes that start within window_sec of each other into chords."""
if not notes:
return []
sorted_notes = sorted(notes, key=lambda n: n["start"])
groups = [[sorted_notes[0]]]
for note in sorted_notes[1:]:
if abs(note["start"] - groups[-1][0]["start"]) <= window_sec:
groups[-1].append(note)
else:
groups.append([note])
return groups
# ─────────────────────────────────────────────
# Core transcription logic
# ─────────────────────────────────────────────
def transcribe_audio(
audio: np.ndarray,
sr: int,
bpm_hint: Optional[float],
quantize: bool,
min_note_length_ms: float,
onset_sensitivity: float,
) -> tuple[list[dict], float, Optional[float]]:
"""
Run Basic Pitch inference and return:
- list of note dicts
- audio duration (seconds)
- detected bpm (or None)
"""
from basic_pitch.inference import predict
from basic_pitch import ICASSP_2022_MODEL_PATH
duration = len(audio) / sr
# Write temp wav for basic_pitch (it expects a file path)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp_path = tmp.name
try:
sf.write(tmp_path, audio, sr)
model_output, midi_data, note_events = predict(
tmp_path,
onset_threshold=onset_sensitivity,
frame_threshold=0.3,
minimum_note_length=min_note_length_ms / 1000.0,
minimum_frequency=librosa.midi_to_hz(MIDI_NOTE_MIN),
maximum_frequency=librosa.midi_to_hz(MIDI_NOTE_MAX),
melodia_trick=True,
midi_tempo=bpm_hint or 120,
)
finally:
os.unlink(tmp_path)
# note_events: list of (start_time, end_time, pitch_midi, amplitude, pitch_bends)
notes = []
for start, end, pitch, amplitude, _ in note_events:
pitch = clamp_pitch(int(round(pitch)))
velocity = int(np.clip(amplitude * 127, 1, 127))
notes.append({
"pitch": pitch,
"note_name": midi_to_note_name(pitch),
"start": round(float(start), 4),
"end": round(float(end), 4),
"velocity": velocity,
"confidence": round(float(amplitude), 4),
})
# Post-process
notes = merge_tiny_notes(notes, min_ms=min_note_length_ms)
notes = remove_duplicate_overlaps(notes)
# BPM
detected_bpm = bpm_hint or estimate_bpm(audio, sr)
if quantize and detected_bpm:
notes = quantize_notes(notes, detected_bpm)
notes = sorted(notes, key=lambda n: n["start"])
return notes, duration, detected_bpm
# ─────────────────────────────────────────────
# Routes
# ─────────────────────────────────────────────
@app.get("/", summary="Health check")
async def root():
return {
"status": "online",
"service": "voice-to-midi",
"engine": "basic-pitch",
"model_ready": _model_loaded,
}
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/transcribe", summary="Convert audio to MIDI")
async def transcribe(
background_tasks: BackgroundTasks,
file: UploadFile = File(..., description="Audio file (wav/mp3/ogg/m4a/flac)"),
bpm: Optional[float] = Form(None, description="Hint BPM (auto-detected if omitted)"),
quantize: bool = Form(False, description="Snap notes to nearest beat grid"),
instrument_name: str = Form("Acoustic Grand Piano"),
min_note_length_ms: float = Form(40.0, description="Minimum note duration in ms"),
onset_sensitivity: float = Form(0.5, description="Onset detection threshold 0–1"),
):
# ── Validate file ──────────────────────────────────────────────────────
if not file.filename:
raise HTTPException(status_code=400, detail="No filename provided.")
suffix = Path(file.filename).suffix.lower()
if suffix not in SUPPORTED_EXTENSIONS:
raise HTTPException(
status_code=415,
detail=f"Unsupported file type '{suffix}'. Supported: {sorted(SUPPORTED_EXTENSIONS)}",
)
content_type = (file.content_type or "").split(";")[0].strip().lower()
if content_type and content_type not in SUPPORTED_TYPES:
log.warning(f"Unexpected content-type: {content_type} – proceeding anyway.")
onset_sensitivity = float(np.clip(onset_sensitivity, 0.1, 0.9))
min_note_length_ms = max(10.0, min(500.0, min_note_length_ms))
# ── Read & size-check upload ───────────────────────────────────────────
raw = await file.read()
if len(raw) > MAX_UPLOAD_MB * 1_000_000:
raise HTTPException(status_code=413, detail=f"File exceeds {MAX_UPLOAD_MB} MB limit.")
if len(raw) < 1024:
raise HTTPException(status_code=400, detail="File too small – likely empty or corrupt.")
# ── Save to temp, load audio ───────────────────────────────────────────
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(raw)
tmp_input_path = Path(tmp.name)
try:
audio, sr = load_audio(tmp_input_path)
finally:
tmp_input_path.unlink(missing_ok=True)
if len(audio) / sr < 0.2:
raise HTTPException(status_code=422, detail="Audio too short (< 0.2 s).")
# ── Ensure model loaded ────────────────────────────────────────────────
_ensure_model_loaded()
# ── Run inference ──────────────────────────────────────────────────────
t0 = time.perf_counter()
try:
loop = asyncio.get_event_loop()
notes, duration, detected_bpm = await loop.run_in_executor(
None,
transcribe_audio,
audio, sr, bpm, quantize, min_note_length_ms, onset_sensitivity,
)
except HTTPException:
raise
except Exception as exc:
log.exception("Inference failed")
raise HTTPException(status_code=500, detail=f"Inference error: {exc}")
elapsed = time.perf_counter() - t0
log.info(f"Transcribed {duration:.1f}s audio β†’ {len(notes)} notes in {elapsed:.2f}s")
if not notes:
raise HTTPException(
status_code=422,
detail="No notes detected. Try adjusting onset_sensitivity or check audio quality.",
)
# ── Build MIDI ─────────────────────────────────────────────────────────
effective_bpm = detected_bpm or 120.0
midi_obj = notes_to_midi(notes, effective_bpm, instrument_name)
midi_id = uuid.uuid4().hex
midi_path = OUTPUT_DIR / f"{midi_id}.mid"
midi_obj.save(str(midi_path))
# ── Chord grouping meta ────────────────────────────────────────────────
chords = group_chords(notes)
chord_count = sum(1 for g in chords if len(g) > 1)
return JSONResponse({
"success": True,
"notes": notes,
"midi_url": f"/download/{midi_id}.mid",
"note_count": len(notes),
"duration": round(duration, 3),
"bpm": effective_bpm,
"bpm_source": "hint" if bpm else ("detected" if detected_bpm else "default"),
"chord_count": chord_count,
"processing_time_sec": round(elapsed, 3),
"quantized": quantize,
"instrument": instrument_name,
})
@app.get("/download/{filename}", summary="Download generated MIDI file")
async def download(filename: str):
# Security: disallow path traversal
if "/" in filename or "\\" in filename or ".." in filename:
raise HTTPException(status_code=400, detail="Invalid filename.")
if not filename.endswith(".mid"):
raise HTTPException(status_code=400, detail="Only .mid files available.")
path = OUTPUT_DIR / filename
if not path.exists():
raise HTTPException(status_code=404, detail="File not found or expired.")
return FileResponse(
path=str(path),
media_type="audio/midi",
filename=filename,
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@app.get("/info", summary="Service info & supported formats")
async def info():
return {
"supported_formats": sorted(SUPPORTED_EXTENSIONS),
"max_upload_mb": MAX_UPLOAD_MB,
"midi_note_range": {"min": MIDI_NOTE_MIN, "max": MIDI_NOTE_MAX},
"default_sample_rate": DEFAULT_SAMPLE_RATE,
"file_expiry_minutes": MAX_FILE_AGE_SECONDS // 60,
"endpoints": [
{"path": "/", "method": "GET", "description": "Health check"},
{"path": "/transcribe", "method": "POST", "description": "Audio β†’ MIDI"},
{"path": "/download/{filename}", "method": "GET", "description": "MIDI download"},
],
}
# ─────────────────────────────────────────────
# Entry point (local dev)
# ─────────────────────────────────────────────
if __name__ == "__main__":
uvicorn.run(
"app:app",
host="0.0.0.0",
port=7860,
workers=1, # Single worker on HF Spaces CPU; model is global
log_level="info",
)