Spaces:
Sleeping
Sleeping
Add BraTS2020 segmentation pipeline - UNet3D, FastAPI backend, React frontend, 110 epochs Mean Dice 0.557
2f33c28 | import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset | |
| import nibabel as nib | |
| from pathlib import Path | |
| MODALITIES = ["flair", "t1", "t1ce", "t2"] | |
| TARGET_SHAPE = (128, 128, 128) | |
| # βββ Preprocessing Functions ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def normalize_modality(vol: np.ndarray) -> np.ndarray: | |
| brain_mask = vol > 0 | |
| if brain_mask.sum() == 0: | |
| return vol | |
| mu = vol[brain_mask].mean() | |
| std = vol[brain_mask].std() + 1e-8 | |
| normalized = np.zeros_like(vol) | |
| normalized[brain_mask] = (vol[brain_mask] - mu) / std | |
| return normalized.astype(np.float32) | |
| def crop_to_brain(vol: np.ndarray) -> np.ndarray: | |
| coords = np.array(np.where(vol > 0)) | |
| if coords.shape[1] == 0: | |
| return vol | |
| mins = coords.min(axis=1) | |
| maxs = coords.max(axis=1) + 1 | |
| return vol[mins[0]:maxs[0], | |
| mins[1]:maxs[1], | |
| mins[2]:maxs[2]] | |
| def resize_volume(vol: np.ndarray, target=TARGET_SHAPE, | |
| mode="trilinear") -> np.ndarray: | |
| tensor = torch.from_numpy(vol).float().unsqueeze(0).unsqueeze(0) | |
| kwargs = {"align_corners": True} if mode == "trilinear" else {} | |
| resized = F.interpolate(tensor, size=target, mode=mode, **kwargs) | |
| return resized.squeeze().numpy() | |
| # βββ Dataset ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class BraTSDataset(Dataset): | |
| """ | |
| PyTorch Dataset for BraTS2020 training data. | |
| Returns per sample: | |
| images : float32 tensor (4, 128, 128, 128) β 4 modalities | |
| mask : long tensor (128, 128, 128) β labels {0,1,2,3} | |
| """ | |
| def __init__(self, root_dir: str, split: str = "train", | |
| train_ratio: float = 0.8, seed: int = 42): | |
| root = Path(root_dir) | |
| cases = sorted([d for d in root.iterdir() if d.is_dir()]) | |
| rng = np.random.default_rng(seed) | |
| rng.shuffle(cases) | |
| # Split into train / val | |
| n_train = int(len(cases) * train_ratio) | |
| if split == "train": | |
| self.cases = cases[:n_train] | |
| else: | |
| self.cases = cases[n_train:] | |
| self.split = split | |
| def __len__(self): | |
| # DataLoader calls this to know how many batches to produce per epoch | |
| return len(self.cases) | |
| def __getitem__(self, idx: int): | |
| case_dir = self.cases[idx] | |
| case_id = case_dir.name | |
| # ββ Load and preprocess all 4 modalities βββββββββββββββββββββββββββββ | |
| volumes = [] | |
| for mod in MODALITIES: | |
| path = case_dir / f"{case_id}_{mod}.nii" | |
| vol = nib.load(str(path)).get_fdata().astype(np.float32) | |
| vol = normalize_modality(vol) | |
| vol = crop_to_brain(vol) | |
| vol = resize_volume(vol, mode="trilinear") | |
| volumes.append(vol) | |
| # Stack: list of 4 Γ (128,128,128) β (4, 128, 128, 128) | |
| stacked = np.stack(volumes, axis=0) | |
| # ββ Load and preprocess segmentation mask βββββββββββββββββββββββββββββ | |
| seg_path = case_dir / f"{case_id}_seg.nii" | |
| seg = nib.load(str(seg_path)).get_fdata().astype(np.uint8) | |
| seg[seg == 4] = 3 # remap label 4 β 3 | |
| seg = resize_volume(seg, mode="nearest") # nearest for labels | |
| seg = seg.astype(np.int64) | |
| # ββ Convert to tensors ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| images = torch.from_numpy(stacked).float() # (4, 128, 128, 128) | |
| mask = torch.from_numpy(seg).long() # (128, 128, 128) | |
| return images, mask |