Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- Dockerfile +16 -0
- app.py +127 -0
- best_full_train.pt +3 -0
- configs/ablation.yaml +25 -0
- configs/base.yaml +64 -0
- configs/full_train.yaml +26 -0
- configs/subset_60k.yaml +25 -0
- requirements.txt +11 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/__pycache__/ablation.cpython-310.pyc +0 -0
- src/__pycache__/attention_viz.cpython-310.pyc +0 -0
- src/__pycache__/baselines.cpython-310.pyc +0 -0
- src/__pycache__/baselines.cpython-312.pyc +0 -0
- src/__pycache__/dataset.cpython-310.pyc +0 -0
- src/__pycache__/evaluate_full.cpython-310.pyc +0 -0
- src/__pycache__/evaluate_full.cpython-312.pyc +0 -0
- src/__pycache__/loss.cpython-310.pyc +0 -0
- src/__pycache__/metrics.cpython-310.pyc +0 -0
- src/__pycache__/model.cpython-310.pyc +0 -0
- src/__pycache__/model.cpython-312.pyc +0 -0
- src/__pycache__/train.cpython-310.pyc +0 -0
- src/__pycache__/train_single.cpython-310.pyc +0 -0
- src/__pycache__/train_single.cpython-312.pyc +0 -0
- src/__pycache__/uncertainty_analysis.cpython-310.pyc +0 -0
- src/ablation.py +258 -0
- src/attention_viz.py +316 -0
- src/baselines.py +844 -0
- src/dataset.py +270 -0
- src/evaluate_full.py +619 -0
- src/loss.py +155 -0
- src/metrics.py +335 -0
- src/model.py +316 -0
- src/train.py +327 -0
- src/train_single.py +419 -0
- src/uncertainty_analysis.py +578 -0
Dockerfile
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /code
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y \
|
| 6 |
+
libgl1 \
|
| 7 |
+
libglib2.0-0 \
|
| 8 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
COPY ./requirements.txt /code/requirements.txt
|
| 11 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
| 12 |
+
|
| 13 |
+
COPY . /code/
|
| 14 |
+
|
| 15 |
+
# Run the FastAPI server on port 7860 (Hugging Face Spaces default)
|
| 16 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import io
|
| 4 |
+
import base64
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import torch
|
| 8 |
+
from fastapi import FastAPI, File, UploadFile
|
| 9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
+
from pydantic import BaseModel
|
| 11 |
+
|
| 12 |
+
# Add current directory to path so HF Space finds it
|
| 13 |
+
import sys
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
if current_dir not in sys.path:
|
| 18 |
+
sys.path.append(current_dir)
|
| 19 |
+
|
| 20 |
+
from omegaconf import OmegaConf
|
| 21 |
+
from src.model import build_model
|
| 22 |
+
from src.attention_viz import attention_rollout_full, make_overlay
|
| 23 |
+
from src.dataset import QUESTION_GROUPS
|
| 24 |
+
from torchvision import transforms
|
| 25 |
+
|
| 26 |
+
app = FastAPI()
|
| 27 |
+
|
| 28 |
+
app.add_middleware(
|
| 29 |
+
CORSMiddleware,
|
| 30 |
+
allow_origins=["*"],
|
| 31 |
+
allow_credentials=True,
|
| 32 |
+
allow_methods=["*"],
|
| 33 |
+
allow_headers=["*"],
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 37 |
+
model = None
|
| 38 |
+
cfg = None
|
| 39 |
+
transform = None
|
| 40 |
+
|
| 41 |
+
@app.on_event("startup")
|
| 42 |
+
def load_model():
|
| 43 |
+
global model, cfg, transform
|
| 44 |
+
print("Loading configuration...")
|
| 45 |
+
base_cfg = OmegaConf.load(os.path.join(current_dir, "configs/base.yaml"))
|
| 46 |
+
|
| 47 |
+
# We load the full train config
|
| 48 |
+
try:
|
| 49 |
+
exp_cfg = OmegaConf.load(os.path.join(current_dir, "configs/full_train.yaml"))
|
| 50 |
+
cfg = OmegaConf.merge(base_cfg, exp_cfg)
|
| 51 |
+
except:
|
| 52 |
+
cfg = base_cfg
|
| 53 |
+
|
| 54 |
+
print("Building model...")
|
| 55 |
+
model = build_model(cfg).to(device)
|
| 56 |
+
|
| 57 |
+
ckpt_path = os.path.join(current_dir, "best_full_train.pt")
|
| 58 |
+
if os.path.exists(ckpt_path):
|
| 59 |
+
print(f"Loading checkpoint from {ckpt_path}")
|
| 60 |
+
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
|
| 61 |
+
model.load_state_dict(ckpt["model_state"])
|
| 62 |
+
else:
|
| 63 |
+
print(f"WARNING: Checkpoint not found at {ckpt_path}")
|
| 64 |
+
|
| 65 |
+
model.eval()
|
| 66 |
+
|
| 67 |
+
# Galaxy Zoo image transform: resize, crop, center, normalize
|
| 68 |
+
# Assuming standard Imagenet + ViT transforms for 224x224
|
| 69 |
+
transform = transforms.Compose([
|
| 70 |
+
transforms.Resize(224),
|
| 71 |
+
transforms.CenterCrop(224),
|
| 72 |
+
transforms.ToTensor(),
|
| 73 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 74 |
+
])
|
| 75 |
+
|
| 76 |
+
@app.post("/api/predict")
|
| 77 |
+
async def predict(file: UploadFile = File(...)):
|
| 78 |
+
contents = await file.read()
|
| 79 |
+
image = Image.open(io.BytesIO(contents)).convert("RGB")
|
| 80 |
+
|
| 81 |
+
# Transform image
|
| 82 |
+
img_tensor = transform(image).unsqueeze(0).to(device)
|
| 83 |
+
|
| 84 |
+
with torch.no_grad():
|
| 85 |
+
with torch.amp.autocast("cuda", enabled=True):
|
| 86 |
+
logits = model(img_tensor)
|
| 87 |
+
|
| 88 |
+
# Get attention weights
|
| 89 |
+
layers = model.get_all_attention_weights()
|
| 90 |
+
|
| 91 |
+
# Process predictions mapping
|
| 92 |
+
predictions = logits[0].cpu().numpy()
|
| 93 |
+
results = {}
|
| 94 |
+
|
| 95 |
+
# In proper evaluation, hierarchical softmax is applied per question group
|
| 96 |
+
import torch.nn.functional as F
|
| 97 |
+
probs = logits.detach().cpu().clone()
|
| 98 |
+
for q_name, (start, end) in QUESTION_GROUPS.items():
|
| 99 |
+
probs[:, start:end] = F.softmax(probs[:, start:end], dim=-1)
|
| 100 |
+
|
| 101 |
+
probs_np = probs[0].numpy()
|
| 102 |
+
|
| 103 |
+
for q_name, (start, end) in QUESTION_GROUPS.items():
|
| 104 |
+
results[q_name] = probs_np[start:end].tolist()
|
| 105 |
+
|
| 106 |
+
# Generate Attention Heatmap Overlay
|
| 107 |
+
if layers is not None:
|
| 108 |
+
# attention_rollout_full expects list of [B, H, N+1, N+1]
|
| 109 |
+
all_layer_attns = [l.cpu() for l in layers]
|
| 110 |
+
rollout_map = attention_rollout_full(all_layer_attns, patch_size=16, image_size=224)[0]
|
| 111 |
+
|
| 112 |
+
# original image numpy for overlay (denormalised size)
|
| 113 |
+
original_img_np = np.array(image.resize((224, 224)))
|
| 114 |
+
overlay = make_overlay(original_img_np, rollout_map, alpha=0.5, colormap="inferno")
|
| 115 |
+
|
| 116 |
+
# Encode to base64
|
| 117 |
+
overlay_img = Image.fromarray(overlay)
|
| 118 |
+
buffered = io.BytesIO()
|
| 119 |
+
overlay_img.save(buffered, format="PNG")
|
| 120 |
+
heatmap_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 121 |
+
else:
|
| 122 |
+
heatmap_base64 = None
|
| 123 |
+
|
| 124 |
+
return {
|
| 125 |
+
"predictions": results,
|
| 126 |
+
"heatmap": heatmap_base64
|
| 127 |
+
}
|
best_full_train.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f31287bca388d29f3b144a9a407f041b4f02c6742ece7614751aab70cd6f04ea
|
| 3 |
+
size 343371682
|
configs/ablation.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 2 |
+
# configs/ablation.yaml
|
| 3 |
+
# Phase 1: lambda_kl ablation on a 10k subset.
|
| 4 |
+
# Run FIRST before any full training.
|
| 5 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 6 |
+
|
| 7 |
+
defaults:
|
| 8 |
+
- base
|
| 9 |
+
|
| 10 |
+
experiment_name : "ablation"
|
| 11 |
+
|
| 12 |
+
data:
|
| 13 |
+
n_samples : 10000 # ablation uses 10k for speed
|
| 14 |
+
|
| 15 |
+
training:
|
| 16 |
+
epochs : 15 # sufficient to converge on 10k
|
| 17 |
+
|
| 18 |
+
scheduler:
|
| 19 |
+
T_max : 15
|
| 20 |
+
|
| 21 |
+
early_stopping:
|
| 22 |
+
patience : 5
|
| 23 |
+
|
| 24 |
+
wandb:
|
| 25 |
+
log_attention_every_n_epochs : 99 # disable attention in ablation
|
configs/base.yaml
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 2 |
+
# configs/base.yaml
|
| 3 |
+
# Base configuration for all experiments.
|
| 4 |
+
# All experiment configs inherit and override from this file.
|
| 5 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 6 |
+
|
| 7 |
+
project_name : "gz2-hierarchical-vit"
|
| 8 |
+
experiment_name : "base"
|
| 9 |
+
seed : 42
|
| 10 |
+
|
| 11 |
+
data:
|
| 12 |
+
parquet_path : "data/labels.parquet"
|
| 13 |
+
image_dir : "data/images"
|
| 14 |
+
image_id_col : "dr7objid"
|
| 15 |
+
image_size : 224
|
| 16 |
+
n_samples : null # null = full dataset
|
| 17 |
+
train_frac : 0.80
|
| 18 |
+
val_frac : 0.10
|
| 19 |
+
test_frac : 0.10
|
| 20 |
+
num_workers : 12
|
| 21 |
+
pin_memory : true
|
| 22 |
+
persistent_workers : true
|
| 23 |
+
prefetch_factor : 4
|
| 24 |
+
|
| 25 |
+
model:
|
| 26 |
+
backbone : "vit_base_patch16_224"
|
| 27 |
+
pretrained : true
|
| 28 |
+
# FIXED: increased from 0.1 β 0.3 to reduce overfitting on 86M-param model.
|
| 29 |
+
# Loss curves showed train/val divergence from epoch ~12 with dropout=0.1.
|
| 30 |
+
dropout : 0.3
|
| 31 |
+
|
| 32 |
+
loss:
|
| 33 |
+
lambda_kl : 0.5 # weight of KL divergence term
|
| 34 |
+
lambda_mse : 0.5 # weight of MSE term
|
| 35 |
+
epsilon : 1.0e-8 # numerical stability clamp
|
| 36 |
+
|
| 37 |
+
training:
|
| 38 |
+
epochs : 100
|
| 39 |
+
batch_size : 64
|
| 40 |
+
learning_rate : 1.0e-4
|
| 41 |
+
weight_decay : 1.0e-4
|
| 42 |
+
grad_clip : 1.0
|
| 43 |
+
mixed_precision : true
|
| 44 |
+
|
| 45 |
+
early_stopping:
|
| 46 |
+
patience : 10
|
| 47 |
+
min_delta : 1.0e-5
|
| 48 |
+
monitor : "val/loss_total"
|
| 49 |
+
|
| 50 |
+
scheduler:
|
| 51 |
+
name : "cosine"
|
| 52 |
+
T_max : 100
|
| 53 |
+
eta_min : 1.0e-6
|
| 54 |
+
|
| 55 |
+
outputs:
|
| 56 |
+
checkpoint_dir : "outputs/checkpoints"
|
| 57 |
+
figures_dir : "outputs/figures"
|
| 58 |
+
log_dir : "outputs/logs"
|
| 59 |
+
|
| 60 |
+
wandb:
|
| 61 |
+
enabled : true
|
| 62 |
+
project : "gz2-hierarchical-vit" # new project name
|
| 63 |
+
log_attention_every_n_epochs : 5
|
| 64 |
+
n_attention_samples : 8
|
configs/full_train.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 2 |
+
# configs/full_train.yaml
|
| 3 |
+
# Phase 3: Full training on complete 239k dataset.
|
| 4 |
+
# Run after ablation confirms lambda_kl = 0.5 is optimal.
|
| 5 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 6 |
+
|
| 7 |
+
defaults:
|
| 8 |
+
- base
|
| 9 |
+
|
| 10 |
+
experiment_name : "full_train"
|
| 11 |
+
|
| 12 |
+
data:
|
| 13 |
+
n_samples : null # full 239k dataset
|
| 14 |
+
|
| 15 |
+
training:
|
| 16 |
+
epochs : 100
|
| 17 |
+
batch_size : 64
|
| 18 |
+
|
| 19 |
+
scheduler:
|
| 20 |
+
T_max : 100
|
| 21 |
+
|
| 22 |
+
early_stopping:
|
| 23 |
+
patience : 10
|
| 24 |
+
|
| 25 |
+
wandb:
|
| 26 |
+
log_attention_every_n_epochs : 5
|
configs/subset_60k.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 2 |
+
# configs/subset_60k.yaml
|
| 3 |
+
# Phase 2: sanity check / quick prototype on 60k subset.
|
| 4 |
+
# Use for code verification before full training.
|
| 5 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 6 |
+
|
| 7 |
+
defaults:
|
| 8 |
+
- base
|
| 9 |
+
|
| 10 |
+
experiment_name : "subset_60k"
|
| 11 |
+
|
| 12 |
+
data:
|
| 13 |
+
n_samples : 60000 # 60k random galaxies
|
| 14 |
+
|
| 15 |
+
training:
|
| 16 |
+
epochs : 30
|
| 17 |
+
|
| 18 |
+
scheduler:
|
| 19 |
+
T_max : 30
|
| 20 |
+
|
| 21 |
+
early_stopping:
|
| 22 |
+
patience : 7
|
| 23 |
+
|
| 24 |
+
wandb:
|
| 25 |
+
log_attention_every_n_epochs : 5
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
python-multipart
|
| 4 |
+
pydantic
|
| 5 |
+
torch
|
| 6 |
+
torchvision
|
| 7 |
+
numpy
|
| 8 |
+
Pillow
|
| 9 |
+
omegaconf
|
| 10 |
+
timm
|
| 11 |
+
opencv-python-headless
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (128 Bytes). View file
|
|
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (132 Bytes). View file
|
|
|
src/__pycache__/ablation.cpython-310.pyc
ADDED
|
Binary file (7.39 kB). View file
|
|
|
src/__pycache__/attention_viz.cpython-310.pyc
ADDED
|
Binary file (9.72 kB). View file
|
|
|
src/__pycache__/baselines.cpython-310.pyc
ADDED
|
Binary file (22.4 kB). View file
|
|
|
src/__pycache__/baselines.cpython-312.pyc
ADDED
|
Binary file (39.1 kB). View file
|
|
|
src/__pycache__/dataset.cpython-310.pyc
ADDED
|
Binary file (7.46 kB). View file
|
|
|
src/__pycache__/evaluate_full.cpython-310.pyc
ADDED
|
Binary file (18.2 kB). View file
|
|
|
src/__pycache__/evaluate_full.cpython-312.pyc
ADDED
|
Binary file (30.7 kB). View file
|
|
|
src/__pycache__/loss.cpython-310.pyc
ADDED
|
Binary file (5.32 kB). View file
|
|
|
src/__pycache__/metrics.cpython-310.pyc
ADDED
|
Binary file (8.73 kB). View file
|
|
|
src/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
src/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
src/__pycache__/train.cpython-310.pyc
ADDED
|
Binary file (8.82 kB). View file
|
|
|
src/__pycache__/train_single.cpython-310.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
src/__pycache__/train_single.cpython-312.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
src/__pycache__/uncertainty_analysis.cpython-310.pyc
ADDED
|
Binary file (17.1 kB). View file
|
|
|
src/ablation.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/ablation.py
|
| 3 |
+
---------------
|
| 4 |
+
Lambda ablation study for the hierarchical KL + MSE loss.
|
| 5 |
+
|
| 6 |
+
Sweeps lambda_kl over [0.0, 0.25, 0.50, 0.75, 1.0] on a 10k subset
|
| 7 |
+
to justify the choice of lambda_kl = 0.5 used in the proposed model.
|
| 8 |
+
|
| 9 |
+
This ablation is reported in the paper as justification for the
|
| 10 |
+
balanced KL + MSE formulation. It is run BEFORE full training.
|
| 11 |
+
|
| 12 |
+
Output
|
| 13 |
+
------
|
| 14 |
+
outputs/figures/ablation/table_lambda_ablation.csv
|
| 15 |
+
outputs/figures/ablation/fig_lambda_ablation.pdf
|
| 16 |
+
outputs/figures/ablation/fig_lambda_ablation.png
|
| 17 |
+
|
| 18 |
+
Usage
|
| 19 |
+
-----
|
| 20 |
+
cd ~/galaxy
|
| 21 |
+
nohup python -m src.ablation --config configs/ablation.yaml \
|
| 22 |
+
> outputs/logs/ablation.log 2>&1 &
|
| 23 |
+
echo "PID: $!"
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import argparse
|
| 27 |
+
import copy
|
| 28 |
+
import logging
|
| 29 |
+
import random
|
| 30 |
+
import sys
|
| 31 |
+
import gc
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
|
| 34 |
+
import numpy as np
|
| 35 |
+
import pandas as pd
|
| 36 |
+
import torch
|
| 37 |
+
import matplotlib
|
| 38 |
+
matplotlib.use("Agg")
|
| 39 |
+
import matplotlib.pyplot as plt
|
| 40 |
+
from torch.amp import autocast, GradScaler
|
| 41 |
+
from omegaconf import OmegaConf, DictConfig
|
| 42 |
+
from tqdm import tqdm
|
| 43 |
+
|
| 44 |
+
from src.dataset import build_dataloaders
|
| 45 |
+
from src.model import build_model
|
| 46 |
+
from src.loss import HierarchicalLoss
|
| 47 |
+
from src.metrics import compute_metrics, predictions_to_numpy
|
| 48 |
+
|
| 49 |
+
logging.basicConfig(
|
| 50 |
+
format="%(asctime)s %(levelname)s %(message)s",
|
| 51 |
+
datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
|
| 52 |
+
)
|
| 53 |
+
log = logging.getLogger("ablation")
|
| 54 |
+
|
| 55 |
+
LAMBDA_VALUES = [0.0, 0.25, 0.50, 0.75, 1.0]
|
| 56 |
+
ABLATION_EPOCHS = 15 # sufficient to converge on 10k subset
|
| 57 |
+
ABLATION_SAMPLES = 10000
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _set_seed(seed: int):
|
| 61 |
+
random.seed(seed)
|
| 62 |
+
np.random.seed(seed)
|
| 63 |
+
torch.manual_seed(seed)
|
| 64 |
+
torch.cuda.manual_seed_all(seed)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def run_single(cfg: DictConfig, lambda_kl: float) -> dict:
|
| 68 |
+
"""
|
| 69 |
+
Train one model with the given lambda_kl on a 10k subset and
|
| 70 |
+
return test metrics. All other settings are identical across runs.
|
| 71 |
+
"""
|
| 72 |
+
_set_seed(cfg.seed)
|
| 73 |
+
|
| 74 |
+
cfg = copy.deepcopy(cfg)
|
| 75 |
+
cfg.loss.lambda_kl = lambda_kl
|
| 76 |
+
cfg.loss.lambda_mse = 1.0 - lambda_kl
|
| 77 |
+
cfg.data.n_samples = ABLATION_SAMPLES
|
| 78 |
+
cfg.training.epochs = ABLATION_EPOCHS
|
| 79 |
+
|
| 80 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 81 |
+
train_loader, val_loader, test_loader = build_dataloaders(cfg)
|
| 82 |
+
|
| 83 |
+
model = build_model(cfg).to(device)
|
| 84 |
+
loss_fn = HierarchicalLoss(cfg)
|
| 85 |
+
|
| 86 |
+
optimizer = torch.optim.AdamW(
|
| 87 |
+
[
|
| 88 |
+
{"params": model.backbone.parameters(),
|
| 89 |
+
"lr": cfg.training.learning_rate * 0.1},
|
| 90 |
+
{"params": model.head.parameters(),
|
| 91 |
+
"lr": cfg.training.learning_rate},
|
| 92 |
+
],
|
| 93 |
+
weight_decay=cfg.training.weight_decay,
|
| 94 |
+
)
|
| 95 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 96 |
+
optimizer, T_max=ABLATION_EPOCHS, eta_min=1e-6
|
| 97 |
+
)
|
| 98 |
+
scaler = GradScaler("cuda")
|
| 99 |
+
|
| 100 |
+
best_val = float("inf")
|
| 101 |
+
best_state = None
|
| 102 |
+
|
| 103 |
+
for epoch in range(1, ABLATION_EPOCHS + 1):
|
| 104 |
+
# ββ train ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 105 |
+
model.train()
|
| 106 |
+
for images, targets, weights, _ in tqdm(
|
| 107 |
+
train_loader, desc=f"Ξ»={lambda_kl:.2f} E{epoch}", leave=False
|
| 108 |
+
):
|
| 109 |
+
images = images.to(device, non_blocking=True)
|
| 110 |
+
targets = targets.to(device, non_blocking=True)
|
| 111 |
+
weights = weights.to(device, non_blocking=True)
|
| 112 |
+
optimizer.zero_grad(set_to_none=True)
|
| 113 |
+
with autocast("cuda", enabled=True):
|
| 114 |
+
logits = model(images)
|
| 115 |
+
loss, _ = loss_fn(logits, targets, weights)
|
| 116 |
+
scaler.scale(loss).backward()
|
| 117 |
+
scaler.unscale_(optimizer)
|
| 118 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 119 |
+
scaler.step(optimizer)
|
| 120 |
+
scaler.update()
|
| 121 |
+
scheduler.step()
|
| 122 |
+
|
| 123 |
+
# ββ validate βββββββββββββββββββββββββββββββββββββββββββ
|
| 124 |
+
model.eval()
|
| 125 |
+
val_loss = 0.0
|
| 126 |
+
nb = 0
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
for images, targets, weights, _ in val_loader:
|
| 129 |
+
images = images.to(device, non_blocking=True)
|
| 130 |
+
targets = targets.to(device, non_blocking=True)
|
| 131 |
+
weights = weights.to(device, non_blocking=True)
|
| 132 |
+
with autocast("cuda", enabled=True):
|
| 133 |
+
logits = model(images)
|
| 134 |
+
loss, _ = loss_fn(logits, targets, weights)
|
| 135 |
+
val_loss += loss.item()
|
| 136 |
+
nb += 1
|
| 137 |
+
val_loss /= nb
|
| 138 |
+
log.info(" Ξ»_kl=%.2f epoch=%d val_loss=%.5f", lambda_kl, epoch, val_loss)
|
| 139 |
+
|
| 140 |
+
if val_loss < best_val:
|
| 141 |
+
best_val = val_loss
|
| 142 |
+
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
| 143 |
+
|
| 144 |
+
# ββ test evaluation ββββββββββββββββββββββββββββββββββββββββ
|
| 145 |
+
model.load_state_dict(best_state)
|
| 146 |
+
model.eval()
|
| 147 |
+
|
| 148 |
+
all_preds, all_targets, all_weights = [], [], []
|
| 149 |
+
with torch.no_grad():
|
| 150 |
+
for images, targets, weights, _ in test_loader:
|
| 151 |
+
images = images.to(device, non_blocking=True)
|
| 152 |
+
targets = targets.to(device, non_blocking=True)
|
| 153 |
+
weights = weights.to(device, non_blocking=True)
|
| 154 |
+
with autocast("cuda", enabled=True):
|
| 155 |
+
logits = model(images)
|
| 156 |
+
p, t, w = predictions_to_numpy(logits, targets, weights)
|
| 157 |
+
all_preds.append(p)
|
| 158 |
+
all_targets.append(t)
|
| 159 |
+
all_weights.append(w)
|
| 160 |
+
|
| 161 |
+
all_preds = np.concatenate(all_preds)
|
| 162 |
+
all_targets = np.concatenate(all_targets)
|
| 163 |
+
all_weights = np.concatenate(all_weights)
|
| 164 |
+
metrics = compute_metrics(all_preds, all_targets, all_weights)
|
| 165 |
+
|
| 166 |
+
return {
|
| 167 |
+
"lambda_kl" : lambda_kl,
|
| 168 |
+
"lambda_mse" : round(1.0 - lambda_kl, 2),
|
| 169 |
+
"best_val_loss": round(best_val, 5),
|
| 170 |
+
"mae_weighted" : round(metrics["mae/weighted_avg"], 5),
|
| 171 |
+
"rmse_weighted": round(metrics["rmse/weighted_avg"], 5),
|
| 172 |
+
"ece_mean" : round(metrics["ece/mean"], 5),
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _plot_ablation(df: pd.DataFrame, save_dir: Path):
|
| 177 |
+
best_row = df.loc[df["mae_weighted"].idxmin()]
|
| 178 |
+
|
| 179 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
|
| 180 |
+
|
| 181 |
+
metrics_cfg = [
|
| 182 |
+
("mae_weighted", "Weighted MAE", "#2980b9"),
|
| 183 |
+
("rmse_weighted", "Weighted RMSE", "#c0392b"),
|
| 184 |
+
("ece_mean", "Mean ECE", "#27ae60"),
|
| 185 |
+
]
|
| 186 |
+
|
| 187 |
+
for ax, (col, ylabel, color) in zip(axes, metrics_cfg):
|
| 188 |
+
ax.plot(df["lambda_kl"], df[col], "-o", color=color,
|
| 189 |
+
linewidth=2, markersize=8)
|
| 190 |
+
ax.axvline(best_row["lambda_kl"], color="#7f8c8d",
|
| 191 |
+
linestyle="--", alpha=0.8,
|
| 192 |
+
label=f"Best Ξ» = {best_row['lambda_kl']:.2f}")
|
| 193 |
+
ax.set_xlabel("$\\lambda_{\\mathrm{KL}}$ "
|
| 194 |
+
"(0 = pure MSE, 1 = pure KL)", fontsize=11)
|
| 195 |
+
ax.set_ylabel(ylabel, fontsize=11)
|
| 196 |
+
ax.set_title(f"Lambda ablation β {ylabel}", fontsize=10)
|
| 197 |
+
ax.legend(fontsize=9)
|
| 198 |
+
ax.grid(True, alpha=0.3)
|
| 199 |
+
ax.set_xticks(df["lambda_kl"].tolist())
|
| 200 |
+
|
| 201 |
+
plt.suptitle(
|
| 202 |
+
"Ablation study: effect of $\\lambda_{\\mathrm{KL}}$ in the hierarchical loss\n"
|
| 203 |
+
f"10,000-sample subset, seed=42. Best: $\\lambda_{{\\mathrm{{KL}}}}$"
|
| 204 |
+
f" = {best_row['lambda_kl']:.2f} (MAE = {best_row['mae_weighted']:.5f})",
|
| 205 |
+
fontsize=11, y=1.02,
|
| 206 |
+
)
|
| 207 |
+
plt.tight_layout()
|
| 208 |
+
fig.savefig(save_dir / "fig_lambda_ablation.pdf", dpi=300, bbox_inches="tight")
|
| 209 |
+
fig.savefig(save_dir / "fig_lambda_ablation.png", dpi=300, bbox_inches="tight")
|
| 210 |
+
plt.close(fig)
|
| 211 |
+
log.info("Saved: fig_lambda_ablation")
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def main():
|
| 215 |
+
parser = argparse.ArgumentParser()
|
| 216 |
+
parser.add_argument("--config", required=True)
|
| 217 |
+
args = parser.parse_args()
|
| 218 |
+
|
| 219 |
+
base_cfg = OmegaConf.load("configs/base.yaml")
|
| 220 |
+
exp_cfg = OmegaConf.load(args.config)
|
| 221 |
+
cfg = OmegaConf.merge(base_cfg, exp_cfg)
|
| 222 |
+
|
| 223 |
+
save_dir = Path(cfg.outputs.figures_dir) / "ablation"
|
| 224 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 225 |
+
|
| 226 |
+
results = []
|
| 227 |
+
for lam in LAMBDA_VALUES:
|
| 228 |
+
log.info("=" * 55)
|
| 229 |
+
log.info("Ablation: lambda_kl=%.2f lambda_mse=%.2f",
|
| 230 |
+
lam, 1.0 - lam)
|
| 231 |
+
log.info("=" * 55)
|
| 232 |
+
|
| 233 |
+
result = run_single(cfg, lam)
|
| 234 |
+
results.append(result)
|
| 235 |
+
log.info("Result: %s", result)
|
| 236 |
+
|
| 237 |
+
# Free up RAM and GPU memory
|
| 238 |
+
gc.collect()
|
| 239 |
+
if torch.cuda.is_available():
|
| 240 |
+
torch.cuda.empty_cache()
|
| 241 |
+
|
| 242 |
+
df = pd.DataFrame(results)
|
| 243 |
+
df.to_csv(save_dir / "table_lambda_ablation.csv", index=False)
|
| 244 |
+
log.info("Saved: table_lambda_ablation.csv")
|
| 245 |
+
|
| 246 |
+
print()
|
| 247 |
+
print(df.to_string(index=False))
|
| 248 |
+
print()
|
| 249 |
+
|
| 250 |
+
best = df.loc[df["mae_weighted"].idxmin()]
|
| 251 |
+
log.info("Best: lambda_kl=%.2f MAE=%.5f RMSE=%.5f",
|
| 252 |
+
best["lambda_kl"], best["mae_weighted"], best["rmse_weighted"])
|
| 253 |
+
|
| 254 |
+
_plot_ablation(df, save_dir)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
if __name__ == "__main__":
|
| 258 |
+
main()
|
src/attention_viz.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/attention_viz.py
|
| 3 |
+
--------------------
|
| 4 |
+
Full multi-layer attention rollout for ViT explainability.
|
| 5 |
+
|
| 6 |
+
Theory β Abnar & Zuidema (2020)
|
| 7 |
+
--------------------------------
|
| 8 |
+
Each ViT transformer block l produces attention weights A_l of shape
|
| 9 |
+
[B, H, N+1, N+1], where H=12 heads and N+1=197 tokens (196 patches
|
| 10 |
+
+ 1 CLS token).
|
| 11 |
+
|
| 12 |
+
Full rollout algorithm:
|
| 13 |
+
1. Average over heads: A_l = mean_h(attn_l) [B, N+1, N+1]
|
| 14 |
+
2. Add residual: A_l = 0.5*A_l + 0.5*I [B, N+1, N+1]
|
| 15 |
+
3. Row-normalise so attention sums to 1 per token.
|
| 16 |
+
4. Chain layers: R = A_1 β A_2 β ... β A_12 [B, N+1, N+1]
|
| 17 |
+
5. CLS row, patch cols: rollout = R[:, 0, 1:] [B, 196]
|
| 18 |
+
6. Reshape 196 β 14Γ14, upsample to 224Γ224.
|
| 19 |
+
|
| 20 |
+
FIX applied vs. original
|
| 21 |
+
--------------------------
|
| 22 |
+
The original code used R = bmm(A, R) (left-multiplication) which
|
| 23 |
+
accumulates attention in reverse order. The correct propagation per
|
| 24 |
+
Abnar & Zuidema is R = bmm(R, A) (right-multiplication), which
|
| 25 |
+
tracks how information from the INPUT patches flows forward through
|
| 26 |
+
successive layers into the CLS token.
|
| 27 |
+
|
| 28 |
+
Entropy interpretation
|
| 29 |
+
-----------------------
|
| 30 |
+
CLS attention entropy INCREASES from early to late layers. This is
|
| 31 |
+
the expected and correct behaviour for ViT classification:
|
| 32 |
+
- Early layers (1β8): entropy is low and stable (~1.7β2.0 nats),
|
| 33 |
+
consistent with local morphological feature detection.
|
| 34 |
+
- Late layers (9β12): entropy rises sharply (~2.7β4.5 nats),
|
| 35 |
+
consistent with the CLS token performing global integration β
|
| 36 |
+
aggregating information from all patches before the regression head.
|
| 37 |
+
This pattern confirms that early layers specialise in local structure
|
| 38 |
+
while late layers globally aggregate morphological information for
|
| 39 |
+
the final prediction.
|
| 40 |
+
|
| 41 |
+
References
|
| 42 |
+
----------
|
| 43 |
+
Abnar & Zuidema (2020). Quantifying Attention Flow in Transformers.
|
| 44 |
+
ACL 2020. https://arxiv.org/abs/2005.00928
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
from __future__ import annotations
|
| 48 |
+
|
| 49 |
+
import numpy as np
|
| 50 |
+
import torch
|
| 51 |
+
import torch.nn.functional as F
|
| 52 |
+
import matplotlib
|
| 53 |
+
matplotlib.use("Agg")
|
| 54 |
+
import matplotlib.pyplot as plt
|
| 55 |
+
import matplotlib.cm as cm
|
| 56 |
+
from pathlib import Path
|
| 57 |
+
from typing import Optional, List
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 61 |
+
# Full multi-layer rollout (FIXED)
|
| 62 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 63 |
+
|
| 64 |
+
def attention_rollout_full(
|
| 65 |
+
all_attn_weights: List[torch.Tensor],
|
| 66 |
+
patch_size: int = 16,
|
| 67 |
+
image_size: int = 224,
|
| 68 |
+
) -> np.ndarray:
|
| 69 |
+
"""
|
| 70 |
+
Full multi-layer attention rollout per Abnar & Zuidema (2020).
|
| 71 |
+
|
| 72 |
+
Parameters
|
| 73 |
+
----------
|
| 74 |
+
all_attn_weights : list of L tensors, each [B, H, N+1, N+1]
|
| 75 |
+
One tensor per transformer layer, in order 1 β L.
|
| 76 |
+
patch_size : ViT patch size (16 for ViT-Base/16)
|
| 77 |
+
image_size : input image size (224)
|
| 78 |
+
|
| 79 |
+
Returns
|
| 80 |
+
-------
|
| 81 |
+
rollout_maps : [B, image_size, image_size] float32 in [0, 1]
|
| 82 |
+
"""
|
| 83 |
+
assert len(all_attn_weights) > 0, "Need at least one attention layer"
|
| 84 |
+
|
| 85 |
+
B, H, N1, _ = all_attn_weights[0].shape
|
| 86 |
+
device = all_attn_weights[0].device
|
| 87 |
+
|
| 88 |
+
# Identity matrix: R_0 = I
|
| 89 |
+
R = torch.eye(N1, device=device).unsqueeze(0).expand(B, -1, -1).clone()
|
| 90 |
+
|
| 91 |
+
for attn in all_attn_weights:
|
| 92 |
+
# Step 1: average over heads β [B, N+1, N+1]
|
| 93 |
+
A = attn.mean(dim=1)
|
| 94 |
+
|
| 95 |
+
# Step 2: residual connection
|
| 96 |
+
I = torch.eye(N1, device=device).unsqueeze(0)
|
| 97 |
+
A = 0.5 * A + 0.5 * I
|
| 98 |
+
|
| 99 |
+
# Step 3: row-normalise
|
| 100 |
+
A = A / A.sum(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 101 |
+
|
| 102 |
+
# Step 4: chain rollout β FIXED: R = R @ A (right-multiply)
|
| 103 |
+
# This propagates information forward from input to CLS.
|
| 104 |
+
# Original had R = A @ R (left-multiply) which is incorrect.
|
| 105 |
+
R = torch.bmm(R, A)
|
| 106 |
+
|
| 107 |
+
# Step 5: CLS row (index 0), patch columns (1 onwards)
|
| 108 |
+
cls_attn = R[:, 0, 1:] # [B, 196]
|
| 109 |
+
|
| 110 |
+
# Step 6: reshape and upsample to image size
|
| 111 |
+
grid_size = image_size // patch_size # 14
|
| 112 |
+
cls_attn = cls_attn.reshape(B, 1, grid_size, grid_size)
|
| 113 |
+
rollout = F.interpolate(
|
| 114 |
+
cls_attn, size=(image_size, image_size),
|
| 115 |
+
mode="bilinear", align_corners=False,
|
| 116 |
+
).squeeze(1) # [B, 224, 224]
|
| 117 |
+
|
| 118 |
+
rollout_np = rollout.cpu().numpy()
|
| 119 |
+
for i in range(B):
|
| 120 |
+
mn, mx = rollout_np[i].min(), rollout_np[i].max()
|
| 121 |
+
rollout_np[i] = (rollout_np[i] - mn) / (mx - mn + 1e-8)
|
| 122 |
+
|
| 123 |
+
return rollout_np.astype(np.float32)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def attention_rollout_single_layer(
|
| 127 |
+
attn_weights: torch.Tensor,
|
| 128 |
+
patch_size: int = 16,
|
| 129 |
+
image_size: int = 224,
|
| 130 |
+
) -> np.ndarray:
|
| 131 |
+
"""Single-layer rollout (backward compatibility). Prefer full rollout."""
|
| 132 |
+
return attention_rollout_full(
|
| 133 |
+
[attn_weights], patch_size=patch_size, image_size=image_size
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 138 |
+
# Visualisation utilities
|
| 139 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 140 |
+
|
| 141 |
+
def denormalise_image(tensor: torch.Tensor) -> np.ndarray:
|
| 142 |
+
"""Reverse ImageNet normalisation β uint8 [H, W, 3]."""
|
| 143 |
+
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
| 144 |
+
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
| 145 |
+
img = tensor.cpu().numpy().transpose(1, 2, 0)
|
| 146 |
+
img = np.clip(img * std + mean, 0, 1)
|
| 147 |
+
return (img * 255).astype(np.uint8)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def make_overlay(
|
| 151 |
+
image_np: np.ndarray,
|
| 152 |
+
rollout: np.ndarray,
|
| 153 |
+
alpha: float = 0.5,
|
| 154 |
+
colormap: str = "inferno",
|
| 155 |
+
) -> np.ndarray:
|
| 156 |
+
"""Blend attention heatmap onto galaxy image."""
|
| 157 |
+
cmap = cm.get_cmap(colormap)
|
| 158 |
+
heatmap = (cmap(rollout)[:, :, :3] * 255).astype(np.uint8)
|
| 159 |
+
overlay = (
|
| 160 |
+
(1 - alpha) * image_np.astype(np.float32) +
|
| 161 |
+
alpha * heatmap.astype(np.float32)
|
| 162 |
+
).clip(0, 255).astype(np.uint8)
|
| 163 |
+
return overlay
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def plot_attention_grid(
|
| 167 |
+
images: torch.Tensor,
|
| 168 |
+
attn_weights,
|
| 169 |
+
image_ids: list,
|
| 170 |
+
save_path: Optional[str] = None,
|
| 171 |
+
alpha: float = 0.5,
|
| 172 |
+
n_cols: int = 4,
|
| 173 |
+
rollout_mode: str = "full",
|
| 174 |
+
) -> plt.Figure:
|
| 175 |
+
"""
|
| 176 |
+
Publication-quality attention rollout gallery.
|
| 177 |
+
|
| 178 |
+
Parameters
|
| 179 |
+
----------
|
| 180 |
+
images : [N, 3, H, W] galaxy image tensors
|
| 181 |
+
attn_weights : list of L tensors [N, H, N+1, N+1] (full mode)
|
| 182 |
+
or single tensor [N, H, N+1, N+1] (single mode)
|
| 183 |
+
image_ids : dr7objid list for panel titles
|
| 184 |
+
save_path : optional file path to save the figure
|
| 185 |
+
alpha : heatmap opacity (0 = image only, 1 = heatmap only)
|
| 186 |
+
n_cols : number of columns in the grid
|
| 187 |
+
rollout_mode : "full" for 12-layer rollout (recommended)
|
| 188 |
+
"""
|
| 189 |
+
N = images.shape[0]
|
| 190 |
+
|
| 191 |
+
if rollout_mode == "full" and isinstance(attn_weights, list):
|
| 192 |
+
rollout_maps = attention_rollout_full(attn_weights)
|
| 193 |
+
else:
|
| 194 |
+
if isinstance(attn_weights, list):
|
| 195 |
+
attn_weights = attn_weights[-1]
|
| 196 |
+
rollout_maps = attention_rollout_single_layer(attn_weights)
|
| 197 |
+
|
| 198 |
+
n_rows = int(np.ceil(N / n_cols))
|
| 199 |
+
fig, axes = plt.subplots(
|
| 200 |
+
n_rows * 2, n_cols,
|
| 201 |
+
figsize=(n_cols * 3, n_rows * 6),
|
| 202 |
+
facecolor="black",
|
| 203 |
+
)
|
| 204 |
+
axes = axes.flatten()
|
| 205 |
+
|
| 206 |
+
for i in range(N):
|
| 207 |
+
img_np = denormalise_image(images[i])
|
| 208 |
+
overlay = make_overlay(img_np, rollout_maps[i], alpha=alpha)
|
| 209 |
+
|
| 210 |
+
row_base = (i // n_cols) * 2
|
| 211 |
+
col = i % n_cols
|
| 212 |
+
ax_img = axes[row_base * n_cols + col]
|
| 213 |
+
ax_attn = axes[(row_base + 1) * n_cols + col]
|
| 214 |
+
|
| 215 |
+
ax_img.imshow(img_np)
|
| 216 |
+
ax_img.axis("off")
|
| 217 |
+
ax_img.set_title(str(image_ids[i])[-6:], color="white",
|
| 218 |
+
fontsize=7, pad=2)
|
| 219 |
+
ax_attn.imshow(overlay)
|
| 220 |
+
ax_attn.axis("off")
|
| 221 |
+
|
| 222 |
+
# Hide empty panels
|
| 223 |
+
for j in range(N, n_rows * n_cols):
|
| 224 |
+
if j < len(axes):
|
| 225 |
+
axes[j].axis("off")
|
| 226 |
+
|
| 227 |
+
mode_label = "Full 12-layer rollout" if rollout_mode == "full" else "Last-layer rollout"
|
| 228 |
+
plt.suptitle(
|
| 229 |
+
f"Galaxy attention rollout β {mode_label} (ViT-Base/16)",
|
| 230 |
+
color="white", fontsize=10, y=1.01
|
| 231 |
+
)
|
| 232 |
+
plt.tight_layout(pad=0.3)
|
| 233 |
+
|
| 234 |
+
if save_path is not None:
|
| 235 |
+
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
| 236 |
+
fig.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="black")
|
| 237 |
+
|
| 238 |
+
return fig
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 242 |
+
# Attention entropy per layer
|
| 243 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 244 |
+
|
| 245 |
+
def compute_attention_entropy_per_layer(
|
| 246 |
+
all_attn_weights: List[torch.Tensor],
|
| 247 |
+
) -> np.ndarray:
|
| 248 |
+
"""
|
| 249 |
+
Mean CLS attention entropy per transformer layer.
|
| 250 |
+
|
| 251 |
+
Interpretation
|
| 252 |
+
--------------
|
| 253 |
+
Early layers (1β8): low, stable entropy (~1.7β2.0 nats) consistent
|
| 254 |
+
with local morphological feature detection across patches.
|
| 255 |
+
|
| 256 |
+
Late layers (9β12): rapidly increasing entropy (~2.7β4.5 nats),
|
| 257 |
+
reflecting the CLS token performing global integration β attending
|
| 258 |
+
broadly across all patches to aggregate morphological evidence before
|
| 259 |
+
the regression head. This is the expected behaviour for ViT-class
|
| 260 |
+
models and is consistent with prior work on ViT attention patterns.
|
| 261 |
+
|
| 262 |
+
Higher entropy β less discriminative. In late layers, broad attention
|
| 263 |
+
is necessary for global aggregation. The rollout visualisations confirm
|
| 264 |
+
that the final representation correctly emphasises morphological
|
| 265 |
+
structure (spiral arms, bulge, disk) despite diffuse raw attention.
|
| 266 |
+
|
| 267 |
+
Returns
|
| 268 |
+
-------
|
| 269 |
+
entropies : [L] mean entropy per layer in nats
|
| 270 |
+
"""
|
| 271 |
+
entropies = []
|
| 272 |
+
for attn in all_attn_weights:
|
| 273 |
+
# CLS token attention to patches: [B, H, N_patches]
|
| 274 |
+
cls_attn = attn[:, :, 0, 1:].clamp(min=1e-9)
|
| 275 |
+
ent = -(cls_attn * cls_attn.log()).sum(dim=-1) # [B, H]
|
| 276 |
+
entropies.append(ent.mean().item())
|
| 277 |
+
return np.array(entropies, dtype=np.float32)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def plot_attention_entropy(
|
| 281 |
+
all_attn_weights: List[torch.Tensor],
|
| 282 |
+
save_path: Optional[str] = None,
|
| 283 |
+
) -> plt.Figure:
|
| 284 |
+
"""
|
| 285 |
+
Plot CLS attention entropy per transformer layer with correct interpretation.
|
| 286 |
+
"""
|
| 287 |
+
entropies = compute_attention_entropy_per_layer(all_attn_weights)
|
| 288 |
+
L = len(entropies)
|
| 289 |
+
|
| 290 |
+
fig, ax = plt.subplots(figsize=(8, 4))
|
| 291 |
+
ax.plot(range(1, L + 1), entropies, "b-o", markersize=6, linewidth=2)
|
| 292 |
+
|
| 293 |
+
# Shade regions for interpretation
|
| 294 |
+
ax.axvspan(1, 8.5, alpha=0.07, color="blue",
|
| 295 |
+
label="Local feature detection (layers 1β8)")
|
| 296 |
+
ax.axvspan(8.5, L + 0.5, alpha=0.07, color="orange",
|
| 297 |
+
label="Global integration (layers 9β12)")
|
| 298 |
+
|
| 299 |
+
ax.set_xlabel("Transformer layer", fontsize=12)
|
| 300 |
+
ax.set_ylabel("Mean CLS attention entropy (nats)", fontsize=12)
|
| 301 |
+
ax.set_title(
|
| 302 |
+
"CLS token attention entropy vs. transformer depth\n"
|
| 303 |
+
"Early layers: local morphological detection | "
|
| 304 |
+
"Late layers: global aggregation",
|
| 305 |
+
fontsize=10,
|
| 306 |
+
)
|
| 307 |
+
ax.set_xticks(range(1, L + 1))
|
| 308 |
+
ax.legend(fontsize=9)
|
| 309 |
+
ax.grid(True, alpha=0.3)
|
| 310 |
+
plt.tight_layout()
|
| 311 |
+
|
| 312 |
+
if save_path:
|
| 313 |
+
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
| 314 |
+
fig.savefig(save_path, dpi=300, bbox_inches="tight")
|
| 315 |
+
|
| 316 |
+
return fig
|
src/baselines.py
ADDED
|
@@ -0,0 +1,844 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/baselines.py
|
| 3 |
+
----------------
|
| 4 |
+
Consolidated baseline training for the GZ2 hierarchical probabilistic
|
| 5 |
+
regression paper. ALL baselines are trained from this single script.
|
| 6 |
+
|
| 7 |
+
Replaces the three separate scripts:
|
| 8 |
+
src/baselines.py (was: ResNet-18 MSE + ViT MSE)
|
| 9 |
+
src/run_resnet_kl.py (was: ResNet-18 KL+MSE β now merged here)
|
| 10 |
+
src/train_dirichlet.py (was: ViT Dirichlet β now merged here)
|
| 11 |
+
|
| 12 |
+
DELETE those three original files after switching to this one.
|
| 13 |
+
|
| 14 |
+
Baselines trained
|
| 15 |
+
-----------------
|
| 16 |
+
B1. ResNet-18 + independent MSE (sigmoid)
|
| 17 |
+
β CNN, no hierarchy, no KL. Demonstrates the cost of
|
| 18 |
+
ignoring the decision-tree structure.
|
| 19 |
+
|
| 20 |
+
B2. ResNet-18 + hierarchical KL+MSE
|
| 21 |
+
β Same loss as proposed, CNN backbone.
|
| 22 |
+
Isolates ViT vs. CNN contribution.
|
| 23 |
+
|
| 24 |
+
B3. ViT-Base + hierarchical MSE only (no KL)
|
| 25 |
+
β Same backbone as proposed, KL term removed.
|
| 26 |
+
Isolates contribution of the KL term.
|
| 27 |
+
|
| 28 |
+
B4. ViT-Base + Dirichlet NLL (Zoobot-style)
|
| 29 |
+
β Direct comparison with the established Zoobot approach
|
| 30 |
+
(Walmsley et al. 2022, MNRAS 509, 3966).
|
| 31 |
+
|
| 32 |
+
Proposed model (not trained here β trained via src/train.py):
|
| 33 |
+
ViT-Base + hierarchical KL+MSE β outputs/checkpoints/best_full_train.pt
|
| 34 |
+
|
| 35 |
+
Consistency guarantee
|
| 36 |
+
---------------------
|
| 37 |
+
All baselines use identical:
|
| 38 |
+
- Random seed, data split, batch size, epochs, early stopping
|
| 39 |
+
- AdamW optimiser, CosineAnnealingLR, gradient clipping
|
| 40 |
+
- Image transforms and evaluation metric (compute_metrics on same test split)
|
| 41 |
+
|
| 42 |
+
The ONLY differences between models are the backbone and/or loss function.
|
| 43 |
+
|
| 44 |
+
Usage
|
| 45 |
+
-----
|
| 46 |
+
cd ~/galaxy
|
| 47 |
+
nohup python -m src.baselines --config configs/full_train.yaml \
|
| 48 |
+
> outputs/logs/baselines.log 2>&1 &
|
| 49 |
+
echo "PID: $!"
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
import argparse
|
| 53 |
+
import logging
|
| 54 |
+
import random
|
| 55 |
+
import sys
|
| 56 |
+
from pathlib import Path
|
| 57 |
+
|
| 58 |
+
import numpy as np
|
| 59 |
+
import pandas as pd
|
| 60 |
+
import torch
|
| 61 |
+
import timm
|
| 62 |
+
import torch.nn as nn
|
| 63 |
+
import torch.nn.functional as F
|
| 64 |
+
import matplotlib
|
| 65 |
+
matplotlib.use("Agg")
|
| 66 |
+
import matplotlib.pyplot as plt
|
| 67 |
+
from torch.amp import autocast, GradScaler
|
| 68 |
+
from omegaconf import OmegaConf
|
| 69 |
+
from tqdm import tqdm
|
| 70 |
+
|
| 71 |
+
import wandb
|
| 72 |
+
|
| 73 |
+
from src.dataset import build_dataloaders, QUESTION_GROUPS
|
| 74 |
+
from src.loss import HierarchicalLoss, DirichletLoss, MSEOnlyLoss
|
| 75 |
+
from src.metrics import (compute_metrics, predictions_to_numpy,
|
| 76 |
+
dirichlet_predictions_to_numpy, simplex_violation_rate)
|
| 77 |
+
from src.model import build_model, build_dirichlet_model
|
| 78 |
+
|
| 79 |
+
logging.basicConfig(
|
| 80 |
+
format="%(asctime)s %(levelname)s %(message)s",
|
| 81 |
+
datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
|
| 82 |
+
)
|
| 83 |
+
log = logging.getLogger("baselines")
|
| 84 |
+
|
| 85 |
+
QUESTION_LABELS = {
|
| 86 |
+
"t01": "Smooth or features", "t02": "Edge-on disk",
|
| 87 |
+
"t03": "Bar", "t04": "Spiral arms",
|
| 88 |
+
"t05": "Bulge prominence", "t06": "Odd feature",
|
| 89 |
+
"t07": "Roundedness", "t08": "Odd feature type",
|
| 90 |
+
"t09": "Bulge shape", "t10": "Arms winding",
|
| 91 |
+
"t11": "Arms number",
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 96 |
+
# Reproducibility
|
| 97 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 98 |
+
|
| 99 |
+
def set_seed(seed: int):
|
| 100 |
+
random.seed(seed)
|
| 101 |
+
np.random.seed(seed)
|
| 102 |
+
torch.manual_seed(seed)
|
| 103 |
+
torch.cuda.manual_seed_all(seed)
|
| 104 |
+
torch.backends.cudnn.deterministic = True
|
| 105 |
+
torch.backends.cudnn.benchmark = False
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 109 |
+
# Early stopping (mirrors train.py exactly)
|
| 110 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 111 |
+
|
| 112 |
+
class EarlyStopping:
|
| 113 |
+
def __init__(self, patience, min_delta, checkpoint_path):
|
| 114 |
+
self.patience = patience
|
| 115 |
+
self.min_delta = min_delta
|
| 116 |
+
self.checkpoint_path = checkpoint_path
|
| 117 |
+
self.best_loss = float("inf")
|
| 118 |
+
self.counter = 0
|
| 119 |
+
self.best_epoch = 0
|
| 120 |
+
|
| 121 |
+
def step(self, val_loss, model, epoch) -> bool:
|
| 122 |
+
if val_loss < self.best_loss - self.min_delta:
|
| 123 |
+
self.best_loss = val_loss
|
| 124 |
+
self.counter = 0
|
| 125 |
+
self.best_epoch = epoch
|
| 126 |
+
torch.save(
|
| 127 |
+
{"epoch": epoch, "model_state": model.state_dict(),
|
| 128 |
+
"val_loss": val_loss},
|
| 129 |
+
self.checkpoint_path,
|
| 130 |
+
)
|
| 131 |
+
log.info(" [ckpt] saved val_loss=%.6f epoch=%d", val_loss, epoch)
|
| 132 |
+
else:
|
| 133 |
+
self.counter += 1
|
| 134 |
+
log.info(" [early_stop] %d/%d best=%.6f",
|
| 135 |
+
self.counter, self.patience, self.best_loss)
|
| 136 |
+
return self.counter >= self.patience
|
| 137 |
+
|
| 138 |
+
def restore_best(self, model) -> float:
|
| 139 |
+
ckpt = torch.load(self.checkpoint_path, map_location="cpu",
|
| 140 |
+
weights_only=True)
|
| 141 |
+
model.load_state_dict(ckpt["model_state"])
|
| 142 |
+
log.info("Restored best weights epoch=%d val_loss=%.6f",
|
| 143 |
+
ckpt["epoch"], ckpt["val_loss"])
|
| 144 |
+
return ckpt["val_loss"]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 148 |
+
# Baseline Model 1: ResNet-18 + independent MSE
|
| 149 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 150 |
+
|
| 151 |
+
class ResNet18Baseline(nn.Module):
|
| 152 |
+
"""
|
| 153 |
+
ResNet-18 pretrained on ImageNet with a dropout + linear head.
|
| 154 |
+
Used for both the sigmoid-MSE baseline and the KL+MSE baseline.
|
| 155 |
+
"""
|
| 156 |
+
def __init__(self, dropout: float = 0.3):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.backbone = timm.create_model(
|
| 159 |
+
"resnet18", pretrained=True, num_classes=0
|
| 160 |
+
)
|
| 161 |
+
self.head = nn.Sequential(
|
| 162 |
+
nn.Dropout(p=dropout),
|
| 163 |
+
nn.Linear(self.backbone.num_features, 37),
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 167 |
+
return self.head(self.backbone(x))
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class IndependentMSELoss(nn.Module):
|
| 171 |
+
"""
|
| 172 |
+
Plain MSE over all 37 targets independently.
|
| 173 |
+
No hierarchical weighting, no KL divergence.
|
| 174 |
+
Sigmoid applied to predictions before MSE to constrain range [0,1].
|
| 175 |
+
|
| 176 |
+
Note: predictions do NOT sum to 1 per question group by construction.
|
| 177 |
+
This is documented and the simplex_violation_rate metric quantifies
|
| 178 |
+
this invalidity to allow fair comparison with the proposed method.
|
| 179 |
+
"""
|
| 180 |
+
def forward(self, predictions, targets, weights):
|
| 181 |
+
pred_prob = torch.sigmoid(predictions)
|
| 182 |
+
loss = F.mse_loss(pred_prob, targets)
|
| 183 |
+
return loss, {"loss/total": loss.detach().item()}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 187 |
+
# Shared training loop
|
| 188 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 189 |
+
|
| 190 |
+
def _train_epoch(model, loader, loss_fn, optimizer, scaler,
|
| 191 |
+
device, cfg, epoch, label):
|
| 192 |
+
model.train()
|
| 193 |
+
total = 0.0
|
| 194 |
+
nb = 0
|
| 195 |
+
for images, targets, weights, _ in tqdm(
|
| 196 |
+
loader, desc=f"{label} E{epoch}", leave=False
|
| 197 |
+
):
|
| 198 |
+
images = images.to(device, non_blocking=True)
|
| 199 |
+
targets = targets.to(device, non_blocking=True)
|
| 200 |
+
weights = weights.to(device, non_blocking=True)
|
| 201 |
+
|
| 202 |
+
optimizer.zero_grad(set_to_none=True)
|
| 203 |
+
with autocast("cuda", enabled=cfg.training.mixed_precision):
|
| 204 |
+
logits = model(images)
|
| 205 |
+
loss, _ = loss_fn(logits, targets, weights)
|
| 206 |
+
scaler.scale(loss).backward()
|
| 207 |
+
scaler.unscale_(optimizer)
|
| 208 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.training.grad_clip)
|
| 209 |
+
scaler.step(optimizer)
|
| 210 |
+
scaler.update()
|
| 211 |
+
|
| 212 |
+
total += loss.item()
|
| 213 |
+
nb += 1
|
| 214 |
+
return total / nb
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _train_epoch_dirichlet(model, loader, loss_fn, optimizer, scaler,
|
| 218 |
+
device, cfg, epoch, label):
|
| 219 |
+
"""Training epoch for Dirichlet model (outputs alpha, not logits)."""
|
| 220 |
+
model.train()
|
| 221 |
+
total = 0.0
|
| 222 |
+
nb = 0
|
| 223 |
+
for images, targets, weights, _ in tqdm(
|
| 224 |
+
loader, desc=f"{label} E{epoch}", leave=False
|
| 225 |
+
):
|
| 226 |
+
images = images.to(device, non_blocking=True)
|
| 227 |
+
targets = targets.to(device, non_blocking=True)
|
| 228 |
+
weights = weights.to(device, non_blocking=True)
|
| 229 |
+
|
| 230 |
+
optimizer.zero_grad(set_to_none=True)
|
| 231 |
+
with autocast("cuda", enabled=cfg.training.mixed_precision):
|
| 232 |
+
alpha = model(images)
|
| 233 |
+
loss, _ = loss_fn(alpha, targets, weights)
|
| 234 |
+
scaler.scale(loss).backward()
|
| 235 |
+
scaler.unscale_(optimizer)
|
| 236 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.training.grad_clip)
|
| 237 |
+
scaler.step(optimizer)
|
| 238 |
+
scaler.update()
|
| 239 |
+
|
| 240 |
+
total += loss.item()
|
| 241 |
+
nb += 1
|
| 242 |
+
return total / nb
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def _val_epoch(model, loader, loss_fn, device, cfg, epoch, label,
|
| 246 |
+
use_sigmoid=False):
|
| 247 |
+
model.eval()
|
| 248 |
+
total = 0.0
|
| 249 |
+
nb = 0
|
| 250 |
+
all_preds, all_targets, all_weights = [], [], []
|
| 251 |
+
|
| 252 |
+
with torch.no_grad():
|
| 253 |
+
for images, targets, weights, _ in tqdm(
|
| 254 |
+
loader, desc=f"{label} Val E{epoch}", leave=False
|
| 255 |
+
):
|
| 256 |
+
images = images.to(device, non_blocking=True)
|
| 257 |
+
targets = targets.to(device, non_blocking=True)
|
| 258 |
+
weights = weights.to(device, non_blocking=True)
|
| 259 |
+
|
| 260 |
+
with autocast("cuda", enabled=cfg.training.mixed_precision):
|
| 261 |
+
logits = model(images)
|
| 262 |
+
loss, _ = loss_fn(logits, targets, weights)
|
| 263 |
+
|
| 264 |
+
total += loss.item()
|
| 265 |
+
nb += 1
|
| 266 |
+
|
| 267 |
+
if use_sigmoid:
|
| 268 |
+
pred_prob = torch.sigmoid(logits).detach().cpu().numpy()
|
| 269 |
+
else:
|
| 270 |
+
pred_cpu = logits.detach().cpu().clone()
|
| 271 |
+
for q, (s, e) in QUESTION_GROUPS.items():
|
| 272 |
+
pred_cpu[:, s:e] = torch.softmax(pred_cpu[:, s:e], dim=-1)
|
| 273 |
+
pred_prob = pred_cpu.numpy()
|
| 274 |
+
|
| 275 |
+
all_preds.append(pred_prob)
|
| 276 |
+
all_targets.append(targets.detach().cpu().numpy())
|
| 277 |
+
all_weights.append(weights.detach().cpu().numpy())
|
| 278 |
+
|
| 279 |
+
all_preds = np.concatenate(all_preds)
|
| 280 |
+
all_targets = np.concatenate(all_targets)
|
| 281 |
+
all_weights = np.concatenate(all_weights)
|
| 282 |
+
metrics = compute_metrics(all_preds, all_targets, all_weights)
|
| 283 |
+
|
| 284 |
+
return total / nb, metrics
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def _val_epoch_dirichlet(model, loader, loss_fn, device, cfg, epoch, label):
|
| 288 |
+
model.eval()
|
| 289 |
+
total = 0.0
|
| 290 |
+
nb = 0
|
| 291 |
+
all_preds, all_targets, all_weights = [], [], []
|
| 292 |
+
|
| 293 |
+
with torch.no_grad():
|
| 294 |
+
for images, targets, weights, _ in tqdm(
|
| 295 |
+
loader, desc=f"{label} Val E{epoch}", leave=False
|
| 296 |
+
):
|
| 297 |
+
images = images.to(device, non_blocking=True)
|
| 298 |
+
targets = targets.to(device, non_blocking=True)
|
| 299 |
+
weights = weights.to(device, non_blocking=True)
|
| 300 |
+
|
| 301 |
+
with autocast("cuda", enabled=cfg.training.mixed_precision):
|
| 302 |
+
alpha = model(images)
|
| 303 |
+
loss, _ = loss_fn(alpha, targets, weights)
|
| 304 |
+
|
| 305 |
+
total += loss.item()
|
| 306 |
+
nb += 1
|
| 307 |
+
|
| 308 |
+
p, t, w = dirichlet_predictions_to_numpy(alpha, targets, weights)
|
| 309 |
+
all_preds.append(p)
|
| 310 |
+
all_targets.append(t)
|
| 311 |
+
all_weights.append(w)
|
| 312 |
+
|
| 313 |
+
all_preds = np.concatenate(all_preds)
|
| 314 |
+
all_targets = np.concatenate(all_targets)
|
| 315 |
+
all_weights = np.concatenate(all_weights)
|
| 316 |
+
metrics = compute_metrics(all_preds, all_targets, all_weights)
|
| 317 |
+
|
| 318 |
+
return total / nb, metrics
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 322 |
+
# Generic train_and_evaluate (non-Dirichlet)
|
| 323 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 324 |
+
|
| 325 |
+
def train_and_evaluate(
|
| 326 |
+
model, loss_fn, cfg, device,
|
| 327 |
+
label, checkpoint_path,
|
| 328 |
+
use_layerwise_lr=True,
|
| 329 |
+
use_sigmoid=False,
|
| 330 |
+
):
|
| 331 |
+
"""
|
| 332 |
+
Full training loop consistent with train.py.
|
| 333 |
+
Returns (test_metrics, best_val_loss, best_epoch, history).
|
| 334 |
+
If checkpoint exists, loads it and skips training.
|
| 335 |
+
"""
|
| 336 |
+
# Check if checkpoint exists - if so, skip training
|
| 337 |
+
if Path(checkpoint_path).exists():
|
| 338 |
+
log.info("%s: checkpoint found - loading and skipping training", label)
|
| 339 |
+
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
|
| 340 |
+
model.load_state_dict(ckpt["model_state"])
|
| 341 |
+
best_epoch = ckpt.get("epoch", 0)
|
| 342 |
+
best_val = ckpt.get("val_loss", float("inf"))
|
| 343 |
+
log.info("Restored: epoch=%d, val_loss=%.6f", best_epoch, best_val)
|
| 344 |
+
|
| 345 |
+
# Evaluate on test set
|
| 346 |
+
_, _, test_loader = build_dataloaders(cfg)
|
| 347 |
+
_, test_metrics = _val_epoch(
|
| 348 |
+
model, test_loader, loss_fn, device, cfg,
|
| 349 |
+
epoch=0, label=f"{label}-test", use_sigmoid=use_sigmoid
|
| 350 |
+
)
|
| 351 |
+
return test_metrics, best_val, best_epoch, []
|
| 352 |
+
|
| 353 |
+
train_loader, val_loader, test_loader = build_dataloaders(cfg)
|
| 354 |
+
|
| 355 |
+
if use_layerwise_lr and hasattr(model, "backbone") and hasattr(model, "head"):
|
| 356 |
+
optimizer = torch.optim.AdamW(
|
| 357 |
+
[
|
| 358 |
+
{"params": model.backbone.parameters(),
|
| 359 |
+
"lr": cfg.training.learning_rate * 0.1},
|
| 360 |
+
{"params": model.head.parameters(),
|
| 361 |
+
"lr": cfg.training.learning_rate},
|
| 362 |
+
],
|
| 363 |
+
weight_decay=cfg.training.weight_decay,
|
| 364 |
+
)
|
| 365 |
+
log.info("%s: layer-wise lr β backbone=%.1e head=%.1e",
|
| 366 |
+
label, cfg.training.learning_rate * 0.1, cfg.training.learning_rate)
|
| 367 |
+
else:
|
| 368 |
+
optimizer = torch.optim.AdamW(
|
| 369 |
+
model.parameters(),
|
| 370 |
+
lr=cfg.training.learning_rate,
|
| 371 |
+
weight_decay=cfg.training.weight_decay,
|
| 372 |
+
)
|
| 373 |
+
log.info("%s: single lr=%.1e", label, cfg.training.learning_rate)
|
| 374 |
+
|
| 375 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 376 |
+
optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min
|
| 377 |
+
)
|
| 378 |
+
scaler = GradScaler("cuda")
|
| 379 |
+
early_stop = EarlyStopping(
|
| 380 |
+
patience=cfg.early_stopping.patience,
|
| 381 |
+
min_delta=cfg.early_stopping.min_delta,
|
| 382 |
+
checkpoint_path=checkpoint_path,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
wandb.init(
|
| 386 |
+
project=cfg.wandb.project,
|
| 387 |
+
name=label,
|
| 388 |
+
config={
|
| 389 |
+
"model": label, "backbone": "resnet18" if "ResNet" in label else "vit_base_patch16_224",
|
| 390 |
+
"batch_size": cfg.training.batch_size, "lr": cfg.training.learning_rate,
|
| 391 |
+
"epochs": cfg.training.epochs, "seed": cfg.seed,
|
| 392 |
+
"lambda_kl": cfg.loss.lambda_kl, "lambda_mse": cfg.loss.lambda_mse,
|
| 393 |
+
},
|
| 394 |
+
reinit=True,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
history = []
|
| 398 |
+
for epoch in range(1, cfg.training.epochs + 1):
|
| 399 |
+
train_loss = _train_epoch(
|
| 400 |
+
model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch, label
|
| 401 |
+
)
|
| 402 |
+
val_loss, val_metrics = _val_epoch(
|
| 403 |
+
model, val_loader, loss_fn, device, cfg, epoch, label,
|
| 404 |
+
use_sigmoid=use_sigmoid
|
| 405 |
+
)
|
| 406 |
+
scheduler.step()
|
| 407 |
+
lr = scheduler.get_last_lr()[0]
|
| 408 |
+
|
| 409 |
+
val_mae = val_metrics.get("mae/weighted_avg", 0)
|
| 410 |
+
log.info("%s epoch=%d train=%.4f val=%.4f mae=%.4f lr=%.2e",
|
| 411 |
+
label, epoch, train_loss, val_loss, val_mae, lr)
|
| 412 |
+
history.append({
|
| 413 |
+
"epoch": epoch, "train_loss": train_loss,
|
| 414 |
+
"val_loss": val_loss, "val_mae": val_mae,
|
| 415 |
+
})
|
| 416 |
+
wandb.log({
|
| 417 |
+
"train_loss": train_loss, "val_loss": val_loss,
|
| 418 |
+
"val_mae": val_mae, "lr": lr,
|
| 419 |
+
}, step=epoch)
|
| 420 |
+
|
| 421 |
+
if early_stop.step(val_loss, model, epoch):
|
| 422 |
+
log.info("%s: early stopping at epoch %d best=%d",
|
| 423 |
+
label, epoch, early_stop.best_epoch)
|
| 424 |
+
break
|
| 425 |
+
|
| 426 |
+
best_val = early_stop.restore_best(model)
|
| 427 |
+
wandb.finish()
|
| 428 |
+
|
| 429 |
+
log.info("%s: evaluating on test set...", label)
|
| 430 |
+
_, test_metrics = _val_epoch(
|
| 431 |
+
model, test_loader, loss_fn, device, cfg,
|
| 432 |
+
epoch=0, label=f"{label}-test", use_sigmoid=use_sigmoid
|
| 433 |
+
)
|
| 434 |
+
return test_metrics, best_val, early_stop.best_epoch, history
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 438 |
+
# Dirichlet train_and_evaluate
|
| 439 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 440 |
+
|
| 441 |
+
def train_and_evaluate_dirichlet(model, loss_fn, cfg, device,
|
| 442 |
+
label, checkpoint_path):
|
| 443 |
+
"""Training loop for Dirichlet model. Skips training if checkpoint exists."""
|
| 444 |
+
# Check if checkpoint exists - if so, skip training
|
| 445 |
+
if Path(checkpoint_path).exists():
|
| 446 |
+
log.info("%s: checkpoint found - loading and skipping training", label)
|
| 447 |
+
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
|
| 448 |
+
model.load_state_dict(ckpt["model_state"])
|
| 449 |
+
best_epoch = ckpt.get("epoch", 0)
|
| 450 |
+
best_val = ckpt.get("val_loss", float("inf"))
|
| 451 |
+
log.info("Restored: epoch=%d, val_loss=%.6f", best_epoch, best_val)
|
| 452 |
+
|
| 453 |
+
# Evaluate on test set
|
| 454 |
+
_, _, test_loader = build_dataloaders(cfg)
|
| 455 |
+
_, test_metrics = _val_epoch_dirichlet(
|
| 456 |
+
model, test_loader, loss_fn, device, cfg,
|
| 457 |
+
epoch=0, label=f"{label}-test"
|
| 458 |
+
)
|
| 459 |
+
return test_metrics, best_val, best_epoch, []
|
| 460 |
+
|
| 461 |
+
train_loader, val_loader, test_loader = build_dataloaders(cfg)
|
| 462 |
+
|
| 463 |
+
optimizer = torch.optim.AdamW(
|
| 464 |
+
[
|
| 465 |
+
{"params": model.backbone.parameters(),
|
| 466 |
+
"lr": cfg.training.learning_rate * 0.1},
|
| 467 |
+
{"params": model.head.parameters(),
|
| 468 |
+
"lr": cfg.training.learning_rate},
|
| 469 |
+
],
|
| 470 |
+
weight_decay=cfg.training.weight_decay,
|
| 471 |
+
)
|
| 472 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 473 |
+
optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min
|
| 474 |
+
)
|
| 475 |
+
scaler = GradScaler("cuda")
|
| 476 |
+
early_stop = EarlyStopping(
|
| 477 |
+
patience=cfg.early_stopping.patience,
|
| 478 |
+
min_delta=cfg.early_stopping.min_delta,
|
| 479 |
+
checkpoint_path=checkpoint_path,
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
wandb.init(
|
| 483 |
+
project=cfg.wandb.project, name=label,
|
| 484 |
+
config={"model": label, "loss": "DirichletNLL",
|
| 485 |
+
"seed": cfg.seed, "epochs": cfg.training.epochs},
|
| 486 |
+
reinit=True,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
history = []
|
| 490 |
+
for epoch in range(1, cfg.training.epochs + 1):
|
| 491 |
+
train_loss = _train_epoch_dirichlet(
|
| 492 |
+
model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch, label
|
| 493 |
+
)
|
| 494 |
+
val_loss, val_metrics = _val_epoch_dirichlet(
|
| 495 |
+
model, val_loader, loss_fn, device, cfg, epoch, label
|
| 496 |
+
)
|
| 497 |
+
scheduler.step()
|
| 498 |
+
lr = scheduler.get_last_lr()[0]
|
| 499 |
+
|
| 500 |
+
val_mae = val_metrics.get("mae/weighted_avg", 0)
|
| 501 |
+
log.info("%s epoch=%d train=%.4f val=%.4f mae=%.4f lr=%.2e",
|
| 502 |
+
label, epoch, train_loss, val_loss, val_mae, lr)
|
| 503 |
+
history.append({
|
| 504 |
+
"epoch": epoch, "train_loss": train_loss,
|
| 505 |
+
"val_loss": val_loss, "val_mae": val_mae,
|
| 506 |
+
})
|
| 507 |
+
wandb.log({
|
| 508 |
+
"train_loss": train_loss, "val_loss": val_loss,
|
| 509 |
+
"val_mae": val_mae, "lr": lr,
|
| 510 |
+
}, step=epoch)
|
| 511 |
+
|
| 512 |
+
if early_stop.step(val_loss, model, epoch):
|
| 513 |
+
log.info("%s: early stopping at epoch %d", label, epoch)
|
| 514 |
+
break
|
| 515 |
+
|
| 516 |
+
best_val = early_stop.restore_best(model)
|
| 517 |
+
wandb.finish()
|
| 518 |
+
|
| 519 |
+
log.info("%s: evaluating on test set...", label)
|
| 520 |
+
_, test_metrics = _val_epoch_dirichlet(
|
| 521 |
+
model, test_loader, loss_fn, device, cfg, epoch=0, label=f"{label}-test"
|
| 522 |
+
)
|
| 523 |
+
return test_metrics, best_val, early_stop.best_epoch, history
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 527 |
+
# Figures
|
| 528 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 529 |
+
|
| 530 |
+
def _save_comparison_figures(all_results, all_histories, save_dir):
|
| 531 |
+
"""
|
| 532 |
+
Saves:
|
| 533 |
+
1. Per-question MAE + RMSE bar chart
|
| 534 |
+
2. Validation MAE learning curves
|
| 535 |
+
3. Simplex violation table for sigmoid baseline
|
| 536 |
+
All figure names follow IEEE journal conventions.
|
| 537 |
+
"""
|
| 538 |
+
q_names = list(QUESTION_GROUPS.keys())
|
| 539 |
+
n_models = len(all_results)
|
| 540 |
+
x = np.arange(len(q_names))
|
| 541 |
+
width = 0.80 / n_models
|
| 542 |
+
palette = ["#c0392b", "#e67e22", "#2980b9", "#27ae60", "#8e44ad"]
|
| 543 |
+
|
| 544 |
+
# ββ Figure 1: Per-question MAE and RMSE βββββββββββββββββββ
|
| 545 |
+
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
|
| 546 |
+
for metric, ax, ylabel in [
|
| 547 |
+
("mae", axes[0], "Mean Absolute Error (MAE)"),
|
| 548 |
+
("rmse", axes[1], "Root Mean Squared Error (RMSE)"),
|
| 549 |
+
]:
|
| 550 |
+
for i, (row_d, color) in enumerate(zip(all_results, palette)):
|
| 551 |
+
vals = [row_d.get(f"{metric}_{q}", np.nan) for q in q_names]
|
| 552 |
+
ax.bar(x + i * width, vals, width,
|
| 553 |
+
label=row_d["model"], color=color,
|
| 554 |
+
alpha=0.85, edgecolor="white", linewidth=0.5)
|
| 555 |
+
ax.set_xticks(x + width * (n_models - 1) / 2)
|
| 556 |
+
ax.set_xticklabels(
|
| 557 |
+
[f"{q}\n({QUESTION_LABELS[q][:10]})" for q in q_names],
|
| 558 |
+
rotation=45, ha="right", fontsize=7,
|
| 559 |
+
)
|
| 560 |
+
ax.set_ylabel(ylabel, fontsize=11)
|
| 561 |
+
ax.set_title(f"Per-question {metric.upper()} β baseline comparison", fontsize=11)
|
| 562 |
+
ax.legend(fontsize=7, loc="upper right")
|
| 563 |
+
ax.grid(True, alpha=0.3, axis="y")
|
| 564 |
+
ax.set_axisbelow(True)
|
| 565 |
+
|
| 566 |
+
plt.suptitle(
|
| 567 |
+
"Baseline comparison β GZ2 hierarchical probabilistic regression\n"
|
| 568 |
+
"Full 239,267-sample dataset, identical seed/split/protocol",
|
| 569 |
+
fontsize=12, y=1.02,
|
| 570 |
+
)
|
| 571 |
+
plt.tight_layout()
|
| 572 |
+
fig.savefig(save_dir / "fig_baseline_comparison_mae_rmse.pdf",
|
| 573 |
+
dpi=300, bbox_inches="tight")
|
| 574 |
+
fig.savefig(save_dir / "fig_baseline_comparison_mae_rmse.png",
|
| 575 |
+
dpi=300, bbox_inches="tight")
|
| 576 |
+
plt.close(fig)
|
| 577 |
+
log.info("Saved: fig_baseline_comparison_mae_rmse")
|
| 578 |
+
|
| 579 |
+
# ββ Figure 2: Validation MAE learning curves βββββββββββββββ
|
| 580 |
+
fig2, ax2 = plt.subplots(figsize=(10, 5))
|
| 581 |
+
styles = ["-", "--", "-.", ":", (0, (3, 1, 1, 1))]
|
| 582 |
+
markers = ["o", "s", "^", "D", "v"]
|
| 583 |
+
for (name, hist), ls, color, mk in zip(
|
| 584 |
+
all_histories.items(), styles, palette, markers
|
| 585 |
+
):
|
| 586 |
+
epochs_h = [h["epoch"] for h in hist]
|
| 587 |
+
val_maes = [h["val_mae"] for h in hist]
|
| 588 |
+
ax2.plot(epochs_h, val_maes, linestyle=ls, color=color, linewidth=1.8,
|
| 589 |
+
label=name, marker=mk, markersize=3, markevery=5)
|
| 590 |
+
|
| 591 |
+
ax2.set_xlabel("Epoch", fontsize=11)
|
| 592 |
+
ax2.set_ylabel("Validation MAE (weighted average)", fontsize=11)
|
| 593 |
+
ax2.set_title("Validation MAE during training β all baseline models", fontsize=11)
|
| 594 |
+
ax2.legend(fontsize=9)
|
| 595 |
+
ax2.grid(True, alpha=0.3)
|
| 596 |
+
plt.tight_layout()
|
| 597 |
+
fig2.savefig(save_dir / "fig_baseline_val_mae_curves.pdf",
|
| 598 |
+
dpi=300, bbox_inches="tight")
|
| 599 |
+
fig2.savefig(save_dir / "fig_baseline_val_mae_curves.png",
|
| 600 |
+
dpi=300, bbox_inches="tight")
|
| 601 |
+
plt.close(fig2)
|
| 602 |
+
log.info("Saved: fig_baseline_val_mae_curves")
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 606 |
+
# Main
|
| 607 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 608 |
+
|
| 609 |
+
def main():
|
| 610 |
+
parser = argparse.ArgumentParser()
|
| 611 |
+
parser.add_argument("--config", required=True)
|
| 612 |
+
args = parser.parse_args()
|
| 613 |
+
|
| 614 |
+
base_cfg = OmegaConf.load("configs/base.yaml")
|
| 615 |
+
exp_cfg = OmegaConf.load(args.config)
|
| 616 |
+
cfg = OmegaConf.merge(base_cfg, exp_cfg)
|
| 617 |
+
|
| 618 |
+
set_seed(cfg.seed)
|
| 619 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 620 |
+
log.info("Device: %s Dataset: %s",
|
| 621 |
+
device, "full 239k" if cfg.data.n_samples is None
|
| 622 |
+
else f"{cfg.data.n_samples:,}")
|
| 623 |
+
|
| 624 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 625 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 626 |
+
|
| 627 |
+
save_dir = Path(cfg.outputs.figures_dir) / "comparison"
|
| 628 |
+
ckpt_dir = Path(cfg.outputs.checkpoint_dir)
|
| 629 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 630 |
+
|
| 631 |
+
all_results = []
|
| 632 |
+
all_histories = {}
|
| 633 |
+
|
| 634 |
+
# βββ B1: ResNet-18 + independent MSE (sigmoid) ββββββββββββ
|
| 635 |
+
log.info("=" * 60)
|
| 636 |
+
log.info("B1: ResNet-18 + independent MSE (sigmoid, no hierarchy)")
|
| 637 |
+
log.info("=" * 60)
|
| 638 |
+
set_seed(cfg.seed)
|
| 639 |
+
|
| 640 |
+
rn_mse_model = ResNet18Baseline(dropout=cfg.model.dropout).to(device)
|
| 641 |
+
rn_mse_loss = IndependentMSELoss()
|
| 642 |
+
log.info("ResNet-18 params: %s", f"{sum(p.numel() for p in rn_mse_model.parameters()):,}")
|
| 643 |
+
|
| 644 |
+
rn_mse_metrics, rn_mse_val, rn_mse_epoch, rn_mse_hist = train_and_evaluate(
|
| 645 |
+
rn_mse_model, rn_mse_loss, cfg, device,
|
| 646 |
+
label = "B1-ResNet18-MSE",
|
| 647 |
+
checkpoint_path = str(ckpt_dir / "baseline_resnet18_mse.pt"),
|
| 648 |
+
use_layerwise_lr = False,
|
| 649 |
+
use_sigmoid = True,
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
# Simplex violation for this baseline
|
| 653 |
+
_, _, test_loader_tmp = build_dataloaders(cfg)
|
| 654 |
+
rn_mse_model.eval()
|
| 655 |
+
tmp_preds = []
|
| 656 |
+
with torch.no_grad():
|
| 657 |
+
for images, _, _, _ in test_loader_tmp:
|
| 658 |
+
images = images.to(device, non_blocking=True)
|
| 659 |
+
logits = rn_mse_model(images)
|
| 660 |
+
tmp_preds.append(torch.sigmoid(logits).cpu().numpy())
|
| 661 |
+
tmp_preds = np.concatenate(tmp_preds)
|
| 662 |
+
svr = simplex_violation_rate(tmp_preds, tolerance=0.02)
|
| 663 |
+
log.info("B1 simplex violation rate (mean): %.4f", svr["mean"])
|
| 664 |
+
|
| 665 |
+
row = {
|
| 666 |
+
"model": "ResNet-18 + MSE (sigmoid, no hierarchy)",
|
| 667 |
+
"backbone": "ResNet-18", "loss": "Independent MSE",
|
| 668 |
+
"hierarchy": "None",
|
| 669 |
+
"best_epoch": rn_mse_epoch, "best_val_loss": round(rn_mse_val, 5),
|
| 670 |
+
"mae_weighted" : round(rn_mse_metrics["mae/weighted_avg"], 5),
|
| 671 |
+
"rmse_weighted": round(rn_mse_metrics["rmse/weighted_avg"], 5),
|
| 672 |
+
"simplex_violation_mean": round(svr["mean"], 4),
|
| 673 |
+
}
|
| 674 |
+
for q in QUESTION_GROUPS:
|
| 675 |
+
row[f"mae_{q}"] = round(rn_mse_metrics[f"mae/{q}"], 5)
|
| 676 |
+
row[f"rmse_{q}"] = round(rn_mse_metrics[f"rmse/{q}"], 5)
|
| 677 |
+
all_results.append(row)
|
| 678 |
+
all_histories["ResNet-18 + MSE (sigmoid)"] = rn_mse_hist
|
| 679 |
+
log.info("B1 done: MAE=%.5f RMSE=%.5f SimplexViol=%.4f",
|
| 680 |
+
rn_mse_metrics["mae/weighted_avg"],
|
| 681 |
+
rn_mse_metrics["rmse/weighted_avg"],
|
| 682 |
+
svr["mean"])
|
| 683 |
+
|
| 684 |
+
# βββ B2: ResNet-18 + hierarchical KL+MSE ββββββββββββββββββ
|
| 685 |
+
log.info("=" * 60)
|
| 686 |
+
log.info("B2: ResNet-18 + hierarchical KL+MSE (same loss as proposed)")
|
| 687 |
+
log.info("=" * 60)
|
| 688 |
+
set_seed(cfg.seed)
|
| 689 |
+
|
| 690 |
+
rn_kl_model = ResNet18Baseline(dropout=cfg.model.dropout).to(device)
|
| 691 |
+
rn_kl_loss = HierarchicalLoss(cfg)
|
| 692 |
+
|
| 693 |
+
rn_kl_metrics, rn_kl_val, rn_kl_epoch, rn_kl_hist = train_and_evaluate(
|
| 694 |
+
rn_kl_model, rn_kl_loss, cfg, device,
|
| 695 |
+
label = "B2-ResNet18-KL+MSE",
|
| 696 |
+
checkpoint_path = str(ckpt_dir / "baseline_resnet18_klmse.pt"),
|
| 697 |
+
use_layerwise_lr = False,
|
| 698 |
+
use_sigmoid = False,
|
| 699 |
+
)
|
| 700 |
+
row = {
|
| 701 |
+
"model": "ResNet-18 + hierarchical KL+MSE",
|
| 702 |
+
"backbone": "ResNet-18", "loss": "Hierarchical KL+MSE (Ξ»=0.5)",
|
| 703 |
+
"hierarchy": "Full (weights + KL)",
|
| 704 |
+
"best_epoch": rn_kl_epoch, "best_val_loss": round(rn_kl_val, 5),
|
| 705 |
+
"mae_weighted" : round(rn_kl_metrics["mae/weighted_avg"], 5),
|
| 706 |
+
"rmse_weighted": round(rn_kl_metrics["rmse/weighted_avg"], 5),
|
| 707 |
+
"simplex_violation_mean": 0.0, # softmax guarantees validity
|
| 708 |
+
}
|
| 709 |
+
for q in QUESTION_GROUPS:
|
| 710 |
+
row[f"mae_{q}"] = round(rn_kl_metrics[f"mae/{q}"], 5)
|
| 711 |
+
row[f"rmse_{q}"] = round(rn_kl_metrics[f"rmse/{q}"], 5)
|
| 712 |
+
all_results.append(row)
|
| 713 |
+
all_histories["ResNet-18 + KL+MSE"] = rn_kl_hist
|
| 714 |
+
log.info("B2 done: MAE=%.5f RMSE=%.5f",
|
| 715 |
+
rn_kl_metrics["mae/weighted_avg"],
|
| 716 |
+
rn_kl_metrics["rmse/weighted_avg"])
|
| 717 |
+
|
| 718 |
+
# βββ B3: ViT-Base + hierarchical MSE only βββββββββββββββββ
|
| 719 |
+
log.info("=" * 60)
|
| 720 |
+
log.info("B3: ViT-Base + hierarchical MSE only (no KL term)")
|
| 721 |
+
log.info("=" * 60)
|
| 722 |
+
set_seed(cfg.seed)
|
| 723 |
+
|
| 724 |
+
from omegaconf import OmegaConf as OC
|
| 725 |
+
vit_mse_cfg = OC.merge(cfg, OC.create({"loss": {"lambda_kl": 0.0, "lambda_mse": 1.0}}))
|
| 726 |
+
vit_mse_model = build_model(vit_mse_cfg).to(device)
|
| 727 |
+
vit_mse_loss = MSEOnlyLoss(vit_mse_cfg)
|
| 728 |
+
|
| 729 |
+
vit_mse_metrics, vit_mse_val, vit_mse_epoch, vit_mse_hist = train_and_evaluate(
|
| 730 |
+
vit_mse_model, vit_mse_loss, vit_mse_cfg, device,
|
| 731 |
+
label = "B3-ViT-MSE",
|
| 732 |
+
checkpoint_path = str(ckpt_dir / "baseline_vit_mse.pt"),
|
| 733 |
+
use_layerwise_lr = True,
|
| 734 |
+
use_sigmoid = False,
|
| 735 |
+
)
|
| 736 |
+
row = {
|
| 737 |
+
"model": "ViT-Base + hierarchical MSE (no KL)",
|
| 738 |
+
"backbone": "ViT-Base/16", "loss": "Hierarchical MSE (Ξ»_KL=0)",
|
| 739 |
+
"hierarchy": "Weights only",
|
| 740 |
+
"best_epoch": vit_mse_epoch, "best_val_loss": round(vit_mse_val, 5),
|
| 741 |
+
"mae_weighted" : round(vit_mse_metrics["mae/weighted_avg"], 5),
|
| 742 |
+
"rmse_weighted": round(vit_mse_metrics["rmse/weighted_avg"], 5),
|
| 743 |
+
"simplex_violation_mean": 0.0,
|
| 744 |
+
}
|
| 745 |
+
for q in QUESTION_GROUPS:
|
| 746 |
+
row[f"mae_{q}"] = round(vit_mse_metrics[f"mae/{q}"], 5)
|
| 747 |
+
row[f"rmse_{q}"] = round(vit_mse_metrics[f"rmse/{q}"], 5)
|
| 748 |
+
all_results.append(row)
|
| 749 |
+
all_histories["ViT-Base + MSE only"] = vit_mse_hist
|
| 750 |
+
log.info("B3 done: MAE=%.5f RMSE=%.5f",
|
| 751 |
+
vit_mse_metrics["mae/weighted_avg"],
|
| 752 |
+
vit_mse_metrics["rmse/weighted_avg"])
|
| 753 |
+
|
| 754 |
+
# βββ B4: ViT-Base + Dirichlet NLL (Zoobot-style) ββββββββββ
|
| 755 |
+
log.info("=" * 60)
|
| 756 |
+
log.info("B4: ViT-Base + Dirichlet NLL (Walmsley et al. 2022)")
|
| 757 |
+
log.info("=" * 60)
|
| 758 |
+
set_seed(cfg.seed)
|
| 759 |
+
|
| 760 |
+
vit_dir_model = build_dirichlet_model(cfg).to(device)
|
| 761 |
+
vit_dir_loss = DirichletLoss(cfg)
|
| 762 |
+
|
| 763 |
+
vit_dir_metrics, vit_dir_val, vit_dir_epoch, vit_dir_hist = train_and_evaluate_dirichlet(
|
| 764 |
+
vit_dir_model, vit_dir_loss, cfg, device,
|
| 765 |
+
label = "B4-ViT-Dirichlet",
|
| 766 |
+
checkpoint_path = str(ckpt_dir / "baseline_vit_dirichlet.pt"),
|
| 767 |
+
)
|
| 768 |
+
row = {
|
| 769 |
+
"model": "ViT-Base + Dirichlet NLL (Zoobot-style)",
|
| 770 |
+
"backbone": "ViT-Base/16", "loss": "Dirichlet NLL",
|
| 771 |
+
"hierarchy": "Full (weights + Dirichlet)",
|
| 772 |
+
"best_epoch": vit_dir_epoch, "best_val_loss": round(vit_dir_val, 5),
|
| 773 |
+
"mae_weighted" : round(vit_dir_metrics["mae/weighted_avg"], 5),
|
| 774 |
+
"rmse_weighted": round(vit_dir_metrics["rmse/weighted_avg"], 5),
|
| 775 |
+
"simplex_violation_mean": 0.0,
|
| 776 |
+
}
|
| 777 |
+
for q in QUESTION_GROUPS:
|
| 778 |
+
row[f"mae_{q}"] = round(vit_dir_metrics[f"mae/{q}"], 5)
|
| 779 |
+
row[f"rmse_{q}"] = round(vit_dir_metrics[f"rmse/{q}"], 5)
|
| 780 |
+
all_results.append(row)
|
| 781 |
+
all_histories["ViT-Base + Dirichlet"] = vit_dir_hist
|
| 782 |
+
log.info("B4 done: MAE=%.5f RMSE=%.5f",
|
| 783 |
+
vit_dir_metrics["mae/weighted_avg"],
|
| 784 |
+
vit_dir_metrics["rmse/weighted_avg"])
|
| 785 |
+
|
| 786 |
+
# βββ Proposed: load existing checkpoint for final table ββββ
|
| 787 |
+
proposed_ckpt = ckpt_dir / "best_full_train.pt"
|
| 788 |
+
if proposed_ckpt.exists():
|
| 789 |
+
log.info("=" * 60)
|
| 790 |
+
log.info("PROPOSED: Loading ViT-Base + hierarchical KL+MSE")
|
| 791 |
+
log.info("=" * 60)
|
| 792 |
+
proposed_model = build_model(cfg).to(device)
|
| 793 |
+
proposed_model.load_state_dict(
|
| 794 |
+
torch.load(proposed_ckpt, map_location="cpu", weights_only=True)["model_state"]
|
| 795 |
+
)
|
| 796 |
+
_, _, test_loader_p = build_dataloaders(cfg)
|
| 797 |
+
_, proposed_metrics = _val_epoch(
|
| 798 |
+
proposed_model, test_loader_p, HierarchicalLoss(cfg), device, cfg,
|
| 799 |
+
epoch=0, label="Proposed-test", use_sigmoid=False
|
| 800 |
+
)
|
| 801 |
+
ckpt_info = torch.load(proposed_ckpt, map_location="cpu", weights_only=True)
|
| 802 |
+
row = {
|
| 803 |
+
"model": "ViT-Base + hierarchical KL+MSE (proposed)",
|
| 804 |
+
"backbone": "ViT-Base/16", "loss": "Hierarchical KL+MSE (Ξ»=0.5)",
|
| 805 |
+
"hierarchy": "Full (weights + KL)",
|
| 806 |
+
"best_epoch": ckpt_info["epoch"],
|
| 807 |
+
"best_val_loss": round(ckpt_info["val_loss"], 5),
|
| 808 |
+
"mae_weighted" : round(proposed_metrics["mae/weighted_avg"], 5),
|
| 809 |
+
"rmse_weighted": round(proposed_metrics["rmse/weighted_avg"], 5),
|
| 810 |
+
"simplex_violation_mean": 0.0,
|
| 811 |
+
}
|
| 812 |
+
for q in QUESTION_GROUPS:
|
| 813 |
+
row[f"mae_{q}"] = round(proposed_metrics[f"mae/{q}"], 5)
|
| 814 |
+
row[f"rmse_{q}"] = round(proposed_metrics[f"rmse/{q}"], 5)
|
| 815 |
+
all_results.append(row)
|
| 816 |
+
log.info("Proposed: MAE=%.5f RMSE=%.5f",
|
| 817 |
+
proposed_metrics["mae/weighted_avg"],
|
| 818 |
+
proposed_metrics["rmse/weighted_avg"])
|
| 819 |
+
|
| 820 |
+
# βββ Save results ββββββββββββββββββββββββββββββββββββββββββ
|
| 821 |
+
df = pd.DataFrame(all_results)
|
| 822 |
+
df.to_csv(save_dir / "table_baseline_comparison.csv", index=False)
|
| 823 |
+
|
| 824 |
+
summary_cols = ["model", "loss", "hierarchy", "best_epoch",
|
| 825 |
+
"best_val_loss", "mae_weighted", "rmse_weighted",
|
| 826 |
+
"simplex_violation_mean"]
|
| 827 |
+
summary = df[[c for c in summary_cols if c in df.columns]].copy()
|
| 828 |
+
summary.to_csv(save_dir / "table_baseline_summary.csv", index=False)
|
| 829 |
+
|
| 830 |
+
print()
|
| 831 |
+
print("=" * 80)
|
| 832 |
+
print("BASELINE COMPARISON β FINAL RESULTS")
|
| 833 |
+
print("=" * 80)
|
| 834 |
+
print(summary.to_string(index=False))
|
| 835 |
+
print()
|
| 836 |
+
|
| 837 |
+
# βββ Figures βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 838 |
+
_save_comparison_figures(all_results, all_histories, save_dir)
|
| 839 |
+
|
| 840 |
+
log.info("All baseline outputs saved to: %s", save_dir)
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
if __name__ == "__main__":
|
| 844 |
+
main()
|
src/dataset.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/dataset.py
|
| 3 |
+
--------------
|
| 4 |
+
Galaxy Zoo 2 dataset loader for hierarchical probabilistic regression.
|
| 5 |
+
|
| 6 |
+
The GZ2 decision tree has 11 questions (t01-t11) with 37 total answer
|
| 7 |
+
columns. Each question is a conditional probability vector β not
|
| 8 |
+
independent regression targets.
|
| 9 |
+
|
| 10 |
+
Hierarchy (parent answer -> child question):
|
| 11 |
+
t01_a02 (features/disk) -> t02, t03, t04, t05, t06
|
| 12 |
+
t02_a05 (not edge-on) -> t03, t04
|
| 13 |
+
t04_a08 (has spiral) -> t10, t11
|
| 14 |
+
t06_a14 (odd feature) -> t08
|
| 15 |
+
t01_a01 (smooth) -> t07
|
| 16 |
+
t02_a04 (edge-on) -> t09
|
| 17 |
+
|
| 18 |
+
References
|
| 19 |
+
----------
|
| 20 |
+
Willett et al. (2013), MNRAS 435, 2835
|
| 21 |
+
Hart et al. (2016), MNRAS 461, 3663
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import math
|
| 25 |
+
import logging
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
import pandas as pd
|
| 30 |
+
import torch
|
| 31 |
+
from torch.utils.data import Dataset, DataLoader
|
| 32 |
+
from torchvision import transforms
|
| 33 |
+
from PIL import Image
|
| 34 |
+
from omegaconf import DictConfig
|
| 35 |
+
|
| 36 |
+
log = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 39 |
+
# GZ2 decision tree definition
|
| 40 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
+
|
| 42 |
+
LABEL_COLUMNS = [
|
| 43 |
+
# t01: smooth or features?
|
| 44 |
+
"t01_smooth_or_features_a01_smooth_debiased",
|
| 45 |
+
"t01_smooth_or_features_a02_features_or_disk_debiased",
|
| 46 |
+
"t01_smooth_or_features_a03_star_or_artifact_debiased",
|
| 47 |
+
# t02: edge-on?
|
| 48 |
+
"t02_edgeon_a04_yes_debiased",
|
| 49 |
+
"t02_edgeon_a05_no_debiased",
|
| 50 |
+
# t03: bar?
|
| 51 |
+
"t03_bar_a06_bar_debiased",
|
| 52 |
+
"t03_bar_a07_no_bar_debiased",
|
| 53 |
+
# t04: spiral?
|
| 54 |
+
"t04_spiral_a08_spiral_debiased",
|
| 55 |
+
"t04_spiral_a09_no_spiral_debiased",
|
| 56 |
+
# t05: bulge prominence
|
| 57 |
+
"t05_bulge_prominence_a10_no_bulge_debiased",
|
| 58 |
+
"t05_bulge_prominence_a11_just_noticeable_debiased",
|
| 59 |
+
"t05_bulge_prominence_a12_obvious_debiased",
|
| 60 |
+
"t05_bulge_prominence_a13_dominant_debiased",
|
| 61 |
+
# t06: odd feature?
|
| 62 |
+
"t06_odd_a14_yes_debiased",
|
| 63 |
+
"t06_odd_a15_no_debiased",
|
| 64 |
+
# t07: roundedness (smooth galaxies)
|
| 65 |
+
"t07_rounded_a16_completely_round_debiased",
|
| 66 |
+
"t07_rounded_a17_in_between_debiased",
|
| 67 |
+
"t07_rounded_a18_cigar_shaped_debiased",
|
| 68 |
+
# t08: odd feature type
|
| 69 |
+
"t08_odd_feature_a19_ring_debiased",
|
| 70 |
+
"t08_odd_feature_a20_lens_or_arc_debiased",
|
| 71 |
+
"t08_odd_feature_a21_disturbed_debiased",
|
| 72 |
+
"t08_odd_feature_a22_irregular_debiased",
|
| 73 |
+
"t08_odd_feature_a23_other_debiased",
|
| 74 |
+
"t08_odd_feature_a24_merger_debiased",
|
| 75 |
+
"t08_odd_feature_a38_dust_lane_debiased",
|
| 76 |
+
# t09: bulge shape (edge-on only)
|
| 77 |
+
"t09_bulge_shape_a25_rounded_debiased",
|
| 78 |
+
"t09_bulge_shape_a26_boxy_debiased",
|
| 79 |
+
"t09_bulge_shape_a27_no_bulge_debiased",
|
| 80 |
+
# t10: arms winding
|
| 81 |
+
"t10_arms_winding_a28_tight_debiased",
|
| 82 |
+
"t10_arms_winding_a29_medium_debiased",
|
| 83 |
+
"t10_arms_winding_a30_loose_debiased",
|
| 84 |
+
# t11: arms number
|
| 85 |
+
"t11_arms_number_a31_1_debiased",
|
| 86 |
+
"t11_arms_number_a32_2_debiased",
|
| 87 |
+
"t11_arms_number_a33_3_debiased",
|
| 88 |
+
"t11_arms_number_a34_4_debiased",
|
| 89 |
+
"t11_arms_number_a36_more_than_4_debiased",
|
| 90 |
+
"t11_arms_number_a37_cant_tell_debiased",
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
# Slice indices into LABEL_COLUMNS for each question group.
|
| 94 |
+
QUESTION_GROUPS = {
|
| 95 |
+
"t01": (0, 3),
|
| 96 |
+
"t02": (3, 5),
|
| 97 |
+
"t03": (5, 7),
|
| 98 |
+
"t04": (7, 9),
|
| 99 |
+
"t05": (9, 13),
|
| 100 |
+
"t06": (13, 15),
|
| 101 |
+
"t07": (15, 18),
|
| 102 |
+
"t08": (18, 25),
|
| 103 |
+
"t09": (25, 28),
|
| 104 |
+
"t10": (28, 31),
|
| 105 |
+
"t11": (31, 37),
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
# Parent answer column for hierarchical branch weighting.
|
| 109 |
+
# w_q = vote fraction of the parent answer that unlocks question q.
|
| 110 |
+
# t01 is the root question; its weight is always 1.0.
|
| 111 |
+
QUESTION_PARENT_COL = {
|
| 112 |
+
"t01": None,
|
| 113 |
+
"t02": "t01_smooth_or_features_a02_features_or_disk_debiased",
|
| 114 |
+
"t03": "t02_edgeon_a05_no_debiased",
|
| 115 |
+
"t04": "t02_edgeon_a05_no_debiased",
|
| 116 |
+
"t05": "t01_smooth_or_features_a02_features_or_disk_debiased",
|
| 117 |
+
"t06": "t01_smooth_or_features_a02_features_or_disk_debiased",
|
| 118 |
+
"t07": "t01_smooth_or_features_a01_smooth_debiased",
|
| 119 |
+
"t08": "t06_odd_a14_yes_debiased",
|
| 120 |
+
"t09": "t02_edgeon_a04_yes_debiased",
|
| 121 |
+
"t10": "t04_spiral_a08_spiral_debiased",
|
| 122 |
+
"t11": "t04_spiral_a08_spiral_debiased",
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
N_LABELS = len(LABEL_COLUMNS) # 37
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 129 |
+
# Image transforms
|
| 130 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 131 |
+
|
| 132 |
+
def get_transforms(image_size: int, split: str) -> transforms.Compose:
|
| 133 |
+
"""
|
| 134 |
+
Training: random flips + rotations (galaxies have no preferred orientation),
|
| 135 |
+
colour jitter (instrument variation), ImageNet normalisation.
|
| 136 |
+
Val/Test: resize only, ImageNet normalisation.
|
| 137 |
+
"""
|
| 138 |
+
mean = [0.485, 0.456, 0.406]
|
| 139 |
+
std = [0.229, 0.224, 0.225]
|
| 140 |
+
|
| 141 |
+
if split == "train":
|
| 142 |
+
return transforms.Compose([
|
| 143 |
+
transforms.Resize((image_size + 16, image_size + 16)),
|
| 144 |
+
transforms.RandomCrop(image_size),
|
| 145 |
+
transforms.RandomHorizontalFlip(),
|
| 146 |
+
transforms.RandomVerticalFlip(),
|
| 147 |
+
transforms.RandomRotation(180),
|
| 148 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
|
| 149 |
+
transforms.ToTensor(),
|
| 150 |
+
transforms.Normalize(mean=mean, std=std),
|
| 151 |
+
])
|
| 152 |
+
else:
|
| 153 |
+
return transforms.Compose([
|
| 154 |
+
transforms.Resize((image_size, image_size)),
|
| 155 |
+
transforms.ToTensor(),
|
| 156 |
+
transforms.Normalize(mean=mean, std=std),
|
| 157 |
+
])
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 161 |
+
# Dataset
|
| 162 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 163 |
+
|
| 164 |
+
class GalaxyZoo2Dataset(Dataset):
|
| 165 |
+
"""
|
| 166 |
+
PyTorch Dataset for Galaxy Zoo 2.
|
| 167 |
+
|
| 168 |
+
Returns
|
| 169 |
+
-------
|
| 170 |
+
image : FloatTensor [3, H, W] normalised galaxy image
|
| 171 |
+
targets : FloatTensor [37] vote fraction vector
|
| 172 |
+
weights : FloatTensor [11] per-question hierarchical weights
|
| 173 |
+
image_id : int dr7objid for traceability
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
def __init__(self, df: pd.DataFrame, image_dir: str, transform):
|
| 177 |
+
self.df = df.reset_index(drop=True)
|
| 178 |
+
self.image_dir = Path(image_dir)
|
| 179 |
+
self.transform = transform
|
| 180 |
+
|
| 181 |
+
self.labels = self.df[LABEL_COLUMNS].values.astype(np.float32)
|
| 182 |
+
self.weights = self._compute_weights()
|
| 183 |
+
self.image_ids = self.df["dr7objid"].tolist()
|
| 184 |
+
|
| 185 |
+
def _compute_weights(self) -> np.ndarray:
|
| 186 |
+
n = len(self.df)
|
| 187 |
+
q_names = list(QUESTION_GROUPS.keys())
|
| 188 |
+
weights = np.ones((n, len(q_names)), dtype=np.float32)
|
| 189 |
+
for q_idx, q_name in enumerate(q_names):
|
| 190 |
+
parent_col = QUESTION_PARENT_COL[q_name]
|
| 191 |
+
if parent_col is not None:
|
| 192 |
+
weights[:, q_idx] = self.df[parent_col].values.astype(np.float32)
|
| 193 |
+
return weights
|
| 194 |
+
|
| 195 |
+
def __len__(self) -> int:
|
| 196 |
+
return len(self.df)
|
| 197 |
+
|
| 198 |
+
def __getitem__(self, idx: int):
|
| 199 |
+
image_id = self.image_ids[idx]
|
| 200 |
+
img_path = self.image_dir / f"{image_id}.jpg"
|
| 201 |
+
try:
|
| 202 |
+
image = Image.open(img_path).convert("RGB")
|
| 203 |
+
except FileNotFoundError:
|
| 204 |
+
raise FileNotFoundError(
|
| 205 |
+
f"Image not found: {img_path}. "
|
| 206 |
+
f"Check dr7objid {image_id} has a matching .jpg file."
|
| 207 |
+
)
|
| 208 |
+
image = self.transform(image)
|
| 209 |
+
targets = torch.from_numpy(self.labels[idx])
|
| 210 |
+
weights = torch.from_numpy(self.weights[idx])
|
| 211 |
+
return image, targets, weights, image_id
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 215 |
+
# DataLoader factory
|
| 216 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 217 |
+
|
| 218 |
+
def build_dataloaders(cfg: DictConfig):
|
| 219 |
+
"""Build train / val / test DataLoaders from the labels parquet."""
|
| 220 |
+
log.info("Loading parquet: %s", cfg.data.parquet_path)
|
| 221 |
+
df = pd.read_parquet(cfg.data.parquet_path)
|
| 222 |
+
|
| 223 |
+
missing = [c for c in LABEL_COLUMNS if c not in df.columns]
|
| 224 |
+
if missing:
|
| 225 |
+
raise ValueError(f"Missing columns in parquet: {missing}")
|
| 226 |
+
|
| 227 |
+
if cfg.data.n_samples is not None:
|
| 228 |
+
n = int(cfg.data.n_samples)
|
| 229 |
+
log.info("Using subset of %d samples (full dataset: %d)", n, len(df))
|
| 230 |
+
df = df.sample(n=n, random_state=cfg.seed).reset_index(drop=True)
|
| 231 |
+
else:
|
| 232 |
+
log.info("Using full dataset: %d samples", len(df))
|
| 233 |
+
|
| 234 |
+
rng = np.random.default_rng(cfg.seed)
|
| 235 |
+
idx = rng.permutation(len(df))
|
| 236 |
+
n = len(df)
|
| 237 |
+
n_train = math.floor(cfg.data.train_frac * n)
|
| 238 |
+
n_val = math.floor(cfg.data.val_frac * n)
|
| 239 |
+
|
| 240 |
+
train_idx = idx[:n_train]
|
| 241 |
+
val_idx = idx[n_train : n_train + n_val]
|
| 242 |
+
test_idx = idx[n_train + n_val :]
|
| 243 |
+
|
| 244 |
+
log.info("Split β train: %d val: %d test: %d",
|
| 245 |
+
len(train_idx), len(val_idx), len(test_idx))
|
| 246 |
+
|
| 247 |
+
image_size = cfg.data.image_size
|
| 248 |
+
train_ds = GalaxyZoo2Dataset(
|
| 249 |
+
df.iloc[train_idx], cfg.data.image_dir,
|
| 250 |
+
get_transforms(image_size, "train"))
|
| 251 |
+
val_ds = GalaxyZoo2Dataset(
|
| 252 |
+
df.iloc[val_idx], cfg.data.image_dir,
|
| 253 |
+
get_transforms(image_size, "val"))
|
| 254 |
+
test_ds = GalaxyZoo2Dataset(
|
| 255 |
+
df.iloc[test_idx], cfg.data.image_dir,
|
| 256 |
+
get_transforms(image_size, "test"))
|
| 257 |
+
|
| 258 |
+
common = dict(
|
| 259 |
+
batch_size = cfg.training.batch_size,
|
| 260 |
+
num_workers = cfg.data.num_workers,
|
| 261 |
+
pin_memory = cfg.data.pin_memory,
|
| 262 |
+
persistent_workers = getattr(cfg.data, "persistent_workers", True),
|
| 263 |
+
prefetch_factor = getattr(cfg.data, "prefetch_factor", 4),
|
| 264 |
+
drop_last = False,
|
| 265 |
+
)
|
| 266 |
+
train_loader = DataLoader(train_ds, shuffle=True, **common)
|
| 267 |
+
val_loader = DataLoader(val_ds, shuffle=False, **common)
|
| 268 |
+
test_loader = DataLoader(test_ds, shuffle=False, **common)
|
| 269 |
+
|
| 270 |
+
return train_loader, val_loader, test_loader
|
src/evaluate_full.py
ADDED
|
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/evaluate_full.py
|
| 3 |
+
--------------------
|
| 4 |
+
Full evaluation of all trained models on the held-out test set.
|
| 5 |
+
|
| 6 |
+
Generates all paper figures and tables:
|
| 7 |
+
|
| 8 |
+
Tables
|
| 9 |
+
------
|
| 10 |
+
table_metrics_proposed.csv β MAE / RMSE / bias / ECE for proposed model
|
| 11 |
+
table_reached_branch_mae.csv β reached-branch MAE across all 5 models
|
| 12 |
+
table_simplex_violation.csv β simplex validity for sigmoid baseline
|
| 13 |
+
|
| 14 |
+
Figures (PDF + PNG, IEEE naming convention)
|
| 15 |
+
-------------------------------------------
|
| 16 |
+
fig_scatter_predicted_vs_true.pdf β predicted vs true vote fractions (proposed)
|
| 17 |
+
fig_calibration_reliability.pdf β reliability diagrams, all models
|
| 18 |
+
fig_ece_comparison.pdf β ECE bar chart, all models
|
| 19 |
+
fig_attention_rollout_gallery.pdf β full 12-layer attention rollout gallery
|
| 20 |
+
fig_attention_entropy_depth.pdf β CLS attention entropy vs. layer depth
|
| 21 |
+
|
| 22 |
+
Usage
|
| 23 |
+
-----
|
| 24 |
+
cd ~/galaxy
|
| 25 |
+
nohup python -m src.evaluate_full --config configs/full_train.yaml \
|
| 26 |
+
> outputs/logs/evaluate.log 2>&1 &
|
| 27 |
+
echo "PID: $!"
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import argparse
|
| 31 |
+
import logging
|
| 32 |
+
import sys
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
|
| 35 |
+
import numpy as np
|
| 36 |
+
import pandas as pd
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn.functional as F
|
| 39 |
+
import matplotlib
|
| 40 |
+
matplotlib.use("Agg")
|
| 41 |
+
import matplotlib.pyplot as plt
|
| 42 |
+
from torch.amp import autocast
|
| 43 |
+
from omegaconf import OmegaConf
|
| 44 |
+
from tqdm import tqdm
|
| 45 |
+
|
| 46 |
+
from src.dataset import build_dataloaders, QUESTION_GROUPS
|
| 47 |
+
from src.model import build_model, build_dirichlet_model
|
| 48 |
+
from src.metrics import (compute_metrics, predictions_to_numpy,
|
| 49 |
+
compute_reached_branch_mae_table,
|
| 50 |
+
dirichlet_predictions_to_numpy,
|
| 51 |
+
simplex_violation_rate, _compute_ece)
|
| 52 |
+
from src.attention_viz import plot_attention_grid, plot_attention_entropy
|
| 53 |
+
from src.baselines import ResNet18Baseline
|
| 54 |
+
|
| 55 |
+
logging.basicConfig(
|
| 56 |
+
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
| 57 |
+
datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
|
| 58 |
+
)
|
| 59 |
+
log = logging.getLogger("evaluate_full")
|
| 60 |
+
|
| 61 |
+
# ββ Global matplotlib style ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
plt.rcParams.update({
|
| 63 |
+
"figure.dpi" : 150,
|
| 64 |
+
"savefig.dpi" : 300,
|
| 65 |
+
"font.family" : "serif",
|
| 66 |
+
"font.size" : 11,
|
| 67 |
+
"axes.titlesize" : 11,
|
| 68 |
+
"axes.labelsize" : 11,
|
| 69 |
+
"xtick.labelsize" : 9,
|
| 70 |
+
"ytick.labelsize" : 9,
|
| 71 |
+
"legend.fontsize" : 9,
|
| 72 |
+
"figure.facecolor" : "white",
|
| 73 |
+
"axes.facecolor" : "white",
|
| 74 |
+
"axes.grid" : True,
|
| 75 |
+
"grid.alpha" : 0.3,
|
| 76 |
+
"pdf.fonttype" : 42, # editable text in PDF
|
| 77 |
+
"ps.fonttype" : 42,
|
| 78 |
+
})
|
| 79 |
+
|
| 80 |
+
QUESTION_LABELS = {
|
| 81 |
+
"t01": "Smooth or features",
|
| 82 |
+
"t02": "Edge-on disk",
|
| 83 |
+
"t03": "Bar",
|
| 84 |
+
"t04": "Spiral arms",
|
| 85 |
+
"t05": "Bulge prominence",
|
| 86 |
+
"t06": "Odd feature",
|
| 87 |
+
"t07": "Roundedness",
|
| 88 |
+
"t08": "Odd feature type",
|
| 89 |
+
"t09": "Bulge shape",
|
| 90 |
+
"t10": "Arms winding",
|
| 91 |
+
"t11": "Arms number",
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
# Consistent colours and line styles for all models across all figures
|
| 95 |
+
MODEL_COLORS = {
|
| 96 |
+
"ResNet-18 + MSE (sigmoid)" : "#c0392b",
|
| 97 |
+
"ResNet-18 + KL+MSE" : "#e67e22",
|
| 98 |
+
"ViT-Base + MSE only" : "#2980b9",
|
| 99 |
+
"ViT-Base + KL+MSE (proposed)" : "#27ae60",
|
| 100 |
+
"ViT-Base + Dirichlet (Zoobot-style)": "#8e44ad",
|
| 101 |
+
}
|
| 102 |
+
MODEL_STYLES = {
|
| 103 |
+
"ResNet-18 + MSE (sigmoid)" : "-",
|
| 104 |
+
"ResNet-18 + KL+MSE" : "-.",
|
| 105 |
+
"ViT-Base + MSE only" : "--",
|
| 106 |
+
"ViT-Base + KL+MSE (proposed)" : "-",
|
| 107 |
+
"ViT-Base + Dirichlet (Zoobot-style)": ":",
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 112 |
+
# Inference helpers
|
| 113 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 114 |
+
|
| 115 |
+
def _infer_vit(model, loader, device, cfg,
|
| 116 |
+
collect_attn=True, n_attn=16):
|
| 117 |
+
model.eval()
|
| 118 |
+
all_preds, all_targets, all_weights = [], [], []
|
| 119 |
+
attn_images, all_layer_attns, attn_ids = [], [], []
|
| 120 |
+
attn_done = False
|
| 121 |
+
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
for images, targets, weights, image_ids in tqdm(loader, desc="ViT inference"):
|
| 124 |
+
images = images.to(device, non_blocking=True)
|
| 125 |
+
targets = targets.to(device, non_blocking=True)
|
| 126 |
+
weights = weights.to(device, non_blocking=True)
|
| 127 |
+
with autocast("cuda", enabled=cfg.training.mixed_precision):
|
| 128 |
+
logits = model(images)
|
| 129 |
+
p, t, w = predictions_to_numpy(logits, targets, weights)
|
| 130 |
+
all_preds.append(p)
|
| 131 |
+
all_targets.append(t)
|
| 132 |
+
all_weights.append(w)
|
| 133 |
+
|
| 134 |
+
if collect_attn and not attn_done:
|
| 135 |
+
layers = model.get_all_attention_weights()
|
| 136 |
+
if layers is not None:
|
| 137 |
+
n = min(n_attn, images.shape[0])
|
| 138 |
+
attn_images.append(images[:n].cpu())
|
| 139 |
+
all_layer_attns.append([l[:n].cpu() for l in layers])
|
| 140 |
+
attn_ids.extend([int(i) for i in image_ids[:n]])
|
| 141 |
+
if len(attn_ids) >= n_attn:
|
| 142 |
+
attn_done = True
|
| 143 |
+
|
| 144 |
+
preds = np.concatenate(all_preds)
|
| 145 |
+
targets = np.concatenate(all_targets)
|
| 146 |
+
weights = np.concatenate(all_weights)
|
| 147 |
+
|
| 148 |
+
attn_imgs_t = torch.cat(attn_images, dim=0)[:n_attn] if attn_images else None
|
| 149 |
+
merged_layers = None
|
| 150 |
+
if all_layer_attns:
|
| 151 |
+
merged_layers = [
|
| 152 |
+
torch.cat([b[li] for b in all_layer_attns], dim=0)[:n_attn]
|
| 153 |
+
for li in range(len(all_layer_attns[0]))
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
return preds, targets, weights, attn_imgs_t, merged_layers, attn_ids
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _infer_resnet(model, loader, device, cfg, use_sigmoid: bool):
|
| 160 |
+
model.eval()
|
| 161 |
+
all_preds, all_targets, all_weights = [], [], []
|
| 162 |
+
with torch.no_grad():
|
| 163 |
+
for images, targets, weights, _ in tqdm(loader, desc="ResNet inference"):
|
| 164 |
+
images = images.to(device, non_blocking=True)
|
| 165 |
+
with autocast("cuda", enabled=cfg.training.mixed_precision):
|
| 166 |
+
logits = model(images)
|
| 167 |
+
if use_sigmoid:
|
| 168 |
+
pred = torch.sigmoid(logits).cpu().numpy()
|
| 169 |
+
else:
|
| 170 |
+
pred = logits.detach().cpu().clone()
|
| 171 |
+
for q, (s, e) in QUESTION_GROUPS.items():
|
| 172 |
+
pred[:, s:e] = F.softmax(pred[:, s:e], dim=-1)
|
| 173 |
+
pred = pred.numpy()
|
| 174 |
+
all_preds.append(pred)
|
| 175 |
+
all_targets.append(targets.numpy())
|
| 176 |
+
all_weights.append(weights.numpy())
|
| 177 |
+
return (np.concatenate(all_preds),
|
| 178 |
+
np.concatenate(all_targets),
|
| 179 |
+
np.concatenate(all_weights))
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _infer_dirichlet(model, loader, device, cfg):
|
| 183 |
+
model.eval()
|
| 184 |
+
all_preds, all_targets, all_weights = [], [], []
|
| 185 |
+
with torch.no_grad():
|
| 186 |
+
for images, targets, weights, _ in tqdm(loader, desc="Dirichlet inference"):
|
| 187 |
+
images = images.to(device, non_blocking=True)
|
| 188 |
+
with autocast("cuda", enabled=cfg.training.mixed_precision):
|
| 189 |
+
alpha = model(images)
|
| 190 |
+
p, t, w = dirichlet_predictions_to_numpy(alpha, targets, weights)
|
| 191 |
+
all_preds.append(p)
|
| 192 |
+
all_targets.append(t)
|
| 193 |
+
all_weights.append(w)
|
| 194 |
+
return (np.concatenate(all_preds),
|
| 195 |
+
np.concatenate(all_targets),
|
| 196 |
+
np.concatenate(all_weights))
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 200 |
+
# Figure 1: Predicted vs true scatter (proposed model)
|
| 201 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 202 |
+
|
| 203 |
+
def fig_scatter_predicted_vs_true(preds, targets, weights, save_dir):
|
| 204 |
+
path_pdf = save_dir / "fig_scatter_predicted_vs_true.pdf"
|
| 205 |
+
path_png = save_dir / "fig_scatter_predicted_vs_true.png"
|
| 206 |
+
if path_pdf.exists() and path_png.exists():
|
| 207 |
+
log.info("Skip (exists): fig_scatter_predicted_vs_true"); return
|
| 208 |
+
|
| 209 |
+
fig, axes = plt.subplots(3, 4, figsize=(16, 12))
|
| 210 |
+
axes = axes.flatten()
|
| 211 |
+
|
| 212 |
+
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
|
| 213 |
+
ax = axes[q_idx]
|
| 214 |
+
mask = weights[:, q_idx] >= 0.05
|
| 215 |
+
pq = preds[mask, start:end].flatten()
|
| 216 |
+
tq = targets[mask, start:end].flatten()
|
| 217 |
+
|
| 218 |
+
ax.scatter(tq, pq, alpha=0.06, s=1, color="#2563eb", rasterized=True)
|
| 219 |
+
ax.plot([0, 1], [0, 1], "r--", linewidth=1, alpha=0.8)
|
| 220 |
+
ax.set_xlim(0, 1); ax.set_ylim(0, 1)
|
| 221 |
+
ax.set_xlabel("True vote fraction")
|
| 222 |
+
ax.set_ylabel("Predicted vote fraction")
|
| 223 |
+
ax.set_title(
|
| 224 |
+
f"{q_name}: {QUESTION_LABELS[q_name]}\n"
|
| 225 |
+
f"$n$ = {mask.sum():,} (w β₯ 0.05)",
|
| 226 |
+
fontsize=9,
|
| 227 |
+
)
|
| 228 |
+
ax.set_aspect("equal")
|
| 229 |
+
mae = np.abs(pq - tq).mean()
|
| 230 |
+
ax.text(0.05, 0.92, f"MAE = {mae:.3f}",
|
| 231 |
+
transform=ax.transAxes, fontsize=8,
|
| 232 |
+
bbox=dict(boxstyle="round,pad=0.2", facecolor="white",
|
| 233 |
+
edgecolor="grey", alpha=0.85))
|
| 234 |
+
|
| 235 |
+
axes[-1].axis("off")
|
| 236 |
+
plt.suptitle(
|
| 237 |
+
"Predicted vs. true vote fractions β reached branches (w β₯ 0.05)\n"
|
| 238 |
+
"ViT-Base/16 + hierarchical KL+MSE (proposed model, test set)",
|
| 239 |
+
fontsize=12,
|
| 240 |
+
)
|
| 241 |
+
plt.tight_layout()
|
| 242 |
+
fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
|
| 243 |
+
fig.savefig(path_png, dpi=300, bbox_inches="tight")
|
| 244 |
+
plt.close(fig)
|
| 245 |
+
log.info("Saved: fig_scatter_predicted_vs_true")
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 249 |
+
# Figure 2: Calibration reliability diagrams
|
| 250 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 251 |
+
|
| 252 |
+
def fig_calibration_reliability(model_results, save_dir, n_bins=15):
|
| 253 |
+
path_pdf = save_dir / "fig_calibration_reliability.pdf"
|
| 254 |
+
path_png = save_dir / "fig_calibration_reliability.png"
|
| 255 |
+
if path_pdf.exists() and path_png.exists():
|
| 256 |
+
log.info("Skip (exists): fig_calibration_reliability"); return
|
| 257 |
+
|
| 258 |
+
# Show 8 representative questions (skip t02 β bimodal, shown separately)
|
| 259 |
+
q_show = ["t01", "t03", "t04", "t06", "t07", "t09", "t10", "t11"]
|
| 260 |
+
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
|
| 261 |
+
axes = axes.flatten()
|
| 262 |
+
|
| 263 |
+
for ax_idx, q_name in enumerate(q_show):
|
| 264 |
+
ax = axes[ax_idx]
|
| 265 |
+
start, end = QUESTION_GROUPS[q_name]
|
| 266 |
+
q_idx = list(QUESTION_GROUPS.keys()).index(q_name)
|
| 267 |
+
|
| 268 |
+
for model_name, (preds, targets, weights) in model_results.items():
|
| 269 |
+
mask = weights[:, q_idx] >= 0.05
|
| 270 |
+
if mask.sum() < 50:
|
| 271 |
+
continue
|
| 272 |
+
pf = preds[mask, start:end].flatten()
|
| 273 |
+
tf = targets[mask, start:end].flatten()
|
| 274 |
+
|
| 275 |
+
# Adaptive bins (equal-frequency) β consistent with ECE computation
|
| 276 |
+
percentiles = np.linspace(0, 100, n_bins + 1)
|
| 277 |
+
bin_edges = np.unique(np.percentile(pf, percentiles))
|
| 278 |
+
if len(bin_edges) < 2:
|
| 279 |
+
continue
|
| 280 |
+
bin_ids = np.clip(
|
| 281 |
+
np.digitize(pf, bin_edges[1:-1]), 0, len(bin_edges) - 2
|
| 282 |
+
)
|
| 283 |
+
mp = np.array([
|
| 284 |
+
pf[bin_ids == b].mean() if (bin_ids == b).any() else np.nan
|
| 285 |
+
for b in range(len(bin_edges) - 1)
|
| 286 |
+
])
|
| 287 |
+
mt = np.array([
|
| 288 |
+
tf[bin_ids == b].mean() if (bin_ids == b).any() else np.nan
|
| 289 |
+
for b in range(len(bin_edges) - 1)
|
| 290 |
+
])
|
| 291 |
+
valid = ~np.isnan(mp) & ~np.isnan(mt)
|
| 292 |
+
ax.plot(
|
| 293 |
+
mp[valid], mt[valid],
|
| 294 |
+
MODEL_STYLES.get(model_name, "-"),
|
| 295 |
+
color=MODEL_COLORS.get(model_name, "#888888"),
|
| 296 |
+
linewidth=1.8, marker="o", markersize=3.5,
|
| 297 |
+
label=model_name, alpha=0.9,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
ax.plot([0, 1], [0, 1], "k--", linewidth=1, alpha=0.5, label="Perfect")
|
| 301 |
+
ax.set_xlim(0, 1); ax.set_ylim(0, 1)
|
| 302 |
+
ax.set_xlabel("Mean predicted", fontsize=8)
|
| 303 |
+
ax.set_ylabel("Mean true", fontsize=8)
|
| 304 |
+
ax.set_title(f"{q_name}: {QUESTION_LABELS[q_name]}", fontsize=9)
|
| 305 |
+
ax.set_aspect("equal")
|
| 306 |
+
if ax_idx == 0:
|
| 307 |
+
ax.legend(fontsize=6.5, loc="upper left")
|
| 308 |
+
|
| 309 |
+
plt.suptitle(
|
| 310 |
+
"Calibration reliability diagrams β all models (test set)\n"
|
| 311 |
+
"Reached branches only (w β₯ 0.05). Adaptive equal-frequency bins. "
|
| 312 |
+
"Closer to diagonal = better calibrated.",
|
| 313 |
+
fontsize=11,
|
| 314 |
+
)
|
| 315 |
+
plt.tight_layout()
|
| 316 |
+
fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
|
| 317 |
+
fig.savefig(path_png, dpi=300, bbox_inches="tight")
|
| 318 |
+
plt.close(fig)
|
| 319 |
+
log.info("Saved: fig_calibration_reliability")
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 323 |
+
# Figure 3: ECE bar chart
|
| 324 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 325 |
+
|
| 326 |
+
def fig_ece_comparison(model_results, save_dir):
|
| 327 |
+
path_pdf = save_dir / "fig_ece_comparison.pdf"
|
| 328 |
+
path_png = save_dir / "fig_ece_comparison.png"
|
| 329 |
+
if path_pdf.exists() and path_png.exists():
|
| 330 |
+
log.info("Skip (exists): fig_ece_comparison"); return
|
| 331 |
+
|
| 332 |
+
q_names = list(QUESTION_GROUPS.keys())
|
| 333 |
+
ece_rows = []
|
| 334 |
+
for model_name, (preds, targets, weights) in model_results.items():
|
| 335 |
+
row = {"model": model_name}
|
| 336 |
+
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
|
| 337 |
+
mask = weights[:, q_idx] >= 0.05
|
| 338 |
+
if mask.sum() < 50:
|
| 339 |
+
row[q_name] = float("nan")
|
| 340 |
+
else:
|
| 341 |
+
row[q_name] = _compute_ece(
|
| 342 |
+
preds[mask, start:end].flatten(),
|
| 343 |
+
targets[mask, start:end].flatten(),
|
| 344 |
+
n_bins=15,
|
| 345 |
+
)
|
| 346 |
+
row["mean_ece"] = float(
|
| 347 |
+
np.nanmean([row[q] for q in q_names])
|
| 348 |
+
)
|
| 349 |
+
ece_rows.append(row)
|
| 350 |
+
|
| 351 |
+
df_ece = pd.DataFrame(ece_rows)
|
| 352 |
+
df_ece.to_csv(save_dir / "table_ece_comparison.csv", index=False)
|
| 353 |
+
|
| 354 |
+
x = np.arange(len(q_names))
|
| 355 |
+
width = 0.80 / len(model_results)
|
| 356 |
+
palette = list(MODEL_COLORS.values())
|
| 357 |
+
|
| 358 |
+
fig, ax = plt.subplots(figsize=(14, 5))
|
| 359 |
+
for i, (model_name, _) in enumerate(model_results.items()):
|
| 360 |
+
vals = [
|
| 361 |
+
float(df_ece[df_ece["model"] == model_name][q].values[0])
|
| 362 |
+
for q in q_names
|
| 363 |
+
]
|
| 364 |
+
ax.bar(
|
| 365 |
+
x + i * width, vals, width,
|
| 366 |
+
label=model_name,
|
| 367 |
+
color=MODEL_COLORS.get(model_name, palette[i % len(palette)]),
|
| 368 |
+
alpha=0.85, edgecolor="white", linewidth=0.5,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
ax.set_xticks(x + width * (len(model_results) - 1) / 2)
|
| 372 |
+
ax.set_xticklabels(
|
| 373 |
+
[f"{q}\n({QUESTION_LABELS[q][:12]})" for q in q_names],
|
| 374 |
+
rotation=30, ha="right", fontsize=8,
|
| 375 |
+
)
|
| 376 |
+
ax.set_ylabel("Expected Calibration Error (ECE)", fontsize=11)
|
| 377 |
+
ax.set_title(
|
| 378 |
+
"Expected Calibration Error β all models (test set)\n"
|
| 379 |
+
"Reached branches (w β₯ 0.05). Adaptive equal-frequency binning. "
|
| 380 |
+
"Lower is better.",
|
| 381 |
+
fontsize=11,
|
| 382 |
+
)
|
| 383 |
+
ax.legend(fontsize=8)
|
| 384 |
+
ax.grid(True, alpha=0.3, axis="y")
|
| 385 |
+
ax.set_axisbelow(True)
|
| 386 |
+
plt.tight_layout()
|
| 387 |
+
fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
|
| 388 |
+
fig.savefig(path_png, dpi=300, bbox_inches="tight")
|
| 389 |
+
plt.close(fig)
|
| 390 |
+
log.info("Saved: fig_ece_comparison")
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 394 |
+
# Figure 4: Attention rollout gallery
|
| 395 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 396 |
+
|
| 397 |
+
def fig_attention_rollout_gallery(attn_imgs, all_layers, attn_ids, save_dir):
|
| 398 |
+
if attn_imgs is None or all_layers is None:
|
| 399 |
+
log.warning("No attention data β skipping gallery."); return
|
| 400 |
+
|
| 401 |
+
path_pdf = save_dir / "fig_attention_rollout_gallery.pdf"
|
| 402 |
+
path_png = save_dir / "fig_attention_rollout_gallery.png"
|
| 403 |
+
|
| 404 |
+
if not path_pdf.exists():
|
| 405 |
+
fig = plot_attention_grid(
|
| 406 |
+
attn_imgs, all_layers, attn_ids,
|
| 407 |
+
save_path=str(path_png),
|
| 408 |
+
n_cols=4, rollout_mode="full",
|
| 409 |
+
)
|
| 410 |
+
fig.savefig(path_pdf, dpi=300, bbox_inches="tight", facecolor="black")
|
| 411 |
+
plt.close(fig)
|
| 412 |
+
log.info("Saved: fig_attention_rollout_gallery")
|
| 413 |
+
|
| 414 |
+
# High-resolution PNG for journal submission
|
| 415 |
+
path_hq = save_dir / "fig_attention_rollout_gallery_HQ.png"
|
| 416 |
+
if not path_hq.exists():
|
| 417 |
+
fig2 = plot_attention_grid(
|
| 418 |
+
attn_imgs, all_layers, attn_ids,
|
| 419 |
+
n_cols=4, rollout_mode="full",
|
| 420 |
+
)
|
| 421 |
+
fig2.savefig(path_hq, dpi=600, bbox_inches="tight", facecolor="black")
|
| 422 |
+
plt.close(fig2)
|
| 423 |
+
log.info("Saved: fig_attention_rollout_gallery_HQ (600 dpi)")
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 427 |
+
# Figure 5: Attention entropy vs. depth
|
| 428 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 429 |
+
|
| 430 |
+
def fig_attention_entropy_depth(all_layers, save_dir):
|
| 431 |
+
if all_layers is None:
|
| 432 |
+
log.warning("No attention layers β skipping entropy plot."); return
|
| 433 |
+
|
| 434 |
+
path_pdf = save_dir / "fig_attention_entropy_depth.pdf"
|
| 435 |
+
path_png = save_dir / "fig_attention_entropy_depth.png"
|
| 436 |
+
if path_pdf.exists() and path_png.exists():
|
| 437 |
+
log.info("Skip (exists): fig_attention_entropy_depth"); return
|
| 438 |
+
|
| 439 |
+
fig = plot_attention_entropy(all_layers, save_path=str(path_png))
|
| 440 |
+
fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
|
| 441 |
+
plt.close(fig)
|
| 442 |
+
log.info("Saved: fig_attention_entropy_depth")
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 446 |
+
# Table: metrics for proposed model
|
| 447 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 448 |
+
|
| 449 |
+
def table_metrics_proposed(preds, targets, weights, save_dir):
|
| 450 |
+
metrics = compute_metrics(preds, targets, weights)
|
| 451 |
+
rows = []
|
| 452 |
+
for q_name in QUESTION_GROUPS:
|
| 453 |
+
rows.append({
|
| 454 |
+
"question" : q_name,
|
| 455 |
+
"description": QUESTION_LABELS[q_name],
|
| 456 |
+
"MAE" : round(metrics[f"mae/{q_name}"], 5),
|
| 457 |
+
"RMSE" : round(metrics[f"rmse/{q_name}"], 5),
|
| 458 |
+
"bias" : round(metrics[f"bias/{q_name}"], 5),
|
| 459 |
+
"ECE" : round(metrics[f"ece/{q_name}"], 5),
|
| 460 |
+
})
|
| 461 |
+
rows.append({
|
| 462 |
+
"question": "weighted_avg", "description": "Weighted average",
|
| 463 |
+
"MAE" : round(metrics["mae/weighted_avg"], 5),
|
| 464 |
+
"RMSE": round(metrics["rmse/weighted_avg"], 5),
|
| 465 |
+
"bias": "",
|
| 466 |
+
"ECE" : round(metrics["ece/mean"], 5),
|
| 467 |
+
})
|
| 468 |
+
df = pd.DataFrame(rows)
|
| 469 |
+
df.to_csv(save_dir / "table_metrics_proposed.csv", index=False)
|
| 470 |
+
log.info("\n%s\n", df.to_string(index=False))
|
| 471 |
+
return metrics
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
# βββββοΏ½οΏ½οΏ½βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 475 |
+
# Table: simplex violation for sigmoid baseline
|
| 476 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 477 |
+
|
| 478 |
+
def table_simplex_violation(model_results, save_dir):
|
| 479 |
+
"""
|
| 480 |
+
For each model, report the fraction of test samples where per-question
|
| 481 |
+
predictions do not sum to 1 Β± 0.02. Expected: ~0 for softmax models,
|
| 482 |
+
nonzero for sigmoid baseline. This table explains why the sigmoid
|
| 483 |
+
baseline achieves lower raw per-answer MAE despite being scientifically
|
| 484 |
+
invalid: unconstrained sigmoid outputs fit each marginal independently.
|
| 485 |
+
"""
|
| 486 |
+
rows = []
|
| 487 |
+
for model_name, (preds, _, _) in model_results.items():
|
| 488 |
+
svr = simplex_violation_rate(preds, tolerance=0.02)
|
| 489 |
+
row = {"model": model_name}
|
| 490 |
+
row.update({q: round(svr[q], 4) for q in QUESTION_GROUPS})
|
| 491 |
+
row["mean"] = round(svr["mean"], 4)
|
| 492 |
+
rows.append(row)
|
| 493 |
+
df = pd.DataFrame(rows)
|
| 494 |
+
df.to_csv(save_dir / "table_simplex_violation.csv", index=False)
|
| 495 |
+
log.info("Saved: table_simplex_violation.csv")
|
| 496 |
+
log.info("\n%s\n", df[["model", "mean"]].to_string(index=False))
|
| 497 |
+
return df
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 501 |
+
# Main
|
| 502 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 503 |
+
|
| 504 |
+
def main():
|
| 505 |
+
parser = argparse.ArgumentParser()
|
| 506 |
+
parser.add_argument("--config", required=True)
|
| 507 |
+
args = parser.parse_args()
|
| 508 |
+
|
| 509 |
+
base_cfg = OmegaConf.load("configs/base.yaml")
|
| 510 |
+
exp_cfg = OmegaConf.load(args.config)
|
| 511 |
+
cfg = OmegaConf.merge(base_cfg, exp_cfg)
|
| 512 |
+
|
| 513 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 514 |
+
save_dir = Path(cfg.outputs.figures_dir) / "evaluation"
|
| 515 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 516 |
+
ckpt_dir = Path(cfg.outputs.checkpoint_dir)
|
| 517 |
+
|
| 518 |
+
_, _, test_loader = build_dataloaders(cfg)
|
| 519 |
+
|
| 520 |
+
# ββ Load all models ββββββββββββββββββββββββββββββββββββββββ
|
| 521 |
+
log.info("Loading models from: %s", ckpt_dir)
|
| 522 |
+
|
| 523 |
+
def _load(path, model):
|
| 524 |
+
ckpt = torch.load(path, map_location="cpu", weights_only=True)
|
| 525 |
+
model.load_state_dict(ckpt["model_state"])
|
| 526 |
+
return model
|
| 527 |
+
|
| 528 |
+
vit_proposed = _load(
|
| 529 |
+
ckpt_dir / "best_full_train.pt", build_model(cfg)
|
| 530 |
+
).to(device)
|
| 531 |
+
|
| 532 |
+
vit_mse = _load(
|
| 533 |
+
ckpt_dir / "baseline_vit_mse.pt", build_model(cfg)
|
| 534 |
+
).to(device)
|
| 535 |
+
|
| 536 |
+
rn_mse = _load(
|
| 537 |
+
ckpt_dir / "baseline_resnet18_mse.pt",
|
| 538 |
+
ResNet18Baseline(dropout=cfg.model.dropout)
|
| 539 |
+
).to(device)
|
| 540 |
+
|
| 541 |
+
rn_kl = _load(
|
| 542 |
+
ckpt_dir / "baseline_resnet18_klmse.pt",
|
| 543 |
+
ResNet18Baseline(dropout=cfg.model.dropout)
|
| 544 |
+
).to(device)
|
| 545 |
+
|
| 546 |
+
vit_dirichlet = None
|
| 547 |
+
dp = ckpt_dir / "baseline_vit_dirichlet.pt"
|
| 548 |
+
if dp.exists():
|
| 549 |
+
vit_dirichlet = _load(dp, build_dirichlet_model(cfg)).to(device)
|
| 550 |
+
log.info("Loaded: ViT-Base + Dirichlet")
|
| 551 |
+
|
| 552 |
+
# ββ Run inference ββββββββββββββββββββββββββββββββββββββββββ
|
| 553 |
+
log.info("Running inference on test set...")
|
| 554 |
+
|
| 555 |
+
(p_proposed, t_proposed, w_proposed,
|
| 556 |
+
attn_imgs, all_layers, attn_ids) = _infer_vit(
|
| 557 |
+
vit_proposed, test_loader, device, cfg,
|
| 558 |
+
collect_attn=True, n_attn=16,
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
p_vit_mse, t_vit_mse, w_vit_mse = _infer_vit(
|
| 562 |
+
vit_mse, test_loader, device, cfg, collect_attn=False
|
| 563 |
+
)[:3]
|
| 564 |
+
|
| 565 |
+
p_rn_mse, t_rn_mse, w_rn_mse = _infer_resnet(
|
| 566 |
+
rn_mse, test_loader, device, cfg, use_sigmoid=True
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
p_rn_kl, t_rn_kl, w_rn_kl = _infer_resnet(
|
| 570 |
+
rn_kl, test_loader, device, cfg, use_sigmoid=False
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
# Build model_results dict (order determines legend order in figures)
|
| 574 |
+
model_results = {
|
| 575 |
+
"ResNet-18 + MSE (sigmoid)" : (p_rn_mse, t_rn_mse, w_rn_mse),
|
| 576 |
+
"ResNet-18 + KL+MSE" : (p_rn_kl, t_rn_kl, w_rn_kl),
|
| 577 |
+
"ViT-Base + MSE only" : (p_vit_mse, t_vit_mse, w_vit_mse),
|
| 578 |
+
"ViT-Base + KL+MSE (proposed)" : (p_proposed, t_proposed, w_proposed),
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
if vit_dirichlet is not None:
|
| 582 |
+
p_dir, t_dir, w_dir = _infer_dirichlet(
|
| 583 |
+
vit_dirichlet, test_loader, device, cfg
|
| 584 |
+
)
|
| 585 |
+
model_results["ViT-Base + Dirichlet (Zoobot-style)"] = (p_dir, t_dir, w_dir)
|
| 586 |
+
|
| 587 |
+
# ββ Tables βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 588 |
+
log.info("Computing metrics...")
|
| 589 |
+
table_metrics_proposed(p_proposed, t_proposed, w_proposed, save_dir)
|
| 590 |
+
|
| 591 |
+
log.info("Computing reached-branch MAE table...")
|
| 592 |
+
df_r = compute_reached_branch_mae_table(model_results)
|
| 593 |
+
df_r.to_csv(save_dir / "table_reached_branch_mae.csv", index=False)
|
| 594 |
+
log.info("Saved: table_reached_branch_mae.csv")
|
| 595 |
+
|
| 596 |
+
log.info("Computing simplex violation table...")
|
| 597 |
+
table_simplex_violation(model_results, save_dir)
|
| 598 |
+
|
| 599 |
+
# ββ Figures ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 600 |
+
log.info("Generating figures...")
|
| 601 |
+
fig_scatter_predicted_vs_true(p_proposed, t_proposed, w_proposed, save_dir)
|
| 602 |
+
fig_calibration_reliability(model_results, save_dir)
|
| 603 |
+
fig_ece_comparison(model_results, save_dir)
|
| 604 |
+
fig_attention_rollout_gallery(attn_imgs, all_layers, attn_ids, save_dir)
|
| 605 |
+
fig_attention_entropy_depth(all_layers, save_dir)
|
| 606 |
+
|
| 607 |
+
log.info("=" * 60)
|
| 608 |
+
log.info("ALL OUTPUTS SAVED TO: %s", save_dir)
|
| 609 |
+
log.info("=" * 60)
|
| 610 |
+
|
| 611 |
+
metrics = compute_metrics(p_proposed, t_proposed, w_proposed)
|
| 612 |
+
log.info("Proposed model β test set results:")
|
| 613 |
+
log.info(" Weighted MAE = %.5f", metrics["mae/weighted_avg"])
|
| 614 |
+
log.info(" Weighted RMSE = %.5f", metrics["rmse/weighted_avg"])
|
| 615 |
+
log.info(" Mean ECE = %.5f", metrics["ece/mean"])
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
if __name__ == "__main__":
|
| 619 |
+
main()
|
src/loss.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/loss.py
|
| 3 |
+
-----------
|
| 4 |
+
Loss functions for hierarchical probabilistic vote-fraction regression.
|
| 5 |
+
|
| 6 |
+
Two losses are implemented:
|
| 7 |
+
|
| 8 |
+
1. HierarchicalLoss β proposed method: weighted KL + MSE per question.
|
| 9 |
+
2. DirichletLoss β Zoobot-style comparison: weighted Dirichlet NLL.
|
| 10 |
+
3. MSEOnlyLoss β ablation baseline: hierarchical MSE, no KL term.
|
| 11 |
+
|
| 12 |
+
Both main losses use identical per-sample hierarchical weighting:
|
| 13 |
+
w_q = parent branch vote fraction (1.0 for root question t01)
|
| 14 |
+
|
| 15 |
+
Mathematical formulation
|
| 16 |
+
------------------------
|
| 17 |
+
HierarchicalLoss per question q:
|
| 18 |
+
L_q = w_q * [ Ξ»_kl * KL(p_q || Ε·_q) + Ξ»_mse * MSE(Ε·_q, p_q) ]
|
| 19 |
+
|
| 20 |
+
where p_q = ground-truth vote fractions [B, A_q]
|
| 21 |
+
Ε·_q = softmax(logits_q) [B, A_q]
|
| 22 |
+
w_q = hierarchical weight [B]
|
| 23 |
+
|
| 24 |
+
DirichletLoss per question q:
|
| 25 |
+
L_q = w_q * [ log B(Ξ±_q) β Ξ£_a (Ξ±_qa β 1) log(p_qa) ]
|
| 26 |
+
|
| 27 |
+
where Ξ±_q = 1 + softplus(logits_q) > 1 [B, A_q]
|
| 28 |
+
|
| 29 |
+
References
|
| 30 |
+
----------
|
| 31 |
+
Walmsley et al. (2022), MNRAS 509, 3966 (Zoobot β Dirichlet approach)
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
import torch.nn as nn
|
| 36 |
+
import torch.nn.functional as F
|
| 37 |
+
from omegaconf import DictConfig
|
| 38 |
+
from src.dataset import QUESTION_GROUPS
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class HierarchicalLoss(nn.Module):
|
| 42 |
+
"""Weighted hierarchical KL + MSE loss. Proposed method."""
|
| 43 |
+
|
| 44 |
+
def __init__(self, cfg: DictConfig):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.lambda_kl = float(cfg.loss.lambda_kl)
|
| 47 |
+
self.lambda_mse = float(cfg.loss.lambda_mse)
|
| 48 |
+
self.epsilon = float(cfg.loss.epsilon)
|
| 49 |
+
self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()]
|
| 50 |
+
|
| 51 |
+
def forward(self, predictions: torch.Tensor,
|
| 52 |
+
targets: torch.Tensor, weights: torch.Tensor):
|
| 53 |
+
total_loss = torch.zeros(1, device=predictions.device, dtype=predictions.dtype)
|
| 54 |
+
loss_dict = {}
|
| 55 |
+
|
| 56 |
+
for q_idx, (q_name, start, end) in enumerate(self.question_slices):
|
| 57 |
+
logits_q = predictions[:, start:end]
|
| 58 |
+
target_q = targets[:, start:end]
|
| 59 |
+
weight_q = weights[:, q_idx]
|
| 60 |
+
|
| 61 |
+
pred_q = F.softmax(logits_q, dim=-1)
|
| 62 |
+
pred_q_c = pred_q.clamp(min=self.epsilon, max=1.0)
|
| 63 |
+
target_q_c = target_q.clamp(min=self.epsilon, max=1.0)
|
| 64 |
+
|
| 65 |
+
kl_per_sample = (
|
| 66 |
+
target_q_c * (target_q_c.log() - pred_q_c.log())
|
| 67 |
+
).sum(dim=-1)
|
| 68 |
+
|
| 69 |
+
mse_per_sample = F.mse_loss(
|
| 70 |
+
pred_q, target_q, reduction="none"
|
| 71 |
+
).mean(dim=-1)
|
| 72 |
+
|
| 73 |
+
combined = (self.lambda_kl * kl_per_sample +
|
| 74 |
+
self.lambda_mse * mse_per_sample)
|
| 75 |
+
q_loss = (weight_q * combined).mean()
|
| 76 |
+
|
| 77 |
+
total_loss = total_loss + q_loss
|
| 78 |
+
loss_dict[f"loss/{q_name}"] = q_loss.detach().item()
|
| 79 |
+
|
| 80 |
+
loss_dict["loss/total"] = total_loss.detach().item()
|
| 81 |
+
return total_loss, loss_dict
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class DirichletLoss(nn.Module):
|
| 85 |
+
"""
|
| 86 |
+
Weighted hierarchical Dirichlet negative log-likelihood.
|
| 87 |
+
Used to train GalaxyViTDirichlet for comparison with the proposed method.
|
| 88 |
+
Matches the Zoobot approach (Walmsley et al. 2022).
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, cfg: DictConfig):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.epsilon = float(cfg.loss.epsilon)
|
| 94 |
+
self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()]
|
| 95 |
+
|
| 96 |
+
def forward(self, alpha: torch.Tensor,
|
| 97 |
+
targets: torch.Tensor, weights: torch.Tensor):
|
| 98 |
+
total_loss = torch.zeros(1, device=alpha.device, dtype=alpha.dtype)
|
| 99 |
+
loss_dict = {}
|
| 100 |
+
|
| 101 |
+
for q_idx, (q_name, start, end) in enumerate(self.question_slices):
|
| 102 |
+
alpha_q = alpha[:, start:end]
|
| 103 |
+
target_q = targets[:, start:end]
|
| 104 |
+
weight_q = weights[:, q_idx]
|
| 105 |
+
|
| 106 |
+
target_q_c = target_q.clamp(min=self.epsilon)
|
| 107 |
+
|
| 108 |
+
# log B(Ξ±) = Ξ£ lgamma(Ξ±_a) β lgamma(Ξ£ Ξ±_a)
|
| 109 |
+
log_beta = (
|
| 110 |
+
torch.lgamma(alpha_q).sum(dim=-1) -
|
| 111 |
+
torch.lgamma(alpha_q.sum(dim=-1))
|
| 112 |
+
)
|
| 113 |
+
# βΞ£ (Ξ±_a β 1) log(p_a)
|
| 114 |
+
log_likelihood = ((alpha_q - 1.0) * target_q_c.log()).sum(dim=-1)
|
| 115 |
+
|
| 116 |
+
nll_per_sample = log_beta - log_likelihood
|
| 117 |
+
q_loss = (weight_q * nll_per_sample).mean()
|
| 118 |
+
|
| 119 |
+
total_loss = total_loss + q_loss
|
| 120 |
+
loss_dict[f"loss/{q_name}"] = q_loss.detach().item()
|
| 121 |
+
|
| 122 |
+
loss_dict["loss/total"] = total_loss.detach().item()
|
| 123 |
+
return total_loss, loss_dict
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class MSEOnlyLoss(nn.Module):
|
| 127 |
+
"""
|
| 128 |
+
Hierarchical MSE loss without KL term. Used as ablation baseline.
|
| 129 |
+
Equivalent to HierarchicalLoss with lambda_kl=0.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(self, cfg: DictConfig):
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.epsilon = float(cfg.loss.epsilon)
|
| 135 |
+
self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()]
|
| 136 |
+
|
| 137 |
+
def forward(self, predictions: torch.Tensor,
|
| 138 |
+
targets: torch.Tensor, weights: torch.Tensor):
|
| 139 |
+
total_loss = torch.zeros(1, device=predictions.device, dtype=predictions.dtype)
|
| 140 |
+
loss_dict = {}
|
| 141 |
+
|
| 142 |
+
for q_idx, (q_name, start, end) in enumerate(self.question_slices):
|
| 143 |
+
logits_q = predictions[:, start:end]
|
| 144 |
+
target_q = targets[:, start:end]
|
| 145 |
+
weight_q = weights[:, q_idx]
|
| 146 |
+
|
| 147 |
+
pred_q = F.softmax(logits_q, dim=-1)
|
| 148 |
+
mse_per_sample = F.mse_loss(pred_q, target_q, reduction="none").mean(dim=-1)
|
| 149 |
+
q_loss = (weight_q * mse_per_sample).mean()
|
| 150 |
+
|
| 151 |
+
total_loss = total_loss + q_loss
|
| 152 |
+
loss_dict[f"loss/{q_name}"] = q_loss.detach().item()
|
| 153 |
+
|
| 154 |
+
loss_dict["loss/total"] = total_loss.detach().item()
|
| 155 |
+
return total_loss, loss_dict
|
src/metrics.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/metrics.py
|
| 3 |
+
--------------
|
| 4 |
+
Evaluation metrics for hierarchical probabilistic vote-fraction regression
|
| 5 |
+
on Galaxy Zoo 2.
|
| 6 |
+
|
| 7 |
+
Three evaluation regimes
|
| 8 |
+
------------------------
|
| 9 |
+
1. GLOBAL β all test samples (dominated by root question t01).
|
| 10 |
+
2. REACHED-BRANCH β samples where branch was actually reached (w >= threshold).
|
| 11 |
+
This is the scientifically correct regime for conditional questions.
|
| 12 |
+
3. ECE β Expected Calibration Error using adaptive (equal-frequency) bins.
|
| 13 |
+
|
| 14 |
+
Fixes applied vs. original
|
| 15 |
+
---------------------------
|
| 16 |
+
- ECE uses adaptive binning (equal-frequency bins) instead of equal-width.
|
| 17 |
+
Equal-width bins saturate at 0.200 for bimodal questions (t02, t03, t04)
|
| 18 |
+
where predictions cluster near 0 and 1. Adaptive bins are unbiased for
|
| 19 |
+
any distribution shape.
|
| 20 |
+
- simplex_violation_rate() added: fraction of question groups where the
|
| 21 |
+
sigmoid baseline predictions do not sum to 1 Β± 0.02. Used to explain
|
| 22 |
+
why ResNet-18 + sigmoid achieves lower raw MAE despite predicting
|
| 23 |
+
invalid distributions.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
from src.dataset import QUESTION_GROUPS
|
| 30 |
+
|
| 31 |
+
WEIGHT_THRESHOLDS = [0.05, 0.50, 0.75]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 35 |
+
# Main metrics function
|
| 36 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
+
|
| 38 |
+
def compute_metrics(
|
| 39 |
+
all_predictions: np.ndarray, # [N, 37]
|
| 40 |
+
all_targets: np.ndarray, # [N, 37]
|
| 41 |
+
all_weights: np.ndarray, # [N, 11]
|
| 42 |
+
) -> dict:
|
| 43 |
+
"""
|
| 44 |
+
Full metrics suite: global + reached-branch MAE/RMSE + bias + ECE.
|
| 45 |
+
"""
|
| 46 |
+
metrics = {}
|
| 47 |
+
q_names = list(QUESTION_GROUPS.keys())
|
| 48 |
+
|
| 49 |
+
# ββ 1. Global metrics ββββββββββββββββββββββββββββββββββββββ
|
| 50 |
+
mae_values = []
|
| 51 |
+
rmse_values = []
|
| 52 |
+
weight_means = []
|
| 53 |
+
|
| 54 |
+
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
|
| 55 |
+
pred_q = all_predictions[:, start:end]
|
| 56 |
+
target_q = all_targets[:, start:end]
|
| 57 |
+
weight_q = all_weights[:, q_idx]
|
| 58 |
+
|
| 59 |
+
mae_q = np.abs(pred_q - target_q).mean(axis=1).mean()
|
| 60 |
+
rmse_q = np.sqrt(((pred_q - target_q) ** 2).mean(axis=1).mean())
|
| 61 |
+
w_mean = weight_q.mean()
|
| 62 |
+
|
| 63 |
+
metrics[f"mae/{q_name}"] = float(mae_q)
|
| 64 |
+
metrics[f"rmse/{q_name}"] = float(rmse_q)
|
| 65 |
+
metrics[f"bias/{q_name}"] = float(
|
| 66 |
+
(all_predictions[:, start:end] - all_targets[:, start:end]).mean()
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
mae_values.append(mae_q)
|
| 70 |
+
rmse_values.append(rmse_q)
|
| 71 |
+
weight_means.append(w_mean)
|
| 72 |
+
|
| 73 |
+
weight_means = np.array(weight_means)
|
| 74 |
+
weight_sum = weight_means.sum()
|
| 75 |
+
metrics["mae/weighted_avg"] = float(
|
| 76 |
+
(weight_means * np.array(mae_values)).sum() / weight_sum
|
| 77 |
+
)
|
| 78 |
+
metrics["rmse/weighted_avg"] = float(
|
| 79 |
+
(weight_means * np.array(rmse_values)).sum() / weight_sum
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# ββ 2. Reached-branch metrics ββββββββββββββββββββββββββββββ
|
| 83 |
+
for thresh in WEIGHT_THRESHOLDS:
|
| 84 |
+
thresh_key = str(thresh).replace(".", "")
|
| 85 |
+
branch_maes = []
|
| 86 |
+
branch_ws = []
|
| 87 |
+
|
| 88 |
+
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
|
| 89 |
+
pred_q = all_predictions[:, start:end]
|
| 90 |
+
target_q = all_targets[:, start:end]
|
| 91 |
+
weight_q = all_weights[:, q_idx]
|
| 92 |
+
mask = weight_q >= thresh
|
| 93 |
+
n_reached = mask.sum()
|
| 94 |
+
metrics[f"n_reached_w{thresh_key}/{q_name}"] = int(n_reached)
|
| 95 |
+
|
| 96 |
+
if n_reached >= 10:
|
| 97 |
+
mae_q = np.abs(pred_q[mask] - target_q[mask]).mean(axis=1).mean()
|
| 98 |
+
metrics[f"mae_w{thresh_key}/{q_name}"] = float(mae_q)
|
| 99 |
+
branch_maes.append(mae_q)
|
| 100 |
+
branch_ws.append(weight_q[mask].mean())
|
| 101 |
+
else:
|
| 102 |
+
metrics[f"mae_w{thresh_key}/{q_name}"] = float("nan")
|
| 103 |
+
|
| 104 |
+
if branch_maes:
|
| 105 |
+
bw = np.array(branch_ws)
|
| 106 |
+
bm = np.array(branch_maes)
|
| 107 |
+
metrics[f"mae_w{thresh_key}/conditional_avg"] = float(
|
| 108 |
+
(bw * bm).sum() / bw.sum()
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# ββ 3. ECE per question (adaptive binning) βββββββββββββββββ
|
| 112 |
+
ece_values = []
|
| 113 |
+
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
|
| 114 |
+
pred_flat = all_predictions[:, start:end].flatten()
|
| 115 |
+
target_flat = all_targets[:, start:end].flatten()
|
| 116 |
+
ece = _compute_ece(pred_flat, target_flat)
|
| 117 |
+
metrics[f"ece/{q_name}"] = float(ece)
|
| 118 |
+
ece_values.append(ece)
|
| 119 |
+
|
| 120 |
+
metrics["ece/mean"] = float(np.nanmean(ece_values))
|
| 121 |
+
|
| 122 |
+
return metrics
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# ββββββββββββββββββββββοΏ½οΏ½οΏ½ββββββββββββββββββββββββββββββββββββββ
|
| 126 |
+
# ECE β adaptive (equal-frequency) binning
|
| 127 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 128 |
+
|
| 129 |
+
def _compute_ece(pred: np.ndarray, target: np.ndarray,
|
| 130 |
+
n_bins: int = 15) -> float:
|
| 131 |
+
"""
|
| 132 |
+
Expected Calibration Error with adaptive (equal-frequency) binning.
|
| 133 |
+
|
| 134 |
+
Equal-width binning saturates for bimodal distributions (e.g. t02, t03,
|
| 135 |
+
t04 where predictions cluster at 0 and 1) because >95% of samples fall
|
| 136 |
+
into boundary bins. Adaptive binning places bin edges at percentiles of
|
| 137 |
+
the predicted distribution, giving each bin an equal number of samples
|
| 138 |
+
and making ECE meaningful regardless of the prediction distribution shape.
|
| 139 |
+
|
| 140 |
+
Parameters
|
| 141 |
+
----------
|
| 142 |
+
pred : [N] predicted vote fractions
|
| 143 |
+
target : [N] true vote fractions
|
| 144 |
+
n_bins : number of bins (default 15)
|
| 145 |
+
|
| 146 |
+
Returns
|
| 147 |
+
-------
|
| 148 |
+
ECE : float in [0, 1]
|
| 149 |
+
"""
|
| 150 |
+
if len(pred) < n_bins:
|
| 151 |
+
return float("nan")
|
| 152 |
+
|
| 153 |
+
# Build equal-frequency bin edges from percentiles of pred
|
| 154 |
+
percentiles = np.linspace(0, 100, n_bins + 1)
|
| 155 |
+
bin_edges = np.unique(np.percentile(pred, percentiles))
|
| 156 |
+
|
| 157 |
+
if len(bin_edges) < 2:
|
| 158 |
+
return float("nan")
|
| 159 |
+
|
| 160 |
+
# Assign samples to bins (digitize returns 1-indexed; clip to [0, n-2])
|
| 161 |
+
bin_ids = np.clip(np.digitize(pred, bin_edges[1:-1]), 0, len(bin_edges) - 2)
|
| 162 |
+
|
| 163 |
+
ece = 0.0
|
| 164 |
+
n = len(pred)
|
| 165 |
+
for b in np.unique(bin_ids):
|
| 166 |
+
mask = bin_ids == b
|
| 167 |
+
if not mask.any():
|
| 168 |
+
continue
|
| 169 |
+
ece += (mask.sum() / n) * abs(pred[mask].mean() - target[mask].mean())
|
| 170 |
+
|
| 171 |
+
return float(ece)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 175 |
+
# Simplex violation rate
|
| 176 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 177 |
+
|
| 178 |
+
def simplex_violation_rate(
|
| 179 |
+
predictions: np.ndarray, # [N, 37]
|
| 180 |
+
tolerance: float = 0.02,
|
| 181 |
+
) -> dict:
|
| 182 |
+
"""
|
| 183 |
+
Compute the fraction of galaxies for which each question's predictions
|
| 184 |
+
do NOT sum to 1 Β± tolerance. Used to demonstrate that the sigmoid
|
| 185 |
+
baseline produces invalid probability distributions.
|
| 186 |
+
|
| 187 |
+
A model trained with softmax per question group will have violation_rate
|
| 188 |
+
β 0.0 by construction. A sigmoid baseline will have nonzero rates,
|
| 189 |
+
explaining why its raw per-answer MAE is lower (unconstrained outputs
|
| 190 |
+
can fit each marginal independently).
|
| 191 |
+
|
| 192 |
+
Parameters
|
| 193 |
+
----------
|
| 194 |
+
predictions : [N, 37] array of predicted values
|
| 195 |
+
tolerance : acceptable deviation from 1.0 (default 0.02)
|
| 196 |
+
|
| 197 |
+
Returns
|
| 198 |
+
-------
|
| 199 |
+
dict mapping question name to violation rate in [0, 1]
|
| 200 |
+
"""
|
| 201 |
+
rates = {}
|
| 202 |
+
for q_name, (start, end) in QUESTION_GROUPS.items():
|
| 203 |
+
pred_q = predictions[:, start:end]
|
| 204 |
+
row_sums = pred_q.sum(axis=1)
|
| 205 |
+
violated = np.abs(row_sums - 1.0) > tolerance
|
| 206 |
+
rates[q_name] = float(violated.mean())
|
| 207 |
+
rates["mean"] = float(np.mean(list(rates.values())))
|
| 208 |
+
return rates
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 212 |
+
# Reached-branch comparison table (for paper Table 2)
|
| 213 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 214 |
+
|
| 215 |
+
def compute_reached_branch_mae_table(
|
| 216 |
+
model_results: dict,
|
| 217 |
+
) -> "pd.DataFrame":
|
| 218 |
+
"""
|
| 219 |
+
Build the reached-branch MAE comparison table across all models.
|
| 220 |
+
|
| 221 |
+
Parameters
|
| 222 |
+
----------
|
| 223 |
+
model_results : dict mapping model_name β (preds, targets, weights)
|
| 224 |
+
All arrays are [N, 37] or [N, 11].
|
| 225 |
+
|
| 226 |
+
Returns
|
| 227 |
+
-------
|
| 228 |
+
pd.DataFrame with columns:
|
| 229 |
+
model, question, description, n_w005, mae_w005, mae_w050, mae_w075
|
| 230 |
+
"""
|
| 231 |
+
import pandas as pd
|
| 232 |
+
|
| 233 |
+
QUESTION_DESCRIPTIONS = {
|
| 234 |
+
"t01": "Smooth or features",
|
| 235 |
+
"t02": "Edge-on disk",
|
| 236 |
+
"t03": "Bar",
|
| 237 |
+
"t04": "Spiral arms",
|
| 238 |
+
"t05": "Bulge prominence",
|
| 239 |
+
"t06": "Odd feature",
|
| 240 |
+
"t07": "Roundedness (smooth)",
|
| 241 |
+
"t08": "Odd feature type",
|
| 242 |
+
"t09": "Bulge shape (edge-on)",
|
| 243 |
+
"t10": "Arms winding",
|
| 244 |
+
"t11": "Arms number",
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
rows = []
|
| 248 |
+
for model_name, (preds, targets, weights) in model_results.items():
|
| 249 |
+
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
|
| 250 |
+
pred_q = preds[:, start:end]
|
| 251 |
+
target_q = targets[:, start:end]
|
| 252 |
+
weight_q = weights[:, q_idx]
|
| 253 |
+
|
| 254 |
+
row = {
|
| 255 |
+
"model" : model_name,
|
| 256 |
+
"question" : q_name,
|
| 257 |
+
"description": QUESTION_DESCRIPTIONS[q_name],
|
| 258 |
+
}
|
| 259 |
+
for thresh in WEIGHT_THRESHOLDS:
|
| 260 |
+
mask = weight_q >= thresh
|
| 261 |
+
n = mask.sum()
|
| 262 |
+
key = f"n_w{str(thresh).replace('.','')}"
|
| 263 |
+
mkey = f"mae_w{str(thresh).replace('.','')}"
|
| 264 |
+
row[key] = int(n)
|
| 265 |
+
row[mkey] = (
|
| 266 |
+
float(np.abs(pred_q[mask] - target_q[mask]).mean(axis=1).mean())
|
| 267 |
+
if n >= 10 else float("nan")
|
| 268 |
+
)
|
| 269 |
+
rows.append(row)
|
| 270 |
+
|
| 271 |
+
# Weighted-average row for this model
|
| 272 |
+
branch_maes = []
|
| 273 |
+
branch_ws = []
|
| 274 |
+
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
|
| 275 |
+
weight_q = weights[:, q_idx]
|
| 276 |
+
pred_q = preds[:, start:end]
|
| 277 |
+
target_q = targets[:, start:end]
|
| 278 |
+
mask = weight_q >= 0.05
|
| 279 |
+
if mask.sum() >= 10:
|
| 280 |
+
branch_maes.append(
|
| 281 |
+
np.abs(pred_q[mask] - target_q[mask]).mean(axis=1).mean()
|
| 282 |
+
)
|
| 283 |
+
branch_ws.append(weight_q[mask].mean())
|
| 284 |
+
|
| 285 |
+
bw = np.array(branch_ws)
|
| 286 |
+
bm = np.array(branch_maes)
|
| 287 |
+
rows.append({
|
| 288 |
+
"model" : model_name,
|
| 289 |
+
"question" : "weighted_avg",
|
| 290 |
+
"description": "Weighted average (wβ₯0.05)",
|
| 291 |
+
"n_w005" : int(sum(weights[:, q] >= 0.05 for q in range(11)).sum()
|
| 292 |
+
if hasattr(weights, "__len__") else 0),
|
| 293 |
+
"mae_w005" : float((bw * bm).sum() / bw.sum()) if len(bw) > 0 else float("nan"),
|
| 294 |
+
"mae_w050" : float("nan"),
|
| 295 |
+
"mae_w075" : float("nan"),
|
| 296 |
+
})
|
| 297 |
+
|
| 298 |
+
return pd.DataFrame(rows)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 302 |
+
# Tensor β numpy helpers
|
| 303 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 304 |
+
|
| 305 |
+
def predictions_to_numpy(
|
| 306 |
+
predictions: torch.Tensor,
|
| 307 |
+
targets: torch.Tensor,
|
| 308 |
+
weights: torch.Tensor,
|
| 309 |
+
) -> tuple:
|
| 310 |
+
"""Apply softmax per question group and return numpy arrays."""
|
| 311 |
+
pred_np = predictions.detach().cpu().clone()
|
| 312 |
+
for q_name, (start, end) in QUESTION_GROUPS.items():
|
| 313 |
+
pred_np[:, start:end] = F.softmax(pred_np[:, start:end], dim=-1)
|
| 314 |
+
return (
|
| 315 |
+
pred_np.numpy(),
|
| 316 |
+
targets.detach().cpu().numpy(),
|
| 317 |
+
weights.detach().cpu().numpy(),
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def dirichlet_predictions_to_numpy(
|
| 322 |
+
alpha: torch.Tensor,
|
| 323 |
+
targets: torch.Tensor,
|
| 324 |
+
weights: torch.Tensor,
|
| 325 |
+
) -> tuple:
|
| 326 |
+
"""Convert Dirichlet concentration parameters to mean predictions."""
|
| 327 |
+
means = torch.zeros_like(alpha)
|
| 328 |
+
for q_name, (start, end) in QUESTION_GROUPS.items():
|
| 329 |
+
a_q = alpha[:, start:end]
|
| 330 |
+
means[:, start:end] = a_q / a_q.sum(dim=-1, keepdim=True)
|
| 331 |
+
return (
|
| 332 |
+
means.detach().cpu().numpy(),
|
| 333 |
+
targets.detach().cpu().numpy(),
|
| 334 |
+
weights.detach().cpu().numpy(),
|
| 335 |
+
)
|
src/model.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/model.py
|
| 3 |
+
------------
|
| 4 |
+
Vision Transformer (ViT-Base/16) backbone with three head variants:
|
| 5 |
+
|
| 6 |
+
1. GalaxyViT β linear regression head (37 logits). Proposed model.
|
| 7 |
+
2. GalaxyViTDirichlet β Dirichlet concentration head (Zoobot-style baseline).
|
| 8 |
+
3. mc_dropout_predict β MC Dropout uncertainty estimation wrapper.
|
| 9 |
+
|
| 10 |
+
Architecture
|
| 11 |
+
------------
|
| 12 |
+
Backbone : vit_base_patch16_224 from timm (pretrained ImageNet-21k)
|
| 13 |
+
12 transformer layers, 12 heads, embed_dim=768
|
| 14 |
+
Input : [B, 3, 224, 224]
|
| 15 |
+
CLS out: [B, 768]
|
| 16 |
+
Head : Dropout(p) β Linear(768, 37)
|
| 17 |
+
|
| 18 |
+
Full multi-layer attention rollout
|
| 19 |
+
------------------------------------
|
| 20 |
+
All 12 transformer blocks use fused_attn=False so forward hooks can
|
| 21 |
+
capture the post-softmax attention matrices. Rollout is computed in
|
| 22 |
+
attention_viz.py using the corrected right-multiplication order.
|
| 23 |
+
|
| 24 |
+
MC Dropout
|
| 25 |
+
-----------
|
| 26 |
+
enable_mc_dropout() keeps Dropout active at inference time.
|
| 27 |
+
Running N stochastic forward passes gives mean prediction and
|
| 28 |
+
per-answer std (epistemic uncertainty). N=30 is standard practice
|
| 29 |
+
per Gal & Ghahramani (2016).
|
| 30 |
+
|
| 31 |
+
Dirichlet head
|
| 32 |
+
--------------
|
| 33 |
+
Outputs Ξ± > 1 per answer via: Ξ± = 1 + softplus(linear(features))
|
| 34 |
+
Matches the Zoobot approach for a fair direct comparison.
|
| 35 |
+
Mean vote fraction: E[p_q] = Ξ±_q / sum(Ξ±_q).
|
| 36 |
+
|
| 37 |
+
References
|
| 38 |
+
----------
|
| 39 |
+
Gal & Ghahramani (2016). Dropout as a Bayesian Approximation.
|
| 40 |
+
ICML 2016. https://arxiv.org/abs/1506.02142
|
| 41 |
+
Walmsley et al. (2022). Towards Galaxy Foundation Models.
|
| 42 |
+
MNRAS 509, 3966. https://arxiv.org/abs/2110.12735
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
from __future__ import annotations
|
| 46 |
+
|
| 47 |
+
import torch
|
| 48 |
+
import torch.nn as nn
|
| 49 |
+
import torch.nn.functional as F
|
| 50 |
+
import timm
|
| 51 |
+
import numpy as np
|
| 52 |
+
from omegaconf import DictConfig
|
| 53 |
+
from typing import Optional, List, Tuple
|
| 54 |
+
|
| 55 |
+
from src.dataset import QUESTION_GROUPS
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 59 |
+
# Attention hook manager
|
| 60 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 61 |
+
|
| 62 |
+
class AttentionHookManager:
|
| 63 |
+
"""
|
| 64 |
+
Registers forward hooks on all transformer blocks to capture
|
| 65 |
+
post-softmax attention matrices for full rollout computation.
|
| 66 |
+
|
| 67 |
+
With fused_attn=False, timm's attention block executes:
|
| 68 |
+
attn = softmax(q @ k.T / scale) # [B, H, N+1, N+1]
|
| 69 |
+
attn = attn_drop(attn) # hook fires on INPUT = post-softmax
|
| 70 |
+
out = attn @ v
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, blocks):
|
| 74 |
+
self.blocks = blocks
|
| 75 |
+
self._attn_list: List[torch.Tensor] = []
|
| 76 |
+
self._handles = []
|
| 77 |
+
self._register_hooks()
|
| 78 |
+
|
| 79 |
+
def _register_hooks(self):
|
| 80 |
+
for block in self.blocks:
|
| 81 |
+
block.attn.fused_attn = False
|
| 82 |
+
|
| 83 |
+
def _make_hook():
|
| 84 |
+
def _hook(module, input, output):
|
| 85 |
+
# input[0] is the post-softmax attention tensor
|
| 86 |
+
self._attn_list.append(input[0].detach())
|
| 87 |
+
return _hook
|
| 88 |
+
|
| 89 |
+
h = block.attn.attn_drop.register_forward_hook(_make_hook())
|
| 90 |
+
self._handles.append(h)
|
| 91 |
+
|
| 92 |
+
def clear(self):
|
| 93 |
+
self._attn_list.clear()
|
| 94 |
+
|
| 95 |
+
def get_all_attentions(self) -> Optional[List[torch.Tensor]]:
|
| 96 |
+
"""Returns list of L tensors, each [B, H, N+1, N+1]."""
|
| 97 |
+
if not self._attn_list:
|
| 98 |
+
return None
|
| 99 |
+
return list(self._attn_list)
|
| 100 |
+
|
| 101 |
+
def get_last_attention(self) -> Optional[torch.Tensor]:
|
| 102 |
+
if not self._attn_list:
|
| 103 |
+
return None
|
| 104 |
+
return self._attn_list[-1]
|
| 105 |
+
|
| 106 |
+
def remove_all(self):
|
| 107 |
+
for h in self._handles:
|
| 108 |
+
h.remove()
|
| 109 |
+
self._handles.clear()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 113 |
+
# GalaxyViT β proposed model
|
| 114 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 115 |
+
|
| 116 |
+
class GalaxyViT(nn.Module):
|
| 117 |
+
"""
|
| 118 |
+
ViT-Base/16 backbone + linear regression head for GZ2.
|
| 119 |
+
|
| 120 |
+
Outputs 37 raw logits; softmax is applied per question group
|
| 121 |
+
during loss computation and metric evaluation.
|
| 122 |
+
|
| 123 |
+
Full 12-layer attention hooks are registered at construction.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(self, cfg: DictConfig):
|
| 127 |
+
super().__init__()
|
| 128 |
+
|
| 129 |
+
self.backbone = timm.create_model(
|
| 130 |
+
cfg.model.backbone,
|
| 131 |
+
pretrained=cfg.model.pretrained,
|
| 132 |
+
num_classes=0,
|
| 133 |
+
)
|
| 134 |
+
embed_dim = self.backbone.embed_dim # 768
|
| 135 |
+
|
| 136 |
+
self.head = nn.Sequential(
|
| 137 |
+
nn.Dropout(p=cfg.model.dropout),
|
| 138 |
+
nn.Linear(embed_dim, 37),
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
self._hook_mgr = AttentionHookManager(self.backbone.blocks)
|
| 142 |
+
self._mc_dropout = False
|
| 143 |
+
|
| 144 |
+
def enable_mc_dropout(self):
|
| 145 |
+
"""Keep Dropout active at inference time for MC sampling."""
|
| 146 |
+
self._mc_dropout = True
|
| 147 |
+
for m in self.modules():
|
| 148 |
+
if isinstance(m, nn.Dropout):
|
| 149 |
+
m.train()
|
| 150 |
+
|
| 151 |
+
def disable_mc_dropout(self):
|
| 152 |
+
self._mc_dropout = False
|
| 153 |
+
|
| 154 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 155 |
+
self._hook_mgr.clear()
|
| 156 |
+
features = self.backbone(x) # [B, 768]
|
| 157 |
+
logits = self.head(features) # [B, 37]
|
| 158 |
+
return logits
|
| 159 |
+
|
| 160 |
+
def get_attention_weights(self) -> Optional[torch.Tensor]:
|
| 161 |
+
return self._hook_mgr.get_last_attention()
|
| 162 |
+
|
| 163 |
+
def get_all_attention_weights(self) -> Optional[List[torch.Tensor]]:
|
| 164 |
+
return self._hook_mgr.get_all_attentions()
|
| 165 |
+
|
| 166 |
+
def remove_hooks(self):
|
| 167 |
+
self._hook_mgr.remove_all()
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 171 |
+
# GalaxyViTDirichlet β Zoobot-style comparison baseline
|
| 172 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 173 |
+
|
| 174 |
+
class GalaxyViTDirichlet(nn.Module):
|
| 175 |
+
"""
|
| 176 |
+
ViT-Base/16 + Dirichlet concentration head.
|
| 177 |
+
|
| 178 |
+
Outputs Ξ± > 1 per answer via Ξ± = 1 + softplus(linear(features)).
|
| 179 |
+
Enforcing Ξ± > 1 ensures unimodal Dirichlet distributions.
|
| 180 |
+
|
| 181 |
+
Mean vote fraction: E[p_q] = Ξ±_q / sum(Ξ±_q) (same as softmax mean).
|
| 182 |
+
Total concentration sum(Ξ±_q) encodes prediction confidence.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
def __init__(self, cfg: DictConfig):
|
| 186 |
+
super().__init__()
|
| 187 |
+
|
| 188 |
+
self.backbone = timm.create_model(
|
| 189 |
+
cfg.model.backbone,
|
| 190 |
+
pretrained=cfg.model.pretrained,
|
| 191 |
+
num_classes=0,
|
| 192 |
+
)
|
| 193 |
+
embed_dim = self.backbone.embed_dim
|
| 194 |
+
|
| 195 |
+
self.head = nn.Sequential(
|
| 196 |
+
nn.Dropout(p=cfg.model.dropout),
|
| 197 |
+
nn.Linear(embed_dim, 37),
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
self._hook_mgr = AttentionHookManager(self.backbone.blocks)
|
| 201 |
+
|
| 202 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 203 |
+
"""Returns Ξ±: [B, 37] Dirichlet concentration parameters > 1."""
|
| 204 |
+
self._hook_mgr.clear()
|
| 205 |
+
features = self.backbone(x)
|
| 206 |
+
logits = self.head(features)
|
| 207 |
+
alpha = 1.0 + F.softplus(logits) # Ξ± > 1
|
| 208 |
+
return alpha
|
| 209 |
+
|
| 210 |
+
def get_mean_prediction(self, alpha: torch.Tensor) -> torch.Tensor:
|
| 211 |
+
means = torch.zeros_like(alpha)
|
| 212 |
+
for q_name, (start, end) in QUESTION_GROUPS.items():
|
| 213 |
+
a_q = alpha[:, start:end]
|
| 214 |
+
means[:, start:end] = a_q / a_q.sum(dim=-1, keepdim=True)
|
| 215 |
+
return means
|
| 216 |
+
|
| 217 |
+
def get_attention_weights(self):
|
| 218 |
+
return self._hook_mgr.get_last_attention()
|
| 219 |
+
|
| 220 |
+
def get_all_attention_weights(self):
|
| 221 |
+
return self._hook_mgr.get_all_attentions()
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 225 |
+
# MC Dropout inference
|
| 226 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 227 |
+
|
| 228 |
+
@torch.no_grad()
|
| 229 |
+
def mc_dropout_predict(
|
| 230 |
+
model: GalaxyViT,
|
| 231 |
+
images: torch.Tensor,
|
| 232 |
+
n_passes: int = 30,
|
| 233 |
+
device: torch.device = None,
|
| 234 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 235 |
+
"""
|
| 236 |
+
MC Dropout epistemic uncertainty estimation.
|
| 237 |
+
|
| 238 |
+
Runs n_passes stochastic forward passes with dropout active,
|
| 239 |
+
returning mean prediction and per-answer std.
|
| 240 |
+
|
| 241 |
+
Parameters
|
| 242 |
+
----------
|
| 243 |
+
model : GalaxyViT instance
|
| 244 |
+
images : [B, 3, H, W]
|
| 245 |
+
n_passes : number of MC samples (30 is standard)
|
| 246 |
+
device : inference device
|
| 247 |
+
|
| 248 |
+
Returns
|
| 249 |
+
-------
|
| 250 |
+
mean_pred : [B, 37] mean softmax predictions
|
| 251 |
+
std_pred : [B, 37] std across passes (epistemic uncertainty)
|
| 252 |
+
per_q_uncertainty: [B, 11] mean std per question
|
| 253 |
+
"""
|
| 254 |
+
if device is None:
|
| 255 |
+
device = next(model.parameters()).device
|
| 256 |
+
|
| 257 |
+
model.eval()
|
| 258 |
+
model.enable_mc_dropout()
|
| 259 |
+
images = images.to(device)
|
| 260 |
+
all_preds = []
|
| 261 |
+
|
| 262 |
+
for _ in range(n_passes):
|
| 263 |
+
logits = model(images) # [B, 37]
|
| 264 |
+
preds = torch.zeros_like(logits)
|
| 265 |
+
for q_name, (start, end) in QUESTION_GROUPS.items():
|
| 266 |
+
preds[:, start:end] = F.softmax(logits[:, start:end], dim=-1)
|
| 267 |
+
all_preds.append(preds.cpu().numpy())
|
| 268 |
+
|
| 269 |
+
model.disable_mc_dropout()
|
| 270 |
+
|
| 271 |
+
all_preds = np.stack(all_preds, axis=0) # [n_passes, B, 37]
|
| 272 |
+
mean_pred = all_preds.mean(axis=0) # [B, 37]
|
| 273 |
+
std_pred = all_preds.std(axis=0) # [B, 37]
|
| 274 |
+
|
| 275 |
+
q_names = list(QUESTION_GROUPS.keys())
|
| 276 |
+
per_q_unc = np.zeros(
|
| 277 |
+
(mean_pred.shape[0], len(q_names)), dtype=np.float32
|
| 278 |
+
)
|
| 279 |
+
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
|
| 280 |
+
per_q_unc[:, q_idx] = std_pred[:, start:end].mean(axis=1)
|
| 281 |
+
|
| 282 |
+
return (
|
| 283 |
+
mean_pred.astype(np.float32),
|
| 284 |
+
std_pred.astype(np.float32),
|
| 285 |
+
per_q_unc,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 290 |
+
# Factory functions
|
| 291 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 292 |
+
|
| 293 |
+
def build_model(cfg: DictConfig) -> GalaxyViT:
|
| 294 |
+
model = GalaxyViT(cfg)
|
| 295 |
+
_print_summary(model, cfg, "GalaxyViT (regression β proposed)")
|
| 296 |
+
return model
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def build_dirichlet_model(cfg: DictConfig) -> GalaxyViTDirichlet:
|
| 300 |
+
model = GalaxyViTDirichlet(cfg)
|
| 301 |
+
_print_summary(model, cfg, "GalaxyViTDirichlet (Zoobot-style baseline)")
|
| 302 |
+
return model
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _print_summary(model, cfg, name: str):
|
| 306 |
+
total = sum(p.numel() for p in model.parameters())
|
| 307 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 308 |
+
n_hooks = len(model.backbone.blocks)
|
| 309 |
+
print(f"\n{'='*55}")
|
| 310 |
+
print(f"Model : {name}")
|
| 311 |
+
print(f"Backbone : {cfg.model.backbone}")
|
| 312 |
+
print(f"Pretrained : {cfg.model.pretrained}")
|
| 313 |
+
print(f"Dropout : {cfg.model.dropout}")
|
| 314 |
+
print(f"Parameters : {total:,} ({trainable:,} trainable)")
|
| 315 |
+
print(f"Attn hooks : {n_hooks} layers (full rollout enabled)")
|
| 316 |
+
print(f"{'='*55}\n")
|
src/train.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/train.py
|
| 3 |
+
------------
|
| 4 |
+
Main training loop for the proposed hierarchical probabilistic ViT
|
| 5 |
+
regression model on Galaxy Zoo 2.
|
| 6 |
+
|
| 7 |
+
Model : GalaxyViT (ViT-Base/16 + linear head)
|
| 8 |
+
Loss : HierarchicalLoss (KL + MSE, Ξ»=0.5 each)
|
| 9 |
+
Scheduler: CosineAnnealingLR
|
| 10 |
+
Dropout : 0.3 (increased from 0.1 β see base.yaml rationale)
|
| 11 |
+
|
| 12 |
+
Saves
|
| 13 |
+
-----
|
| 14 |
+
outputs/checkpoints/best_<experiment_name>.pt β best checkpoint
|
| 15 |
+
outputs/logs/training_<experiment_name>_history.csv β epoch history
|
| 16 |
+
|
| 17 |
+
Usage
|
| 18 |
+
-----
|
| 19 |
+
cd ~/galaxy
|
| 20 |
+
nohup python -m src.train --config configs/full_train.yaml \
|
| 21 |
+
> outputs/logs/train_full.log 2>&1 &
|
| 22 |
+
echo "PID: $!"
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import logging
|
| 27 |
+
import random
|
| 28 |
+
import sys
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
from torch.amp import autocast, GradScaler
|
| 35 |
+
from omegaconf import OmegaConf
|
| 36 |
+
import pandas as pd
|
| 37 |
+
|
| 38 |
+
import wandb
|
| 39 |
+
from tqdm import tqdm
|
| 40 |
+
|
| 41 |
+
from src.dataset import build_dataloaders
|
| 42 |
+
from src.loss import HierarchicalLoss
|
| 43 |
+
from src.metrics import compute_metrics, predictions_to_numpy
|
| 44 |
+
from src.model import build_model
|
| 45 |
+
from src.attention_viz import plot_attention_grid
|
| 46 |
+
|
| 47 |
+
logging.basicConfig(
|
| 48 |
+
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
| 49 |
+
datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
|
| 50 |
+
)
|
| 51 |
+
log = logging.getLogger("train")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 55 |
+
# Utilities
|
| 56 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 57 |
+
|
| 58 |
+
def set_seed(seed: int):
|
| 59 |
+
random.seed(seed)
|
| 60 |
+
np.random.seed(seed)
|
| 61 |
+
torch.manual_seed(seed)
|
| 62 |
+
torch.cuda.manual_seed_all(seed)
|
| 63 |
+
torch.backends.cudnn.deterministic = True
|
| 64 |
+
torch.backends.cudnn.benchmark = False
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class EarlyStopping:
|
| 68 |
+
def __init__(self, patience, min_delta, checkpoint_path):
|
| 69 |
+
self.patience = patience
|
| 70 |
+
self.min_delta = min_delta
|
| 71 |
+
self.checkpoint_path = checkpoint_path
|
| 72 |
+
self.best_loss = float("inf")
|
| 73 |
+
self.counter = 0
|
| 74 |
+
self.best_epoch = 0
|
| 75 |
+
|
| 76 |
+
def step(self, val_loss, model, epoch) -> bool:
|
| 77 |
+
if val_loss < self.best_loss - self.min_delta:
|
| 78 |
+
self.best_loss = val_loss
|
| 79 |
+
self.counter = 0
|
| 80 |
+
self.best_epoch = epoch
|
| 81 |
+
torch.save(
|
| 82 |
+
{"epoch": epoch, "model_state": model.state_dict(),
|
| 83 |
+
"val_loss": val_loss},
|
| 84 |
+
self.checkpoint_path,
|
| 85 |
+
)
|
| 86 |
+
log.info(" [ckpt] saved epoch=%d val_loss=%.6f", epoch, val_loss)
|
| 87 |
+
else:
|
| 88 |
+
self.counter += 1
|
| 89 |
+
log.info(" [early_stop] %d/%d best=%.6f",
|
| 90 |
+
self.counter, self.patience, self.best_loss)
|
| 91 |
+
return self.counter >= self.patience
|
| 92 |
+
|
| 93 |
+
def restore_best(self, model):
|
| 94 |
+
ckpt = torch.load(self.checkpoint_path, map_location="cpu",
|
| 95 |
+
weights_only=True)
|
| 96 |
+
model.load_state_dict(ckpt["model_state"])
|
| 97 |
+
log.info("Restored best weights epoch=%d val_loss=%.6f",
|
| 98 |
+
ckpt["epoch"], ckpt["val_loss"])
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 102 |
+
# Training / validation steps
|
| 103 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 104 |
+
|
| 105 |
+
def train_one_epoch(model, loader, loss_fn, optimizer,
|
| 106 |
+
scaler, device, cfg, epoch):
|
| 107 |
+
model.train()
|
| 108 |
+
total = 0.0
|
| 109 |
+
nb = 0
|
| 110 |
+
for images, targets, weights, _ in tqdm(
|
| 111 |
+
loader, desc=f"Train E{epoch}", leave=False
|
| 112 |
+
):
|
| 113 |
+
images = images.to(device, non_blocking=True)
|
| 114 |
+
targets = targets.to(device, non_blocking=True)
|
| 115 |
+
weights = weights.to(device, non_blocking=True)
|
| 116 |
+
|
| 117 |
+
optimizer.zero_grad(set_to_none=True)
|
| 118 |
+
with autocast("cuda", enabled=cfg.training.mixed_precision):
|
| 119 |
+
logits = model(images)
|
| 120 |
+
loss, _ = loss_fn(logits, targets, weights)
|
| 121 |
+
scaler.scale(loss).backward()
|
| 122 |
+
scaler.unscale_(optimizer)
|
| 123 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.training.grad_clip)
|
| 124 |
+
scaler.step(optimizer)
|
| 125 |
+
scaler.update()
|
| 126 |
+
|
| 127 |
+
total += loss.item()
|
| 128 |
+
nb += 1
|
| 129 |
+
return total / nb
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def validate(model, loader, loss_fn, device, cfg,
|
| 133 |
+
collect_attn=False, n_attn=8, epoch=0):
|
| 134 |
+
model.eval()
|
| 135 |
+
total = 0.0
|
| 136 |
+
nb = 0
|
| 137 |
+
all_preds, all_targets, all_weights = [], [], []
|
| 138 |
+
attn_imgs, all_layers_list, attn_ids = [], [], []
|
| 139 |
+
attn_done = False
|
| 140 |
+
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
for images, targets, weights, image_ids in tqdm(
|
| 143 |
+
loader, desc=f"Val E{epoch}", leave=False
|
| 144 |
+
):
|
| 145 |
+
images = images.to(device, non_blocking=True)
|
| 146 |
+
targets = targets.to(device, non_blocking=True)
|
| 147 |
+
weights = weights.to(device, non_blocking=True)
|
| 148 |
+
|
| 149 |
+
with autocast("cuda", enabled=cfg.training.mixed_precision):
|
| 150 |
+
logits = model(images)
|
| 151 |
+
loss, _ = loss_fn(logits, targets, weights)
|
| 152 |
+
|
| 153 |
+
total += loss.item()
|
| 154 |
+
nb += 1
|
| 155 |
+
p, t, w = predictions_to_numpy(logits, targets, weights)
|
| 156 |
+
all_preds.append(p)
|
| 157 |
+
all_targets.append(t)
|
| 158 |
+
all_weights.append(w)
|
| 159 |
+
|
| 160 |
+
if collect_attn and not attn_done:
|
| 161 |
+
all_layers = model.get_all_attention_weights()
|
| 162 |
+
if all_layers is not None:
|
| 163 |
+
n = min(n_attn, images.shape[0])
|
| 164 |
+
attn_imgs.append(images[:n].cpu())
|
| 165 |
+
all_layers_list.append([l[:n].cpu() for l in all_layers])
|
| 166 |
+
attn_ids.extend([int(i) for i in image_ids[:n]])
|
| 167 |
+
if len(attn_ids) >= n_attn:
|
| 168 |
+
attn_done = True
|
| 169 |
+
|
| 170 |
+
all_preds = np.concatenate(all_preds)
|
| 171 |
+
all_targets = np.concatenate(all_targets)
|
| 172 |
+
all_weights = np.concatenate(all_weights)
|
| 173 |
+
metrics = compute_metrics(all_preds, all_targets, all_weights)
|
| 174 |
+
|
| 175 |
+
val_logs = {"val/loss_total": total / nb}
|
| 176 |
+
val_logs.update({f"val/{k}": v for k, v in metrics.items()})
|
| 177 |
+
val_logs["val/reached_mae_w050"] = metrics.get("mae_w050/conditional_avg", 0)
|
| 178 |
+
|
| 179 |
+
attn_data = None
|
| 180 |
+
if collect_attn and attn_imgs:
|
| 181 |
+
attn_data = (
|
| 182 |
+
torch.cat(attn_imgs, dim=0),
|
| 183 |
+
[torch.cat([b[li] for b in all_layers_list], dim=0)
|
| 184 |
+
for li in range(len(all_layers_list[0]))],
|
| 185 |
+
attn_ids,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
return val_logs, attn_data
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 192 |
+
# Main
|
| 193 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 194 |
+
|
| 195 |
+
def main():
|
| 196 |
+
parser = argparse.ArgumentParser()
|
| 197 |
+
parser.add_argument("--config", required=True)
|
| 198 |
+
args = parser.parse_args()
|
| 199 |
+
|
| 200 |
+
base_cfg = OmegaConf.load("configs/base.yaml")
|
| 201 |
+
exp_cfg = OmegaConf.load(args.config)
|
| 202 |
+
cfg = OmegaConf.merge(base_cfg, exp_cfg)
|
| 203 |
+
|
| 204 |
+
set_seed(cfg.seed)
|
| 205 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 206 |
+
log.info("Device: %s", device)
|
| 207 |
+
|
| 208 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 209 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 210 |
+
|
| 211 |
+
Path(cfg.outputs.checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
| 212 |
+
Path(cfg.outputs.figures_dir).mkdir(parents=True, exist_ok=True)
|
| 213 |
+
Path(cfg.outputs.log_dir).mkdir(parents=True, exist_ok=True)
|
| 214 |
+
|
| 215 |
+
checkpoint_path = str(
|
| 216 |
+
Path(cfg.outputs.checkpoint_dir) / f"best_{cfg.experiment_name}.pt"
|
| 217 |
+
)
|
| 218 |
+
history_path = str(
|
| 219 |
+
Path(cfg.outputs.log_dir) / f"training_{cfg.experiment_name}_history.csv"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
if cfg.wandb.enabled:
|
| 223 |
+
wandb.init(
|
| 224 |
+
project=cfg.wandb.project,
|
| 225 |
+
name=cfg.experiment_name,
|
| 226 |
+
config=OmegaConf.to_container(cfg, resolve=True),
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
log.info("Building dataloaders...")
|
| 230 |
+
train_loader, val_loader, _ = build_dataloaders(cfg)
|
| 231 |
+
|
| 232 |
+
log.info("Building model...")
|
| 233 |
+
model = build_model(cfg).to(device)
|
| 234 |
+
loss_fn = HierarchicalLoss(cfg)
|
| 235 |
+
|
| 236 |
+
optimizer = torch.optim.AdamW(
|
| 237 |
+
[
|
| 238 |
+
{"params": model.backbone.parameters(),
|
| 239 |
+
"lr": cfg.training.learning_rate * 0.1},
|
| 240 |
+
{"params": model.head.parameters(),
|
| 241 |
+
"lr": cfg.training.learning_rate},
|
| 242 |
+
],
|
| 243 |
+
weight_decay=cfg.training.weight_decay,
|
| 244 |
+
)
|
| 245 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 246 |
+
optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min
|
| 247 |
+
)
|
| 248 |
+
scaler = GradScaler("cuda")
|
| 249 |
+
early_stop = EarlyStopping(
|
| 250 |
+
patience = cfg.early_stopping.patience,
|
| 251 |
+
min_delta = cfg.early_stopping.min_delta,
|
| 252 |
+
checkpoint_path = checkpoint_path,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
log.info("Starting training: %s", cfg.experiment_name)
|
| 256 |
+
history = []
|
| 257 |
+
|
| 258 |
+
for epoch in range(1, cfg.training.epochs + 1):
|
| 259 |
+
train_loss = train_one_epoch(
|
| 260 |
+
model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch
|
| 261 |
+
)
|
| 262 |
+
collect_attn = (epoch % cfg.wandb.log_attention_every_n_epochs == 0)
|
| 263 |
+
val_logs, attn_data = validate(
|
| 264 |
+
model, val_loader, loss_fn, device, cfg,
|
| 265 |
+
collect_attn=collect_attn,
|
| 266 |
+
n_attn=cfg.wandb.n_attention_samples,
|
| 267 |
+
epoch=epoch,
|
| 268 |
+
)
|
| 269 |
+
scheduler.step()
|
| 270 |
+
lr = scheduler.get_last_lr()[0]
|
| 271 |
+
|
| 272 |
+
val_mae = val_logs.get("val/mae/weighted_avg", 0)
|
| 273 |
+
val_loss = val_logs["val/loss_total"]
|
| 274 |
+
reached = val_logs.get("val/reached_mae_w050", 0)
|
| 275 |
+
|
| 276 |
+
log.info(
|
| 277 |
+
"Epoch %d train=%.4f val=%.4f mae=%.4f reached_mae=%.4f lr=%.2e",
|
| 278 |
+
epoch, train_loss, val_loss, val_mae, reached, lr,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
history.append({
|
| 282 |
+
"epoch" : epoch,
|
| 283 |
+
"train_loss" : train_loss,
|
| 284 |
+
"val_loss" : val_loss,
|
| 285 |
+
"val_mae" : val_mae,
|
| 286 |
+
"reached_mae": reached,
|
| 287 |
+
"lr" : lr,
|
| 288 |
+
})
|
| 289 |
+
|
| 290 |
+
if cfg.wandb.enabled:
|
| 291 |
+
log_dict = {
|
| 292 |
+
"train/loss": train_loss,
|
| 293 |
+
**val_logs,
|
| 294 |
+
"lr": lr, "epoch": epoch,
|
| 295 |
+
}
|
| 296 |
+
if attn_data is not None:
|
| 297 |
+
import matplotlib.pyplot as plt
|
| 298 |
+
imgs, layers, ids = attn_data
|
| 299 |
+
fig = plot_attention_grid(
|
| 300 |
+
imgs, layers, ids,
|
| 301 |
+
save_path=(
|
| 302 |
+
f"{cfg.outputs.figures_dir}/{cfg.experiment_name}/"
|
| 303 |
+
f"attn_epoch{epoch:03d}.png"
|
| 304 |
+
),
|
| 305 |
+
n_cols=4, rollout_mode="full",
|
| 306 |
+
)
|
| 307 |
+
log_dict["attention/rollout_full"] = wandb.Image(fig)
|
| 308 |
+
plt.close(fig)
|
| 309 |
+
wandb.log(log_dict, step=epoch)
|
| 310 |
+
|
| 311 |
+
if early_stop.step(val_loss, model, epoch):
|
| 312 |
+
log.info("Early stopping at epoch %d best=%d loss=%.6f",
|
| 313 |
+
epoch, early_stop.best_epoch, early_stop.best_loss)
|
| 314 |
+
break
|
| 315 |
+
|
| 316 |
+
# Save history
|
| 317 |
+
pd.DataFrame(history).to_csv(history_path, index=False)
|
| 318 |
+
log.info("Saved history: %s", history_path)
|
| 319 |
+
|
| 320 |
+
early_stop.restore_best(model)
|
| 321 |
+
if cfg.wandb.enabled:
|
| 322 |
+
wandb.finish()
|
| 323 |
+
log.info("Done. Best checkpoint: %s", checkpoint_path)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
if __name__ == "__main__":
|
| 327 |
+
main()
|
src/train_single.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/train_single.py
|
| 3 |
+
-------------------
|
| 4 |
+
Train any single model by name. Designed for running baselines
|
| 5 |
+
one at a time with breaks between them.
|
| 6 |
+
|
| 7 |
+
Available models
|
| 8 |
+
----------------
|
| 9 |
+
proposed β ViT-Base + hierarchical KL+MSE (main model)
|
| 10 |
+
b1_resnet_mse β ResNet-18 + independent MSE (sigmoid)
|
| 11 |
+
b2_resnet_kl β ResNet-18 + hierarchical KL+MSE
|
| 12 |
+
b3_vit_mse β ViT-Base + hierarchical MSE only (no KL)
|
| 13 |
+
b4_vit_dir β ViT-Base + Dirichlet NLL (Zoobot-style)
|
| 14 |
+
|
| 15 |
+
Usage
|
| 16 |
+
-----
|
| 17 |
+
# Train proposed model
|
| 18 |
+
python -m src.train_single --model proposed --config configs/full_train.yaml
|
| 19 |
+
|
| 20 |
+
# Train one baseline at a time
|
| 21 |
+
python -m src.train_single --model b1_resnet_mse --config configs/full_train.yaml
|
| 22 |
+
python -m src.train_single --model b2_resnet_kl --config configs/full_train.yaml
|
| 23 |
+
python -m src.train_single --model b3_vit_mse --config configs/full_train.yaml
|
| 24 |
+
python -m src.train_single --model b4_vit_dir --config configs/full_train.yaml
|
| 25 |
+
|
| 26 |
+
# With nohup (recommended)
|
| 27 |
+
nohup python -m src.train_single --model b3_vit_mse \\
|
| 28 |
+
--config configs/full_train.yaml \\
|
| 29 |
+
> outputs/logs/train_b3_vit_mse.log 2>&1 &
|
| 30 |
+
echo "PID: $!"
|
| 31 |
+
|
| 32 |
+
Each model saves its checkpoint independently, so you can run them
|
| 33 |
+
in any order and resume from any point. Already-trained models are
|
| 34 |
+
detected by their checkpoint file and skipped unless --force is passed.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import argparse
|
| 38 |
+
import logging
|
| 39 |
+
import sys
|
| 40 |
+
from pathlib import Path
|
| 41 |
+
|
| 42 |
+
import numpy as np
|
| 43 |
+
import torch
|
| 44 |
+
from omegaconf import OmegaConf
|
| 45 |
+
|
| 46 |
+
logging.basicConfig(
|
| 47 |
+
format="%(asctime)s %(levelname)s %(message)s",
|
| 48 |
+
datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
|
| 49 |
+
)
|
| 50 |
+
log = logging.getLogger("train_single")
|
| 51 |
+
|
| 52 |
+
# ββ Checkpoint paths per model βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 53 |
+
CHECKPOINT_NAMES = {
|
| 54 |
+
"proposed" : "best_full_train.pt",
|
| 55 |
+
"b1_resnet_mse" : "baseline_resnet18_mse.pt",
|
| 56 |
+
"b2_resnet_kl" : "baseline_resnet18_klmse.pt",
|
| 57 |
+
"b3_vit_mse" : "baseline_vit_mse.pt",
|
| 58 |
+
"b4_vit_dir" : "baseline_vit_dirichlet.pt",
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
# ββ Human-readable labels ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
MODEL_LABELS = {
|
| 63 |
+
"proposed" : "ViT-Base + hierarchical KL+MSE (proposed)",
|
| 64 |
+
"b1_resnet_mse" : "ResNet-18 + independent MSE (sigmoid, no hierarchy)",
|
| 65 |
+
"b2_resnet_kl" : "ResNet-18 + hierarchical KL+MSE",
|
| 66 |
+
"b3_vit_mse" : "ViT-Base + hierarchical MSE only (no KL)",
|
| 67 |
+
"b4_vit_dir" : "ViT-Base + Dirichlet NLL (Zoobot-style)",
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def train_proposed(cfg, device, ckpt_path):
|
| 72 |
+
"""Train the proposed ViT + hierarchical KL+MSE model."""
|
| 73 |
+
from src.train import (
|
| 74 |
+
train_one_epoch, validate, EarlyStopping, set_seed
|
| 75 |
+
)
|
| 76 |
+
from src.dataset import build_dataloaders
|
| 77 |
+
from src.model import build_model
|
| 78 |
+
from src.loss import HierarchicalLoss
|
| 79 |
+
from src.attention_viz import plot_attention_grid
|
| 80 |
+
import pandas as pd
|
| 81 |
+
import wandb
|
| 82 |
+
from torch.amp import GradScaler
|
| 83 |
+
import matplotlib.pyplot as plt
|
| 84 |
+
|
| 85 |
+
set_seed(cfg.seed)
|
| 86 |
+
log.info("Training: %s", MODEL_LABELS["proposed"])
|
| 87 |
+
|
| 88 |
+
Path(cfg.outputs.checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
| 89 |
+
Path(cfg.outputs.figures_dir).mkdir(parents=True, exist_ok=True)
|
| 90 |
+
Path(cfg.outputs.log_dir).mkdir(parents=True, exist_ok=True)
|
| 91 |
+
|
| 92 |
+
history_path = str(
|
| 93 |
+
Path(cfg.outputs.log_dir) / "training_full_train_history.csv"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
if cfg.wandb.enabled:
|
| 97 |
+
wandb.init(
|
| 98 |
+
project=cfg.wandb.project,
|
| 99 |
+
name=cfg.experiment_name,
|
| 100 |
+
config=OmegaConf.to_container(cfg, resolve=True),
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
train_loader, val_loader, _ = build_dataloaders(cfg)
|
| 104 |
+
model = build_model(cfg).to(device)
|
| 105 |
+
loss_fn = HierarchicalLoss(cfg)
|
| 106 |
+
|
| 107 |
+
optimizer = torch.optim.AdamW(
|
| 108 |
+
[
|
| 109 |
+
{"params": model.backbone.parameters(),
|
| 110 |
+
"lr": cfg.training.learning_rate * 0.1},
|
| 111 |
+
{"params": model.head.parameters(),
|
| 112 |
+
"lr": cfg.training.learning_rate},
|
| 113 |
+
],
|
| 114 |
+
weight_decay=cfg.training.weight_decay,
|
| 115 |
+
)
|
| 116 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 117 |
+
optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min
|
| 118 |
+
)
|
| 119 |
+
scaler = GradScaler("cuda")
|
| 120 |
+
early_stop = EarlyStopping(
|
| 121 |
+
patience=cfg.early_stopping.patience,
|
| 122 |
+
min_delta=cfg.early_stopping.min_delta,
|
| 123 |
+
checkpoint_path=ckpt_path,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
history = []
|
| 127 |
+
for epoch in range(1, cfg.training.epochs + 1):
|
| 128 |
+
train_loss = train_one_epoch(
|
| 129 |
+
model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch
|
| 130 |
+
)
|
| 131 |
+
collect_attn = (epoch % cfg.wandb.log_attention_every_n_epochs == 0)
|
| 132 |
+
val_logs, attn_data = validate(
|
| 133 |
+
model, val_loader, loss_fn, device, cfg,
|
| 134 |
+
collect_attn=collect_attn,
|
| 135 |
+
n_attn=cfg.wandb.n_attention_samples,
|
| 136 |
+
epoch=epoch,
|
| 137 |
+
)
|
| 138 |
+
scheduler.step()
|
| 139 |
+
lr = scheduler.get_last_lr()[0]
|
| 140 |
+
|
| 141 |
+
val_mae = val_logs.get("val/mae/weighted_avg", 0)
|
| 142 |
+
val_loss = val_logs["val/loss_total"]
|
| 143 |
+
log.info("Epoch %d train=%.4f val=%.4f mae=%.4f lr=%.2e",
|
| 144 |
+
epoch, train_loss, val_loss, val_mae, lr)
|
| 145 |
+
|
| 146 |
+
history.append({
|
| 147 |
+
"epoch": epoch, "train_loss": train_loss,
|
| 148 |
+
"val_loss": val_loss, "val_mae": val_mae, "lr": lr,
|
| 149 |
+
})
|
| 150 |
+
|
| 151 |
+
if cfg.wandb.enabled:
|
| 152 |
+
log_dict = {"train/loss": train_loss, **val_logs,
|
| 153 |
+
"lr": lr, "epoch": epoch}
|
| 154 |
+
if attn_data is not None:
|
| 155 |
+
imgs, layers, ids = attn_data
|
| 156 |
+
fig = plot_attention_grid(
|
| 157 |
+
imgs, layers, ids,
|
| 158 |
+
save_path=(f"{cfg.outputs.figures_dir}/{cfg.experiment_name}/"
|
| 159 |
+
f"attn_epoch{epoch:03d}.png"),
|
| 160 |
+
n_cols=4, rollout_mode="full",
|
| 161 |
+
)
|
| 162 |
+
log_dict["attention/rollout_full"] = wandb.Image(fig)
|
| 163 |
+
plt.close(fig)
|
| 164 |
+
wandb.log(log_dict, step=epoch)
|
| 165 |
+
|
| 166 |
+
if early_stop.step(val_loss, model, epoch):
|
| 167 |
+
log.info("Early stopping at epoch %d", epoch)
|
| 168 |
+
break
|
| 169 |
+
|
| 170 |
+
pd.DataFrame(history).to_csv(history_path, index=False)
|
| 171 |
+
early_stop.restore_best(model)
|
| 172 |
+
if cfg.wandb.enabled:
|
| 173 |
+
wandb.finish()
|
| 174 |
+
log.info("Done. Checkpoint: %s", ckpt_path)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def train_baseline(cfg, device, ckpt_path, model_key):
|
| 178 |
+
"""Train any of the four baselines."""
|
| 179 |
+
import wandb
|
| 180 |
+
from torch.amp import GradScaler
|
| 181 |
+
from src.dataset import build_dataloaders
|
| 182 |
+
from src.model import build_model, build_dirichlet_model
|
| 183 |
+
from src.loss import HierarchicalLoss, DirichletLoss, MSEOnlyLoss
|
| 184 |
+
from src.metrics import (compute_metrics, predictions_to_numpy,
|
| 185 |
+
dirichlet_predictions_to_numpy)
|
| 186 |
+
from src.baselines import (
|
| 187 |
+
ResNet18Baseline, IndependentMSELoss, EarlyStopping,
|
| 188 |
+
set_seed, _train_epoch, _val_epoch,
|
| 189 |
+
_train_epoch_dirichlet, _val_epoch_dirichlet,
|
| 190 |
+
)
|
| 191 |
+
import pandas as pd
|
| 192 |
+
from omegaconf import OmegaConf as OC
|
| 193 |
+
|
| 194 |
+
set_seed(cfg.seed)
|
| 195 |
+
log.info("Training: %s", MODEL_LABELS[model_key])
|
| 196 |
+
|
| 197 |
+
Path(cfg.outputs.checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
| 198 |
+
|
| 199 |
+
# ββ Build model and loss βββββββββββββββββββββββββββββββββββ
|
| 200 |
+
if model_key == "b1_resnet_mse":
|
| 201 |
+
model = ResNet18Baseline(dropout=cfg.model.dropout).to(device)
|
| 202 |
+
loss_fn = IndependentMSELoss()
|
| 203 |
+
use_sigmoid = True
|
| 204 |
+
is_dirichlet = False
|
| 205 |
+
use_layerwise_lr = False
|
| 206 |
+
wandb_name = "B1-ResNet18-MSE"
|
| 207 |
+
|
| 208 |
+
elif model_key == "b2_resnet_kl":
|
| 209 |
+
model = ResNet18Baseline(dropout=cfg.model.dropout).to(device)
|
| 210 |
+
loss_fn = HierarchicalLoss(cfg)
|
| 211 |
+
use_sigmoid = False
|
| 212 |
+
is_dirichlet = False
|
| 213 |
+
use_layerwise_lr = False
|
| 214 |
+
wandb_name = "B2-ResNet18-KL+MSE"
|
| 215 |
+
|
| 216 |
+
elif model_key == "b3_vit_mse":
|
| 217 |
+
vit_mse_cfg = OC.merge(
|
| 218 |
+
cfg, OC.create({"loss": {"lambda_kl": 0.0, "lambda_mse": 1.0}})
|
| 219 |
+
)
|
| 220 |
+
model = build_model(vit_mse_cfg).to(device)
|
| 221 |
+
loss_fn = MSEOnlyLoss(vit_mse_cfg)
|
| 222 |
+
cfg = vit_mse_cfg # use updated cfg for optimizer
|
| 223 |
+
use_sigmoid = False
|
| 224 |
+
is_dirichlet = False
|
| 225 |
+
use_layerwise_lr = True
|
| 226 |
+
wandb_name = "B3-ViT-MSE"
|
| 227 |
+
|
| 228 |
+
elif model_key == "b4_vit_dir":
|
| 229 |
+
model = build_dirichlet_model(cfg).to(device)
|
| 230 |
+
loss_fn = DirichletLoss(cfg)
|
| 231 |
+
use_sigmoid = False
|
| 232 |
+
is_dirichlet = True
|
| 233 |
+
use_layerwise_lr = True
|
| 234 |
+
wandb_name = "B4-ViT-Dirichlet"
|
| 235 |
+
|
| 236 |
+
else:
|
| 237 |
+
raise ValueError(f"Unknown model key: {model_key}")
|
| 238 |
+
|
| 239 |
+
total = sum(p.numel() for p in model.parameters())
|
| 240 |
+
log.info("Parameters: %s", f"{total:,}")
|
| 241 |
+
|
| 242 |
+
# ββ Optimizer ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 243 |
+
if use_layerwise_lr and hasattr(model, "backbone") and hasattr(model, "head"):
|
| 244 |
+
optimizer = torch.optim.AdamW(
|
| 245 |
+
[
|
| 246 |
+
{"params": model.backbone.parameters(),
|
| 247 |
+
"lr": cfg.training.learning_rate * 0.1},
|
| 248 |
+
{"params": model.head.parameters(),
|
| 249 |
+
"lr": cfg.training.learning_rate},
|
| 250 |
+
],
|
| 251 |
+
weight_decay=cfg.training.weight_decay,
|
| 252 |
+
)
|
| 253 |
+
else:
|
| 254 |
+
optimizer = torch.optim.AdamW(
|
| 255 |
+
model.parameters(),
|
| 256 |
+
lr=cfg.training.learning_rate,
|
| 257 |
+
weight_decay=cfg.training.weight_decay,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 261 |
+
optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min
|
| 262 |
+
)
|
| 263 |
+
scaler = GradScaler("cuda")
|
| 264 |
+
early_stop = EarlyStopping(
|
| 265 |
+
patience=cfg.early_stopping.patience,
|
| 266 |
+
min_delta=cfg.early_stopping.min_delta,
|
| 267 |
+
checkpoint_path=ckpt_path,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
train_loader, val_loader, test_loader = build_dataloaders(cfg)
|
| 271 |
+
|
| 272 |
+
wandb.init(
|
| 273 |
+
project=cfg.wandb.project, name=wandb_name,
|
| 274 |
+
config={"model": wandb_name, "seed": cfg.seed,
|
| 275 |
+
"epochs": cfg.training.epochs,
|
| 276 |
+
"lambda_kl": cfg.loss.lambda_kl},
|
| 277 |
+
reinit=True,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# ββ Training loop ββββββββββββββββββββββββββββββββββββββββββ
|
| 281 |
+
history = []
|
| 282 |
+
for epoch in range(1, cfg.training.epochs + 1):
|
| 283 |
+
if is_dirichlet:
|
| 284 |
+
train_loss = _train_epoch_dirichlet(
|
| 285 |
+
model, train_loader, loss_fn, optimizer, scaler,
|
| 286 |
+
device, cfg, epoch, wandb_name
|
| 287 |
+
)
|
| 288 |
+
val_loss, val_metrics = _val_epoch_dirichlet(
|
| 289 |
+
model, val_loader, loss_fn, device, cfg, epoch, wandb_name
|
| 290 |
+
)
|
| 291 |
+
else:
|
| 292 |
+
train_loss = _train_epoch(
|
| 293 |
+
model, train_loader, loss_fn, optimizer, scaler,
|
| 294 |
+
device, cfg, epoch, wandb_name
|
| 295 |
+
)
|
| 296 |
+
val_loss, val_metrics = _val_epoch(
|
| 297 |
+
model, val_loader, loss_fn, device, cfg, epoch, wandb_name,
|
| 298 |
+
use_sigmoid=use_sigmoid
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
scheduler.step()
|
| 302 |
+
lr = scheduler.get_last_lr()[0]
|
| 303 |
+
val_mae = val_metrics.get("mae/weighted_avg", 0)
|
| 304 |
+
|
| 305 |
+
log.info("%s epoch=%d train=%.4f val=%.4f mae=%.4f lr=%.2e",
|
| 306 |
+
wandb_name, epoch, train_loss, val_loss, val_mae, lr)
|
| 307 |
+
history.append({
|
| 308 |
+
"epoch": epoch, "train_loss": train_loss,
|
| 309 |
+
"val_loss": val_loss, "val_mae": val_mae,
|
| 310 |
+
})
|
| 311 |
+
wandb.log({
|
| 312 |
+
"train_loss": train_loss, "val_loss": val_loss,
|
| 313 |
+
"val_mae": val_mae, "lr": lr,
|
| 314 |
+
}, step=epoch)
|
| 315 |
+
|
| 316 |
+
if early_stop.step(val_loss, model, epoch):
|
| 317 |
+
log.info("%s: early stopping at epoch %d", wandb_name, epoch)
|
| 318 |
+
break
|
| 319 |
+
|
| 320 |
+
best_val = early_stop.restore_best(model)
|
| 321 |
+
wandb.finish()
|
| 322 |
+
|
| 323 |
+
# ββ Test evaluation ββββββββββββββββββββββββββββββββββββββββ
|
| 324 |
+
log.info("Evaluating on test set...")
|
| 325 |
+
if is_dirichlet:
|
| 326 |
+
_, test_metrics = _val_epoch_dirichlet(
|
| 327 |
+
model, test_loader, loss_fn, device, cfg,
|
| 328 |
+
epoch=0, label=f"{wandb_name}-test"
|
| 329 |
+
)
|
| 330 |
+
else:
|
| 331 |
+
_, test_metrics = _val_epoch(
|
| 332 |
+
model, test_loader, loss_fn, device, cfg,
|
| 333 |
+
epoch=0, label=f"{wandb_name}-test", use_sigmoid=use_sigmoid
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
log.info("%s β Test MAE=%.5f RMSE=%.5f",
|
| 337 |
+
wandb_name,
|
| 338 |
+
test_metrics["mae/weighted_avg"],
|
| 339 |
+
test_metrics["rmse/weighted_avg"])
|
| 340 |
+
|
| 341 |
+
# ββ Save per-model history βββββββββββββββββββββββββββββββββ
|
| 342 |
+
hist_path = Path(cfg.outputs.log_dir) / f"training_{model_key}_history.csv"
|
| 343 |
+
pd.DataFrame(history).to_csv(hist_path, index=False)
|
| 344 |
+
log.info("History saved: %s", hist_path)
|
| 345 |
+
log.info("Done. Checkpoint: %s", ckpt_path)
|
| 346 |
+
|
| 347 |
+
return test_metrics, best_val, early_stop.best_epoch
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 351 |
+
# Main
|
| 352 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 353 |
+
|
| 354 |
+
def main():
|
| 355 |
+
parser = argparse.ArgumentParser(
|
| 356 |
+
description="Train a single model. Run multiple times to train "
|
| 357 |
+
"different models with breaks in between."
|
| 358 |
+
)
|
| 359 |
+
parser.add_argument(
|
| 360 |
+
"--model",
|
| 361 |
+
required=True,
|
| 362 |
+
choices=list(CHECKPOINT_NAMES.keys()),
|
| 363 |
+
help=(
|
| 364 |
+
"Which model to train:\n"
|
| 365 |
+
" proposed β ViT-Base + hierarchical KL+MSE (main)\n"
|
| 366 |
+
" b1_resnet_mse β ResNet-18 + independent MSE (sigmoid)\n"
|
| 367 |
+
" b2_resnet_kl β ResNet-18 + hierarchical KL+MSE\n"
|
| 368 |
+
" b3_vit_mse β ViT-Base + hierarchical MSE only\n"
|
| 369 |
+
" b4_vit_dir β ViT-Base + Dirichlet NLL\n"
|
| 370 |
+
),
|
| 371 |
+
)
|
| 372 |
+
parser.add_argument("--config", required=True)
|
| 373 |
+
parser.add_argument(
|
| 374 |
+
"--force",
|
| 375 |
+
action="store_true",
|
| 376 |
+
help="Retrain even if checkpoint already exists.",
|
| 377 |
+
)
|
| 378 |
+
args = parser.parse_args()
|
| 379 |
+
|
| 380 |
+
base_cfg = OmegaConf.load("configs/base.yaml")
|
| 381 |
+
exp_cfg = OmegaConf.load(args.config)
|
| 382 |
+
cfg = OmegaConf.merge(base_cfg, exp_cfg)
|
| 383 |
+
|
| 384 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 385 |
+
ckpt_dir = Path(cfg.outputs.checkpoint_dir)
|
| 386 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 387 |
+
Path(cfg.outputs.log_dir).mkdir(parents=True, exist_ok=True)
|
| 388 |
+
|
| 389 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 390 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 391 |
+
|
| 392 |
+
ckpt_path = str(ckpt_dir / CHECKPOINT_NAMES[args.model])
|
| 393 |
+
|
| 394 |
+
# ββ Skip if already done βββββββββββββββββββββββββββββββββββ
|
| 395 |
+
if Path(ckpt_path).exists() and not args.force:
|
| 396 |
+
log.info("Checkpoint already exists: %s", ckpt_path)
|
| 397 |
+
log.info("Model '%s' is already trained. Skipping.", args.model)
|
| 398 |
+
log.info("Use --force to retrain.")
|
| 399 |
+
return
|
| 400 |
+
|
| 401 |
+
log.info("=" * 60)
|
| 402 |
+
log.info("Training: %s", MODEL_LABELS[args.model])
|
| 403 |
+
log.info("Device : %s", device)
|
| 404 |
+
log.info("Config : %s", args.config)
|
| 405 |
+
log.info("Ckpt : %s", ckpt_path)
|
| 406 |
+
log.info("=" * 60)
|
| 407 |
+
|
| 408 |
+
if args.model == "proposed":
|
| 409 |
+
train_proposed(cfg, device, ckpt_path)
|
| 410 |
+
else:
|
| 411 |
+
train_baseline(cfg, device, ckpt_path, args.model)
|
| 412 |
+
|
| 413 |
+
log.info("=" * 60)
|
| 414 |
+
log.info("FINISHED: %s", MODEL_LABELS[args.model])
|
| 415 |
+
log.info("=" * 60)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
if __name__ == "__main__":
|
| 419 |
+
main()
|
src/uncertainty_analysis.py
ADDED
|
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/uncertainty_analysis.py
|
| 3 |
+
----------------------------
|
| 4 |
+
MC Dropout epistemic uncertainty analysis for the proposed model.
|
| 5 |
+
|
| 6 |
+
MC Dropout (Gal & Ghahramani 2016) is used as a post-hoc uncertainty
|
| 7 |
+
estimator. At inference time, dropout is kept active and N=30 stochastic
|
| 8 |
+
forward passes are run per batch. The standard deviation across passes
|
| 9 |
+
is used as the epistemic uncertainty estimate per galaxy per question.
|
| 10 |
+
|
| 11 |
+
Key findings reported
|
| 12 |
+
---------------------
|
| 13 |
+
1. Uncertainty distributions: right-skewed, well-separated means across
|
| 14 |
+
questions reflecting the conditional nature of the decision tree.
|
| 15 |
+
|
| 16 |
+
2. Uncertainty vs. error correlation: Spearman Ο reported per question.
|
| 17 |
+
Strong positive correlation for root and shallow-branch questions
|
| 18 |
+
(t01, t02, t04, t07) indicates the model is well-calibrated in
|
| 19 |
+
uncertainty. Weak or near-zero correlation for deep conditional
|
| 20 |
+
branches (t03, t05, t08, t09, t10, t11) is expected β these branches
|
| 21 |
+
have small effective sample sizes and aleatoric uncertainty dominates.
|
| 22 |
+
|
| 23 |
+
3. Morphology selection benchmark: F1 score at threshold Ο for downstream
|
| 24 |
+
binary morphology classification tasks.
|
| 25 |
+
|
| 26 |
+
Output files
|
| 27 |
+
------------
|
| 28 |
+
outputs/figures/uncertainty/
|
| 29 |
+
fig_uncertainty_distributions.pdf
|
| 30 |
+
fig_uncertainty_vs_error.pdf
|
| 31 |
+
fig_morphology_f1_comparison.pdf
|
| 32 |
+
table_uncertainty_summary.csv
|
| 33 |
+
table_morphology_selection_benchmark.csv
|
| 34 |
+
mc_cache/ β cached numpy arrays (crash-safe)
|
| 35 |
+
|
| 36 |
+
Usage
|
| 37 |
+
-----
|
| 38 |
+
cd ~/galaxy
|
| 39 |
+
nohup python -m src.uncertainty_analysis \
|
| 40 |
+
--config configs/full_train.yaml --n_passes 30 \
|
| 41 |
+
> outputs/logs/uncertainty.log 2>&1 &
|
| 42 |
+
echo "PID: $!"
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
import argparse
|
| 46 |
+
import logging
|
| 47 |
+
import sys
|
| 48 |
+
from pathlib import Path
|
| 49 |
+
|
| 50 |
+
import numpy as np
|
| 51 |
+
import pandas as pd
|
| 52 |
+
import torch
|
| 53 |
+
import torch.nn.functional as F
|
| 54 |
+
import matplotlib
|
| 55 |
+
matplotlib.use("Agg")
|
| 56 |
+
import matplotlib.pyplot as plt
|
| 57 |
+
from scipy import stats as scipy_stats
|
| 58 |
+
from torch.amp import autocast
|
| 59 |
+
from omegaconf import OmegaConf
|
| 60 |
+
from tqdm import tqdm
|
| 61 |
+
|
| 62 |
+
from src.dataset import build_dataloaders, QUESTION_GROUPS
|
| 63 |
+
from src.model import build_model, build_dirichlet_model
|
| 64 |
+
from src.baselines import ResNet18Baseline
|
| 65 |
+
from src.metrics import predictions_to_numpy, dirichlet_predictions_to_numpy
|
| 66 |
+
|
| 67 |
+
logging.basicConfig(
|
| 68 |
+
format="%(asctime)s %(levelname)s %(message)s",
|
| 69 |
+
datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
|
| 70 |
+
)
|
| 71 |
+
log = logging.getLogger("uncertainty")
|
| 72 |
+
|
| 73 |
+
plt.rcParams.update({
|
| 74 |
+
"figure.dpi": 150, "savefig.dpi": 300,
|
| 75 |
+
"font.family": "serif", "font.size": 11,
|
| 76 |
+
"axes.titlesize": 10, "axes.labelsize": 10,
|
| 77 |
+
"xtick.labelsize": 8, "ytick.labelsize": 8,
|
| 78 |
+
"legend.fontsize": 8,
|
| 79 |
+
"figure.facecolor": "white", "axes.facecolor": "white",
|
| 80 |
+
"axes.grid": True, "grid.alpha": 0.3,
|
| 81 |
+
"pdf.fonttype": 42, "ps.fonttype": 42,
|
| 82 |
+
})
|
| 83 |
+
|
| 84 |
+
QUESTION_LABELS = {
|
| 85 |
+
"t01": "Smooth or features", "t02": "Edge-on disk",
|
| 86 |
+
"t03": "Bar", "t04": "Spiral arms",
|
| 87 |
+
"t05": "Bulge prominence", "t06": "Odd feature",
|
| 88 |
+
"t07": "Roundedness", "t08": "Odd feature type",
|
| 89 |
+
"t09": "Bulge shape", "t10": "Arms winding",
|
| 90 |
+
"t11": "Arms number",
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
MODEL_COLORS = {
|
| 94 |
+
"ViT-Base + KL+MSE (proposed)" : "#27ae60",
|
| 95 |
+
"ViT-Base + Dirichlet (Zoobot-style)": "#8e44ad",
|
| 96 |
+
"ResNet-18 + MSE (sigmoid)" : "#c0392b",
|
| 97 |
+
"ResNet-18 + KL+MSE" : "#e67e22",
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
SELECTION_THRESHOLDS = [0.5, 0.7, 0.8, 0.9]
|
| 101 |
+
SELECTION_ANSWERS = {
|
| 102 |
+
"t01": (0, "smooth"),
|
| 103 |
+
"t02": (0, "edge-on"),
|
| 104 |
+
"t03": (0, "bar"),
|
| 105 |
+
"t04": (0, "spiral"),
|
| 106 |
+
"t06": (0, "odd feature"),
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 111 |
+
# MC Dropout inference β Welford online algorithm, crash-safe
|
| 112 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 113 |
+
|
| 114 |
+
def run_mc_inference(model, loader, device, cfg,
|
| 115 |
+
n_passes=30, cache_dir=None):
|
| 116 |
+
"""
|
| 117 |
+
Fast batched MC Dropout inference.
|
| 118 |
+
|
| 119 |
+
Uses Welford's online algorithm to compute mean and std
|
| 120 |
+
per batch without storing all n_passes Γ N predictions.
|
| 121 |
+
Memory usage: O(N Γ 37) regardless of n_passes.
|
| 122 |
+
|
| 123 |
+
Parameters
|
| 124 |
+
----------
|
| 125 |
+
model : GalaxyViT with enable_mc_dropout() available
|
| 126 |
+
loader : test DataLoader
|
| 127 |
+
device : inference device
|
| 128 |
+
cfg : OmegaConf config
|
| 129 |
+
n_passes : number of stochastic forward passes (default 30)
|
| 130 |
+
cache_dir : if given, saves .npy files and skips if they exist
|
| 131 |
+
|
| 132 |
+
Returns
|
| 133 |
+
-------
|
| 134 |
+
mean_all, std_all : [N, 37] float32
|
| 135 |
+
targets_all : [N, 37] float32
|
| 136 |
+
weights_all : [N, 11] float32
|
| 137 |
+
"""
|
| 138 |
+
if cache_dir is not None:
|
| 139 |
+
cache_dir = Path(cache_dir)
|
| 140 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 141 |
+
fp_mean = cache_dir / "mc_mean.npy"
|
| 142 |
+
fp_std = cache_dir / "mc_std.npy"
|
| 143 |
+
fp_targets = cache_dir / "mc_targets.npy"
|
| 144 |
+
fp_weights = cache_dir / "mc_weights.npy"
|
| 145 |
+
|
| 146 |
+
if all(p.exists() for p in [fp_mean, fp_std, fp_targets, fp_weights]):
|
| 147 |
+
log.info("MC cache found β loading from disk (skipping inference).")
|
| 148 |
+
return (np.load(fp_mean), np.load(fp_std),
|
| 149 |
+
np.load(fp_targets), np.load(fp_weights))
|
| 150 |
+
|
| 151 |
+
model.eval()
|
| 152 |
+
model.enable_mc_dropout()
|
| 153 |
+
|
| 154 |
+
all_means, all_stds, all_targets, all_weights = [], [], [], []
|
| 155 |
+
log.info("MC Dropout: %d passes Γ %d-image batches = %d total forward passes",
|
| 156 |
+
n_passes, loader.batch_size, n_passes * len(loader))
|
| 157 |
+
|
| 158 |
+
for images, targets, weights, _ in tqdm(loader, desc="MC Dropout"):
|
| 159 |
+
images_dev = images.to(device, non_blocking=True)
|
| 160 |
+
|
| 161 |
+
# Welford online mean and M2
|
| 162 |
+
mean_acc = None
|
| 163 |
+
M2_acc = None
|
| 164 |
+
count = 0
|
| 165 |
+
|
| 166 |
+
for _ in range(n_passes):
|
| 167 |
+
with torch.no_grad():
|
| 168 |
+
with autocast("cuda", enabled=cfg.training.mixed_precision):
|
| 169 |
+
logits = model(images_dev)
|
| 170 |
+
|
| 171 |
+
pred = torch.zeros_like(logits)
|
| 172 |
+
for q, (s, e) in QUESTION_GROUPS.items():
|
| 173 |
+
pred[:, s:e] = F.softmax(logits[:, s:e], dim=-1)
|
| 174 |
+
pred_np = pred.cpu().float().numpy() # [B, 37]
|
| 175 |
+
|
| 176 |
+
count += 1
|
| 177 |
+
if mean_acc is None:
|
| 178 |
+
mean_acc = pred_np.copy()
|
| 179 |
+
M2_acc = np.zeros_like(pred_np)
|
| 180 |
+
else:
|
| 181 |
+
delta = pred_np - mean_acc
|
| 182 |
+
mean_acc += delta / count
|
| 183 |
+
M2_acc += delta * (pred_np - mean_acc)
|
| 184 |
+
|
| 185 |
+
std_acc = np.sqrt(M2_acc / (count - 1) if count > 1
|
| 186 |
+
else np.zeros_like(M2_acc))
|
| 187 |
+
|
| 188 |
+
all_means.append(mean_acc)
|
| 189 |
+
all_stds.append(std_acc)
|
| 190 |
+
all_targets.append(targets.numpy())
|
| 191 |
+
all_weights.append(weights.numpy())
|
| 192 |
+
|
| 193 |
+
model.disable_mc_dropout()
|
| 194 |
+
|
| 195 |
+
mean_all = np.concatenate(all_means)
|
| 196 |
+
std_all = np.concatenate(all_stds)
|
| 197 |
+
targets_all = np.concatenate(all_targets)
|
| 198 |
+
weights_all = np.concatenate(all_weights)
|
| 199 |
+
|
| 200 |
+
if cache_dir is not None:
|
| 201 |
+
np.save(fp_mean, mean_all)
|
| 202 |
+
np.save(fp_std, std_all)
|
| 203 |
+
np.save(fp_targets, targets_all)
|
| 204 |
+
np.save(fp_weights, weights_all)
|
| 205 |
+
log.info("MC results cached: %s", cache_dir)
|
| 206 |
+
|
| 207 |
+
return mean_all, std_all, targets_all, weights_all
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 211 |
+
# Figure 1: Uncertainty distributions
|
| 212 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 213 |
+
|
| 214 |
+
def fig_uncertainty_distributions(mean_preds, std_preds,
|
| 215 |
+
targets, weights, save_dir):
|
| 216 |
+
path_pdf = save_dir / "fig_uncertainty_distributions.pdf"
|
| 217 |
+
path_png = save_dir / "fig_uncertainty_distributions.png"
|
| 218 |
+
if path_pdf.exists() and path_png.exists():
|
| 219 |
+
log.info("Skip (exists): fig_uncertainty_distributions"); return
|
| 220 |
+
|
| 221 |
+
fig, axes = plt.subplots(3, 4, figsize=(16, 12))
|
| 222 |
+
axes = axes.flatten()
|
| 223 |
+
|
| 224 |
+
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
|
| 225 |
+
ax = axes[q_idx]
|
| 226 |
+
mask = weights[:, q_idx] >= 0.05
|
| 227 |
+
std_q = std_preds[mask, start:end].mean(axis=1)
|
| 228 |
+
|
| 229 |
+
ax.hist(std_q, bins=50, color="#6366f1", alpha=0.85,
|
| 230 |
+
edgecolor="none", density=True)
|
| 231 |
+
ax.axvline(std_q.mean(), color="#c0392b", linewidth=1.8,
|
| 232 |
+
linestyle="--", label=f"Mean = {std_q.mean():.4f}")
|
| 233 |
+
ax.set_xlabel("MC Dropout std (epistemic uncertainty)")
|
| 234 |
+
ax.set_ylabel("Density")
|
| 235 |
+
ax.set_title(
|
| 236 |
+
f"{q_name}: {QUESTION_LABELS[q_name]}\n"
|
| 237 |
+
f"$n$ = {mask.sum():,} (w β₯ 0.05)",
|
| 238 |
+
fontsize=9,
|
| 239 |
+
)
|
| 240 |
+
ax.legend(fontsize=7)
|
| 241 |
+
|
| 242 |
+
axes[-1].axis("off")
|
| 243 |
+
plt.suptitle(
|
| 244 |
+
"Epistemic uncertainty distributions β MC Dropout (30 passes)\n"
|
| 245 |
+
"Proposed model (ViT-Base/16 + hierarchical KL+MSE), test set",
|
| 246 |
+
fontsize=12,
|
| 247 |
+
)
|
| 248 |
+
plt.tight_layout()
|
| 249 |
+
fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
|
| 250 |
+
fig.savefig(path_png, dpi=300, bbox_inches="tight")
|
| 251 |
+
plt.close(fig)
|
| 252 |
+
log.info("Saved: fig_uncertainty_distributions")
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 256 |
+
# Figure 2: Uncertainty vs. error (Spearman Ο)
|
| 257 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 258 |
+
|
| 259 |
+
def fig_uncertainty_vs_error(mean_preds, std_preds,
|
| 260 |
+
targets, weights, save_dir):
|
| 261 |
+
path_pdf = save_dir / "fig_uncertainty_vs_error.pdf"
|
| 262 |
+
path_png = save_dir / "fig_uncertainty_vs_error.png"
|
| 263 |
+
if path_pdf.exists() and path_png.exists():
|
| 264 |
+
log.info("Skip (exists): fig_uncertainty_vs_error"); return
|
| 265 |
+
|
| 266 |
+
fig, axes = plt.subplots(3, 4, figsize=(16, 12))
|
| 267 |
+
axes = axes.flatten()
|
| 268 |
+
|
| 269 |
+
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
|
| 270 |
+
ax = axes[q_idx]
|
| 271 |
+
mask = weights[:, q_idx] >= 0.05
|
| 272 |
+
unc = std_preds[mask, start:end].mean(axis=1)
|
| 273 |
+
err = np.abs(mean_preds[mask, start:end] -
|
| 274 |
+
targets[mask, start:end]).mean(axis=1)
|
| 275 |
+
|
| 276 |
+
# Adaptive bin means for trend line
|
| 277 |
+
n_bins = 15
|
| 278 |
+
unc_bins = np.unique(np.percentile(unc, np.linspace(0, 100, n_bins + 1)))
|
| 279 |
+
bin_ids = np.clip(np.digitize(unc, unc_bins) - 1, 0, len(unc_bins) - 2)
|
| 280 |
+
bn_unc = [unc[bin_ids == b].mean() for b in range(len(unc_bins) - 1)
|
| 281 |
+
if (bin_ids == b).any()]
|
| 282 |
+
bn_err = [err[bin_ids == b].mean() for b in range(len(unc_bins) - 1)
|
| 283 |
+
if (bin_ids == b).any()]
|
| 284 |
+
|
| 285 |
+
ax.scatter(unc, err, alpha=0.04, s=1, color="#94a3b8", rasterized=True)
|
| 286 |
+
ax.plot(bn_unc, bn_err, "r-o", markersize=4, linewidth=2,
|
| 287 |
+
label="Bin mean")
|
| 288 |
+
|
| 289 |
+
# Spearman rank correlation (more robust than Pearson for this data)
|
| 290 |
+
rho, pval = scipy_stats.spearmanr(unc, err)
|
| 291 |
+
p_str = f"p < 0.001" if pval < 0.001 else f"p = {pval:.3f}"
|
| 292 |
+
ax.text(0.05, 0.90,
|
| 293 |
+
f"Spearman Ο = {rho:.3f}\n{p_str}",
|
| 294 |
+
transform=ax.transAxes, fontsize=7.5,
|
| 295 |
+
bbox=dict(boxstyle="round,pad=0.25", facecolor="white",
|
| 296 |
+
edgecolor="grey", alpha=0.85))
|
| 297 |
+
|
| 298 |
+
ax.set_xlabel("Uncertainty (MC std)")
|
| 299 |
+
ax.set_ylabel("Absolute error")
|
| 300 |
+
ax.set_title(f"{q_name}: {QUESTION_LABELS[q_name]}", fontsize=9)
|
| 301 |
+
ax.legend(fontsize=7)
|
| 302 |
+
|
| 303 |
+
axes[-1].axis("off")
|
| 304 |
+
plt.suptitle(
|
| 305 |
+
"Epistemic uncertainty vs. absolute prediction error β per morphological question\n"
|
| 306 |
+
"Strong Spearman Ο for root/shallow questions; weak Ο for deep conditional branches "
|
| 307 |
+
"(expected: aleatoric uncertainty dominates when branch is rarely reached)",
|
| 308 |
+
fontsize=10,
|
| 309 |
+
)
|
| 310 |
+
plt.tight_layout()
|
| 311 |
+
fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
|
| 312 |
+
fig.savefig(path_png, dpi=300, bbox_inches="tight")
|
| 313 |
+
plt.close(fig)
|
| 314 |
+
log.info("Saved: fig_uncertainty_vs_error")
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 318 |
+
# Table: uncertainty summary
|
| 319 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 320 |
+
|
| 321 |
+
def table_uncertainty_summary(mean_preds, std_preds,
|
| 322 |
+
targets, weights, save_dir):
|
| 323 |
+
path = save_dir / "table_uncertainty_summary.csv"
|
| 324 |
+
if path.exists():
|
| 325 |
+
log.info("Skip (exists): table_uncertainty_summary"); return
|
| 326 |
+
|
| 327 |
+
rows = []
|
| 328 |
+
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
|
| 329 |
+
mask = weights[:, q_idx] >= 0.05
|
| 330 |
+
unc = std_preds[mask, start:end].mean(axis=1)
|
| 331 |
+
err = np.abs(mean_preds[mask, start:end] -
|
| 332 |
+
targets[mask, start:end]).mean(axis=1)
|
| 333 |
+
|
| 334 |
+
if mask.sum() > 10:
|
| 335 |
+
rho, pval = scipy_stats.spearmanr(unc, err)
|
| 336 |
+
else:
|
| 337 |
+
rho, pval = float("nan"), float("nan")
|
| 338 |
+
|
| 339 |
+
rows.append({
|
| 340 |
+
"question" : q_name,
|
| 341 |
+
"description" : QUESTION_LABELS[q_name],
|
| 342 |
+
"n_reached" : int(mask.sum()),
|
| 343 |
+
"mean_uncertainty": round(float(unc.mean()), 5),
|
| 344 |
+
"std_uncertainty" : round(float(unc.std()), 5),
|
| 345 |
+
"mean_mae" : round(float(err.mean()), 5),
|
| 346 |
+
"spearman_rho" : round(float(rho), 4),
|
| 347 |
+
"spearman_pval" : round(float(pval), 4),
|
| 348 |
+
})
|
| 349 |
+
|
| 350 |
+
df = pd.DataFrame(rows)
|
| 351 |
+
df.to_csv(path, index=False)
|
| 352 |
+
log.info("Saved: table_uncertainty_summary.csv")
|
| 353 |
+
print("\n" + df.to_string(index=False) + "\n")
|
| 354 |
+
return df
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 358 |
+
# Figure 3 + Table: Morphology selection benchmark
|
| 359 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 360 |
+
|
| 361 |
+
def morphology_selection_benchmark(model_results, save_dir):
|
| 362 |
+
csv_path = save_dir / "table_morphology_selection_benchmark.csv"
|
| 363 |
+
if csv_path.exists():
|
| 364 |
+
log.info("Loading existing morphology benchmark...")
|
| 365 |
+
df = pd.read_csv(csv_path)
|
| 366 |
+
_fig_morphology_f1(df, save_dir)
|
| 367 |
+
return df
|
| 368 |
+
|
| 369 |
+
rows = []
|
| 370 |
+
for model_name, (preds, targets, weights) in model_results.items():
|
| 371 |
+
for q_name, (ans_idx, ans_label) in SELECTION_ANSWERS.items():
|
| 372 |
+
start, end = QUESTION_GROUPS[q_name]
|
| 373 |
+
q_idx = list(QUESTION_GROUPS.keys()).index(q_name)
|
| 374 |
+
mask = weights[:, q_idx] >= 0.05
|
| 375 |
+
pred_a = preds[mask, start + ans_idx]
|
| 376 |
+
true_a = targets[mask, start + ans_idx]
|
| 377 |
+
|
| 378 |
+
for thresh in SELECTION_THRESHOLDS:
|
| 379 |
+
sel = pred_a >= thresh
|
| 380 |
+
true_pos = true_a >= thresh
|
| 381 |
+
n_sel = sel.sum()
|
| 382 |
+
n_tp_all = true_pos.sum()
|
| 383 |
+
n_tp = (sel & true_pos).sum()
|
| 384 |
+
prec = n_tp / n_sel if n_sel > 0 else 0.0
|
| 385 |
+
rec = n_tp / n_tp_all if n_tp_all > 0 else 0.0
|
| 386 |
+
f1 = (2 * prec * rec / (prec + rec)
|
| 387 |
+
if (prec + rec) > 0 else 0.0)
|
| 388 |
+
rows.append({
|
| 389 |
+
"model" : model_name,
|
| 390 |
+
"question" : q_name,
|
| 391 |
+
"answer" : ans_label,
|
| 392 |
+
"threshold" : thresh,
|
| 393 |
+
"n_selected": int(n_sel),
|
| 394 |
+
"n_true_pos": int(n_tp_all),
|
| 395 |
+
"precision" : round(float(prec), 4),
|
| 396 |
+
"recall" : round(float(rec), 4),
|
| 397 |
+
"f1" : round(float(f1), 4),
|
| 398 |
+
})
|
| 399 |
+
|
| 400 |
+
df = pd.DataFrame(rows)
|
| 401 |
+
df.to_csv(csv_path, index=False)
|
| 402 |
+
log.info("Saved: table_morphology_selection_benchmark.csv")
|
| 403 |
+
_fig_morphology_f1(df, save_dir)
|
| 404 |
+
return df
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def _fig_morphology_f1(df, save_dir):
|
| 408 |
+
path_pdf = save_dir / "fig_morphology_f1_comparison.pdf"
|
| 409 |
+
path_png = save_dir / "fig_morphology_f1_comparison.png"
|
| 410 |
+
if path_pdf.exists() and path_png.exists():
|
| 411 |
+
log.info("Skip (exists): fig_morphology_f1_comparison"); return
|
| 412 |
+
|
| 413 |
+
thresh = 0.8
|
| 414 |
+
sub = df[df["threshold"] == thresh]
|
| 415 |
+
q_list = list(SELECTION_ANSWERS.keys())
|
| 416 |
+
models = list(df["model"].unique())
|
| 417 |
+
|
| 418 |
+
x = np.arange(len(q_list))
|
| 419 |
+
width = 0.80 / len(models)
|
| 420 |
+
palette = list(MODEL_COLORS.values())
|
| 421 |
+
|
| 422 |
+
fig, ax = plt.subplots(figsize=(12, 5))
|
| 423 |
+
for i, model in enumerate(models):
|
| 424 |
+
f1s = []
|
| 425 |
+
for q in q_list:
|
| 426 |
+
row = sub[(sub["model"] == model) & (sub["question"] == q)]
|
| 427 |
+
f1s.append(float(row["f1"].values[0]) if len(row) > 0 else 0.0)
|
| 428 |
+
ax.bar(
|
| 429 |
+
x + i * width, f1s, width,
|
| 430 |
+
label=model,
|
| 431 |
+
color=MODEL_COLORS.get(model, palette[i % len(palette)]),
|
| 432 |
+
alpha=0.85, edgecolor="white", linewidth=0.5,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
ax.set_xticks(x + width * (len(models) - 1) / 2)
|
| 436 |
+
ax.set_xticklabels(
|
| 437 |
+
[f"{q}\n({SELECTION_ANSWERS[q][1]})" for q in q_list], fontsize=9
|
| 438 |
+
)
|
| 439 |
+
ax.set_ylabel("F$_1$ score", fontsize=11)
|
| 440 |
+
ax.set_title(
|
| 441 |
+
f"Downstream morphology selection β F$_1$ at threshold $\\tau$ = {thresh}\n"
|
| 442 |
+
"Higher F$_1$ indicates cleaner, more complete morphological sample selection.",
|
| 443 |
+
fontsize=11,
|
| 444 |
+
)
|
| 445 |
+
ax.legend(fontsize=8)
|
| 446 |
+
ax.set_ylim(0, 1)
|
| 447 |
+
ax.grid(True, alpha=0.3, axis="y")
|
| 448 |
+
ax.set_axisbelow(True)
|
| 449 |
+
plt.tight_layout()
|
| 450 |
+
fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
|
| 451 |
+
fig.savefig(path_png, dpi=300, bbox_inches="tight")
|
| 452 |
+
plt.close(fig)
|
| 453 |
+
log.info("Saved: fig_morphology_f1_comparison")
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 457 |
+
# Main
|
| 458 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 459 |
+
|
| 460 |
+
def main():
|
| 461 |
+
parser = argparse.ArgumentParser()
|
| 462 |
+
parser.add_argument("--config", required=True)
|
| 463 |
+
parser.add_argument("--n_passes", type=int, default=30)
|
| 464 |
+
args = parser.parse_args()
|
| 465 |
+
|
| 466 |
+
base_cfg = OmegaConf.load("configs/base.yaml")
|
| 467 |
+
exp_cfg = OmegaConf.load(args.config)
|
| 468 |
+
cfg = OmegaConf.merge(base_cfg, exp_cfg)
|
| 469 |
+
|
| 470 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 471 |
+
save_dir = Path(cfg.outputs.figures_dir) / "uncertainty"
|
| 472 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 473 |
+
cache_dir = save_dir / "mc_cache"
|
| 474 |
+
ckpt_dir = Path(cfg.outputs.checkpoint_dir)
|
| 475 |
+
|
| 476 |
+
_, _, test_loader = build_dataloaders(cfg)
|
| 477 |
+
|
| 478 |
+
# ββ 1. MC Dropout on proposed model βββββββββββββββββββββββ
|
| 479 |
+
log.info("Loading proposed model...")
|
| 480 |
+
proposed = build_model(cfg).to(device)
|
| 481 |
+
proposed.load_state_dict(
|
| 482 |
+
torch.load(ckpt_dir / "best_full_train.pt",
|
| 483 |
+
map_location="cpu", weights_only=True)["model_state"]
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
mean_preds, std_preds, targets, weights = run_mc_inference(
|
| 487 |
+
proposed, test_loader, device, cfg,
|
| 488 |
+
n_passes=args.n_passes, cache_dir=cache_dir,
|
| 489 |
+
)
|
| 490 |
+
log.info("MC Dropout complete: %d galaxies, %d passes.",
|
| 491 |
+
len(mean_preds), args.n_passes)
|
| 492 |
+
|
| 493 |
+
# ββ 2. Uncertainty figures and table ββββββββββββββββββββββ
|
| 494 |
+
fig_uncertainty_distributions(mean_preds, std_preds, targets, weights, save_dir)
|
| 495 |
+
fig_uncertainty_vs_error(mean_preds, std_preds, targets, weights, save_dir)
|
| 496 |
+
table_uncertainty_summary(mean_preds, std_preds, targets, weights, save_dir)
|
| 497 |
+
|
| 498 |
+
# ββ 3. Downstream benchmark across all models βββββββββββββ
|
| 499 |
+
log.info("Building model_results for downstream benchmark...")
|
| 500 |
+
|
| 501 |
+
model_results = {
|
| 502 |
+
"ViT-Base + KL+MSE (proposed)": (mean_preds, targets, weights),
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
def _load_resnet(ckpt_name, use_sigmoid):
|
| 506 |
+
m = ResNet18Baseline(dropout=cfg.model.dropout).to(device)
|
| 507 |
+
m.load_state_dict(
|
| 508 |
+
torch.load(ckpt_dir / ckpt_name, map_location="cpu",
|
| 509 |
+
weights_only=True)["model_state"]
|
| 510 |
+
)
|
| 511 |
+
m.eval()
|
| 512 |
+
preds_l, tgts_l, wgts_l = [], [], []
|
| 513 |
+
with torch.no_grad():
|
| 514 |
+
for images, tgts, wgts, _ in tqdm(test_loader, desc=f"ResNet {ckpt_name}"):
|
| 515 |
+
images = images.to(device, non_blocking=True)
|
| 516 |
+
with autocast("cuda", enabled=cfg.training.mixed_precision):
|
| 517 |
+
logits = m(images)
|
| 518 |
+
if use_sigmoid:
|
| 519 |
+
p = torch.sigmoid(logits).cpu().numpy()
|
| 520 |
+
else:
|
| 521 |
+
p = logits.detach().cpu().clone()
|
| 522 |
+
for q, (s, e) in QUESTION_GROUPS.items():
|
| 523 |
+
p[:, s:e] = F.softmax(p[:, s:e], dim=-1)
|
| 524 |
+
p = p.numpy()
|
| 525 |
+
preds_l.append(p)
|
| 526 |
+
tgts_l.append(tgts.numpy())
|
| 527 |
+
wgts_l.append(wgts.numpy())
|
| 528 |
+
return (np.concatenate(preds_l),
|
| 529 |
+
np.concatenate(tgts_l),
|
| 530 |
+
np.concatenate(wgts_l))
|
| 531 |
+
|
| 532 |
+
rn_mse_ckpt = "baseline_resnet18_mse.pt"
|
| 533 |
+
rn_klm_ckpt = "baseline_resnet18_klmse.pt"
|
| 534 |
+
|
| 535 |
+
if (ckpt_dir / rn_mse_ckpt).exists():
|
| 536 |
+
model_results["ResNet-18 + MSE (sigmoid)"] = _load_resnet(
|
| 537 |
+
rn_mse_ckpt, use_sigmoid=True
|
| 538 |
+
)
|
| 539 |
+
if (ckpt_dir / rn_klm_ckpt).exists():
|
| 540 |
+
model_results["ResNet-18 + KL+MSE"] = _load_resnet(
|
| 541 |
+
rn_klm_ckpt, use_sigmoid=False
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
dp = ckpt_dir / "baseline_vit_dirichlet.pt"
|
| 545 |
+
if dp.exists():
|
| 546 |
+
vit_dir = build_dirichlet_model(cfg).to(device)
|
| 547 |
+
vit_dir.load_state_dict(
|
| 548 |
+
torch.load(dp, map_location="cpu", weights_only=True)["model_state"]
|
| 549 |
+
)
|
| 550 |
+
vit_dir.eval()
|
| 551 |
+
d_p, d_t, d_w = [], [], []
|
| 552 |
+
with torch.no_grad():
|
| 553 |
+
for images, tgts, wgts, _ in tqdm(test_loader, desc="Dirichlet"):
|
| 554 |
+
images = images.to(device, non_blocking=True)
|
| 555 |
+
with autocast("cuda", enabled=cfg.training.mixed_precision):
|
| 556 |
+
alpha = vit_dir(images)
|
| 557 |
+
p, t, w = dirichlet_predictions_to_numpy(alpha, tgts, wgts)
|
| 558 |
+
d_p.append(p); d_t.append(t); d_w.append(w)
|
| 559 |
+
model_results["ViT-Base + Dirichlet (Zoobot-style)"] = (
|
| 560 |
+
np.concatenate(d_p),
|
| 561 |
+
np.concatenate(d_t),
|
| 562 |
+
np.concatenate(d_w),
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
df_sel = morphology_selection_benchmark(model_results, save_dir)
|
| 566 |
+
|
| 567 |
+
log.info("=" * 60)
|
| 568 |
+
log.info("DOWNSTREAM F1 @ Ο = 0.8")
|
| 569 |
+
log.info("=" * 60)
|
| 570 |
+
summary = df_sel[df_sel["threshold"] == 0.8][
|
| 571 |
+
["model", "question", "answer", "precision", "recall", "f1"]
|
| 572 |
+
]
|
| 573 |
+
log.info("\n%s\n", summary.to_string(index=False))
|
| 574 |
+
log.info("All outputs saved to: %s", save_dir)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
if __name__ == "__main__":
|
| 578 |
+
main()
|