File size: 13,156 Bytes
87904b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 |
import glob
import os
from pathlib import Path
from typing import Any, Optional
import albumentations as A
import numpy as np
import pandas as pd
import torch
from torch import Tensor
from torchgeo.datasets import NonGeoDataset
from sklearn.model_selection import train_test_split
class HyperviewNonGeo(NonGeoDataset):
"""
Modified dataset that can load either 6, 12, or all 150 channels.
- For `num_bands=12`, it loads the standard Sentinel-2 set (skipping B10),
with B11/B12 either filled from band #150 or zeroed out (depending on the
`fill_last_with_150` flag).
- For `num_bands=6`, it loads an HLS-like subset.
- For `num_bands=150`, it simply loads the entire hyperspectral cube.
The 'mask' parameter can take one of three values:
- "none": No mask is used (arrays loaded as np.ndarray).
- "og": A MaskedArray is created for each file (original mask),
but no cropping is performed.
- "square": A MaskedArray is created, and we find the largest
square region in which all pixels are unmasked (mask=False),
cropping the data to that region.
"""
_S2_TO_HYPER_12 = [1, 10, 32, 64, 77, 88, 101, 120, 127, 150, None, None]
_S2_TO_ENMAP_12 = [6, 16, 30, 48, 54, 59, 64, 72, 75, 90, 150, 191]
_S2_TO_INTUITON_12 = [5, 14, 35, 85, 101, 112, 134, 155, 165, 179, None, None]
_HLS_TO_HYPER_6 = [10, 32, 64, 127, None, None]
_LABEL_MIN = np.array([20.3000, 21.1000, 26.8000, 5.8000], dtype=np.float32)
_LABEL_MAX = np.array([325.0000, 625.0000, 400.0000, 7.8000], dtype=np.float32)
_LABEL_MEAN = np.array([70.4617, 226.8499, 159.3915, 6.7789], dtype=np.float32)
_LABEL_STD = np.array([30.1490, 60.5661, 39.7610, 0.2593], dtype=np.float32)
splits = {
"train": "train",
"val": "val",
"test": "test",
"test_dat": "test_dat",
"test_enmap": "test_enmap",
"test_intuition": "test_intuition",
}
def __init__(
self,
data_root: str,
split: str = "train",
label_path: Optional[str] = None,
transform: Optional[A.Compose] = None,
target_index: list[int] = [0, 1, 2, 3],
cropped_load: bool = False,
val_ratio: float = 0.2,
random_state: int = 42,
mask: str = "none",
fill_last_with_150: bool = True,
exclude_11x11: bool = False,
only_11x11: bool = False,
num_bands: int = 12,
label_scaling: str = "std",
aoi: Optional[str] = None,
) -> None:
"""
Args:
data_root (str): Path to the root directory of the data.
split (str): One of 'train', 'val', or 'test'.
label_path (str, optional): Path to CSV with labels.
transform (A.Compose, optional): Albumentations transforms.
target_index (list[int]): Indices of columns to select as labels.
cropped_load (bool): Not directly used here (legacy param).
val_ratio (float): Fraction of data to use as validation (when split is train/val).
random_state (int): Random seed for train/val split.
mask (str): Controls how we handle mask. One of:
- "none" -> no mask
- "og" -> original mask
- "square" -> original mask + crop to largest unmasked square
fill_last_with_150 (bool): If True, B11/B12 are replaced by band #150
(in 12- or 6-band scenario) otherwise zero-filled.
This has no effect when num_bands=150.
exclude_11x11 (bool): If True, images of size (11,11) are filtered out.
only_11x11 (bool): If True, only images of size (11,11) are kept.
num_bands (int): Number of bands to use. One of {6, 12, 150}.
label_scaling (str): Label scaling mode. One of
{"none", "std", "norm_max", "norm_min_max"}.
"""
super().__init__()
if split not in self.splits:
raise ValueError(f"split must be one of {list(self.splits.keys())}, got '{split}'")
if mask not in ("none", "og", "square"):
raise ValueError(f"mask must be one of ['none', 'og', 'square'], got '{mask}'")
if num_bands not in (6, 12, 150):
raise ValueError(f"num_bands must be one of [6, 12, 150], got {num_bands}")
if label_scaling not in ("none", "std", "norm_max", "norm_min_max"):
raise ValueError(
"label_scaling must be one of ['none', 'std', 'norm_max', 'norm_min_max'], "
f"got '{label_scaling}'"
)
self.split = split
self.mask = mask
self.fill_last_with_150 = fill_last_with_150
self.label_scaling = label_scaling
self.data_root = Path(data_root)
self.exclude_11x11 = exclude_11x11
self.only_11x11 = only_11x11
self.num_bands = num_bands
if self.split in ["train", "val"]:
self.data_dir = self.data_root / "train_data"
elif self.split == "test_dat":
self.data_dir = self.data_root / "test_dat"
elif self.split == "test_enmap":
if aoi is None:
raise ValueError("aoi must be provided when split is 'test_enmap'")
self.data_dir = self.data_root / "test_enmap" / aoi
elif self.split == "test_intuition":
self.data_dir = self.data_root / "test_intuition"
else:
self.data_dir = self.data_root / "test_data"
self.files = sorted(
glob.glob(os.path.join(self.data_dir, "*.npz")),
key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
)
if self.exclude_11x11:
filtered_files = []
for fp in self.files:
with np.load(fp) as npz:
data = npz["data"]
if not (data.shape[1] == 11 and data.shape[2] == 11):
filtered_files.append(fp)
self.files = filtered_files
if self.only_11x11:
filtered_files = []
for fp in self.files:
with np.load(fp) as npz:
data = npz["data"]
if data.shape[1] == 11 and data.shape[2] == 11:
filtered_files.append(fp)
self.files = filtered_files
if self.split in ["train", "val"]:
indices = np.arange(len(self.files))
train_idx, val_idx = train_test_split(
indices, test_size=val_ratio, random_state=random_state, shuffle=True
)
if self.split == "train":
self.files = [self.files[i] for i in train_idx]
else:
self.files = [self.files[i] for i in val_idx]
self.labels = None
if label_path is not None and os.path.exists(label_path):
self.labels = self._scale_labels(self._load_labels(label_path))
self.target_index = target_index
self.transform = transform
self.cropped_load = cropped_load
if self.num_bands == 12 and self.split == "test_enmap":
band_mapping = self._S2_TO_ENMAP_12
elif self.num_bands == 12 and self.split == "test_intuition":
band_mapping = self._S2_TO_INTUITON_12
elif self.num_bands == 12:
band_mapping = self._S2_TO_HYPER_12
elif self.num_bands == 6:
band_mapping = self._HLS_TO_HYPER_6
else:
band_mapping = list(range(1, 151))
self.s2_zero_based = []
for b_1 in band_mapping:
if b_1 is None:
if self.fill_last_with_150 and self.split != "test_intuition":
self.s2_zero_based.append(149)
elif self.fill_last_with_150 and self.split == "test_intuition":
self.s2_zero_based.append(179)
else:
self.s2_zero_based.append(-1)
else:
self.s2_zero_based.append(b_1 - 1)
def __len__(self) -> int:
"""Return dataset size."""
return len(self.files)
def __getitem__(self, index: int) -> dict[str, Any]:
"""Load one sample with optional masking, scaling and transforms."""
file_path = self.files[index]
with np.load(file_path) as npz:
if self.mask == "none":
if self.split == "test_enmap":
arr = npz["enmap"]
else:
arr = npz["data"]
else:
arr = np.ma.MaskedArray(**npz)
channels = []
if isinstance(arr, np.ma.MaskedArray):
for band_idx in self.s2_zero_based:
if band_idx == -1:
h, w = arr.shape[-2], arr.shape[-1]
zeros_data = np.zeros((h, w), dtype=arr.dtype)
zeros_mask = np.zeros((h, w), dtype=bool)
channel_masked = np.ma.MaskedArray(data=zeros_data, mask=zeros_mask)
channels.append(channel_masked)
else:
channels.append(arr[band_idx])
data_arr = np.ma.stack(channels, axis=0)
else:
for band_idx in self.s2_zero_based:
if band_idx == -1:
h, w = arr.shape[-2], arr.shape[-1]
channels.append(np.zeros((h, w), dtype=arr.dtype))
else:
channels.append(arr[band_idx])
data_arr = np.stack(channels, axis=0)
if isinstance(data_arr, np.ma.MaskedArray) and self.mask == "square":
data_arr = self._crop_to_largest_square_unmasked(data_arr)
if isinstance(data_arr, np.ma.MaskedArray):
data_arr = data_arr.filled(0)
data_arr = (data_arr / 5419.0).astype(np.float32)
data_arr = np.transpose(data_arr, (1, 2, 0))
if self.labels is not None:
base = os.path.basename(file_path).replace(".npz", "")
sample_id = int(base)
label_row = self.labels[sample_id][self.target_index]
else:
label_row = np.zeros(len(self.target_index), dtype=np.float32)
output = {"image": data_arr, "S2L2A": data_arr, "label": label_row}
if self.transform is not None:
transformed = self.transform(image=output["image"])
output["image"] = transformed["image"]
output["S2L2A"] = output["image"]
output["label"] = torch.tensor(output["label"], dtype=torch.float32)
return output
@staticmethod
def _load_labels(label_path: str) -> np.ndarray:
"""Load labels CSV into a dense array indexed by sample_index."""
df = pd.read_csv(label_path)
max_idx = int(np.asarray(df["sample_index"].max()).item())
label_array = np.zeros((max_idx + 1, 4), dtype=np.float32)
for row in df.itertuples():
sample_index = int(np.asarray(row.sample_index).item())
label_array[sample_index] = np.array([row.P, row.K, row.Mg, row.pH], dtype=np.float32)
return label_array
def _scale_labels(self, labels: np.ndarray) -> np.ndarray:
"""Scale labels according to configured mode."""
labels = labels.astype(np.float32, copy=False)
if self.label_scaling == "none":
return labels
if self.label_scaling == "std":
return (labels - self._LABEL_MEAN) / self._LABEL_STD
if self.label_scaling == "norm_max":
return labels / self._LABEL_MAX
if self.label_scaling == "norm_min_max":
denom = np.maximum(self._LABEL_MAX - self._LABEL_MIN, 1e-8)
return (labels - self._LABEL_MIN) / denom
raise ValueError(f"Unknown label_scaling mode: {self.label_scaling}")
@staticmethod
def _crop_to_largest_square_unmasked(masked_data: np.ma.MaskedArray) -> np.ma.MaskedArray:
"""Return the largest square region containing only unmasked pixels."""
combined_mask = np.asarray(masked_data.mask.any(axis=0), dtype=bool)
top, left, size = HyperviewNonGeo._find_largest_square_false(combined_mask)
cropped = masked_data[:, top : top + size, left : left + size]
return cropped
@staticmethod
def _find_largest_square_false(mask_2d: np.ndarray) -> tuple[int, int, int]:
"""Find the largest False-valued square and return `(top, left, size)`."""
H, W = mask_2d.shape
dp = np.zeros((H, W), dtype=np.int32)
max_size = 0
max_pos = (0, 0)
for i in range(H):
for j in range(W):
if not mask_2d[i, j]:
if i == 0 or j == 0:
dp[i, j] = 1
else:
dp[i, j] = min(dp[i-1, j], dp[i, j-1], dp[i-1, j-1]) + 1
if dp[i, j] > max_size:
max_size = dp[i, j]
max_pos = (i, j)
(best_i, best_j) = max_pos
top = best_i - max_size + 1
left = best_j - max_size + 1
return top, left, max_size
|