| from pathlib import Path |
| import cv2 |
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
|
|
|
|
| class CrackSegDataset(Dataset): |
| def __init__(self, images_dir: str, masks_dir: str, transform=None): |
| self.images_dir = Path(images_dir) |
| self.masks_dir = Path(masks_dir) |
| self.transform = transform |
| self.image_paths = sorted([p for p in self.images_dir.glob("*") if p.suffix.lower() in {".jpg",".jpeg",".png"}]) |
|
|
| def __len__(self): |
| return len(self.image_paths) |
|
|
| def __getitem__(self, idx): |
| img_path = self.image_paths[idx] |
| mask_path = self.masks_dir / (img_path.stem + ".png") |
| if not mask_path.exists(): |
| raise FileNotFoundError(f"Mask not found for {img_path.name}: {mask_path}") |
|
|
| image = cv2.imread(str(img_path), cv2.IMREAD_COLOR) |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
| mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) |
| mask = (mask > 127).astype(np.uint8) |
|
|
| if self.transform is not None: |
| augmented = self.transform(image=image, mask=mask) |
| image, mask = augmented["image"], augmented["mask"] |
|
|
| |
| image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 |
| mask = torch.from_numpy(mask).unsqueeze(0).float() |
|
|
| return image, mask |
|
|