TerraMind-HYPERVIEW / datasets /hyperview_dataset.py
KPLabs's picture
Upload folder using huggingface_hub
87904b0 verified
import glob
import os
from pathlib import Path
from typing import Any, Optional
import albumentations as A
import numpy as np
import pandas as pd
import torch
from torch import Tensor
from torchgeo.datasets import NonGeoDataset
from sklearn.model_selection import train_test_split
class HyperviewNonGeo(NonGeoDataset):
"""
Modified dataset that can load either 6, 12, or all 150 channels.
- For `num_bands=12`, it loads the standard Sentinel-2 set (skipping B10),
with B11/B12 either filled from band #150 or zeroed out (depending on the
`fill_last_with_150` flag).
- For `num_bands=6`, it loads an HLS-like subset.
- For `num_bands=150`, it simply loads the entire hyperspectral cube.
The 'mask' parameter can take one of three values:
- "none": No mask is used (arrays loaded as np.ndarray).
- "og": A MaskedArray is created for each file (original mask),
but no cropping is performed.
- "square": A MaskedArray is created, and we find the largest
square region in which all pixels are unmasked (mask=False),
cropping the data to that region.
"""
_S2_TO_HYPER_12 = [1, 10, 32, 64, 77, 88, 101, 120, 127, 150, None, None]
_S2_TO_ENMAP_12 = [6, 16, 30, 48, 54, 59, 64, 72, 75, 90, 150, 191]
_S2_TO_INTUITON_12 = [5, 14, 35, 85, 101, 112, 134, 155, 165, 179, None, None]
_HLS_TO_HYPER_6 = [10, 32, 64, 127, None, None]
_LABEL_MIN = np.array([20.3000, 21.1000, 26.8000, 5.8000], dtype=np.float32)
_LABEL_MAX = np.array([325.0000, 625.0000, 400.0000, 7.8000], dtype=np.float32)
_LABEL_MEAN = np.array([70.4617, 226.8499, 159.3915, 6.7789], dtype=np.float32)
_LABEL_STD = np.array([30.1490, 60.5661, 39.7610, 0.2593], dtype=np.float32)
splits = {
"train": "train",
"val": "val",
"test": "test",
"test_dat": "test_dat",
"test_enmap": "test_enmap",
"test_intuition": "test_intuition",
}
def __init__(
self,
data_root: str,
split: str = "train",
label_path: Optional[str] = None,
transform: Optional[A.Compose] = None,
target_index: list[int] = [0, 1, 2, 3],
cropped_load: bool = False,
val_ratio: float = 0.2,
random_state: int = 42,
mask: str = "none",
fill_last_with_150: bool = True,
exclude_11x11: bool = False,
only_11x11: bool = False,
num_bands: int = 12,
label_scaling: str = "std",
aoi: Optional[str] = None,
) -> None:
"""
Args:
data_root (str): Path to the root directory of the data.
split (str): One of 'train', 'val', or 'test'.
label_path (str, optional): Path to CSV with labels.
transform (A.Compose, optional): Albumentations transforms.
target_index (list[int]): Indices of columns to select as labels.
cropped_load (bool): Not directly used here (legacy param).
val_ratio (float): Fraction of data to use as validation (when split is train/val).
random_state (int): Random seed for train/val split.
mask (str): Controls how we handle mask. One of:
- "none" -> no mask
- "og" -> original mask
- "square" -> original mask + crop to largest unmasked square
fill_last_with_150 (bool): If True, B11/B12 are replaced by band #150
(in 12- or 6-band scenario) otherwise zero-filled.
This has no effect when num_bands=150.
exclude_11x11 (bool): If True, images of size (11,11) are filtered out.
only_11x11 (bool): If True, only images of size (11,11) are kept.
num_bands (int): Number of bands to use. One of {6, 12, 150}.
label_scaling (str): Label scaling mode. One of
{"none", "std", "norm_max", "norm_min_max"}.
"""
super().__init__()
if split not in self.splits:
raise ValueError(f"split must be one of {list(self.splits.keys())}, got '{split}'")
if mask not in ("none", "og", "square"):
raise ValueError(f"mask must be one of ['none', 'og', 'square'], got '{mask}'")
if num_bands not in (6, 12, 150):
raise ValueError(f"num_bands must be one of [6, 12, 150], got {num_bands}")
if label_scaling not in ("none", "std", "norm_max", "norm_min_max"):
raise ValueError(
"label_scaling must be one of ['none', 'std', 'norm_max', 'norm_min_max'], "
f"got '{label_scaling}'"
)
self.split = split
self.mask = mask
self.fill_last_with_150 = fill_last_with_150
self.label_scaling = label_scaling
self.data_root = Path(data_root)
self.exclude_11x11 = exclude_11x11
self.only_11x11 = only_11x11
self.num_bands = num_bands
if self.split in ["train", "val"]:
self.data_dir = self.data_root / "train_data"
elif self.split == "test_dat":
self.data_dir = self.data_root / "test_dat"
elif self.split == "test_enmap":
if aoi is None:
raise ValueError("aoi must be provided when split is 'test_enmap'")
self.data_dir = self.data_root / "test_enmap" / aoi
elif self.split == "test_intuition":
self.data_dir = self.data_root / "test_intuition"
else:
self.data_dir = self.data_root / "test_data"
self.files = sorted(
glob.glob(os.path.join(self.data_dir, "*.npz")),
key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
)
if self.exclude_11x11:
filtered_files = []
for fp in self.files:
with np.load(fp) as npz:
data = npz["data"]
if not (data.shape[1] == 11 and data.shape[2] == 11):
filtered_files.append(fp)
self.files = filtered_files
if self.only_11x11:
filtered_files = []
for fp in self.files:
with np.load(fp) as npz:
data = npz["data"]
if data.shape[1] == 11 and data.shape[2] == 11:
filtered_files.append(fp)
self.files = filtered_files
if self.split in ["train", "val"]:
indices = np.arange(len(self.files))
train_idx, val_idx = train_test_split(
indices, test_size=val_ratio, random_state=random_state, shuffle=True
)
if self.split == "train":
self.files = [self.files[i] for i in train_idx]
else:
self.files = [self.files[i] for i in val_idx]
self.labels = None
if label_path is not None and os.path.exists(label_path):
self.labels = self._scale_labels(self._load_labels(label_path))
self.target_index = target_index
self.transform = transform
self.cropped_load = cropped_load
if self.num_bands == 12 and self.split == "test_enmap":
band_mapping = self._S2_TO_ENMAP_12
elif self.num_bands == 12 and self.split == "test_intuition":
band_mapping = self._S2_TO_INTUITON_12
elif self.num_bands == 12:
band_mapping = self._S2_TO_HYPER_12
elif self.num_bands == 6:
band_mapping = self._HLS_TO_HYPER_6
else:
band_mapping = list(range(1, 151))
self.s2_zero_based = []
for b_1 in band_mapping:
if b_1 is None:
if self.fill_last_with_150 and self.split != "test_intuition":
self.s2_zero_based.append(149)
elif self.fill_last_with_150 and self.split == "test_intuition":
self.s2_zero_based.append(179)
else:
self.s2_zero_based.append(-1)
else:
self.s2_zero_based.append(b_1 - 1)
def __len__(self) -> int:
"""Return dataset size."""
return len(self.files)
def __getitem__(self, index: int) -> dict[str, Any]:
"""Load one sample with optional masking, scaling and transforms."""
file_path = self.files[index]
with np.load(file_path) as npz:
if self.mask == "none":
if self.split == "test_enmap":
arr = npz["enmap"]
else:
arr = npz["data"]
else:
arr = np.ma.MaskedArray(**npz)
channels = []
if isinstance(arr, np.ma.MaskedArray):
for band_idx in self.s2_zero_based:
if band_idx == -1:
h, w = arr.shape[-2], arr.shape[-1]
zeros_data = np.zeros((h, w), dtype=arr.dtype)
zeros_mask = np.zeros((h, w), dtype=bool)
channel_masked = np.ma.MaskedArray(data=zeros_data, mask=zeros_mask)
channels.append(channel_masked)
else:
channels.append(arr[band_idx])
data_arr = np.ma.stack(channels, axis=0)
else:
for band_idx in self.s2_zero_based:
if band_idx == -1:
h, w = arr.shape[-2], arr.shape[-1]
channels.append(np.zeros((h, w), dtype=arr.dtype))
else:
channels.append(arr[band_idx])
data_arr = np.stack(channels, axis=0)
if isinstance(data_arr, np.ma.MaskedArray) and self.mask == "square":
data_arr = self._crop_to_largest_square_unmasked(data_arr)
if isinstance(data_arr, np.ma.MaskedArray):
data_arr = data_arr.filled(0)
data_arr = (data_arr / 5419.0).astype(np.float32)
data_arr = np.transpose(data_arr, (1, 2, 0))
if self.labels is not None:
base = os.path.basename(file_path).replace(".npz", "")
sample_id = int(base)
label_row = self.labels[sample_id][self.target_index]
else:
label_row = np.zeros(len(self.target_index), dtype=np.float32)
output = {"image": data_arr, "S2L2A": data_arr, "label": label_row}
if self.transform is not None:
transformed = self.transform(image=output["image"])
output["image"] = transformed["image"]
output["S2L2A"] = output["image"]
output["label"] = torch.tensor(output["label"], dtype=torch.float32)
return output
@staticmethod
def _load_labels(label_path: str) -> np.ndarray:
"""Load labels CSV into a dense array indexed by sample_index."""
df = pd.read_csv(label_path)
max_idx = int(np.asarray(df["sample_index"].max()).item())
label_array = np.zeros((max_idx + 1, 4), dtype=np.float32)
for row in df.itertuples():
sample_index = int(np.asarray(row.sample_index).item())
label_array[sample_index] = np.array([row.P, row.K, row.Mg, row.pH], dtype=np.float32)
return label_array
def _scale_labels(self, labels: np.ndarray) -> np.ndarray:
"""Scale labels according to configured mode."""
labels = labels.astype(np.float32, copy=False)
if self.label_scaling == "none":
return labels
if self.label_scaling == "std":
return (labels - self._LABEL_MEAN) / self._LABEL_STD
if self.label_scaling == "norm_max":
return labels / self._LABEL_MAX
if self.label_scaling == "norm_min_max":
denom = np.maximum(self._LABEL_MAX - self._LABEL_MIN, 1e-8)
return (labels - self._LABEL_MIN) / denom
raise ValueError(f"Unknown label_scaling mode: {self.label_scaling}")
@staticmethod
def _crop_to_largest_square_unmasked(masked_data: np.ma.MaskedArray) -> np.ma.MaskedArray:
"""Return the largest square region containing only unmasked pixels."""
combined_mask = np.asarray(masked_data.mask.any(axis=0), dtype=bool)
top, left, size = HyperviewNonGeo._find_largest_square_false(combined_mask)
cropped = masked_data[:, top : top + size, left : left + size]
return cropped
@staticmethod
def _find_largest_square_false(mask_2d: np.ndarray) -> tuple[int, int, int]:
"""Find the largest False-valued square and return `(top, left, size)`."""
H, W = mask_2d.shape
dp = np.zeros((H, W), dtype=np.int32)
max_size = 0
max_pos = (0, 0)
for i in range(H):
for j in range(W):
if not mask_2d[i, j]:
if i == 0 or j == 0:
dp[i, j] = 1
else:
dp[i, j] = min(dp[i-1, j], dp[i, j-1], dp[i-1, j-1]) + 1
if dp[i, j] > max_size:
max_size = dp[i, j]
max_pos = (i, j)
(best_i, best_j) = max_pos
top = best_i - max_size + 1
left = best_j - max_size + 1
return top, left, max_size