File size: 5,427 Bytes
87904b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
from typing import Any, cast
import albumentations as A
from albumentations.core.composition import TransformType
from albumentations.pytorch import ToTensorV2
from torch import Tensor
from torch.utils.data import DataLoader
from torchgeo.datamodules import NonGeoDataModule
try:
from datasets.hyperview_dataset import HyperviewNonGeo
except ModuleNotFoundError:
from hyperview_dataset import HyperviewNonGeo
train_transforms = [
A.Resize(224, 224),
A.GaussNoise(std_range=(0.0, 0.005), mean_range=(0.0, 0.0), p=0.5),
A.ElasticTransform(p=0.25),
A.RandomRotate90(p=0.5),
A.VerticalFlip(p=0.5),
A.HorizontalFlip(p=0.5),
A.ShiftScaleRotate(
rotate_limit=90,
shift_limit_x=0.05,
shift_limit_y=0.05,
p=0.5,
),
ToTensorV2(),
]
train_full_transform = A.Compose(cast(list[TransformType], train_transforms))
eval_transforms = [A.Resize(224, 224), ToTensorV2()]
eval_transform = A.Compose(cast(list[TransformType], eval_transforms))
class HyperviewNonGeoDataModule(NonGeoDataModule):
"""DataModule for Hyperview non-geospatial hyperspectral cubes."""
def __init__(
self,
data_root: str,
label_train_path: str | None = None,
label_test_path: str | None = None,
batch_size: int = 4,
num_workers: int = 0,
train_transform: A.Compose | None = None,
val_transform: A.Compose | None = None,
test_transform: A.Compose | None = None,
target_index: list[int] | None = None,
cropped_load: bool = False,
val_ratio: float = 0.2,
random_state: int = 42,
fill_last_with_150: bool = True,
drop_last: bool = False,
exclude_11x11: bool = False,
only_11x11: bool = False,
mask: str = "none",
num_bands: int = 12,
label_scaling: str = "std",
split: str = "test",
aoi: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize datamodule with paths, transforms and split configuration."""
super().__init__(HyperviewNonGeo, batch_size, num_workers, **kwargs)
self.data_root = data_root
self.label_train_path = label_train_path
self.label_test_path = label_test_path
self.target_index = [0, 1, 2, 3] if target_index is None else target_index
self.cropped_load = cropped_load
self.val_ratio = val_ratio
self.random_state = random_state
self.train_transform = train_transform if train_transform is not None else train_full_transform
self.val_transform = val_transform if val_transform is not None else eval_transform
self.test_transform = test_transform if test_transform is not None else eval_transform
self.fill_last_with_150 = fill_last_with_150
self.drop_last = drop_last
self.exclude_11x11 = exclude_11x11
self.only_11x11 = only_11x11
self.mask = mask
self.num_bands = num_bands
self.label_scaling = label_scaling
self.split = split
self.aoi = aoi
def _build_dataset(
self,
split: str,
label_path: str | None,
transform: A.Compose | None,
aoi: str | None,
) -> HyperviewNonGeo:
"""Create dataset instance with shared keyword arguments."""
return HyperviewNonGeo(
data_root=self.data_root,
split=split,
label_path=label_path,
transform=transform,
target_index=self.target_index,
cropped_load=self.cropped_load,
val_ratio=self.val_ratio,
random_state=self.random_state,
fill_last_with_150=self.fill_last_with_150,
exclude_11x11=self.exclude_11x11,
only_11x11=self.only_11x11,
mask=self.mask,
num_bands=self.num_bands,
label_scaling=self.label_scaling,
aoi=aoi,
)
def setup(self, stage: str | None = None) -> None:
"""Instantiate datasets for requested stage."""
if stage in {None, "fit", "train"}:
self.train_dataset = self._build_dataset(
split="train",
label_path=self.label_train_path,
transform=self.train_transform,
aoi=None,
)
if stage in {None, "fit", "validate", "val"}:
self.val_dataset = self._build_dataset(
split="val",
label_path=self.label_train_path,
transform=self.val_transform,
aoi=None,
)
if stage in {None, "test", "predict"}:
self.test_dataset = self._build_dataset(
split=self.split,
label_path=self.label_test_path,
transform=self.test_transform,
aoi=self.aoi,
)
def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
"""Build split-specific DataLoader."""
dataset = self._valid_attribute(f"{split}_dataset", "dataset")
batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
return DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=(split == "train"),
num_workers=self.num_workers,
drop_last=(split == "train" and self.drop_last),
collate_fn=self.collate_fn,
)
|