Spaces:
Running
Running
| """ | |
| src/uncertainty_analysis.py | |
| ---------------------------- | |
| MC Dropout epistemic uncertainty analysis for the proposed model. | |
| MC Dropout (Gal & Ghahramani 2016) is used as a post-hoc uncertainty | |
| estimator. At inference time, dropout is kept active and N=30 stochastic | |
| forward passes are run per batch. The standard deviation across passes | |
| is used as the epistemic uncertainty estimate per galaxy per question. | |
| Key findings reported | |
| --------------------- | |
| 1. Uncertainty distributions: right-skewed, well-separated means across | |
| questions reflecting the conditional nature of the decision tree. | |
| 2. Uncertainty vs. error correlation: Spearman Ο reported per question. | |
| Strong positive correlation for root and shallow-branch questions | |
| (t01, t02, t04, t07) indicates the model is well-calibrated in | |
| uncertainty. Weak or near-zero correlation for deep conditional | |
| branches (t03, t05, t08, t09, t10, t11) is expected β these branches | |
| have small effective sample sizes and aleatoric uncertainty dominates. | |
| 3. Morphology selection benchmark: F1 score at threshold Ο for downstream | |
| binary morphology classification tasks. | |
| Output files | |
| ------------ | |
| outputs/figures/uncertainty/ | |
| fig_uncertainty_distributions.pdf | |
| fig_uncertainty_vs_error.pdf | |
| fig_morphology_f1_comparison.pdf | |
| table_uncertainty_summary.csv | |
| table_morphology_selection_benchmark.csv | |
| mc_cache/ β cached numpy arrays (crash-safe) | |
| Usage | |
| ----- | |
| cd ~/galaxy | |
| nohup python -m src.uncertainty_analysis \ | |
| --config configs/full_train.yaml --n_passes 30 \ | |
| > outputs/logs/uncertainty.log 2>&1 & | |
| echo "PID: $!" | |
| """ | |
| import argparse | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn.functional as F | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from scipy import stats as scipy_stats | |
| from torch.amp import autocast | |
| from omegaconf import OmegaConf | |
| from tqdm import tqdm | |
| from src.dataset import build_dataloaders, QUESTION_GROUPS | |
| from src.model import build_model, build_dirichlet_model | |
| from src.baselines import ResNet18Baseline | |
| from src.metrics import predictions_to_numpy, dirichlet_predictions_to_numpy | |
| logging.basicConfig( | |
| format="%(asctime)s %(levelname)s %(message)s", | |
| datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout, | |
| ) | |
| log = logging.getLogger("uncertainty") | |
| plt.rcParams.update({ | |
| "figure.dpi": 150, "savefig.dpi": 300, | |
| "font.family": "serif", "font.size": 11, | |
| "axes.titlesize": 10, "axes.labelsize": 10, | |
| "xtick.labelsize": 8, "ytick.labelsize": 8, | |
| "legend.fontsize": 8, | |
| "figure.facecolor": "white", "axes.facecolor": "white", | |
| "axes.grid": True, "grid.alpha": 0.3, | |
| "pdf.fonttype": 42, "ps.fonttype": 42, | |
| }) | |
| QUESTION_LABELS = { | |
| "t01": "Smooth or features", "t02": "Edge-on disk", | |
| "t03": "Bar", "t04": "Spiral arms", | |
| "t05": "Bulge prominence", "t06": "Odd feature", | |
| "t07": "Roundedness", "t08": "Odd feature type", | |
| "t09": "Bulge shape", "t10": "Arms winding", | |
| "t11": "Arms number", | |
| } | |
| MODEL_COLORS = { | |
| "ViT-Base + KL+MSE (proposed)" : "#27ae60", | |
| "ViT-Base + Dirichlet (Zoobot-style)": "#8e44ad", | |
| "ResNet-18 + MSE (sigmoid)" : "#c0392b", | |
| "ResNet-18 + KL+MSE" : "#e67e22", | |
| } | |
| SELECTION_THRESHOLDS = [0.5, 0.7, 0.8, 0.9] | |
| SELECTION_ANSWERS = { | |
| "t01": (0, "smooth"), | |
| "t02": (0, "edge-on"), | |
| "t03": (0, "bar"), | |
| "t04": (0, "spiral"), | |
| "t06": (0, "odd feature"), | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MC Dropout inference β Welford online algorithm, crash-safe | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_mc_inference(model, loader, device, cfg, | |
| n_passes=30, cache_dir=None): | |
| """ | |
| Fast batched MC Dropout inference. | |
| Uses Welford's online algorithm to compute mean and std | |
| per batch without storing all n_passes Γ N predictions. | |
| Memory usage: O(N Γ 37) regardless of n_passes. | |
| Parameters | |
| ---------- | |
| model : GalaxyViT with enable_mc_dropout() available | |
| loader : test DataLoader | |
| device : inference device | |
| cfg : OmegaConf config | |
| n_passes : number of stochastic forward passes (default 30) | |
| cache_dir : if given, saves .npy files and skips if they exist | |
| Returns | |
| ------- | |
| mean_all, std_all : [N, 37] float32 | |
| targets_all : [N, 37] float32 | |
| weights_all : [N, 11] float32 | |
| """ | |
| if cache_dir is not None: | |
| cache_dir = Path(cache_dir) | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| fp_mean = cache_dir / "mc_mean.npy" | |
| fp_std = cache_dir / "mc_std.npy" | |
| fp_targets = cache_dir / "mc_targets.npy" | |
| fp_weights = cache_dir / "mc_weights.npy" | |
| if all(p.exists() for p in [fp_mean, fp_std, fp_targets, fp_weights]): | |
| log.info("MC cache found β loading from disk (skipping inference).") | |
| return (np.load(fp_mean), np.load(fp_std), | |
| np.load(fp_targets), np.load(fp_weights)) | |
| model.eval() | |
| model.enable_mc_dropout() | |
| all_means, all_stds, all_targets, all_weights = [], [], [], [] | |
| log.info("MC Dropout: %d passes Γ %d-image batches = %d total forward passes", | |
| n_passes, loader.batch_size, n_passes * len(loader)) | |
| for images, targets, weights, _ in tqdm(loader, desc="MC Dropout"): | |
| images_dev = images.to(device, non_blocking=True) | |
| # Welford online mean and M2 | |
| mean_acc = None | |
| M2_acc = None | |
| count = 0 | |
| for _ in range(n_passes): | |
| with torch.no_grad(): | |
| with autocast("cuda", enabled=cfg.training.mixed_precision): | |
| logits = model(images_dev) | |
| pred = torch.zeros_like(logits) | |
| for q, (s, e) in QUESTION_GROUPS.items(): | |
| pred[:, s:e] = F.softmax(logits[:, s:e], dim=-1) | |
| pred_np = pred.cpu().float().numpy() # [B, 37] | |
| count += 1 | |
| if mean_acc is None: | |
| mean_acc = pred_np.copy() | |
| M2_acc = np.zeros_like(pred_np) | |
| else: | |
| delta = pred_np - mean_acc | |
| mean_acc += delta / count | |
| M2_acc += delta * (pred_np - mean_acc) | |
| std_acc = np.sqrt(M2_acc / (count - 1) if count > 1 | |
| else np.zeros_like(M2_acc)) | |
| all_means.append(mean_acc) | |
| all_stds.append(std_acc) | |
| all_targets.append(targets.numpy()) | |
| all_weights.append(weights.numpy()) | |
| model.disable_mc_dropout() | |
| mean_all = np.concatenate(all_means) | |
| std_all = np.concatenate(all_stds) | |
| targets_all = np.concatenate(all_targets) | |
| weights_all = np.concatenate(all_weights) | |
| if cache_dir is not None: | |
| np.save(fp_mean, mean_all) | |
| np.save(fp_std, std_all) | |
| np.save(fp_targets, targets_all) | |
| np.save(fp_weights, weights_all) | |
| log.info("MC results cached: %s", cache_dir) | |
| return mean_all, std_all, targets_all, weights_all | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Figure 1: Uncertainty distributions | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def fig_uncertainty_distributions(mean_preds, std_preds, | |
| targets, weights, save_dir): | |
| path_pdf = save_dir / "fig_uncertainty_distributions.pdf" | |
| path_png = save_dir / "fig_uncertainty_distributions.png" | |
| if path_pdf.exists() and path_png.exists(): | |
| log.info("Skip (exists): fig_uncertainty_distributions"); return | |
| fig, axes = plt.subplots(3, 4, figsize=(16, 12)) | |
| axes = axes.flatten() | |
| for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()): | |
| ax = axes[q_idx] | |
| mask = weights[:, q_idx] >= 0.05 | |
| std_q = std_preds[mask, start:end].mean(axis=1) | |
| ax.hist(std_q, bins=50, color="#6366f1", alpha=0.85, | |
| edgecolor="none", density=True) | |
| ax.axvline(std_q.mean(), color="#c0392b", linewidth=1.8, | |
| linestyle="--", label=f"Mean = {std_q.mean():.4f}") | |
| ax.set_xlabel("MC Dropout std (epistemic uncertainty)") | |
| ax.set_ylabel("Density") | |
| ax.set_title( | |
| f"{q_name}: {QUESTION_LABELS[q_name]}\n" | |
| f"$n$ = {mask.sum():,} (w β₯ 0.05)", | |
| fontsize=9, | |
| ) | |
| ax.legend(fontsize=7) | |
| axes[-1].axis("off") | |
| plt.suptitle( | |
| "Epistemic uncertainty distributions β MC Dropout (30 passes)\n" | |
| "Proposed model (ViT-Base/16 + hierarchical KL+MSE), test set", | |
| fontsize=12, | |
| ) | |
| plt.tight_layout() | |
| fig.savefig(path_pdf, dpi=300, bbox_inches="tight") | |
| fig.savefig(path_png, dpi=300, bbox_inches="tight") | |
| plt.close(fig) | |
| log.info("Saved: fig_uncertainty_distributions") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Figure 2: Uncertainty vs. error (Spearman Ο) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def fig_uncertainty_vs_error(mean_preds, std_preds, | |
| targets, weights, save_dir): | |
| path_pdf = save_dir / "fig_uncertainty_vs_error.pdf" | |
| path_png = save_dir / "fig_uncertainty_vs_error.png" | |
| if path_pdf.exists() and path_png.exists(): | |
| log.info("Skip (exists): fig_uncertainty_vs_error"); return | |
| fig, axes = plt.subplots(3, 4, figsize=(16, 12)) | |
| axes = axes.flatten() | |
| for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()): | |
| ax = axes[q_idx] | |
| mask = weights[:, q_idx] >= 0.05 | |
| unc = std_preds[mask, start:end].mean(axis=1) | |
| err = np.abs(mean_preds[mask, start:end] - | |
| targets[mask, start:end]).mean(axis=1) | |
| # Adaptive bin means for trend line | |
| n_bins = 15 | |
| unc_bins = np.unique(np.percentile(unc, np.linspace(0, 100, n_bins + 1))) | |
| bin_ids = np.clip(np.digitize(unc, unc_bins) - 1, 0, len(unc_bins) - 2) | |
| bn_unc = [unc[bin_ids == b].mean() for b in range(len(unc_bins) - 1) | |
| if (bin_ids == b).any()] | |
| bn_err = [err[bin_ids == b].mean() for b in range(len(unc_bins) - 1) | |
| if (bin_ids == b).any()] | |
| ax.scatter(unc, err, alpha=0.04, s=1, color="#94a3b8", rasterized=True) | |
| ax.plot(bn_unc, bn_err, "r-o", markersize=4, linewidth=2, | |
| label="Bin mean") | |
| # Spearman rank correlation (more robust than Pearson for this data) | |
| rho, pval = scipy_stats.spearmanr(unc, err) | |
| p_str = f"p < 0.001" if pval < 0.001 else f"p = {pval:.3f}" | |
| ax.text(0.05, 0.90, | |
| f"Spearman Ο = {rho:.3f}\n{p_str}", | |
| transform=ax.transAxes, fontsize=7.5, | |
| bbox=dict(boxstyle="round,pad=0.25", facecolor="white", | |
| edgecolor="grey", alpha=0.85)) | |
| ax.set_xlabel("Uncertainty (MC std)") | |
| ax.set_ylabel("Absolute error") | |
| ax.set_title(f"{q_name}: {QUESTION_LABELS[q_name]}", fontsize=9) | |
| ax.legend(fontsize=7) | |
| axes[-1].axis("off") | |
| plt.suptitle( | |
| "Epistemic uncertainty vs. absolute prediction error β per morphological question\n" | |
| "Strong Spearman Ο for root/shallow questions; weak Ο for deep conditional branches " | |
| "(expected: aleatoric uncertainty dominates when branch is rarely reached)", | |
| fontsize=10, | |
| ) | |
| plt.tight_layout() | |
| fig.savefig(path_pdf, dpi=300, bbox_inches="tight") | |
| fig.savefig(path_png, dpi=300, bbox_inches="tight") | |
| plt.close(fig) | |
| log.info("Saved: fig_uncertainty_vs_error") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Table: uncertainty summary | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def table_uncertainty_summary(mean_preds, std_preds, | |
| targets, weights, save_dir): | |
| path = save_dir / "table_uncertainty_summary.csv" | |
| if path.exists(): | |
| log.info("Skip (exists): table_uncertainty_summary"); return | |
| rows = [] | |
| for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()): | |
| mask = weights[:, q_idx] >= 0.05 | |
| unc = std_preds[mask, start:end].mean(axis=1) | |
| err = np.abs(mean_preds[mask, start:end] - | |
| targets[mask, start:end]).mean(axis=1) | |
| if mask.sum() > 10: | |
| rho, pval = scipy_stats.spearmanr(unc, err) | |
| else: | |
| rho, pval = float("nan"), float("nan") | |
| rows.append({ | |
| "question" : q_name, | |
| "description" : QUESTION_LABELS[q_name], | |
| "n_reached" : int(mask.sum()), | |
| "mean_uncertainty": round(float(unc.mean()), 5), | |
| "std_uncertainty" : round(float(unc.std()), 5), | |
| "mean_mae" : round(float(err.mean()), 5), | |
| "spearman_rho" : round(float(rho), 4), | |
| "spearman_pval" : round(float(pval), 4), | |
| }) | |
| df = pd.DataFrame(rows) | |
| df.to_csv(path, index=False) | |
| log.info("Saved: table_uncertainty_summary.csv") | |
| print("\n" + df.to_string(index=False) + "\n") | |
| return df | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Figure 3 + Table: Morphology selection benchmark | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def morphology_selection_benchmark(model_results, save_dir): | |
| csv_path = save_dir / "table_morphology_selection_benchmark.csv" | |
| if csv_path.exists(): | |
| log.info("Loading existing morphology benchmark...") | |
| df = pd.read_csv(csv_path) | |
| _fig_morphology_f1(df, save_dir) | |
| return df | |
| rows = [] | |
| for model_name, (preds, targets, weights) in model_results.items(): | |
| for q_name, (ans_idx, ans_label) in SELECTION_ANSWERS.items(): | |
| start, end = QUESTION_GROUPS[q_name] | |
| q_idx = list(QUESTION_GROUPS.keys()).index(q_name) | |
| mask = weights[:, q_idx] >= 0.05 | |
| pred_a = preds[mask, start + ans_idx] | |
| true_a = targets[mask, start + ans_idx] | |
| for thresh in SELECTION_THRESHOLDS: | |
| sel = pred_a >= thresh | |
| true_pos = true_a >= thresh | |
| n_sel = sel.sum() | |
| n_tp_all = true_pos.sum() | |
| n_tp = (sel & true_pos).sum() | |
| prec = n_tp / n_sel if n_sel > 0 else 0.0 | |
| rec = n_tp / n_tp_all if n_tp_all > 0 else 0.0 | |
| f1 = (2 * prec * rec / (prec + rec) | |
| if (prec + rec) > 0 else 0.0) | |
| rows.append({ | |
| "model" : model_name, | |
| "question" : q_name, | |
| "answer" : ans_label, | |
| "threshold" : thresh, | |
| "n_selected": int(n_sel), | |
| "n_true_pos": int(n_tp_all), | |
| "precision" : round(float(prec), 4), | |
| "recall" : round(float(rec), 4), | |
| "f1" : round(float(f1), 4), | |
| }) | |
| df = pd.DataFrame(rows) | |
| df.to_csv(csv_path, index=False) | |
| log.info("Saved: table_morphology_selection_benchmark.csv") | |
| _fig_morphology_f1(df, save_dir) | |
| return df | |
| def _fig_morphology_f1(df, save_dir): | |
| path_pdf = save_dir / "fig_morphology_f1_comparison.pdf" | |
| path_png = save_dir / "fig_morphology_f1_comparison.png" | |
| if path_pdf.exists() and path_png.exists(): | |
| log.info("Skip (exists): fig_morphology_f1_comparison"); return | |
| thresh = 0.8 | |
| sub = df[df["threshold"] == thresh] | |
| q_list = list(SELECTION_ANSWERS.keys()) | |
| models = list(df["model"].unique()) | |
| x = np.arange(len(q_list)) | |
| width = 0.80 / len(models) | |
| palette = list(MODEL_COLORS.values()) | |
| fig, ax = plt.subplots(figsize=(12, 5)) | |
| for i, model in enumerate(models): | |
| f1s = [] | |
| for q in q_list: | |
| row = sub[(sub["model"] == model) & (sub["question"] == q)] | |
| f1s.append(float(row["f1"].values[0]) if len(row) > 0 else 0.0) | |
| ax.bar( | |
| x + i * width, f1s, width, | |
| label=model, | |
| color=MODEL_COLORS.get(model, palette[i % len(palette)]), | |
| alpha=0.85, edgecolor="white", linewidth=0.5, | |
| ) | |
| ax.set_xticks(x + width * (len(models) - 1) / 2) | |
| ax.set_xticklabels( | |
| [f"{q}\n({SELECTION_ANSWERS[q][1]})" for q in q_list], fontsize=9 | |
| ) | |
| ax.set_ylabel("F$_1$ score", fontsize=11) | |
| ax.set_title( | |
| f"Downstream morphology selection β F$_1$ at threshold $\\tau$ = {thresh}\n" | |
| "Higher F$_1$ indicates cleaner, more complete morphological sample selection.", | |
| fontsize=11, | |
| ) | |
| ax.legend(fontsize=8) | |
| ax.set_ylim(0, 1) | |
| ax.grid(True, alpha=0.3, axis="y") | |
| ax.set_axisbelow(True) | |
| plt.tight_layout() | |
| fig.savefig(path_pdf, dpi=300, bbox_inches="tight") | |
| fig.savefig(path_png, dpi=300, bbox_inches="tight") | |
| plt.close(fig) | |
| log.info("Saved: fig_morphology_f1_comparison") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", required=True) | |
| parser.add_argument("--n_passes", type=int, default=30) | |
| args = parser.parse_args() | |
| base_cfg = OmegaConf.load("configs/base.yaml") | |
| exp_cfg = OmegaConf.load(args.config) | |
| cfg = OmegaConf.merge(base_cfg, exp_cfg) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| save_dir = Path(cfg.outputs.figures_dir) / "uncertainty" | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| cache_dir = save_dir / "mc_cache" | |
| ckpt_dir = Path(cfg.outputs.checkpoint_dir) | |
| _, _, test_loader = build_dataloaders(cfg) | |
| # ββ 1. MC Dropout on proposed model βββββββββββββββββββββββ | |
| log.info("Loading proposed model...") | |
| proposed = build_model(cfg).to(device) | |
| proposed.load_state_dict( | |
| torch.load(ckpt_dir / "best_full_train.pt", | |
| map_location="cpu", weights_only=True)["model_state"] | |
| ) | |
| mean_preds, std_preds, targets, weights = run_mc_inference( | |
| proposed, test_loader, device, cfg, | |
| n_passes=args.n_passes, cache_dir=cache_dir, | |
| ) | |
| log.info("MC Dropout complete: %d galaxies, %d passes.", | |
| len(mean_preds), args.n_passes) | |
| # ββ 2. Uncertainty figures and table ββββββββββββββββββββββ | |
| fig_uncertainty_distributions(mean_preds, std_preds, targets, weights, save_dir) | |
| fig_uncertainty_vs_error(mean_preds, std_preds, targets, weights, save_dir) | |
| table_uncertainty_summary(mean_preds, std_preds, targets, weights, save_dir) | |
| # ββ 3. Downstream benchmark across all models βββββββββββββ | |
| log.info("Building model_results for downstream benchmark...") | |
| model_results = { | |
| "ViT-Base + KL+MSE (proposed)": (mean_preds, targets, weights), | |
| } | |
| def _load_resnet(ckpt_name, use_sigmoid): | |
| m = ResNet18Baseline(dropout=cfg.model.dropout).to(device) | |
| m.load_state_dict( | |
| torch.load(ckpt_dir / ckpt_name, map_location="cpu", | |
| weights_only=True)["model_state"] | |
| ) | |
| m.eval() | |
| preds_l, tgts_l, wgts_l = [], [], [] | |
| with torch.no_grad(): | |
| for images, tgts, wgts, _ in tqdm(test_loader, desc=f"ResNet {ckpt_name}"): | |
| images = images.to(device, non_blocking=True) | |
| with autocast("cuda", enabled=cfg.training.mixed_precision): | |
| logits = m(images) | |
| if use_sigmoid: | |
| p = torch.sigmoid(logits).cpu().numpy() | |
| else: | |
| p = logits.detach().cpu().clone() | |
| for q, (s, e) in QUESTION_GROUPS.items(): | |
| p[:, s:e] = F.softmax(p[:, s:e], dim=-1) | |
| p = p.numpy() | |
| preds_l.append(p) | |
| tgts_l.append(tgts.numpy()) | |
| wgts_l.append(wgts.numpy()) | |
| return (np.concatenate(preds_l), | |
| np.concatenate(tgts_l), | |
| np.concatenate(wgts_l)) | |
| rn_mse_ckpt = "baseline_resnet18_mse.pt" | |
| rn_klm_ckpt = "baseline_resnet18_klmse.pt" | |
| if (ckpt_dir / rn_mse_ckpt).exists(): | |
| model_results["ResNet-18 + MSE (sigmoid)"] = _load_resnet( | |
| rn_mse_ckpt, use_sigmoid=True | |
| ) | |
| if (ckpt_dir / rn_klm_ckpt).exists(): | |
| model_results["ResNet-18 + KL+MSE"] = _load_resnet( | |
| rn_klm_ckpt, use_sigmoid=False | |
| ) | |
| dp = ckpt_dir / "baseline_vit_dirichlet.pt" | |
| if dp.exists(): | |
| vit_dir = build_dirichlet_model(cfg).to(device) | |
| vit_dir.load_state_dict( | |
| torch.load(dp, map_location="cpu", weights_only=True)["model_state"] | |
| ) | |
| vit_dir.eval() | |
| d_p, d_t, d_w = [], [], [] | |
| with torch.no_grad(): | |
| for images, tgts, wgts, _ in tqdm(test_loader, desc="Dirichlet"): | |
| images = images.to(device, non_blocking=True) | |
| with autocast("cuda", enabled=cfg.training.mixed_precision): | |
| alpha = vit_dir(images) | |
| p, t, w = dirichlet_predictions_to_numpy(alpha, tgts, wgts) | |
| d_p.append(p); d_t.append(t); d_w.append(w) | |
| model_results["ViT-Base + Dirichlet (Zoobot-style)"] = ( | |
| np.concatenate(d_p), | |
| np.concatenate(d_t), | |
| np.concatenate(d_w), | |
| ) | |
| df_sel = morphology_selection_benchmark(model_results, save_dir) | |
| log.info("=" * 60) | |
| log.info("DOWNSTREAM F1 @ Ο = 0.8") | |
| log.info("=" * 60) | |
| summary = df_sel[df_sel["threshold"] == 0.8][ | |
| ["model", "question", "answer", "precision", "recall", "f1"] | |
| ] | |
| log.info("\n%s\n", summary.to_string(index=False)) | |
| log.info("All outputs saved to: %s", save_dir) | |
| if __name__ == "__main__": | |
| main() | |