File size: 4,109 Bytes
97a17c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import pandas as pd
import albumentations as A
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torchgeo.datamodules import NonGeoDataModule
from methane_simulated_dataset import MethaneSimulatedDataset

class MethaneSimulatedDataModule(NonGeoDataModule):
    def __init__(
        self,
        data_root: str,
        excel_file: str,
        batch_size: int = 8,
        num_workers: int = 0,
        val_split: float = 0.2,
        seed: int = 42,
        test_fold: int = 4,   # Default test fold from your script
        num_folds: int = 5,
        **kwargs
    ):
        super().__init__(MethaneSimulatedDataset, batch_size, num_workers, **kwargs)

        self.data_root = data_root
        self.excel_file = excel_file
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split
        self.seed = seed
        self.test_fold = test_fold
        self.num_folds = num_folds
        
        self.train_paths = []
        self.val_paths = []

    def _get_training_transforms(self):
        return A.Compose([
            A.ElasticTransform(p=0.25),
            A.RandomRotate90(p=0.5),
            A.Flip(p=0.5),
            A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
        ])

    def _get_simulated_paths(self, paths):
        """Logic to rename files to I1/TOA format"""
        simulated_paths = []
        for path in paths:
            try:
                tokens = path.split('_')
                if len(tokens) >= 5:
                    simulated_path = f"{tokens[0]}_toarefl_{tokens[3]}_{tokens[4]}"
                    simulated_paths.append(simulated_path)
                else:
                    simulated_paths.append(path)
            except Exception:
                simulated_paths.append(path)
        return simulated_paths

    def setup(self, stage: str = None):
        # 1. Read Excel
        try:
            df = pd.read_excel(self.excel_file)
        except Exception as e:
            raise RuntimeError(f"Failed to load excel: {e}")

        # 2. Filter Folds (Exclude test_fold)
        all_folds = list(range(1, self.num_folds + 1))
        train_pool_folds = [f for f in all_folds if f != self.test_fold]
        
        df_filtered = df[df['Fold'].isin(train_pool_folds)]
        raw_paths = df_filtered['Filename'].tolist()

        # 3. Apply Path Renaming Logic
        paths = self._get_simulated_paths(raw_paths)

        # 4. Train/Val Split
        self.train_paths, self.val_paths = train_test_split(
            paths, 
            test_size=self.val_split, 
            random_state=self.seed
        )

        # 5. Instantiate Datasets
        if stage in ("fit", "train"):
            self.train_dataset = MethaneSimulatedDataset(
                root_dir=self.data_root,
                excel_file=self.excel_file,
                paths=self.train_paths,
                transform=self._get_training_transforms(),
            )
        
        if stage in ("fit", "validate", "val"):
            self.val_dataset = MethaneSimulatedDataset(
                root_dir=self.data_root,
                excel_file=self.excel_file,
                paths=self.val_paths,
                transform=None,
            )

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, 
                          num_workers=self.num_workers, drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, 
                          num_workers=self.num_workers, drop_last=True)

    # def on_after_batch_transfer(self, batch, dataloader_idx):
    #     # 1. Run TorchGeo default (expects 'image')
    #     batch = super().on_after_batch_transfer(batch, dataloader_idx)
        
    #     # 2. Wrap into TerraMind format {'S2L2A': ...}
    #     if 'image' in batch:
    #         s2_data = batch['image']
    #         batch['image'] = {'S2L2A': s2_data}
            
    #     return batch