import os import rasterio import torch from torchgeo.datasets import NonGeoDataset from torch.utils.data import DataLoader from torchgeo.datamodules import NonGeoDataModule from methan_text_dataset import MethaneTextDataset class MethaneTextDataModule(NonGeoDataModule): """ A DataModule for handling MethaneClassificationDataset """ def __init__( self, data_root: str, paths: list, captions: list, batch_size: int = 8, num_workers: int = 0, train_transform: callable = None, val_transform: callable = None, test_transform: callable = None, **kwargs ): super().__init__(MethaneTextDataset, batch_size, num_workers, **kwargs) self.data_root = data_root self.paths = paths self.captions = captions self.train_transform = train_transform self.val_transform = val_transform self.test_transform = test_transform def setup(self, stage: str = None): if stage in ("fit", "train"): self.train_dataset = MethaneTextDataset( root_dir=self.data_root, paths=self.paths, captions=self.captions, transform=self.train_transform, ) if stage in ("fit", "validate", "val"): self.val_dataset = MethaneTextDataset( root_dir=self.data_root, paths=self.paths, captions=self.captions, transform=self.val_transform, ) if stage in ("test", "predict"): self.test_dataset = MethaneTextDataset( root_dir=self.data_root, paths=self.paths, captions=self.captions, transform=self.test_transform, ) def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True ) def val_dataloader(self): return DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True ) def test_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True )