File size: 5,427 Bytes
87904b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from typing import Any, cast

import albumentations as A
from albumentations.core.composition import TransformType
from albumentations.pytorch import ToTensorV2
from torch import Tensor
from torch.utils.data import DataLoader
from torchgeo.datamodules import NonGeoDataModule

try:
    from datasets.hyperview_dataset import HyperviewNonGeo
except ModuleNotFoundError:
    from hyperview_dataset import HyperviewNonGeo

train_transforms = [
    A.Resize(224, 224),
    A.GaussNoise(std_range=(0.0, 0.005), mean_range=(0.0, 0.0), p=0.5),
    A.ElasticTransform(p=0.25),
    A.RandomRotate90(p=0.5),
    A.VerticalFlip(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(
        rotate_limit=90,
        shift_limit_x=0.05,
        shift_limit_y=0.05,
        p=0.5,
    ),
    ToTensorV2(),
]
train_full_transform = A.Compose(cast(list[TransformType], train_transforms))

eval_transforms = [A.Resize(224, 224), ToTensorV2()]
eval_transform = A.Compose(cast(list[TransformType], eval_transforms))


class HyperviewNonGeoDataModule(NonGeoDataModule):
    """DataModule for Hyperview non-geospatial hyperspectral cubes."""

    def __init__(
        self,
        data_root: str,
        label_train_path: str | None = None,
        label_test_path: str | None = None,
        batch_size: int = 4,
        num_workers: int = 0,
        train_transform: A.Compose | None = None,
        val_transform: A.Compose | None = None,
        test_transform: A.Compose | None = None,
        target_index: list[int] | None = None,
        cropped_load: bool = False,
        val_ratio: float = 0.2,
        random_state: int = 42,
        fill_last_with_150: bool = True,
        drop_last: bool = False,
        exclude_11x11: bool = False,
        only_11x11: bool = False,
        mask: str = "none",
        num_bands: int = 12,
        label_scaling: str = "std",
        split: str = "test",
        aoi: str | None = None,
        **kwargs: Any,
    ) -> None:
        """Initialize datamodule with paths, transforms and split configuration."""
        super().__init__(HyperviewNonGeo, batch_size, num_workers, **kwargs)

        self.data_root = data_root
        self.label_train_path = label_train_path
        self.label_test_path = label_test_path
        self.target_index = [0, 1, 2, 3] if target_index is None else target_index
        self.cropped_load = cropped_load
        self.val_ratio = val_ratio
        self.random_state = random_state

        self.train_transform = train_transform if train_transform is not None else train_full_transform
        self.val_transform = val_transform if val_transform is not None else eval_transform
        self.test_transform = test_transform if test_transform is not None else eval_transform

        self.fill_last_with_150 = fill_last_with_150
        self.drop_last = drop_last
        self.exclude_11x11 = exclude_11x11
        self.only_11x11 = only_11x11
        self.mask = mask
        self.num_bands = num_bands
        self.label_scaling = label_scaling
        self.split = split
        self.aoi = aoi

    def _build_dataset(
        self,
        split: str,
        label_path: str | None,
        transform: A.Compose | None,
        aoi: str | None,
    ) -> HyperviewNonGeo:
        """Create dataset instance with shared keyword arguments."""
        return HyperviewNonGeo(
            data_root=self.data_root,
            split=split,
            label_path=label_path,
            transform=transform,
            target_index=self.target_index,
            cropped_load=self.cropped_load,
            val_ratio=self.val_ratio,
            random_state=self.random_state,
            fill_last_with_150=self.fill_last_with_150,
            exclude_11x11=self.exclude_11x11,
            only_11x11=self.only_11x11,
            mask=self.mask,
            num_bands=self.num_bands,
            label_scaling=self.label_scaling,
            aoi=aoi,
        )

    def setup(self, stage: str | None = None) -> None:
        """Instantiate datasets for requested stage."""
        if stage in {None, "fit", "train"}:
            self.train_dataset = self._build_dataset(
                split="train",
                label_path=self.label_train_path,
                transform=self.train_transform,
                aoi=None,
            )

        if stage in {None, "fit", "validate", "val"}:
            self.val_dataset = self._build_dataset(
                split="val",
                label_path=self.label_train_path,
                transform=self.val_transform,
                aoi=None,
            )

        if stage in {None, "test", "predict"}:
            self.test_dataset = self._build_dataset(
                split=self.split,
                label_path=self.label_test_path,
                transform=self.test_transform,
                aoi=self.aoi,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Build split-specific DataLoader."""
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")

        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=(split == "train"),
            num_workers=self.num_workers,
            drop_last=(split == "train" and self.drop_last),
            collate_fn=self.collate_fn,
        )