fractex2D_tuto / src /dataset_benchm.py
Ayoub
add metrics computation
ce5153c
from pathlib import Path
import random
from typing import List, Optional, Sequence, Tuple
import numpy as np
import torch
import torchvision.transforms.v2 as t
import torchvision.transforms.v2.functional as TF
from skimage import io
from skimage.filters.rank import maximum
from skimage.measure import label
from skimage.morphology import binary_dilation, dilation, disk
from skimage.segmentation import expand_labels
from torch.utils.data import ConcatDataset, DataLoader, Dataset
# -------------------------
# Label pre-processing
# -------------------------
def expand_wide_fractures_gt(
img: np.ndarray,
gt: np.ndarray,
disk_size: int = 2,
thresh: int = 30,
gt_thresh: int = 100,
gt_ext: str = "png",
) -> np.ndarray:
"""
Expand a binary/soft ground-truth mask to include nearby wide/dark fractures.
Method:
- Use green channel (index 1) as a grayscale proxy.
- Apply a maximum filter to emphasize large dark regions.
- Threshold and dilate to form a candidate mask.
- Keep only connected components that overlap the original GT.
- Return a combined mask as uint8 (0..255). If gt_ext contains "tif" the
original `gt` is assumed to be already in [0,1] or in the original dtype;
the code preserves existing scaling behavior from the original script.
Args:
img: HxWxC image (expects at least 2 channels; green channel used).
gt: HxW ground-truth mask (expected in [0..1] or [0..255]).
disk_size: radius for morphological operations.
thresh: threshold applied to the maximum-filtered gray image.
gt_thresh: threshold to consider a pixel part of the original GT.
gt_ext: file extension of GT (affects final combination step).
Returns:
Expanded GT mask as np.uint8 (values 0 or 255).
"""
if img.ndim < 3 or img.shape[2] < 2:
raise ValueError("img must have at least 2 channels (uses green channel).")
# use green channel as grayscale proxy
gray = img[..., 1].astype(np.uint8)
# keep large dark areas via maximum filter, then threshold and dilate
imax = maximum(gray, disk(disk_size))
candidate = binary_dilation(imax < thresh, disk(disk_size))
# combine candidate with existing GT (considering gt_thresh)
gt_bool = gt > gt_thresh
combined = np.logical_or(candidate, gt_bool)
# remove connected components that do not overlap original GT
labeled, num = label(combined, connectivity=1, return_num=True)
for comp_id in range(1, num + 1):
comp_mask = labeled == comp_id
if not np.any(gt_bool[comp_mask]):
combined[comp_mask] = False
# produce uint8 [0,255] result with behavior matching original code
if "tif" in gt_ext:
# preserve original gt scaling behavior from source
new_gt = (np.array(gt * 255, dtype=np.uint8) | np.array(combined * 255, dtype=np.uint8))
else:
new_gt = (np.array(gt, dtype=np.uint8) | np.array(combined * 255, dtype=np.uint8))
return new_gt
def dilate_labels(image: np.ndarray) -> np.ndarray:
"""
Smooth label boundaries by multi-scale dilation and blending.
- Expand labels to fill tiny gaps (expand_labels).
- Create three dilation masks with increasing disks and blend them into
a smoothed label map with decreasing weights.
Args:
image: integer-labeled image or binary mask (HxW).
Returns:
np.uint8 array (HxW) with blended/smoothed label boundaries.
"""
expanded = expand_labels(image, distance=2)
# Multi-scale dilation masks (exclusive differences)
d1 = dilation(expanded, disk(2)) ^ expanded
d2 = dilation(expanded, disk(5)) ^ d1 ^ expanded
d3 = dilation(expanded, disk(7)) ^ d2 ^ d1 ^ expanded
blended = expanded + d1 / 3.0 + d2 / 5.0 + d3 / 9.0
return np.array(blended, dtype=np.uint8)
# -------------------------
# Augmentation helpers
# -------------------------
def _apply_random_flips(image: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Random horizontal and vertical flips (50% each)."""
if random.random() > 0.5:
image, mask = TF.hflip(image), TF.hflip(mask)
if random.random() > 0.5:
image, mask = TF.vflip(image), TF.vflip(mask)
return image, mask
def _apply_random_photometric_augmentations(image: torch.Tensor, prob_config: Optional[dict] = None) -> torch.Tensor:
"""
Photometric augmentations applied independently with small probabilities.
The function preserves an extra channel (e.g. DEM) if image has 4 channels:
- augment only the first three (RGB) channels, then concatenate the extra.
"""
if prob_config is None:
prob_config = {
"gaussian_blur": 0.05,
"darken_low": 0.05,
"brighten": 0.15,
"contrast": 0.05,
"saturation": 0.05,
}
has_extra = image.shape[0] == 4
rgb = image[:3] if has_extra else image
# gaussian blur
if random.random() < prob_config["gaussian_blur"]:
sigma = random.uniform(0.1, 2.0)
rgb = TF.gaussian_blur(rgb, kernel_size=5, sigma=sigma)
# darken (factor < 1)
if random.random() < prob_config["darken_low"]:
factor = random.uniform(0.7, 0.9)
rgb = TF.adjust_brightness(rgb, factor)
# brighten (factor > 1)
if random.random() < prob_config["brighten"]:
factor = random.uniform(1.1, 1.7)
rgb = TF.adjust_brightness(rgb, factor)
# contrast
if random.random() < prob_config["contrast"]:
factor = random.uniform(0.7, 1.5)
rgb = TF.adjust_contrast(rgb, factor)
# saturation
if random.random() < prob_config["saturation"]:
factor = random.uniform(0.7, 1.5)
rgb = TF.adjust_saturation(rgb, factor)
if has_extra:
image = torch.cat([rgb, image[3:]], dim=0)
else:
image = rgb
return image
# -------------------------
# Base dataset utilities
# -------------------------
def _read_image(path: Path) -> np.ndarray:
"""Read image with skimage.io and ensure dtype uint8."""
arr = io.imread(str(path))
# convert floats to uint8 if necessary
if arr.dtype != np.uint8:
arr = arr.astype(np.uint8)
return arr
def _read_mask(path: Path) -> np.ndarray:
"""Read mask and convert to uint8 0..255."""
arr = io.imread(str(path))
if arr.dtype != np.uint8:
arr = (arr * 255).astype(np.uint8) if arr.max() <= 1.0 else arr.astype(np.uint8)
return arr
# -------------------------
# Dataset classes
# -------------------------
class BaseCrackDataset(Dataset):
"""
Minimal common functionality for the specific dataset wrappers used downstream.
Subclasses must provide:
- self.images (list[Path])
- self.masks (list[Path])
- optional self.dems (list[Path]) when in_channels==4
"""
def __init__(
self,
images: Sequence[Path],
masks: Sequence[Path],
dem_paths: Optional[Sequence[Path]] = None,
topo: bool = False,
transform: bool = False,
expand: bool = True,
dilate: bool = True,
in_channels: int = 3,
):
self.images = list(images)
self.masks = list(masks)
self.dems = list(dem_paths) if dem_paths is not None else None
self.topo = topo
self.transform = transform
self.expand = expand
self.dilate = dilate
self.in_channels = in_channels
def __len__(self) -> int:
return len(self.images)
def _load_pair(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Load image/mask pair, apply optional expand/dilate and channel handling,
then perform flips and photometric augmentations.
"""
img_np = _read_image(Path(self.images[idx]))
gt_np = _read_mask(Path(self.masks[idx]))
# expand wide fractures (if requested)
if self.expand:
gt_np = expand_wide_fractures_gt(img_np[:, :, :3].astype(np.uint8), gt_np)
# dilate labels (if requested)
if self.dilate:
gt_np = dilate_labels(gt_np)
# build image tensor. If dataset provides DEM as a separate file, append as 4th channel.
img_tensor = torch.from_numpy(img_np[:, :, :3])
if self.in_channels == 4:
# if DEM present inside the image array or as separate file, handle both cases
if img_np.shape[2] >= 4:
dem_np = img_np[:, :, 3].astype(np.float32)
elif self.dems is not None:
dem_np = _read_image(Path(self.dems[idx])).astype(np.float32)
else:
raise RuntimeError("Requested 4 input channels but no DEM found.")
# normalize DEM to [0,1]
dem_tensor = torch.from_numpy(dem_np).float()
dem_tensor = (dem_tensor - dem_tensor.min()) / (dem_tensor.max() - dem_tensor.min() + 1e-8)
img_tensor = torch.cat((img_tensor, dem_tensor.unsqueeze(2)), axis=2)
# reformat to C,H,W and normalize image to [0,1]
img_tensor = img_tensor.permute(2, 0, 1).float() / 255.0
mask_tensor = torch.from_numpy(gt_np).unsqueeze(0).float() / 255.0
# random flips
img_tensor, mask_tensor = _apply_random_flips(img_tensor, mask_tensor)
# photometric augmentations
if self.transform:
img_tensor = _apply_random_photometric_augmentations(img_tensor)
return img_tensor.float(), mask_tensor.float()
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
idx = index % len(self.images)
return self._load_pair(idx)
# -------------------------
# Concrete dataset wrappers
# -------------------------
def _read_list_file(list_path: Path) -> List[str]:
"""Read non-empty lines from a list file and return them as strings."""
with list_path.open("r") as f:
return [ln.strip() for ln in f if ln.strip()]
class OVAS(BaseCrackDataset):
"""OVAS dataset wrapper. Expects directory structure: <root>/<subset>/{image,gt,dem}."""
def __init__(
self,
subset: str,
list_file: Optional[str] = "list.txt",
topo: bool = False,
transform: bool = False,
expand: bool = True,
dilate: bool = True,
in_channels: int = 3,
):
root = Path("data/ovaskainen23_") / subset
ext_img = "png"
ext_gt = "tif"
names = []
if list_file:
names = _read_list_file(root / list_file)
images = [
(root / "image" / n).with_suffix("." + ext_img)
for n in names
if n.endswith("." + ext_gt)
]
masks = [root / "gt" / n for n in names if n.endswith("." + ext_gt)]
dems = [root / "dem" / n for n in names if n.endswith("." + ext_gt)]
else:
images = sorted(path for path in (root / "image").iterdir() if path.suffix.lower().lstrip(".") == ext_img)
masks = sorted(path for path in (root / "gt").iterdir() if path.suffix.lower().lstrip(".") == ext_gt)
dems = sorted(path for path in (root / "dem").iterdir() if path.suffix.lower().lstrip(".") == ext_gt)
super().__init__(images=images, masks=masks, dem_paths=dems, topo=topo, transform=transform,
expand=expand, dilate=dilate, in_channels=in_channels)
class MATTEO(BaseCrackDataset):
"""MATTEO dataset wrapper. Expects .tif files; includes DEM channel inside the image."""
def __init__(
self,
subset: str,
list_file: Optional[str] = "list.txt",
topo: bool = False,
transform: bool = False,
expand: bool = True,
dilate: bool = True,
in_channels: int = 3,
):
root = Path("data/matteo21") / subset
ext = "tif"
if list_file:
names = _read_list_file(root / list_file)
else:
names = [p.name for p in (root / "image").iterdir() if p.suffix.lstrip(".") == ext]
images = sorted(root / "image" / name for name in names)
masks = sorted(root / "gt" / name for name in names)
super().__init__(images=images, masks=masks, dem_paths=None, topo=topo, transform=transform,
expand=expand, dilate=dilate, in_channels=in_channels)
class SAMSU(BaseCrackDataset):
"""SAMSU dataset wrapper. Similar layout to OVAS."""
def __init__(
self,
subset: str,
list_file: Optional[str] = "list.txt",
topo: bool = False,
transform: bool = False,
expand: bool = True,
dilate: bool = True,
in_channels: int = 3,
):
root = Path("data/samsu19") / subset
ext_img = "png"
ext_gt = "tif"
names = []
if list_file:
names = _read_list_file(root / list_file)
images = [
(root / "image" / n).with_suffix("." + ext_img)
for n in names
if n.endswith("." + ext_gt)
]
masks = [root / "gt" / n for n in names if n.endswith("." + ext_gt)]
dems = [root / "dem" / n for n in names if n.endswith("." + ext_gt)]
else:
images = sorted(p for p in (root / "image").iterdir() if p.suffix.lstrip(".") == ext_img)
masks = sorted(p for p in (root / "gt").iterdir() if p.suffix.lstrip(".") == ext_gt)
dems = sorted(p for p in (root / "dem").iterdir() if p.suffix.lstrip(".") == ext_gt)
super().__init__(images=images, masks=masks, dem_paths=dems, topo=topo, transform=transform,
expand=expand, dilate=dilate, in_channels=in_channels)
class GeoCrack(BaseCrackDataset):
"""GeoCrack dataset wrapper (simple PNG images)."""
def __init__(
self,
subset: str,
topo: bool = False,
transform: bool = False,
expand: bool = True,
dilate: bool = True,
in_channels: int = 3,
):
root = Path("data/GeoCrack_") / subset
ext = "png"
images = sorted(p for p in (root / "image").iterdir() if p.suffix.lstrip(".") == ext)
masks = sorted(p for p in (root / "gt").iterdir() if p.suffix.lstrip(".") == ext)
super().__init__(images=images, masks=masks, dem_paths=None, topo=topo, transform=transform,
expand=expand, dilate=dilate, in_channels=in_channels)
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
img, mask = super().__getitem__(index)
# consistent resizing used originally
img = t.Resize(256)(img)
mask = t.Resize(256)(mask)
return img.float(), mask.float()
class DIC(BaseCrackDataset):
"""DIC dataset wrapper: single-channel images and PNG masks."""
def __init__(
self,
subset: str,
topo: bool = False,
transform: bool = False,
expand: bool = False,
dilate: bool = False,
in_channels: int = 1,
):
root = Path("data/DIC") / subset
ext_img = "tif"
ext_mask = "png"
images = sorted(p for p in (root / "image").iterdir() if p.suffix.lstrip(".") == ext_img)
masks = sorted(p for p in (root / "gt").iterdir() if p.suffix.lstrip(".") == ext_mask)
super().__init__(images=images, masks=masks, dem_paths=None, topo=topo, transform=transform,
expand=expand, dilate=dilate, in_channels=in_channels)
def _load_pair(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Override to handle single-channel image format (the base expects >=3 channels).
"""
img_np = _read_image(Path(self.images[idx]))
gt_np = _read_mask(Path(self.masks[idx]))
# ensure single channel
if img_np.ndim == 3:
img_np = img_np[..., 0]
img_tensor = torch.from_numpy(img_np).unsqueeze(0).float() / 255.0
mask_tensor = torch.from_numpy(gt_np).unsqueeze(0).float() / 255.0
img_tensor, mask_tensor = _apply_random_flips(img_tensor, mask_tensor)
if self.transform:
img_tensor = _apply_random_photometric_augmentations(img_tensor)
img_tensor = t.Resize(256)(img_tensor)
mask_tensor = t.Resize(256)(mask_tensor)
return img_tensor.float(), mask_tensor.float()
# -------------------------
# Dataset registry & loader builder
# -------------------------
DATASETS = {
"ovaskainen23": OVAS,
"matteo21": MATTEO,
"samsu19": SAMSU,
"geocrack": GeoCrack,
"dic": DIC,
}
def all_datasets(
batch_size: int = 32,
datasets: str = "samsu19-matteo21-ovaskainen23",
in_channels: int = 4,
out_channels: int = 1,
shape: int = 256,
expand: bool = True,
dilate: bool = True,
shuffle_train: bool = True,
do_transform: bool = True,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
"""
Create concatenated train/val/test DataLoaders from multiple dataset names.
Args:
batch_size: batch size for DataLoaders.
datasets: dash-separated dataset keys from DATASETS dict.
in_channels: number of input channels requested (3 or 4).
out_channels: number of output channels (kept for API compatibility).
shape: target shape (not used directly here; datasets may resize internally).
expand, dilate: whether to apply expand/dilate preprocessing.
shuffle_train: whether to shuffle the training DataLoader.
do_transform: whether to enable augmentations.
Returns:
Tuple(train_loader, val_loader, test_loader)
"""
keys = [k.strip() for k in datasets.split("-") if k.strip()]
all_train = []
all_val = []
all_test = []
for name in keys:
if name not in DATASETS:
raise KeyError(f"Unknown dataset key: {name}")
DS = DATASETS[name]
all_train.append(DS(subset="train", transform=do_transform, expand=expand, dilate=dilate, in_channels=in_channels))
all_val.append(DS(subset="valid", transform=False, expand=expand, dilate=dilate, in_channels=in_channels))
all_test.append(DS(subset="test", transform=False, expand=expand, dilate=dilate, in_channels=in_channels))
trainset = ConcatDataset(all_train)
valset = ConcatDataset(all_val)
testset = ConcatDataset(all_test)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=shuffle_train)
valloader = DataLoader(valset, batch_size=batch_size, shuffle=False)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)
return trainloader, valloader, testloader