Spaces:
Running
Running
| """ | |
| src/baselines.py | |
| ---------------- | |
| Consolidated baseline training for the GZ2 hierarchical probabilistic | |
| regression paper. ALL baselines are trained from this single script. | |
| Replaces the three separate scripts: | |
| src/baselines.py (was: ResNet-18 MSE + ViT MSE) | |
| src/run_resnet_kl.py (was: ResNet-18 KL+MSE β now merged here) | |
| src/train_dirichlet.py (was: ViT Dirichlet β now merged here) | |
| DELETE those three original files after switching to this one. | |
| Baselines trained | |
| ----------------- | |
| B1. ResNet-18 + independent MSE (sigmoid) | |
| β CNN, no hierarchy, no KL. Demonstrates the cost of | |
| ignoring the decision-tree structure. | |
| B2. ResNet-18 + hierarchical KL+MSE | |
| β Same loss as proposed, CNN backbone. | |
| Isolates ViT vs. CNN contribution. | |
| B3. ViT-Base + hierarchical MSE only (no KL) | |
| β Same backbone as proposed, KL term removed. | |
| Isolates contribution of the KL term. | |
| B4. ViT-Base + Dirichlet NLL (Zoobot-style) | |
| β Direct comparison with the established Zoobot approach | |
| (Walmsley et al. 2022, MNRAS 509, 3966). | |
| Proposed model (not trained here β trained via src/train.py): | |
| ViT-Base + hierarchical KL+MSE β outputs/checkpoints/best_full_train.pt | |
| Consistency guarantee | |
| --------------------- | |
| All baselines use identical: | |
| - Random seed, data split, batch size, epochs, early stopping | |
| - AdamW optimiser, CosineAnnealingLR, gradient clipping | |
| - Image transforms and evaluation metric (compute_metrics on same test split) | |
| The ONLY differences between models are the backbone and/or loss function. | |
| Usage | |
| ----- | |
| cd ~/galaxy | |
| nohup python -m src.baselines --config configs/full_train.yaml \ | |
| > outputs/logs/baselines.log 2>&1 & | |
| echo "PID: $!" | |
| """ | |
| import argparse | |
| import logging | |
| import random | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import timm | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from torch.amp import autocast, GradScaler | |
| from omegaconf import OmegaConf | |
| from tqdm import tqdm | |
| import wandb | |
| from src.dataset import build_dataloaders, QUESTION_GROUPS | |
| from src.loss import HierarchicalLoss, DirichletLoss, MSEOnlyLoss | |
| from src.metrics import (compute_metrics, predictions_to_numpy, | |
| dirichlet_predictions_to_numpy, simplex_violation_rate) | |
| from src.model import build_model, build_dirichlet_model | |
| logging.basicConfig( | |
| format="%(asctime)s %(levelname)s %(message)s", | |
| datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout, | |
| ) | |
| log = logging.getLogger("baselines") | |
| QUESTION_LABELS = { | |
| "t01": "Smooth or features", "t02": "Edge-on disk", | |
| "t03": "Bar", "t04": "Spiral arms", | |
| "t05": "Bulge prominence", "t06": "Odd feature", | |
| "t07": "Roundedness", "t08": "Odd feature type", | |
| "t09": "Bulge shape", "t10": "Arms winding", | |
| "t11": "Arms number", | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Reproducibility | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def set_seed(seed: int): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Early stopping (mirrors train.py exactly) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class EarlyStopping: | |
| def __init__(self, patience, min_delta, checkpoint_path): | |
| self.patience = patience | |
| self.min_delta = min_delta | |
| self.checkpoint_path = checkpoint_path | |
| self.best_loss = float("inf") | |
| self.counter = 0 | |
| self.best_epoch = 0 | |
| def step(self, val_loss, model, epoch) -> bool: | |
| if val_loss < self.best_loss - self.min_delta: | |
| self.best_loss = val_loss | |
| self.counter = 0 | |
| self.best_epoch = epoch | |
| torch.save( | |
| {"epoch": epoch, "model_state": model.state_dict(), | |
| "val_loss": val_loss}, | |
| self.checkpoint_path, | |
| ) | |
| log.info(" [ckpt] saved val_loss=%.6f epoch=%d", val_loss, epoch) | |
| else: | |
| self.counter += 1 | |
| log.info(" [early_stop] %d/%d best=%.6f", | |
| self.counter, self.patience, self.best_loss) | |
| return self.counter >= self.patience | |
| def restore_best(self, model) -> float: | |
| ckpt = torch.load(self.checkpoint_path, map_location="cpu", | |
| weights_only=True) | |
| model.load_state_dict(ckpt["model_state"]) | |
| log.info("Restored best weights epoch=%d val_loss=%.6f", | |
| ckpt["epoch"], ckpt["val_loss"]) | |
| return ckpt["val_loss"] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Baseline Model 1: ResNet-18 + independent MSE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ResNet18Baseline(nn.Module): | |
| """ | |
| ResNet-18 pretrained on ImageNet with a dropout + linear head. | |
| Used for both the sigmoid-MSE baseline and the KL+MSE baseline. | |
| """ | |
| def __init__(self, dropout: float = 0.3): | |
| super().__init__() | |
| self.backbone = timm.create_model( | |
| "resnet18", pretrained=True, num_classes=0 | |
| ) | |
| self.head = nn.Sequential( | |
| nn.Dropout(p=dropout), | |
| nn.Linear(self.backbone.num_features, 37), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.head(self.backbone(x)) | |
| class IndependentMSELoss(nn.Module): | |
| """ | |
| Plain MSE over all 37 targets independently. | |
| No hierarchical weighting, no KL divergence. | |
| Sigmoid applied to predictions before MSE to constrain range [0,1]. | |
| Note: predictions do NOT sum to 1 per question group by construction. | |
| This is documented and the simplex_violation_rate metric quantifies | |
| this invalidity to allow fair comparison with the proposed method. | |
| """ | |
| def forward(self, predictions, targets, weights): | |
| pred_prob = torch.sigmoid(predictions) | |
| loss = F.mse_loss(pred_prob, targets) | |
| return loss, {"loss/total": loss.detach().item()} | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Shared training loop | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _train_epoch(model, loader, loss_fn, optimizer, scaler, | |
| device, cfg, epoch, label): | |
| model.train() | |
| total = 0.0 | |
| nb = 0 | |
| for images, targets, weights, _ in tqdm( | |
| loader, desc=f"{label} E{epoch}", leave=False | |
| ): | |
| images = images.to(device, non_blocking=True) | |
| targets = targets.to(device, non_blocking=True) | |
| weights = weights.to(device, non_blocking=True) | |
| optimizer.zero_grad(set_to_none=True) | |
| with autocast("cuda", enabled=cfg.training.mixed_precision): | |
| logits = model(images) | |
| loss, _ = loss_fn(logits, targets, weights) | |
| scaler.scale(loss).backward() | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.training.grad_clip) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| total += loss.item() | |
| nb += 1 | |
| return total / nb | |
| def _train_epoch_dirichlet(model, loader, loss_fn, optimizer, scaler, | |
| device, cfg, epoch, label): | |
| """Training epoch for Dirichlet model (outputs alpha, not logits).""" | |
| model.train() | |
| total = 0.0 | |
| nb = 0 | |
| for images, targets, weights, _ in tqdm( | |
| loader, desc=f"{label} E{epoch}", leave=False | |
| ): | |
| images = images.to(device, non_blocking=True) | |
| targets = targets.to(device, non_blocking=True) | |
| weights = weights.to(device, non_blocking=True) | |
| optimizer.zero_grad(set_to_none=True) | |
| with autocast("cuda", enabled=cfg.training.mixed_precision): | |
| alpha = model(images) | |
| loss, _ = loss_fn(alpha, targets, weights) | |
| scaler.scale(loss).backward() | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.training.grad_clip) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| total += loss.item() | |
| nb += 1 | |
| return total / nb | |
| def _val_epoch(model, loader, loss_fn, device, cfg, epoch, label, | |
| use_sigmoid=False): | |
| model.eval() | |
| total = 0.0 | |
| nb = 0 | |
| all_preds, all_targets, all_weights = [], [], [] | |
| with torch.no_grad(): | |
| for images, targets, weights, _ in tqdm( | |
| loader, desc=f"{label} Val E{epoch}", leave=False | |
| ): | |
| images = images.to(device, non_blocking=True) | |
| targets = targets.to(device, non_blocking=True) | |
| weights = weights.to(device, non_blocking=True) | |
| with autocast("cuda", enabled=cfg.training.mixed_precision): | |
| logits = model(images) | |
| loss, _ = loss_fn(logits, targets, weights) | |
| total += loss.item() | |
| nb += 1 | |
| if use_sigmoid: | |
| pred_prob = torch.sigmoid(logits).detach().cpu().numpy() | |
| else: | |
| pred_cpu = logits.detach().cpu().clone() | |
| for q, (s, e) in QUESTION_GROUPS.items(): | |
| pred_cpu[:, s:e] = torch.softmax(pred_cpu[:, s:e], dim=-1) | |
| pred_prob = pred_cpu.numpy() | |
| all_preds.append(pred_prob) | |
| all_targets.append(targets.detach().cpu().numpy()) | |
| all_weights.append(weights.detach().cpu().numpy()) | |
| all_preds = np.concatenate(all_preds) | |
| all_targets = np.concatenate(all_targets) | |
| all_weights = np.concatenate(all_weights) | |
| metrics = compute_metrics(all_preds, all_targets, all_weights) | |
| return total / nb, metrics | |
| def _val_epoch_dirichlet(model, loader, loss_fn, device, cfg, epoch, label): | |
| model.eval() | |
| total = 0.0 | |
| nb = 0 | |
| all_preds, all_targets, all_weights = [], [], [] | |
| with torch.no_grad(): | |
| for images, targets, weights, _ in tqdm( | |
| loader, desc=f"{label} Val E{epoch}", leave=False | |
| ): | |
| images = images.to(device, non_blocking=True) | |
| targets = targets.to(device, non_blocking=True) | |
| weights = weights.to(device, non_blocking=True) | |
| with autocast("cuda", enabled=cfg.training.mixed_precision): | |
| alpha = model(images) | |
| loss, _ = loss_fn(alpha, targets, weights) | |
| total += loss.item() | |
| nb += 1 | |
| p, t, w = dirichlet_predictions_to_numpy(alpha, targets, weights) | |
| all_preds.append(p) | |
| all_targets.append(t) | |
| all_weights.append(w) | |
| all_preds = np.concatenate(all_preds) | |
| all_targets = np.concatenate(all_targets) | |
| all_weights = np.concatenate(all_weights) | |
| metrics = compute_metrics(all_preds, all_targets, all_weights) | |
| return total / nb, metrics | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Generic train_and_evaluate (non-Dirichlet) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train_and_evaluate( | |
| model, loss_fn, cfg, device, | |
| label, checkpoint_path, | |
| use_layerwise_lr=True, | |
| use_sigmoid=False, | |
| ): | |
| """ | |
| Full training loop consistent with train.py. | |
| Returns (test_metrics, best_val_loss, best_epoch, history). | |
| If checkpoint exists, loads it and skips training. | |
| """ | |
| # Check if checkpoint exists - if so, skip training | |
| if Path(checkpoint_path).exists(): | |
| log.info("%s: checkpoint found - loading and skipping training", label) | |
| ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True) | |
| model.load_state_dict(ckpt["model_state"]) | |
| best_epoch = ckpt.get("epoch", 0) | |
| best_val = ckpt.get("val_loss", float("inf")) | |
| log.info("Restored: epoch=%d, val_loss=%.6f", best_epoch, best_val) | |
| # Evaluate on test set | |
| _, _, test_loader = build_dataloaders(cfg) | |
| _, test_metrics = _val_epoch( | |
| model, test_loader, loss_fn, device, cfg, | |
| epoch=0, label=f"{label}-test", use_sigmoid=use_sigmoid | |
| ) | |
| return test_metrics, best_val, best_epoch, [] | |
| train_loader, val_loader, test_loader = build_dataloaders(cfg) | |
| if use_layerwise_lr and hasattr(model, "backbone") and hasattr(model, "head"): | |
| optimizer = torch.optim.AdamW( | |
| [ | |
| {"params": model.backbone.parameters(), | |
| "lr": cfg.training.learning_rate * 0.1}, | |
| {"params": model.head.parameters(), | |
| "lr": cfg.training.learning_rate}, | |
| ], | |
| weight_decay=cfg.training.weight_decay, | |
| ) | |
| log.info("%s: layer-wise lr β backbone=%.1e head=%.1e", | |
| label, cfg.training.learning_rate * 0.1, cfg.training.learning_rate) | |
| else: | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=cfg.training.learning_rate, | |
| weight_decay=cfg.training.weight_decay, | |
| ) | |
| log.info("%s: single lr=%.1e", label, cfg.training.learning_rate) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min | |
| ) | |
| scaler = GradScaler("cuda") | |
| early_stop = EarlyStopping( | |
| patience=cfg.early_stopping.patience, | |
| min_delta=cfg.early_stopping.min_delta, | |
| checkpoint_path=checkpoint_path, | |
| ) | |
| wandb.init( | |
| project=cfg.wandb.project, | |
| name=label, | |
| config={ | |
| "model": label, "backbone": "resnet18" if "ResNet" in label else "vit_base_patch16_224", | |
| "batch_size": cfg.training.batch_size, "lr": cfg.training.learning_rate, | |
| "epochs": cfg.training.epochs, "seed": cfg.seed, | |
| "lambda_kl": cfg.loss.lambda_kl, "lambda_mse": cfg.loss.lambda_mse, | |
| }, | |
| reinit=True, | |
| ) | |
| history = [] | |
| for epoch in range(1, cfg.training.epochs + 1): | |
| train_loss = _train_epoch( | |
| model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch, label | |
| ) | |
| val_loss, val_metrics = _val_epoch( | |
| model, val_loader, loss_fn, device, cfg, epoch, label, | |
| use_sigmoid=use_sigmoid | |
| ) | |
| scheduler.step() | |
| lr = scheduler.get_last_lr()[0] | |
| val_mae = val_metrics.get("mae/weighted_avg", 0) | |
| log.info("%s epoch=%d train=%.4f val=%.4f mae=%.4f lr=%.2e", | |
| label, epoch, train_loss, val_loss, val_mae, lr) | |
| history.append({ | |
| "epoch": epoch, "train_loss": train_loss, | |
| "val_loss": val_loss, "val_mae": val_mae, | |
| }) | |
| wandb.log({ | |
| "train_loss": train_loss, "val_loss": val_loss, | |
| "val_mae": val_mae, "lr": lr, | |
| }, step=epoch) | |
| if early_stop.step(val_loss, model, epoch): | |
| log.info("%s: early stopping at epoch %d best=%d", | |
| label, epoch, early_stop.best_epoch) | |
| break | |
| best_val = early_stop.restore_best(model) | |
| wandb.finish() | |
| log.info("%s: evaluating on test set...", label) | |
| _, test_metrics = _val_epoch( | |
| model, test_loader, loss_fn, device, cfg, | |
| epoch=0, label=f"{label}-test", use_sigmoid=use_sigmoid | |
| ) | |
| return test_metrics, best_val, early_stop.best_epoch, history | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Dirichlet train_and_evaluate | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train_and_evaluate_dirichlet(model, loss_fn, cfg, device, | |
| label, checkpoint_path): | |
| """Training loop for Dirichlet model. Skips training if checkpoint exists.""" | |
| # Check if checkpoint exists - if so, skip training | |
| if Path(checkpoint_path).exists(): | |
| log.info("%s: checkpoint found - loading and skipping training", label) | |
| ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True) | |
| model.load_state_dict(ckpt["model_state"]) | |
| best_epoch = ckpt.get("epoch", 0) | |
| best_val = ckpt.get("val_loss", float("inf")) | |
| log.info("Restored: epoch=%d, val_loss=%.6f", best_epoch, best_val) | |
| # Evaluate on test set | |
| _, _, test_loader = build_dataloaders(cfg) | |
| _, test_metrics = _val_epoch_dirichlet( | |
| model, test_loader, loss_fn, device, cfg, | |
| epoch=0, label=f"{label}-test" | |
| ) | |
| return test_metrics, best_val, best_epoch, [] | |
| train_loader, val_loader, test_loader = build_dataloaders(cfg) | |
| optimizer = torch.optim.AdamW( | |
| [ | |
| {"params": model.backbone.parameters(), | |
| "lr": cfg.training.learning_rate * 0.1}, | |
| {"params": model.head.parameters(), | |
| "lr": cfg.training.learning_rate}, | |
| ], | |
| weight_decay=cfg.training.weight_decay, | |
| ) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min | |
| ) | |
| scaler = GradScaler("cuda") | |
| early_stop = EarlyStopping( | |
| patience=cfg.early_stopping.patience, | |
| min_delta=cfg.early_stopping.min_delta, | |
| checkpoint_path=checkpoint_path, | |
| ) | |
| wandb.init( | |
| project=cfg.wandb.project, name=label, | |
| config={"model": label, "loss": "DirichletNLL", | |
| "seed": cfg.seed, "epochs": cfg.training.epochs}, | |
| reinit=True, | |
| ) | |
| history = [] | |
| for epoch in range(1, cfg.training.epochs + 1): | |
| train_loss = _train_epoch_dirichlet( | |
| model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch, label | |
| ) | |
| val_loss, val_metrics = _val_epoch_dirichlet( | |
| model, val_loader, loss_fn, device, cfg, epoch, label | |
| ) | |
| scheduler.step() | |
| lr = scheduler.get_last_lr()[0] | |
| val_mae = val_metrics.get("mae/weighted_avg", 0) | |
| log.info("%s epoch=%d train=%.4f val=%.4f mae=%.4f lr=%.2e", | |
| label, epoch, train_loss, val_loss, val_mae, lr) | |
| history.append({ | |
| "epoch": epoch, "train_loss": train_loss, | |
| "val_loss": val_loss, "val_mae": val_mae, | |
| }) | |
| wandb.log({ | |
| "train_loss": train_loss, "val_loss": val_loss, | |
| "val_mae": val_mae, "lr": lr, | |
| }, step=epoch) | |
| if early_stop.step(val_loss, model, epoch): | |
| log.info("%s: early stopping at epoch %d", label, epoch) | |
| break | |
| best_val = early_stop.restore_best(model) | |
| wandb.finish() | |
| log.info("%s: evaluating on test set...", label) | |
| _, test_metrics = _val_epoch_dirichlet( | |
| model, test_loader, loss_fn, device, cfg, epoch=0, label=f"{label}-test" | |
| ) | |
| return test_metrics, best_val, early_stop.best_epoch, history | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Figures | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _save_comparison_figures(all_results, all_histories, save_dir): | |
| """ | |
| Saves: | |
| 1. Per-question MAE + RMSE bar chart | |
| 2. Validation MAE learning curves | |
| 3. Simplex violation table for sigmoid baseline | |
| All figure names follow IEEE journal conventions. | |
| """ | |
| q_names = list(QUESTION_GROUPS.keys()) | |
| n_models = len(all_results) | |
| x = np.arange(len(q_names)) | |
| width = 0.80 / n_models | |
| palette = ["#c0392b", "#e67e22", "#2980b9", "#27ae60", "#8e44ad"] | |
| # ββ Figure 1: Per-question MAE and RMSE βββββββββββββββββββ | |
| fig, axes = plt.subplots(1, 2, figsize=(16, 6)) | |
| for metric, ax, ylabel in [ | |
| ("mae", axes[0], "Mean Absolute Error (MAE)"), | |
| ("rmse", axes[1], "Root Mean Squared Error (RMSE)"), | |
| ]: | |
| for i, (row_d, color) in enumerate(zip(all_results, palette)): | |
| vals = [row_d.get(f"{metric}_{q}", np.nan) for q in q_names] | |
| ax.bar(x + i * width, vals, width, | |
| label=row_d["model"], color=color, | |
| alpha=0.85, edgecolor="white", linewidth=0.5) | |
| ax.set_xticks(x + width * (n_models - 1) / 2) | |
| ax.set_xticklabels( | |
| [f"{q}\n({QUESTION_LABELS[q][:10]})" for q in q_names], | |
| rotation=45, ha="right", fontsize=7, | |
| ) | |
| ax.set_ylabel(ylabel, fontsize=11) | |
| ax.set_title(f"Per-question {metric.upper()} β baseline comparison", fontsize=11) | |
| ax.legend(fontsize=7, loc="upper right") | |
| ax.grid(True, alpha=0.3, axis="y") | |
| ax.set_axisbelow(True) | |
| plt.suptitle( | |
| "Baseline comparison β GZ2 hierarchical probabilistic regression\n" | |
| "Full 239,267-sample dataset, identical seed/split/protocol", | |
| fontsize=12, y=1.02, | |
| ) | |
| plt.tight_layout() | |
| fig.savefig(save_dir / "fig_baseline_comparison_mae_rmse.pdf", | |
| dpi=300, bbox_inches="tight") | |
| fig.savefig(save_dir / "fig_baseline_comparison_mae_rmse.png", | |
| dpi=300, bbox_inches="tight") | |
| plt.close(fig) | |
| log.info("Saved: fig_baseline_comparison_mae_rmse") | |
| # ββ Figure 2: Validation MAE learning curves βββββββββββββββ | |
| fig2, ax2 = plt.subplots(figsize=(10, 5)) | |
| styles = ["-", "--", "-.", ":", (0, (3, 1, 1, 1))] | |
| markers = ["o", "s", "^", "D", "v"] | |
| for (name, hist), ls, color, mk in zip( | |
| all_histories.items(), styles, palette, markers | |
| ): | |
| epochs_h = [h["epoch"] for h in hist] | |
| val_maes = [h["val_mae"] for h in hist] | |
| ax2.plot(epochs_h, val_maes, linestyle=ls, color=color, linewidth=1.8, | |
| label=name, marker=mk, markersize=3, markevery=5) | |
| ax2.set_xlabel("Epoch", fontsize=11) | |
| ax2.set_ylabel("Validation MAE (weighted average)", fontsize=11) | |
| ax2.set_title("Validation MAE during training β all baseline models", fontsize=11) | |
| ax2.legend(fontsize=9) | |
| ax2.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| fig2.savefig(save_dir / "fig_baseline_val_mae_curves.pdf", | |
| dpi=300, bbox_inches="tight") | |
| fig2.savefig(save_dir / "fig_baseline_val_mae_curves.png", | |
| dpi=300, bbox_inches="tight") | |
| plt.close(fig2) | |
| log.info("Saved: fig_baseline_val_mae_curves") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", required=True) | |
| args = parser.parse_args() | |
| base_cfg = OmegaConf.load("configs/base.yaml") | |
| exp_cfg = OmegaConf.load(args.config) | |
| cfg = OmegaConf.merge(base_cfg, exp_cfg) | |
| set_seed(cfg.seed) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| log.info("Device: %s Dataset: %s", | |
| device, "full 239k" if cfg.data.n_samples is None | |
| else f"{cfg.data.n_samples:,}") | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| save_dir = Path(cfg.outputs.figures_dir) / "comparison" | |
| ckpt_dir = Path(cfg.outputs.checkpoint_dir) | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| all_results = [] | |
| all_histories = {} | |
| # βββ B1: ResNet-18 + independent MSE (sigmoid) ββββββββββββ | |
| log.info("=" * 60) | |
| log.info("B1: ResNet-18 + independent MSE (sigmoid, no hierarchy)") | |
| log.info("=" * 60) | |
| set_seed(cfg.seed) | |
| rn_mse_model = ResNet18Baseline(dropout=cfg.model.dropout).to(device) | |
| rn_mse_loss = IndependentMSELoss() | |
| log.info("ResNet-18 params: %s", f"{sum(p.numel() for p in rn_mse_model.parameters()):,}") | |
| rn_mse_metrics, rn_mse_val, rn_mse_epoch, rn_mse_hist = train_and_evaluate( | |
| rn_mse_model, rn_mse_loss, cfg, device, | |
| label = "B1-ResNet18-MSE", | |
| checkpoint_path = str(ckpt_dir / "baseline_resnet18_mse.pt"), | |
| use_layerwise_lr = False, | |
| use_sigmoid = True, | |
| ) | |
| # Simplex violation for this baseline | |
| _, _, test_loader_tmp = build_dataloaders(cfg) | |
| rn_mse_model.eval() | |
| tmp_preds = [] | |
| with torch.no_grad(): | |
| for images, _, _, _ in test_loader_tmp: | |
| images = images.to(device, non_blocking=True) | |
| logits = rn_mse_model(images) | |
| tmp_preds.append(torch.sigmoid(logits).cpu().numpy()) | |
| tmp_preds = np.concatenate(tmp_preds) | |
| svr = simplex_violation_rate(tmp_preds, tolerance=0.02) | |
| log.info("B1 simplex violation rate (mean): %.4f", svr["mean"]) | |
| row = { | |
| "model": "ResNet-18 + MSE (sigmoid, no hierarchy)", | |
| "backbone": "ResNet-18", "loss": "Independent MSE", | |
| "hierarchy": "None", | |
| "best_epoch": rn_mse_epoch, "best_val_loss": round(rn_mse_val, 5), | |
| "mae_weighted" : round(rn_mse_metrics["mae/weighted_avg"], 5), | |
| "rmse_weighted": round(rn_mse_metrics["rmse/weighted_avg"], 5), | |
| "simplex_violation_mean": round(svr["mean"], 4), | |
| } | |
| for q in QUESTION_GROUPS: | |
| row[f"mae_{q}"] = round(rn_mse_metrics[f"mae/{q}"], 5) | |
| row[f"rmse_{q}"] = round(rn_mse_metrics[f"rmse/{q}"], 5) | |
| all_results.append(row) | |
| all_histories["ResNet-18 + MSE (sigmoid)"] = rn_mse_hist | |
| log.info("B1 done: MAE=%.5f RMSE=%.5f SimplexViol=%.4f", | |
| rn_mse_metrics["mae/weighted_avg"], | |
| rn_mse_metrics["rmse/weighted_avg"], | |
| svr["mean"]) | |
| # βββ B2: ResNet-18 + hierarchical KL+MSE ββββββββββββββββββ | |
| log.info("=" * 60) | |
| log.info("B2: ResNet-18 + hierarchical KL+MSE (same loss as proposed)") | |
| log.info("=" * 60) | |
| set_seed(cfg.seed) | |
| rn_kl_model = ResNet18Baseline(dropout=cfg.model.dropout).to(device) | |
| rn_kl_loss = HierarchicalLoss(cfg) | |
| rn_kl_metrics, rn_kl_val, rn_kl_epoch, rn_kl_hist = train_and_evaluate( | |
| rn_kl_model, rn_kl_loss, cfg, device, | |
| label = "B2-ResNet18-KL+MSE", | |
| checkpoint_path = str(ckpt_dir / "baseline_resnet18_klmse.pt"), | |
| use_layerwise_lr = False, | |
| use_sigmoid = False, | |
| ) | |
| row = { | |
| "model": "ResNet-18 + hierarchical KL+MSE", | |
| "backbone": "ResNet-18", "loss": "Hierarchical KL+MSE (Ξ»=0.5)", | |
| "hierarchy": "Full (weights + KL)", | |
| "best_epoch": rn_kl_epoch, "best_val_loss": round(rn_kl_val, 5), | |
| "mae_weighted" : round(rn_kl_metrics["mae/weighted_avg"], 5), | |
| "rmse_weighted": round(rn_kl_metrics["rmse/weighted_avg"], 5), | |
| "simplex_violation_mean": 0.0, # softmax guarantees validity | |
| } | |
| for q in QUESTION_GROUPS: | |
| row[f"mae_{q}"] = round(rn_kl_metrics[f"mae/{q}"], 5) | |
| row[f"rmse_{q}"] = round(rn_kl_metrics[f"rmse/{q}"], 5) | |
| all_results.append(row) | |
| all_histories["ResNet-18 + KL+MSE"] = rn_kl_hist | |
| log.info("B2 done: MAE=%.5f RMSE=%.5f", | |
| rn_kl_metrics["mae/weighted_avg"], | |
| rn_kl_metrics["rmse/weighted_avg"]) | |
| # βββ B3: ViT-Base + hierarchical MSE only βββββββββββββββββ | |
| log.info("=" * 60) | |
| log.info("B3: ViT-Base + hierarchical MSE only (no KL term)") | |
| log.info("=" * 60) | |
| set_seed(cfg.seed) | |
| from omegaconf import OmegaConf as OC | |
| vit_mse_cfg = OC.merge(cfg, OC.create({"loss": {"lambda_kl": 0.0, "lambda_mse": 1.0}})) | |
| vit_mse_model = build_model(vit_mse_cfg).to(device) | |
| vit_mse_loss = MSEOnlyLoss(vit_mse_cfg) | |
| vit_mse_metrics, vit_mse_val, vit_mse_epoch, vit_mse_hist = train_and_evaluate( | |
| vit_mse_model, vit_mse_loss, vit_mse_cfg, device, | |
| label = "B3-ViT-MSE", | |
| checkpoint_path = str(ckpt_dir / "baseline_vit_mse.pt"), | |
| use_layerwise_lr = True, | |
| use_sigmoid = False, | |
| ) | |
| row = { | |
| "model": "ViT-Base + hierarchical MSE (no KL)", | |
| "backbone": "ViT-Base/16", "loss": "Hierarchical MSE (Ξ»_KL=0)", | |
| "hierarchy": "Weights only", | |
| "best_epoch": vit_mse_epoch, "best_val_loss": round(vit_mse_val, 5), | |
| "mae_weighted" : round(vit_mse_metrics["mae/weighted_avg"], 5), | |
| "rmse_weighted": round(vit_mse_metrics["rmse/weighted_avg"], 5), | |
| "simplex_violation_mean": 0.0, | |
| } | |
| for q in QUESTION_GROUPS: | |
| row[f"mae_{q}"] = round(vit_mse_metrics[f"mae/{q}"], 5) | |
| row[f"rmse_{q}"] = round(vit_mse_metrics[f"rmse/{q}"], 5) | |
| all_results.append(row) | |
| all_histories["ViT-Base + MSE only"] = vit_mse_hist | |
| log.info("B3 done: MAE=%.5f RMSE=%.5f", | |
| vit_mse_metrics["mae/weighted_avg"], | |
| vit_mse_metrics["rmse/weighted_avg"]) | |
| # βββ B4: ViT-Base + Dirichlet NLL (Zoobot-style) ββββββββββ | |
| log.info("=" * 60) | |
| log.info("B4: ViT-Base + Dirichlet NLL (Walmsley et al. 2022)") | |
| log.info("=" * 60) | |
| set_seed(cfg.seed) | |
| vit_dir_model = build_dirichlet_model(cfg).to(device) | |
| vit_dir_loss = DirichletLoss(cfg) | |
| vit_dir_metrics, vit_dir_val, vit_dir_epoch, vit_dir_hist = train_and_evaluate_dirichlet( | |
| vit_dir_model, vit_dir_loss, cfg, device, | |
| label = "B4-ViT-Dirichlet", | |
| checkpoint_path = str(ckpt_dir / "baseline_vit_dirichlet.pt"), | |
| ) | |
| row = { | |
| "model": "ViT-Base + Dirichlet NLL (Zoobot-style)", | |
| "backbone": "ViT-Base/16", "loss": "Dirichlet NLL", | |
| "hierarchy": "Full (weights + Dirichlet)", | |
| "best_epoch": vit_dir_epoch, "best_val_loss": round(vit_dir_val, 5), | |
| "mae_weighted" : round(vit_dir_metrics["mae/weighted_avg"], 5), | |
| "rmse_weighted": round(vit_dir_metrics["rmse/weighted_avg"], 5), | |
| "simplex_violation_mean": 0.0, | |
| } | |
| for q in QUESTION_GROUPS: | |
| row[f"mae_{q}"] = round(vit_dir_metrics[f"mae/{q}"], 5) | |
| row[f"rmse_{q}"] = round(vit_dir_metrics[f"rmse/{q}"], 5) | |
| all_results.append(row) | |
| all_histories["ViT-Base + Dirichlet"] = vit_dir_hist | |
| log.info("B4 done: MAE=%.5f RMSE=%.5f", | |
| vit_dir_metrics["mae/weighted_avg"], | |
| vit_dir_metrics["rmse/weighted_avg"]) | |
| # βββ Proposed: load existing checkpoint for final table ββββ | |
| proposed_ckpt = ckpt_dir / "best_full_train.pt" | |
| if proposed_ckpt.exists(): | |
| log.info("=" * 60) | |
| log.info("PROPOSED: Loading ViT-Base + hierarchical KL+MSE") | |
| log.info("=" * 60) | |
| proposed_model = build_model(cfg).to(device) | |
| proposed_model.load_state_dict( | |
| torch.load(proposed_ckpt, map_location="cpu", weights_only=True)["model_state"] | |
| ) | |
| _, _, test_loader_p = build_dataloaders(cfg) | |
| _, proposed_metrics = _val_epoch( | |
| proposed_model, test_loader_p, HierarchicalLoss(cfg), device, cfg, | |
| epoch=0, label="Proposed-test", use_sigmoid=False | |
| ) | |
| ckpt_info = torch.load(proposed_ckpt, map_location="cpu", weights_only=True) | |
| row = { | |
| "model": "ViT-Base + hierarchical KL+MSE (proposed)", | |
| "backbone": "ViT-Base/16", "loss": "Hierarchical KL+MSE (Ξ»=0.5)", | |
| "hierarchy": "Full (weights + KL)", | |
| "best_epoch": ckpt_info["epoch"], | |
| "best_val_loss": round(ckpt_info["val_loss"], 5), | |
| "mae_weighted" : round(proposed_metrics["mae/weighted_avg"], 5), | |
| "rmse_weighted": round(proposed_metrics["rmse/weighted_avg"], 5), | |
| "simplex_violation_mean": 0.0, | |
| } | |
| for q in QUESTION_GROUPS: | |
| row[f"mae_{q}"] = round(proposed_metrics[f"mae/{q}"], 5) | |
| row[f"rmse_{q}"] = round(proposed_metrics[f"rmse/{q}"], 5) | |
| all_results.append(row) | |
| log.info("Proposed: MAE=%.5f RMSE=%.5f", | |
| proposed_metrics["mae/weighted_avg"], | |
| proposed_metrics["rmse/weighted_avg"]) | |
| # βββ Save results ββββββββββββββββββββββββββββββββββββββββββ | |
| df = pd.DataFrame(all_results) | |
| df.to_csv(save_dir / "table_baseline_comparison.csv", index=False) | |
| summary_cols = ["model", "loss", "hierarchy", "best_epoch", | |
| "best_val_loss", "mae_weighted", "rmse_weighted", | |
| "simplex_violation_mean"] | |
| summary = df[[c for c in summary_cols if c in df.columns]].copy() | |
| summary.to_csv(save_dir / "table_baseline_summary.csv", index=False) | |
| print() | |
| print("=" * 80) | |
| print("BASELINE COMPARISON β FINAL RESULTS") | |
| print("=" * 80) | |
| print(summary.to_string(index=False)) | |
| print() | |
| # βββ Figures βββββββββββββββββββββββββββββββββββββββββββββββ | |
| _save_comparison_figures(all_results, all_histories, save_dir) | |
| log.info("All baseline outputs saved to: %s", save_dir) | |
| if __name__ == "__main__": | |
| main() | |