AndreCosta's picture
Initial clean commit with LFS configured
7b615ae
import os
import torch
import numpy as np
from PIL import Image
import scripts.config as config
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class SegmentationDataset(Dataset):
def __init__(self, transform=None):
self.image_dir = config.images
self.mask_dir = config.masks
self.transform = transform
paths = [os.path.join(self.image_dir, f) for f in os.listdir(self.image_dir) if f.lower().endswith('.jpg')]
self.image_files = [os.path.basename(f) for f in paths]
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_name = self.image_files[idx]
img_path = os.path.join(self.image_dir, img_name)
mask_path = os.path.join(self.mask_dir, img_name.replace('.jpg', '_mask.png'))
if not os.path.exists(mask_path):
raise FileNotFoundError(f"Mask not found for: {img_name}")
image = Image.open(img_path).convert("L")
mask = Image.open(mask_path).convert("L")
if self.transform:
image = self.transform(image)
mask = np.array(mask)
mask = (mask > 127).astype(np.uint8)
mask = torch.from_numpy(mask).long()
unique_vals = np.unique(mask)
if not set(unique_vals).issubset({0, 1}):
raise ValueError(f"Mask contains invalid values: {unique_vals}")
return image, mask