| """ |
| dataset.py |
| ---------- |
| PyTorch Dataset class for the OSF Ti-64 SEM fractography dataset. |
| Use this after running inspect_dataset.py to confirm your mask format. |
| |
| Key decisions you may need to make after inspection: |
| - If masks are binary (0/255): set NUM_CLASSES=2, update MASK_SCALE |
| - If masks are RGB color: set COLOR_MASK=True and define COLOR_TO_LABEL |
| - If masks are integer labels (0..N): use as-is (ideal case) |
| |
| Usage: |
| from dataset import FractographyDataset |
| ds = FractographyDataset("data/", split="train") |
| img, mask = ds[0] |
| """ |
|
|
| from pathlib import Path |
| from typing import Callable, Optional |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from torch.utils.data import Dataset, DataLoader, random_split |
| import torchvision.transforms.functional as TF |
| import random |
|
|
|
|
| |
| NUM_CLASSES = 2 |
| IMAGE_SIZE = (512, 512) |
| MASK_SCALE = 255 |
| |
| |
| COLOR_MASK = False |
| COLOR_TO_LABEL: dict[tuple, int] = { |
| |
| |
| } |
| |
|
|
|
|
| def rgb_mask_to_label(mask_rgb: np.ndarray, color_to_label: dict) -> np.ndarray: |
| """Convert an HΓWΓ3 RGB mask to an HΓW integer label mask.""" |
| label = np.zeros(mask_rgb.shape[:2], dtype=np.int64) |
| for color, cls_idx in color_to_label.items(): |
| match = np.all(mask_rgb == np.array(color), axis=-1) |
| label[match] = cls_idx |
| return label |
|
|
|
|
| class FractographyDataset(Dataset): |
| """ |
| OSF Ti-64 SEM Fractography Dataset. |
| |
| Args: |
| data_dir: Root of downloaded data (contains subfolders with images/ + masks/). |
| split: "train", "val", or "all" (no splitting, returns everything). |
| transform: Optional callable applied to both image and mask (augmentation). |
| image_size: Resize target (H, W). |
| """ |
|
|
| IMAGE_EXTS = {".png", ".tif", ".tiff", ".jpg", ".jpeg"} |
|
|
| def __init__( |
| self, |
| data_dir: str | Path, |
| split: str = "all", |
| transform: Optional[Callable] = None, |
| image_size: tuple[int, int] = IMAGE_SIZE, |
| ): |
| self.data_dir = Path(data_dir) |
| self.split = split |
| self.transform = transform |
| self.image_size = image_size |
| self.pairs = self._find_pairs() |
|
|
| if not self.pairs: |
| raise FileNotFoundError( |
| f"No image/mask pairs found in {self.data_dir}. " |
| "Run inspect_dataset.py to diagnose." |
| ) |
| def _find_pairs(self) -> list[tuple[Path, Path]]: |
| pairs = [] |
| for images_dir in sorted(self.data_dir.rglob("images_8bit")): |
| if not images_dir.is_dir(): |
| continue |
| masks_dir = images_dir.parent / "masks_8bit" |
| if not masks_dir.exists(): |
| continue |
| for img_path in sorted(images_dir.iterdir()): |
| if img_path.suffix.lower() not in self.IMAGE_EXTS: |
| continue |
| mask_path = masks_dir / img_path.name |
| if mask_path.exists(): |
| pairs.append((img_path, mask_path)) |
| return pairs |
| def _load_image(self, path: Path) -> torch.Tensor: |
| img = Image.open(path).convert("RGB") |
| img = img.resize((self.image_size[1], self.image_size[0]), Image.BILINEAR) |
| arr = np.array(img, dtype=np.float32) / 255.0 |
| mean = np.array([0.485, 0.456, 0.406]) |
| std = np.array([0.229, 0.224, 0.225]) |
| arr = (arr - mean) / std |
| return torch.from_numpy(arr).permute(2, 0, 1).float() |
| def _load_mask(self, path: Path) -> torch.Tensor: |
| mask_pil = Image.open(path) |
|
|
| if COLOR_MASK: |
| mask_arr = np.array(mask_pil.convert("RGB")) |
| mask_arr = rgb_mask_to_label(mask_arr, COLOR_TO_LABEL) |
| else: |
| mask_arr = np.array(mask_pil.convert("L"), dtype=np.int64) |
| if MASK_SCALE > 1: |
| mask_arr = mask_arr // MASK_SCALE |
|
|
| mask_pil_resized = Image.fromarray(mask_arr.astype(np.uint8)).resize( |
| (self.image_size[1], self.image_size[0]), Image.NEAREST |
| ) |
| mask_arr = np.array(mask_pil_resized, dtype=np.int64) |
| return torch.from_numpy(mask_arr).long() |
|
|
| def _augment(self, image: torch.Tensor, mask: torch.Tensor): |
| """Shared spatial augmentations (applied identically to image and mask).""" |
| |
| if random.random() > 0.5: |
| image = TF.hflip(image) |
| mask = TF.hflip(mask.unsqueeze(0)).squeeze(0) |
|
|
| |
| if random.random() > 0.5: |
| image = TF.vflip(image) |
| mask = TF.vflip(mask.unsqueeze(0)).squeeze(0) |
|
|
| |
| k = random.choice([0, 1, 2, 3]) |
| if k: |
| image = torch.rot90(image, k, dims=[1, 2]) |
| mask = torch.rot90(mask, k, dims=[0, 1]) |
|
|
| return image, mask |
|
|
| def __len__(self) -> int: |
| return len(self.pairs) |
|
|
| def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: |
| img_path, mask_path = self.pairs[idx] |
| image = self._load_image(img_path) |
| mask = self._load_mask(mask_path) |
|
|
| if self.split == "train" and self.transform is None: |
| image, mask = self._augment(image, mask) |
| elif self.transform is not None: |
| image, mask = self.transform(image, mask) |
|
|
| return image, mask |
|
|
| def __repr__(self) -> str: |
| return ( |
| f"FractographyDataset(" |
| f"n={len(self)}, split='{self.split}', " |
| f"image_size={self.image_size}, classes={NUM_CLASSES})" |
| ) |
|
|
|
|
| def get_dataloaders( |
| data_dir: str | Path, |
| batch_size: int = 4, |
| train_frac: float = 0.8, |
| num_workers: int = 2, |
| ) -> tuple[DataLoader, DataLoader]: |
| """ |
| Returns (train_loader, val_loader) with 80/20 split. |
| """ |
| full_dataset = FractographyDataset(data_dir, split="all") |
| n_train = int(len(full_dataset) * train_frac) |
| n_val = len(full_dataset) - n_train |
| train_ds, val_ds = random_split(full_dataset, [n_train, n_val]) |
|
|
| |
| train_ds.dataset.split = "train" |
|
|
| train_loader = DataLoader( |
| train_ds, batch_size=batch_size, shuffle=True, |
| num_workers=num_workers, pin_memory=True |
| ) |
| val_loader = DataLoader( |
| val_ds, batch_size=batch_size, shuffle=False, |
| num_workers=num_workers, pin_memory=True |
| ) |
| print(f"Train: {len(train_ds)} samples | Val: {len(val_ds)} samples") |
| return train_loader, val_loader |
|
|
|
|
| |
| if __name__ == "__main__": |
| import sys |
| data_dir = sys.argv[1] if len(sys.argv) > 1 else "data" |
|
|
| try: |
| ds = FractographyDataset(data_dir) |
| print(ds) |
| img, mask = ds[0] |
| print(f"Image tensor: {img.shape} dtype={img.dtype} range=[{img.min():.2f}, {img.max():.2f}]") |
| print(f"Mask tensor: {mask.shape} dtype={mask.dtype} unique={mask.unique().tolist()}") |
| print("\nβ
Dataset loads correctly.") |
| except FileNotFoundError as e: |
| print(f"β {e}") |
|
|