| from typing import Any, cast | |
| import albumentations as A | |
| from albumentations.core.composition import TransformType | |
| from albumentations.pytorch import ToTensorV2 | |
| from torch import Tensor | |
| from torch.utils.data import DataLoader | |
| from torchgeo.datamodules import NonGeoDataModule | |
| try: | |
| from datasets.hyperview_dataset import HyperviewNonGeo | |
| except ModuleNotFoundError: | |
| from hyperview_dataset import HyperviewNonGeo | |
| train_transforms = [ | |
| A.Resize(224, 224), | |
| A.GaussNoise(std_range=(0.0, 0.005), mean_range=(0.0, 0.0), p=0.5), | |
| A.ElasticTransform(p=0.25), | |
| A.RandomRotate90(p=0.5), | |
| A.VerticalFlip(p=0.5), | |
| A.HorizontalFlip(p=0.5), | |
| A.ShiftScaleRotate( | |
| rotate_limit=90, | |
| shift_limit_x=0.05, | |
| shift_limit_y=0.05, | |
| p=0.5, | |
| ), | |
| ToTensorV2(), | |
| ] | |
| train_full_transform = A.Compose(cast(list[TransformType], train_transforms)) | |
| eval_transforms = [A.Resize(224, 224), ToTensorV2()] | |
| eval_transform = A.Compose(cast(list[TransformType], eval_transforms)) | |
| class HyperviewNonGeoDataModule(NonGeoDataModule): | |
| """DataModule for Hyperview non-geospatial hyperspectral cubes.""" | |
| def __init__( | |
| self, | |
| data_root: str, | |
| label_train_path: str | None = None, | |
| label_test_path: str | None = None, | |
| batch_size: int = 4, | |
| num_workers: int = 0, | |
| train_transform: A.Compose | None = None, | |
| val_transform: A.Compose | None = None, | |
| test_transform: A.Compose | None = None, | |
| target_index: list[int] | None = None, | |
| cropped_load: bool = False, | |
| val_ratio: float = 0.2, | |
| random_state: int = 42, | |
| fill_last_with_150: bool = True, | |
| drop_last: bool = False, | |
| exclude_11x11: bool = False, | |
| only_11x11: bool = False, | |
| mask: str = "none", | |
| num_bands: int = 12, | |
| label_scaling: str = "std", | |
| split: str = "test", | |
| aoi: str | None = None, | |
| **kwargs: Any, | |
| ) -> None: | |
| """Initialize datamodule with paths, transforms and split configuration.""" | |
| super().__init__(HyperviewNonGeo, batch_size, num_workers, **kwargs) | |
| self.data_root = data_root | |
| self.label_train_path = label_train_path | |
| self.label_test_path = label_test_path | |
| self.target_index = [0, 1, 2, 3] if target_index is None else target_index | |
| self.cropped_load = cropped_load | |
| self.val_ratio = val_ratio | |
| self.random_state = random_state | |
| self.train_transform = train_transform if train_transform is not None else train_full_transform | |
| self.val_transform = val_transform if val_transform is not None else eval_transform | |
| self.test_transform = test_transform if test_transform is not None else eval_transform | |
| self.fill_last_with_150 = fill_last_with_150 | |
| self.drop_last = drop_last | |
| self.exclude_11x11 = exclude_11x11 | |
| self.only_11x11 = only_11x11 | |
| self.mask = mask | |
| self.num_bands = num_bands | |
| self.label_scaling = label_scaling | |
| self.split = split | |
| self.aoi = aoi | |
| def _build_dataset( | |
| self, | |
| split: str, | |
| label_path: str | None, | |
| transform: A.Compose | None, | |
| aoi: str | None, | |
| ) -> HyperviewNonGeo: | |
| """Create dataset instance with shared keyword arguments.""" | |
| return HyperviewNonGeo( | |
| data_root=self.data_root, | |
| split=split, | |
| label_path=label_path, | |
| transform=transform, | |
| target_index=self.target_index, | |
| cropped_load=self.cropped_load, | |
| val_ratio=self.val_ratio, | |
| random_state=self.random_state, | |
| fill_last_with_150=self.fill_last_with_150, | |
| exclude_11x11=self.exclude_11x11, | |
| only_11x11=self.only_11x11, | |
| mask=self.mask, | |
| num_bands=self.num_bands, | |
| label_scaling=self.label_scaling, | |
| aoi=aoi, | |
| ) | |
| def setup(self, stage: str | None = None) -> None: | |
| """Instantiate datasets for requested stage.""" | |
| if stage in {None, "fit", "train"}: | |
| self.train_dataset = self._build_dataset( | |
| split="train", | |
| label_path=self.label_train_path, | |
| transform=self.train_transform, | |
| aoi=None, | |
| ) | |
| if stage in {None, "fit", "validate", "val"}: | |
| self.val_dataset = self._build_dataset( | |
| split="val", | |
| label_path=self.label_train_path, | |
| transform=self.val_transform, | |
| aoi=None, | |
| ) | |
| if stage in {None, "test", "predict"}: | |
| self.test_dataset = self._build_dataset( | |
| split=self.split, | |
| label_path=self.label_test_path, | |
| transform=self.test_transform, | |
| aoi=self.aoi, | |
| ) | |
| def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: | |
| """Build split-specific DataLoader.""" | |
| dataset = self._valid_attribute(f"{split}_dataset", "dataset") | |
| batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size") | |
| return DataLoader( | |
| dataset=dataset, | |
| batch_size=batch_size, | |
| shuffle=(split == "train"), | |
| num_workers=self.num_workers, | |
| drop_last=(split == "train" and self.drop_last), | |
| collate_fn=self.collate_fn, | |
| ) | |