TerraMind-Methane-Classification / intuition1_classification_finetuning /config /methane_simulated_datamodule.py
| import pandas as pd | |
| import albumentations as A | |
| from sklearn.model_selection import train_test_split | |
| from torch.utils.data import DataLoader | |
| from torchgeo.datamodules import NonGeoDataModule | |
| from methane_simulated_dataset import MethaneSimulatedDataset | |
| class MethaneSimulatedDataModule(NonGeoDataModule): | |
| def __init__( | |
| self, | |
| data_root: str, | |
| excel_file: str, | |
| batch_size: int = 8, | |
| num_workers: int = 0, | |
| val_split: float = 0.2, | |
| seed: int = 42, | |
| test_fold: int = 4, # Default test fold from your script | |
| num_folds: int = 5, | |
| **kwargs | |
| ): | |
| super().__init__(MethaneSimulatedDataset, batch_size, num_workers, **kwargs) | |
| self.data_root = data_root | |
| self.excel_file = excel_file | |
| self.batch_size = batch_size | |
| self.num_workers = num_workers | |
| self.val_split = val_split | |
| self.seed = seed | |
| self.test_fold = test_fold | |
| self.num_folds = num_folds | |
| self.train_paths = [] | |
| self.val_paths = [] | |
| def _get_training_transforms(self): | |
| return A.Compose([ | |
| A.ElasticTransform(p=0.25), | |
| A.RandomRotate90(p=0.5), | |
| A.Flip(p=0.5), | |
| A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5) | |
| ]) | |
| def _get_simulated_paths(self, paths): | |
| """Logic to rename files to I1/TOA format""" | |
| simulated_paths = [] | |
| for path in paths: | |
| try: | |
| tokens = path.split('_') | |
| if len(tokens) >= 5: | |
| simulated_path = f"{tokens[0]}_toarefl_{tokens[3]}_{tokens[4]}" | |
| simulated_paths.append(simulated_path) | |
| else: | |
| simulated_paths.append(path) | |
| except Exception: | |
| simulated_paths.append(path) | |
| return simulated_paths | |
| def setup(self, stage: str = None): | |
| # 1. Read Excel | |
| try: | |
| df = pd.read_excel(self.excel_file) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load excel: {e}") | |
| # 2. Filter Folds (Exclude test_fold) | |
| all_folds = list(range(1, self.num_folds + 1)) | |
| train_pool_folds = [f for f in all_folds if f != self.test_fold] | |
| df_filtered = df[df['Fold'].isin(train_pool_folds)] | |
| raw_paths = df_filtered['Filename'].tolist() | |
| # 3. Apply Path Renaming Logic | |
| paths = self._get_simulated_paths(raw_paths) | |
| # 4. Train/Val Split | |
| self.train_paths, self.val_paths = train_test_split( | |
| paths, | |
| test_size=self.val_split, | |
| random_state=self.seed | |
| ) | |
| # 5. Instantiate Datasets | |
| if stage in ("fit", "train"): | |
| self.train_dataset = MethaneSimulatedDataset( | |
| root_dir=self.data_root, | |
| excel_file=self.excel_file, | |
| paths=self.train_paths, | |
| transform=self._get_training_transforms(), | |
| ) | |
| if stage in ("fit", "validate", "val"): | |
| self.val_dataset = MethaneSimulatedDataset( | |
| root_dir=self.data_root, | |
| excel_file=self.excel_file, | |
| paths=self.val_paths, | |
| transform=None, | |
| ) | |
| def train_dataloader(self): | |
| return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, | |
| num_workers=self.num_workers, drop_last=True) | |
| def val_dataloader(self): | |
| return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, | |
| num_workers=self.num_workers, drop_last=True) | |
| # def on_after_batch_transfer(self, batch, dataloader_idx): | |
| # # 1. Run TorchGeo default (expects 'image') | |
| # batch = super().on_after_batch_transfer(batch, dataloader_idx) | |
| # # 2. Wrap into TerraMind format {'S2L2A': ...} | |
| # if 'image' in batch: | |
| # s2_data = batch['image'] | |
| # batch['image'] = {'S2L2A': s2_data} | |
| # return batch |