FailSafe / dataset.py
rcrane4's picture
Upload 10 files
7538d69 verified
"""
dataset.py
----------
PyTorch Dataset class for the OSF Ti-64 SEM fractography dataset.
Use this after running inspect_dataset.py to confirm your mask format.
Key decisions you may need to make after inspection:
- If masks are binary (0/255): set NUM_CLASSES=2, update MASK_SCALE
- If masks are RGB color: set COLOR_MASK=True and define COLOR_TO_LABEL
- If masks are integer labels (0..N): use as-is (ideal case)
Usage:
from dataset import FractographyDataset
ds = FractographyDataset("data/", split="train")
img, mask = ds[0]
"""
from pathlib import Path
from typing import Callable, Optional
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms.functional as TF
import random
# ── Config β€” update after running inspect_dataset.py ────────────────────────
NUM_CLASSES = 2 # update once you know how many classes are in your masks
IMAGE_SIZE = (512, 512) # resize target; SegFormer-b0 default input
MASK_SCALE = 255
# If masks use RGB color encoding instead of integer labels, set this to True
# and populate COLOR_TO_LABEL below.
COLOR_MASK = False
COLOR_TO_LABEL: dict[tuple, int] = {
# (R, G, B): class_index
# e.g. (255, 0, 0): 1,
}
# ─────────────────────────────────────────────────────────────────────────────
def rgb_mask_to_label(mask_rgb: np.ndarray, color_to_label: dict) -> np.ndarray:
"""Convert an HΓ—WΓ—3 RGB mask to an HΓ—W integer label mask."""
label = np.zeros(mask_rgb.shape[:2], dtype=np.int64)
for color, cls_idx in color_to_label.items():
match = np.all(mask_rgb == np.array(color), axis=-1)
label[match] = cls_idx
return label
class FractographyDataset(Dataset):
"""
OSF Ti-64 SEM Fractography Dataset.
Args:
data_dir: Root of downloaded data (contains subfolders with images/ + masks/).
split: "train", "val", or "all" (no splitting, returns everything).
transform: Optional callable applied to both image and mask (augmentation).
image_size: Resize target (H, W).
"""
IMAGE_EXTS = {".png", ".tif", ".tiff", ".jpg", ".jpeg"}
def __init__(
self,
data_dir: str | Path,
split: str = "all",
transform: Optional[Callable] = None,
image_size: tuple[int, int] = IMAGE_SIZE,
):
self.data_dir = Path(data_dir)
self.split = split
self.transform = transform
self.image_size = image_size
self.pairs = self._find_pairs()
if not self.pairs:
raise FileNotFoundError(
f"No image/mask pairs found in {self.data_dir}. "
"Run inspect_dataset.py to diagnose."
)
def _find_pairs(self) -> list[tuple[Path, Path]]:
pairs = []
for images_dir in sorted(self.data_dir.rglob("images_8bit")):
if not images_dir.is_dir():
continue
masks_dir = images_dir.parent / "masks_8bit"
if not masks_dir.exists():
continue
for img_path in sorted(images_dir.iterdir()):
if img_path.suffix.lower() not in self.IMAGE_EXTS:
continue
mask_path = masks_dir / img_path.name
if mask_path.exists():
pairs.append((img_path, mask_path))
return pairs
def _load_image(self, path: Path) -> torch.Tensor:
img = Image.open(path).convert("RGB")
img = img.resize((self.image_size[1], self.image_size[0]), Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
arr = (arr - mean) / std
return torch.from_numpy(arr).permute(2, 0, 1).float()
def _load_mask(self, path: Path) -> torch.Tensor:
mask_pil = Image.open(path)
if COLOR_MASK:
mask_arr = np.array(mask_pil.convert("RGB"))
mask_arr = rgb_mask_to_label(mask_arr, COLOR_TO_LABEL)
else:
mask_arr = np.array(mask_pil.convert("L"), dtype=np.int64)
if MASK_SCALE > 1:
mask_arr = mask_arr // MASK_SCALE # e.g. 0/255 β†’ 0/1
mask_pil_resized = Image.fromarray(mask_arr.astype(np.uint8)).resize(
(self.image_size[1], self.image_size[0]), Image.NEAREST # NEAREST preserves labels
)
mask_arr = np.array(mask_pil_resized, dtype=np.int64)
return torch.from_numpy(mask_arr).long() # HΓ—W
def _augment(self, image: torch.Tensor, mask: torch.Tensor):
"""Shared spatial augmentations (applied identically to image and mask)."""
# Random horizontal flip
if random.random() > 0.5:
image = TF.hflip(image)
mask = TF.hflip(mask.unsqueeze(0)).squeeze(0)
# Random vertical flip
if random.random() > 0.5:
image = TF.vflip(image)
mask = TF.vflip(mask.unsqueeze(0)).squeeze(0)
# Random 90Β° rotation
k = random.choice([0, 1, 2, 3])
if k:
image = torch.rot90(image, k, dims=[1, 2])
mask = torch.rot90(mask, k, dims=[0, 1])
return image, mask
def __len__(self) -> int:
return len(self.pairs)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
img_path, mask_path = self.pairs[idx]
image = self._load_image(img_path)
mask = self._load_mask(mask_path)
if self.split == "train" and self.transform is None:
image, mask = self._augment(image, mask)
elif self.transform is not None:
image, mask = self.transform(image, mask)
return image, mask
def __repr__(self) -> str:
return (
f"FractographyDataset("
f"n={len(self)}, split='{self.split}', "
f"image_size={self.image_size}, classes={NUM_CLASSES})"
)
def get_dataloaders(
data_dir: str | Path,
batch_size: int = 4,
train_frac: float = 0.8,
num_workers: int = 2,
) -> tuple[DataLoader, DataLoader]:
"""
Returns (train_loader, val_loader) with 80/20 split.
"""
full_dataset = FractographyDataset(data_dir, split="all")
n_train = int(len(full_dataset) * train_frac)
n_val = len(full_dataset) - n_train
train_ds, val_ds = random_split(full_dataset, [n_train, n_val])
# Override split tag so augmentation fires for train only
train_ds.dataset.split = "train"
train_loader = DataLoader(
train_ds, batch_size=batch_size, shuffle=True,
num_workers=num_workers, pin_memory=True
)
val_loader = DataLoader(
val_ds, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True
)
print(f"Train: {len(train_ds)} samples | Val: {len(val_ds)} samples")
return train_loader, val_loader
# ── Quick sanity check ────────────────────────────────────────────────────────
if __name__ == "__main__":
import sys
data_dir = sys.argv[1] if len(sys.argv) > 1 else "data"
try:
ds = FractographyDataset(data_dir)
print(ds)
img, mask = ds[0]
print(f"Image tensor: {img.shape} dtype={img.dtype} range=[{img.min():.2f}, {img.max():.2f}]")
print(f"Mask tensor: {mask.shape} dtype={mask.dtype} unique={mask.unique().tolist()}")
print("\nβœ… Dataset loads correctly.")
except FileNotFoundError as e:
print(f"❌ {e}")