Spaces:
Sleeping
Sleeping
| """ | |
| inference.py β FastAPI Backend for BraTS Segmentation | |
| ======================================================= | |
| Loads the trained UNet3D checkpoint and serves predictions via HTTP. | |
| Endpoints: | |
| GET /health β model status | |
| POST /segment β run segmentation on uploaded NIfTI files | |
| POST /segment/demo β run on a synthetic volume (no upload needed) | |
| Run: | |
| cd src | |
| uvicorn inference:app --host 0.0.0.0 --port 8000 --reload | |
| """ | |
| import io | |
| import sys | |
| import numpy as np | |
| import torch | |
| from pathlib import Path | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from dotenv import load_dotenv | |
| import os | |
| load_dotenv() | |
| # Make sure src/ is on the path when running from project root | |
| sys.path.append(str(Path(__file__).parent)) | |
| from model import UNet3D | |
| from dataset import normalize_modality, crop_to_brain, resize_volume, MODALITIES | |
| # βββ App Setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CORSMiddleware allows the React frontend (running on localhost:5173) | |
| # to call this API without being blocked by the browser's same-origin policy. | |
| app = FastAPI( | |
| title="BraTS Segmentation API", | |
| description="3D U-Net brain tumor segmentation β BraTS2020", | |
| version="1.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # βββ Model Loading ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Model is loaded once at startup and reused for every request. | |
| # Loading per-request would be ~5 seconds of overhead each time. | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| CHECKPOINT = Path(__file__).parent.parent / os.getenv("CHECKPOINT_PATH", "checkpoints/best_model.pth") | |
| TARGET = (128, 128, 128) | |
| model: UNet3D | None = None | |
| def load_model(): | |
| global model | |
| model = UNet3D(in_channels=4, out_channels=4, | |
| base_filters=32, depth=4).to(DEVICE) | |
| if CHECKPOINT.exists(): | |
| ckpt = torch.load(str(CHECKPOINT), map_location=DEVICE) | |
| model.load_state_dict(ckpt["model_state_dict"]) | |
| print(f"β Loaded checkpoint from epoch {ckpt['epoch']} " | |
| f"best Dice: {ckpt['best_dice']:.4f}") | |
| else: | |
| print("β οΈ No checkpoint found β using random weights") | |
| model.eval() | |
| # βββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Converts a raw NIfTI bytes object into a preprocessed numpy array. | |
| # Supports .nii and .nii.gz β nibabel detects format from the header. | |
| def load_nifti_bytes(content: bytes, filename: str) -> np.ndarray: | |
| try: | |
| import nibabel as nib | |
| import tempfile, os | |
| suffix = ".nii.gz" if filename.endswith(".gz") else ".nii" | |
| with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: | |
| tmp.write(content) | |
| tmp_path = tmp.name | |
| vol = nib.load(tmp_path).get_fdata().astype(np.float32) | |
| os.unlink(tmp_path) | |
| return vol | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Failed to load {filename}: {e}") | |
| def preprocess_volume(volumes: list[np.ndarray]) -> torch.Tensor: | |
| # Apply full pipeline to each modality: normalize β crop β resize | |
| # Then stack β (1, 4, 128, 128, 128) with batch dim | |
| processed = [] | |
| for vol in volumes: | |
| vol = normalize_modality(vol) | |
| vol = crop_to_brain(vol) | |
| vol = resize_volume(vol, target=TARGET, mode="trilinear") | |
| processed.append(vol) | |
| stacked = np.stack(processed, axis=0) # (4, 128, 128, 128) | |
| return torch.from_numpy(stacked).float().unsqueeze(0) # (1, 4, 128, 128, 128) | |
| def run_inference(input_tensor: torch.Tensor) -> np.ndarray: | |
| # Returns (128, 128, 128) integer label map {0,1,2,3} | |
| input_tensor = input_tensor.to(DEVICE) | |
| with torch.no_grad(): | |
| logits = model(input_tensor) # (1, 4, 128, 128, 128) | |
| pred = torch.argmax(logits, dim=1).squeeze(0) # (128, 128, 128) | |
| return pred.cpu().numpy().astype(np.uint8) | |
| def build_response(pred: np.ndarray, volumes: list[np.ndarray] | None = None, demo: bool = False) -> dict: | |
| total = pred.size | |
| classes = {} | |
| class_names = {0: "Background", 1: "Necrotic Core", 2: "Edema", 3: "Enhancing Tumor"} | |
| class_colors = {0: [0,0,0,0], 1: [255,50,20,200], 2: [0,220,80,200], 3: [255,220,0,200]} | |
| for label in range(4): | |
| count = int((pred == label).sum()) | |
| classes[str(label)] = { | |
| "name": class_names[label], | |
| "voxels": count, | |
| "percentage": round(100 * count / total, 2), | |
| "color": class_colors[label], | |
| } | |
| regions = { | |
| "WT": int((pred > 0).sum()), | |
| "TC": int(np.isin(pred, [1, 3]).sum()), | |
| "ET": int((pred == 3).sum()), | |
| } | |
| h, w, d = pred.shape | |
| # Segmentation slices | |
| slices = { | |
| "axial": pred[:, :, d // 2].tolist(), | |
| "coronal": pred[:, w // 2, :].tolist(), | |
| "sagittal": pred[h // 2, :, :].tolist(), | |
| } | |
| # MRI slices β normalize each modality to 0-255 for display | |
| # FLAIR (index 0) is best for showing tumor context | |
| mri_slices = {} | |
| if volumes is not None: | |
| flair = volumes[0] # FLAIR is most informative for tumor visualization | |
| # Normalize to 0β255 for frontend rendering | |
| flair_min, flair_max = flair.min(), flair.max() | |
| flair_norm = ((flair - flair_min) / (flair_max - flair_min + 1e-8) * 255).astype(np.uint8) | |
| mri_slices = { | |
| "axial": flair_norm[:, :, d // 2].tolist(), | |
| "coronal": flair_norm[:, w // 2, :].tolist(), | |
| "sagittal": flair_norm[h // 2, :, :].tolist(), | |
| } | |
| return { | |
| "success": True, | |
| "demo": demo, | |
| "shape": list(pred.shape), | |
| "tumor_burden_%": round(100 * (pred > 0).sum() / total, 3), | |
| "classes": classes, | |
| "regions": regions, | |
| "slices": slices, | |
| "mri_slices": mri_slices, | |
| } | |
| # βββ Endpoints ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def health(): | |
| # Called by frontend on load to check if the model is ready | |
| return { | |
| "status": "ok", | |
| "device": str(DEVICE), | |
| "model_loaded": model is not None, | |
| "checkpoint_found": CHECKPOINT.exists(), | |
| } | |
| async def segment( | |
| flair: UploadFile = File(...), | |
| t1: UploadFile = File(...), | |
| t1ce: UploadFile = File(...), | |
| t2: UploadFile = File(...), | |
| ): | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| uploads = [flair, t1, t1ce, t2] | |
| volumes = [] | |
| for upload in uploads: | |
| content = await upload.read() | |
| vol = load_nifti_bytes(content, upload.filename) | |
| volumes.append(vol) | |
| tensor = preprocess_volume(volumes) | |
| # Also get the preprocessed volumes for visualization | |
| preprocessed_vols = [] | |
| for vol in volumes: | |
| v = normalize_modality(vol) | |
| v = crop_to_brain(v) | |
| v = resize_volume(v, target=TARGET, mode="trilinear") | |
| preprocessed_vols.append(v) | |
| pred = run_inference(tensor) | |
| return JSONResponse(build_response(pred, volumes=preprocessed_vols, demo=False)) | |
| def segment_demo(): | |
| # Runs inference on a synthetic random volume β no file upload needed. | |
| # Useful for testing the frontend without real patient data. | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| synthetic = torch.randn(1, 4, 128, 128, 128) | |
| pred = run_inference(synthetic) | |
| return JSONResponse(build_response(pred, volumes=None, demo=True)) |