File size: 13,156 Bytes
87904b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
import glob
import os
from pathlib import Path
from typing import Any, Optional

import albumentations as A
import numpy as np
import pandas as pd
import torch
from torch import Tensor
from torchgeo.datasets import NonGeoDataset
from sklearn.model_selection import train_test_split


class HyperviewNonGeo(NonGeoDataset):
    """
    Modified dataset that can load either 6, 12, or all 150 channels.

    - For `num_bands=12`, it loads the standard Sentinel-2 set (skipping B10),
      with B11/B12 either filled from band #150 or zeroed out (depending on the
      `fill_last_with_150` flag).
    - For `num_bands=6`, it loads an HLS-like subset.
    - For `num_bands=150`, it simply loads the entire hyperspectral cube.

    The 'mask' parameter can take one of three values:
      - "none": No mask is used (arrays loaded as np.ndarray).
      - "og": A MaskedArray is created for each file (original mask),
        but no cropping is performed.
      - "square": A MaskedArray is created, and we find the largest
        square region in which all pixels are unmasked (mask=False),
        cropping the data to that region.
    """

    _S2_TO_HYPER_12 = [1, 10, 32, 64, 77, 88, 101, 120, 127, 150, None, None]
    _S2_TO_ENMAP_12 = [6, 16, 30, 48, 54, 59, 64, 72, 75, 90, 150, 191]
    _S2_TO_INTUITON_12 = [5, 14, 35, 85, 101, 112, 134, 155, 165, 179, None, None]
    _HLS_TO_HYPER_6 = [10, 32, 64, 127, None, None]

    _LABEL_MIN = np.array([20.3000, 21.1000, 26.8000, 5.8000], dtype=np.float32)
    _LABEL_MAX = np.array([325.0000, 625.0000, 400.0000, 7.8000], dtype=np.float32)
    _LABEL_MEAN = np.array([70.4617, 226.8499, 159.3915, 6.7789], dtype=np.float32)
    _LABEL_STD = np.array([30.1490, 60.5661, 39.7610, 0.2593], dtype=np.float32)

    splits = {
        "train": "train",
        "val": "val",
        "test": "test",
        "test_dat": "test_dat",
        "test_enmap": "test_enmap",
        "test_intuition": "test_intuition",
    }

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        label_path: Optional[str] = None,
        transform: Optional[A.Compose] = None,
        target_index: list[int] = [0, 1, 2, 3],
        cropped_load: bool = False,
        val_ratio: float = 0.2,
        random_state: int = 42,
        mask: str = "none",
        fill_last_with_150: bool = True,
        exclude_11x11: bool = False,
        only_11x11: bool = False,
        num_bands: int = 12,
        label_scaling: str = "std",
        aoi: Optional[str] = None,
    ) -> None:
        """
        Args:
            data_root (str): Path to the root directory of the data.
            split (str): One of 'train', 'val', or 'test'.
            label_path (str, optional): Path to CSV with labels.
            transform (A.Compose, optional): Albumentations transforms.
            target_index (list[int]): Indices of columns to select as labels.
            cropped_load (bool): Not directly used here (legacy param).
            val_ratio (float): Fraction of data to use as validation (when split is train/val).
            random_state (int): Random seed for train/val split.
            mask (str): Controls how we handle mask. One of:
                - "none"   -> no mask
                - "og"     -> original mask
                - "square" -> original mask + crop to largest unmasked square
            fill_last_with_150 (bool): If True, B11/B12 are replaced by band #150
                                       (in 12- or 6-band scenario) otherwise zero-filled.
                                       This has no effect when num_bands=150.
            exclude_11x11 (bool): If True, images of size (11,11) are filtered out.
            only_11x11 (bool): If True, only images of size (11,11) are kept.
            num_bands (int): Number of bands to use. One of {6, 12, 150}.
            label_scaling (str): Label scaling mode. One of
                {"none", "std", "norm_max", "norm_min_max"}.
        """
        super().__init__()

        if split not in self.splits:
            raise ValueError(f"split must be one of {list(self.splits.keys())}, got '{split}'")

        if mask not in ("none", "og", "square"):
            raise ValueError(f"mask must be one of ['none', 'og', 'square'], got '{mask}'")

        if num_bands not in (6, 12, 150):
            raise ValueError(f"num_bands must be one of [6, 12, 150], got {num_bands}")

        if label_scaling not in ("none", "std", "norm_max", "norm_min_max"):
            raise ValueError(
                "label_scaling must be one of ['none', 'std', 'norm_max', 'norm_min_max'], "
                f"got '{label_scaling}'"
            )

        self.split = split
        self.mask = mask
        self.fill_last_with_150 = fill_last_with_150
        self.label_scaling = label_scaling
        self.data_root = Path(data_root)
        self.exclude_11x11 = exclude_11x11
        self.only_11x11 = only_11x11
        self.num_bands = num_bands

        if self.split in ["train", "val"]:
            self.data_dir = self.data_root / "train_data"
        elif self.split == "test_dat":
            self.data_dir = self.data_root / "test_dat"
        elif self.split == "test_enmap":
            if aoi is None:
                raise ValueError("aoi must be provided when split is 'test_enmap'")
            self.data_dir = self.data_root / "test_enmap" / aoi
        elif self.split == "test_intuition":
            self.data_dir = self.data_root / "test_intuition"
        else:
            self.data_dir = self.data_root / "test_data"

        self.files = sorted(
            glob.glob(os.path.join(self.data_dir, "*.npz")),
            key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
        )

        if self.exclude_11x11:
            filtered_files = []
            for fp in self.files:
                with np.load(fp) as npz:
                    data = npz["data"]
                if not (data.shape[1] == 11 and data.shape[2] == 11):
                    filtered_files.append(fp)
            self.files = filtered_files


        if self.only_11x11:
            filtered_files = []
            for fp in self.files:
                with np.load(fp) as npz:
                    data = npz["data"]
                if data.shape[1] == 11 and data.shape[2] == 11:
                    filtered_files.append(fp)
            self.files = filtered_files

        if self.split in ["train", "val"]:
            indices = np.arange(len(self.files))
            train_idx, val_idx = train_test_split(
                indices, test_size=val_ratio, random_state=random_state, shuffle=True
            )
            if self.split == "train":
                self.files = [self.files[i] for i in train_idx]
            else:
                self.files = [self.files[i] for i in val_idx]

        self.labels = None
        if label_path is not None and os.path.exists(label_path):
            self.labels = self._scale_labels(self._load_labels(label_path))

        self.target_index = target_index
        self.transform = transform
        self.cropped_load = cropped_load

        if self.num_bands == 12 and self.split == "test_enmap":
            band_mapping = self._S2_TO_ENMAP_12
        elif self.num_bands == 12 and self.split == "test_intuition":
            band_mapping = self._S2_TO_INTUITON_12
        elif self.num_bands == 12:
            band_mapping = self._S2_TO_HYPER_12
        elif self.num_bands == 6:
            band_mapping = self._HLS_TO_HYPER_6
        else:
            band_mapping = list(range(1, 151))

        self.s2_zero_based = []
        for b_1 in band_mapping:
            if b_1 is None:
                if self.fill_last_with_150 and self.split != "test_intuition":
                    self.s2_zero_based.append(149)
                elif self.fill_last_with_150 and self.split == "test_intuition":
                    self.s2_zero_based.append(179)
                else:
                    self.s2_zero_based.append(-1)
            else:
                self.s2_zero_based.append(b_1 - 1)

    def __len__(self) -> int:
        """Return dataset size."""
        return len(self.files)

    def __getitem__(self, index: int) -> dict[str, Any]:
        """Load one sample with optional masking, scaling and transforms."""
        file_path = self.files[index]

        with np.load(file_path) as npz:
            if self.mask == "none":
                if self.split == "test_enmap":
                    arr = npz["enmap"]
                else:
                    arr = npz["data"]
            else:
                arr = np.ma.MaskedArray(**npz)

        channels = []
        if isinstance(arr, np.ma.MaskedArray):
            for band_idx in self.s2_zero_based:
                if band_idx == -1:
                    h, w = arr.shape[-2], arr.shape[-1]
                    zeros_data = np.zeros((h, w), dtype=arr.dtype)
                    zeros_mask = np.zeros((h, w), dtype=bool)
                    channel_masked = np.ma.MaskedArray(data=zeros_data, mask=zeros_mask)
                    channels.append(channel_masked)
                else:
                    channels.append(arr[band_idx])
            data_arr = np.ma.stack(channels, axis=0)
        else:
            for band_idx in self.s2_zero_based:
                if band_idx == -1:
                    h, w = arr.shape[-2], arr.shape[-1]
                    channels.append(np.zeros((h, w), dtype=arr.dtype))
                else:
                    channels.append(arr[band_idx])
            data_arr = np.stack(channels, axis=0)

        if isinstance(data_arr, np.ma.MaskedArray) and self.mask == "square":
            data_arr = self._crop_to_largest_square_unmasked(data_arr)

        if isinstance(data_arr, np.ma.MaskedArray):
            data_arr = data_arr.filled(0)

        data_arr = (data_arr / 5419.0).astype(np.float32)
        data_arr = np.transpose(data_arr, (1, 2, 0))

        if self.labels is not None:
            base = os.path.basename(file_path).replace(".npz", "")
            sample_id = int(base)
            label_row = self.labels[sample_id][self.target_index]
        else:
            label_row = np.zeros(len(self.target_index), dtype=np.float32)

        output = {"image": data_arr, "S2L2A": data_arr, "label": label_row}

        if self.transform is not None:
            transformed = self.transform(image=output["image"])
            output["image"] = transformed["image"]
            output["S2L2A"] = output["image"]

        output["label"] = torch.tensor(output["label"], dtype=torch.float32)
        return output

    @staticmethod
    def _load_labels(label_path: str) -> np.ndarray:
        """Load labels CSV into a dense array indexed by sample_index."""
        df = pd.read_csv(label_path)
        max_idx = int(np.asarray(df["sample_index"].max()).item())
        label_array = np.zeros((max_idx + 1, 4), dtype=np.float32)
        for row in df.itertuples():
            sample_index = int(np.asarray(row.sample_index).item())
            label_array[sample_index] = np.array([row.P, row.K, row.Mg, row.pH], dtype=np.float32)
        return label_array

    def _scale_labels(self, labels: np.ndarray) -> np.ndarray:
        """Scale labels according to configured mode."""
        labels = labels.astype(np.float32, copy=False)
        if self.label_scaling == "none":
            return labels
        if self.label_scaling == "std":
            return (labels - self._LABEL_MEAN) / self._LABEL_STD
        if self.label_scaling == "norm_max":
            return labels / self._LABEL_MAX
        if self.label_scaling == "norm_min_max":
            denom = np.maximum(self._LABEL_MAX - self._LABEL_MIN, 1e-8)
            return (labels - self._LABEL_MIN) / denom
        raise ValueError(f"Unknown label_scaling mode: {self.label_scaling}")

    @staticmethod
    def _crop_to_largest_square_unmasked(masked_data: np.ma.MaskedArray) -> np.ma.MaskedArray:
        """Return the largest square region containing only unmasked pixels."""
        combined_mask = np.asarray(masked_data.mask.any(axis=0), dtype=bool)
        top, left, size = HyperviewNonGeo._find_largest_square_false(combined_mask)
        cropped = masked_data[:, top : top + size, left : left + size]
        return cropped

    @staticmethod
    def _find_largest_square_false(mask_2d: np.ndarray) -> tuple[int, int, int]:
        """Find the largest False-valued square and return `(top, left, size)`."""
        H, W = mask_2d.shape
        dp = np.zeros((H, W), dtype=np.int32)

        max_size = 0
        max_pos = (0, 0)

        for i in range(H):
            for j in range(W):
                if not mask_2d[i, j]:
                    if i == 0 or j == 0:
                        dp[i, j] = 1
                    else:
                        dp[i, j] = min(dp[i-1, j], dp[i, j-1], dp[i-1, j-1]) + 1

                    if dp[i, j] > max_size:
                        max_size = dp[i, j]
                        max_pos = (i, j)

        (best_i, best_j) = max_pos
        top = best_i - max_size + 1
        left = best_j - max_size + 1
        return top, left, max_size