| """ |
| NoteGrabber API Server |
| ====================== |
| FastAPI backend for studio.cloudfom.org's AI Note Grabber feature. |
| |
| Powered by Spotify's basic-pitch (the same engine as NeuralNote). |
| Accepts audio uploads, runs transcription, and returns piano-roll-ready |
| note data as JSON plus an optional MIDI file download. |
| |
| Install dependencies: |
| pip install fastapi uvicorn python-multipart basic-pitch pretty-midi |
| |
| Run: |
| uvicorn main:app --host 0.0.0.0 --port 8000 --reload |
| """ |
|
|
| import io |
| import os |
| import base64 |
| import tempfile |
| import logging |
| from pathlib import Path |
| from typing import Optional |
|
|
| import pretty_midi |
| import uvicorn |
| from fastapi import FastAPI, File, Form, HTTPException, UploadFile |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import JSONResponse |
| from pydantic import BaseModel, Field |
|
|
| |
| from basic_pitch.inference import predict, Model |
| from basic_pitch import ICASSP_2022_MODEL_PATH |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger("notegrabber") |
|
|
| |
| app = FastAPI( |
| title="NoteGrabber API", |
| description="Audio-to-MIDI transcription service for studio.cloudfom.org", |
| version="1.0.0", |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| |
| allow_origins=[ |
| "https://studio.cloudfom.org", |
| "http://localhost:3000", |
| "http://localhost:5173", |
| ], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| logger.info("Loading basic-pitch model…") |
| _MODEL = Model(ICASSP_2022_MODEL_PATH) |
| logger.info("Model loaded ✓") |
|
|
| |
| SUPPORTED_EXTENSIONS = {".mp3", ".wav", ".ogg", ".flac", ".m4a", ".aiff"} |
| MAX_FILE_SIZE_MB = 50 |
|
|
|
|
| |
| class NoteEvent(BaseModel): |
| """A single transcribed note, ready for your piano roll.""" |
|
|
| pitch: int = Field(..., description="MIDI note number (0–127)") |
| pitch_name: str = Field(..., description="Human-readable name, e.g. 'C4'") |
| start_time: float = Field(..., description="Note start in seconds") |
| end_time: float = Field(..., description="Note end in seconds") |
| duration: float = Field(..., description="Duration in seconds") |
| velocity: int = Field(..., description="MIDI velocity (0–127)") |
| confidence: float = Field(..., description="Model confidence (0.0–1.0)") |
| |
| pitch_bend: Optional[list[float]] = Field( |
| None, description="Sub-semitone pitch bend offsets if detected" |
| ) |
|
|
|
|
| class TranscriptionResult(BaseModel): |
| note_count: int |
| duration_seconds: float |
| tempo_bpm: Optional[float] |
| notes: list[NoteEvent] |
| |
| midi_base64: str |
| |
| settings: dict |
|
|
|
|
| |
| def midi_to_note_events(midi_data: pretty_midi.PrettyMIDI) -> list[NoteEvent]: |
| events: list[NoteEvent] = [] |
|
|
| for instrument in midi_data.instruments: |
| for note in instrument.notes: |
| pitch_name = pretty_midi.note_number_to_name(note.pitch) |
| |
| |
| bends_in_window = [ |
| pb.pitch / 8192.0 |
| for pb in instrument.pitch_bends |
| if note.start <= pb.time < note.end |
| ] |
|
|
| events.append( |
| NoteEvent( |
| pitch=note.pitch, |
| pitch_name=pitch_name, |
| start_time=round(note.start, 4), |
| end_time=round(note.end, 4), |
| duration=round(note.end - note.start, 4), |
| velocity=note.velocity, |
| |
| |
| confidence=round(note.velocity / 127.0, 3), |
| pitch_bend=bends_in_window if bends_in_window else None, |
| ) |
| ) |
|
|
| |
| events.sort(key=lambda n: n.start_time) |
| return events |
|
|
|
|
| |
| def midi_to_base64(midi_data: pretty_midi.PrettyMIDI) -> str: |
| buf = io.BytesIO() |
| midi_data.write(buf) |
| buf.seek(0) |
| return base64.b64encode(buf.read()).decode("utf-8") |
|
|
|
|
| |
| def _clamp(value: float, lo: float, hi: float) -> float: |
| return max(lo, min(hi, value)) |
|
|
|
|
| |
|
|
| @app.get("/health") |
| async def health(): |
| """Simple liveness probe.""" |
| return {"status": "ok", "model": "basic-pitch (ICASSP 2022)"} |
|
|
|
|
| @app.post("/transcribe", response_model=TranscriptionResult) |
| async def transcribe( |
| audio: UploadFile = File(..., description="Audio file to transcribe"), |
| |
| onset_threshold: float = Form( |
| 0.5, |
| description="Sensitivity for detecting note onsets (0.0–1.0). " |
| "Lower = more notes detected, higher = only confident onsets.", |
| ), |
| frame_threshold: float = Form( |
| 0.3, |
| description="Minimum frame-level activation to sustain a note (0.0–1.0).", |
| ), |
| min_note_length: float = Form( |
| 0.058, |
| description="Minimum note duration in seconds. Shorter notes are filtered out.", |
| ), |
| min_frequency: Optional[float] = Form( |
| None, |
| description="Lowest frequency to transcribe in Hz (e.g. 80 for bass guitar). " |
| "Leave empty for no lower limit.", |
| ), |
| max_frequency: Optional[float] = Form( |
| None, |
| description="Highest frequency to transcribe in Hz (e.g. 2000 for voice). " |
| "Leave empty for no upper limit.", |
| ), |
| multiple_pitch_bends: bool = Form( |
| False, |
| description="Allow multiple simultaneous pitch bends (polyphonic pitch bend). " |
| "Set True for instruments like guitar; False for monophonic sources.", |
| ), |
| melodia_trick: bool = Form( |
| True, |
| description="Apply the Melodia post-processing trick to reduce false positives " |
| "on sustained notes.", |
| ), |
| ): |
| """ |
| Transcribe an audio file to MIDI notes. |
| |
| Returns JSON with all detected notes (pitch, timing, velocity, pitch-bend) |
| plus a base64-encoded .mid file for direct download or import. |
| |
| **Frontend usage:** |
| 1. POST the audio file + settings as multipart/form-data. |
| 2. Parse `notes` array directly into your piano-roll note objects. |
| 3. Optionally decode `midi_base64` and offer a "Download MIDI" button. |
| """ |
|
|
| |
| filename = audio.filename or "audio" |
| ext = Path(filename).suffix.lower() |
| if ext not in SUPPORTED_EXTENSIONS: |
| raise HTTPException( |
| status_code=415, |
| detail=f"Unsupported file type '{ext}'. " |
| f"Supported: {', '.join(sorted(SUPPORTED_EXTENSIONS))}", |
| ) |
|
|
| |
| audio_bytes = await audio.read() |
| size_mb = len(audio_bytes) / (1024 * 1024) |
| if size_mb > MAX_FILE_SIZE_MB: |
| raise HTTPException( |
| status_code=413, |
| detail=f"File too large ({size_mb:.1f} MB). Maximum is {MAX_FILE_SIZE_MB} MB.", |
| ) |
|
|
| logger.info(f"Received '{filename}' ({size_mb:.2f} MB), running transcription…") |
|
|
| |
| with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp: |
| tmp.write(audio_bytes) |
| tmp_path = tmp.name |
|
|
| |
| onset_threshold = _clamp(onset_threshold, 0.05, 0.95) |
| frame_threshold = _clamp(frame_threshold, 0.05, 0.95) |
| min_note_length = max(0.01, min_note_length) |
| min_note_length_ms = min_note_length * 1000.0 |
|
|
| try: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| _model_output, midi_data, note_events = predict( |
| tmp_path, |
| _MODEL, |
| onset_threshold, |
| frame_threshold, |
| min_note_length_ms, |
| min_frequency, |
| max_frequency, |
| multiple_pitch_bends, |
| melodia_trick, |
| ) |
|
|
| except Exception as exc: |
| logger.exception("Transcription failed") |
| raise HTTPException(status_code=500, detail=f"Transcription error: {exc}") |
| finally: |
| os.unlink(tmp_path) |
|
|
| |
| notes = midi_to_note_events(midi_data) |
| midi_b64 = midi_to_base64(midi_data) |
|
|
| |
| try: |
| tempos = midi_data.get_tempo_changes() |
| tempo_bpm = float(tempos[1][0]) if len(tempos[1]) > 0 else None |
| except Exception: |
| tempo_bpm = None |
|
|
| result = TranscriptionResult( |
| note_count=len(notes), |
| duration_seconds=round(midi_data.get_end_time(), 3), |
| tempo_bpm=round(tempo_bpm, 2) if tempo_bpm else None, |
| notes=notes, |
| midi_base64=midi_b64, |
| settings={ |
| "onset_threshold": onset_threshold, |
| "frame_threshold": frame_threshold, |
| "min_note_length_seconds": min_note_length, |
| "min_note_length_ms": min_note_length_ms, |
| "min_frequency": min_frequency, |
| "max_frequency": max_frequency, |
| "multiple_pitch_bends": multiple_pitch_bends, |
| "melodia_trick": melodia_trick, |
| }, |
| ) |
|
|
| logger.info(f"Transcription complete: {len(notes)} notes detected.") |
| return result |
|
|
|
|
| |
| if __name__ == "__main__": |
| uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) |
|
|