File size: 1,441 Bytes
7b615ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import os
import torch
import numpy as np
from PIL import Image
import scripts.config as config
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class SegmentationDataset(Dataset):
def __init__(self, transform=None):
self.image_dir = config.images
self.mask_dir = config.masks
self.transform = transform
paths = [os.path.join(self.image_dir, f) for f in os.listdir(self.image_dir) if f.lower().endswith('.jpg')]
self.image_files = [os.path.basename(f) for f in paths]
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_name = self.image_files[idx]
img_path = os.path.join(self.image_dir, img_name)
mask_path = os.path.join(self.mask_dir, img_name.replace('.jpg', '_mask.png'))
if not os.path.exists(mask_path):
raise FileNotFoundError(f"Mask not found for: {img_name}")
image = Image.open(img_path).convert("L")
mask = Image.open(mask_path).convert("L")
if self.transform:
image = self.transform(image)
mask = np.array(mask)
mask = (mask > 127).astype(np.uint8)
mask = torch.from_numpy(mask).long()
unique_vals = np.unique(mask)
if not set(unique_vals).issubset({0, 1}):
raise ValueError(f"Mask contains invalid values: {unique_vals}")
return image, mask
|