Mohamed-ENNHIRI
Add Tab 7: resolution study (segformer_b0 + U-Net at 192/256/512)
a3200e4
Raw
History Blame Contribute Delete
1.89 kB
"""
SolarPanelDataset for the resolution study. Identical to the clean baseline's
dataset apart from taking image_size as an argument (which the trainer varies).
"""
from pathlib import Path
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
class SolarPanelDataset(Dataset):
def __init__(self, image_dir, mask_dir, image_size=128, augment=False):
self.image_dir = Path(image_dir)
self.mask_dir = Path(mask_dir)
self.image_size = image_size
self.augment = augment
self.image_files = sorted(p.name for p in self.image_dir.iterdir() if p.suffix == ".jpg")
self.image_transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
])
self.mask_transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
])
self.augment_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomRotation(15),
])
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_name = self.image_files[idx]
image = Image.open(self.image_dir / img_name).convert("RGB")
mask = Image.open(self.mask_dir / img_name.replace(".jpg", "_mask.png")).convert("L")
image = self.image_transform(image)
mask = self.mask_transform(mask)
if self.augment:
seed = torch.randint(0, 2**32, (1,)).item()
torch.manual_seed(seed)
image = self.augment_transform(image)
torch.manual_seed(seed)
mask = self.augment_transform(mask)
mask = (mask > 0.5).float()
return image, mask