| import os | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import scripts.config as config | |
| from torch.utils.data import Dataset | |
| import torchvision.transforms as transforms | |
| class SegmentationDataset(Dataset): | |
| def __init__(self, transform=None): | |
| self.image_dir = config.images | |
| self.mask_dir = config.masks | |
| self.transform = transform | |
| paths = [os.path.join(self.image_dir, f) for f in os.listdir(self.image_dir) if f.lower().endswith('.jpg')] | |
| self.image_files = [os.path.basename(f) for f in paths] | |
| def __len__(self): | |
| return len(self.image_files) | |
| def __getitem__(self, idx): | |
| img_name = self.image_files[idx] | |
| img_path = os.path.join(self.image_dir, img_name) | |
| mask_path = os.path.join(self.mask_dir, img_name.replace('.jpg', '_mask.png')) | |
| if not os.path.exists(mask_path): | |
| raise FileNotFoundError(f"Mask not found for: {img_name}") | |
| image = Image.open(img_path).convert("L") | |
| mask = Image.open(mask_path).convert("L") | |
| if self.transform: | |
| image = self.transform(image) | |
| mask = np.array(mask) | |
| mask = (mask > 127).astype(np.uint8) | |
| mask = torch.from_numpy(mask).long() | |
| unique_vals = np.unique(mask) | |
| if not set(unique_vals).issubset({0, 1}): | |
| raise ValueError(f"Mask contains invalid values: {unique_vals}") | |
| return image, mask | |