File size: 2,223 Bytes
52efd90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Subset-aware dataset for the clean data-scaling study.

Reads from the deduplicated dataset at final_data_clean/. Training side reads
its filename list from subsets/subset_{25,50,100}.txt; validation side reads
the full cleaned val directory.
"""
from pathlib import Path
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms


class SubsetSolarPanelDataset(Dataset):
    def __init__(self, image_dir, mask_dir, file_list=None, image_size=128, augment=False):
        self.image_dir = Path(image_dir)
        self.mask_dir = Path(mask_dir)
        self.image_size = image_size
        self.augment = augment

        if file_list is not None:
            with open(file_list) as f:
                self.image_files = [line.strip() for line in f if line.strip()]
        else:
            self.image_files = sorted(p.name for p in self.image_dir.iterdir() if p.suffix == ".jpg")

        self.image_transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
        ])
        self.mask_transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
        ])
        self.augment_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(15),
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = self.image_dir / img_name
        mask_path = self.mask_dir / img_name.replace(".jpg", "_mask.png")

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        image = self.image_transform(image)
        mask = self.mask_transform(mask)

        if self.augment:
            seed = torch.randint(0, 2**32, (1,)).item()
            torch.manual_seed(seed)
            image = self.augment_transform(image)
            torch.manual_seed(seed)
            mask = self.augment_transform(mask)

        mask = (mask > 0.5).float()
        return image, mask