testing / src /backend /server.py
lokesh1's picture
MedMosaic benchmark demo
b4d0173
Raw
History Blame Contribute Delete
4.01 kB
"""
FastAPI backend for the MedMosaic showcase UI.
Endpoints
GET /health
GET /qa_types -> dropdown options
GET /models -> selectable prediction models
GET /questions/{qa_type} -> the 2 questions (+ task, difficulty)
GET /audio/{qa_type}/{index} -> audio file stream (for the player)
POST /predict -> run prediction + evaluation for one item
Run: uvicorn backend.server:app --host 0.0.0.0 --port 8000
(from the medmosaic-benchmark-demo/ directory)
"""
import logging
import shutil
import subprocess
from typing import Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
from . import config, data_store
from .benchmark import Benchmark
# Lightweight, cached playback transcodes so the browser never pulls the full
# multi-minute wav (a 42MB 44.1kHz clip saturates the connection and stalls
# /predict). Plays the ENTIRE clip — same duration, just compressed.
_PLAYBACK_DIR = config.DATA_DIR / "audio_cache"
def _playback_file(qa_type: str, index: int, src: str):
"""(path, media_type): a small mono mp3 of the full clip, transcoded once
and cached. Falls back to the original wav if ffmpeg is unavailable."""
if not shutil.which("ffmpeg"):
return src, "audio/wav"
_PLAYBACK_DIR.mkdir(parents=True, exist_ok=True)
cache = _PLAYBACK_DIR / f"{qa_type}_{index}.mp3"
if not cache.exists():
try:
subprocess.run(
["ffmpeg", "-v", "error", "-y", "-i", src,
"-ac", "1", "-ar", "22050", "-b:a", "64k", str(cache)],
stdin=subprocess.DEVNULL, check=True)
except Exception: # noqa: BLE001
return src, "audio/wav"
return str(cache), "audio/mpeg"
logging.basicConfig(level=logging.INFO)
app = FastAPI(title="MedMosaic Showcase Backend", version="1.0")
app.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"],
)
_bench: Optional[Benchmark] = None
def bench() -> Benchmark:
global _bench
if _bench is None:
_bench = Benchmark()
return _bench
class PredictRequest(BaseModel):
qa_type: str
index: int = 0
model: str = "gemini-2.5-flash"
@app.get("/health")
def health():
return {"status": "ok", "samples": len(data_store.catalog())}
@app.get("/qa_types")
def qa_types():
return [{"qa_type": qt,
"task": config.TASK_BY_QA_TYPE.get(qt, "mcq"),
"n_questions": len(data_store.questions(qt))}
for qt in data_store.qa_types()]
@app.get("/models")
def models():
return [{"key": k, "provider": v["provider"], "target": v["target"]}
for k, v in config.PREDICTION_MODELS.items()]
@app.get("/questions/{qa_type}")
def questions(qa_type: str):
try:
qs = data_store.questions(qa_type)
except KeyError:
raise HTTPException(404, f"unknown qa_type: {qa_type}")
# do not stream audio bytes here; expose a URL the player can hit
out = []
for q in qs:
item = {k: v for k, v in q.items() if k != "audio_file"}
item["audio_url"] = f"/audio/{qa_type}/{q['index']}"
out.append(item)
return out
@app.get("/audio/{qa_type}/{index}")
def audio(qa_type: str, index: int):
try:
rec = data_store.get_record(qa_type, index)
except (KeyError, IndexError):
raise HTTPException(404, "audio not found")
path, media = _playback_file(qa_type, index, rec["audio_file"])
return FileResponse(path, media_type=media)
@app.post("/predict")
def predict(req: PredictRequest):
if req.model not in config.PREDICTION_MODELS:
raise HTTPException(400, f"unknown model: {req.model}")
try:
return bench().run_one(req.qa_type, req.index, req.model)
except (KeyError, IndexError) as e:
raise HTTPException(404, str(e))