File size: 4,021 Bytes
97a17c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import pandas as pd
import albumentations as A
from typing import Optional, List
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torchgeo.datamodules import NonGeoDataModule
from methane_classification_dataset import MethaneClassificationDataset

class MethaneClassificationDataModule(NonGeoDataModule):
    def __init__(
        self,
        data_root: str,
        excel_file: str,
        batch_size: int = 8,
        num_workers: int = 0,
        val_split: float = 0.2,
        seed: int = 42,
        **kwargs
    ):
        # We pass "NonGeoDataset" just to satisfy the parent class, 
        # but we instantiate specific datasets in setup()
        super().__init__(MethaneClassificationDataset, batch_size, num_workers, **kwargs)

        self.data_root = data_root
        self.excel_file = excel_file
        self.val_split = val_split
        self.seed = seed
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        # State variables for paths
        self.train_paths = []
        self.val_paths = []

    def _get_training_transforms(self):
        """Internal definition of training transforms"""
        return A.Compose([
            A.ElasticTransform(p=0.25),
            A.RandomRotate90(p=0.5),
            A.Flip(p=0.5),
            A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
        ])

    def setup(self, stage: str = None):
        # 1. Read the Excel File
        try:
            df = pd.read_csv(self.excel_file) if self.excel_file.endswith('.csv') else pd.read_excel(self.excel_file)
        except Exception as e:
            raise RuntimeError(f"Failed to load summary file: {e}")

        # 2. Filter valid paths (checking if Fold column exists or just using all data)
        # Assuming we just use all data in the file and split it 80/20 here.
        # If you need specific Fold filtering, add that logic here.
        all_paths = df['Filename'].tolist()

        # 3. Perform the Split
        self.train_paths, self.val_paths = train_test_split(
            all_paths, 
            test_size=self.val_split, 
            random_state=self.seed
        )

        # 4. Instantiate Datasets
        if stage in ("fit", "train"):
            self.train_dataset = MethaneClassificationDataset(
                root_dir=self.data_root,
                excel_file=self.excel_file,
                paths=self.train_paths,
                transform=self._get_training_transforms(),
            )
        
        if stage in ("fit", "validate", "val"):
            self.val_dataset = MethaneClassificationDataset(
                root_dir=self.data_root,
                excel_file=self.excel_file,
                paths=self.val_paths,
                transform=None, # No transforms for validation
            )

        if stage in ("test", "predict"):
            # For testing, you might want to use a specific hold-out set
            # For now, reusing val_paths or you can add logic to load a test fold
            self.test_dataset = MethaneClassificationDataset(
                root_dir=self.data_root,
                excel_file=self.excel_file,
                paths=self.val_paths,
                transform=None,
            )


    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, 
            batch_size=self.batch_size, 
            shuffle=True, 
            num_workers=self.num_workers, 
            drop_last=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, 
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=self.num_workers, 
            drop_last=True
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset, 
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=self.num_workers, 
            drop_last=True
        )