| """FastAPI inference server for the Cardiomegaly classifier. |
| |
| Loads the multi-seed ensemble trained in ``model_training/`` and exposes a |
| single ``POST /predict`` endpoint that the frontend (`src/services/predict.ts`) |
| already knows how to consume. |
| |
| Nothing inside ``model_training/`` is modified — we only *import* the model |
| factory (``src.model.build_model``) to rebuild the exact architecture that was |
| saved to disk, then load the weights on top. |
| |
| Run locally |
| ----------- |
| cd inference_server |
| pip install -r requirements.txt |
| uvicorn server:app --host 0.0.0.0 --port 8000 |
| |
| Environment overrides (optional) |
| -------------------------------- |
| MODEL_BACKBONE default: CFG.backbone (e.g. "efficientnet_b0") |
| MODEL_IMG_SIZE default: CFG.img_size (e.g. 224) |
| MODEL_THRESHOLD default: 0.5 (binary cut-off for the label) |
| MODEL_USE_TTA default: "false" ("true" → 6-pass TTA per image) |
| ALLOWED_ORIGINS comma-separated CORS origins (exact match) |
| ALLOWED_ORIGIN_REGEX regex origin whitelist (e.g. Lovable preview URLs: |
| "https://.*\\.lovable\\.app") |
| LOG_LEVEL default: INFO |
| """ |
|
|
| from __future__ import annotations |
|
|
| import io |
| import json |
| import logging |
| import os |
| import sys |
| from pathlib import Path |
| from typing import List, Literal |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import torchvision.transforms as T |
| from fastapi import FastAPI, File, HTTPException, Query, UploadFile |
| from fastapi.middleware.cors import CORSMiddleware |
| from huggingface_hub import hf_hub_download |
| from PIL import Image |
|
|
| |
| |
| |
| |
| |
| REPO_ROOT = Path(__file__).resolve().parent.parent |
| TRAINING_DIR = REPO_ROOT / "model_training" |
| NOTEBOOKS_DIR = TRAINING_DIR / "notebooks" |
| RESULTS_DIR = NOTEBOOKS_DIR / "results" |
| HF_MODEL_REPO_ID = os.environ.get("HF_MODEL_REPO_ID", "").strip() |
| HF_MODEL_REVISION = os.environ.get("HF_MODEL_REVISION", "main") |
| HF_HUB_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") |
| HF_MODEL_CACHE_DIR = os.environ.get("HF_MODEL_CACHE_DIR", str(REPO_ROOT / ".hf-model-cache")) |
| ACCURATE_MANIFEST = os.environ.get("ACCURATE_MANIFEST_FILE", "ensemble_manifest.csv") |
| FAST_MANIFEST = os.environ.get("FAST_MANIFEST_FILE", "fast/fast_ensemble_manifest.csv") |
| ACCURATE_METRICS = os.environ.get("ACCURATE_METRICS_FILE", "val_metrics_final.json") |
| FAST_METRICS = os.environ.get("FAST_METRICS_FILE", "fast/fast_val_metrics_final.json") |
|
|
| if str(TRAINING_DIR) not in sys.path: |
| sys.path.insert(0, str(TRAINING_DIR)) |
|
|
| |
| |
| |
| os.environ.setdefault("TORCH_HOME", str(REPO_ROOT / ".torch-cache")) |
|
|
| |
| |
| |
| |
| |
| |
| |
| import torchvision.models as _tvm |
| import torchxrayvision as _xrv |
|
|
| for _fn_name in ("efficientnet_b0", "efficientnet_b3", "mobilenet_v3_large"): |
| _orig = getattr(_tvm, _fn_name, None) |
| if _orig is None: |
| continue |
|
|
| def _no_download_builder(*args, __orig=_orig, **kwargs): |
| kwargs["weights"] = None |
| return __orig(*args, **kwargs) |
|
|
| setattr(_tvm, _fn_name, _no_download_builder) |
|
|
| |
| |
| |
| |
| |
| |
| _orig_xrv_densenet_init = _xrv.models.DenseNet.__init__ |
|
|
|
|
| def _xrv_densenet_init_no_download(self, *args, **kwargs): |
| requested_weights = kwargs.get("weights") |
| kwargs["weights"] = None |
| _orig_xrv_densenet_init(self, *args, **kwargs) |
| if requested_weights and requested_weights in _xrv.models.model_urls: |
| labels = _xrv.models.model_urls[requested_weights]["labels"] |
| self.targets = labels |
| self.pathologies = labels |
|
|
|
|
| _xrv.models.DenseNet.__init__ = _xrv_densenet_init_no_download |
|
|
| from src.config import CFG |
| from src.model import build_model, cardio_logit |
| from src.dataset import get_normalize_fn |
|
|
|
|
| def _detect_backbone_from_checkpoint(ckpt_path: Path) -> str: |
| """Inspect a saved state_dict and guess which backbone produced it. |
| |
| Rules: |
| * torchxrayvision DenseNet-121 → has ``features.denseblockN.*`` keys |
| * torchvision EfficientNet → top-level ``features.0.0.weight`` (stem conv) |
| and depth ≥ 9 feature groups |
| * torchvision MobileNetV3-Large → ``features.0.0.weight`` with depth ~17 |
| * microsoft/rad-dino → keys under ``features.embeddings`` / |
| ``features.encoder.layer.`` |
| Defaults to ``CFG.backbone`` if no signature matches. |
| """ |
| state = torch.load(ckpt_path, map_location="cpu", weights_only=True) |
| if isinstance(state, dict) and "state_dict" in state: |
| state = state["state_dict"] |
| keys = list(state.keys()) |
|
|
| if any("denseblock" in k for k in keys): |
| return "densenet121" |
| if any(k.startswith("features.embeddings.") for k in keys) or any( |
| k.startswith("features.encoder.layer.") for k in keys |
| ): |
| return "rad-dino" |
| |
| feature_indices = { |
| int(k.split(".")[1]) |
| for k in keys |
| if k.startswith("features.") and k.split(".")[1].isdigit() |
| } |
| if feature_indices: |
| |
| |
| if max(feature_indices) >= 12: |
| return "mobilenet_v3_large" |
| if max(feature_indices) >= 7: |
| return "efficientnet_b0" |
| return CFG.backbone |
|
|
|
|
| def _hf_download(filename: str) -> Path: |
| """Download a file from HF model repo and return local cached path.""" |
| if not HF_MODEL_REPO_ID: |
| raise FileNotFoundError( |
| f"File {filename!r} not found locally and HF_MODEL_REPO_ID is not set." |
| ) |
| path = hf_hub_download( |
| repo_id=HF_MODEL_REPO_ID, |
| filename=filename, |
| revision=HF_MODEL_REVISION, |
| token=HF_HUB_TOKEN, |
| cache_dir=HF_MODEL_CACHE_DIR, |
| ) |
| return Path(path) |
|
|
|
|
| def _manifest_candidates(mode: Literal["accurate", "fast"]) -> tuple[list[Path], list[str]]: |
| if mode == "fast": |
| local = [RESULTS_DIR / "fast_model" / "fast_ensemble_manifest.csv", RESULTS_DIR / "fast_ensemble_manifest.csv"] |
| remote = [FAST_MANIFEST, "fast_ensemble_manifest.csv"] |
| else: |
| local = [RESULTS_DIR / "ensemble_manifest.csv"] |
| remote = [ACCURATE_MANIFEST, "ensemble_manifest.csv"] |
| return local, remote |
|
|
|
|
| def _metrics_candidates(mode: Literal["accurate", "fast"]) -> tuple[list[Path], list[str]]: |
| if mode == "fast": |
| local = [RESULTS_DIR / "fast_model" / "fast_val_metrics_final.json", RESULTS_DIR / "fast_val_metrics_final.json"] |
| remote = [FAST_METRICS, "fast_val_metrics_final.json"] |
| else: |
| local = [RESULTS_DIR / "val_metrics_final.json"] |
| remote = [ACCURATE_METRICS, "val_metrics_final.json"] |
| return local, remote |
|
|
|
|
| def _resolve_manifest_path(mode: Literal["accurate", "fast"] = "accurate") -> Path: |
| """Find mode-specific manifest locally first, else download from HF.""" |
| local_candidates, remote_candidates = _manifest_candidates(mode) |
| for local in local_candidates: |
| if local.exists(): |
| return local |
| log.info( |
| "Local %s manifest not found under %s; downloading from HF repo %s", |
| mode, |
| RESULTS_DIR, |
| HF_MODEL_REPO_ID or "<unset>", |
| ) |
| for filename in remote_candidates: |
| try: |
| return _hf_download(filename) |
| except Exception: |
| continue |
| raise FileNotFoundError(f"Could not resolve {mode} manifest from local paths or HF") |
|
|
|
|
| def _resolve_optional_support_file( |
| name: str, |
| mode: Literal["accurate", "fast"] = "accurate", |
| ) -> Path | None: |
| """Find optional support file locally; if missing try HF model repo.""" |
| if name == "val_metrics_final.json": |
| local_candidates, remote_candidates = _metrics_candidates(mode) |
| else: |
| local_candidates = [RESULTS_DIR / name] |
| remote_candidates = [name] |
| for local in local_candidates: |
| if local.exists(): |
| return local |
| for remote in remote_candidates: |
| try: |
| return _hf_download(remote) |
| except Exception: |
| continue |
| return None |
|
|
|
|
| |
| |
| |
| |
| def _first_checkpoint_path(mode: Literal["accurate", "fast"] = "accurate") -> Path: |
| try: |
| manifest = _resolve_manifest_path(mode) |
| df = pd.read_csv(manifest) |
| first = df["checkpoint"].iloc[0] |
| p = Path(first) |
| if p.is_absolute() and p.exists(): |
| return p |
| |
| for candidate in ( |
| NOTEBOOKS_DIR / first, |
| RESULTS_DIR / Path(first).name, |
| RESULTS_DIR / "fast_model" / Path(first).name, |
| ): |
| if candidate.exists(): |
| return candidate |
| |
| |
| for name in (first, Path(first).name): |
| try: |
| return _hf_download(name) |
| except Exception: |
| continue |
| raise FileNotFoundError(f"Could not resolve first checkpoint from manifest entry: {first!r}") |
| except Exception: |
| pass |
| fallback = RESULTS_DIR / ("fast_best_model.pth" if mode == "fast" else "best_model.pth") |
| if fallback.exists(): |
| return fallback |
| try: |
| return _hf_download("fast_best_model.pth" if mode == "fast" else "best_model.pth") |
| except Exception as exc: |
| raise FileNotFoundError( |
| "No checkpoints found locally and could not download from HF. " |
| "Set HF_MODEL_REPO_ID and upload ensemble_manifest.csv + *.pth." |
| ) from exc |
|
|
|
|
| _DETECTED_BACKBONE = _detect_backbone_from_checkpoint(_first_checkpoint_path("accurate")) |
| try: |
| _DETECTED_FAST_BACKBONE = _detect_backbone_from_checkpoint(_first_checkpoint_path("fast")) |
| except Exception as exc: |
| |
| |
| _DETECTED_FAST_BACKBONE = _DETECTED_BACKBONE |
| log = logging.getLogger("inference") |
| log.warning( |
| "Fast backbone auto-detect failed at import (%s). Falling back to accurate backbone %s.", |
| exc, |
| _DETECTED_BACKBONE, |
| ) |
| |
| _DEFAULT_IMG_SIZE = 518 if _DETECTED_BACKBONE == "rad-dino" else 224 |
| _DEFAULT_FAST_IMG_SIZE = 518 if _DETECTED_FAST_BACKBONE == "rad-dino" else 224 |
|
|
| BACKBONE: str = os.environ.get("MODEL_BACKBONE", _DETECTED_BACKBONE) |
| IMG_SIZE: int = int(os.environ.get("MODEL_IMG_SIZE", str(_DEFAULT_IMG_SIZE))) |
| FAST_BACKBONE: str = os.environ.get("MODEL_FAST_BACKBONE", _DETECTED_FAST_BACKBONE) |
| FAST_IMG_SIZE: int = int(os.environ.get("MODEL_FAST_IMG_SIZE", str(_DEFAULT_FAST_IMG_SIZE))) |
| USE_TTA: bool = os.environ.get("MODEL_USE_TTA", "true").lower() in {"1", "true", "yes"} |
|
|
|
|
| def _default_threshold() -> float: |
| """Use the training-selected threshold when available.""" |
| metrics_path = _resolve_optional_support_file("val_metrics_final.json", mode="accurate") |
| if metrics_path is not None: |
| try: |
| with open(metrics_path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| thr = float(data.get("threshold", 0.5)) |
| if 0.0 <= thr <= 1.0: |
| return thr |
| except Exception: |
| pass |
| return 0.5 |
|
|
|
|
| DECISION_THRESHOLD: float = float(os.environ.get("MODEL_THRESHOLD", str(_default_threshold()))) |
|
|
|
|
| def _default_fast_threshold() -> float: |
| metrics_path = _resolve_optional_support_file("val_metrics_final.json", mode="fast") |
| if metrics_path is not None: |
| try: |
| with open(metrics_path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| thr = float(data.get("threshold", DECISION_THRESHOLD)) |
| if 0.0 <= thr <= 1.0: |
| return thr |
| except Exception: |
| pass |
| return DECISION_THRESHOLD |
|
|
|
|
| FAST_DECISION_THRESHOLD: float = float( |
| os.environ.get("MODEL_FAST_THRESHOLD", str(_default_fast_threshold())) |
| ) |
|
|
| _DEFAULT_ORIGINS = ( |
| "http://localhost:3000," |
| "http://localhost:5173," |
| "http://localhost:8080," |
| "http://127.0.0.1:3000," |
| "http://127.0.0.1:5173," |
| "http://127.0.0.1:8080" |
| ) |
| ALLOWED_ORIGINS: list[str] = [ |
| o.strip() |
| for o in os.environ.get("ALLOWED_ORIGINS", _DEFAULT_ORIGINS).split(",") |
| if o.strip() |
| ] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| _DEFAULT_ORIGIN_REGEX = ( |
| r"https://([a-z0-9-]+\.)*lovable\.app" |
| r"|https://([a-z0-9-]+\.)*lovableproject\.com" |
| r"|https://([a-z0-9-]+\.)*ngrok-free\.app" |
| r"|https://([a-z0-9-]+\.)*ngrok\.app" |
| r"|https://([a-z0-9-]+\.)*ngrok\.io" |
| r"|http://(192\.168\.\d{1,3}\.\d{1,3}|10\.\d{1,3}\.\d{1,3}\.\d{1,3}):\d+" |
| ) |
| _ORIGIN_REGEX: str | None = os.environ.get("ALLOWED_ORIGIN_REGEX", _DEFAULT_ORIGIN_REGEX) or None |
|
|
| DEVICE: torch.device = torch.device( |
| "cuda" if torch.cuda.is_available() |
| else "mps" if torch.backends.mps.is_available() |
| else "cpu" |
| ) |
|
|
| POSITIVE_LABEL = "Cardiomegaly" |
| NEGATIVE_LABEL = "No Cardiomegaly indication" |
|
|
| |
| |
| |
| logging.basicConfig( |
| level=os.environ.get("LOG_LEVEL", "INFO"), |
| format="%(asctime)s %(levelname)-5s %(message)s", |
| ) |
| log = logging.getLogger("inference") |
|
|
| |
| |
| |
| |
| |
| |
| _normalize_fn = get_normalize_fn(BACKBONE) |
| _fast_normalize_fn = get_normalize_fn(FAST_BACKBONE) |
|
|
|
|
| def _pil_hflip(img: Image.Image) -> Image.Image: |
| return img.transpose(Image.FLIP_LEFT_RIGHT) |
|
|
|
|
| def _tta_pipelines(size: int) -> List[T.Compose]: |
| """Match `src.transforms.make_tta_transforms` (6 deterministic passes).""" |
| s = (size, size) |
| return [ |
| T.Compose([T.Resize(s)]), |
| T.Compose([T.Resize(s), T.Lambda(_pil_hflip)]), |
| T.Compose([T.Resize((size + 20, size + 20)), T.CenterCrop(s)]), |
| T.Compose([T.Resize((size - 20, size - 20)), T.Pad(10, fill=0), T.CenterCrop(s)]), |
| T.Compose([T.Resize(s), T.RandomAffine(degrees=(6, 6), fill=0)]), |
| T.Compose([T.Resize(s), T.RandomAffine(degrees=(-6, -6), fill=0)]), |
| ] |
|
|
|
|
| def _single_eval_pipeline(size: int) -> T.Compose: |
| return T.Compose([T.Resize((size, size))]) |
|
|
|
|
| |
| |
| |
| def _resolve_checkpoint(p: str, mode: Literal["accurate", "fast"] = "accurate") -> Path: |
| """Resolve checkpoint locally first, else download from HF model repo.""" |
| path = Path(p) |
| if path.is_absolute() and path.exists(): |
| return path |
| for candidate in ( |
| NOTEBOOKS_DIR / p, |
| RESULTS_DIR / Path(p).name, |
| RESULTS_DIR / "fast_model" / Path(p).name, |
| ): |
| if candidate.exists(): |
| return candidate |
| |
| tried = [p] |
| if Path(p).name != p: |
| tried.append(Path(p).name) |
| if mode == "fast": |
| tried.extend([f"fast/{Path(p).name}", f"fast_model/{Path(p).name}"]) |
| for name in tried: |
| try: |
| downloaded = _hf_download(name) |
| log.info(" → downloaded %s from HF repo %s", name, HF_MODEL_REPO_ID) |
| return downloaded |
| except Exception: |
| continue |
| raise FileNotFoundError( |
| f"Checkpoint not found locally and not downloadable from HF repo: {p!r}" |
| ) |
|
|
|
|
| def _load_ensemble(mode: Literal["accurate", "fast"] = "accurate") -> tuple[List[nn.Module], list[str]]: |
| |
| mode_backbone = FAST_BACKBONE if mode == "fast" else BACKBONE |
| mode_img_size = FAST_IMG_SIZE if mode == "fast" else IMG_SIZE |
| mode_norm_name = _fast_normalize_fn.__name__ if mode == "fast" else _normalize_fn.__name__ |
|
|
| CFG.backbone = mode_backbone |
| CFG.img_size = mode_img_size |
|
|
| try: |
| manifest = _resolve_manifest_path(mode) |
| df = pd.read_csv(manifest) |
| checkpoint_paths = [_resolve_checkpoint(p, mode=mode) for p in df["checkpoint"].tolist()] |
| log.info( |
| "Loading %s ensemble of %d models from %s", |
| mode, |
| len(checkpoint_paths), |
| manifest, |
| ) |
| except Exception: |
| fallback_name = "fast_best_model.pth" if mode == "fast" else "best_model.pth" |
| best = RESULTS_DIR / fallback_name |
| if best.exists(): |
| checkpoint_paths = [best] |
| log.info("No %s manifest found, falling back to local checkpoint: %s", mode, best.name) |
| else: |
| checkpoint_paths = [_resolve_checkpoint(fallback_name, mode=mode)] |
| log.info("No %s manifest found, falling back to HF checkpoint: %s", mode, fallback_name) |
|
|
| models: list[nn.Module] = [] |
| for ckpt_path in checkpoint_paths: |
| log.info(" → loading %s (%s)", ckpt_path.name, ckpt_path.resolve()) |
| model = build_model(mode_backbone) |
| state = torch.load(ckpt_path, map_location=DEVICE) |
| if isinstance(state, dict) and "state_dict" in state: |
| state = state["state_dict"] |
| missing, unexpected = model.load_state_dict(state, strict=False) |
| if missing or unexpected: |
| raise RuntimeError( |
| "Checkpoint architecture mismatch. " |
| f"backbone={mode_backbone!r}, checkpoint={ckpt_path.name!r}, " |
| f"missing_keys={len(missing)}, unexpected_keys={len(unexpected)}. " |
| "Use the correct mode-specific backbone / img_size and ensure " |
| "ensemble_manifest.csv points to checkpoints from that training run." |
| ) |
| model.to(DEVICE).eval() |
| models.append(model) |
|
|
| loaded_checkpoints = [p.name for p in checkpoint_paths] |
| mode_thr = FAST_DECISION_THRESHOLD if mode == "fast" else DECISION_THRESHOLD |
| log.info( |
| "%s ensemble ready — %d model(s) · device=%s · backbone=%s · " |
| "normalize=%s · img_size=%d · tta=%s · threshold=%.4f", |
| mode, len(models), DEVICE, mode_backbone, |
| mode_norm_name, mode_img_size, USE_TTA, mode_thr, |
| ) |
| return models, loaded_checkpoints |
|
|
|
|
| |
| |
| |
| app = FastAPI(title="CardioScan inference", version="1.0") |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=ALLOWED_ORIGINS, |
| allow_origin_regex=_ORIGIN_REGEX, |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| _ensemble: list[nn.Module] = [] |
| _loaded_checkpoints: list[str] = [] |
| _fast_ensemble: list[nn.Module] = [] |
| _fast_loaded_checkpoints: list[str] = [] |
|
|
|
|
| @app.on_event("startup") |
| def _startup() -> None: |
| global _ensemble, _loaded_checkpoints, _fast_ensemble, _fast_loaded_checkpoints |
| _ensemble, _loaded_checkpoints = _load_ensemble("accurate") |
| try: |
| _fast_ensemble, _fast_loaded_checkpoints = _load_ensemble("fast") |
| except Exception: |
| log.warning("Fast ensemble unavailable; using accurate ensemble for fast mode fallback.") |
| _fast_ensemble, _fast_loaded_checkpoints = _ensemble, _loaded_checkpoints |
|
|
|
|
| @app.get("/health") |
| def health() -> dict: |
| return { |
| "ok": bool(_ensemble), |
| "models": len(_ensemble), |
| "checkpoints": _loaded_checkpoints, |
| "backbone": BACKBONE, |
| "detected_backbone": _DETECTED_BACKBONE, |
| "normalization": _normalize_fn.__name__, |
| "img_size": IMG_SIZE, |
| "device": str(DEVICE), |
| "use_tta": USE_TTA, |
| "threshold": DECISION_THRESHOLD, |
| "fast_backbone": FAST_BACKBONE, |
| "fast_detected_backbone": _DETECTED_FAST_BACKBONE, |
| "fast_normalization": _fast_normalize_fn.__name__, |
| "fast_img_size": FAST_IMG_SIZE, |
| "fast_models": len(_fast_ensemble), |
| "fast_checkpoints": _fast_loaded_checkpoints, |
| "fast_threshold": FAST_DECISION_THRESHOLD, |
| } |
|
|
|
|
| @torch.no_grad() |
| def _predict_probability_detailed( |
| pil_gray: Image.Image, |
| use_tta: bool, |
| ensemble: list[nn.Module], |
| checkpoints: list[str], |
| img_size: int, |
| normalize_fn, |
| max_models: int | None = None, |
| ) -> dict: |
| """Run ensemble (+ optional TTA) on a single PIL image. |
| |
| Returns a dict with per-model / per-TTA logits for transparency. |
| Matches `tta_predict` / `tta_predict_ensemble` in ``src.train`` exactly: |
| average logits across TTA (per model), then average across models, |
| then sigmoid. |
| """ |
| pipelines = _tta_pipelines(img_size) if use_tta else [_single_eval_pipeline(img_size)] |
|
|
| tensors = [normalize_fn(pipeline(pil_gray)) for pipeline in pipelines] |
| batch = torch.stack(tensors, dim=0).to(DEVICE) |
|
|
| active_model_count = len(ensemble) if max_models is None else max(1, min(max_models, len(ensemble))) |
| active_models = ensemble[:active_model_count] |
| active_checkpoints = checkpoints[:active_model_count] |
|
|
| per_model_tta_logits: list[np.ndarray] = [] |
| per_model_mean_logit: list[float] = [] |
| for model in active_models: |
| logit_vec = cardio_logit(model, batch).float().cpu().numpy() |
| per_model_tta_logits.append(logit_vec) |
| per_model_mean_logit.append(float(np.mean(logit_vec))) |
|
|
| ensemble_mean_logit = float(np.mean(per_model_mean_logit)) |
| probability = float(1.0 / (1.0 + np.exp(-ensemble_mean_logit))) |
|
|
| return { |
| "probability": probability, |
| "ensemble_mean_logit": ensemble_mean_logit, |
| "per_model_mean_logit": { |
| name: lg for name, lg in zip(active_checkpoints, per_model_mean_logit) |
| }, |
| "per_model_tta_logits": { |
| name: lg.tolist() for name, lg in zip(active_checkpoints, per_model_tta_logits) |
| }, |
| "num_tta_passes": batch.shape[0], |
| "models_used": active_model_count, |
| "checkpoints_used": active_checkpoints, |
| } |
|
|
|
|
| @app.post("/predict") |
| async def predict( |
| image: UploadFile = File(...), |
| mode: Literal["accurate", "fast"] = Query(default="accurate", description="Inference mode"), |
| use_tta: bool | None = Query(default=None, description="Override TTA for this request."), |
| max_models: int | None = Query(default=None, ge=1, description="Use only first N models for speed."), |
| ) -> dict: |
| if not _ensemble: |
| raise HTTPException(status_code=503, detail="Model not ready") |
|
|
| raw = await image.read() |
| if not raw: |
| raise HTTPException(status_code=400, detail="Empty upload") |
|
|
| try: |
| pil = Image.open(io.BytesIO(raw)).convert("L") |
| except Exception as exc: |
| raise HTTPException(status_code=400, detail=f"Could not decode image: {exc}") from exc |
|
|
| selected_ensemble = _fast_ensemble if mode == "fast" else _ensemble |
| selected_checkpoints = _fast_loaded_checkpoints if mode == "fast" else _loaded_checkpoints |
| selected_img_size = FAST_IMG_SIZE if mode == "fast" else IMG_SIZE |
| selected_normalize_fn = _fast_normalize_fn if mode == "fast" else _normalize_fn |
| if not selected_ensemble: |
| raise HTTPException(status_code=503, detail=f"{mode} model not ready") |
| effective_use_tta = (False if mode == "fast" else USE_TTA) if use_tta is None else use_tta |
|
|
| try: |
| details = _predict_probability_detailed( |
| pil, |
| use_tta=effective_use_tta, |
| ensemble=selected_ensemble, |
| checkpoints=selected_checkpoints, |
| img_size=selected_img_size, |
| normalize_fn=selected_normalize_fn, |
| max_models=max_models, |
| ) |
| except Exception as exc: |
| log.exception("Inference failed") |
| raise HTTPException(status_code=500, detail=f"Inference error: {exc}") from exc |
|
|
| probability = details["probability"] |
| active_threshold = FAST_DECISION_THRESHOLD if mode == "fast" else DECISION_THRESHOLD |
| is_positive = probability >= active_threshold |
|
|
| log.info( |
| "/predict file=%s size=%d prob=%.4f thr=%.4f -> %s (per-model=%s, tta=%d)", |
| image.filename, |
| len(raw), |
| probability, |
| active_threshold, |
| "Cardiomegaly" if is_positive else "Negative", |
| {k: round(v, 4) for k, v in details["per_model_mean_logit"].items()}, |
| details["num_tta_passes"], |
| ) |
|
|
| return { |
| "prediction": POSITIVE_LABEL if is_positive else NEGATIVE_LABEL, |
| "prediction_binary": 1 if is_positive else 0, |
| "confidence": probability, |
| "heatmap_url": None, |
| "source": "model", |
| "threshold": active_threshold, |
| "ensemble_size": details["models_used"], |
| "use_tta": effective_use_tta, |
| "checkpoints": details["checkpoints_used"], |
| "mode": mode, |
| } |
|
|
|
|
| @app.post("/debug/predict") |
| async def debug_predict(image: UploadFile = File(...)) -> dict: |
| """Same as /predict but returns per-model and per-TTA raw logits for |
| verification against the training notebook's val/test CSVs.""" |
| if not _ensemble: |
| raise HTTPException(status_code=503, detail="Model not ready") |
|
|
| raw = await image.read() |
| if not raw: |
| raise HTTPException(status_code=400, detail="Empty upload") |
|
|
| try: |
| pil = Image.open(io.BytesIO(raw)).convert("L") |
| except Exception as exc: |
| raise HTTPException(status_code=400, detail=f"Could not decode image: {exc}") from exc |
|
|
| details = _predict_probability_detailed( |
| pil, |
| use_tta=USE_TTA, |
| ensemble=_ensemble, |
| checkpoints=_loaded_checkpoints, |
| img_size=IMG_SIZE, |
| normalize_fn=_normalize_fn, |
| ) |
| details["prediction"] = ( |
| POSITIVE_LABEL if details["probability"] >= DECISION_THRESHOLD else NEGATIVE_LABEL |
| ) |
| details["prediction_binary"] = 1 if details["probability"] >= DECISION_THRESHOLD else 0 |
| details["threshold"] = DECISION_THRESHOLD |
| details["use_tta"] = USE_TTA |
| details["checkpoints"] = _loaded_checkpoints |
| return details |
|
|