brats-segmentation / src /dataset.py
farahhamad's picture
Add BraTS2020 segmentation pipeline - UNet3D, FastAPI backend, React frontend, 110 epochs Mean Dice 0.557
2f33c28
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import nibabel as nib
from pathlib import Path
MODALITIES = ["flair", "t1", "t1ce", "t2"]
TARGET_SHAPE = (128, 128, 128)
# ─── Preprocessing Functions ──────────────────────────────────────────────────
def normalize_modality(vol: np.ndarray) -> np.ndarray:
brain_mask = vol > 0
if brain_mask.sum() == 0:
return vol
mu = vol[brain_mask].mean()
std = vol[brain_mask].std() + 1e-8
normalized = np.zeros_like(vol)
normalized[brain_mask] = (vol[brain_mask] - mu) / std
return normalized.astype(np.float32)
def crop_to_brain(vol: np.ndarray) -> np.ndarray:
coords = np.array(np.where(vol > 0))
if coords.shape[1] == 0:
return vol
mins = coords.min(axis=1)
maxs = coords.max(axis=1) + 1
return vol[mins[0]:maxs[0],
mins[1]:maxs[1],
mins[2]:maxs[2]]
def resize_volume(vol: np.ndarray, target=TARGET_SHAPE,
mode="trilinear") -> np.ndarray:
tensor = torch.from_numpy(vol).float().unsqueeze(0).unsqueeze(0)
kwargs = {"align_corners": True} if mode == "trilinear" else {}
resized = F.interpolate(tensor, size=target, mode=mode, **kwargs)
return resized.squeeze().numpy()
# ─── Dataset ──────────────────────────────────────────────────────────────────
class BraTSDataset(Dataset):
"""
PyTorch Dataset for BraTS2020 training data.
Returns per sample:
images : float32 tensor (4, 128, 128, 128) β€” 4 modalities
mask : long tensor (128, 128, 128) β€” labels {0,1,2,3}
"""
def __init__(self, root_dir: str, split: str = "train",
train_ratio: float = 0.8, seed: int = 42):
root = Path(root_dir)
cases = sorted([d for d in root.iterdir() if d.is_dir()])
rng = np.random.default_rng(seed)
rng.shuffle(cases)
# Split into train / val
n_train = int(len(cases) * train_ratio)
if split == "train":
self.cases = cases[:n_train]
else:
self.cases = cases[n_train:]
self.split = split
def __len__(self):
# DataLoader calls this to know how many batches to produce per epoch
return len(self.cases)
def __getitem__(self, idx: int):
case_dir = self.cases[idx]
case_id = case_dir.name
# ── Load and preprocess all 4 modalities ─────────────────────────────
volumes = []
for mod in MODALITIES:
path = case_dir / f"{case_id}_{mod}.nii"
vol = nib.load(str(path)).get_fdata().astype(np.float32)
vol = normalize_modality(vol)
vol = crop_to_brain(vol)
vol = resize_volume(vol, mode="trilinear")
volumes.append(vol)
# Stack: list of 4 Γ— (128,128,128) β†’ (4, 128, 128, 128)
stacked = np.stack(volumes, axis=0)
# ── Load and preprocess segmentation mask ─────────────────────────────
seg_path = case_dir / f"{case_id}_seg.nii"
seg = nib.load(str(seg_path)).get_fdata().astype(np.uint8)
seg[seg == 4] = 3 # remap label 4 β†’ 3
seg = resize_volume(seg, mode="nearest") # nearest for labels
seg = seg.astype(np.int64)
# ── Convert to tensors ────────────────────────────────────────────────
images = torch.from_numpy(stacked).float() # (4, 128, 128, 128)
mask = torch.from_numpy(seg).long() # (128, 128, 128)
return images, mask