import pandas as pd import albumentations as A from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader from torchgeo.datamodules import NonGeoDataModule from methane_simulated_dataset import MethaneSimulatedDataset class MethaneSimulatedDataModule(NonGeoDataModule): def __init__( self, data_root: str, excel_file: str, batch_size: int = 8, num_workers: int = 0, val_split: float = 0.2, seed: int = 42, test_fold: int = 4, num_folds: int = 5, sim_tag: str = "toarefl", # <--- New arg for 'toarefl'/'boarefl' **kwargs ): super().__init__(MethaneSimulatedDataset, batch_size, num_workers, **kwargs) self.data_root = data_root self.excel_file = excel_file self.batch_size = batch_size self.num_workers = num_workers self.val_split = val_split self.seed = seed self.test_fold = test_fold self.num_folds = num_folds self.sim_tag = sim_tag self.train_paths = [] self.val_paths = [] def _get_training_transforms(self): return A.Compose([ A.ElasticTransform(p=0.25), A.RandomRotate90(p=0.5), A.Flip(p=0.5), A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5) ]) def _get_simulated_paths(self, paths): """Internal logic to rename files based on sim_tag""" simulated_paths = [] for path in paths: try: tokens = path.split('_') if len(tokens) >= 5: # Logic: {ID}_{tag}_{Coord1}_{Coord2} simulated_path = f"{tokens[0]}_{self.sim_tag}_{tokens[3]}_{tokens[4]}" simulated_paths.append(simulated_path) else: simulated_paths.append(path) except Exception: simulated_paths.append(path) return simulated_paths def setup(self, stage: str = None): # 1. Read Excel try: df = pd.read_excel(self.excel_file) except Exception as e: raise RuntimeError(f"Failed to load excel: {e}") # 2. Filter Folds (Exclude test_fold) all_folds = list(range(1, self.num_folds + 1)) train_pool_folds = [f for f in all_folds if f != self.test_fold] df_filtered = df[df['Fold'].isin(train_pool_folds)] raw_paths = df_filtered['Filename'].tolist() # 3. Apply Path Renaming Logic paths = self._get_simulated_paths(raw_paths) # 4. Train/Val Split self.train_paths, self.val_paths = train_test_split( paths, test_size=self.val_split, random_state=self.seed ) # 5. Instantiate Datasets if stage in ("fit", "train"): self.train_dataset = MethaneSimulatedDataset( root_dir=self.data_root, excel_file=self.excel_file, paths=self.train_paths, transform=self._get_training_transforms(), ) if stage in ("fit", "validate", "val"): self.val_dataset = MethaneSimulatedDataset( root_dir=self.data_root, excel_file=self.excel_file, paths=self.val_paths, transform=None, ) def train_dataloader(self): return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True) def val_dataloader(self): return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True) def on_after_batch_transfer(self, batch, dataloader_idx): # 1. Run TorchGeo default (expects 'image') batch = super().on_after_batch_transfer(batch, dataloader_idx) # 2. Wrap into TerraMind format {'S2L2A': ...} if 'image' in batch: s2_data = batch['image'] batch['image'] = {'S2L2A': s2_data} return batch