import pandas as pd import albumentations as A from typing import Optional, List from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader from torchgeo.datamodules import NonGeoDataModule from methane_classification_dataset import MethaneClassificationDataset class MethaneClassificationDataModule(NonGeoDataModule): def __init__( self, data_root: str, excel_file: str, batch_size: int = 8, num_workers: int = 0, val_split: float = 0.2, seed: int = 42, **kwargs ): # We pass "NonGeoDataset" just to satisfy the parent class, # but we instantiate specific datasets in setup() super().__init__(MethaneClassificationDataset, batch_size, num_workers, **kwargs) self.data_root = data_root self.excel_file = excel_file self.val_split = val_split self.seed = seed self.batch_size = batch_size self.num_workers = num_workers # State variables for paths self.train_paths = [] self.val_paths = [] def _get_training_transforms(self): """Internal definition of training transforms""" return A.Compose([ A.ElasticTransform(p=0.25), A.RandomRotate90(p=0.5), A.Flip(p=0.5), A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5) ]) def setup(self, stage: str = None): # 1. Read the Excel File try: df = pd.read_csv(self.excel_file) if self.excel_file.endswith('.csv') else pd.read_excel(self.excel_file) except Exception as e: raise RuntimeError(f"Failed to load summary file: {e}") # 2. Filter valid paths (checking if Fold column exists or just using all data) # Assuming we just use all data in the file and split it 80/20 here. # If you need specific Fold filtering, add that logic here. all_paths = df['Filename'].tolist() # 3. Perform the Split self.train_paths, self.val_paths = train_test_split( all_paths, test_size=self.val_split, random_state=self.seed ) # 4. Instantiate Datasets if stage in ("fit", "train"): self.train_dataset = MethaneClassificationDataset( root_dir=self.data_root, excel_file=self.excel_file, paths=self.train_paths, transform=self._get_training_transforms(), ) if stage in ("fit", "validate", "val"): self.val_dataset = MethaneClassificationDataset( root_dir=self.data_root, excel_file=self.excel_file, paths=self.val_paths, transform=None, # No transforms for validation ) if stage in ("test", "predict"): # For testing, you might want to use a specific hold-out set # For now, reusing val_paths or you can add logic to load a test fold self.test_dataset = MethaneClassificationDataset( root_dir=self.data_root, excel_file=self.excel_file, paths=self.val_paths, transform=None, ) def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True ) def val_dataloader(self): return DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True ) def test_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True )