import os import rasterio import torch from torchgeo.datasets import NonGeoDataset from torch.utils.data import DataLoader from torchgeo.datamodules import NonGeoDataModule from methane_simulated_dataset import MethaneSimulatedDataset # from methane_classification_dataset import MethaneClassificationDataset class MethaneSimulatedDataModule(NonGeoDataModule): """ A DataModule for handling MethaneClassificationDataset """ def __init__( self, data_root: str, excel_file: str, paths: list, batch_size: int = 8, num_workers: int = 0, train_transform: callable = None, val_transform: callable = None, test_transform: callable = None, **kwargs ): super().__init__(MethaneSimulatedDataset, batch_size, num_workers, **kwargs) self.data_root = data_root self.excel_file = excel_file self.paths = paths self.train_transform = train_transform self.val_transform = val_transform self.test_transform = test_transform def setup(self, stage: str = None): if stage in ("fit", "train"): self.train_dataset = MethaneSimulatedDataset( root_dir=self.data_root, excel_file=self.excel_file, paths=self.paths, transform=self.train_transform, ) if stage in ("fit", "validate", "val"): self.val_dataset = MethaneSimulatedDataset( root_dir=self.data_root, excel_file=self.excel_file, paths=self.paths, transform=self.val_transform, ) if stage in ("test", "predict"): self.test_dataset = MethaneSimulatedDataset( root_dir=self.data_root, excel_file=self.excel_file, paths=self.paths, transform=self.test_transform, ) def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True ) def val_dataloader(self): return DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True ) def test_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True )