TerraMind-HYPERVIEW / hyperview_subimssion.py
KPLabs's picture
Upload folder using huggingface_hub
87904b0 verified
"""Utilities for loading Hyperview data, postprocessing and evaluation."""
from glob import glob
import os
from typing import Dict
import numpy as np
import pandas as pd
from sklearn.metrics import (
accuracy_score,
balanced_accuracy_score,
cohen_kappa_score,
f1_score,
matthews_corrcoef,
)
CLASSES = ["P", "K", "Mg", "pH"]
class_metrics = {
"avg_acc": (balanced_accuracy_score, {}),
"acc": (accuracy_score, {}),
"mcc": (matthews_corrcoef, {}),
"f1": (f1_score, {"average": "macro"}),
}
ph_classes_names = ["acidic", "strongly acidic", "slightly acidic", "neutral", "alkaline"]
classes_names = ["very low", "low", "medium", "high", "very high"]
ph_thresholds = [4.6, 5.6, 6.6, 7.3]
phosphorus_thresholds = [
[50, 110, 186, 262],
[49, 103, 158, 215],
[47, 99, 152, 207],
[27, 54, 75, 99],
[27, 54, 75, 99],
]
potassium_thresholds = [
[32, 75, 119, 162],
[52, 99, 145, 191],
[98, 139, 200, 241],
[126, 174, 270, 318],
]
magnesium_thresholds = [
[7, 21, 51, 80],
[31, 43, 67, 93],
[48, 77, 106, 135],
[69, 93, 142, 191],
]
def element_classification(result: float, thresholds: list[float]) -> int:
"""Classify numeric value into threshold-defined bucket."""
class_id = 0
for i, threshold in enumerate(thresholds):
if result > threshold:
class_id = i + 1
else:
break
return class_id
class BaselineRegressor:
"""Baseline regressor predicting target-wise mean from training labels."""
def __init__(self) -> None:
"""Initialize baseline regressor state."""
self.mean = 0
self.classes_count = 0
def fit(self, X_train: np.ndarray, y_train: np.ndarray) -> "BaselineRegressor":
"""Fit baseline statistics from training labels."""
_ = X_train
self.mean = np.mean(y_train, axis=0)
self.classes_count = y_train.shape[1]
return self
def predict(self, X_test: np.ndarray) -> np.ndarray:
"""Predict constant mean vector for each sample."""
return np.full((len(X_test), self.classes_count), self.mean)
class SpectralCurveFiltering:
"""Convert a 3D hyperspectral cube into 1D spectral curve."""
def __init__(self, merge_function=np.mean) -> None:
"""Store aggregation function used for spectral compression."""
self.merge_function = merge_function
def __call__(self, sample: np.ndarray) -> np.ndarray:
"""Aggregate each band over spatial dimensions."""
return self.merge_function(sample, axis=(1, 2))
def load_data(directory: str, split: str | None = None, mask: str = "none") -> np.ndarray:
"""Load and transform all `.npz` cubes from a directory.
Args:
directory: Directory with `.npz` files.
split: Optional split hint (e.g. `test_enmap`) used to prefer proper key.
mask: Mask mode. For `"none"` load dense arrays; otherwise prefer masked arrays.
"""
filtering = SpectralCurveFiltering()
data = []
if split is None:
split = "test_enmap" if "test_enmap" in directory else "test"
all_files = np.array(
sorted(
glob(os.path.join(directory, "*.npz")),
key=lambda path: int(os.path.basename(path).replace(".npz", "")),
)
)
for file_name in all_files:
with np.load(file_name) as npz:
keys = set(npz.files)
if mask == "none":
if split == "test_enmap" and "enmap" in keys:
arr = npz["enmap"]
elif "data" in keys:
arr = npz["data"]
elif "dat" in keys:
arr = npz["dat"]
elif "enmap" in keys:
arr = npz["enmap"]
else:
raise ValueError(
f"Unsupported .npz format in {file_name}. Found keys: {sorted(keys)}"
)
else:
if {"data", "mask"}.issubset(keys):
arr = np.ma.MaskedArray(data=npz["data"], mask=npz["mask"])
elif {"dat", "mask"}.issubset(keys):
arr = np.ma.MaskedArray(data=npz["dat"], mask=npz["mask"])
elif "enmap" in keys:
arr = npz["enmap"]
elif "data" in keys:
arr = npz["data"]
elif "dat" in keys:
arr = npz["dat"]
else:
raise ValueError(
f"Unsupported .npz format in {file_name}. Found keys: {sorted(keys)}"
)
data.append(filtering(arr))
return np.array(data)
def load_gt(file_path: str) -> np.ndarray:
"""Load target labels from CSV file."""
gt_file = pd.read_csv(file_path)
return gt_file[["P", "K", "Mg", "pH"]].values
def load_hyperview_data():
"""Load default train/test arrays and labels for Hyperview."""
X_train = load_data("hyperview_data/train_data")
y_train = load_gt("hyperview_data/train_gt.csv")
X_test = load_data("hyperview_data/test_data")
y_test = load_gt("hyperview_data/test_gt.csv")
return X_train, y_train, X_test, y_test
def calculate_metrics(
y_pred: pd.DataFrame,
y_true: pd.DataFrame,
soil_class: int = 3,
) -> Dict[str, float | list[float]]:
"""Calculate per-class classification metrics and aggregated stats."""
_ = soil_class
out: Dict[str, float | list[float]] = {}
for metric_name, (func, kwargs) in class_metrics.items():
metric_scores = []
for class_name in CLASSES:
score = [func(y_pred=y_pred[class_name], y_true=y_true[class_name], **kwargs)]
out[f"{class_name}_{metric_name}"] = score
metric_scores.append(score)
out[f"{class_name}_kappa"] = [
cohen_kappa_score(y_pred[class_name], y_true[class_name])
]
out[f"mean_{metric_name}"] = [np.mean(metric_scores)]
out[f"std_{metric_name}"] = [np.std(metric_scores)]
return out
def ph_classification(result: float) -> int:
"""Classify pH level."""
return element_classification(result, ph_thresholds)
def phosphorus_classification(result: float, ph_class: int) -> int:
"""Classify phosphorus based on pH class thresholds."""
return element_classification(result, phosphorus_thresholds[int(ph_class)])
def potassium_classification(result: float, soil_class: int) -> int:
"""Classify potassium based on soil class thresholds."""
return element_classification(result, potassium_thresholds[int(soil_class)])
def magnesium_classification(result: float, soil_class: int) -> int:
"""Classify magnesium based on soil class thresholds."""
return element_classification(result, magnesium_thresholds[int(soil_class)])
def get_classes(y: pd.DataFrame, soil_class: int = 3) -> pd.DataFrame:
"""Convert continuous predictions into discrete nutrient classes."""
y_classes = {k: [] for k in CLASSES}
for _, row in y.iterrows():
y_classes["pH"].append(ph_classification(row["pH"]))
y_classes["P"].append(phosphorus_classification(row["P"], y_classes["pH"][-1]))
y_classes["K"].append(potassium_classification(row["K"], soil_class))
y_classes["Mg"].append(magnesium_classification(row["Mg"], soil_class))
return pd.DataFrame.from_dict(y_classes)