eshwar-gz2-api / src /metrics.py
sreshwarprasad's picture
Upload folder using huggingface_hub
e36eee4 verified
"""
src/metrics.py
--------------
Evaluation metrics for hierarchical probabilistic vote-fraction regression
on Galaxy Zoo 2.
Three evaluation regimes
------------------------
1. GLOBAL β€” all test samples (dominated by root question t01).
2. REACHED-BRANCH β€” samples where branch was actually reached (w >= threshold).
This is the scientifically correct regime for conditional questions.
3. ECE β€” Expected Calibration Error using adaptive (equal-frequency) bins.
Fixes applied vs. original
---------------------------
- ECE uses adaptive binning (equal-frequency bins) instead of equal-width.
Equal-width bins saturate at 0.200 for bimodal questions (t02, t03, t04)
where predictions cluster near 0 and 1. Adaptive bins are unbiased for
any distribution shape.
- simplex_violation_rate() added: fraction of question groups where the
sigmoid baseline predictions do not sum to 1 Β± 0.02. Used to explain
why ResNet-18 + sigmoid achieves lower raw MAE despite predicting
invalid distributions.
"""
import numpy as np
import torch
import torch.nn.functional as F
from src.dataset import QUESTION_GROUPS
WEIGHT_THRESHOLDS = [0.05, 0.50, 0.75]
# ─────────────────────────────────────────────────────────────
# Main metrics function
# ─────────────────────────────────────────────────────────────
def compute_metrics(
all_predictions: np.ndarray, # [N, 37]
all_targets: np.ndarray, # [N, 37]
all_weights: np.ndarray, # [N, 11]
) -> dict:
"""
Full metrics suite: global + reached-branch MAE/RMSE + bias + ECE.
"""
metrics = {}
q_names = list(QUESTION_GROUPS.keys())
# ── 1. Global metrics ──────────────────────────────────────
mae_values = []
rmse_values = []
weight_means = []
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
pred_q = all_predictions[:, start:end]
target_q = all_targets[:, start:end]
weight_q = all_weights[:, q_idx]
mae_q = np.abs(pred_q - target_q).mean(axis=1).mean()
rmse_q = np.sqrt(((pred_q - target_q) ** 2).mean(axis=1).mean())
w_mean = weight_q.mean()
metrics[f"mae/{q_name}"] = float(mae_q)
metrics[f"rmse/{q_name}"] = float(rmse_q)
metrics[f"bias/{q_name}"] = float(
(all_predictions[:, start:end] - all_targets[:, start:end]).mean()
)
mae_values.append(mae_q)
rmse_values.append(rmse_q)
weight_means.append(w_mean)
weight_means = np.array(weight_means)
weight_sum = weight_means.sum()
metrics["mae/weighted_avg"] = float(
(weight_means * np.array(mae_values)).sum() / weight_sum
)
metrics["rmse/weighted_avg"] = float(
(weight_means * np.array(rmse_values)).sum() / weight_sum
)
# ── 2. Reached-branch metrics ──────────────────────────────
for thresh in WEIGHT_THRESHOLDS:
thresh_key = str(thresh).replace(".", "")
branch_maes = []
branch_ws = []
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
pred_q = all_predictions[:, start:end]
target_q = all_targets[:, start:end]
weight_q = all_weights[:, q_idx]
mask = weight_q >= thresh
n_reached = mask.sum()
metrics[f"n_reached_w{thresh_key}/{q_name}"] = int(n_reached)
if n_reached >= 10:
mae_q = np.abs(pred_q[mask] - target_q[mask]).mean(axis=1).mean()
metrics[f"mae_w{thresh_key}/{q_name}"] = float(mae_q)
branch_maes.append(mae_q)
branch_ws.append(weight_q[mask].mean())
else:
metrics[f"mae_w{thresh_key}/{q_name}"] = float("nan")
if branch_maes:
bw = np.array(branch_ws)
bm = np.array(branch_maes)
metrics[f"mae_w{thresh_key}/conditional_avg"] = float(
(bw * bm).sum() / bw.sum()
)
# ── 3. ECE per question (adaptive binning) ─────────────────
ece_values = []
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
pred_flat = all_predictions[:, start:end].flatten()
target_flat = all_targets[:, start:end].flatten()
ece = _compute_ece(pred_flat, target_flat)
metrics[f"ece/{q_name}"] = float(ece)
ece_values.append(ece)
metrics["ece/mean"] = float(np.nanmean(ece_values))
return metrics
# ─────────────────────────────────────────────────────────────
# ECE β€” adaptive (equal-frequency) binning
# ─────────────────────────────────────────────────────────────
def _compute_ece(pred: np.ndarray, target: np.ndarray,
n_bins: int = 15) -> float:
"""
Expected Calibration Error with adaptive (equal-frequency) binning.
Equal-width binning saturates for bimodal distributions (e.g. t02, t03,
t04 where predictions cluster at 0 and 1) because >95% of samples fall
into boundary bins. Adaptive binning places bin edges at percentiles of
the predicted distribution, giving each bin an equal number of samples
and making ECE meaningful regardless of the prediction distribution shape.
Parameters
----------
pred : [N] predicted vote fractions
target : [N] true vote fractions
n_bins : number of bins (default 15)
Returns
-------
ECE : float in [0, 1]
"""
if len(pred) < n_bins:
return float("nan")
# Build equal-frequency bin edges from percentiles of pred
percentiles = np.linspace(0, 100, n_bins + 1)
bin_edges = np.unique(np.percentile(pred, percentiles))
if len(bin_edges) < 2:
return float("nan")
# Assign samples to bins (digitize returns 1-indexed; clip to [0, n-2])
bin_ids = np.clip(np.digitize(pred, bin_edges[1:-1]), 0, len(bin_edges) - 2)
ece = 0.0
n = len(pred)
for b in np.unique(bin_ids):
mask = bin_ids == b
if not mask.any():
continue
ece += (mask.sum() / n) * abs(pred[mask].mean() - target[mask].mean())
return float(ece)
# ─────────────────────────────────────────────────────────────
# Simplex violation rate
# ─────────────────────────────────────────────────────────────
def simplex_violation_rate(
predictions: np.ndarray, # [N, 37]
tolerance: float = 0.02,
) -> dict:
"""
Compute the fraction of galaxies for which each question's predictions
do NOT sum to 1 Β± tolerance. Used to demonstrate that the sigmoid
baseline produces invalid probability distributions.
A model trained with softmax per question group will have violation_rate
β‰ˆ 0.0 by construction. A sigmoid baseline will have nonzero rates,
explaining why its raw per-answer MAE is lower (unconstrained outputs
can fit each marginal independently).
Parameters
----------
predictions : [N, 37] array of predicted values
tolerance : acceptable deviation from 1.0 (default 0.02)
Returns
-------
dict mapping question name to violation rate in [0, 1]
"""
rates = {}
for q_name, (start, end) in QUESTION_GROUPS.items():
pred_q = predictions[:, start:end]
row_sums = pred_q.sum(axis=1)
violated = np.abs(row_sums - 1.0) > tolerance
rates[q_name] = float(violated.mean())
rates["mean"] = float(np.mean(list(rates.values())))
return rates
# ─────────────────────────────────────────────────────────────
# Reached-branch comparison table (for paper Table 2)
# ─────────────────────────────────────────────────────────────
def compute_reached_branch_mae_table(
model_results: dict,
) -> "pd.DataFrame":
"""
Build the reached-branch MAE comparison table across all models.
Parameters
----------
model_results : dict mapping model_name β†’ (preds, targets, weights)
All arrays are [N, 37] or [N, 11].
Returns
-------
pd.DataFrame with columns:
model, question, description, n_w005, mae_w005, mae_w050, mae_w075
"""
import pandas as pd
QUESTION_DESCRIPTIONS = {
"t01": "Smooth or features",
"t02": "Edge-on disk",
"t03": "Bar",
"t04": "Spiral arms",
"t05": "Bulge prominence",
"t06": "Odd feature",
"t07": "Roundedness (smooth)",
"t08": "Odd feature type",
"t09": "Bulge shape (edge-on)",
"t10": "Arms winding",
"t11": "Arms number",
}
rows = []
for model_name, (preds, targets, weights) in model_results.items():
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
pred_q = preds[:, start:end]
target_q = targets[:, start:end]
weight_q = weights[:, q_idx]
row = {
"model" : model_name,
"question" : q_name,
"description": QUESTION_DESCRIPTIONS[q_name],
}
for thresh in WEIGHT_THRESHOLDS:
mask = weight_q >= thresh
n = mask.sum()
key = f"n_w{str(thresh).replace('.','')}"
mkey = f"mae_w{str(thresh).replace('.','')}"
row[key] = int(n)
row[mkey] = (
float(np.abs(pred_q[mask] - target_q[mask]).mean(axis=1).mean())
if n >= 10 else float("nan")
)
rows.append(row)
# Weighted-average row for this model
branch_maes = []
branch_ws = []
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
weight_q = weights[:, q_idx]
pred_q = preds[:, start:end]
target_q = targets[:, start:end]
mask = weight_q >= 0.05
if mask.sum() >= 10:
branch_maes.append(
np.abs(pred_q[mask] - target_q[mask]).mean(axis=1).mean()
)
branch_ws.append(weight_q[mask].mean())
bw = np.array(branch_ws)
bm = np.array(branch_maes)
rows.append({
"model" : model_name,
"question" : "weighted_avg",
"description": "Weighted average (wβ‰₯0.05)",
"n_w005" : int(sum(weights[:, q] >= 0.05 for q in range(11)).sum()
if hasattr(weights, "__len__") else 0),
"mae_w005" : float((bw * bm).sum() / bw.sum()) if len(bw) > 0 else float("nan"),
"mae_w050" : float("nan"),
"mae_w075" : float("nan"),
})
return pd.DataFrame(rows)
# ─────────────────────────────────────────────────────────────
# Tensor β†’ numpy helpers
# ─────────────────────────────────────────────────────────────
def predictions_to_numpy(
predictions: torch.Tensor,
targets: torch.Tensor,
weights: torch.Tensor,
) -> tuple:
"""Apply softmax per question group and return numpy arrays."""
pred_np = predictions.detach().cpu().clone()
for q_name, (start, end) in QUESTION_GROUPS.items():
pred_np[:, start:end] = F.softmax(pred_np[:, start:end], dim=-1)
return (
pred_np.numpy(),
targets.detach().cpu().numpy(),
weights.detach().cpu().numpy(),
)
def dirichlet_predictions_to_numpy(
alpha: torch.Tensor,
targets: torch.Tensor,
weights: torch.Tensor,
) -> tuple:
"""Convert Dirichlet concentration parameters to mean predictions."""
means = torch.zeros_like(alpha)
for q_name, (start, end) in QUESTION_GROUPS.items():
a_q = alpha[:, start:end]
means[:, start:end] = a_q / a_q.sum(dim=-1, keepdim=True)
return (
means.detach().cpu().numpy(),
targets.detach().cpu().numpy(),
weights.detach().cpu().numpy(),
)