from typing import Any, cast import albumentations as A from albumentations.core.composition import TransformType from albumentations.pytorch import ToTensorV2 from torch import Tensor from torch.utils.data import DataLoader from torchgeo.datamodules import NonGeoDataModule try: from datasets.hyperview_dataset import HyperviewNonGeo except ModuleNotFoundError: from hyperview_dataset import HyperviewNonGeo train_transforms = [ A.Resize(224, 224), A.GaussNoise(std_range=(0.0, 0.005), mean_range=(0.0, 0.0), p=0.5), A.ElasticTransform(p=0.25), A.RandomRotate90(p=0.5), A.VerticalFlip(p=0.5), A.HorizontalFlip(p=0.5), A.ShiftScaleRotate( rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5, ), ToTensorV2(), ] train_full_transform = A.Compose(cast(list[TransformType], train_transforms)) eval_transforms = [A.Resize(224, 224), ToTensorV2()] eval_transform = A.Compose(cast(list[TransformType], eval_transforms)) class HyperviewNonGeoDataModule(NonGeoDataModule): """DataModule for Hyperview non-geospatial hyperspectral cubes.""" def __init__( self, data_root: str, label_train_path: str | None = None, label_test_path: str | None = None, batch_size: int = 4, num_workers: int = 0, train_transform: A.Compose | None = None, val_transform: A.Compose | None = None, test_transform: A.Compose | None = None, target_index: list[int] | None = None, cropped_load: bool = False, val_ratio: float = 0.2, random_state: int = 42, fill_last_with_150: bool = True, drop_last: bool = False, exclude_11x11: bool = False, only_11x11: bool = False, mask: str = "none", num_bands: int = 12, label_scaling: str = "std", split: str = "test", aoi: str | None = None, **kwargs: Any, ) -> None: """Initialize datamodule with paths, transforms and split configuration.""" super().__init__(HyperviewNonGeo, batch_size, num_workers, **kwargs) self.data_root = data_root self.label_train_path = label_train_path self.label_test_path = label_test_path self.target_index = [0, 1, 2, 3] if target_index is None else target_index self.cropped_load = cropped_load self.val_ratio = val_ratio self.random_state = random_state self.train_transform = train_transform if train_transform is not None else train_full_transform self.val_transform = val_transform if val_transform is not None else eval_transform self.test_transform = test_transform if test_transform is not None else eval_transform self.fill_last_with_150 = fill_last_with_150 self.drop_last = drop_last self.exclude_11x11 = exclude_11x11 self.only_11x11 = only_11x11 self.mask = mask self.num_bands = num_bands self.label_scaling = label_scaling self.split = split self.aoi = aoi def _build_dataset( self, split: str, label_path: str | None, transform: A.Compose | None, aoi: str | None, ) -> HyperviewNonGeo: """Create dataset instance with shared keyword arguments.""" return HyperviewNonGeo( data_root=self.data_root, split=split, label_path=label_path, transform=transform, target_index=self.target_index, cropped_load=self.cropped_load, val_ratio=self.val_ratio, random_state=self.random_state, fill_last_with_150=self.fill_last_with_150, exclude_11x11=self.exclude_11x11, only_11x11=self.only_11x11, mask=self.mask, num_bands=self.num_bands, label_scaling=self.label_scaling, aoi=aoi, ) def setup(self, stage: str | None = None) -> None: """Instantiate datasets for requested stage.""" if stage in {None, "fit", "train"}: self.train_dataset = self._build_dataset( split="train", label_path=self.label_train_path, transform=self.train_transform, aoi=None, ) if stage in {None, "fit", "validate", "val"}: self.val_dataset = self._build_dataset( split="val", label_path=self.label_train_path, transform=self.val_transform, aoi=None, ) if stage in {None, "test", "predict"}: self.test_dataset = self._build_dataset( split=self.split, label_path=self.label_test_path, transform=self.test_transform, aoi=self.aoi, ) def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: """Build split-specific DataLoader.""" dataset = self._valid_attribute(f"{split}_dataset", "dataset") batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size") return DataLoader( dataset=dataset, batch_size=batch_size, shuffle=(split == "train"), num_workers=self.num_workers, drop_last=(split == "train" and self.drop_last), collate_fn=self.collate_fn, )