Spaces:
Running
Running
| """ | |
| Subset-aware dataset for the clean data-scaling study. | |
| Reads from the deduplicated dataset at final_data_clean/. Training side reads | |
| its filename list from subsets/subset_{25,50,100}.txt; validation side reads | |
| the full cleaned val directory. | |
| """ | |
| from pathlib import Path | |
| import torch | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| class SubsetSolarPanelDataset(Dataset): | |
| def __init__(self, image_dir, mask_dir, file_list=None, image_size=128, augment=False): | |
| self.image_dir = Path(image_dir) | |
| self.mask_dir = Path(mask_dir) | |
| self.image_size = image_size | |
| self.augment = augment | |
| if file_list is not None: | |
| with open(file_list) as f: | |
| self.image_files = [line.strip() for line in f if line.strip()] | |
| else: | |
| self.image_files = sorted(p.name for p in self.image_dir.iterdir() if p.suffix == ".jpg") | |
| self.image_transform = transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| ]) | |
| self.mask_transform = transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| ]) | |
| self.augment_transform = transforms.Compose([ | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomVerticalFlip(p=0.5), | |
| transforms.RandomRotation(15), | |
| ]) | |
| def __len__(self): | |
| return len(self.image_files) | |
| def __getitem__(self, idx): | |
| img_name = self.image_files[idx] | |
| img_path = self.image_dir / img_name | |
| mask_path = self.mask_dir / img_name.replace(".jpg", "_mask.png") | |
| image = Image.open(img_path).convert("RGB") | |
| mask = Image.open(mask_path).convert("L") | |
| image = self.image_transform(image) | |
| mask = self.mask_transform(mask) | |
| if self.augment: | |
| seed = torch.randint(0, 2**32, (1,)).item() | |
| torch.manual_seed(seed) | |
| image = self.augment_transform(image) | |
| torch.manual_seed(seed) | |
| mask = self.augment_transform(mask) | |
| mask = (mask > 0.5).float() | |
| return image, mask | |