File size: 7,869 Bytes
7538d69 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | """
dataset.py
----------
PyTorch Dataset class for the OSF Ti-64 SEM fractography dataset.
Use this after running inspect_dataset.py to confirm your mask format.
Key decisions you may need to make after inspection:
- If masks are binary (0/255): set NUM_CLASSES=2, update MASK_SCALE
- If masks are RGB color: set COLOR_MASK=True and define COLOR_TO_LABEL
- If masks are integer labels (0..N): use as-is (ideal case)
Usage:
from dataset import FractographyDataset
ds = FractographyDataset("data/", split="train")
img, mask = ds[0]
"""
from pathlib import Path
from typing import Callable, Optional
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms.functional as TF
import random
# ββ Config β update after running inspect_dataset.py ββββββββββββββββββββββββ
NUM_CLASSES = 2 # update once you know how many classes are in your masks
IMAGE_SIZE = (512, 512) # resize target; SegFormer-b0 default input
MASK_SCALE = 255
# If masks use RGB color encoding instead of integer labels, set this to True
# and populate COLOR_TO_LABEL below.
COLOR_MASK = False
COLOR_TO_LABEL: dict[tuple, int] = {
# (R, G, B): class_index
# e.g. (255, 0, 0): 1,
}
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def rgb_mask_to_label(mask_rgb: np.ndarray, color_to_label: dict) -> np.ndarray:
"""Convert an HΓWΓ3 RGB mask to an HΓW integer label mask."""
label = np.zeros(mask_rgb.shape[:2], dtype=np.int64)
for color, cls_idx in color_to_label.items():
match = np.all(mask_rgb == np.array(color), axis=-1)
label[match] = cls_idx
return label
class FractographyDataset(Dataset):
"""
OSF Ti-64 SEM Fractography Dataset.
Args:
data_dir: Root of downloaded data (contains subfolders with images/ + masks/).
split: "train", "val", or "all" (no splitting, returns everything).
transform: Optional callable applied to both image and mask (augmentation).
image_size: Resize target (H, W).
"""
IMAGE_EXTS = {".png", ".tif", ".tiff", ".jpg", ".jpeg"}
def __init__(
self,
data_dir: str | Path,
split: str = "all",
transform: Optional[Callable] = None,
image_size: tuple[int, int] = IMAGE_SIZE,
):
self.data_dir = Path(data_dir)
self.split = split
self.transform = transform
self.image_size = image_size
self.pairs = self._find_pairs()
if not self.pairs:
raise FileNotFoundError(
f"No image/mask pairs found in {self.data_dir}. "
"Run inspect_dataset.py to diagnose."
)
def _find_pairs(self) -> list[tuple[Path, Path]]:
pairs = []
for images_dir in sorted(self.data_dir.rglob("images_8bit")):
if not images_dir.is_dir():
continue
masks_dir = images_dir.parent / "masks_8bit"
if not masks_dir.exists():
continue
for img_path in sorted(images_dir.iterdir()):
if img_path.suffix.lower() not in self.IMAGE_EXTS:
continue
mask_path = masks_dir / img_path.name
if mask_path.exists():
pairs.append((img_path, mask_path))
return pairs
def _load_image(self, path: Path) -> torch.Tensor:
img = Image.open(path).convert("RGB")
img = img.resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
arr = (arr - mean) / std
return torch.from_numpy(arr).permute(2, 0, 1).float()
def _load_mask(self, path: Path) -> torch.Tensor:
mask_pil = Image.open(path)
if COLOR_MASK:
mask_arr = np.array(mask_pil.convert("RGB"))
mask_arr = rgb_mask_to_label(mask_arr, COLOR_TO_LABEL)
else:
mask_arr = np.array(mask_pil.convert("L"), dtype=np.int64)
if MASK_SCALE > 1:
mask_arr = mask_arr // MASK_SCALE # e.g. 0/255 β 0/1
mask_pil_resized = Image.fromarray(mask_arr.astype(np.uint8)).resize(
(self.image_size[1], self.image_size[0]), Image.NEAREST # NEAREST preserves labels
)
mask_arr = np.array(mask_pil_resized, dtype=np.int64)
return torch.from_numpy(mask_arr).long() # HΓW
def _augment(self, image: torch.Tensor, mask: torch.Tensor):
"""Shared spatial augmentations (applied identically to image and mask)."""
# Random horizontal flip
if random.random() > 0.5:
image = TF.hflip(image)
mask = TF.hflip(mask.unsqueeze(0)).squeeze(0)
# Random vertical flip
if random.random() > 0.5:
image = TF.vflip(image)
mask = TF.vflip(mask.unsqueeze(0)).squeeze(0)
# Random 90Β° rotation
k = random.choice([0, 1, 2, 3])
if k:
image = torch.rot90(image, k, dims=[1, 2])
mask = torch.rot90(mask, k, dims=[0, 1])
return image, mask
def __len__(self) -> int:
return len(self.pairs)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
img_path, mask_path = self.pairs[idx]
image = self._load_image(img_path)
mask = self._load_mask(mask_path)
if self.split == "train" and self.transform is None:
image, mask = self._augment(image, mask)
elif self.transform is not None:
image, mask = self.transform(image, mask)
return image, mask
def __repr__(self) -> str:
return (
f"FractographyDataset("
f"n={len(self)}, split='{self.split}', "
f"image_size={self.image_size}, classes={NUM_CLASSES})"
)
def get_dataloaders(
data_dir: str | Path,
batch_size: int = 4,
train_frac: float = 0.8,
num_workers: int = 2,
) -> tuple[DataLoader, DataLoader]:
"""
Returns (train_loader, val_loader) with 80/20 split.
"""
full_dataset = FractographyDataset(data_dir, split="all")
n_train = int(len(full_dataset) * train_frac)
n_val = len(full_dataset) - n_train
train_ds, val_ds = random_split(full_dataset, [n_train, n_val])
# Override split tag so augmentation fires for train only
train_ds.dataset.split = "train"
train_loader = DataLoader(
train_ds, batch_size=batch_size, shuffle=True,
num_workers=num_workers, pin_memory=True
)
val_loader = DataLoader(
val_ds, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True
)
print(f"Train: {len(train_ds)} samples | Val: {len(val_ds)} samples")
return train_loader, val_loader
# ββ Quick sanity check ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
import sys
data_dir = sys.argv[1] if len(sys.argv) > 1 else "data"
try:
ds = FractographyDataset(data_dir)
print(ds)
img, mask = ds[0]
print(f"Image tensor: {img.shape} dtype={img.dtype} range=[{img.min():.2f}, {img.max():.2f}]")
print(f"Mask tensor: {mask.shape} dtype={mask.dtype} unique={mask.unique().tolist()}")
print("\nβ
Dataset loads correctly.")
except FileNotFoundError as e:
print(f"β {e}")
|