Spaces:
Running
Running
| """ | |
| SolarPanelDataset for the resolution study. Identical to the clean baseline's | |
| dataset apart from taking image_size as an argument (which the trainer varies). | |
| """ | |
| from pathlib import Path | |
| import torch | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| class SolarPanelDataset(Dataset): | |
| def __init__(self, image_dir, mask_dir, image_size=128, augment=False): | |
| self.image_dir = Path(image_dir) | |
| self.mask_dir = Path(mask_dir) | |
| self.image_size = image_size | |
| self.augment = augment | |
| self.image_files = sorted(p.name for p in self.image_dir.iterdir() if p.suffix == ".jpg") | |
| self.image_transform = transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| ]) | |
| self.mask_transform = transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| ]) | |
| self.augment_transform = transforms.Compose([ | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomVerticalFlip(p=0.5), | |
| transforms.RandomRotation(15), | |
| ]) | |
| def __len__(self): | |
| return len(self.image_files) | |
| def __getitem__(self, idx): | |
| img_name = self.image_files[idx] | |
| image = Image.open(self.image_dir / img_name).convert("RGB") | |
| mask = Image.open(self.mask_dir / img_name.replace(".jpg", "_mask.png")).convert("L") | |
| image = self.image_transform(image) | |
| mask = self.mask_transform(mask) | |
| if self.augment: | |
| seed = torch.randint(0, 2**32, (1,)).item() | |
| torch.manual_seed(seed) | |
| image = self.augment_transform(image) | |
| torch.manual_seed(seed) | |
| mask = self.augment_transform(mask) | |
| mask = (mask > 0.5).float() | |
| return image, mask | |