|
|
import glob |
|
|
import os |
|
|
from pathlib import Path |
|
|
from typing import Any, Optional |
|
|
|
|
|
import albumentations as A |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
from torch import Tensor |
|
|
from torchgeo.datasets import NonGeoDataset |
|
|
from sklearn.model_selection import train_test_split |
|
|
|
|
|
|
|
|
class HyperviewNonGeo(NonGeoDataset): |
|
|
""" |
|
|
Modified dataset that can load either 6, 12, or all 150 channels. |
|
|
|
|
|
- For `num_bands=12`, it loads the standard Sentinel-2 set (skipping B10), |
|
|
with B11/B12 either filled from band #150 or zeroed out (depending on the |
|
|
`fill_last_with_150` flag). |
|
|
- For `num_bands=6`, it loads an HLS-like subset. |
|
|
- For `num_bands=150`, it simply loads the entire hyperspectral cube. |
|
|
|
|
|
The 'mask' parameter can take one of three values: |
|
|
- "none": No mask is used (arrays loaded as np.ndarray). |
|
|
- "og": A MaskedArray is created for each file (original mask), |
|
|
but no cropping is performed. |
|
|
- "square": A MaskedArray is created, and we find the largest |
|
|
square region in which all pixels are unmasked (mask=False), |
|
|
cropping the data to that region. |
|
|
""" |
|
|
|
|
|
_S2_TO_HYPER_12 = [1, 10, 32, 64, 77, 88, 101, 120, 127, 150, None, None] |
|
|
_S2_TO_ENMAP_12 = [6, 16, 30, 48, 54, 59, 64, 72, 75, 90, 150, 191] |
|
|
_S2_TO_INTUITON_12 = [5, 14, 35, 85, 101, 112, 134, 155, 165, 179, None, None] |
|
|
_HLS_TO_HYPER_6 = [10, 32, 64, 127, None, None] |
|
|
|
|
|
_LABEL_MIN = np.array([20.3000, 21.1000, 26.8000, 5.8000], dtype=np.float32) |
|
|
_LABEL_MAX = np.array([325.0000, 625.0000, 400.0000, 7.8000], dtype=np.float32) |
|
|
_LABEL_MEAN = np.array([70.4617, 226.8499, 159.3915, 6.7789], dtype=np.float32) |
|
|
_LABEL_STD = np.array([30.1490, 60.5661, 39.7610, 0.2593], dtype=np.float32) |
|
|
|
|
|
splits = { |
|
|
"train": "train", |
|
|
"val": "val", |
|
|
"test": "test", |
|
|
"test_dat": "test_dat", |
|
|
"test_enmap": "test_enmap", |
|
|
"test_intuition": "test_intuition", |
|
|
} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
data_root: str, |
|
|
split: str = "train", |
|
|
label_path: Optional[str] = None, |
|
|
transform: Optional[A.Compose] = None, |
|
|
target_index: list[int] = [0, 1, 2, 3], |
|
|
cropped_load: bool = False, |
|
|
val_ratio: float = 0.2, |
|
|
random_state: int = 42, |
|
|
mask: str = "none", |
|
|
fill_last_with_150: bool = True, |
|
|
exclude_11x11: bool = False, |
|
|
only_11x11: bool = False, |
|
|
num_bands: int = 12, |
|
|
label_scaling: str = "std", |
|
|
aoi: Optional[str] = None, |
|
|
) -> None: |
|
|
""" |
|
|
Args: |
|
|
data_root (str): Path to the root directory of the data. |
|
|
split (str): One of 'train', 'val', or 'test'. |
|
|
label_path (str, optional): Path to CSV with labels. |
|
|
transform (A.Compose, optional): Albumentations transforms. |
|
|
target_index (list[int]): Indices of columns to select as labels. |
|
|
cropped_load (bool): Not directly used here (legacy param). |
|
|
val_ratio (float): Fraction of data to use as validation (when split is train/val). |
|
|
random_state (int): Random seed for train/val split. |
|
|
mask (str): Controls how we handle mask. One of: |
|
|
- "none" -> no mask |
|
|
- "og" -> original mask |
|
|
- "square" -> original mask + crop to largest unmasked square |
|
|
fill_last_with_150 (bool): If True, B11/B12 are replaced by band #150 |
|
|
(in 12- or 6-band scenario) otherwise zero-filled. |
|
|
This has no effect when num_bands=150. |
|
|
exclude_11x11 (bool): If True, images of size (11,11) are filtered out. |
|
|
only_11x11 (bool): If True, only images of size (11,11) are kept. |
|
|
num_bands (int): Number of bands to use. One of {6, 12, 150}. |
|
|
label_scaling (str): Label scaling mode. One of |
|
|
{"none", "std", "norm_max", "norm_min_max"}. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
if split not in self.splits: |
|
|
raise ValueError(f"split must be one of {list(self.splits.keys())}, got '{split}'") |
|
|
|
|
|
if mask not in ("none", "og", "square"): |
|
|
raise ValueError(f"mask must be one of ['none', 'og', 'square'], got '{mask}'") |
|
|
|
|
|
if num_bands not in (6, 12, 150): |
|
|
raise ValueError(f"num_bands must be one of [6, 12, 150], got {num_bands}") |
|
|
|
|
|
if label_scaling not in ("none", "std", "norm_max", "norm_min_max"): |
|
|
raise ValueError( |
|
|
"label_scaling must be one of ['none', 'std', 'norm_max', 'norm_min_max'], " |
|
|
f"got '{label_scaling}'" |
|
|
) |
|
|
|
|
|
self.split = split |
|
|
self.mask = mask |
|
|
self.fill_last_with_150 = fill_last_with_150 |
|
|
self.label_scaling = label_scaling |
|
|
self.data_root = Path(data_root) |
|
|
self.exclude_11x11 = exclude_11x11 |
|
|
self.only_11x11 = only_11x11 |
|
|
self.num_bands = num_bands |
|
|
|
|
|
if self.split in ["train", "val"]: |
|
|
self.data_dir = self.data_root / "train_data" |
|
|
elif self.split == "test_dat": |
|
|
self.data_dir = self.data_root / "test_dat" |
|
|
elif self.split == "test_enmap": |
|
|
if aoi is None: |
|
|
raise ValueError("aoi must be provided when split is 'test_enmap'") |
|
|
self.data_dir = self.data_root / "test_enmap" / aoi |
|
|
elif self.split == "test_intuition": |
|
|
self.data_dir = self.data_root / "test_intuition" |
|
|
else: |
|
|
self.data_dir = self.data_root / "test_data" |
|
|
|
|
|
self.files = sorted( |
|
|
glob.glob(os.path.join(self.data_dir, "*.npz")), |
|
|
key=lambda x: int(os.path.splitext(os.path.basename(x))[0]) |
|
|
) |
|
|
|
|
|
if self.exclude_11x11: |
|
|
filtered_files = [] |
|
|
for fp in self.files: |
|
|
with np.load(fp) as npz: |
|
|
data = npz["data"] |
|
|
if not (data.shape[1] == 11 and data.shape[2] == 11): |
|
|
filtered_files.append(fp) |
|
|
self.files = filtered_files |
|
|
|
|
|
|
|
|
if self.only_11x11: |
|
|
filtered_files = [] |
|
|
for fp in self.files: |
|
|
with np.load(fp) as npz: |
|
|
data = npz["data"] |
|
|
if data.shape[1] == 11 and data.shape[2] == 11: |
|
|
filtered_files.append(fp) |
|
|
self.files = filtered_files |
|
|
|
|
|
if self.split in ["train", "val"]: |
|
|
indices = np.arange(len(self.files)) |
|
|
train_idx, val_idx = train_test_split( |
|
|
indices, test_size=val_ratio, random_state=random_state, shuffle=True |
|
|
) |
|
|
if self.split == "train": |
|
|
self.files = [self.files[i] for i in train_idx] |
|
|
else: |
|
|
self.files = [self.files[i] for i in val_idx] |
|
|
|
|
|
self.labels = None |
|
|
if label_path is not None and os.path.exists(label_path): |
|
|
self.labels = self._scale_labels(self._load_labels(label_path)) |
|
|
|
|
|
self.target_index = target_index |
|
|
self.transform = transform |
|
|
self.cropped_load = cropped_load |
|
|
|
|
|
if self.num_bands == 12 and self.split == "test_enmap": |
|
|
band_mapping = self._S2_TO_ENMAP_12 |
|
|
elif self.num_bands == 12 and self.split == "test_intuition": |
|
|
band_mapping = self._S2_TO_INTUITON_12 |
|
|
elif self.num_bands == 12: |
|
|
band_mapping = self._S2_TO_HYPER_12 |
|
|
elif self.num_bands == 6: |
|
|
band_mapping = self._HLS_TO_HYPER_6 |
|
|
else: |
|
|
band_mapping = list(range(1, 151)) |
|
|
|
|
|
self.s2_zero_based = [] |
|
|
for b_1 in band_mapping: |
|
|
if b_1 is None: |
|
|
if self.fill_last_with_150 and self.split != "test_intuition": |
|
|
self.s2_zero_based.append(149) |
|
|
elif self.fill_last_with_150 and self.split == "test_intuition": |
|
|
self.s2_zero_based.append(179) |
|
|
else: |
|
|
self.s2_zero_based.append(-1) |
|
|
else: |
|
|
self.s2_zero_based.append(b_1 - 1) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
"""Return dataset size.""" |
|
|
return len(self.files) |
|
|
|
|
|
def __getitem__(self, index: int) -> dict[str, Any]: |
|
|
"""Load one sample with optional masking, scaling and transforms.""" |
|
|
file_path = self.files[index] |
|
|
|
|
|
with np.load(file_path) as npz: |
|
|
if self.mask == "none": |
|
|
if self.split == "test_enmap": |
|
|
arr = npz["enmap"] |
|
|
else: |
|
|
arr = npz["data"] |
|
|
else: |
|
|
arr = np.ma.MaskedArray(**npz) |
|
|
|
|
|
channels = [] |
|
|
if isinstance(arr, np.ma.MaskedArray): |
|
|
for band_idx in self.s2_zero_based: |
|
|
if band_idx == -1: |
|
|
h, w = arr.shape[-2], arr.shape[-1] |
|
|
zeros_data = np.zeros((h, w), dtype=arr.dtype) |
|
|
zeros_mask = np.zeros((h, w), dtype=bool) |
|
|
channel_masked = np.ma.MaskedArray(data=zeros_data, mask=zeros_mask) |
|
|
channels.append(channel_masked) |
|
|
else: |
|
|
channels.append(arr[band_idx]) |
|
|
data_arr = np.ma.stack(channels, axis=0) |
|
|
else: |
|
|
for band_idx in self.s2_zero_based: |
|
|
if band_idx == -1: |
|
|
h, w = arr.shape[-2], arr.shape[-1] |
|
|
channels.append(np.zeros((h, w), dtype=arr.dtype)) |
|
|
else: |
|
|
channels.append(arr[band_idx]) |
|
|
data_arr = np.stack(channels, axis=0) |
|
|
|
|
|
if isinstance(data_arr, np.ma.MaskedArray) and self.mask == "square": |
|
|
data_arr = self._crop_to_largest_square_unmasked(data_arr) |
|
|
|
|
|
if isinstance(data_arr, np.ma.MaskedArray): |
|
|
data_arr = data_arr.filled(0) |
|
|
|
|
|
data_arr = (data_arr / 5419.0).astype(np.float32) |
|
|
data_arr = np.transpose(data_arr, (1, 2, 0)) |
|
|
|
|
|
if self.labels is not None: |
|
|
base = os.path.basename(file_path).replace(".npz", "") |
|
|
sample_id = int(base) |
|
|
label_row = self.labels[sample_id][self.target_index] |
|
|
else: |
|
|
label_row = np.zeros(len(self.target_index), dtype=np.float32) |
|
|
|
|
|
output = {"image": data_arr, "S2L2A": data_arr, "label": label_row} |
|
|
|
|
|
if self.transform is not None: |
|
|
transformed = self.transform(image=output["image"]) |
|
|
output["image"] = transformed["image"] |
|
|
output["S2L2A"] = output["image"] |
|
|
|
|
|
output["label"] = torch.tensor(output["label"], dtype=torch.float32) |
|
|
return output |
|
|
|
|
|
@staticmethod |
|
|
def _load_labels(label_path: str) -> np.ndarray: |
|
|
"""Load labels CSV into a dense array indexed by sample_index.""" |
|
|
df = pd.read_csv(label_path) |
|
|
max_idx = int(np.asarray(df["sample_index"].max()).item()) |
|
|
label_array = np.zeros((max_idx + 1, 4), dtype=np.float32) |
|
|
for row in df.itertuples(): |
|
|
sample_index = int(np.asarray(row.sample_index).item()) |
|
|
label_array[sample_index] = np.array([row.P, row.K, row.Mg, row.pH], dtype=np.float32) |
|
|
return label_array |
|
|
|
|
|
def _scale_labels(self, labels: np.ndarray) -> np.ndarray: |
|
|
"""Scale labels according to configured mode.""" |
|
|
labels = labels.astype(np.float32, copy=False) |
|
|
if self.label_scaling == "none": |
|
|
return labels |
|
|
if self.label_scaling == "std": |
|
|
return (labels - self._LABEL_MEAN) / self._LABEL_STD |
|
|
if self.label_scaling == "norm_max": |
|
|
return labels / self._LABEL_MAX |
|
|
if self.label_scaling == "norm_min_max": |
|
|
denom = np.maximum(self._LABEL_MAX - self._LABEL_MIN, 1e-8) |
|
|
return (labels - self._LABEL_MIN) / denom |
|
|
raise ValueError(f"Unknown label_scaling mode: {self.label_scaling}") |
|
|
|
|
|
@staticmethod |
|
|
def _crop_to_largest_square_unmasked(masked_data: np.ma.MaskedArray) -> np.ma.MaskedArray: |
|
|
"""Return the largest square region containing only unmasked pixels.""" |
|
|
combined_mask = np.asarray(masked_data.mask.any(axis=0), dtype=bool) |
|
|
top, left, size = HyperviewNonGeo._find_largest_square_false(combined_mask) |
|
|
cropped = masked_data[:, top : top + size, left : left + size] |
|
|
return cropped |
|
|
|
|
|
@staticmethod |
|
|
def _find_largest_square_false(mask_2d: np.ndarray) -> tuple[int, int, int]: |
|
|
"""Find the largest False-valued square and return `(top, left, size)`.""" |
|
|
H, W = mask_2d.shape |
|
|
dp = np.zeros((H, W), dtype=np.int32) |
|
|
|
|
|
max_size = 0 |
|
|
max_pos = (0, 0) |
|
|
|
|
|
for i in range(H): |
|
|
for j in range(W): |
|
|
if not mask_2d[i, j]: |
|
|
if i == 0 or j == 0: |
|
|
dp[i, j] = 1 |
|
|
else: |
|
|
dp[i, j] = min(dp[i-1, j], dp[i, j-1], dp[i-1, j-1]) + 1 |
|
|
|
|
|
if dp[i, j] > max_size: |
|
|
max_size = dp[i, j] |
|
|
max_pos = (i, j) |
|
|
|
|
|
(best_i, best_j) = max_pos |
|
|
top = best_i - max_size + 1 |
|
|
left = best_j - max_size + 1 |
|
|
return top, left, max_size |
|
|
|