ngo / main.py
fomext's picture
Upload main.py
e1b863f verified
Raw
History Blame Contribute Delete
13.2 kB
"""
NoteGrabber API Server
======================
FastAPI backend for studio.cloudfom.org's AI Note Grabber feature.
Powered by Spotify's basic-pitch (the same engine as NeuralNote).
Accepts audio uploads, runs transcription, and returns piano-roll-ready
note data as JSON plus an optional MIDI file download.
Install dependencies:
pip install fastapi uvicorn python-multipart basic-pitch pretty-midi
Run:
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
"""
import io
import os
import base64
import tempfile
import logging
from pathlib import Path
from typing import Optional
import pretty_midi
import uvicorn
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
# ── basic-pitch imports ──────────────────────────────────────────────────────
from basic_pitch.inference import predict, Model
from basic_pitch import ICASSP_2022_MODEL_PATH
# ── Logging ──────────────────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("notegrabber")
# ── App setup ────────────────────────────────────────────────────────────────
app = FastAPI(
title="NoteGrabber API",
description="Audio-to-MIDI transcription service for studio.cloudfom.org",
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
# Lock this down to your actual frontend domain in production
allow_origins=[
"https://studio.cloudfom.org",
"http://localhost:3000", # local dev
"http://localhost:5173", # Vite dev
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ── Load model once at startup (not per-request) ─────────────────────────────
logger.info("Loading basic-pitch model…")
_MODEL = Model(ICASSP_2022_MODEL_PATH)
logger.info("Model loaded ✓")
# ── Supported input formats ───────────────────────────────────────────────────
SUPPORTED_EXTENSIONS = {".mp3", ".wav", ".ogg", ".flac", ".m4a", ".aiff"}
MAX_FILE_SIZE_MB = 50
# ── Response schemas ──────────────────────────────────────────────────────────
class NoteEvent(BaseModel):
"""A single transcribed note, ready for your piano roll."""
pitch: int = Field(..., description="MIDI note number (0–127)")
pitch_name: str = Field(..., description="Human-readable name, e.g. 'C4'")
start_time: float = Field(..., description="Note start in seconds")
end_time: float = Field(..., description="Note end in seconds")
duration: float = Field(..., description="Duration in seconds")
velocity: int = Field(..., description="MIDI velocity (0–127)")
confidence: float = Field(..., description="Model confidence (0.0–1.0)")
# Pitch-bend data (semitone offsets, one per time step within this note)
pitch_bend: Optional[list[float]] = Field(
None, description="Sub-semitone pitch bend offsets if detected"
)
class TranscriptionResult(BaseModel):
note_count: int
duration_seconds: float
tempo_bpm: Optional[float]
notes: list[NoteEvent]
# base64-encoded .mid file for direct download / import
midi_base64: str
# Settings echoed back so the client can cache them
settings: dict
# ── Helper: convert pretty_midi → NoteEvent list ─────────────────────────────
def midi_to_note_events(midi_data: pretty_midi.PrettyMIDI) -> list[NoteEvent]:
events: list[NoteEvent] = []
for instrument in midi_data.instruments:
for note in instrument.notes:
pitch_name = pretty_midi.note_number_to_name(note.pitch)
# pitch_bends live on the instrument, not the note directly;
# collect bends that fall within this note's time window
bends_in_window = [
pb.pitch / 8192.0 # normalise to semitones (-2 … +2)
for pb in instrument.pitch_bends
if note.start <= pb.time < note.end
]
events.append(
NoteEvent(
pitch=note.pitch,
pitch_name=pitch_name,
start_time=round(note.start, 4),
end_time=round(note.end, 4),
duration=round(note.end - note.start, 4),
velocity=note.velocity,
# pretty_midi doesn't store per-note confidence directly;
# basic-pitch stuffs it into velocity (0–127). Normalise.
confidence=round(note.velocity / 127.0, 3),
pitch_bend=bends_in_window if bends_in_window else None,
)
)
# Sort chronologically
events.sort(key=lambda n: n.start_time)
return events
# ── Helper: MIDI → base64 string ─────────────────────────────────────────────
def midi_to_base64(midi_data: pretty_midi.PrettyMIDI) -> str:
buf = io.BytesIO()
midi_data.write(buf)
buf.seek(0)
return base64.b64encode(buf.read()).decode("utf-8")
# ── Helper: clamp and validate user params ────────────────────────────────────
def _clamp(value: float, lo: float, hi: float) -> float:
return max(lo, min(hi, value))
# ── Routes ────────────────────────────────────────────────────────────────────
@app.get("/health")
async def health():
"""Simple liveness probe."""
return {"status": "ok", "model": "basic-pitch (ICASSP 2022)"}
@app.post("/transcribe", response_model=TranscriptionResult)
async def transcribe(
audio: UploadFile = File(..., description="Audio file to transcribe"),
# ── Transcription parameters (all optional, sensible defaults) ──
onset_threshold: float = Form(
0.5,
description="Sensitivity for detecting note onsets (0.0–1.0). "
"Lower = more notes detected, higher = only confident onsets.",
),
frame_threshold: float = Form(
0.3,
description="Minimum frame-level activation to sustain a note (0.0–1.0).",
),
min_note_length: float = Form(
0.058,
description="Minimum note duration in seconds. Shorter notes are filtered out.",
),
min_frequency: Optional[float] = Form(
None,
description="Lowest frequency to transcribe in Hz (e.g. 80 for bass guitar). "
"Leave empty for no lower limit.",
),
max_frequency: Optional[float] = Form(
None,
description="Highest frequency to transcribe in Hz (e.g. 2000 for voice). "
"Leave empty for no upper limit.",
),
multiple_pitch_bends: bool = Form(
False,
description="Allow multiple simultaneous pitch bends (polyphonic pitch bend). "
"Set True for instruments like guitar; False for monophonic sources.",
),
melodia_trick: bool = Form(
True,
description="Apply the Melodia post-processing trick to reduce false positives "
"on sustained notes.",
),
):
"""
Transcribe an audio file to MIDI notes.
Returns JSON with all detected notes (pitch, timing, velocity, pitch-bend)
plus a base64-encoded .mid file for direct download or import.
**Frontend usage:**
1. POST the audio file + settings as multipart/form-data.
2. Parse `notes` array directly into your piano-roll note objects.
3. Optionally decode `midi_base64` and offer a "Download MIDI" button.
"""
# ── Validate file extension ───────────────────────────────────────────────
filename = audio.filename or "audio"
ext = Path(filename).suffix.lower()
if ext not in SUPPORTED_EXTENSIONS:
raise HTTPException(
status_code=415,
detail=f"Unsupported file type '{ext}'. "
f"Supported: {', '.join(sorted(SUPPORTED_EXTENSIONS))}",
)
# ── Read and size-check ───────────────────────────────────────────────────
audio_bytes = await audio.read()
size_mb = len(audio_bytes) / (1024 * 1024)
if size_mb > MAX_FILE_SIZE_MB:
raise HTTPException(
status_code=413,
detail=f"File too large ({size_mb:.1f} MB). Maximum is {MAX_FILE_SIZE_MB} MB.",
)
logger.info(f"Received '{filename}' ({size_mb:.2f} MB), running transcription…")
# ── Write to a temp file (basic-pitch needs a file path) ─────────────────
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
# ── Validate & clamp user settings (before try so always defined) ────────
onset_threshold = _clamp(onset_threshold, 0.05, 0.95)
frame_threshold = _clamp(frame_threshold, 0.05, 0.95)
min_note_length = max(0.01, min_note_length)
min_note_length_ms = min_note_length * 1000.0 # basic-pitch wants ms, not s
try:
# ── Run basic-pitch ───────────────────────────────────────────────────
# Exact signature (basic-pitch >= 0.3.x):
# predict(audio_path, model_or_model_path,
# onset_threshold, frame_threshold,
# minimum_note_length, ← in MILLISECONDS, not seconds
# minimum_frequency, maximum_frequency,
# multiple_pitch_bends, melodia_trick,
# debug_file, midi_tempo)
# NOTE: `include_pitch_bends` was removed from predict(); pitch bends
# are now always detected internally and stored in the MIDI output.
_model_output, midi_data, note_events = predict(
tmp_path,
_MODEL, # pre-loaded — no cold start per request
onset_threshold,
frame_threshold,
min_note_length_ms,
min_frequency, # None = no lower bound
max_frequency, # None = no upper bound
multiple_pitch_bends,
melodia_trick,
)
except Exception as exc:
logger.exception("Transcription failed")
raise HTTPException(status_code=500, detail=f"Transcription error: {exc}")
finally:
os.unlink(tmp_path) # always clean up the temp file
# ── Build response ────────────────────────────────────────────────────────
notes = midi_to_note_events(midi_data)
midi_b64 = midi_to_base64(midi_data)
# Attempt to extract tempo (basic-pitch doesn't always set this)
try:
tempos = midi_data.get_tempo_changes()
tempo_bpm = float(tempos[1][0]) if len(tempos[1]) > 0 else None
except Exception:
tempo_bpm = None
result = TranscriptionResult(
note_count=len(notes),
duration_seconds=round(midi_data.get_end_time(), 3),
tempo_bpm=round(tempo_bpm, 2) if tempo_bpm else None,
notes=notes,
midi_base64=midi_b64,
settings={
"onset_threshold": onset_threshold,
"frame_threshold": frame_threshold,
"min_note_length_seconds": min_note_length,
"min_note_length_ms": min_note_length_ms,
"min_frequency": min_frequency,
"max_frequency": max_frequency,
"multiple_pitch_bends": multiple_pitch_bends,
"melodia_trick": melodia_trick,
},
)
logger.info(f"Transcription complete: {len(notes)} notes detected.")
return result
# ── Dev entrypoint ────────────────────────────────────────────────────────────
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)