TerraMind-Methane-Classification / sentinel2_classification_finetuning /script /methane_simulated_datamodule.py
| import os | |
| import rasterio | |
| import torch | |
| from torchgeo.datasets import NonGeoDataset | |
| from torch.utils.data import DataLoader | |
| from torchgeo.datamodules import NonGeoDataModule | |
| from methane_simulated_dataset import MethaneSimulatedDataset | |
| # from methane_classification_dataset import MethaneClassificationDataset | |
| class MethaneSimulatedDataModule(NonGeoDataModule): | |
| """ | |
| A DataModule for handling MethaneClassificationDataset | |
| """ | |
| def __init__( | |
| self, | |
| data_root: str, | |
| excel_file: str, | |
| paths: list, | |
| batch_size: int = 8, | |
| num_workers: int = 0, | |
| train_transform: callable = None, | |
| val_transform: callable = None, | |
| test_transform: callable = None, | |
| **kwargs | |
| ): | |
| super().__init__(MethaneSimulatedDataset, batch_size, num_workers, **kwargs) | |
| self.data_root = data_root | |
| self.excel_file = excel_file | |
| self.paths = paths | |
| self.train_transform = train_transform | |
| self.val_transform = val_transform | |
| self.test_transform = test_transform | |
| def setup(self, stage: str = None): | |
| if stage in ("fit", "train"): | |
| self.train_dataset = MethaneSimulatedDataset( | |
| root_dir=self.data_root, | |
| excel_file=self.excel_file, | |
| paths=self.paths, | |
| transform=self.train_transform, | |
| ) | |
| if stage in ("fit", "validate", "val"): | |
| self.val_dataset = MethaneSimulatedDataset( | |
| root_dir=self.data_root, | |
| excel_file=self.excel_file, | |
| paths=self.paths, | |
| transform=self.val_transform, | |
| ) | |
| if stage in ("test", "predict"): | |
| self.test_dataset = MethaneSimulatedDataset( | |
| root_dir=self.data_root, | |
| excel_file=self.excel_file, | |
| paths=self.paths, | |
| transform=self.test_transform, | |
| ) | |
| def train_dataloader(self): | |
| return DataLoader( | |
| self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True | |
| ) | |
| def test_dataloader(self): | |
| return DataLoader( | |
| self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True | |
| ) | |