| """ |
| Voice-to-MIDI FastAPI Service |
| Powered by Spotify Basic Pitch |
| Optimized for Hugging Face Spaces CPU deployment |
| """ |
|
|
| import os |
| import uuid |
| import time |
| import asyncio |
| import logging |
| import tempfile |
| import threading |
| from pathlib import Path |
| from typing import Optional |
| from contextlib import asynccontextmanager |
| from collections import defaultdict |
|
|
| import numpy as np |
| import mido |
| import soundfile as sf |
| import librosa |
| import uvicorn |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks |
| from fastapi.responses import FileResponse, JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| |
| |
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
| ) |
| log = logging.getLogger("voice-to-midi") |
|
|
| |
| |
| |
| SUPPORTED_TYPES = { |
| "audio/wav", "audio/x-wav", "audio/mpeg", "audio/mp3", |
| "audio/ogg", "audio/x-m4a", "audio/mp4", "audio/aac", |
| "audio/flac", "application/octet-stream", |
| } |
| SUPPORTED_EXTENSIONS = {".wav", ".mp3", ".ogg", ".m4a", ".flac", ".aac"} |
| OUTPUT_DIR = Path(tempfile.gettempdir()) / "voice_midi_outputs" |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
| MAX_FILE_AGE_SECONDS = 1800 |
| MAX_UPLOAD_MB = 50 |
| MIDI_NOTE_MIN = 21 |
| MIDI_NOTE_MAX = 108 |
| DEFAULT_SAMPLE_RATE = 22050 |
|
|
| |
| |
| |
| _model_lock = threading.Lock() |
| _model_loaded = False |
|
|
|
|
| def _ensure_model_loaded() -> None: |
| """Import basic_pitch and warm up TensorFlow graph on first call.""" |
| global _model_loaded |
| if _model_loaded: |
| return |
| with _model_lock: |
| if _model_loaded: |
| return |
| log.info("Loading Basic Pitch model β this happens once at cold start β¦") |
| |
| from basic_pitch.inference import predict |
| _model_loaded = True |
| log.info("Basic Pitch model ready.") |
|
|
|
|
| |
| |
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| |
| loop = asyncio.get_event_loop() |
| await loop.run_in_executor(None, _ensure_model_loaded) |
| |
| cleanup_task = asyncio.create_task(_periodic_cleanup()) |
| yield |
| cleanup_task.cancel() |
|
|
|
|
| app = FastAPI( |
| title="Voice-to-MIDI", |
| description="Convert voice / humming / whistling to MIDI via Spotify Basic Pitch", |
| version="1.0.0", |
| lifespan=lifespan, |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| |
| |
| |
| async def _periodic_cleanup(): |
| """Delete MIDI files older than MAX_FILE_AGE_SECONDS every 10 minutes.""" |
| while True: |
| await asyncio.sleep(600) |
| _cleanup_old_files() |
|
|
|
|
| def _cleanup_old_files(): |
| now = time.time() |
| removed = 0 |
| for f in OUTPUT_DIR.glob("*.mid"): |
| if now - f.stat().st_mtime > MAX_FILE_AGE_SECONDS: |
| try: |
| f.unlink() |
| removed += 1 |
| except OSError: |
| pass |
| if removed: |
| log.info(f"Cleaned up {removed} old MIDI file(s).") |
|
|
|
|
| |
| |
| |
|
|
| NOTE_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] |
|
|
|
|
| def midi_to_note_name(midi_pitch: int) -> str: |
| """Convert MIDI pitch integer to scientific notation, e.g. 60 β C4.""" |
| octave = (midi_pitch // 12) - 1 |
| name = NOTE_NAMES[midi_pitch % 12] |
| return f"{name}{octave}" |
|
|
|
|
| def load_audio(path: Path) -> tuple[np.ndarray, int]: |
| """Load any audio file to mono float32, resampled to DEFAULT_SAMPLE_RATE.""" |
| try: |
| audio, sr = librosa.load(str(path), sr=DEFAULT_SAMPLE_RATE, mono=True) |
| return audio, sr |
| except Exception as exc: |
| raise HTTPException(status_code=422, detail=f"Cannot decode audio: {exc}") |
|
|
|
|
| def estimate_bpm(audio: np.ndarray, sr: int) -> Optional[float]: |
| """Use librosa onset/beat tracking to estimate tempo.""" |
| try: |
| tempo, _ = librosa.beat.beat_track(y=audio, sr=sr) |
| val = float(np.atleast_1d(tempo)[0]) |
| return round(val, 1) if 30 < val < 300 else None |
| except Exception: |
| return None |
|
|
|
|
| |
| |
| |
|
|
| def clamp_pitch(pitch: int) -> int: |
| return max(MIDI_NOTE_MIN, min(MIDI_NOTE_MAX, pitch)) |
|
|
|
|
| def remove_duplicate_overlaps(notes: list[dict]) -> list[dict]: |
| """ |
| Remove notes that are exact duplicates or where one note completely |
| contains another at the same pitch. |
| """ |
| |
| notes = sorted(notes, key=lambda n: (n["pitch"], n["start"])) |
| result = [] |
| by_pitch: dict[int, list[dict]] = defaultdict(list) |
| for n in notes: |
| by_pitch[n["pitch"]].append(n) |
|
|
| for pitch, group in by_pitch.items(): |
| group = sorted(group, key=lambda n: n["start"]) |
| merged = [group[0]] |
| for note in group[1:]: |
| prev = merged[-1] |
| |
| if note["start"] < prev["end"]: |
| |
| if note["end"] > prev["end"]: |
| prev["end"] = note["end"] |
| prev["velocity"] = max(prev["velocity"], note["velocity"]) |
| |
| else: |
| merged.append(note) |
| result.extend(merged) |
|
|
| return sorted(result, key=lambda n: n["start"]) |
|
|
|
|
| def merge_tiny_notes(notes: list[dict], min_ms: float = 40) -> list[dict]: |
| """Drop notes shorter than min_ms milliseconds.""" |
| min_sec = min_ms / 1000.0 |
| kept = [] |
| for n in notes: |
| if (n["end"] - n["start"]) >= min_sec: |
| kept.append(n) |
| return kept |
|
|
|
|
| def quantize_notes(notes: list[dict], bpm: float, subdivisions: int = 16) -> list[dict]: |
| """Snap note start/end times to the nearest subdivision grid.""" |
| if bpm <= 0: |
| return notes |
| beat_sec = 60.0 / bpm |
| grid = beat_sec / (subdivisions / 4) |
|
|
| def snap(t: float) -> float: |
| return round(round(t / grid) * grid, 4) |
|
|
| for n in notes: |
| n["start"] = snap(n["start"]) |
| n["end"] = snap(n["end"]) |
| if n["end"] <= n["start"]: |
| n["end"] = n["start"] + grid |
| return notes |
|
|
|
|
| def notes_to_midi( |
| notes: list[dict], |
| bpm: float, |
| instrument_name: str = "Acoustic Grand Piano", |
| ) -> mido.MidiFile: |
| """Convert list of note dicts to a mido MidiFile object.""" |
| mid = mido.MidiFile(type=0, ticks_per_beat=480) |
| track = mido.MidiTrack() |
| mid.tracks.append(track) |
|
|
| |
| tempo = mido.bpm2tempo(bpm) |
| track.append(mido.MetaMessage("set_tempo", tempo=tempo, time=0)) |
|
|
| |
| gm_programs = { |
| "Acoustic Grand Piano": 0, "Electric Piano": 4, |
| "Violin": 40, "Flute": 73, "Synth Lead": 80, |
| } |
| program = gm_programs.get(instrument_name, 0) |
| track.append(mido.Message("program_change", program=program, time=0)) |
|
|
| |
| events = [] |
| for n in notes: |
| events.append((n["start"], "note_on", n["pitch"], n["velocity"])) |
| events.append((n["end"], "note_off", n["pitch"], 0)) |
|
|
| events.sort(key=lambda e: (e[0], 0 if e[1] == "note_off" else 1)) |
|
|
| def sec_to_ticks(t: float) -> int: |
| return int(mido.second2tick(t, mid.ticks_per_beat, tempo)) |
|
|
| prev_ticks = 0 |
| for abs_sec, msg_type, pitch, vel in events: |
| abs_ticks = sec_to_ticks(abs_sec) |
| delta = max(0, abs_ticks - prev_ticks) |
| track.append(mido.Message(msg_type, note=pitch, velocity=vel, time=delta)) |
| prev_ticks = abs_ticks |
|
|
| return mid |
|
|
|
|
| def group_chords(notes: list[dict], window_sec: float = 0.05) -> list[list[dict]]: |
| """Group notes that start within window_sec of each other into chords.""" |
| if not notes: |
| return [] |
| sorted_notes = sorted(notes, key=lambda n: n["start"]) |
| groups = [[sorted_notes[0]]] |
| for note in sorted_notes[1:]: |
| if abs(note["start"] - groups[-1][0]["start"]) <= window_sec: |
| groups[-1].append(note) |
| else: |
| groups.append([note]) |
| return groups |
|
|
|
|
| |
| |
| |
|
|
| def transcribe_audio( |
| audio: np.ndarray, |
| sr: int, |
| bpm_hint: Optional[float], |
| quantize: bool, |
| min_note_length_ms: float, |
| onset_sensitivity: float, |
| ) -> tuple[list[dict], float, Optional[float]]: |
| """ |
| Run Basic Pitch inference and return: |
| - list of note dicts |
| - audio duration (seconds) |
| - detected bpm (or None) |
| """ |
| from basic_pitch.inference import predict |
| from basic_pitch import ICASSP_2022_MODEL_PATH |
|
|
| duration = len(audio) / sr |
|
|
| |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
| tmp_path = tmp.name |
| try: |
| sf.write(tmp_path, audio, sr) |
|
|
| model_output, midi_data, note_events = predict( |
| tmp_path, |
| onset_threshold=onset_sensitivity, |
| frame_threshold=0.3, |
| minimum_note_length=min_note_length_ms / 1000.0, |
| minimum_frequency=librosa.midi_to_hz(MIDI_NOTE_MIN), |
| maximum_frequency=librosa.midi_to_hz(MIDI_NOTE_MAX), |
| melodia_trick=True, |
| midi_tempo=bpm_hint or 120, |
| ) |
| finally: |
| os.unlink(tmp_path) |
|
|
| |
| notes = [] |
| for start, end, pitch, amplitude, _ in note_events: |
| pitch = clamp_pitch(int(round(pitch))) |
| velocity = int(np.clip(amplitude * 127, 1, 127)) |
| notes.append({ |
| "pitch": pitch, |
| "note_name": midi_to_note_name(pitch), |
| "start": round(float(start), 4), |
| "end": round(float(end), 4), |
| "velocity": velocity, |
| "confidence": round(float(amplitude), 4), |
| }) |
|
|
| |
| notes = merge_tiny_notes(notes, min_ms=min_note_length_ms) |
| notes = remove_duplicate_overlaps(notes) |
|
|
| |
| detected_bpm = bpm_hint or estimate_bpm(audio, sr) |
| if quantize and detected_bpm: |
| notes = quantize_notes(notes, detected_bpm) |
|
|
| notes = sorted(notes, key=lambda n: n["start"]) |
| return notes, duration, detected_bpm |
|
|
|
|
| |
| |
| |
|
|
| @app.get("/", summary="Health check") |
| async def root(): |
| return { |
| "status": "online", |
| "service": "voice-to-midi", |
| "engine": "basic-pitch", |
| "model_ready": _model_loaded, |
| } |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| return {"status": "ok"} |
|
|
|
|
| @app.post("/transcribe", summary="Convert audio to MIDI") |
| async def transcribe( |
| background_tasks: BackgroundTasks, |
| file: UploadFile = File(..., description="Audio file (wav/mp3/ogg/m4a/flac)"), |
| bpm: Optional[float] = Form(None, description="Hint BPM (auto-detected if omitted)"), |
| quantize: bool = Form(False, description="Snap notes to nearest beat grid"), |
| instrument_name: str = Form("Acoustic Grand Piano"), |
| min_note_length_ms: float = Form(40.0, description="Minimum note duration in ms"), |
| onset_sensitivity: float = Form(0.5, description="Onset detection threshold 0β1"), |
| ): |
| |
| if not file.filename: |
| raise HTTPException(status_code=400, detail="No filename provided.") |
|
|
| suffix = Path(file.filename).suffix.lower() |
| if suffix not in SUPPORTED_EXTENSIONS: |
| raise HTTPException( |
| status_code=415, |
| detail=f"Unsupported file type '{suffix}'. Supported: {sorted(SUPPORTED_EXTENSIONS)}", |
| ) |
|
|
| content_type = (file.content_type or "").split(";")[0].strip().lower() |
| if content_type and content_type not in SUPPORTED_TYPES: |
| log.warning(f"Unexpected content-type: {content_type} β proceeding anyway.") |
|
|
| onset_sensitivity = float(np.clip(onset_sensitivity, 0.1, 0.9)) |
| min_note_length_ms = max(10.0, min(500.0, min_note_length_ms)) |
|
|
| |
| raw = await file.read() |
| if len(raw) > MAX_UPLOAD_MB * 1_000_000: |
| raise HTTPException(status_code=413, detail=f"File exceeds {MAX_UPLOAD_MB} MB limit.") |
| if len(raw) < 1024: |
| raise HTTPException(status_code=400, detail="File too small β likely empty or corrupt.") |
|
|
| |
| with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: |
| tmp.write(raw) |
| tmp_input_path = Path(tmp.name) |
|
|
| try: |
| audio, sr = load_audio(tmp_input_path) |
| finally: |
| tmp_input_path.unlink(missing_ok=True) |
|
|
| if len(audio) / sr < 0.2: |
| raise HTTPException(status_code=422, detail="Audio too short (< 0.2 s).") |
|
|
| |
| _ensure_model_loaded() |
|
|
| |
| t0 = time.perf_counter() |
| try: |
| loop = asyncio.get_event_loop() |
| notes, duration, detected_bpm = await loop.run_in_executor( |
| None, |
| transcribe_audio, |
| audio, sr, bpm, quantize, min_note_length_ms, onset_sensitivity, |
| ) |
| except HTTPException: |
| raise |
| except Exception as exc: |
| log.exception("Inference failed") |
| raise HTTPException(status_code=500, detail=f"Inference error: {exc}") |
|
|
| elapsed = time.perf_counter() - t0 |
| log.info(f"Transcribed {duration:.1f}s audio β {len(notes)} notes in {elapsed:.2f}s") |
|
|
| if not notes: |
| raise HTTPException( |
| status_code=422, |
| detail="No notes detected. Try adjusting onset_sensitivity or check audio quality.", |
| ) |
|
|
| |
| effective_bpm = detected_bpm or 120.0 |
| midi_obj = notes_to_midi(notes, effective_bpm, instrument_name) |
| midi_id = uuid.uuid4().hex |
| midi_path = OUTPUT_DIR / f"{midi_id}.mid" |
| midi_obj.save(str(midi_path)) |
|
|
| |
| chords = group_chords(notes) |
| chord_count = sum(1 for g in chords if len(g) > 1) |
|
|
| return JSONResponse({ |
| "success": True, |
| "notes": notes, |
| "midi_url": f"/download/{midi_id}.mid", |
| "note_count": len(notes), |
| "duration": round(duration, 3), |
| "bpm": effective_bpm, |
| "bpm_source": "hint" if bpm else ("detected" if detected_bpm else "default"), |
| "chord_count": chord_count, |
| "processing_time_sec": round(elapsed, 3), |
| "quantized": quantize, |
| "instrument": instrument_name, |
| }) |
|
|
|
|
| @app.get("/download/{filename}", summary="Download generated MIDI file") |
| async def download(filename: str): |
| |
| if "/" in filename or "\\" in filename or ".." in filename: |
| raise HTTPException(status_code=400, detail="Invalid filename.") |
| if not filename.endswith(".mid"): |
| raise HTTPException(status_code=400, detail="Only .mid files available.") |
|
|
| path = OUTPUT_DIR / filename |
| if not path.exists(): |
| raise HTTPException(status_code=404, detail="File not found or expired.") |
|
|
| return FileResponse( |
| path=str(path), |
| media_type="audio/midi", |
| filename=filename, |
| headers={"Content-Disposition": f'attachment; filename="{filename}"'}, |
| ) |
|
|
|
|
| @app.get("/info", summary="Service info & supported formats") |
| async def info(): |
| return { |
| "supported_formats": sorted(SUPPORTED_EXTENSIONS), |
| "max_upload_mb": MAX_UPLOAD_MB, |
| "midi_note_range": {"min": MIDI_NOTE_MIN, "max": MIDI_NOTE_MAX}, |
| "default_sample_rate": DEFAULT_SAMPLE_RATE, |
| "file_expiry_minutes": MAX_FILE_AGE_SECONDS // 60, |
| "endpoints": [ |
| {"path": "/", "method": "GET", "description": "Health check"}, |
| {"path": "/transcribe", "method": "POST", "description": "Audio β MIDI"}, |
| {"path": "/download/{filename}", "method": "GET", "description": "MIDI download"}, |
| ], |
| } |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| uvicorn.run( |
| "app:app", |
| host="0.0.0.0", |
| port=7860, |
| workers=1, |
| log_level="info", |
| ) |
|
|