import glob import os from pathlib import Path from typing import Any, Optional import albumentations as A import numpy as np import pandas as pd import torch from torch import Tensor from torchgeo.datasets import NonGeoDataset from sklearn.model_selection import train_test_split class HyperviewNonGeo(NonGeoDataset): """ Modified dataset that can load either 6, 12, or all 150 channels. - For `num_bands=12`, it loads the standard Sentinel-2 set (skipping B10), with B11/B12 either filled from band #150 or zeroed out (depending on the `fill_last_with_150` flag). - For `num_bands=6`, it loads an HLS-like subset. - For `num_bands=150`, it simply loads the entire hyperspectral cube. The 'mask' parameter can take one of three values: - "none": No mask is used (arrays loaded as np.ndarray). - "og": A MaskedArray is created for each file (original mask), but no cropping is performed. - "square": A MaskedArray is created, and we find the largest square region in which all pixels are unmasked (mask=False), cropping the data to that region. """ _S2_TO_HYPER_12 = [1, 10, 32, 64, 77, 88, 101, 120, 127, 150, None, None] _S2_TO_ENMAP_12 = [6, 16, 30, 48, 54, 59, 64, 72, 75, 90, 150, 191] _S2_TO_INTUITON_12 = [5, 14, 35, 85, 101, 112, 134, 155, 165, 179, None, None] _HLS_TO_HYPER_6 = [10, 32, 64, 127, None, None] _LABEL_MIN = np.array([20.3000, 21.1000, 26.8000, 5.8000], dtype=np.float32) _LABEL_MAX = np.array([325.0000, 625.0000, 400.0000, 7.8000], dtype=np.float32) _LABEL_MEAN = np.array([70.4617, 226.8499, 159.3915, 6.7789], dtype=np.float32) _LABEL_STD = np.array([30.1490, 60.5661, 39.7610, 0.2593], dtype=np.float32) splits = { "train": "train", "val": "val", "test": "test", "test_dat": "test_dat", "test_enmap": "test_enmap", "test_intuition": "test_intuition", } def __init__( self, data_root: str, split: str = "train", label_path: Optional[str] = None, transform: Optional[A.Compose] = None, target_index: list[int] = [0, 1, 2, 3], cropped_load: bool = False, val_ratio: float = 0.2, random_state: int = 42, mask: str = "none", fill_last_with_150: bool = True, exclude_11x11: bool = False, only_11x11: bool = False, num_bands: int = 12, label_scaling: str = "std", aoi: Optional[str] = None, ) -> None: """ Args: data_root (str): Path to the root directory of the data. split (str): One of 'train', 'val', or 'test'. label_path (str, optional): Path to CSV with labels. transform (A.Compose, optional): Albumentations transforms. target_index (list[int]): Indices of columns to select as labels. cropped_load (bool): Not directly used here (legacy param). val_ratio (float): Fraction of data to use as validation (when split is train/val). random_state (int): Random seed for train/val split. mask (str): Controls how we handle mask. One of: - "none" -> no mask - "og" -> original mask - "square" -> original mask + crop to largest unmasked square fill_last_with_150 (bool): If True, B11/B12 are replaced by band #150 (in 12- or 6-band scenario) otherwise zero-filled. This has no effect when num_bands=150. exclude_11x11 (bool): If True, images of size (11,11) are filtered out. only_11x11 (bool): If True, only images of size (11,11) are kept. num_bands (int): Number of bands to use. One of {6, 12, 150}. label_scaling (str): Label scaling mode. One of {"none", "std", "norm_max", "norm_min_max"}. """ super().__init__() if split not in self.splits: raise ValueError(f"split must be one of {list(self.splits.keys())}, got '{split}'") if mask not in ("none", "og", "square"): raise ValueError(f"mask must be one of ['none', 'og', 'square'], got '{mask}'") if num_bands not in (6, 12, 150): raise ValueError(f"num_bands must be one of [6, 12, 150], got {num_bands}") if label_scaling not in ("none", "std", "norm_max", "norm_min_max"): raise ValueError( "label_scaling must be one of ['none', 'std', 'norm_max', 'norm_min_max'], " f"got '{label_scaling}'" ) self.split = split self.mask = mask self.fill_last_with_150 = fill_last_with_150 self.label_scaling = label_scaling self.data_root = Path(data_root) self.exclude_11x11 = exclude_11x11 self.only_11x11 = only_11x11 self.num_bands = num_bands if self.split in ["train", "val"]: self.data_dir = self.data_root / "train_data" elif self.split == "test_dat": self.data_dir = self.data_root / "test_dat" elif self.split == "test_enmap": if aoi is None: raise ValueError("aoi must be provided when split is 'test_enmap'") self.data_dir = self.data_root / "test_enmap" / aoi elif self.split == "test_intuition": self.data_dir = self.data_root / "test_intuition" else: self.data_dir = self.data_root / "test_data" self.files = sorted( glob.glob(os.path.join(self.data_dir, "*.npz")), key=lambda x: int(os.path.splitext(os.path.basename(x))[0]) ) if self.exclude_11x11: filtered_files = [] for fp in self.files: with np.load(fp) as npz: data = npz["data"] if not (data.shape[1] == 11 and data.shape[2] == 11): filtered_files.append(fp) self.files = filtered_files if self.only_11x11: filtered_files = [] for fp in self.files: with np.load(fp) as npz: data = npz["data"] if data.shape[1] == 11 and data.shape[2] == 11: filtered_files.append(fp) self.files = filtered_files if self.split in ["train", "val"]: indices = np.arange(len(self.files)) train_idx, val_idx = train_test_split( indices, test_size=val_ratio, random_state=random_state, shuffle=True ) if self.split == "train": self.files = [self.files[i] for i in train_idx] else: self.files = [self.files[i] for i in val_idx] self.labels = None if label_path is not None and os.path.exists(label_path): self.labels = self._scale_labels(self._load_labels(label_path)) self.target_index = target_index self.transform = transform self.cropped_load = cropped_load if self.num_bands == 12 and self.split == "test_enmap": band_mapping = self._S2_TO_ENMAP_12 elif self.num_bands == 12 and self.split == "test_intuition": band_mapping = self._S2_TO_INTUITON_12 elif self.num_bands == 12: band_mapping = self._S2_TO_HYPER_12 elif self.num_bands == 6: band_mapping = self._HLS_TO_HYPER_6 else: band_mapping = list(range(1, 151)) self.s2_zero_based = [] for b_1 in band_mapping: if b_1 is None: if self.fill_last_with_150 and self.split != "test_intuition": self.s2_zero_based.append(149) elif self.fill_last_with_150 and self.split == "test_intuition": self.s2_zero_based.append(179) else: self.s2_zero_based.append(-1) else: self.s2_zero_based.append(b_1 - 1) def __len__(self) -> int: """Return dataset size.""" return len(self.files) def __getitem__(self, index: int) -> dict[str, Any]: """Load one sample with optional masking, scaling and transforms.""" file_path = self.files[index] with np.load(file_path) as npz: if self.mask == "none": if self.split == "test_enmap": arr = npz["enmap"] else: arr = npz["data"] else: arr = np.ma.MaskedArray(**npz) channels = [] if isinstance(arr, np.ma.MaskedArray): for band_idx in self.s2_zero_based: if band_idx == -1: h, w = arr.shape[-2], arr.shape[-1] zeros_data = np.zeros((h, w), dtype=arr.dtype) zeros_mask = np.zeros((h, w), dtype=bool) channel_masked = np.ma.MaskedArray(data=zeros_data, mask=zeros_mask) channels.append(channel_masked) else: channels.append(arr[band_idx]) data_arr = np.ma.stack(channels, axis=0) else: for band_idx in self.s2_zero_based: if band_idx == -1: h, w = arr.shape[-2], arr.shape[-1] channels.append(np.zeros((h, w), dtype=arr.dtype)) else: channels.append(arr[band_idx]) data_arr = np.stack(channels, axis=0) if isinstance(data_arr, np.ma.MaskedArray) and self.mask == "square": data_arr = self._crop_to_largest_square_unmasked(data_arr) if isinstance(data_arr, np.ma.MaskedArray): data_arr = data_arr.filled(0) data_arr = (data_arr / 5419.0).astype(np.float32) data_arr = np.transpose(data_arr, (1, 2, 0)) if self.labels is not None: base = os.path.basename(file_path).replace(".npz", "") sample_id = int(base) label_row = self.labels[sample_id][self.target_index] else: label_row = np.zeros(len(self.target_index), dtype=np.float32) output = {"image": data_arr, "S2L2A": data_arr, "label": label_row} if self.transform is not None: transformed = self.transform(image=output["image"]) output["image"] = transformed["image"] output["S2L2A"] = output["image"] output["label"] = torch.tensor(output["label"], dtype=torch.float32) return output @staticmethod def _load_labels(label_path: str) -> np.ndarray: """Load labels CSV into a dense array indexed by sample_index.""" df = pd.read_csv(label_path) max_idx = int(np.asarray(df["sample_index"].max()).item()) label_array = np.zeros((max_idx + 1, 4), dtype=np.float32) for row in df.itertuples(): sample_index = int(np.asarray(row.sample_index).item()) label_array[sample_index] = np.array([row.P, row.K, row.Mg, row.pH], dtype=np.float32) return label_array def _scale_labels(self, labels: np.ndarray) -> np.ndarray: """Scale labels according to configured mode.""" labels = labels.astype(np.float32, copy=False) if self.label_scaling == "none": return labels if self.label_scaling == "std": return (labels - self._LABEL_MEAN) / self._LABEL_STD if self.label_scaling == "norm_max": return labels / self._LABEL_MAX if self.label_scaling == "norm_min_max": denom = np.maximum(self._LABEL_MAX - self._LABEL_MIN, 1e-8) return (labels - self._LABEL_MIN) / denom raise ValueError(f"Unknown label_scaling mode: {self.label_scaling}") @staticmethod def _crop_to_largest_square_unmasked(masked_data: np.ma.MaskedArray) -> np.ma.MaskedArray: """Return the largest square region containing only unmasked pixels.""" combined_mask = np.asarray(masked_data.mask.any(axis=0), dtype=bool) top, left, size = HyperviewNonGeo._find_largest_square_false(combined_mask) cropped = masked_data[:, top : top + size, left : left + size] return cropped @staticmethod def _find_largest_square_false(mask_2d: np.ndarray) -> tuple[int, int, int]: """Find the largest False-valued square and return `(top, left, size)`.""" H, W = mask_2d.shape dp = np.zeros((H, W), dtype=np.int32) max_size = 0 max_pos = (0, 0) for i in range(H): for j in range(W): if not mask_2d[i, j]: if i == 0 or j == 0: dp[i, j] = 1 else: dp[i, j] = min(dp[i-1, j], dp[i, j-1], dp[i-1, j-1]) + 1 if dp[i, j] > max_size: max_size = dp[i, j] max_pos = (i, j) (best_i, best_j) = max_pos top = best_i - max_size + 1 left = best_j - max_size + 1 return top, left, max_size