File size: 8,513 Bytes
2f33c28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5821a9e
2f33c28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
"""
inference.py β€” FastAPI Backend for BraTS Segmentation
=======================================================
Loads the trained UNet3D checkpoint and serves predictions via HTTP.

Endpoints:
    GET  /health              β€” model status
    POST /segment             β€” run segmentation on uploaded NIfTI files
    POST /segment/demo        β€” run on a synthetic volume (no upload needed)

Run:
    cd src
    uvicorn inference:app --host 0.0.0.0 --port 8000 --reload
"""

import io
import sys
import numpy as np
import torch
from pathlib import Path
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from dotenv import load_dotenv
import os
load_dotenv()
# Make sure src/ is on the path when running from project root
sys.path.append(str(Path(__file__).parent))

from model   import UNet3D
from dataset import normalize_modality, crop_to_brain, resize_volume, MODALITIES


# ─── App Setup ────────────────────────────────────────────────────────────────
# CORSMiddleware allows the React frontend (running on localhost:5173)
# to call this API without being blocked by the browser's same-origin policy.

app = FastAPI(
    title="BraTS Segmentation API",
    description="3D U-Net brain tumor segmentation β€” BraTS2020",
    version="1.0.0",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


# ─── Model Loading ────────────────────────────────────────────────────────────
# Model is loaded once at startup and reused for every request.
# Loading per-request would be ~5 seconds of overhead each time.

DEVICE     = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT = Path(__file__).parent.parent / os.getenv("CHECKPOINT_PATH", "checkpoints/best_model.pth")
TARGET     = (128, 128, 128)

model: UNet3D | None = None


@app.on_event("startup")
def load_model():
    global model
    model = UNet3D(in_channels=4, out_channels=4,
                   base_filters=32, depth=4).to(DEVICE)

    if CHECKPOINT.exists():
        ckpt = torch.load(str(CHECKPOINT), map_location=DEVICE)
        model.load_state_dict(ckpt["model_state_dict"])
        print(f"βœ… Loaded checkpoint from epoch {ckpt['epoch']}  "
              f"best Dice: {ckpt['best_dice']:.4f}")
    else:
        print("⚠️  No checkpoint found β€” using random weights")

    model.eval()


# ─── Helpers ──────────────────────────────────────────────────────────────────
# Converts a raw NIfTI bytes object into a preprocessed numpy array.
# Supports .nii and .nii.gz β€” nibabel detects format from the header.

def load_nifti_bytes(content: bytes, filename: str) -> np.ndarray:
    try:
        import nibabel as nib
        import tempfile, os
        suffix = ".nii.gz" if filename.endswith(".gz") else ".nii"
        with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
            tmp.write(content)
            tmp_path = tmp.name
        vol = nib.load(tmp_path).get_fdata().astype(np.float32)
        os.unlink(tmp_path)
        return vol
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Failed to load {filename}: {e}")


def preprocess_volume(volumes: list[np.ndarray]) -> torch.Tensor:
    # Apply full pipeline to each modality: normalize β†’ crop β†’ resize
    # Then stack β†’ (1, 4, 128, 128, 128) with batch dim
    processed = []
    for vol in volumes:
        vol = normalize_modality(vol)
        vol = crop_to_brain(vol)
        vol = resize_volume(vol, target=TARGET, mode="trilinear")
        processed.append(vol)
    stacked = np.stack(processed, axis=0)                      # (4, 128, 128, 128)
    return torch.from_numpy(stacked).float().unsqueeze(0)      # (1, 4, 128, 128, 128)


def run_inference(input_tensor: torch.Tensor) -> np.ndarray:
    # Returns (128, 128, 128) integer label map {0,1,2,3}
    input_tensor = input_tensor.to(DEVICE)
    with torch.no_grad():
        logits = model(input_tensor)                           # (1, 4, 128, 128, 128)
        pred   = torch.argmax(logits, dim=1).squeeze(0)       # (128, 128, 128)
    return pred.cpu().numpy().astype(np.uint8)


def build_response(pred: np.ndarray, volumes: list[np.ndarray] | None = None, demo: bool = False) -> dict:
    total = pred.size
    classes = {}
    class_names  = {0: "Background", 1: "Necrotic Core", 2: "Edema", 3: "Enhancing Tumor"}
    class_colors = {0: [0,0,0,0], 1: [255,50,20,200], 2: [0,220,80,200], 3: [255,220,0,200]}

    for label in range(4):
        count = int((pred == label).sum())
        classes[str(label)] = {
            "name":       class_names[label],
            "voxels":     count,
            "percentage": round(100 * count / total, 2),
            "color":      class_colors[label],
        }

    regions = {
        "WT": int((pred > 0).sum()),
        "TC": int(np.isin(pred, [1, 3]).sum()),
        "ET": int((pred == 3).sum()),
    }

    h, w, d = pred.shape

    # Segmentation slices
    slices = {
        "axial":    pred[:, :, d // 2].tolist(),
        "coronal":  pred[:, w // 2, :].tolist(),
        "sagittal": pred[h // 2, :, :].tolist(),
    }

    # MRI slices β€” normalize each modality to 0-255 for display
    # FLAIR (index 0) is best for showing tumor context
    mri_slices = {}
    if volumes is not None:
        flair = volumes[0]   # FLAIR is most informative for tumor visualization
        # Normalize to 0–255 for frontend rendering
        flair_min, flair_max = flair.min(), flair.max()
        flair_norm = ((flair - flair_min) / (flair_max - flair_min + 1e-8) * 255).astype(np.uint8)
        mri_slices = {
            "axial":    flair_norm[:, :, d // 2].tolist(),
            "coronal":  flair_norm[:, w // 2, :].tolist(),
            "sagittal": flair_norm[h // 2, :, :].tolist(),
        }

    return {
        "success":        True,
        "demo":           demo,
        "shape":          list(pred.shape),
        "tumor_burden_%": round(100 * (pred > 0).sum() / total, 3),
        "classes":        classes,
        "regions":        regions,
        "slices":         slices,
        "mri_slices":     mri_slices,
    }

# ─── Endpoints ────────────────────────────────────────────────────────────────

@app.get("/health")
def health():
    # Called by frontend on load to check if the model is ready
    return {
        "status":           "ok",
        "device":           str(DEVICE),
        "model_loaded":     model is not None,
        "checkpoint_found": CHECKPOINT.exists(),
    }


@app.post("/segment")
async def segment(
    flair: UploadFile = File(...),
    t1:    UploadFile = File(...),
    t1ce:  UploadFile = File(...),
    t2:    UploadFile = File(...),
):
    if model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    uploads = [flair, t1, t1ce, t2]
    volumes = []
    for upload in uploads:
        content = await upload.read()
        vol     = load_nifti_bytes(content, upload.filename)
        volumes.append(vol)

    tensor = preprocess_volume(volumes)

    # Also get the preprocessed volumes for visualization
    preprocessed_vols = []
    for vol in volumes:
        v = normalize_modality(vol)
        v = crop_to_brain(v)
        v = resize_volume(v, target=TARGET, mode="trilinear")
        preprocessed_vols.append(v)

    pred = run_inference(tensor)
    return JSONResponse(build_response(pred, volumes=preprocessed_vols, demo=False))


@app.post("/segment/demo")
def segment_demo():
    # Runs inference on a synthetic random volume β€” no file upload needed.
    # Useful for testing the frontend without real patient data.
    if model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    synthetic = torch.randn(1, 4, 128, 128, 128)
    pred      = run_inference(synthetic)
    return JSONResponse(build_response(pred, volumes=None, demo=True))