"""Utilities for loading Hyperview data, postprocessing and evaluation.""" from glob import glob import os from typing import Dict import numpy as np import pandas as pd from sklearn.metrics import ( accuracy_score, balanced_accuracy_score, cohen_kappa_score, f1_score, matthews_corrcoef, ) CLASSES = ["P", "K", "Mg", "pH"] class_metrics = { "avg_acc": (balanced_accuracy_score, {}), "acc": (accuracy_score, {}), "mcc": (matthews_corrcoef, {}), "f1": (f1_score, {"average": "macro"}), } ph_classes_names = ["acidic", "strongly acidic", "slightly acidic", "neutral", "alkaline"] classes_names = ["very low", "low", "medium", "high", "very high"] ph_thresholds = [4.6, 5.6, 6.6, 7.3] phosphorus_thresholds = [ [50, 110, 186, 262], [49, 103, 158, 215], [47, 99, 152, 207], [27, 54, 75, 99], [27, 54, 75, 99], ] potassium_thresholds = [ [32, 75, 119, 162], [52, 99, 145, 191], [98, 139, 200, 241], [126, 174, 270, 318], ] magnesium_thresholds = [ [7, 21, 51, 80], [31, 43, 67, 93], [48, 77, 106, 135], [69, 93, 142, 191], ] def element_classification(result: float, thresholds: list[float]) -> int: """Classify numeric value into threshold-defined bucket.""" class_id = 0 for i, threshold in enumerate(thresholds): if result > threshold: class_id = i + 1 else: break return class_id class BaselineRegressor: """Baseline regressor predicting target-wise mean from training labels.""" def __init__(self) -> None: """Initialize baseline regressor state.""" self.mean = 0 self.classes_count = 0 def fit(self, X_train: np.ndarray, y_train: np.ndarray) -> "BaselineRegressor": """Fit baseline statistics from training labels.""" _ = X_train self.mean = np.mean(y_train, axis=0) self.classes_count = y_train.shape[1] return self def predict(self, X_test: np.ndarray) -> np.ndarray: """Predict constant mean vector for each sample.""" return np.full((len(X_test), self.classes_count), self.mean) class SpectralCurveFiltering: """Convert a 3D hyperspectral cube into 1D spectral curve.""" def __init__(self, merge_function=np.mean) -> None: """Store aggregation function used for spectral compression.""" self.merge_function = merge_function def __call__(self, sample: np.ndarray) -> np.ndarray: """Aggregate each band over spatial dimensions.""" return self.merge_function(sample, axis=(1, 2)) def load_data(directory: str, split: str | None = None, mask: str = "none") -> np.ndarray: """Load and transform all `.npz` cubes from a directory. Args: directory: Directory with `.npz` files. split: Optional split hint (e.g. `test_enmap`) used to prefer proper key. mask: Mask mode. For `"none"` load dense arrays; otherwise prefer masked arrays. """ filtering = SpectralCurveFiltering() data = [] if split is None: split = "test_enmap" if "test_enmap" in directory else "test" all_files = np.array( sorted( glob(os.path.join(directory, "*.npz")), key=lambda path: int(os.path.basename(path).replace(".npz", "")), ) ) for file_name in all_files: with np.load(file_name) as npz: keys = set(npz.files) if mask == "none": if split == "test_enmap" and "enmap" in keys: arr = npz["enmap"] elif "data" in keys: arr = npz["data"] elif "dat" in keys: arr = npz["dat"] elif "enmap" in keys: arr = npz["enmap"] else: raise ValueError( f"Unsupported .npz format in {file_name}. Found keys: {sorted(keys)}" ) else: if {"data", "mask"}.issubset(keys): arr = np.ma.MaskedArray(data=npz["data"], mask=npz["mask"]) elif {"dat", "mask"}.issubset(keys): arr = np.ma.MaskedArray(data=npz["dat"], mask=npz["mask"]) elif "enmap" in keys: arr = npz["enmap"] elif "data" in keys: arr = npz["data"] elif "dat" in keys: arr = npz["dat"] else: raise ValueError( f"Unsupported .npz format in {file_name}. Found keys: {sorted(keys)}" ) data.append(filtering(arr)) return np.array(data) def load_gt(file_path: str) -> np.ndarray: """Load target labels from CSV file.""" gt_file = pd.read_csv(file_path) return gt_file[["P", "K", "Mg", "pH"]].values def load_hyperview_data(): """Load default train/test arrays and labels for Hyperview.""" X_train = load_data("hyperview_data/train_data") y_train = load_gt("hyperview_data/train_gt.csv") X_test = load_data("hyperview_data/test_data") y_test = load_gt("hyperview_data/test_gt.csv") return X_train, y_train, X_test, y_test def calculate_metrics( y_pred: pd.DataFrame, y_true: pd.DataFrame, soil_class: int = 3, ) -> Dict[str, float | list[float]]: """Calculate per-class classification metrics and aggregated stats.""" _ = soil_class out: Dict[str, float | list[float]] = {} for metric_name, (func, kwargs) in class_metrics.items(): metric_scores = [] for class_name in CLASSES: score = [func(y_pred=y_pred[class_name], y_true=y_true[class_name], **kwargs)] out[f"{class_name}_{metric_name}"] = score metric_scores.append(score) out[f"{class_name}_kappa"] = [ cohen_kappa_score(y_pred[class_name], y_true[class_name]) ] out[f"mean_{metric_name}"] = [np.mean(metric_scores)] out[f"std_{metric_name}"] = [np.std(metric_scores)] return out def ph_classification(result: float) -> int: """Classify pH level.""" return element_classification(result, ph_thresholds) def phosphorus_classification(result: float, ph_class: int) -> int: """Classify phosphorus based on pH class thresholds.""" return element_classification(result, phosphorus_thresholds[int(ph_class)]) def potassium_classification(result: float, soil_class: int) -> int: """Classify potassium based on soil class thresholds.""" return element_classification(result, potassium_thresholds[int(soil_class)]) def magnesium_classification(result: float, soil_class: int) -> int: """Classify magnesium based on soil class thresholds.""" return element_classification(result, magnesium_thresholds[int(soil_class)]) def get_classes(y: pd.DataFrame, soil_class: int = 3) -> pd.DataFrame: """Convert continuous predictions into discrete nutrient classes.""" y_classes = {k: [] for k in CLASSES} for _, row in y.iterrows(): y_classes["pH"].append(ph_classification(row["pH"])) y_classes["P"].append(phosphorus_classification(row["P"], y_classes["pH"][-1])) y_classes["K"].append(potassium_classification(row["K"], soil_class)) y_classes["Mg"].append(magnesium_classification(row["Mg"], soil_class)) return pd.DataFrame.from_dict(y_classes)