from pathlib import Path import cv2 import numpy as np import torch from torch.utils.data import Dataset class CrackSegDataset(Dataset): def __init__(self, images_dir: str, masks_dir: str, transform=None): self.images_dir = Path(images_dir) self.masks_dir = Path(masks_dir) self.transform = transform self.image_paths = sorted([p for p in self.images_dir.glob("*") if p.suffix.lower() in {".jpg",".jpeg",".png"}]) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] mask_path = self.masks_dir / (img_path.stem + ".png") if not mask_path.exists(): raise FileNotFoundError(f"Mask not found for {img_path.name}: {mask_path}") image = cv2.imread(str(img_path), cv2.IMREAD_COLOR) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) mask = (mask > 127).astype(np.uint8) # binarize if self.transform is not None: augmented = self.transform(image=image, mask=mask) image, mask = augmented["image"], augmented["mask"] # albumentations returns HWC image; convert to CHW float tensor image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 mask = torch.from_numpy(mask).unsqueeze(0).float() # [1,H,W] return image, mask