""" FastAPI backend for the MedMosaic showcase UI. Endpoints GET /health GET /qa_types -> dropdown options GET /models -> selectable prediction models GET /questions/{qa_type} -> the 2 questions (+ task, difficulty) GET /audio/{qa_type}/{index} -> audio file stream (for the player) POST /predict -> run prediction + evaluation for one item Run: uvicorn backend.server:app --host 0.0.0.0 --port 8000 (from the medmosaic-benchmark-demo/ directory) """ import logging import shutil import subprocess from typing import Optional from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from pydantic import BaseModel from . import config, data_store from .benchmark import Benchmark # Lightweight, cached playback transcodes so the browser never pulls the full # multi-minute wav (a 42MB 44.1kHz clip saturates the connection and stalls # /predict). Plays the ENTIRE clip — same duration, just compressed. _PLAYBACK_DIR = config.DATA_DIR / "audio_cache" def _playback_file(qa_type: str, index: int, src: str): """(path, media_type): a small mono mp3 of the full clip, transcoded once and cached. Falls back to the original wav if ffmpeg is unavailable.""" if not shutil.which("ffmpeg"): return src, "audio/wav" _PLAYBACK_DIR.mkdir(parents=True, exist_ok=True) cache = _PLAYBACK_DIR / f"{qa_type}_{index}.mp3" if not cache.exists(): try: subprocess.run( ["ffmpeg", "-v", "error", "-y", "-i", src, "-ac", "1", "-ar", "22050", "-b:a", "64k", str(cache)], stdin=subprocess.DEVNULL, check=True) except Exception: # noqa: BLE001 return src, "audio/wav" return str(cache), "audio/mpeg" logging.basicConfig(level=logging.INFO) app = FastAPI(title="MedMosaic Showcase Backend", version="1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) _bench: Optional[Benchmark] = None def bench() -> Benchmark: global _bench if _bench is None: _bench = Benchmark() return _bench class PredictRequest(BaseModel): qa_type: str index: int = 0 model: str = "gemini-2.5-flash" @app.get("/health") def health(): return {"status": "ok", "samples": len(data_store.catalog())} @app.get("/qa_types") def qa_types(): return [{"qa_type": qt, "task": config.TASK_BY_QA_TYPE.get(qt, "mcq"), "n_questions": len(data_store.questions(qt))} for qt in data_store.qa_types()] @app.get("/models") def models(): return [{"key": k, "provider": v["provider"], "target": v["target"]} for k, v in config.PREDICTION_MODELS.items()] @app.get("/questions/{qa_type}") def questions(qa_type: str): try: qs = data_store.questions(qa_type) except KeyError: raise HTTPException(404, f"unknown qa_type: {qa_type}") # do not stream audio bytes here; expose a URL the player can hit out = [] for q in qs: item = {k: v for k, v in q.items() if k != "audio_file"} item["audio_url"] = f"/audio/{qa_type}/{q['index']}" out.append(item) return out @app.get("/audio/{qa_type}/{index}") def audio(qa_type: str, index: int): try: rec = data_store.get_record(qa_type, index) except (KeyError, IndexError): raise HTTPException(404, "audio not found") path, media = _playback_file(qa_type, index, rec["audio_file"]) return FileResponse(path, media_type=media) @app.post("/predict") def predict(req: PredictRequest): if req.model not in config.PREDICTION_MODELS: raise HTTPException(400, f"unknown model: {req.model}") try: return bench().run_one(req.qa_type, req.index, req.model) except (KeyError, IndexError) as e: raise HTTPException(404, str(e))