Spaces:
Sleeping
Sleeping
File size: 8,513 Bytes
2f33c28 5821a9e 2f33c28 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 | """
inference.py β FastAPI Backend for BraTS Segmentation
=======================================================
Loads the trained UNet3D checkpoint and serves predictions via HTTP.
Endpoints:
GET /health β model status
POST /segment β run segmentation on uploaded NIfTI files
POST /segment/demo β run on a synthetic volume (no upload needed)
Run:
cd src
uvicorn inference:app --host 0.0.0.0 --port 8000 --reload
"""
import io
import sys
import numpy as np
import torch
from pathlib import Path
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from dotenv import load_dotenv
import os
load_dotenv()
# Make sure src/ is on the path when running from project root
sys.path.append(str(Path(__file__).parent))
from model import UNet3D
from dataset import normalize_modality, crop_to_brain, resize_volume, MODALITIES
# βββ App Setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# CORSMiddleware allows the React frontend (running on localhost:5173)
# to call this API without being blocked by the browser's same-origin policy.
app = FastAPI(
title="BraTS Segmentation API",
description="3D U-Net brain tumor segmentation β BraTS2020",
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# βββ Model Loading ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Model is loaded once at startup and reused for every request.
# Loading per-request would be ~5 seconds of overhead each time.
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT = Path(__file__).parent.parent / os.getenv("CHECKPOINT_PATH", "checkpoints/best_model.pth")
TARGET = (128, 128, 128)
model: UNet3D | None = None
@app.on_event("startup")
def load_model():
global model
model = UNet3D(in_channels=4, out_channels=4,
base_filters=32, depth=4).to(DEVICE)
if CHECKPOINT.exists():
ckpt = torch.load(str(CHECKPOINT), map_location=DEVICE)
model.load_state_dict(ckpt["model_state_dict"])
print(f"β
Loaded checkpoint from epoch {ckpt['epoch']} "
f"best Dice: {ckpt['best_dice']:.4f}")
else:
print("β οΈ No checkpoint found β using random weights")
model.eval()
# βββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Converts a raw NIfTI bytes object into a preprocessed numpy array.
# Supports .nii and .nii.gz β nibabel detects format from the header.
def load_nifti_bytes(content: bytes, filename: str) -> np.ndarray:
try:
import nibabel as nib
import tempfile, os
suffix = ".nii.gz" if filename.endswith(".gz") else ".nii"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(content)
tmp_path = tmp.name
vol = nib.load(tmp_path).get_fdata().astype(np.float32)
os.unlink(tmp_path)
return vol
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to load {filename}: {e}")
def preprocess_volume(volumes: list[np.ndarray]) -> torch.Tensor:
# Apply full pipeline to each modality: normalize β crop β resize
# Then stack β (1, 4, 128, 128, 128) with batch dim
processed = []
for vol in volumes:
vol = normalize_modality(vol)
vol = crop_to_brain(vol)
vol = resize_volume(vol, target=TARGET, mode="trilinear")
processed.append(vol)
stacked = np.stack(processed, axis=0) # (4, 128, 128, 128)
return torch.from_numpy(stacked).float().unsqueeze(0) # (1, 4, 128, 128, 128)
def run_inference(input_tensor: torch.Tensor) -> np.ndarray:
# Returns (128, 128, 128) integer label map {0,1,2,3}
input_tensor = input_tensor.to(DEVICE)
with torch.no_grad():
logits = model(input_tensor) # (1, 4, 128, 128, 128)
pred = torch.argmax(logits, dim=1).squeeze(0) # (128, 128, 128)
return pred.cpu().numpy().astype(np.uint8)
def build_response(pred: np.ndarray, volumes: list[np.ndarray] | None = None, demo: bool = False) -> dict:
total = pred.size
classes = {}
class_names = {0: "Background", 1: "Necrotic Core", 2: "Edema", 3: "Enhancing Tumor"}
class_colors = {0: [0,0,0,0], 1: [255,50,20,200], 2: [0,220,80,200], 3: [255,220,0,200]}
for label in range(4):
count = int((pred == label).sum())
classes[str(label)] = {
"name": class_names[label],
"voxels": count,
"percentage": round(100 * count / total, 2),
"color": class_colors[label],
}
regions = {
"WT": int((pred > 0).sum()),
"TC": int(np.isin(pred, [1, 3]).sum()),
"ET": int((pred == 3).sum()),
}
h, w, d = pred.shape
# Segmentation slices
slices = {
"axial": pred[:, :, d // 2].tolist(),
"coronal": pred[:, w // 2, :].tolist(),
"sagittal": pred[h // 2, :, :].tolist(),
}
# MRI slices β normalize each modality to 0-255 for display
# FLAIR (index 0) is best for showing tumor context
mri_slices = {}
if volumes is not None:
flair = volumes[0] # FLAIR is most informative for tumor visualization
# Normalize to 0β255 for frontend rendering
flair_min, flair_max = flair.min(), flair.max()
flair_norm = ((flair - flair_min) / (flair_max - flair_min + 1e-8) * 255).astype(np.uint8)
mri_slices = {
"axial": flair_norm[:, :, d // 2].tolist(),
"coronal": flair_norm[:, w // 2, :].tolist(),
"sagittal": flair_norm[h // 2, :, :].tolist(),
}
return {
"success": True,
"demo": demo,
"shape": list(pred.shape),
"tumor_burden_%": round(100 * (pred > 0).sum() / total, 3),
"classes": classes,
"regions": regions,
"slices": slices,
"mri_slices": mri_slices,
}
# βββ Endpoints ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.get("/health")
def health():
# Called by frontend on load to check if the model is ready
return {
"status": "ok",
"device": str(DEVICE),
"model_loaded": model is not None,
"checkpoint_found": CHECKPOINT.exists(),
}
@app.post("/segment")
async def segment(
flair: UploadFile = File(...),
t1: UploadFile = File(...),
t1ce: UploadFile = File(...),
t2: UploadFile = File(...),
):
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
uploads = [flair, t1, t1ce, t2]
volumes = []
for upload in uploads:
content = await upload.read()
vol = load_nifti_bytes(content, upload.filename)
volumes.append(vol)
tensor = preprocess_volume(volumes)
# Also get the preprocessed volumes for visualization
preprocessed_vols = []
for vol in volumes:
v = normalize_modality(vol)
v = crop_to_brain(v)
v = resize_volume(v, target=TARGET, mode="trilinear")
preprocessed_vols.append(v)
pred = run_inference(tensor)
return JSONResponse(build_response(pred, volumes=preprocessed_vols, demo=False))
@app.post("/segment/demo")
def segment_demo():
# Runs inference on a synthetic random volume β no file upload needed.
# Useful for testing the frontend without real patient data.
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
synthetic = torch.randn(1, 4, 128, 128, 128)
pred = run_inference(synthetic)
return JSONResponse(build_response(pred, volumes=None, demo=True)) |