File size: 2,381 Bytes
97a17c2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | import os
import rasterio
import torch
from torchgeo.datasets import NonGeoDataset
from torch.utils.data import DataLoader
from torchgeo.datamodules import NonGeoDataModule
from methan_text_dataset import MethaneTextDataset
class MethaneTextDataModule(NonGeoDataModule):
"""
A DataModule for handling MethaneClassificationDataset
"""
def __init__(
self,
data_root: str,
paths: list,
captions: list,
batch_size: int = 8,
num_workers: int = 0,
train_transform: callable = None,
val_transform: callable = None,
test_transform: callable = None,
**kwargs
):
super().__init__(MethaneTextDataset, batch_size, num_workers, **kwargs)
self.data_root = data_root
self.paths = paths
self.captions = captions
self.train_transform = train_transform
self.val_transform = val_transform
self.test_transform = test_transform
def setup(self, stage: str = None):
if stage in ("fit", "train"):
self.train_dataset = MethaneTextDataset(
root_dir=self.data_root,
paths=self.paths,
captions=self.captions,
transform=self.train_transform,
)
if stage in ("fit", "validate", "val"):
self.val_dataset = MethaneTextDataset(
root_dir=self.data_root,
paths=self.paths,
captions=self.captions,
transform=self.val_transform,
)
if stage in ("test", "predict"):
self.test_dataset = MethaneTextDataset(
root_dir=self.data_root,
paths=self.paths,
captions=self.captions,
transform=self.test_transform,
)
def train_dataloader(self):
return DataLoader(
self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True
)
def val_dataloader(self):
return DataLoader(
self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
)
def test_dataloader(self):
return DataLoader(
self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
) |