""" NoteGrabber API Server ====================== FastAPI backend for studio.cloudfom.org's AI Note Grabber feature. Powered by Spotify's basic-pitch (the same engine as NeuralNote). Accepts audio uploads, runs transcription, and returns piano-roll-ready note data as JSON plus an optional MIDI file download. Install dependencies: pip install fastapi uvicorn python-multipart basic-pitch pretty-midi Run: uvicorn main:app --host 0.0.0.0 --port 8000 --reload """ import io import os import base64 import tempfile import logging from pathlib import Path from typing import Optional import pretty_midi import uvicorn from fastapi import FastAPI, File, Form, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel, Field # ── basic-pitch imports ────────────────────────────────────────────────────── from basic_pitch.inference import predict, Model from basic_pitch import ICASSP_2022_MODEL_PATH # ── Logging ────────────────────────────────────────────────────────────────── logging.basicConfig(level=logging.INFO) logger = logging.getLogger("notegrabber") # ── App setup ──────────────────────────────────────────────────────────────── app = FastAPI( title="NoteGrabber API", description="Audio-to-MIDI transcription service for studio.cloudfom.org", version="1.0.0", ) app.add_middleware( CORSMiddleware, # Lock this down to your actual frontend domain in production allow_origins=[ "https://studio.cloudfom.org", "http://localhost:3000", # local dev "http://localhost:5173", # Vite dev ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ── Load model once at startup (not per-request) ───────────────────────────── logger.info("Loading basic-pitch model…") _MODEL = Model(ICASSP_2022_MODEL_PATH) logger.info("Model loaded ✓") # ── Supported input formats ─────────────────────────────────────────────────── SUPPORTED_EXTENSIONS = {".mp3", ".wav", ".ogg", ".flac", ".m4a", ".aiff"} MAX_FILE_SIZE_MB = 50 # ── Response schemas ────────────────────────────────────────────────────────── class NoteEvent(BaseModel): """A single transcribed note, ready for your piano roll.""" pitch: int = Field(..., description="MIDI note number (0–127)") pitch_name: str = Field(..., description="Human-readable name, e.g. 'C4'") start_time: float = Field(..., description="Note start in seconds") end_time: float = Field(..., description="Note end in seconds") duration: float = Field(..., description="Duration in seconds") velocity: int = Field(..., description="MIDI velocity (0–127)") confidence: float = Field(..., description="Model confidence (0.0–1.0)") # Pitch-bend data (semitone offsets, one per time step within this note) pitch_bend: Optional[list[float]] = Field( None, description="Sub-semitone pitch bend offsets if detected" ) class TranscriptionResult(BaseModel): note_count: int duration_seconds: float tempo_bpm: Optional[float] notes: list[NoteEvent] # base64-encoded .mid file for direct download / import midi_base64: str # Settings echoed back so the client can cache them settings: dict # ── Helper: convert pretty_midi → NoteEvent list ───────────────────────────── def midi_to_note_events(midi_data: pretty_midi.PrettyMIDI) -> list[NoteEvent]: events: list[NoteEvent] = [] for instrument in midi_data.instruments: for note in instrument.notes: pitch_name = pretty_midi.note_number_to_name(note.pitch) # pitch_bends live on the instrument, not the note directly; # collect bends that fall within this note's time window bends_in_window = [ pb.pitch / 8192.0 # normalise to semitones (-2 … +2) for pb in instrument.pitch_bends if note.start <= pb.time < note.end ] events.append( NoteEvent( pitch=note.pitch, pitch_name=pitch_name, start_time=round(note.start, 4), end_time=round(note.end, 4), duration=round(note.end - note.start, 4), velocity=note.velocity, # pretty_midi doesn't store per-note confidence directly; # basic-pitch stuffs it into velocity (0–127). Normalise. confidence=round(note.velocity / 127.0, 3), pitch_bend=bends_in_window if bends_in_window else None, ) ) # Sort chronologically events.sort(key=lambda n: n.start_time) return events # ── Helper: MIDI → base64 string ───────────────────────────────────────────── def midi_to_base64(midi_data: pretty_midi.PrettyMIDI) -> str: buf = io.BytesIO() midi_data.write(buf) buf.seek(0) return base64.b64encode(buf.read()).decode("utf-8") # ── Helper: clamp and validate user params ──────────────────────────────────── def _clamp(value: float, lo: float, hi: float) -> float: return max(lo, min(hi, value)) # ── Routes ──────────────────────────────────────────────────────────────────── @app.get("/health") async def health(): """Simple liveness probe.""" return {"status": "ok", "model": "basic-pitch (ICASSP 2022)"} @app.post("/transcribe", response_model=TranscriptionResult) async def transcribe( audio: UploadFile = File(..., description="Audio file to transcribe"), # ── Transcription parameters (all optional, sensible defaults) ── onset_threshold: float = Form( 0.5, description="Sensitivity for detecting note onsets (0.0–1.0). " "Lower = more notes detected, higher = only confident onsets.", ), frame_threshold: float = Form( 0.3, description="Minimum frame-level activation to sustain a note (0.0–1.0).", ), min_note_length: float = Form( 0.058, description="Minimum note duration in seconds. Shorter notes are filtered out.", ), min_frequency: Optional[float] = Form( None, description="Lowest frequency to transcribe in Hz (e.g. 80 for bass guitar). " "Leave empty for no lower limit.", ), max_frequency: Optional[float] = Form( None, description="Highest frequency to transcribe in Hz (e.g. 2000 for voice). " "Leave empty for no upper limit.", ), multiple_pitch_bends: bool = Form( False, description="Allow multiple simultaneous pitch bends (polyphonic pitch bend). " "Set True for instruments like guitar; False for monophonic sources.", ), melodia_trick: bool = Form( True, description="Apply the Melodia post-processing trick to reduce false positives " "on sustained notes.", ), ): """ Transcribe an audio file to MIDI notes. Returns JSON with all detected notes (pitch, timing, velocity, pitch-bend) plus a base64-encoded .mid file for direct download or import. **Frontend usage:** 1. POST the audio file + settings as multipart/form-data. 2. Parse `notes` array directly into your piano-roll note objects. 3. Optionally decode `midi_base64` and offer a "Download MIDI" button. """ # ── Validate file extension ─────────────────────────────────────────────── filename = audio.filename or "audio" ext = Path(filename).suffix.lower() if ext not in SUPPORTED_EXTENSIONS: raise HTTPException( status_code=415, detail=f"Unsupported file type '{ext}'. " f"Supported: {', '.join(sorted(SUPPORTED_EXTENSIONS))}", ) # ── Read and size-check ─────────────────────────────────────────────────── audio_bytes = await audio.read() size_mb = len(audio_bytes) / (1024 * 1024) if size_mb > MAX_FILE_SIZE_MB: raise HTTPException( status_code=413, detail=f"File too large ({size_mb:.1f} MB). Maximum is {MAX_FILE_SIZE_MB} MB.", ) logger.info(f"Received '{filename}' ({size_mb:.2f} MB), running transcription…") # ── Write to a temp file (basic-pitch needs a file path) ───────────────── with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp: tmp.write(audio_bytes) tmp_path = tmp.name # ── Validate & clamp user settings (before try so always defined) ──────── onset_threshold = _clamp(onset_threshold, 0.05, 0.95) frame_threshold = _clamp(frame_threshold, 0.05, 0.95) min_note_length = max(0.01, min_note_length) min_note_length_ms = min_note_length * 1000.0 # basic-pitch wants ms, not s try: # ── Run basic-pitch ─────────────────────────────────────────────────── # Exact signature (basic-pitch >= 0.3.x): # predict(audio_path, model_or_model_path, # onset_threshold, frame_threshold, # minimum_note_length, ← in MILLISECONDS, not seconds # minimum_frequency, maximum_frequency, # multiple_pitch_bends, melodia_trick, # debug_file, midi_tempo) # NOTE: `include_pitch_bends` was removed from predict(); pitch bends # are now always detected internally and stored in the MIDI output. _model_output, midi_data, note_events = predict( tmp_path, _MODEL, # pre-loaded — no cold start per request onset_threshold, frame_threshold, min_note_length_ms, min_frequency, # None = no lower bound max_frequency, # None = no upper bound multiple_pitch_bends, melodia_trick, ) except Exception as exc: logger.exception("Transcription failed") raise HTTPException(status_code=500, detail=f"Transcription error: {exc}") finally: os.unlink(tmp_path) # always clean up the temp file # ── Build response ──────────────────────────────────────────────────────── notes = midi_to_note_events(midi_data) midi_b64 = midi_to_base64(midi_data) # Attempt to extract tempo (basic-pitch doesn't always set this) try: tempos = midi_data.get_tempo_changes() tempo_bpm = float(tempos[1][0]) if len(tempos[1]) > 0 else None except Exception: tempo_bpm = None result = TranscriptionResult( note_count=len(notes), duration_seconds=round(midi_data.get_end_time(), 3), tempo_bpm=round(tempo_bpm, 2) if tempo_bpm else None, notes=notes, midi_base64=midi_b64, settings={ "onset_threshold": onset_threshold, "frame_threshold": frame_threshold, "min_note_length_seconds": min_note_length, "min_note_length_ms": min_note_length_ms, "min_frequency": min_frequency, "max_frequency": max_frequency, "multiple_pitch_bends": multiple_pitch_bends, "melodia_trick": melodia_trick, }, ) logger.info(f"Transcription complete: {len(notes)} notes detected.") return result # ── Dev entrypoint ──────────────────────────────────────────────────────────── if __name__ == "__main__": uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)