""" Voice-to-MIDI FastAPI Service Powered by Spotify Basic Pitch Optimized for Hugging Face Spaces CPU deployment """ import os import uuid import time import asyncio import logging import tempfile import threading from pathlib import Path from typing import Optional from contextlib import asynccontextmanager from collections import defaultdict import numpy as np import mido import soundfile as sf import librosa import uvicorn from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks from fastapi.responses import FileResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware # ───────────────────────────────────────────── # Logging # ───────────────────────────────────────────── logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) log = logging.getLogger("voice-to-midi") # ───────────────────────────────────────────── # Constants # ───────────────────────────────────────────── SUPPORTED_TYPES = { "audio/wav", "audio/x-wav", "audio/mpeg", "audio/mp3", "audio/ogg", "audio/x-m4a", "audio/mp4", "audio/aac", "audio/flac", "application/octet-stream", } SUPPORTED_EXTENSIONS = {".wav", ".mp3", ".ogg", ".m4a", ".flac", ".aac"} OUTPUT_DIR = Path(tempfile.gettempdir()) / "voice_midi_outputs" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) MAX_FILE_AGE_SECONDS = 1800 # 30 min – auto-cleanup MAX_UPLOAD_MB = 50 MIDI_NOTE_MIN = 21 # A0 MIDI_NOTE_MAX = 108 # C8 DEFAULT_SAMPLE_RATE = 22050 # ───────────────────────────────────────────── # Global model state (loaded once at startup) # ───────────────────────────────────────────── _model_lock = threading.Lock() _model_loaded = False def _ensure_model_loaded() -> None: """Import basic_pitch and warm up TensorFlow graph on first call.""" global _model_loaded if _model_loaded: return with _model_lock: if _model_loaded: return log.info("Loading Basic Pitch model – this happens once at cold start …") # Importing triggers TF/tflite graph compilation from basic_pitch.inference import predict # noqa: F401 (warm up) _model_loaded = True log.info("Basic Pitch model ready.") # ───────────────────────────────────────────── # App lifespan (startup / shutdown) # ───────────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): # Warm the model in a thread so the event-loop isn't blocked loop = asyncio.get_event_loop() await loop.run_in_executor(None, _ensure_model_loaded) # Schedule background cleanup cleanup_task = asyncio.create_task(_periodic_cleanup()) yield cleanup_task.cancel() app = FastAPI( title="Voice-to-MIDI", description="Convert voice / humming / whistling to MIDI via Spotify Basic Pitch", version="1.0.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ───────────────────────────────────────────── # Background cleanup # ───────────────────────────────────────────── async def _periodic_cleanup(): """Delete MIDI files older than MAX_FILE_AGE_SECONDS every 10 minutes.""" while True: await asyncio.sleep(600) _cleanup_old_files() def _cleanup_old_files(): now = time.time() removed = 0 for f in OUTPUT_DIR.glob("*.mid"): if now - f.stat().st_mtime > MAX_FILE_AGE_SECONDS: try: f.unlink() removed += 1 except OSError: pass if removed: log.info(f"Cleaned up {removed} old MIDI file(s).") # ───────────────────────────────────────────── # Utility helpers # ───────────────────────────────────────────── NOTE_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] def midi_to_note_name(midi_pitch: int) -> str: """Convert MIDI pitch integer to scientific notation, e.g. 60 → C4.""" octave = (midi_pitch // 12) - 1 name = NOTE_NAMES[midi_pitch % 12] return f"{name}{octave}" def load_audio(path: Path) -> tuple[np.ndarray, int]: """Load any audio file to mono float32, resampled to DEFAULT_SAMPLE_RATE.""" try: audio, sr = librosa.load(str(path), sr=DEFAULT_SAMPLE_RATE, mono=True) return audio, sr except Exception as exc: raise HTTPException(status_code=422, detail=f"Cannot decode audio: {exc}") def estimate_bpm(audio: np.ndarray, sr: int) -> Optional[float]: """Use librosa onset/beat tracking to estimate tempo.""" try: tempo, _ = librosa.beat.beat_track(y=audio, sr=sr) val = float(np.atleast_1d(tempo)[0]) return round(val, 1) if 30 < val < 300 else None except Exception: return None # ───────────────────────────────────────────── # MIDI quality post-processing # ───────────────────────────────────────────── def clamp_pitch(pitch: int) -> int: return max(MIDI_NOTE_MIN, min(MIDI_NOTE_MAX, pitch)) def remove_duplicate_overlaps(notes: list[dict]) -> list[dict]: """ Remove notes that are exact duplicates or where one note completely contains another at the same pitch. """ # Sort by pitch then start time notes = sorted(notes, key=lambda n: (n["pitch"], n["start"])) result = [] by_pitch: dict[int, list[dict]] = defaultdict(list) for n in notes: by_pitch[n["pitch"]].append(n) for pitch, group in by_pitch.items(): group = sorted(group, key=lambda n: n["start"]) merged = [group[0]] for note in group[1:]: prev = merged[-1] # If new note starts before previous ends – merge or skip if note["start"] < prev["end"]: # Extend previous note if new one reaches further if note["end"] > prev["end"]: prev["end"] = note["end"] prev["velocity"] = max(prev["velocity"], note["velocity"]) # else: new note is fully contained – skip it else: merged.append(note) result.extend(merged) return sorted(result, key=lambda n: n["start"]) def merge_tiny_notes(notes: list[dict], min_ms: float = 40) -> list[dict]: """Drop notes shorter than min_ms milliseconds.""" min_sec = min_ms / 1000.0 kept = [] for n in notes: if (n["end"] - n["start"]) >= min_sec: kept.append(n) return kept def quantize_notes(notes: list[dict], bpm: float, subdivisions: int = 16) -> list[dict]: """Snap note start/end times to the nearest subdivision grid.""" if bpm <= 0: return notes beat_sec = 60.0 / bpm grid = beat_sec / (subdivisions / 4) # e.g. 16th note def snap(t: float) -> float: return round(round(t / grid) * grid, 4) for n in notes: n["start"] = snap(n["start"]) n["end"] = snap(n["end"]) if n["end"] <= n["start"]: n["end"] = n["start"] + grid return notes def notes_to_midi( notes: list[dict], bpm: float, instrument_name: str = "Acoustic Grand Piano", ) -> mido.MidiFile: """Convert list of note dicts to a mido MidiFile object.""" mid = mido.MidiFile(type=0, ticks_per_beat=480) track = mido.MidiTrack() mid.tracks.append(track) # Tempo tempo = mido.bpm2tempo(bpm) track.append(mido.MetaMessage("set_tempo", tempo=tempo, time=0)) # Program change (General MIDI instrument) gm_programs = { "Acoustic Grand Piano": 0, "Electric Piano": 4, "Violin": 40, "Flute": 73, "Synth Lead": 80, } program = gm_programs.get(instrument_name, 0) track.append(mido.Message("program_change", program=program, time=0)) # Build flat event list: (abs_time_sec, type, pitch, velocity) events = [] for n in notes: events.append((n["start"], "note_on", n["pitch"], n["velocity"])) events.append((n["end"], "note_off", n["pitch"], 0)) events.sort(key=lambda e: (e[0], 0 if e[1] == "note_off" else 1)) def sec_to_ticks(t: float) -> int: return int(mido.second2tick(t, mid.ticks_per_beat, tempo)) prev_ticks = 0 for abs_sec, msg_type, pitch, vel in events: abs_ticks = sec_to_ticks(abs_sec) delta = max(0, abs_ticks - prev_ticks) track.append(mido.Message(msg_type, note=pitch, velocity=vel, time=delta)) prev_ticks = abs_ticks return mid def group_chords(notes: list[dict], window_sec: float = 0.05) -> list[list[dict]]: """Group notes that start within window_sec of each other into chords.""" if not notes: return [] sorted_notes = sorted(notes, key=lambda n: n["start"]) groups = [[sorted_notes[0]]] for note in sorted_notes[1:]: if abs(note["start"] - groups[-1][0]["start"]) <= window_sec: groups[-1].append(note) else: groups.append([note]) return groups # ───────────────────────────────────────────── # Core transcription logic # ───────────────────────────────────────────── def transcribe_audio( audio: np.ndarray, sr: int, bpm_hint: Optional[float], quantize: bool, min_note_length_ms: float, onset_sensitivity: float, ) -> tuple[list[dict], float, Optional[float]]: """ Run Basic Pitch inference and return: - list of note dicts - audio duration (seconds) - detected bpm (or None) """ from basic_pitch.inference import predict from basic_pitch import ICASSP_2022_MODEL_PATH duration = len(audio) / sr # Write temp wav for basic_pitch (it expects a file path) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: tmp_path = tmp.name try: sf.write(tmp_path, audio, sr) model_output, midi_data, note_events = predict( tmp_path, onset_threshold=onset_sensitivity, frame_threshold=0.3, minimum_note_length=min_note_length_ms / 1000.0, minimum_frequency=librosa.midi_to_hz(MIDI_NOTE_MIN), maximum_frequency=librosa.midi_to_hz(MIDI_NOTE_MAX), melodia_trick=True, midi_tempo=bpm_hint or 120, ) finally: os.unlink(tmp_path) # note_events: list of (start_time, end_time, pitch_midi, amplitude, pitch_bends) notes = [] for start, end, pitch, amplitude, _ in note_events: pitch = clamp_pitch(int(round(pitch))) velocity = int(np.clip(amplitude * 127, 1, 127)) notes.append({ "pitch": pitch, "note_name": midi_to_note_name(pitch), "start": round(float(start), 4), "end": round(float(end), 4), "velocity": velocity, "confidence": round(float(amplitude), 4), }) # Post-process notes = merge_tiny_notes(notes, min_ms=min_note_length_ms) notes = remove_duplicate_overlaps(notes) # BPM detected_bpm = bpm_hint or estimate_bpm(audio, sr) if quantize and detected_bpm: notes = quantize_notes(notes, detected_bpm) notes = sorted(notes, key=lambda n: n["start"]) return notes, duration, detected_bpm # ───────────────────────────────────────────── # Routes # ───────────────────────────────────────────── @app.get("/", summary="Health check") async def root(): return { "status": "online", "service": "voice-to-midi", "engine": "basic-pitch", "model_ready": _model_loaded, } @app.get("/health") async def health(): return {"status": "ok"} @app.post("/transcribe", summary="Convert audio to MIDI") async def transcribe( background_tasks: BackgroundTasks, file: UploadFile = File(..., description="Audio file (wav/mp3/ogg/m4a/flac)"), bpm: Optional[float] = Form(None, description="Hint BPM (auto-detected if omitted)"), quantize: bool = Form(False, description="Snap notes to nearest beat grid"), instrument_name: str = Form("Acoustic Grand Piano"), min_note_length_ms: float = Form(40.0, description="Minimum note duration in ms"), onset_sensitivity: float = Form(0.5, description="Onset detection threshold 0–1"), ): # ── Validate file ────────────────────────────────────────────────────── if not file.filename: raise HTTPException(status_code=400, detail="No filename provided.") suffix = Path(file.filename).suffix.lower() if suffix not in SUPPORTED_EXTENSIONS: raise HTTPException( status_code=415, detail=f"Unsupported file type '{suffix}'. Supported: {sorted(SUPPORTED_EXTENSIONS)}", ) content_type = (file.content_type or "").split(";")[0].strip().lower() if content_type and content_type not in SUPPORTED_TYPES: log.warning(f"Unexpected content-type: {content_type} – proceeding anyway.") onset_sensitivity = float(np.clip(onset_sensitivity, 0.1, 0.9)) min_note_length_ms = max(10.0, min(500.0, min_note_length_ms)) # ── Read & size-check upload ─────────────────────────────────────────── raw = await file.read() if len(raw) > MAX_UPLOAD_MB * 1_000_000: raise HTTPException(status_code=413, detail=f"File exceeds {MAX_UPLOAD_MB} MB limit.") if len(raw) < 1024: raise HTTPException(status_code=400, detail="File too small – likely empty or corrupt.") # ── Save to temp, load audio ─────────────────────────────────────────── with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: tmp.write(raw) tmp_input_path = Path(tmp.name) try: audio, sr = load_audio(tmp_input_path) finally: tmp_input_path.unlink(missing_ok=True) if len(audio) / sr < 0.2: raise HTTPException(status_code=422, detail="Audio too short (< 0.2 s).") # ── Ensure model loaded ──────────────────────────────────────────────── _ensure_model_loaded() # ── Run inference ────────────────────────────────────────────────────── t0 = time.perf_counter() try: loop = asyncio.get_event_loop() notes, duration, detected_bpm = await loop.run_in_executor( None, transcribe_audio, audio, sr, bpm, quantize, min_note_length_ms, onset_sensitivity, ) except HTTPException: raise except Exception as exc: log.exception("Inference failed") raise HTTPException(status_code=500, detail=f"Inference error: {exc}") elapsed = time.perf_counter() - t0 log.info(f"Transcribed {duration:.1f}s audio → {len(notes)} notes in {elapsed:.2f}s") if not notes: raise HTTPException( status_code=422, detail="No notes detected. Try adjusting onset_sensitivity or check audio quality.", ) # ── Build MIDI ───────────────────────────────────────────────────────── effective_bpm = detected_bpm or 120.0 midi_obj = notes_to_midi(notes, effective_bpm, instrument_name) midi_id = uuid.uuid4().hex midi_path = OUTPUT_DIR / f"{midi_id}.mid" midi_obj.save(str(midi_path)) # ── Chord grouping meta ──────────────────────────────────────────────── chords = group_chords(notes) chord_count = sum(1 for g in chords if len(g) > 1) return JSONResponse({ "success": True, "notes": notes, "midi_url": f"/download/{midi_id}.mid", "note_count": len(notes), "duration": round(duration, 3), "bpm": effective_bpm, "bpm_source": "hint" if bpm else ("detected" if detected_bpm else "default"), "chord_count": chord_count, "processing_time_sec": round(elapsed, 3), "quantized": quantize, "instrument": instrument_name, }) @app.get("/download/{filename}", summary="Download generated MIDI file") async def download(filename: str): # Security: disallow path traversal if "/" in filename or "\\" in filename or ".." in filename: raise HTTPException(status_code=400, detail="Invalid filename.") if not filename.endswith(".mid"): raise HTTPException(status_code=400, detail="Only .mid files available.") path = OUTPUT_DIR / filename if not path.exists(): raise HTTPException(status_code=404, detail="File not found or expired.") return FileResponse( path=str(path), media_type="audio/midi", filename=filename, headers={"Content-Disposition": f'attachment; filename="{filename}"'}, ) @app.get("/info", summary="Service info & supported formats") async def info(): return { "supported_formats": sorted(SUPPORTED_EXTENSIONS), "max_upload_mb": MAX_UPLOAD_MB, "midi_note_range": {"min": MIDI_NOTE_MIN, "max": MIDI_NOTE_MAX}, "default_sample_rate": DEFAULT_SAMPLE_RATE, "file_expiry_minutes": MAX_FILE_AGE_SECONDS // 60, "endpoints": [ {"path": "/", "method": "GET", "description": "Health check"}, {"path": "/transcribe", "method": "POST", "description": "Audio → MIDI"}, {"path": "/download/{filename}", "method": "GET", "description": "MIDI download"}, ], } # ───────────────────────────────────────────── # Entry point (local dev) # ───────────────────────────────────────────── if __name__ == "__main__": uvicorn.run( "app:app", host="0.0.0.0", port=7860, workers=1, # Single worker on HF Spaces CPU; model is global log_level="info", )