""" SolarPanelDataset for the resolution study. Identical to the clean baseline's dataset apart from taking image_size as an argument (which the trainer varies). """ from pathlib import Path import torch from torch.utils.data import Dataset from PIL import Image import torchvision.transforms as transforms class SolarPanelDataset(Dataset): def __init__(self, image_dir, mask_dir, image_size=128, augment=False): self.image_dir = Path(image_dir) self.mask_dir = Path(mask_dir) self.image_size = image_size self.augment = augment self.image_files = sorted(p.name for p in self.image_dir.iterdir() if p.suffix == ".jpg") self.image_transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), ]) self.mask_transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), ]) self.augment_transform = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.5), transforms.RandomRotation(15), ]) def __len__(self): return len(self.image_files) def __getitem__(self, idx): img_name = self.image_files[idx] image = Image.open(self.image_dir / img_name).convert("RGB") mask = Image.open(self.mask_dir / img_name.replace(".jpg", "_mask.png")).convert("L") image = self.image_transform(image) mask = self.mask_transform(mask) if self.augment: seed = torch.randint(0, 2**32, (1,)).item() torch.manual_seed(seed) image = self.augment_transform(image) torch.manual_seed(seed) mask = self.augment_transform(mask) mask = (mask > 0.5).float() return image, mask