Spaces:
Running
Running
| """ | |
| src/metrics.py | |
| -------------- | |
| Evaluation metrics for hierarchical probabilistic vote-fraction regression | |
| on Galaxy Zoo 2. | |
| Three evaluation regimes | |
| ------------------------ | |
| 1. GLOBAL β all test samples (dominated by root question t01). | |
| 2. REACHED-BRANCH β samples where branch was actually reached (w >= threshold). | |
| This is the scientifically correct regime for conditional questions. | |
| 3. ECE β Expected Calibration Error using adaptive (equal-frequency) bins. | |
| Fixes applied vs. original | |
| --------------------------- | |
| - ECE uses adaptive binning (equal-frequency bins) instead of equal-width. | |
| Equal-width bins saturate at 0.200 for bimodal questions (t02, t03, t04) | |
| where predictions cluster near 0 and 1. Adaptive bins are unbiased for | |
| any distribution shape. | |
| - simplex_violation_rate() added: fraction of question groups where the | |
| sigmoid baseline predictions do not sum to 1 Β± 0.02. Used to explain | |
| why ResNet-18 + sigmoid achieves lower raw MAE despite predicting | |
| invalid distributions. | |
| """ | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from src.dataset import QUESTION_GROUPS | |
| WEIGHT_THRESHOLDS = [0.05, 0.50, 0.75] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main metrics function | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def compute_metrics( | |
| all_predictions: np.ndarray, # [N, 37] | |
| all_targets: np.ndarray, # [N, 37] | |
| all_weights: np.ndarray, # [N, 11] | |
| ) -> dict: | |
| """ | |
| Full metrics suite: global + reached-branch MAE/RMSE + bias + ECE. | |
| """ | |
| metrics = {} | |
| q_names = list(QUESTION_GROUPS.keys()) | |
| # ββ 1. Global metrics ββββββββββββββββββββββββββββββββββββββ | |
| mae_values = [] | |
| rmse_values = [] | |
| weight_means = [] | |
| for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()): | |
| pred_q = all_predictions[:, start:end] | |
| target_q = all_targets[:, start:end] | |
| weight_q = all_weights[:, q_idx] | |
| mae_q = np.abs(pred_q - target_q).mean(axis=1).mean() | |
| rmse_q = np.sqrt(((pred_q - target_q) ** 2).mean(axis=1).mean()) | |
| w_mean = weight_q.mean() | |
| metrics[f"mae/{q_name}"] = float(mae_q) | |
| metrics[f"rmse/{q_name}"] = float(rmse_q) | |
| metrics[f"bias/{q_name}"] = float( | |
| (all_predictions[:, start:end] - all_targets[:, start:end]).mean() | |
| ) | |
| mae_values.append(mae_q) | |
| rmse_values.append(rmse_q) | |
| weight_means.append(w_mean) | |
| weight_means = np.array(weight_means) | |
| weight_sum = weight_means.sum() | |
| metrics["mae/weighted_avg"] = float( | |
| (weight_means * np.array(mae_values)).sum() / weight_sum | |
| ) | |
| metrics["rmse/weighted_avg"] = float( | |
| (weight_means * np.array(rmse_values)).sum() / weight_sum | |
| ) | |
| # ββ 2. Reached-branch metrics ββββββββββββββββββββββββββββββ | |
| for thresh in WEIGHT_THRESHOLDS: | |
| thresh_key = str(thresh).replace(".", "") | |
| branch_maes = [] | |
| branch_ws = [] | |
| for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()): | |
| pred_q = all_predictions[:, start:end] | |
| target_q = all_targets[:, start:end] | |
| weight_q = all_weights[:, q_idx] | |
| mask = weight_q >= thresh | |
| n_reached = mask.sum() | |
| metrics[f"n_reached_w{thresh_key}/{q_name}"] = int(n_reached) | |
| if n_reached >= 10: | |
| mae_q = np.abs(pred_q[mask] - target_q[mask]).mean(axis=1).mean() | |
| metrics[f"mae_w{thresh_key}/{q_name}"] = float(mae_q) | |
| branch_maes.append(mae_q) | |
| branch_ws.append(weight_q[mask].mean()) | |
| else: | |
| metrics[f"mae_w{thresh_key}/{q_name}"] = float("nan") | |
| if branch_maes: | |
| bw = np.array(branch_ws) | |
| bm = np.array(branch_maes) | |
| metrics[f"mae_w{thresh_key}/conditional_avg"] = float( | |
| (bw * bm).sum() / bw.sum() | |
| ) | |
| # ββ 3. ECE per question (adaptive binning) βββββββββββββββββ | |
| ece_values = [] | |
| for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()): | |
| pred_flat = all_predictions[:, start:end].flatten() | |
| target_flat = all_targets[:, start:end].flatten() | |
| ece = _compute_ece(pred_flat, target_flat) | |
| metrics[f"ece/{q_name}"] = float(ece) | |
| ece_values.append(ece) | |
| metrics["ece/mean"] = float(np.nanmean(ece_values)) | |
| return metrics | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ECE β adaptive (equal-frequency) binning | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _compute_ece(pred: np.ndarray, target: np.ndarray, | |
| n_bins: int = 15) -> float: | |
| """ | |
| Expected Calibration Error with adaptive (equal-frequency) binning. | |
| Equal-width binning saturates for bimodal distributions (e.g. t02, t03, | |
| t04 where predictions cluster at 0 and 1) because >95% of samples fall | |
| into boundary bins. Adaptive binning places bin edges at percentiles of | |
| the predicted distribution, giving each bin an equal number of samples | |
| and making ECE meaningful regardless of the prediction distribution shape. | |
| Parameters | |
| ---------- | |
| pred : [N] predicted vote fractions | |
| target : [N] true vote fractions | |
| n_bins : number of bins (default 15) | |
| Returns | |
| ------- | |
| ECE : float in [0, 1] | |
| """ | |
| if len(pred) < n_bins: | |
| return float("nan") | |
| # Build equal-frequency bin edges from percentiles of pred | |
| percentiles = np.linspace(0, 100, n_bins + 1) | |
| bin_edges = np.unique(np.percentile(pred, percentiles)) | |
| if len(bin_edges) < 2: | |
| return float("nan") | |
| # Assign samples to bins (digitize returns 1-indexed; clip to [0, n-2]) | |
| bin_ids = np.clip(np.digitize(pred, bin_edges[1:-1]), 0, len(bin_edges) - 2) | |
| ece = 0.0 | |
| n = len(pred) | |
| for b in np.unique(bin_ids): | |
| mask = bin_ids == b | |
| if not mask.any(): | |
| continue | |
| ece += (mask.sum() / n) * abs(pred[mask].mean() - target[mask].mean()) | |
| return float(ece) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Simplex violation rate | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def simplex_violation_rate( | |
| predictions: np.ndarray, # [N, 37] | |
| tolerance: float = 0.02, | |
| ) -> dict: | |
| """ | |
| Compute the fraction of galaxies for which each question's predictions | |
| do NOT sum to 1 Β± tolerance. Used to demonstrate that the sigmoid | |
| baseline produces invalid probability distributions. | |
| A model trained with softmax per question group will have violation_rate | |
| β 0.0 by construction. A sigmoid baseline will have nonzero rates, | |
| explaining why its raw per-answer MAE is lower (unconstrained outputs | |
| can fit each marginal independently). | |
| Parameters | |
| ---------- | |
| predictions : [N, 37] array of predicted values | |
| tolerance : acceptable deviation from 1.0 (default 0.02) | |
| Returns | |
| ------- | |
| dict mapping question name to violation rate in [0, 1] | |
| """ | |
| rates = {} | |
| for q_name, (start, end) in QUESTION_GROUPS.items(): | |
| pred_q = predictions[:, start:end] | |
| row_sums = pred_q.sum(axis=1) | |
| violated = np.abs(row_sums - 1.0) > tolerance | |
| rates[q_name] = float(violated.mean()) | |
| rates["mean"] = float(np.mean(list(rates.values()))) | |
| return rates | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Reached-branch comparison table (for paper Table 2) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def compute_reached_branch_mae_table( | |
| model_results: dict, | |
| ) -> "pd.DataFrame": | |
| """ | |
| Build the reached-branch MAE comparison table across all models. | |
| Parameters | |
| ---------- | |
| model_results : dict mapping model_name β (preds, targets, weights) | |
| All arrays are [N, 37] or [N, 11]. | |
| Returns | |
| ------- | |
| pd.DataFrame with columns: | |
| model, question, description, n_w005, mae_w005, mae_w050, mae_w075 | |
| """ | |
| import pandas as pd | |
| QUESTION_DESCRIPTIONS = { | |
| "t01": "Smooth or features", | |
| "t02": "Edge-on disk", | |
| "t03": "Bar", | |
| "t04": "Spiral arms", | |
| "t05": "Bulge prominence", | |
| "t06": "Odd feature", | |
| "t07": "Roundedness (smooth)", | |
| "t08": "Odd feature type", | |
| "t09": "Bulge shape (edge-on)", | |
| "t10": "Arms winding", | |
| "t11": "Arms number", | |
| } | |
| rows = [] | |
| for model_name, (preds, targets, weights) in model_results.items(): | |
| for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()): | |
| pred_q = preds[:, start:end] | |
| target_q = targets[:, start:end] | |
| weight_q = weights[:, q_idx] | |
| row = { | |
| "model" : model_name, | |
| "question" : q_name, | |
| "description": QUESTION_DESCRIPTIONS[q_name], | |
| } | |
| for thresh in WEIGHT_THRESHOLDS: | |
| mask = weight_q >= thresh | |
| n = mask.sum() | |
| key = f"n_w{str(thresh).replace('.','')}" | |
| mkey = f"mae_w{str(thresh).replace('.','')}" | |
| row[key] = int(n) | |
| row[mkey] = ( | |
| float(np.abs(pred_q[mask] - target_q[mask]).mean(axis=1).mean()) | |
| if n >= 10 else float("nan") | |
| ) | |
| rows.append(row) | |
| # Weighted-average row for this model | |
| branch_maes = [] | |
| branch_ws = [] | |
| for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()): | |
| weight_q = weights[:, q_idx] | |
| pred_q = preds[:, start:end] | |
| target_q = targets[:, start:end] | |
| mask = weight_q >= 0.05 | |
| if mask.sum() >= 10: | |
| branch_maes.append( | |
| np.abs(pred_q[mask] - target_q[mask]).mean(axis=1).mean() | |
| ) | |
| branch_ws.append(weight_q[mask].mean()) | |
| bw = np.array(branch_ws) | |
| bm = np.array(branch_maes) | |
| rows.append({ | |
| "model" : model_name, | |
| "question" : "weighted_avg", | |
| "description": "Weighted average (wβ₯0.05)", | |
| "n_w005" : int(sum(weights[:, q] >= 0.05 for q in range(11)).sum() | |
| if hasattr(weights, "__len__") else 0), | |
| "mae_w005" : float((bw * bm).sum() / bw.sum()) if len(bw) > 0 else float("nan"), | |
| "mae_w050" : float("nan"), | |
| "mae_w075" : float("nan"), | |
| }) | |
| return pd.DataFrame(rows) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Tensor β numpy helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predictions_to_numpy( | |
| predictions: torch.Tensor, | |
| targets: torch.Tensor, | |
| weights: torch.Tensor, | |
| ) -> tuple: | |
| """Apply softmax per question group and return numpy arrays.""" | |
| pred_np = predictions.detach().cpu().clone() | |
| for q_name, (start, end) in QUESTION_GROUPS.items(): | |
| pred_np[:, start:end] = F.softmax(pred_np[:, start:end], dim=-1) | |
| return ( | |
| pred_np.numpy(), | |
| targets.detach().cpu().numpy(), | |
| weights.detach().cpu().numpy(), | |
| ) | |
| def dirichlet_predictions_to_numpy( | |
| alpha: torch.Tensor, | |
| targets: torch.Tensor, | |
| weights: torch.Tensor, | |
| ) -> tuple: | |
| """Convert Dirichlet concentration parameters to mean predictions.""" | |
| means = torch.zeros_like(alpha) | |
| for q_name, (start, end) in QUESTION_GROUPS.items(): | |
| a_q = alpha[:, start:end] | |
| means[:, start:end] = a_q / a_q.sum(dim=-1, keepdim=True) | |
| return ( | |
| means.detach().cpu().numpy(), | |
| targets.detach().cpu().numpy(), | |
| weights.detach().cpu().numpy(), | |
| ) | |