"""FastAPI backend for the piano tutorial transcription pipeline.""" import json import shutil import sys import tempfile import threading import traceback import uuid from pathlib import Path import pretty_midi from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.responses import FileResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware # Add transcriber to path TRANSCRIBER_DIR = Path(__file__).resolve().parent.parent / "transcriber" sys.path.insert(0, str(TRANSCRIBER_DIR)) app = FastAPI(title="Piano Tutorial API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Directory for temporary processing files WORK_DIR = Path(tempfile.gettempdir()) / "piano-tutorial" WORK_DIR.mkdir(exist_ok=True) @app.post("/api/transcribe") async def transcribe( file: UploadFile = File(...), ): """Transcribe an uploaded audio file to MIDI. Accepts a file upload (MP3, M4A, WAV, OGG, FLAC). Returns JSON with a job_id, MIDI download URL, and chord data. """ job_id = str(uuid.uuid4())[:8] job_dir = WORK_DIR / job_id job_dir.mkdir(exist_ok=True) try: suffix = Path(file.filename).suffix or ".m4a" audio_path = job_dir / f"upload{suffix}" content = await file.read() audio_path.write_bytes(content) # Run transcription from transcribe import transcribe as run_transcribe raw_midi_path = job_dir / "transcription_raw.mid" run_transcribe(str(audio_path), str(raw_midi_path)) # Run optimization (also runs chord detection as Step 10) from optimize import optimize optimized_path = job_dir / "transcription.mid" optimize(str(audio_path), str(raw_midi_path), str(optimized_path)) if not optimized_path.exists(): raise HTTPException(500, "Optimization failed to produce output") # Load chord data if available chords_path = job_dir / "transcription_chords.json" chord_data = None if chords_path.exists(): with open(chords_path) as f: chord_data = json.load(f) return JSONResponse({ "job_id": job_id, "midi_url": f"/api/jobs/{job_id}/midi", "chords_url": f"/api/jobs/{job_id}/chords", "audio_url": f"/api/jobs/{job_id}/audio", "chords": chord_data, }) except HTTPException: raise except Exception as e: raise HTTPException(500, f"Transcription failed: {str(e)}") @app.get("/api/jobs/{job_id}/midi") async def get_midi(job_id: str): """Download the optimized MIDI file for a completed job.""" midi_path = WORK_DIR / job_id / "transcription.mid" if not midi_path.exists(): raise HTTPException(404, f"No MIDI file found for job {job_id}") return FileResponse( midi_path, media_type="audio/midi", filename="transcription.mid", ) @app.get("/api/jobs/{job_id}/chords") async def get_chords(job_id: str): """Get the detected chord data for a completed job.""" chords_path = WORK_DIR / job_id / "transcription_chords.json" if not chords_path.exists(): raise HTTPException(404, f"No chord data found for job {job_id}") with open(chords_path) as f: chord_data = json.load(f) return JSONResponse(chord_data) # ── Full-song mode (Demucs source separation) ────────────────────────── # In-memory job status for async full-song transcription job_status = {} def merge_stems(piano_midi_path, bass_midi_path, output_path): """Merge piano and bass MIDI into a single multi-track file.""" piano = pretty_midi.PrettyMIDI(str(piano_midi_path)) bass = pretty_midi.PrettyMIDI(str(bass_midi_path)) merged = pretty_midi.PrettyMIDI() # Track 0: Piano (program 0) piano_inst = pretty_midi.Instrument(program=0, name="Piano") for inst in piano.instruments: piano_inst.notes.extend(inst.notes) merged.instruments.append(piano_inst) # Track 1: Bass (program 33) bass_inst = pretty_midi.Instrument(program=33, name="Bass") for inst in bass.instruments: bass_inst.notes.extend(inst.notes) merged.instruments.append(bass_inst) merged.write(str(output_path)) def run_full_transcription(job_id, audio_path, job_dir): """Background worker for full-song transcription with Demucs.""" try: # Step 1: Demucs separation job_status[job_id] = {"step": 1, "label": "Separating instruments with AI...", "done": False} from separate import separate stems = separate(str(audio_path), str(job_dir / "stems")) # Step 2: Transcribe melodic + bass stems job_status[job_id] = {"step": 2, "label": "Transcribing instruments...", "done": False} from transcribe import transcribe as run_transcribe piano_raw = job_dir / "piano_raw.mid" run_transcribe(stems["other"], str(piano_raw)) bass_raw = job_dir / "bass_raw.mid" run_transcribe(stems["bass"], str(bass_raw)) # Step 3: Optimize transcriptions # Use the full solo piano optimizer for the melodic stem — it produces # much better rhythm, playability, and note accuracy. Also runs chord # detection and spectral analysis internally. job_status[job_id] = {"step": 3, "label": "Optimizing note accuracy...", "done": False} from optimize import optimize as optimize_piano from optimize_bass import optimize_bass piano_opt = job_dir / "transcription.tmp.mid" optimize_piano(stems["other"], str(piano_raw), str(piano_opt)) # Solo optimizer writes chords to {stem}_chords.json next to the output auto_chords = job_dir / "transcription.tmp_chords.json" chords_path = job_dir / "transcription_chords.json" if auto_chords.exists(): auto_chords.rename(chords_path) # Rename to final path piano_final = job_dir / "piano_optimized.mid" piano_opt.rename(piano_final) piano_opt = piano_final bass_opt = job_dir / "bass_optimized.mid" optimize_bass(stems["bass"], str(bass_raw), str(bass_opt)) # Load chord data chord_data = None if chords_path.exists(): with open(chords_path) as f: chord_data = json.load(f) # Step 4: Transcribe drums job_status[job_id] = {"step": 4, "label": "Transcribing drums...", "done": False} from drums import transcribe_drums drum_tab_path = job_dir / "drum_tab.json" transcribe_drums(stems["drums"], str(drum_tab_path)) # Step 5: Generate guitar and bass tabs job_status[job_id] = {"step": 5, "label": "Generating tabs...", "done": False} from tabs import midi_to_guitar_tab, midi_to_bass_tab guitar_tab = midi_to_guitar_tab(str(piano_opt), str(chords_path)) guitar_tab_path = job_dir / "guitar_tab.json" with open(guitar_tab_path, 'w') as f: json.dump(guitar_tab, f) bass_tab = midi_to_bass_tab(str(bass_opt)) bass_tab_path = job_dir / "bass_tab.json" with open(bass_tab_path, 'w') as f: json.dump(bass_tab, f) # Step 6: Merge melodic + bass into final MIDI job_status[job_id] = {"step": 6, "label": "Assembling final result...", "done": False} merged_path = job_dir / "transcription.mid" merge_stems(str(piano_opt), str(bass_opt), str(merged_path)) # Clean up large stem files and intermediates stems_dir = job_dir / "stems" if stems_dir.exists(): shutil.rmtree(stems_dir) for f in [piano_raw, bass_raw, piano_opt, bass_opt]: f.unlink(missing_ok=True) job_status[job_id] = { "step": 7, "label": "Done!", "done": True, "result": { "job_id": job_id, "midi_url": f"/api/jobs/{job_id}/midi", "chords_url": f"/api/jobs/{job_id}/chords", "audio_url": f"/api/jobs/{job_id}/audio", "guitar_tab_url": f"/api/jobs/{job_id}/guitar-tab", "bass_tab_url": f"/api/jobs/{job_id}/bass-tab", "drum_tab_url": f"/api/jobs/{job_id}/drum-tab", "chords": chord_data, "mode": "full", }, } except Exception as e: traceback.print_exc() job_status[job_id] = { "step": -1, "label": str(e)[:200], "done": True, "error": str(e)[:200], } @app.post("/api/transcribe-full") async def transcribe_full(file: UploadFile = File(...)): """Start full-song transcription with Demucs source separation. Returns immediately with a job_id. Poll /api/jobs/{job_id}/status. """ job_id = str(uuid.uuid4())[:8] job_dir = WORK_DIR / job_id job_dir.mkdir(exist_ok=True) suffix = Path(file.filename).suffix or ".m4a" audio_path = job_dir / f"upload{suffix}" content = await file.read() audio_path.write_bytes(content) job_status[job_id] = {"step": 0, "label": "Starting...", "done": False} thread = threading.Thread( target=run_full_transcription, args=(job_id, audio_path, job_dir), daemon=True, ) thread.start() return JSONResponse({"job_id": job_id}) @app.get("/api/jobs/{job_id}/status") async def get_job_status(job_id: str): """Get the current status of a full-song transcription job.""" status = job_status.get(job_id) if status is None: raise HTTPException(404, f"No job found with id {job_id}") return JSONResponse(status) @app.get("/api/jobs/{job_id}/guitar-tab") async def get_guitar_tab(job_id: str): """Get the guitar tab data for a completed full-song job.""" tab_path = WORK_DIR / job_id / "guitar_tab.json" if not tab_path.exists(): raise HTTPException(404, f"No guitar tab data for job {job_id}") with open(tab_path) as f: return JSONResponse(json.load(f)) @app.get("/api/jobs/{job_id}/bass-tab") async def get_bass_tab(job_id: str): """Get the bass tab data for a completed full-song job.""" tab_path = WORK_DIR / job_id / "bass_tab.json" if not tab_path.exists(): raise HTTPException(404, f"No bass tab data for job {job_id}") with open(tab_path) as f: return JSONResponse(json.load(f)) @app.get("/api/jobs/{job_id}/drum-tab") async def get_drum_tab(job_id: str): """Get the drum tab data for a completed full-song job.""" tab_path = WORK_DIR / job_id / "drum_tab.json" if not tab_path.exists(): raise HTTPException(404, f"No drum tab data for job {job_id}") with open(tab_path) as f: return JSONResponse(json.load(f)) @app.get("/api/jobs/{job_id}/audio") async def get_audio(job_id: str): """Serve the original uploaded audio file back for playback.""" job_dir = WORK_DIR / job_id if not job_dir.exists(): raise HTTPException(404, f"No job found with id {job_id}") # Find the upload file (upload.mp3, upload.m4a, upload.wav, etc.) media_types = { ".mp3": "audio/mpeg", ".m4a": "audio/mp4", ".wav": "audio/wav", ".ogg": "audio/ogg", ".flac": "audio/flac", } for f in job_dir.iterdir(): if f.name.startswith("upload"): mt = media_types.get(f.suffix.lower(), "audio/mpeg") return FileResponse(f, media_type=mt, filename=f"original{f.suffix}") raise HTTPException(404, f"No audio file found for job {job_id}") @app.get("/api/health") async def health(): return {"status": "ok"} # Serve the built React frontend (in production) DIST_DIR = Path(__file__).resolve().parent.parent / "app" / "dist" if DIST_DIR.exists(): # Serve static assets app.mount("/assets", StaticFiles(directory=str(DIST_DIR / "assets")), name="assets") # Serve MIDI files if they exist midi_dir = DIST_DIR / "midi" if midi_dir.exists(): app.mount("/midi", StaticFiles(directory=str(midi_dir)), name="midi") # Catch-all: serve index.html for SPA routing @app.get("/{path:path}") async def serve_spa(path: str): file_path = DIST_DIR / path if file_path.is_file(): return FileResponse(file_path) return FileResponse(DIST_DIR / "index.html")