sreshwarprasad commited on
Commit
e36eee4
Β·
verified Β·
1 Parent(s): 0ebfc32

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /code
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ libgl1 \
7
+ libglib2.0-0 \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ COPY ./requirements.txt /code/requirements.txt
11
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
12
+
13
+ COPY . /code/
14
+
15
+ # Run the FastAPI server on port 7860 (Hugging Face Spaces default)
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import io
4
+ import base64
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torch
8
+ from fastapi import FastAPI, File, UploadFile
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from pydantic import BaseModel
11
+
12
+ # Add current directory to path so HF Space finds it
13
+ import sys
14
+ import os
15
+
16
+ current_dir = os.path.dirname(os.path.abspath(__file__))
17
+ if current_dir not in sys.path:
18
+ sys.path.append(current_dir)
19
+
20
+ from omegaconf import OmegaConf
21
+ from src.model import build_model
22
+ from src.attention_viz import attention_rollout_full, make_overlay
23
+ from src.dataset import QUESTION_GROUPS
24
+ from torchvision import transforms
25
+
26
+ app = FastAPI()
27
+
28
+ app.add_middleware(
29
+ CORSMiddleware,
30
+ allow_origins=["*"],
31
+ allow_credentials=True,
32
+ allow_methods=["*"],
33
+ allow_headers=["*"],
34
+ )
35
+
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ model = None
38
+ cfg = None
39
+ transform = None
40
+
41
+ @app.on_event("startup")
42
+ def load_model():
43
+ global model, cfg, transform
44
+ print("Loading configuration...")
45
+ base_cfg = OmegaConf.load(os.path.join(current_dir, "configs/base.yaml"))
46
+
47
+ # We load the full train config
48
+ try:
49
+ exp_cfg = OmegaConf.load(os.path.join(current_dir, "configs/full_train.yaml"))
50
+ cfg = OmegaConf.merge(base_cfg, exp_cfg)
51
+ except:
52
+ cfg = base_cfg
53
+
54
+ print("Building model...")
55
+ model = build_model(cfg).to(device)
56
+
57
+ ckpt_path = os.path.join(current_dir, "best_full_train.pt")
58
+ if os.path.exists(ckpt_path):
59
+ print(f"Loading checkpoint from {ckpt_path}")
60
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
61
+ model.load_state_dict(ckpt["model_state"])
62
+ else:
63
+ print(f"WARNING: Checkpoint not found at {ckpt_path}")
64
+
65
+ model.eval()
66
+
67
+ # Galaxy Zoo image transform: resize, crop, center, normalize
68
+ # Assuming standard Imagenet + ViT transforms for 224x224
69
+ transform = transforms.Compose([
70
+ transforms.Resize(224),
71
+ transforms.CenterCrop(224),
72
+ transforms.ToTensor(),
73
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
74
+ ])
75
+
76
+ @app.post("/api/predict")
77
+ async def predict(file: UploadFile = File(...)):
78
+ contents = await file.read()
79
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
80
+
81
+ # Transform image
82
+ img_tensor = transform(image).unsqueeze(0).to(device)
83
+
84
+ with torch.no_grad():
85
+ with torch.amp.autocast("cuda", enabled=True):
86
+ logits = model(img_tensor)
87
+
88
+ # Get attention weights
89
+ layers = model.get_all_attention_weights()
90
+
91
+ # Process predictions mapping
92
+ predictions = logits[0].cpu().numpy()
93
+ results = {}
94
+
95
+ # In proper evaluation, hierarchical softmax is applied per question group
96
+ import torch.nn.functional as F
97
+ probs = logits.detach().cpu().clone()
98
+ for q_name, (start, end) in QUESTION_GROUPS.items():
99
+ probs[:, start:end] = F.softmax(probs[:, start:end], dim=-1)
100
+
101
+ probs_np = probs[0].numpy()
102
+
103
+ for q_name, (start, end) in QUESTION_GROUPS.items():
104
+ results[q_name] = probs_np[start:end].tolist()
105
+
106
+ # Generate Attention Heatmap Overlay
107
+ if layers is not None:
108
+ # attention_rollout_full expects list of [B, H, N+1, N+1]
109
+ all_layer_attns = [l.cpu() for l in layers]
110
+ rollout_map = attention_rollout_full(all_layer_attns, patch_size=16, image_size=224)[0]
111
+
112
+ # original image numpy for overlay (denormalised size)
113
+ original_img_np = np.array(image.resize((224, 224)))
114
+ overlay = make_overlay(original_img_np, rollout_map, alpha=0.5, colormap="inferno")
115
+
116
+ # Encode to base64
117
+ overlay_img = Image.fromarray(overlay)
118
+ buffered = io.BytesIO()
119
+ overlay_img.save(buffered, format="PNG")
120
+ heatmap_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
121
+ else:
122
+ heatmap_base64 = None
123
+
124
+ return {
125
+ "predictions": results,
126
+ "heatmap": heatmap_base64
127
+ }
best_full_train.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f31287bca388d29f3b144a9a407f041b4f02c6742ece7614751aab70cd6f04ea
3
+ size 343371682
configs/ablation.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ─────────────────────────────────────────────────────────────
2
+ # configs/ablation.yaml
3
+ # Phase 1: lambda_kl ablation on a 10k subset.
4
+ # Run FIRST before any full training.
5
+ # ─────────────────────────────────────────────────────────────
6
+
7
+ defaults:
8
+ - base
9
+
10
+ experiment_name : "ablation"
11
+
12
+ data:
13
+ n_samples : 10000 # ablation uses 10k for speed
14
+
15
+ training:
16
+ epochs : 15 # sufficient to converge on 10k
17
+
18
+ scheduler:
19
+ T_max : 15
20
+
21
+ early_stopping:
22
+ patience : 5
23
+
24
+ wandb:
25
+ log_attention_every_n_epochs : 99 # disable attention in ablation
configs/base.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ─────────────────────────────────────────────────────────────
2
+ # configs/base.yaml
3
+ # Base configuration for all experiments.
4
+ # All experiment configs inherit and override from this file.
5
+ # ─────────────────────────────────────────────────────────────
6
+
7
+ project_name : "gz2-hierarchical-vit"
8
+ experiment_name : "base"
9
+ seed : 42
10
+
11
+ data:
12
+ parquet_path : "data/labels.parquet"
13
+ image_dir : "data/images"
14
+ image_id_col : "dr7objid"
15
+ image_size : 224
16
+ n_samples : null # null = full dataset
17
+ train_frac : 0.80
18
+ val_frac : 0.10
19
+ test_frac : 0.10
20
+ num_workers : 12
21
+ pin_memory : true
22
+ persistent_workers : true
23
+ prefetch_factor : 4
24
+
25
+ model:
26
+ backbone : "vit_base_patch16_224"
27
+ pretrained : true
28
+ # FIXED: increased from 0.1 β†’ 0.3 to reduce overfitting on 86M-param model.
29
+ # Loss curves showed train/val divergence from epoch ~12 with dropout=0.1.
30
+ dropout : 0.3
31
+
32
+ loss:
33
+ lambda_kl : 0.5 # weight of KL divergence term
34
+ lambda_mse : 0.5 # weight of MSE term
35
+ epsilon : 1.0e-8 # numerical stability clamp
36
+
37
+ training:
38
+ epochs : 100
39
+ batch_size : 64
40
+ learning_rate : 1.0e-4
41
+ weight_decay : 1.0e-4
42
+ grad_clip : 1.0
43
+ mixed_precision : true
44
+
45
+ early_stopping:
46
+ patience : 10
47
+ min_delta : 1.0e-5
48
+ monitor : "val/loss_total"
49
+
50
+ scheduler:
51
+ name : "cosine"
52
+ T_max : 100
53
+ eta_min : 1.0e-6
54
+
55
+ outputs:
56
+ checkpoint_dir : "outputs/checkpoints"
57
+ figures_dir : "outputs/figures"
58
+ log_dir : "outputs/logs"
59
+
60
+ wandb:
61
+ enabled : true
62
+ project : "gz2-hierarchical-vit" # new project name
63
+ log_attention_every_n_epochs : 5
64
+ n_attention_samples : 8
configs/full_train.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ─────────────────────────────────────────────────────────────
2
+ # configs/full_train.yaml
3
+ # Phase 3: Full training on complete 239k dataset.
4
+ # Run after ablation confirms lambda_kl = 0.5 is optimal.
5
+ # ─────────────────────────────────────────────────────────────
6
+
7
+ defaults:
8
+ - base
9
+
10
+ experiment_name : "full_train"
11
+
12
+ data:
13
+ n_samples : null # full 239k dataset
14
+
15
+ training:
16
+ epochs : 100
17
+ batch_size : 64
18
+
19
+ scheduler:
20
+ T_max : 100
21
+
22
+ early_stopping:
23
+ patience : 10
24
+
25
+ wandb:
26
+ log_attention_every_n_epochs : 5
configs/subset_60k.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ─────────────────────────────────────────────────────────────
2
+ # configs/subset_60k.yaml
3
+ # Phase 2: sanity check / quick prototype on 60k subset.
4
+ # Use for code verification before full training.
5
+ # ─────────────────────────────────────────────────────────────
6
+
7
+ defaults:
8
+ - base
9
+
10
+ experiment_name : "subset_60k"
11
+
12
+ data:
13
+ n_samples : 60000 # 60k random galaxies
14
+
15
+ training:
16
+ epochs : 30
17
+
18
+ scheduler:
19
+ T_max : 30
20
+
21
+ early_stopping:
22
+ patience : 7
23
+
24
+ wandb:
25
+ log_attention_every_n_epochs : 5
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ python-multipart
4
+ pydantic
5
+ torch
6
+ torchvision
7
+ numpy
8
+ Pillow
9
+ omegaconf
10
+ timm
11
+ opencv-python-headless
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (128 Bytes). View file
 
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (132 Bytes). View file
 
src/__pycache__/ablation.cpython-310.pyc ADDED
Binary file (7.39 kB). View file
 
src/__pycache__/attention_viz.cpython-310.pyc ADDED
Binary file (9.72 kB). View file
 
src/__pycache__/baselines.cpython-310.pyc ADDED
Binary file (22.4 kB). View file
 
src/__pycache__/baselines.cpython-312.pyc ADDED
Binary file (39.1 kB). View file
 
src/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (7.46 kB). View file
 
src/__pycache__/evaluate_full.cpython-310.pyc ADDED
Binary file (18.2 kB). View file
 
src/__pycache__/evaluate_full.cpython-312.pyc ADDED
Binary file (30.7 kB). View file
 
src/__pycache__/loss.cpython-310.pyc ADDED
Binary file (5.32 kB). View file
 
src/__pycache__/metrics.cpython-310.pyc ADDED
Binary file (8.73 kB). View file
 
src/__pycache__/model.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
src/__pycache__/model.cpython-312.pyc ADDED
Binary file (16 kB). View file
 
src/__pycache__/train.cpython-310.pyc ADDED
Binary file (8.82 kB). View file
 
src/__pycache__/train_single.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
src/__pycache__/train_single.cpython-312.pyc ADDED
Binary file (18.3 kB). View file
 
src/__pycache__/uncertainty_analysis.cpython-310.pyc ADDED
Binary file (17.1 kB). View file
 
src/ablation.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/ablation.py
3
+ ---------------
4
+ Lambda ablation study for the hierarchical KL + MSE loss.
5
+
6
+ Sweeps lambda_kl over [0.0, 0.25, 0.50, 0.75, 1.0] on a 10k subset
7
+ to justify the choice of lambda_kl = 0.5 used in the proposed model.
8
+
9
+ This ablation is reported in the paper as justification for the
10
+ balanced KL + MSE formulation. It is run BEFORE full training.
11
+
12
+ Output
13
+ ------
14
+ outputs/figures/ablation/table_lambda_ablation.csv
15
+ outputs/figures/ablation/fig_lambda_ablation.pdf
16
+ outputs/figures/ablation/fig_lambda_ablation.png
17
+
18
+ Usage
19
+ -----
20
+ cd ~/galaxy
21
+ nohup python -m src.ablation --config configs/ablation.yaml \
22
+ > outputs/logs/ablation.log 2>&1 &
23
+ echo "PID: $!"
24
+ """
25
+
26
+ import argparse
27
+ import copy
28
+ import logging
29
+ import random
30
+ import sys
31
+ import gc
32
+ from pathlib import Path
33
+
34
+ import numpy as np
35
+ import pandas as pd
36
+ import torch
37
+ import matplotlib
38
+ matplotlib.use("Agg")
39
+ import matplotlib.pyplot as plt
40
+ from torch.amp import autocast, GradScaler
41
+ from omegaconf import OmegaConf, DictConfig
42
+ from tqdm import tqdm
43
+
44
+ from src.dataset import build_dataloaders
45
+ from src.model import build_model
46
+ from src.loss import HierarchicalLoss
47
+ from src.metrics import compute_metrics, predictions_to_numpy
48
+
49
+ logging.basicConfig(
50
+ format="%(asctime)s %(levelname)s %(message)s",
51
+ datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
52
+ )
53
+ log = logging.getLogger("ablation")
54
+
55
+ LAMBDA_VALUES = [0.0, 0.25, 0.50, 0.75, 1.0]
56
+ ABLATION_EPOCHS = 15 # sufficient to converge on 10k subset
57
+ ABLATION_SAMPLES = 10000
58
+
59
+
60
+ def _set_seed(seed: int):
61
+ random.seed(seed)
62
+ np.random.seed(seed)
63
+ torch.manual_seed(seed)
64
+ torch.cuda.manual_seed_all(seed)
65
+
66
+
67
+ def run_single(cfg: DictConfig, lambda_kl: float) -> dict:
68
+ """
69
+ Train one model with the given lambda_kl on a 10k subset and
70
+ return test metrics. All other settings are identical across runs.
71
+ """
72
+ _set_seed(cfg.seed)
73
+
74
+ cfg = copy.deepcopy(cfg)
75
+ cfg.loss.lambda_kl = lambda_kl
76
+ cfg.loss.lambda_mse = 1.0 - lambda_kl
77
+ cfg.data.n_samples = ABLATION_SAMPLES
78
+ cfg.training.epochs = ABLATION_EPOCHS
79
+
80
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+ train_loader, val_loader, test_loader = build_dataloaders(cfg)
82
+
83
+ model = build_model(cfg).to(device)
84
+ loss_fn = HierarchicalLoss(cfg)
85
+
86
+ optimizer = torch.optim.AdamW(
87
+ [
88
+ {"params": model.backbone.parameters(),
89
+ "lr": cfg.training.learning_rate * 0.1},
90
+ {"params": model.head.parameters(),
91
+ "lr": cfg.training.learning_rate},
92
+ ],
93
+ weight_decay=cfg.training.weight_decay,
94
+ )
95
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
96
+ optimizer, T_max=ABLATION_EPOCHS, eta_min=1e-6
97
+ )
98
+ scaler = GradScaler("cuda")
99
+
100
+ best_val = float("inf")
101
+ best_state = None
102
+
103
+ for epoch in range(1, ABLATION_EPOCHS + 1):
104
+ # ── train ──────────────────────────────────────────────
105
+ model.train()
106
+ for images, targets, weights, _ in tqdm(
107
+ train_loader, desc=f"Ξ»={lambda_kl:.2f} E{epoch}", leave=False
108
+ ):
109
+ images = images.to(device, non_blocking=True)
110
+ targets = targets.to(device, non_blocking=True)
111
+ weights = weights.to(device, non_blocking=True)
112
+ optimizer.zero_grad(set_to_none=True)
113
+ with autocast("cuda", enabled=True):
114
+ logits = model(images)
115
+ loss, _ = loss_fn(logits, targets, weights)
116
+ scaler.scale(loss).backward()
117
+ scaler.unscale_(optimizer)
118
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
119
+ scaler.step(optimizer)
120
+ scaler.update()
121
+ scheduler.step()
122
+
123
+ # ── validate ───────────────────────────────────────────
124
+ model.eval()
125
+ val_loss = 0.0
126
+ nb = 0
127
+ with torch.no_grad():
128
+ for images, targets, weights, _ in val_loader:
129
+ images = images.to(device, non_blocking=True)
130
+ targets = targets.to(device, non_blocking=True)
131
+ weights = weights.to(device, non_blocking=True)
132
+ with autocast("cuda", enabled=True):
133
+ logits = model(images)
134
+ loss, _ = loss_fn(logits, targets, weights)
135
+ val_loss += loss.item()
136
+ nb += 1
137
+ val_loss /= nb
138
+ log.info(" Ξ»_kl=%.2f epoch=%d val_loss=%.5f", lambda_kl, epoch, val_loss)
139
+
140
+ if val_loss < best_val:
141
+ best_val = val_loss
142
+ best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
143
+
144
+ # ── test evaluation ────────────────────────────────────────
145
+ model.load_state_dict(best_state)
146
+ model.eval()
147
+
148
+ all_preds, all_targets, all_weights = [], [], []
149
+ with torch.no_grad():
150
+ for images, targets, weights, _ in test_loader:
151
+ images = images.to(device, non_blocking=True)
152
+ targets = targets.to(device, non_blocking=True)
153
+ weights = weights.to(device, non_blocking=True)
154
+ with autocast("cuda", enabled=True):
155
+ logits = model(images)
156
+ p, t, w = predictions_to_numpy(logits, targets, weights)
157
+ all_preds.append(p)
158
+ all_targets.append(t)
159
+ all_weights.append(w)
160
+
161
+ all_preds = np.concatenate(all_preds)
162
+ all_targets = np.concatenate(all_targets)
163
+ all_weights = np.concatenate(all_weights)
164
+ metrics = compute_metrics(all_preds, all_targets, all_weights)
165
+
166
+ return {
167
+ "lambda_kl" : lambda_kl,
168
+ "lambda_mse" : round(1.0 - lambda_kl, 2),
169
+ "best_val_loss": round(best_val, 5),
170
+ "mae_weighted" : round(metrics["mae/weighted_avg"], 5),
171
+ "rmse_weighted": round(metrics["rmse/weighted_avg"], 5),
172
+ "ece_mean" : round(metrics["ece/mean"], 5),
173
+ }
174
+
175
+
176
+ def _plot_ablation(df: pd.DataFrame, save_dir: Path):
177
+ best_row = df.loc[df["mae_weighted"].idxmin()]
178
+
179
+ fig, axes = plt.subplots(1, 3, figsize=(15, 4))
180
+
181
+ metrics_cfg = [
182
+ ("mae_weighted", "Weighted MAE", "#2980b9"),
183
+ ("rmse_weighted", "Weighted RMSE", "#c0392b"),
184
+ ("ece_mean", "Mean ECE", "#27ae60"),
185
+ ]
186
+
187
+ for ax, (col, ylabel, color) in zip(axes, metrics_cfg):
188
+ ax.plot(df["lambda_kl"], df[col], "-o", color=color,
189
+ linewidth=2, markersize=8)
190
+ ax.axvline(best_row["lambda_kl"], color="#7f8c8d",
191
+ linestyle="--", alpha=0.8,
192
+ label=f"Best Ξ» = {best_row['lambda_kl']:.2f}")
193
+ ax.set_xlabel("$\\lambda_{\\mathrm{KL}}$ "
194
+ "(0 = pure MSE, 1 = pure KL)", fontsize=11)
195
+ ax.set_ylabel(ylabel, fontsize=11)
196
+ ax.set_title(f"Lambda ablation β€” {ylabel}", fontsize=10)
197
+ ax.legend(fontsize=9)
198
+ ax.grid(True, alpha=0.3)
199
+ ax.set_xticks(df["lambda_kl"].tolist())
200
+
201
+ plt.suptitle(
202
+ "Ablation study: effect of $\\lambda_{\\mathrm{KL}}$ in the hierarchical loss\n"
203
+ f"10,000-sample subset, seed=42. Best: $\\lambda_{{\\mathrm{{KL}}}}$"
204
+ f" = {best_row['lambda_kl']:.2f} (MAE = {best_row['mae_weighted']:.5f})",
205
+ fontsize=11, y=1.02,
206
+ )
207
+ plt.tight_layout()
208
+ fig.savefig(save_dir / "fig_lambda_ablation.pdf", dpi=300, bbox_inches="tight")
209
+ fig.savefig(save_dir / "fig_lambda_ablation.png", dpi=300, bbox_inches="tight")
210
+ plt.close(fig)
211
+ log.info("Saved: fig_lambda_ablation")
212
+
213
+
214
+ def main():
215
+ parser = argparse.ArgumentParser()
216
+ parser.add_argument("--config", required=True)
217
+ args = parser.parse_args()
218
+
219
+ base_cfg = OmegaConf.load("configs/base.yaml")
220
+ exp_cfg = OmegaConf.load(args.config)
221
+ cfg = OmegaConf.merge(base_cfg, exp_cfg)
222
+
223
+ save_dir = Path(cfg.outputs.figures_dir) / "ablation"
224
+ save_dir.mkdir(parents=True, exist_ok=True)
225
+
226
+ results = []
227
+ for lam in LAMBDA_VALUES:
228
+ log.info("=" * 55)
229
+ log.info("Ablation: lambda_kl=%.2f lambda_mse=%.2f",
230
+ lam, 1.0 - lam)
231
+ log.info("=" * 55)
232
+
233
+ result = run_single(cfg, lam)
234
+ results.append(result)
235
+ log.info("Result: %s", result)
236
+
237
+ # Free up RAM and GPU memory
238
+ gc.collect()
239
+ if torch.cuda.is_available():
240
+ torch.cuda.empty_cache()
241
+
242
+ df = pd.DataFrame(results)
243
+ df.to_csv(save_dir / "table_lambda_ablation.csv", index=False)
244
+ log.info("Saved: table_lambda_ablation.csv")
245
+
246
+ print()
247
+ print(df.to_string(index=False))
248
+ print()
249
+
250
+ best = df.loc[df["mae_weighted"].idxmin()]
251
+ log.info("Best: lambda_kl=%.2f MAE=%.5f RMSE=%.5f",
252
+ best["lambda_kl"], best["mae_weighted"], best["rmse_weighted"])
253
+
254
+ _plot_ablation(df, save_dir)
255
+
256
+
257
+ if __name__ == "__main__":
258
+ main()
src/attention_viz.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/attention_viz.py
3
+ --------------------
4
+ Full multi-layer attention rollout for ViT explainability.
5
+
6
+ Theory β€” Abnar & Zuidema (2020)
7
+ --------------------------------
8
+ Each ViT transformer block l produces attention weights A_l of shape
9
+ [B, H, N+1, N+1], where H=12 heads and N+1=197 tokens (196 patches
10
+ + 1 CLS token).
11
+
12
+ Full rollout algorithm:
13
+ 1. Average over heads: A_l = mean_h(attn_l) [B, N+1, N+1]
14
+ 2. Add residual: A_l = 0.5*A_l + 0.5*I [B, N+1, N+1]
15
+ 3. Row-normalise so attention sums to 1 per token.
16
+ 4. Chain layers: R = A_1 βŠ— A_2 βŠ— ... βŠ— A_12 [B, N+1, N+1]
17
+ 5. CLS row, patch cols: rollout = R[:, 0, 1:] [B, 196]
18
+ 6. Reshape 196 β†’ 14Γ—14, upsample to 224Γ—224.
19
+
20
+ FIX applied vs. original
21
+ --------------------------
22
+ The original code used R = bmm(A, R) (left-multiplication) which
23
+ accumulates attention in reverse order. The correct propagation per
24
+ Abnar & Zuidema is R = bmm(R, A) (right-multiplication), which
25
+ tracks how information from the INPUT patches flows forward through
26
+ successive layers into the CLS token.
27
+
28
+ Entropy interpretation
29
+ -----------------------
30
+ CLS attention entropy INCREASES from early to late layers. This is
31
+ the expected and correct behaviour for ViT classification:
32
+ - Early layers (1–8): entropy is low and stable (~1.7–2.0 nats),
33
+ consistent with local morphological feature detection.
34
+ - Late layers (9–12): entropy rises sharply (~2.7–4.5 nats),
35
+ consistent with the CLS token performing global integration β€”
36
+ aggregating information from all patches before the regression head.
37
+ This pattern confirms that early layers specialise in local structure
38
+ while late layers globally aggregate morphological information for
39
+ the final prediction.
40
+
41
+ References
42
+ ----------
43
+ Abnar & Zuidema (2020). Quantifying Attention Flow in Transformers.
44
+ ACL 2020. https://arxiv.org/abs/2005.00928
45
+ """
46
+
47
+ from __future__ import annotations
48
+
49
+ import numpy as np
50
+ import torch
51
+ import torch.nn.functional as F
52
+ import matplotlib
53
+ matplotlib.use("Agg")
54
+ import matplotlib.pyplot as plt
55
+ import matplotlib.cm as cm
56
+ from pathlib import Path
57
+ from typing import Optional, List
58
+
59
+
60
+ # ─────────────────────────────────────────────────────────────
61
+ # Full multi-layer rollout (FIXED)
62
+ # ─────────────────────────────────────────────────────────────
63
+
64
+ def attention_rollout_full(
65
+ all_attn_weights: List[torch.Tensor],
66
+ patch_size: int = 16,
67
+ image_size: int = 224,
68
+ ) -> np.ndarray:
69
+ """
70
+ Full multi-layer attention rollout per Abnar & Zuidema (2020).
71
+
72
+ Parameters
73
+ ----------
74
+ all_attn_weights : list of L tensors, each [B, H, N+1, N+1]
75
+ One tensor per transformer layer, in order 1 β†’ L.
76
+ patch_size : ViT patch size (16 for ViT-Base/16)
77
+ image_size : input image size (224)
78
+
79
+ Returns
80
+ -------
81
+ rollout_maps : [B, image_size, image_size] float32 in [0, 1]
82
+ """
83
+ assert len(all_attn_weights) > 0, "Need at least one attention layer"
84
+
85
+ B, H, N1, _ = all_attn_weights[0].shape
86
+ device = all_attn_weights[0].device
87
+
88
+ # Identity matrix: R_0 = I
89
+ R = torch.eye(N1, device=device).unsqueeze(0).expand(B, -1, -1).clone()
90
+
91
+ for attn in all_attn_weights:
92
+ # Step 1: average over heads β†’ [B, N+1, N+1]
93
+ A = attn.mean(dim=1)
94
+
95
+ # Step 2: residual connection
96
+ I = torch.eye(N1, device=device).unsqueeze(0)
97
+ A = 0.5 * A + 0.5 * I
98
+
99
+ # Step 3: row-normalise
100
+ A = A / A.sum(dim=-1, keepdim=True).clamp(min=1e-8)
101
+
102
+ # Step 4: chain rollout β€” FIXED: R = R @ A (right-multiply)
103
+ # This propagates information forward from input to CLS.
104
+ # Original had R = A @ R (left-multiply) which is incorrect.
105
+ R = torch.bmm(R, A)
106
+
107
+ # Step 5: CLS row (index 0), patch columns (1 onwards)
108
+ cls_attn = R[:, 0, 1:] # [B, 196]
109
+
110
+ # Step 6: reshape and upsample to image size
111
+ grid_size = image_size // patch_size # 14
112
+ cls_attn = cls_attn.reshape(B, 1, grid_size, grid_size)
113
+ rollout = F.interpolate(
114
+ cls_attn, size=(image_size, image_size),
115
+ mode="bilinear", align_corners=False,
116
+ ).squeeze(1) # [B, 224, 224]
117
+
118
+ rollout_np = rollout.cpu().numpy()
119
+ for i in range(B):
120
+ mn, mx = rollout_np[i].min(), rollout_np[i].max()
121
+ rollout_np[i] = (rollout_np[i] - mn) / (mx - mn + 1e-8)
122
+
123
+ return rollout_np.astype(np.float32)
124
+
125
+
126
+ def attention_rollout_single_layer(
127
+ attn_weights: torch.Tensor,
128
+ patch_size: int = 16,
129
+ image_size: int = 224,
130
+ ) -> np.ndarray:
131
+ """Single-layer rollout (backward compatibility). Prefer full rollout."""
132
+ return attention_rollout_full(
133
+ [attn_weights], patch_size=patch_size, image_size=image_size
134
+ )
135
+
136
+
137
+ # ─────────────────────────────────────────────────────────────
138
+ # Visualisation utilities
139
+ # ─────────────────────────────────────────────────────────────
140
+
141
+ def denormalise_image(tensor: torch.Tensor) -> np.ndarray:
142
+ """Reverse ImageNet normalisation β†’ uint8 [H, W, 3]."""
143
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
144
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
145
+ img = tensor.cpu().numpy().transpose(1, 2, 0)
146
+ img = np.clip(img * std + mean, 0, 1)
147
+ return (img * 255).astype(np.uint8)
148
+
149
+
150
+ def make_overlay(
151
+ image_np: np.ndarray,
152
+ rollout: np.ndarray,
153
+ alpha: float = 0.5,
154
+ colormap: str = "inferno",
155
+ ) -> np.ndarray:
156
+ """Blend attention heatmap onto galaxy image."""
157
+ cmap = cm.get_cmap(colormap)
158
+ heatmap = (cmap(rollout)[:, :, :3] * 255).astype(np.uint8)
159
+ overlay = (
160
+ (1 - alpha) * image_np.astype(np.float32) +
161
+ alpha * heatmap.astype(np.float32)
162
+ ).clip(0, 255).astype(np.uint8)
163
+ return overlay
164
+
165
+
166
+ def plot_attention_grid(
167
+ images: torch.Tensor,
168
+ attn_weights,
169
+ image_ids: list,
170
+ save_path: Optional[str] = None,
171
+ alpha: float = 0.5,
172
+ n_cols: int = 4,
173
+ rollout_mode: str = "full",
174
+ ) -> plt.Figure:
175
+ """
176
+ Publication-quality attention rollout gallery.
177
+
178
+ Parameters
179
+ ----------
180
+ images : [N, 3, H, W] galaxy image tensors
181
+ attn_weights : list of L tensors [N, H, N+1, N+1] (full mode)
182
+ or single tensor [N, H, N+1, N+1] (single mode)
183
+ image_ids : dr7objid list for panel titles
184
+ save_path : optional file path to save the figure
185
+ alpha : heatmap opacity (0 = image only, 1 = heatmap only)
186
+ n_cols : number of columns in the grid
187
+ rollout_mode : "full" for 12-layer rollout (recommended)
188
+ """
189
+ N = images.shape[0]
190
+
191
+ if rollout_mode == "full" and isinstance(attn_weights, list):
192
+ rollout_maps = attention_rollout_full(attn_weights)
193
+ else:
194
+ if isinstance(attn_weights, list):
195
+ attn_weights = attn_weights[-1]
196
+ rollout_maps = attention_rollout_single_layer(attn_weights)
197
+
198
+ n_rows = int(np.ceil(N / n_cols))
199
+ fig, axes = plt.subplots(
200
+ n_rows * 2, n_cols,
201
+ figsize=(n_cols * 3, n_rows * 6),
202
+ facecolor="black",
203
+ )
204
+ axes = axes.flatten()
205
+
206
+ for i in range(N):
207
+ img_np = denormalise_image(images[i])
208
+ overlay = make_overlay(img_np, rollout_maps[i], alpha=alpha)
209
+
210
+ row_base = (i // n_cols) * 2
211
+ col = i % n_cols
212
+ ax_img = axes[row_base * n_cols + col]
213
+ ax_attn = axes[(row_base + 1) * n_cols + col]
214
+
215
+ ax_img.imshow(img_np)
216
+ ax_img.axis("off")
217
+ ax_img.set_title(str(image_ids[i])[-6:], color="white",
218
+ fontsize=7, pad=2)
219
+ ax_attn.imshow(overlay)
220
+ ax_attn.axis("off")
221
+
222
+ # Hide empty panels
223
+ for j in range(N, n_rows * n_cols):
224
+ if j < len(axes):
225
+ axes[j].axis("off")
226
+
227
+ mode_label = "Full 12-layer rollout" if rollout_mode == "full" else "Last-layer rollout"
228
+ plt.suptitle(
229
+ f"Galaxy attention rollout β€” {mode_label} (ViT-Base/16)",
230
+ color="white", fontsize=10, y=1.01
231
+ )
232
+ plt.tight_layout(pad=0.3)
233
+
234
+ if save_path is not None:
235
+ Path(save_path).parent.mkdir(parents=True, exist_ok=True)
236
+ fig.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="black")
237
+
238
+ return fig
239
+
240
+
241
+ # ─────────────────────────────────────────────────────────────
242
+ # Attention entropy per layer
243
+ # ─────────────────────────────────────────────────────────────
244
+
245
+ def compute_attention_entropy_per_layer(
246
+ all_attn_weights: List[torch.Tensor],
247
+ ) -> np.ndarray:
248
+ """
249
+ Mean CLS attention entropy per transformer layer.
250
+
251
+ Interpretation
252
+ --------------
253
+ Early layers (1–8): low, stable entropy (~1.7–2.0 nats) consistent
254
+ with local morphological feature detection across patches.
255
+
256
+ Late layers (9–12): rapidly increasing entropy (~2.7–4.5 nats),
257
+ reflecting the CLS token performing global integration β€” attending
258
+ broadly across all patches to aggregate morphological evidence before
259
+ the regression head. This is the expected behaviour for ViT-class
260
+ models and is consistent with prior work on ViT attention patterns.
261
+
262
+ Higher entropy β‰  less discriminative. In late layers, broad attention
263
+ is necessary for global aggregation. The rollout visualisations confirm
264
+ that the final representation correctly emphasises morphological
265
+ structure (spiral arms, bulge, disk) despite diffuse raw attention.
266
+
267
+ Returns
268
+ -------
269
+ entropies : [L] mean entropy per layer in nats
270
+ """
271
+ entropies = []
272
+ for attn in all_attn_weights:
273
+ # CLS token attention to patches: [B, H, N_patches]
274
+ cls_attn = attn[:, :, 0, 1:].clamp(min=1e-9)
275
+ ent = -(cls_attn * cls_attn.log()).sum(dim=-1) # [B, H]
276
+ entropies.append(ent.mean().item())
277
+ return np.array(entropies, dtype=np.float32)
278
+
279
+
280
+ def plot_attention_entropy(
281
+ all_attn_weights: List[torch.Tensor],
282
+ save_path: Optional[str] = None,
283
+ ) -> plt.Figure:
284
+ """
285
+ Plot CLS attention entropy per transformer layer with correct interpretation.
286
+ """
287
+ entropies = compute_attention_entropy_per_layer(all_attn_weights)
288
+ L = len(entropies)
289
+
290
+ fig, ax = plt.subplots(figsize=(8, 4))
291
+ ax.plot(range(1, L + 1), entropies, "b-o", markersize=6, linewidth=2)
292
+
293
+ # Shade regions for interpretation
294
+ ax.axvspan(1, 8.5, alpha=0.07, color="blue",
295
+ label="Local feature detection (layers 1–8)")
296
+ ax.axvspan(8.5, L + 0.5, alpha=0.07, color="orange",
297
+ label="Global integration (layers 9–12)")
298
+
299
+ ax.set_xlabel("Transformer layer", fontsize=12)
300
+ ax.set_ylabel("Mean CLS attention entropy (nats)", fontsize=12)
301
+ ax.set_title(
302
+ "CLS token attention entropy vs. transformer depth\n"
303
+ "Early layers: local morphological detection | "
304
+ "Late layers: global aggregation",
305
+ fontsize=10,
306
+ )
307
+ ax.set_xticks(range(1, L + 1))
308
+ ax.legend(fontsize=9)
309
+ ax.grid(True, alpha=0.3)
310
+ plt.tight_layout()
311
+
312
+ if save_path:
313
+ Path(save_path).parent.mkdir(parents=True, exist_ok=True)
314
+ fig.savefig(save_path, dpi=300, bbox_inches="tight")
315
+
316
+ return fig
src/baselines.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/baselines.py
3
+ ----------------
4
+ Consolidated baseline training for the GZ2 hierarchical probabilistic
5
+ regression paper. ALL baselines are trained from this single script.
6
+
7
+ Replaces the three separate scripts:
8
+ src/baselines.py (was: ResNet-18 MSE + ViT MSE)
9
+ src/run_resnet_kl.py (was: ResNet-18 KL+MSE β€” now merged here)
10
+ src/train_dirichlet.py (was: ViT Dirichlet β€” now merged here)
11
+
12
+ DELETE those three original files after switching to this one.
13
+
14
+ Baselines trained
15
+ -----------------
16
+ B1. ResNet-18 + independent MSE (sigmoid)
17
+ β€” CNN, no hierarchy, no KL. Demonstrates the cost of
18
+ ignoring the decision-tree structure.
19
+
20
+ B2. ResNet-18 + hierarchical KL+MSE
21
+ β€” Same loss as proposed, CNN backbone.
22
+ Isolates ViT vs. CNN contribution.
23
+
24
+ B3. ViT-Base + hierarchical MSE only (no KL)
25
+ β€” Same backbone as proposed, KL term removed.
26
+ Isolates contribution of the KL term.
27
+
28
+ B4. ViT-Base + Dirichlet NLL (Zoobot-style)
29
+ β€” Direct comparison with the established Zoobot approach
30
+ (Walmsley et al. 2022, MNRAS 509, 3966).
31
+
32
+ Proposed model (not trained here β€” trained via src/train.py):
33
+ ViT-Base + hierarchical KL+MSE β†’ outputs/checkpoints/best_full_train.pt
34
+
35
+ Consistency guarantee
36
+ ---------------------
37
+ All baselines use identical:
38
+ - Random seed, data split, batch size, epochs, early stopping
39
+ - AdamW optimiser, CosineAnnealingLR, gradient clipping
40
+ - Image transforms and evaluation metric (compute_metrics on same test split)
41
+
42
+ The ONLY differences between models are the backbone and/or loss function.
43
+
44
+ Usage
45
+ -----
46
+ cd ~/galaxy
47
+ nohup python -m src.baselines --config configs/full_train.yaml \
48
+ > outputs/logs/baselines.log 2>&1 &
49
+ echo "PID: $!"
50
+ """
51
+
52
+ import argparse
53
+ import logging
54
+ import random
55
+ import sys
56
+ from pathlib import Path
57
+
58
+ import numpy as np
59
+ import pandas as pd
60
+ import torch
61
+ import timm
62
+ import torch.nn as nn
63
+ import torch.nn.functional as F
64
+ import matplotlib
65
+ matplotlib.use("Agg")
66
+ import matplotlib.pyplot as plt
67
+ from torch.amp import autocast, GradScaler
68
+ from omegaconf import OmegaConf
69
+ from tqdm import tqdm
70
+
71
+ import wandb
72
+
73
+ from src.dataset import build_dataloaders, QUESTION_GROUPS
74
+ from src.loss import HierarchicalLoss, DirichletLoss, MSEOnlyLoss
75
+ from src.metrics import (compute_metrics, predictions_to_numpy,
76
+ dirichlet_predictions_to_numpy, simplex_violation_rate)
77
+ from src.model import build_model, build_dirichlet_model
78
+
79
+ logging.basicConfig(
80
+ format="%(asctime)s %(levelname)s %(message)s",
81
+ datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
82
+ )
83
+ log = logging.getLogger("baselines")
84
+
85
+ QUESTION_LABELS = {
86
+ "t01": "Smooth or features", "t02": "Edge-on disk",
87
+ "t03": "Bar", "t04": "Spiral arms",
88
+ "t05": "Bulge prominence", "t06": "Odd feature",
89
+ "t07": "Roundedness", "t08": "Odd feature type",
90
+ "t09": "Bulge shape", "t10": "Arms winding",
91
+ "t11": "Arms number",
92
+ }
93
+
94
+
95
+ # ─────────────────────────────────────────────────────────────
96
+ # Reproducibility
97
+ # ─────────────────────────────────────────────────────────────
98
+
99
+ def set_seed(seed: int):
100
+ random.seed(seed)
101
+ np.random.seed(seed)
102
+ torch.manual_seed(seed)
103
+ torch.cuda.manual_seed_all(seed)
104
+ torch.backends.cudnn.deterministic = True
105
+ torch.backends.cudnn.benchmark = False
106
+
107
+
108
+ # ─────────────────────────────────────────────────────────────
109
+ # Early stopping (mirrors train.py exactly)
110
+ # ─────────────────────────────────────────────────────────────
111
+
112
+ class EarlyStopping:
113
+ def __init__(self, patience, min_delta, checkpoint_path):
114
+ self.patience = patience
115
+ self.min_delta = min_delta
116
+ self.checkpoint_path = checkpoint_path
117
+ self.best_loss = float("inf")
118
+ self.counter = 0
119
+ self.best_epoch = 0
120
+
121
+ def step(self, val_loss, model, epoch) -> bool:
122
+ if val_loss < self.best_loss - self.min_delta:
123
+ self.best_loss = val_loss
124
+ self.counter = 0
125
+ self.best_epoch = epoch
126
+ torch.save(
127
+ {"epoch": epoch, "model_state": model.state_dict(),
128
+ "val_loss": val_loss},
129
+ self.checkpoint_path,
130
+ )
131
+ log.info(" [ckpt] saved val_loss=%.6f epoch=%d", val_loss, epoch)
132
+ else:
133
+ self.counter += 1
134
+ log.info(" [early_stop] %d/%d best=%.6f",
135
+ self.counter, self.patience, self.best_loss)
136
+ return self.counter >= self.patience
137
+
138
+ def restore_best(self, model) -> float:
139
+ ckpt = torch.load(self.checkpoint_path, map_location="cpu",
140
+ weights_only=True)
141
+ model.load_state_dict(ckpt["model_state"])
142
+ log.info("Restored best weights epoch=%d val_loss=%.6f",
143
+ ckpt["epoch"], ckpt["val_loss"])
144
+ return ckpt["val_loss"]
145
+
146
+
147
+ # ─────────────────────────────────────────────────────────────
148
+ # Baseline Model 1: ResNet-18 + independent MSE
149
+ # ─────────────────────────────────────────────────────────────
150
+
151
+ class ResNet18Baseline(nn.Module):
152
+ """
153
+ ResNet-18 pretrained on ImageNet with a dropout + linear head.
154
+ Used for both the sigmoid-MSE baseline and the KL+MSE baseline.
155
+ """
156
+ def __init__(self, dropout: float = 0.3):
157
+ super().__init__()
158
+ self.backbone = timm.create_model(
159
+ "resnet18", pretrained=True, num_classes=0
160
+ )
161
+ self.head = nn.Sequential(
162
+ nn.Dropout(p=dropout),
163
+ nn.Linear(self.backbone.num_features, 37),
164
+ )
165
+
166
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
167
+ return self.head(self.backbone(x))
168
+
169
+
170
+ class IndependentMSELoss(nn.Module):
171
+ """
172
+ Plain MSE over all 37 targets independently.
173
+ No hierarchical weighting, no KL divergence.
174
+ Sigmoid applied to predictions before MSE to constrain range [0,1].
175
+
176
+ Note: predictions do NOT sum to 1 per question group by construction.
177
+ This is documented and the simplex_violation_rate metric quantifies
178
+ this invalidity to allow fair comparison with the proposed method.
179
+ """
180
+ def forward(self, predictions, targets, weights):
181
+ pred_prob = torch.sigmoid(predictions)
182
+ loss = F.mse_loss(pred_prob, targets)
183
+ return loss, {"loss/total": loss.detach().item()}
184
+
185
+
186
+ # ─────────────────────────────────────────────────────────────
187
+ # Shared training loop
188
+ # ─────────────────────────────────────────────────────────────
189
+
190
+ def _train_epoch(model, loader, loss_fn, optimizer, scaler,
191
+ device, cfg, epoch, label):
192
+ model.train()
193
+ total = 0.0
194
+ nb = 0
195
+ for images, targets, weights, _ in tqdm(
196
+ loader, desc=f"{label} E{epoch}", leave=False
197
+ ):
198
+ images = images.to(device, non_blocking=True)
199
+ targets = targets.to(device, non_blocking=True)
200
+ weights = weights.to(device, non_blocking=True)
201
+
202
+ optimizer.zero_grad(set_to_none=True)
203
+ with autocast("cuda", enabled=cfg.training.mixed_precision):
204
+ logits = model(images)
205
+ loss, _ = loss_fn(logits, targets, weights)
206
+ scaler.scale(loss).backward()
207
+ scaler.unscale_(optimizer)
208
+ torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.training.grad_clip)
209
+ scaler.step(optimizer)
210
+ scaler.update()
211
+
212
+ total += loss.item()
213
+ nb += 1
214
+ return total / nb
215
+
216
+
217
+ def _train_epoch_dirichlet(model, loader, loss_fn, optimizer, scaler,
218
+ device, cfg, epoch, label):
219
+ """Training epoch for Dirichlet model (outputs alpha, not logits)."""
220
+ model.train()
221
+ total = 0.0
222
+ nb = 0
223
+ for images, targets, weights, _ in tqdm(
224
+ loader, desc=f"{label} E{epoch}", leave=False
225
+ ):
226
+ images = images.to(device, non_blocking=True)
227
+ targets = targets.to(device, non_blocking=True)
228
+ weights = weights.to(device, non_blocking=True)
229
+
230
+ optimizer.zero_grad(set_to_none=True)
231
+ with autocast("cuda", enabled=cfg.training.mixed_precision):
232
+ alpha = model(images)
233
+ loss, _ = loss_fn(alpha, targets, weights)
234
+ scaler.scale(loss).backward()
235
+ scaler.unscale_(optimizer)
236
+ torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.training.grad_clip)
237
+ scaler.step(optimizer)
238
+ scaler.update()
239
+
240
+ total += loss.item()
241
+ nb += 1
242
+ return total / nb
243
+
244
+
245
+ def _val_epoch(model, loader, loss_fn, device, cfg, epoch, label,
246
+ use_sigmoid=False):
247
+ model.eval()
248
+ total = 0.0
249
+ nb = 0
250
+ all_preds, all_targets, all_weights = [], [], []
251
+
252
+ with torch.no_grad():
253
+ for images, targets, weights, _ in tqdm(
254
+ loader, desc=f"{label} Val E{epoch}", leave=False
255
+ ):
256
+ images = images.to(device, non_blocking=True)
257
+ targets = targets.to(device, non_blocking=True)
258
+ weights = weights.to(device, non_blocking=True)
259
+
260
+ with autocast("cuda", enabled=cfg.training.mixed_precision):
261
+ logits = model(images)
262
+ loss, _ = loss_fn(logits, targets, weights)
263
+
264
+ total += loss.item()
265
+ nb += 1
266
+
267
+ if use_sigmoid:
268
+ pred_prob = torch.sigmoid(logits).detach().cpu().numpy()
269
+ else:
270
+ pred_cpu = logits.detach().cpu().clone()
271
+ for q, (s, e) in QUESTION_GROUPS.items():
272
+ pred_cpu[:, s:e] = torch.softmax(pred_cpu[:, s:e], dim=-1)
273
+ pred_prob = pred_cpu.numpy()
274
+
275
+ all_preds.append(pred_prob)
276
+ all_targets.append(targets.detach().cpu().numpy())
277
+ all_weights.append(weights.detach().cpu().numpy())
278
+
279
+ all_preds = np.concatenate(all_preds)
280
+ all_targets = np.concatenate(all_targets)
281
+ all_weights = np.concatenate(all_weights)
282
+ metrics = compute_metrics(all_preds, all_targets, all_weights)
283
+
284
+ return total / nb, metrics
285
+
286
+
287
+ def _val_epoch_dirichlet(model, loader, loss_fn, device, cfg, epoch, label):
288
+ model.eval()
289
+ total = 0.0
290
+ nb = 0
291
+ all_preds, all_targets, all_weights = [], [], []
292
+
293
+ with torch.no_grad():
294
+ for images, targets, weights, _ in tqdm(
295
+ loader, desc=f"{label} Val E{epoch}", leave=False
296
+ ):
297
+ images = images.to(device, non_blocking=True)
298
+ targets = targets.to(device, non_blocking=True)
299
+ weights = weights.to(device, non_blocking=True)
300
+
301
+ with autocast("cuda", enabled=cfg.training.mixed_precision):
302
+ alpha = model(images)
303
+ loss, _ = loss_fn(alpha, targets, weights)
304
+
305
+ total += loss.item()
306
+ nb += 1
307
+
308
+ p, t, w = dirichlet_predictions_to_numpy(alpha, targets, weights)
309
+ all_preds.append(p)
310
+ all_targets.append(t)
311
+ all_weights.append(w)
312
+
313
+ all_preds = np.concatenate(all_preds)
314
+ all_targets = np.concatenate(all_targets)
315
+ all_weights = np.concatenate(all_weights)
316
+ metrics = compute_metrics(all_preds, all_targets, all_weights)
317
+
318
+ return total / nb, metrics
319
+
320
+
321
+ # ─────────────────────────────────────────────────────────────
322
+ # Generic train_and_evaluate (non-Dirichlet)
323
+ # ─────────────────────────────────────────────────────────────
324
+
325
+ def train_and_evaluate(
326
+ model, loss_fn, cfg, device,
327
+ label, checkpoint_path,
328
+ use_layerwise_lr=True,
329
+ use_sigmoid=False,
330
+ ):
331
+ """
332
+ Full training loop consistent with train.py.
333
+ Returns (test_metrics, best_val_loss, best_epoch, history).
334
+ If checkpoint exists, loads it and skips training.
335
+ """
336
+ # Check if checkpoint exists - if so, skip training
337
+ if Path(checkpoint_path).exists():
338
+ log.info("%s: checkpoint found - loading and skipping training", label)
339
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
340
+ model.load_state_dict(ckpt["model_state"])
341
+ best_epoch = ckpt.get("epoch", 0)
342
+ best_val = ckpt.get("val_loss", float("inf"))
343
+ log.info("Restored: epoch=%d, val_loss=%.6f", best_epoch, best_val)
344
+
345
+ # Evaluate on test set
346
+ _, _, test_loader = build_dataloaders(cfg)
347
+ _, test_metrics = _val_epoch(
348
+ model, test_loader, loss_fn, device, cfg,
349
+ epoch=0, label=f"{label}-test", use_sigmoid=use_sigmoid
350
+ )
351
+ return test_metrics, best_val, best_epoch, []
352
+
353
+ train_loader, val_loader, test_loader = build_dataloaders(cfg)
354
+
355
+ if use_layerwise_lr and hasattr(model, "backbone") and hasattr(model, "head"):
356
+ optimizer = torch.optim.AdamW(
357
+ [
358
+ {"params": model.backbone.parameters(),
359
+ "lr": cfg.training.learning_rate * 0.1},
360
+ {"params": model.head.parameters(),
361
+ "lr": cfg.training.learning_rate},
362
+ ],
363
+ weight_decay=cfg.training.weight_decay,
364
+ )
365
+ log.info("%s: layer-wise lr β€” backbone=%.1e head=%.1e",
366
+ label, cfg.training.learning_rate * 0.1, cfg.training.learning_rate)
367
+ else:
368
+ optimizer = torch.optim.AdamW(
369
+ model.parameters(),
370
+ lr=cfg.training.learning_rate,
371
+ weight_decay=cfg.training.weight_decay,
372
+ )
373
+ log.info("%s: single lr=%.1e", label, cfg.training.learning_rate)
374
+
375
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
376
+ optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min
377
+ )
378
+ scaler = GradScaler("cuda")
379
+ early_stop = EarlyStopping(
380
+ patience=cfg.early_stopping.patience,
381
+ min_delta=cfg.early_stopping.min_delta,
382
+ checkpoint_path=checkpoint_path,
383
+ )
384
+
385
+ wandb.init(
386
+ project=cfg.wandb.project,
387
+ name=label,
388
+ config={
389
+ "model": label, "backbone": "resnet18" if "ResNet" in label else "vit_base_patch16_224",
390
+ "batch_size": cfg.training.batch_size, "lr": cfg.training.learning_rate,
391
+ "epochs": cfg.training.epochs, "seed": cfg.seed,
392
+ "lambda_kl": cfg.loss.lambda_kl, "lambda_mse": cfg.loss.lambda_mse,
393
+ },
394
+ reinit=True,
395
+ )
396
+
397
+ history = []
398
+ for epoch in range(1, cfg.training.epochs + 1):
399
+ train_loss = _train_epoch(
400
+ model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch, label
401
+ )
402
+ val_loss, val_metrics = _val_epoch(
403
+ model, val_loader, loss_fn, device, cfg, epoch, label,
404
+ use_sigmoid=use_sigmoid
405
+ )
406
+ scheduler.step()
407
+ lr = scheduler.get_last_lr()[0]
408
+
409
+ val_mae = val_metrics.get("mae/weighted_avg", 0)
410
+ log.info("%s epoch=%d train=%.4f val=%.4f mae=%.4f lr=%.2e",
411
+ label, epoch, train_loss, val_loss, val_mae, lr)
412
+ history.append({
413
+ "epoch": epoch, "train_loss": train_loss,
414
+ "val_loss": val_loss, "val_mae": val_mae,
415
+ })
416
+ wandb.log({
417
+ "train_loss": train_loss, "val_loss": val_loss,
418
+ "val_mae": val_mae, "lr": lr,
419
+ }, step=epoch)
420
+
421
+ if early_stop.step(val_loss, model, epoch):
422
+ log.info("%s: early stopping at epoch %d best=%d",
423
+ label, epoch, early_stop.best_epoch)
424
+ break
425
+
426
+ best_val = early_stop.restore_best(model)
427
+ wandb.finish()
428
+
429
+ log.info("%s: evaluating on test set...", label)
430
+ _, test_metrics = _val_epoch(
431
+ model, test_loader, loss_fn, device, cfg,
432
+ epoch=0, label=f"{label}-test", use_sigmoid=use_sigmoid
433
+ )
434
+ return test_metrics, best_val, early_stop.best_epoch, history
435
+
436
+
437
+ # ─────────────────────────────────────────────────────────────
438
+ # Dirichlet train_and_evaluate
439
+ # ─────────────────────────────────────────────────────────────
440
+
441
+ def train_and_evaluate_dirichlet(model, loss_fn, cfg, device,
442
+ label, checkpoint_path):
443
+ """Training loop for Dirichlet model. Skips training if checkpoint exists."""
444
+ # Check if checkpoint exists - if so, skip training
445
+ if Path(checkpoint_path).exists():
446
+ log.info("%s: checkpoint found - loading and skipping training", label)
447
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
448
+ model.load_state_dict(ckpt["model_state"])
449
+ best_epoch = ckpt.get("epoch", 0)
450
+ best_val = ckpt.get("val_loss", float("inf"))
451
+ log.info("Restored: epoch=%d, val_loss=%.6f", best_epoch, best_val)
452
+
453
+ # Evaluate on test set
454
+ _, _, test_loader = build_dataloaders(cfg)
455
+ _, test_metrics = _val_epoch_dirichlet(
456
+ model, test_loader, loss_fn, device, cfg,
457
+ epoch=0, label=f"{label}-test"
458
+ )
459
+ return test_metrics, best_val, best_epoch, []
460
+
461
+ train_loader, val_loader, test_loader = build_dataloaders(cfg)
462
+
463
+ optimizer = torch.optim.AdamW(
464
+ [
465
+ {"params": model.backbone.parameters(),
466
+ "lr": cfg.training.learning_rate * 0.1},
467
+ {"params": model.head.parameters(),
468
+ "lr": cfg.training.learning_rate},
469
+ ],
470
+ weight_decay=cfg.training.weight_decay,
471
+ )
472
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
473
+ optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min
474
+ )
475
+ scaler = GradScaler("cuda")
476
+ early_stop = EarlyStopping(
477
+ patience=cfg.early_stopping.patience,
478
+ min_delta=cfg.early_stopping.min_delta,
479
+ checkpoint_path=checkpoint_path,
480
+ )
481
+
482
+ wandb.init(
483
+ project=cfg.wandb.project, name=label,
484
+ config={"model": label, "loss": "DirichletNLL",
485
+ "seed": cfg.seed, "epochs": cfg.training.epochs},
486
+ reinit=True,
487
+ )
488
+
489
+ history = []
490
+ for epoch in range(1, cfg.training.epochs + 1):
491
+ train_loss = _train_epoch_dirichlet(
492
+ model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch, label
493
+ )
494
+ val_loss, val_metrics = _val_epoch_dirichlet(
495
+ model, val_loader, loss_fn, device, cfg, epoch, label
496
+ )
497
+ scheduler.step()
498
+ lr = scheduler.get_last_lr()[0]
499
+
500
+ val_mae = val_metrics.get("mae/weighted_avg", 0)
501
+ log.info("%s epoch=%d train=%.4f val=%.4f mae=%.4f lr=%.2e",
502
+ label, epoch, train_loss, val_loss, val_mae, lr)
503
+ history.append({
504
+ "epoch": epoch, "train_loss": train_loss,
505
+ "val_loss": val_loss, "val_mae": val_mae,
506
+ })
507
+ wandb.log({
508
+ "train_loss": train_loss, "val_loss": val_loss,
509
+ "val_mae": val_mae, "lr": lr,
510
+ }, step=epoch)
511
+
512
+ if early_stop.step(val_loss, model, epoch):
513
+ log.info("%s: early stopping at epoch %d", label, epoch)
514
+ break
515
+
516
+ best_val = early_stop.restore_best(model)
517
+ wandb.finish()
518
+
519
+ log.info("%s: evaluating on test set...", label)
520
+ _, test_metrics = _val_epoch_dirichlet(
521
+ model, test_loader, loss_fn, device, cfg, epoch=0, label=f"{label}-test"
522
+ )
523
+ return test_metrics, best_val, early_stop.best_epoch, history
524
+
525
+
526
+ # ─────────────────────────────────────────────────────────────
527
+ # Figures
528
+ # ─────────────────────────────────────────────────────────────
529
+
530
+ def _save_comparison_figures(all_results, all_histories, save_dir):
531
+ """
532
+ Saves:
533
+ 1. Per-question MAE + RMSE bar chart
534
+ 2. Validation MAE learning curves
535
+ 3. Simplex violation table for sigmoid baseline
536
+ All figure names follow IEEE journal conventions.
537
+ """
538
+ q_names = list(QUESTION_GROUPS.keys())
539
+ n_models = len(all_results)
540
+ x = np.arange(len(q_names))
541
+ width = 0.80 / n_models
542
+ palette = ["#c0392b", "#e67e22", "#2980b9", "#27ae60", "#8e44ad"]
543
+
544
+ # ── Figure 1: Per-question MAE and RMSE ───────────────────
545
+ fig, axes = plt.subplots(1, 2, figsize=(16, 6))
546
+ for metric, ax, ylabel in [
547
+ ("mae", axes[0], "Mean Absolute Error (MAE)"),
548
+ ("rmse", axes[1], "Root Mean Squared Error (RMSE)"),
549
+ ]:
550
+ for i, (row_d, color) in enumerate(zip(all_results, palette)):
551
+ vals = [row_d.get(f"{metric}_{q}", np.nan) for q in q_names]
552
+ ax.bar(x + i * width, vals, width,
553
+ label=row_d["model"], color=color,
554
+ alpha=0.85, edgecolor="white", linewidth=0.5)
555
+ ax.set_xticks(x + width * (n_models - 1) / 2)
556
+ ax.set_xticklabels(
557
+ [f"{q}\n({QUESTION_LABELS[q][:10]})" for q in q_names],
558
+ rotation=45, ha="right", fontsize=7,
559
+ )
560
+ ax.set_ylabel(ylabel, fontsize=11)
561
+ ax.set_title(f"Per-question {metric.upper()} β€” baseline comparison", fontsize=11)
562
+ ax.legend(fontsize=7, loc="upper right")
563
+ ax.grid(True, alpha=0.3, axis="y")
564
+ ax.set_axisbelow(True)
565
+
566
+ plt.suptitle(
567
+ "Baseline comparison β€” GZ2 hierarchical probabilistic regression\n"
568
+ "Full 239,267-sample dataset, identical seed/split/protocol",
569
+ fontsize=12, y=1.02,
570
+ )
571
+ plt.tight_layout()
572
+ fig.savefig(save_dir / "fig_baseline_comparison_mae_rmse.pdf",
573
+ dpi=300, bbox_inches="tight")
574
+ fig.savefig(save_dir / "fig_baseline_comparison_mae_rmse.png",
575
+ dpi=300, bbox_inches="tight")
576
+ plt.close(fig)
577
+ log.info("Saved: fig_baseline_comparison_mae_rmse")
578
+
579
+ # ── Figure 2: Validation MAE learning curves ───────────────
580
+ fig2, ax2 = plt.subplots(figsize=(10, 5))
581
+ styles = ["-", "--", "-.", ":", (0, (3, 1, 1, 1))]
582
+ markers = ["o", "s", "^", "D", "v"]
583
+ for (name, hist), ls, color, mk in zip(
584
+ all_histories.items(), styles, palette, markers
585
+ ):
586
+ epochs_h = [h["epoch"] for h in hist]
587
+ val_maes = [h["val_mae"] for h in hist]
588
+ ax2.plot(epochs_h, val_maes, linestyle=ls, color=color, linewidth=1.8,
589
+ label=name, marker=mk, markersize=3, markevery=5)
590
+
591
+ ax2.set_xlabel("Epoch", fontsize=11)
592
+ ax2.set_ylabel("Validation MAE (weighted average)", fontsize=11)
593
+ ax2.set_title("Validation MAE during training β€” all baseline models", fontsize=11)
594
+ ax2.legend(fontsize=9)
595
+ ax2.grid(True, alpha=0.3)
596
+ plt.tight_layout()
597
+ fig2.savefig(save_dir / "fig_baseline_val_mae_curves.pdf",
598
+ dpi=300, bbox_inches="tight")
599
+ fig2.savefig(save_dir / "fig_baseline_val_mae_curves.png",
600
+ dpi=300, bbox_inches="tight")
601
+ plt.close(fig2)
602
+ log.info("Saved: fig_baseline_val_mae_curves")
603
+
604
+
605
+ # ─────────────────────────────────────────────────────────────
606
+ # Main
607
+ # ─────────────────────────────────────────────────────────────
608
+
609
+ def main():
610
+ parser = argparse.ArgumentParser()
611
+ parser.add_argument("--config", required=True)
612
+ args = parser.parse_args()
613
+
614
+ base_cfg = OmegaConf.load("configs/base.yaml")
615
+ exp_cfg = OmegaConf.load(args.config)
616
+ cfg = OmegaConf.merge(base_cfg, exp_cfg)
617
+
618
+ set_seed(cfg.seed)
619
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
620
+ log.info("Device: %s Dataset: %s",
621
+ device, "full 239k" if cfg.data.n_samples is None
622
+ else f"{cfg.data.n_samples:,}")
623
+
624
+ torch.backends.cuda.matmul.allow_tf32 = True
625
+ torch.backends.cudnn.allow_tf32 = True
626
+
627
+ save_dir = Path(cfg.outputs.figures_dir) / "comparison"
628
+ ckpt_dir = Path(cfg.outputs.checkpoint_dir)
629
+ save_dir.mkdir(parents=True, exist_ok=True)
630
+
631
+ all_results = []
632
+ all_histories = {}
633
+
634
+ # ─── B1: ResNet-18 + independent MSE (sigmoid) ────────────
635
+ log.info("=" * 60)
636
+ log.info("B1: ResNet-18 + independent MSE (sigmoid, no hierarchy)")
637
+ log.info("=" * 60)
638
+ set_seed(cfg.seed)
639
+
640
+ rn_mse_model = ResNet18Baseline(dropout=cfg.model.dropout).to(device)
641
+ rn_mse_loss = IndependentMSELoss()
642
+ log.info("ResNet-18 params: %s", f"{sum(p.numel() for p in rn_mse_model.parameters()):,}")
643
+
644
+ rn_mse_metrics, rn_mse_val, rn_mse_epoch, rn_mse_hist = train_and_evaluate(
645
+ rn_mse_model, rn_mse_loss, cfg, device,
646
+ label = "B1-ResNet18-MSE",
647
+ checkpoint_path = str(ckpt_dir / "baseline_resnet18_mse.pt"),
648
+ use_layerwise_lr = False,
649
+ use_sigmoid = True,
650
+ )
651
+
652
+ # Simplex violation for this baseline
653
+ _, _, test_loader_tmp = build_dataloaders(cfg)
654
+ rn_mse_model.eval()
655
+ tmp_preds = []
656
+ with torch.no_grad():
657
+ for images, _, _, _ in test_loader_tmp:
658
+ images = images.to(device, non_blocking=True)
659
+ logits = rn_mse_model(images)
660
+ tmp_preds.append(torch.sigmoid(logits).cpu().numpy())
661
+ tmp_preds = np.concatenate(tmp_preds)
662
+ svr = simplex_violation_rate(tmp_preds, tolerance=0.02)
663
+ log.info("B1 simplex violation rate (mean): %.4f", svr["mean"])
664
+
665
+ row = {
666
+ "model": "ResNet-18 + MSE (sigmoid, no hierarchy)",
667
+ "backbone": "ResNet-18", "loss": "Independent MSE",
668
+ "hierarchy": "None",
669
+ "best_epoch": rn_mse_epoch, "best_val_loss": round(rn_mse_val, 5),
670
+ "mae_weighted" : round(rn_mse_metrics["mae/weighted_avg"], 5),
671
+ "rmse_weighted": round(rn_mse_metrics["rmse/weighted_avg"], 5),
672
+ "simplex_violation_mean": round(svr["mean"], 4),
673
+ }
674
+ for q in QUESTION_GROUPS:
675
+ row[f"mae_{q}"] = round(rn_mse_metrics[f"mae/{q}"], 5)
676
+ row[f"rmse_{q}"] = round(rn_mse_metrics[f"rmse/{q}"], 5)
677
+ all_results.append(row)
678
+ all_histories["ResNet-18 + MSE (sigmoid)"] = rn_mse_hist
679
+ log.info("B1 done: MAE=%.5f RMSE=%.5f SimplexViol=%.4f",
680
+ rn_mse_metrics["mae/weighted_avg"],
681
+ rn_mse_metrics["rmse/weighted_avg"],
682
+ svr["mean"])
683
+
684
+ # ─── B2: ResNet-18 + hierarchical KL+MSE ──────────────────
685
+ log.info("=" * 60)
686
+ log.info("B2: ResNet-18 + hierarchical KL+MSE (same loss as proposed)")
687
+ log.info("=" * 60)
688
+ set_seed(cfg.seed)
689
+
690
+ rn_kl_model = ResNet18Baseline(dropout=cfg.model.dropout).to(device)
691
+ rn_kl_loss = HierarchicalLoss(cfg)
692
+
693
+ rn_kl_metrics, rn_kl_val, rn_kl_epoch, rn_kl_hist = train_and_evaluate(
694
+ rn_kl_model, rn_kl_loss, cfg, device,
695
+ label = "B2-ResNet18-KL+MSE",
696
+ checkpoint_path = str(ckpt_dir / "baseline_resnet18_klmse.pt"),
697
+ use_layerwise_lr = False,
698
+ use_sigmoid = False,
699
+ )
700
+ row = {
701
+ "model": "ResNet-18 + hierarchical KL+MSE",
702
+ "backbone": "ResNet-18", "loss": "Hierarchical KL+MSE (Ξ»=0.5)",
703
+ "hierarchy": "Full (weights + KL)",
704
+ "best_epoch": rn_kl_epoch, "best_val_loss": round(rn_kl_val, 5),
705
+ "mae_weighted" : round(rn_kl_metrics["mae/weighted_avg"], 5),
706
+ "rmse_weighted": round(rn_kl_metrics["rmse/weighted_avg"], 5),
707
+ "simplex_violation_mean": 0.0, # softmax guarantees validity
708
+ }
709
+ for q in QUESTION_GROUPS:
710
+ row[f"mae_{q}"] = round(rn_kl_metrics[f"mae/{q}"], 5)
711
+ row[f"rmse_{q}"] = round(rn_kl_metrics[f"rmse/{q}"], 5)
712
+ all_results.append(row)
713
+ all_histories["ResNet-18 + KL+MSE"] = rn_kl_hist
714
+ log.info("B2 done: MAE=%.5f RMSE=%.5f",
715
+ rn_kl_metrics["mae/weighted_avg"],
716
+ rn_kl_metrics["rmse/weighted_avg"])
717
+
718
+ # ─── B3: ViT-Base + hierarchical MSE only ─────────────────
719
+ log.info("=" * 60)
720
+ log.info("B3: ViT-Base + hierarchical MSE only (no KL term)")
721
+ log.info("=" * 60)
722
+ set_seed(cfg.seed)
723
+
724
+ from omegaconf import OmegaConf as OC
725
+ vit_mse_cfg = OC.merge(cfg, OC.create({"loss": {"lambda_kl": 0.0, "lambda_mse": 1.0}}))
726
+ vit_mse_model = build_model(vit_mse_cfg).to(device)
727
+ vit_mse_loss = MSEOnlyLoss(vit_mse_cfg)
728
+
729
+ vit_mse_metrics, vit_mse_val, vit_mse_epoch, vit_mse_hist = train_and_evaluate(
730
+ vit_mse_model, vit_mse_loss, vit_mse_cfg, device,
731
+ label = "B3-ViT-MSE",
732
+ checkpoint_path = str(ckpt_dir / "baseline_vit_mse.pt"),
733
+ use_layerwise_lr = True,
734
+ use_sigmoid = False,
735
+ )
736
+ row = {
737
+ "model": "ViT-Base + hierarchical MSE (no KL)",
738
+ "backbone": "ViT-Base/16", "loss": "Hierarchical MSE (Ξ»_KL=0)",
739
+ "hierarchy": "Weights only",
740
+ "best_epoch": vit_mse_epoch, "best_val_loss": round(vit_mse_val, 5),
741
+ "mae_weighted" : round(vit_mse_metrics["mae/weighted_avg"], 5),
742
+ "rmse_weighted": round(vit_mse_metrics["rmse/weighted_avg"], 5),
743
+ "simplex_violation_mean": 0.0,
744
+ }
745
+ for q in QUESTION_GROUPS:
746
+ row[f"mae_{q}"] = round(vit_mse_metrics[f"mae/{q}"], 5)
747
+ row[f"rmse_{q}"] = round(vit_mse_metrics[f"rmse/{q}"], 5)
748
+ all_results.append(row)
749
+ all_histories["ViT-Base + MSE only"] = vit_mse_hist
750
+ log.info("B3 done: MAE=%.5f RMSE=%.5f",
751
+ vit_mse_metrics["mae/weighted_avg"],
752
+ vit_mse_metrics["rmse/weighted_avg"])
753
+
754
+ # ─── B4: ViT-Base + Dirichlet NLL (Zoobot-style) ──────────
755
+ log.info("=" * 60)
756
+ log.info("B4: ViT-Base + Dirichlet NLL (Walmsley et al. 2022)")
757
+ log.info("=" * 60)
758
+ set_seed(cfg.seed)
759
+
760
+ vit_dir_model = build_dirichlet_model(cfg).to(device)
761
+ vit_dir_loss = DirichletLoss(cfg)
762
+
763
+ vit_dir_metrics, vit_dir_val, vit_dir_epoch, vit_dir_hist = train_and_evaluate_dirichlet(
764
+ vit_dir_model, vit_dir_loss, cfg, device,
765
+ label = "B4-ViT-Dirichlet",
766
+ checkpoint_path = str(ckpt_dir / "baseline_vit_dirichlet.pt"),
767
+ )
768
+ row = {
769
+ "model": "ViT-Base + Dirichlet NLL (Zoobot-style)",
770
+ "backbone": "ViT-Base/16", "loss": "Dirichlet NLL",
771
+ "hierarchy": "Full (weights + Dirichlet)",
772
+ "best_epoch": vit_dir_epoch, "best_val_loss": round(vit_dir_val, 5),
773
+ "mae_weighted" : round(vit_dir_metrics["mae/weighted_avg"], 5),
774
+ "rmse_weighted": round(vit_dir_metrics["rmse/weighted_avg"], 5),
775
+ "simplex_violation_mean": 0.0,
776
+ }
777
+ for q in QUESTION_GROUPS:
778
+ row[f"mae_{q}"] = round(vit_dir_metrics[f"mae/{q}"], 5)
779
+ row[f"rmse_{q}"] = round(vit_dir_metrics[f"rmse/{q}"], 5)
780
+ all_results.append(row)
781
+ all_histories["ViT-Base + Dirichlet"] = vit_dir_hist
782
+ log.info("B4 done: MAE=%.5f RMSE=%.5f",
783
+ vit_dir_metrics["mae/weighted_avg"],
784
+ vit_dir_metrics["rmse/weighted_avg"])
785
+
786
+ # ─── Proposed: load existing checkpoint for final table ────
787
+ proposed_ckpt = ckpt_dir / "best_full_train.pt"
788
+ if proposed_ckpt.exists():
789
+ log.info("=" * 60)
790
+ log.info("PROPOSED: Loading ViT-Base + hierarchical KL+MSE")
791
+ log.info("=" * 60)
792
+ proposed_model = build_model(cfg).to(device)
793
+ proposed_model.load_state_dict(
794
+ torch.load(proposed_ckpt, map_location="cpu", weights_only=True)["model_state"]
795
+ )
796
+ _, _, test_loader_p = build_dataloaders(cfg)
797
+ _, proposed_metrics = _val_epoch(
798
+ proposed_model, test_loader_p, HierarchicalLoss(cfg), device, cfg,
799
+ epoch=0, label="Proposed-test", use_sigmoid=False
800
+ )
801
+ ckpt_info = torch.load(proposed_ckpt, map_location="cpu", weights_only=True)
802
+ row = {
803
+ "model": "ViT-Base + hierarchical KL+MSE (proposed)",
804
+ "backbone": "ViT-Base/16", "loss": "Hierarchical KL+MSE (Ξ»=0.5)",
805
+ "hierarchy": "Full (weights + KL)",
806
+ "best_epoch": ckpt_info["epoch"],
807
+ "best_val_loss": round(ckpt_info["val_loss"], 5),
808
+ "mae_weighted" : round(proposed_metrics["mae/weighted_avg"], 5),
809
+ "rmse_weighted": round(proposed_metrics["rmse/weighted_avg"], 5),
810
+ "simplex_violation_mean": 0.0,
811
+ }
812
+ for q in QUESTION_GROUPS:
813
+ row[f"mae_{q}"] = round(proposed_metrics[f"mae/{q}"], 5)
814
+ row[f"rmse_{q}"] = round(proposed_metrics[f"rmse/{q}"], 5)
815
+ all_results.append(row)
816
+ log.info("Proposed: MAE=%.5f RMSE=%.5f",
817
+ proposed_metrics["mae/weighted_avg"],
818
+ proposed_metrics["rmse/weighted_avg"])
819
+
820
+ # ─── Save results ──────────────────────────────────────────
821
+ df = pd.DataFrame(all_results)
822
+ df.to_csv(save_dir / "table_baseline_comparison.csv", index=False)
823
+
824
+ summary_cols = ["model", "loss", "hierarchy", "best_epoch",
825
+ "best_val_loss", "mae_weighted", "rmse_weighted",
826
+ "simplex_violation_mean"]
827
+ summary = df[[c for c in summary_cols if c in df.columns]].copy()
828
+ summary.to_csv(save_dir / "table_baseline_summary.csv", index=False)
829
+
830
+ print()
831
+ print("=" * 80)
832
+ print("BASELINE COMPARISON β€” FINAL RESULTS")
833
+ print("=" * 80)
834
+ print(summary.to_string(index=False))
835
+ print()
836
+
837
+ # ─── Figures ───────────────────────────────────────────────
838
+ _save_comparison_figures(all_results, all_histories, save_dir)
839
+
840
+ log.info("All baseline outputs saved to: %s", save_dir)
841
+
842
+
843
+ if __name__ == "__main__":
844
+ main()
src/dataset.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/dataset.py
3
+ --------------
4
+ Galaxy Zoo 2 dataset loader for hierarchical probabilistic regression.
5
+
6
+ The GZ2 decision tree has 11 questions (t01-t11) with 37 total answer
7
+ columns. Each question is a conditional probability vector β€” not
8
+ independent regression targets.
9
+
10
+ Hierarchy (parent answer -> child question):
11
+ t01_a02 (features/disk) -> t02, t03, t04, t05, t06
12
+ t02_a05 (not edge-on) -> t03, t04
13
+ t04_a08 (has spiral) -> t10, t11
14
+ t06_a14 (odd feature) -> t08
15
+ t01_a01 (smooth) -> t07
16
+ t02_a04 (edge-on) -> t09
17
+
18
+ References
19
+ ----------
20
+ Willett et al. (2013), MNRAS 435, 2835
21
+ Hart et al. (2016), MNRAS 461, 3663
22
+ """
23
+
24
+ import math
25
+ import logging
26
+ from pathlib import Path
27
+
28
+ import numpy as np
29
+ import pandas as pd
30
+ import torch
31
+ from torch.utils.data import Dataset, DataLoader
32
+ from torchvision import transforms
33
+ from PIL import Image
34
+ from omegaconf import DictConfig
35
+
36
+ log = logging.getLogger(__name__)
37
+
38
+ # ─────────────────────────────────────────────────────────────
39
+ # GZ2 decision tree definition
40
+ # ─────────────────────────────────────────────────────────────
41
+
42
+ LABEL_COLUMNS = [
43
+ # t01: smooth or features?
44
+ "t01_smooth_or_features_a01_smooth_debiased",
45
+ "t01_smooth_or_features_a02_features_or_disk_debiased",
46
+ "t01_smooth_or_features_a03_star_or_artifact_debiased",
47
+ # t02: edge-on?
48
+ "t02_edgeon_a04_yes_debiased",
49
+ "t02_edgeon_a05_no_debiased",
50
+ # t03: bar?
51
+ "t03_bar_a06_bar_debiased",
52
+ "t03_bar_a07_no_bar_debiased",
53
+ # t04: spiral?
54
+ "t04_spiral_a08_spiral_debiased",
55
+ "t04_spiral_a09_no_spiral_debiased",
56
+ # t05: bulge prominence
57
+ "t05_bulge_prominence_a10_no_bulge_debiased",
58
+ "t05_bulge_prominence_a11_just_noticeable_debiased",
59
+ "t05_bulge_prominence_a12_obvious_debiased",
60
+ "t05_bulge_prominence_a13_dominant_debiased",
61
+ # t06: odd feature?
62
+ "t06_odd_a14_yes_debiased",
63
+ "t06_odd_a15_no_debiased",
64
+ # t07: roundedness (smooth galaxies)
65
+ "t07_rounded_a16_completely_round_debiased",
66
+ "t07_rounded_a17_in_between_debiased",
67
+ "t07_rounded_a18_cigar_shaped_debiased",
68
+ # t08: odd feature type
69
+ "t08_odd_feature_a19_ring_debiased",
70
+ "t08_odd_feature_a20_lens_or_arc_debiased",
71
+ "t08_odd_feature_a21_disturbed_debiased",
72
+ "t08_odd_feature_a22_irregular_debiased",
73
+ "t08_odd_feature_a23_other_debiased",
74
+ "t08_odd_feature_a24_merger_debiased",
75
+ "t08_odd_feature_a38_dust_lane_debiased",
76
+ # t09: bulge shape (edge-on only)
77
+ "t09_bulge_shape_a25_rounded_debiased",
78
+ "t09_bulge_shape_a26_boxy_debiased",
79
+ "t09_bulge_shape_a27_no_bulge_debiased",
80
+ # t10: arms winding
81
+ "t10_arms_winding_a28_tight_debiased",
82
+ "t10_arms_winding_a29_medium_debiased",
83
+ "t10_arms_winding_a30_loose_debiased",
84
+ # t11: arms number
85
+ "t11_arms_number_a31_1_debiased",
86
+ "t11_arms_number_a32_2_debiased",
87
+ "t11_arms_number_a33_3_debiased",
88
+ "t11_arms_number_a34_4_debiased",
89
+ "t11_arms_number_a36_more_than_4_debiased",
90
+ "t11_arms_number_a37_cant_tell_debiased",
91
+ ]
92
+
93
+ # Slice indices into LABEL_COLUMNS for each question group.
94
+ QUESTION_GROUPS = {
95
+ "t01": (0, 3),
96
+ "t02": (3, 5),
97
+ "t03": (5, 7),
98
+ "t04": (7, 9),
99
+ "t05": (9, 13),
100
+ "t06": (13, 15),
101
+ "t07": (15, 18),
102
+ "t08": (18, 25),
103
+ "t09": (25, 28),
104
+ "t10": (28, 31),
105
+ "t11": (31, 37),
106
+ }
107
+
108
+ # Parent answer column for hierarchical branch weighting.
109
+ # w_q = vote fraction of the parent answer that unlocks question q.
110
+ # t01 is the root question; its weight is always 1.0.
111
+ QUESTION_PARENT_COL = {
112
+ "t01": None,
113
+ "t02": "t01_smooth_or_features_a02_features_or_disk_debiased",
114
+ "t03": "t02_edgeon_a05_no_debiased",
115
+ "t04": "t02_edgeon_a05_no_debiased",
116
+ "t05": "t01_smooth_or_features_a02_features_or_disk_debiased",
117
+ "t06": "t01_smooth_or_features_a02_features_or_disk_debiased",
118
+ "t07": "t01_smooth_or_features_a01_smooth_debiased",
119
+ "t08": "t06_odd_a14_yes_debiased",
120
+ "t09": "t02_edgeon_a04_yes_debiased",
121
+ "t10": "t04_spiral_a08_spiral_debiased",
122
+ "t11": "t04_spiral_a08_spiral_debiased",
123
+ }
124
+
125
+ N_LABELS = len(LABEL_COLUMNS) # 37
126
+
127
+
128
+ # ─────────────────────────────────────────────────────────────
129
+ # Image transforms
130
+ # ─────────────────────────────────────────────────────────────
131
+
132
+ def get_transforms(image_size: int, split: str) -> transforms.Compose:
133
+ """
134
+ Training: random flips + rotations (galaxies have no preferred orientation),
135
+ colour jitter (instrument variation), ImageNet normalisation.
136
+ Val/Test: resize only, ImageNet normalisation.
137
+ """
138
+ mean = [0.485, 0.456, 0.406]
139
+ std = [0.229, 0.224, 0.225]
140
+
141
+ if split == "train":
142
+ return transforms.Compose([
143
+ transforms.Resize((image_size + 16, image_size + 16)),
144
+ transforms.RandomCrop(image_size),
145
+ transforms.RandomHorizontalFlip(),
146
+ transforms.RandomVerticalFlip(),
147
+ transforms.RandomRotation(180),
148
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
149
+ transforms.ToTensor(),
150
+ transforms.Normalize(mean=mean, std=std),
151
+ ])
152
+ else:
153
+ return transforms.Compose([
154
+ transforms.Resize((image_size, image_size)),
155
+ transforms.ToTensor(),
156
+ transforms.Normalize(mean=mean, std=std),
157
+ ])
158
+
159
+
160
+ # ─────────────────────────────────────────────────────────────
161
+ # Dataset
162
+ # ─────────────────────────────────────────────────────────────
163
+
164
+ class GalaxyZoo2Dataset(Dataset):
165
+ """
166
+ PyTorch Dataset for Galaxy Zoo 2.
167
+
168
+ Returns
169
+ -------
170
+ image : FloatTensor [3, H, W] normalised galaxy image
171
+ targets : FloatTensor [37] vote fraction vector
172
+ weights : FloatTensor [11] per-question hierarchical weights
173
+ image_id : int dr7objid for traceability
174
+ """
175
+
176
+ def __init__(self, df: pd.DataFrame, image_dir: str, transform):
177
+ self.df = df.reset_index(drop=True)
178
+ self.image_dir = Path(image_dir)
179
+ self.transform = transform
180
+
181
+ self.labels = self.df[LABEL_COLUMNS].values.astype(np.float32)
182
+ self.weights = self._compute_weights()
183
+ self.image_ids = self.df["dr7objid"].tolist()
184
+
185
+ def _compute_weights(self) -> np.ndarray:
186
+ n = len(self.df)
187
+ q_names = list(QUESTION_GROUPS.keys())
188
+ weights = np.ones((n, len(q_names)), dtype=np.float32)
189
+ for q_idx, q_name in enumerate(q_names):
190
+ parent_col = QUESTION_PARENT_COL[q_name]
191
+ if parent_col is not None:
192
+ weights[:, q_idx] = self.df[parent_col].values.astype(np.float32)
193
+ return weights
194
+
195
+ def __len__(self) -> int:
196
+ return len(self.df)
197
+
198
+ def __getitem__(self, idx: int):
199
+ image_id = self.image_ids[idx]
200
+ img_path = self.image_dir / f"{image_id}.jpg"
201
+ try:
202
+ image = Image.open(img_path).convert("RGB")
203
+ except FileNotFoundError:
204
+ raise FileNotFoundError(
205
+ f"Image not found: {img_path}. "
206
+ f"Check dr7objid {image_id} has a matching .jpg file."
207
+ )
208
+ image = self.transform(image)
209
+ targets = torch.from_numpy(self.labels[idx])
210
+ weights = torch.from_numpy(self.weights[idx])
211
+ return image, targets, weights, image_id
212
+
213
+
214
+ # ─────────────────────────────────────────────────────────────
215
+ # DataLoader factory
216
+ # ─────────────────────────────────────────────────────────────
217
+
218
+ def build_dataloaders(cfg: DictConfig):
219
+ """Build train / val / test DataLoaders from the labels parquet."""
220
+ log.info("Loading parquet: %s", cfg.data.parquet_path)
221
+ df = pd.read_parquet(cfg.data.parquet_path)
222
+
223
+ missing = [c for c in LABEL_COLUMNS if c not in df.columns]
224
+ if missing:
225
+ raise ValueError(f"Missing columns in parquet: {missing}")
226
+
227
+ if cfg.data.n_samples is not None:
228
+ n = int(cfg.data.n_samples)
229
+ log.info("Using subset of %d samples (full dataset: %d)", n, len(df))
230
+ df = df.sample(n=n, random_state=cfg.seed).reset_index(drop=True)
231
+ else:
232
+ log.info("Using full dataset: %d samples", len(df))
233
+
234
+ rng = np.random.default_rng(cfg.seed)
235
+ idx = rng.permutation(len(df))
236
+ n = len(df)
237
+ n_train = math.floor(cfg.data.train_frac * n)
238
+ n_val = math.floor(cfg.data.val_frac * n)
239
+
240
+ train_idx = idx[:n_train]
241
+ val_idx = idx[n_train : n_train + n_val]
242
+ test_idx = idx[n_train + n_val :]
243
+
244
+ log.info("Split β€” train: %d val: %d test: %d",
245
+ len(train_idx), len(val_idx), len(test_idx))
246
+
247
+ image_size = cfg.data.image_size
248
+ train_ds = GalaxyZoo2Dataset(
249
+ df.iloc[train_idx], cfg.data.image_dir,
250
+ get_transforms(image_size, "train"))
251
+ val_ds = GalaxyZoo2Dataset(
252
+ df.iloc[val_idx], cfg.data.image_dir,
253
+ get_transforms(image_size, "val"))
254
+ test_ds = GalaxyZoo2Dataset(
255
+ df.iloc[test_idx], cfg.data.image_dir,
256
+ get_transforms(image_size, "test"))
257
+
258
+ common = dict(
259
+ batch_size = cfg.training.batch_size,
260
+ num_workers = cfg.data.num_workers,
261
+ pin_memory = cfg.data.pin_memory,
262
+ persistent_workers = getattr(cfg.data, "persistent_workers", True),
263
+ prefetch_factor = getattr(cfg.data, "prefetch_factor", 4),
264
+ drop_last = False,
265
+ )
266
+ train_loader = DataLoader(train_ds, shuffle=True, **common)
267
+ val_loader = DataLoader(val_ds, shuffle=False, **common)
268
+ test_loader = DataLoader(test_ds, shuffle=False, **common)
269
+
270
+ return train_loader, val_loader, test_loader
src/evaluate_full.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/evaluate_full.py
3
+ --------------------
4
+ Full evaluation of all trained models on the held-out test set.
5
+
6
+ Generates all paper figures and tables:
7
+
8
+ Tables
9
+ ------
10
+ table_metrics_proposed.csv β€” MAE / RMSE / bias / ECE for proposed model
11
+ table_reached_branch_mae.csv β€” reached-branch MAE across all 5 models
12
+ table_simplex_violation.csv β€” simplex validity for sigmoid baseline
13
+
14
+ Figures (PDF + PNG, IEEE naming convention)
15
+ -------------------------------------------
16
+ fig_scatter_predicted_vs_true.pdf β€” predicted vs true vote fractions (proposed)
17
+ fig_calibration_reliability.pdf β€” reliability diagrams, all models
18
+ fig_ece_comparison.pdf β€” ECE bar chart, all models
19
+ fig_attention_rollout_gallery.pdf β€” full 12-layer attention rollout gallery
20
+ fig_attention_entropy_depth.pdf β€” CLS attention entropy vs. layer depth
21
+
22
+ Usage
23
+ -----
24
+ cd ~/galaxy
25
+ nohup python -m src.evaluate_full --config configs/full_train.yaml \
26
+ > outputs/logs/evaluate.log 2>&1 &
27
+ echo "PID: $!"
28
+ """
29
+
30
+ import argparse
31
+ import logging
32
+ import sys
33
+ from pathlib import Path
34
+
35
+ import numpy as np
36
+ import pandas as pd
37
+ import torch
38
+ import torch.nn.functional as F
39
+ import matplotlib
40
+ matplotlib.use("Agg")
41
+ import matplotlib.pyplot as plt
42
+ from torch.amp import autocast
43
+ from omegaconf import OmegaConf
44
+ from tqdm import tqdm
45
+
46
+ from src.dataset import build_dataloaders, QUESTION_GROUPS
47
+ from src.model import build_model, build_dirichlet_model
48
+ from src.metrics import (compute_metrics, predictions_to_numpy,
49
+ compute_reached_branch_mae_table,
50
+ dirichlet_predictions_to_numpy,
51
+ simplex_violation_rate, _compute_ece)
52
+ from src.attention_viz import plot_attention_grid, plot_attention_entropy
53
+ from src.baselines import ResNet18Baseline
54
+
55
+ logging.basicConfig(
56
+ format="%(asctime)s %(levelname)s %(name)s %(message)s",
57
+ datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
58
+ )
59
+ log = logging.getLogger("evaluate_full")
60
+
61
+ # ── Global matplotlib style ────────────────────────────────────────────────────
62
+ plt.rcParams.update({
63
+ "figure.dpi" : 150,
64
+ "savefig.dpi" : 300,
65
+ "font.family" : "serif",
66
+ "font.size" : 11,
67
+ "axes.titlesize" : 11,
68
+ "axes.labelsize" : 11,
69
+ "xtick.labelsize" : 9,
70
+ "ytick.labelsize" : 9,
71
+ "legend.fontsize" : 9,
72
+ "figure.facecolor" : "white",
73
+ "axes.facecolor" : "white",
74
+ "axes.grid" : True,
75
+ "grid.alpha" : 0.3,
76
+ "pdf.fonttype" : 42, # editable text in PDF
77
+ "ps.fonttype" : 42,
78
+ })
79
+
80
+ QUESTION_LABELS = {
81
+ "t01": "Smooth or features",
82
+ "t02": "Edge-on disk",
83
+ "t03": "Bar",
84
+ "t04": "Spiral arms",
85
+ "t05": "Bulge prominence",
86
+ "t06": "Odd feature",
87
+ "t07": "Roundedness",
88
+ "t08": "Odd feature type",
89
+ "t09": "Bulge shape",
90
+ "t10": "Arms winding",
91
+ "t11": "Arms number",
92
+ }
93
+
94
+ # Consistent colours and line styles for all models across all figures
95
+ MODEL_COLORS = {
96
+ "ResNet-18 + MSE (sigmoid)" : "#c0392b",
97
+ "ResNet-18 + KL+MSE" : "#e67e22",
98
+ "ViT-Base + MSE only" : "#2980b9",
99
+ "ViT-Base + KL+MSE (proposed)" : "#27ae60",
100
+ "ViT-Base + Dirichlet (Zoobot-style)": "#8e44ad",
101
+ }
102
+ MODEL_STYLES = {
103
+ "ResNet-18 + MSE (sigmoid)" : "-",
104
+ "ResNet-18 + KL+MSE" : "-.",
105
+ "ViT-Base + MSE only" : "--",
106
+ "ViT-Base + KL+MSE (proposed)" : "-",
107
+ "ViT-Base + Dirichlet (Zoobot-style)": ":",
108
+ }
109
+
110
+
111
+ # ─────────────────────────────────────────────────────────────
112
+ # Inference helpers
113
+ # ─────────────────────────────────────────────────────────────
114
+
115
+ def _infer_vit(model, loader, device, cfg,
116
+ collect_attn=True, n_attn=16):
117
+ model.eval()
118
+ all_preds, all_targets, all_weights = [], [], []
119
+ attn_images, all_layer_attns, attn_ids = [], [], []
120
+ attn_done = False
121
+
122
+ with torch.no_grad():
123
+ for images, targets, weights, image_ids in tqdm(loader, desc="ViT inference"):
124
+ images = images.to(device, non_blocking=True)
125
+ targets = targets.to(device, non_blocking=True)
126
+ weights = weights.to(device, non_blocking=True)
127
+ with autocast("cuda", enabled=cfg.training.mixed_precision):
128
+ logits = model(images)
129
+ p, t, w = predictions_to_numpy(logits, targets, weights)
130
+ all_preds.append(p)
131
+ all_targets.append(t)
132
+ all_weights.append(w)
133
+
134
+ if collect_attn and not attn_done:
135
+ layers = model.get_all_attention_weights()
136
+ if layers is not None:
137
+ n = min(n_attn, images.shape[0])
138
+ attn_images.append(images[:n].cpu())
139
+ all_layer_attns.append([l[:n].cpu() for l in layers])
140
+ attn_ids.extend([int(i) for i in image_ids[:n]])
141
+ if len(attn_ids) >= n_attn:
142
+ attn_done = True
143
+
144
+ preds = np.concatenate(all_preds)
145
+ targets = np.concatenate(all_targets)
146
+ weights = np.concatenate(all_weights)
147
+
148
+ attn_imgs_t = torch.cat(attn_images, dim=0)[:n_attn] if attn_images else None
149
+ merged_layers = None
150
+ if all_layer_attns:
151
+ merged_layers = [
152
+ torch.cat([b[li] for b in all_layer_attns], dim=0)[:n_attn]
153
+ for li in range(len(all_layer_attns[0]))
154
+ ]
155
+
156
+ return preds, targets, weights, attn_imgs_t, merged_layers, attn_ids
157
+
158
+
159
+ def _infer_resnet(model, loader, device, cfg, use_sigmoid: bool):
160
+ model.eval()
161
+ all_preds, all_targets, all_weights = [], [], []
162
+ with torch.no_grad():
163
+ for images, targets, weights, _ in tqdm(loader, desc="ResNet inference"):
164
+ images = images.to(device, non_blocking=True)
165
+ with autocast("cuda", enabled=cfg.training.mixed_precision):
166
+ logits = model(images)
167
+ if use_sigmoid:
168
+ pred = torch.sigmoid(logits).cpu().numpy()
169
+ else:
170
+ pred = logits.detach().cpu().clone()
171
+ for q, (s, e) in QUESTION_GROUPS.items():
172
+ pred[:, s:e] = F.softmax(pred[:, s:e], dim=-1)
173
+ pred = pred.numpy()
174
+ all_preds.append(pred)
175
+ all_targets.append(targets.numpy())
176
+ all_weights.append(weights.numpy())
177
+ return (np.concatenate(all_preds),
178
+ np.concatenate(all_targets),
179
+ np.concatenate(all_weights))
180
+
181
+
182
+ def _infer_dirichlet(model, loader, device, cfg):
183
+ model.eval()
184
+ all_preds, all_targets, all_weights = [], [], []
185
+ with torch.no_grad():
186
+ for images, targets, weights, _ in tqdm(loader, desc="Dirichlet inference"):
187
+ images = images.to(device, non_blocking=True)
188
+ with autocast("cuda", enabled=cfg.training.mixed_precision):
189
+ alpha = model(images)
190
+ p, t, w = dirichlet_predictions_to_numpy(alpha, targets, weights)
191
+ all_preds.append(p)
192
+ all_targets.append(t)
193
+ all_weights.append(w)
194
+ return (np.concatenate(all_preds),
195
+ np.concatenate(all_targets),
196
+ np.concatenate(all_weights))
197
+
198
+
199
+ # ─────────────────────────────────────────────────────────────
200
+ # Figure 1: Predicted vs true scatter (proposed model)
201
+ # ─────────────────────────────────────────────────────────────
202
+
203
+ def fig_scatter_predicted_vs_true(preds, targets, weights, save_dir):
204
+ path_pdf = save_dir / "fig_scatter_predicted_vs_true.pdf"
205
+ path_png = save_dir / "fig_scatter_predicted_vs_true.png"
206
+ if path_pdf.exists() and path_png.exists():
207
+ log.info("Skip (exists): fig_scatter_predicted_vs_true"); return
208
+
209
+ fig, axes = plt.subplots(3, 4, figsize=(16, 12))
210
+ axes = axes.flatten()
211
+
212
+ for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
213
+ ax = axes[q_idx]
214
+ mask = weights[:, q_idx] >= 0.05
215
+ pq = preds[mask, start:end].flatten()
216
+ tq = targets[mask, start:end].flatten()
217
+
218
+ ax.scatter(tq, pq, alpha=0.06, s=1, color="#2563eb", rasterized=True)
219
+ ax.plot([0, 1], [0, 1], "r--", linewidth=1, alpha=0.8)
220
+ ax.set_xlim(0, 1); ax.set_ylim(0, 1)
221
+ ax.set_xlabel("True vote fraction")
222
+ ax.set_ylabel("Predicted vote fraction")
223
+ ax.set_title(
224
+ f"{q_name}: {QUESTION_LABELS[q_name]}\n"
225
+ f"$n$ = {mask.sum():,} (w β‰₯ 0.05)",
226
+ fontsize=9,
227
+ )
228
+ ax.set_aspect("equal")
229
+ mae = np.abs(pq - tq).mean()
230
+ ax.text(0.05, 0.92, f"MAE = {mae:.3f}",
231
+ transform=ax.transAxes, fontsize=8,
232
+ bbox=dict(boxstyle="round,pad=0.2", facecolor="white",
233
+ edgecolor="grey", alpha=0.85))
234
+
235
+ axes[-1].axis("off")
236
+ plt.suptitle(
237
+ "Predicted vs. true vote fractions β€” reached branches (w β‰₯ 0.05)\n"
238
+ "ViT-Base/16 + hierarchical KL+MSE (proposed model, test set)",
239
+ fontsize=12,
240
+ )
241
+ plt.tight_layout()
242
+ fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
243
+ fig.savefig(path_png, dpi=300, bbox_inches="tight")
244
+ plt.close(fig)
245
+ log.info("Saved: fig_scatter_predicted_vs_true")
246
+
247
+
248
+ # ─────────────────────────────────────────────────────────────
249
+ # Figure 2: Calibration reliability diagrams
250
+ # ─────────────────────────────────────────────────────────────
251
+
252
+ def fig_calibration_reliability(model_results, save_dir, n_bins=15):
253
+ path_pdf = save_dir / "fig_calibration_reliability.pdf"
254
+ path_png = save_dir / "fig_calibration_reliability.png"
255
+ if path_pdf.exists() and path_png.exists():
256
+ log.info("Skip (exists): fig_calibration_reliability"); return
257
+
258
+ # Show 8 representative questions (skip t02 β€” bimodal, shown separately)
259
+ q_show = ["t01", "t03", "t04", "t06", "t07", "t09", "t10", "t11"]
260
+ fig, axes = plt.subplots(2, 4, figsize=(16, 8))
261
+ axes = axes.flatten()
262
+
263
+ for ax_idx, q_name in enumerate(q_show):
264
+ ax = axes[ax_idx]
265
+ start, end = QUESTION_GROUPS[q_name]
266
+ q_idx = list(QUESTION_GROUPS.keys()).index(q_name)
267
+
268
+ for model_name, (preds, targets, weights) in model_results.items():
269
+ mask = weights[:, q_idx] >= 0.05
270
+ if mask.sum() < 50:
271
+ continue
272
+ pf = preds[mask, start:end].flatten()
273
+ tf = targets[mask, start:end].flatten()
274
+
275
+ # Adaptive bins (equal-frequency) β€” consistent with ECE computation
276
+ percentiles = np.linspace(0, 100, n_bins + 1)
277
+ bin_edges = np.unique(np.percentile(pf, percentiles))
278
+ if len(bin_edges) < 2:
279
+ continue
280
+ bin_ids = np.clip(
281
+ np.digitize(pf, bin_edges[1:-1]), 0, len(bin_edges) - 2
282
+ )
283
+ mp = np.array([
284
+ pf[bin_ids == b].mean() if (bin_ids == b).any() else np.nan
285
+ for b in range(len(bin_edges) - 1)
286
+ ])
287
+ mt = np.array([
288
+ tf[bin_ids == b].mean() if (bin_ids == b).any() else np.nan
289
+ for b in range(len(bin_edges) - 1)
290
+ ])
291
+ valid = ~np.isnan(mp) & ~np.isnan(mt)
292
+ ax.plot(
293
+ mp[valid], mt[valid],
294
+ MODEL_STYLES.get(model_name, "-"),
295
+ color=MODEL_COLORS.get(model_name, "#888888"),
296
+ linewidth=1.8, marker="o", markersize=3.5,
297
+ label=model_name, alpha=0.9,
298
+ )
299
+
300
+ ax.plot([0, 1], [0, 1], "k--", linewidth=1, alpha=0.5, label="Perfect")
301
+ ax.set_xlim(0, 1); ax.set_ylim(0, 1)
302
+ ax.set_xlabel("Mean predicted", fontsize=8)
303
+ ax.set_ylabel("Mean true", fontsize=8)
304
+ ax.set_title(f"{q_name}: {QUESTION_LABELS[q_name]}", fontsize=9)
305
+ ax.set_aspect("equal")
306
+ if ax_idx == 0:
307
+ ax.legend(fontsize=6.5, loc="upper left")
308
+
309
+ plt.suptitle(
310
+ "Calibration reliability diagrams β€” all models (test set)\n"
311
+ "Reached branches only (w β‰₯ 0.05). Adaptive equal-frequency bins. "
312
+ "Closer to diagonal = better calibrated.",
313
+ fontsize=11,
314
+ )
315
+ plt.tight_layout()
316
+ fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
317
+ fig.savefig(path_png, dpi=300, bbox_inches="tight")
318
+ plt.close(fig)
319
+ log.info("Saved: fig_calibration_reliability")
320
+
321
+
322
+ # ─────────────────────────────────────────────────────────────
323
+ # Figure 3: ECE bar chart
324
+ # ─────────────────────────────────────────────────────────────
325
+
326
+ def fig_ece_comparison(model_results, save_dir):
327
+ path_pdf = save_dir / "fig_ece_comparison.pdf"
328
+ path_png = save_dir / "fig_ece_comparison.png"
329
+ if path_pdf.exists() and path_png.exists():
330
+ log.info("Skip (exists): fig_ece_comparison"); return
331
+
332
+ q_names = list(QUESTION_GROUPS.keys())
333
+ ece_rows = []
334
+ for model_name, (preds, targets, weights) in model_results.items():
335
+ row = {"model": model_name}
336
+ for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
337
+ mask = weights[:, q_idx] >= 0.05
338
+ if mask.sum() < 50:
339
+ row[q_name] = float("nan")
340
+ else:
341
+ row[q_name] = _compute_ece(
342
+ preds[mask, start:end].flatten(),
343
+ targets[mask, start:end].flatten(),
344
+ n_bins=15,
345
+ )
346
+ row["mean_ece"] = float(
347
+ np.nanmean([row[q] for q in q_names])
348
+ )
349
+ ece_rows.append(row)
350
+
351
+ df_ece = pd.DataFrame(ece_rows)
352
+ df_ece.to_csv(save_dir / "table_ece_comparison.csv", index=False)
353
+
354
+ x = np.arange(len(q_names))
355
+ width = 0.80 / len(model_results)
356
+ palette = list(MODEL_COLORS.values())
357
+
358
+ fig, ax = plt.subplots(figsize=(14, 5))
359
+ for i, (model_name, _) in enumerate(model_results.items()):
360
+ vals = [
361
+ float(df_ece[df_ece["model"] == model_name][q].values[0])
362
+ for q in q_names
363
+ ]
364
+ ax.bar(
365
+ x + i * width, vals, width,
366
+ label=model_name,
367
+ color=MODEL_COLORS.get(model_name, palette[i % len(palette)]),
368
+ alpha=0.85, edgecolor="white", linewidth=0.5,
369
+ )
370
+
371
+ ax.set_xticks(x + width * (len(model_results) - 1) / 2)
372
+ ax.set_xticklabels(
373
+ [f"{q}\n({QUESTION_LABELS[q][:12]})" for q in q_names],
374
+ rotation=30, ha="right", fontsize=8,
375
+ )
376
+ ax.set_ylabel("Expected Calibration Error (ECE)", fontsize=11)
377
+ ax.set_title(
378
+ "Expected Calibration Error β€” all models (test set)\n"
379
+ "Reached branches (w β‰₯ 0.05). Adaptive equal-frequency binning. "
380
+ "Lower is better.",
381
+ fontsize=11,
382
+ )
383
+ ax.legend(fontsize=8)
384
+ ax.grid(True, alpha=0.3, axis="y")
385
+ ax.set_axisbelow(True)
386
+ plt.tight_layout()
387
+ fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
388
+ fig.savefig(path_png, dpi=300, bbox_inches="tight")
389
+ plt.close(fig)
390
+ log.info("Saved: fig_ece_comparison")
391
+
392
+
393
+ # ─────────────────────────────────────────────────────────────
394
+ # Figure 4: Attention rollout gallery
395
+ # ─────────────────────────────────────────────────────────────
396
+
397
+ def fig_attention_rollout_gallery(attn_imgs, all_layers, attn_ids, save_dir):
398
+ if attn_imgs is None or all_layers is None:
399
+ log.warning("No attention data β€” skipping gallery."); return
400
+
401
+ path_pdf = save_dir / "fig_attention_rollout_gallery.pdf"
402
+ path_png = save_dir / "fig_attention_rollout_gallery.png"
403
+
404
+ if not path_pdf.exists():
405
+ fig = plot_attention_grid(
406
+ attn_imgs, all_layers, attn_ids,
407
+ save_path=str(path_png),
408
+ n_cols=4, rollout_mode="full",
409
+ )
410
+ fig.savefig(path_pdf, dpi=300, bbox_inches="tight", facecolor="black")
411
+ plt.close(fig)
412
+ log.info("Saved: fig_attention_rollout_gallery")
413
+
414
+ # High-resolution PNG for journal submission
415
+ path_hq = save_dir / "fig_attention_rollout_gallery_HQ.png"
416
+ if not path_hq.exists():
417
+ fig2 = plot_attention_grid(
418
+ attn_imgs, all_layers, attn_ids,
419
+ n_cols=4, rollout_mode="full",
420
+ )
421
+ fig2.savefig(path_hq, dpi=600, bbox_inches="tight", facecolor="black")
422
+ plt.close(fig2)
423
+ log.info("Saved: fig_attention_rollout_gallery_HQ (600 dpi)")
424
+
425
+
426
+ # ─────────────────────────────────────────────────────────────
427
+ # Figure 5: Attention entropy vs. depth
428
+ # ─────────────────────────────────────────────────────────────
429
+
430
+ def fig_attention_entropy_depth(all_layers, save_dir):
431
+ if all_layers is None:
432
+ log.warning("No attention layers β€” skipping entropy plot."); return
433
+
434
+ path_pdf = save_dir / "fig_attention_entropy_depth.pdf"
435
+ path_png = save_dir / "fig_attention_entropy_depth.png"
436
+ if path_pdf.exists() and path_png.exists():
437
+ log.info("Skip (exists): fig_attention_entropy_depth"); return
438
+
439
+ fig = plot_attention_entropy(all_layers, save_path=str(path_png))
440
+ fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
441
+ plt.close(fig)
442
+ log.info("Saved: fig_attention_entropy_depth")
443
+
444
+
445
+ # ─────────────────────────────────────────────────────────────
446
+ # Table: metrics for proposed model
447
+ # ─────────────────────────────────────────────────────────────
448
+
449
+ def table_metrics_proposed(preds, targets, weights, save_dir):
450
+ metrics = compute_metrics(preds, targets, weights)
451
+ rows = []
452
+ for q_name in QUESTION_GROUPS:
453
+ rows.append({
454
+ "question" : q_name,
455
+ "description": QUESTION_LABELS[q_name],
456
+ "MAE" : round(metrics[f"mae/{q_name}"], 5),
457
+ "RMSE" : round(metrics[f"rmse/{q_name}"], 5),
458
+ "bias" : round(metrics[f"bias/{q_name}"], 5),
459
+ "ECE" : round(metrics[f"ece/{q_name}"], 5),
460
+ })
461
+ rows.append({
462
+ "question": "weighted_avg", "description": "Weighted average",
463
+ "MAE" : round(metrics["mae/weighted_avg"], 5),
464
+ "RMSE": round(metrics["rmse/weighted_avg"], 5),
465
+ "bias": "",
466
+ "ECE" : round(metrics["ece/mean"], 5),
467
+ })
468
+ df = pd.DataFrame(rows)
469
+ df.to_csv(save_dir / "table_metrics_proposed.csv", index=False)
470
+ log.info("\n%s\n", df.to_string(index=False))
471
+ return metrics
472
+
473
+
474
+ # ─────���───────────────────────────────────────────────────────
475
+ # Table: simplex violation for sigmoid baseline
476
+ # ─────────────────────────────────────────────────────────────
477
+
478
+ def table_simplex_violation(model_results, save_dir):
479
+ """
480
+ For each model, report the fraction of test samples where per-question
481
+ predictions do not sum to 1 Β± 0.02. Expected: ~0 for softmax models,
482
+ nonzero for sigmoid baseline. This table explains why the sigmoid
483
+ baseline achieves lower raw per-answer MAE despite being scientifically
484
+ invalid: unconstrained sigmoid outputs fit each marginal independently.
485
+ """
486
+ rows = []
487
+ for model_name, (preds, _, _) in model_results.items():
488
+ svr = simplex_violation_rate(preds, tolerance=0.02)
489
+ row = {"model": model_name}
490
+ row.update({q: round(svr[q], 4) for q in QUESTION_GROUPS})
491
+ row["mean"] = round(svr["mean"], 4)
492
+ rows.append(row)
493
+ df = pd.DataFrame(rows)
494
+ df.to_csv(save_dir / "table_simplex_violation.csv", index=False)
495
+ log.info("Saved: table_simplex_violation.csv")
496
+ log.info("\n%s\n", df[["model", "mean"]].to_string(index=False))
497
+ return df
498
+
499
+
500
+ # ─────────────────────────────────────────────────────────────
501
+ # Main
502
+ # ─────────────────────────────────────────────────────────────
503
+
504
+ def main():
505
+ parser = argparse.ArgumentParser()
506
+ parser.add_argument("--config", required=True)
507
+ args = parser.parse_args()
508
+
509
+ base_cfg = OmegaConf.load("configs/base.yaml")
510
+ exp_cfg = OmegaConf.load(args.config)
511
+ cfg = OmegaConf.merge(base_cfg, exp_cfg)
512
+
513
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
514
+ save_dir = Path(cfg.outputs.figures_dir) / "evaluation"
515
+ save_dir.mkdir(parents=True, exist_ok=True)
516
+ ckpt_dir = Path(cfg.outputs.checkpoint_dir)
517
+
518
+ _, _, test_loader = build_dataloaders(cfg)
519
+
520
+ # ── Load all models ────────────────────────────────────────
521
+ log.info("Loading models from: %s", ckpt_dir)
522
+
523
+ def _load(path, model):
524
+ ckpt = torch.load(path, map_location="cpu", weights_only=True)
525
+ model.load_state_dict(ckpt["model_state"])
526
+ return model
527
+
528
+ vit_proposed = _load(
529
+ ckpt_dir / "best_full_train.pt", build_model(cfg)
530
+ ).to(device)
531
+
532
+ vit_mse = _load(
533
+ ckpt_dir / "baseline_vit_mse.pt", build_model(cfg)
534
+ ).to(device)
535
+
536
+ rn_mse = _load(
537
+ ckpt_dir / "baseline_resnet18_mse.pt",
538
+ ResNet18Baseline(dropout=cfg.model.dropout)
539
+ ).to(device)
540
+
541
+ rn_kl = _load(
542
+ ckpt_dir / "baseline_resnet18_klmse.pt",
543
+ ResNet18Baseline(dropout=cfg.model.dropout)
544
+ ).to(device)
545
+
546
+ vit_dirichlet = None
547
+ dp = ckpt_dir / "baseline_vit_dirichlet.pt"
548
+ if dp.exists():
549
+ vit_dirichlet = _load(dp, build_dirichlet_model(cfg)).to(device)
550
+ log.info("Loaded: ViT-Base + Dirichlet")
551
+
552
+ # ── Run inference ──────────────────────────────────────────
553
+ log.info("Running inference on test set...")
554
+
555
+ (p_proposed, t_proposed, w_proposed,
556
+ attn_imgs, all_layers, attn_ids) = _infer_vit(
557
+ vit_proposed, test_loader, device, cfg,
558
+ collect_attn=True, n_attn=16,
559
+ )
560
+
561
+ p_vit_mse, t_vit_mse, w_vit_mse = _infer_vit(
562
+ vit_mse, test_loader, device, cfg, collect_attn=False
563
+ )[:3]
564
+
565
+ p_rn_mse, t_rn_mse, w_rn_mse = _infer_resnet(
566
+ rn_mse, test_loader, device, cfg, use_sigmoid=True
567
+ )
568
+
569
+ p_rn_kl, t_rn_kl, w_rn_kl = _infer_resnet(
570
+ rn_kl, test_loader, device, cfg, use_sigmoid=False
571
+ )
572
+
573
+ # Build model_results dict (order determines legend order in figures)
574
+ model_results = {
575
+ "ResNet-18 + MSE (sigmoid)" : (p_rn_mse, t_rn_mse, w_rn_mse),
576
+ "ResNet-18 + KL+MSE" : (p_rn_kl, t_rn_kl, w_rn_kl),
577
+ "ViT-Base + MSE only" : (p_vit_mse, t_vit_mse, w_vit_mse),
578
+ "ViT-Base + KL+MSE (proposed)" : (p_proposed, t_proposed, w_proposed),
579
+ }
580
+
581
+ if vit_dirichlet is not None:
582
+ p_dir, t_dir, w_dir = _infer_dirichlet(
583
+ vit_dirichlet, test_loader, device, cfg
584
+ )
585
+ model_results["ViT-Base + Dirichlet (Zoobot-style)"] = (p_dir, t_dir, w_dir)
586
+
587
+ # ── Tables ─────────────────────────────────────────────────
588
+ log.info("Computing metrics...")
589
+ table_metrics_proposed(p_proposed, t_proposed, w_proposed, save_dir)
590
+
591
+ log.info("Computing reached-branch MAE table...")
592
+ df_r = compute_reached_branch_mae_table(model_results)
593
+ df_r.to_csv(save_dir / "table_reached_branch_mae.csv", index=False)
594
+ log.info("Saved: table_reached_branch_mae.csv")
595
+
596
+ log.info("Computing simplex violation table...")
597
+ table_simplex_violation(model_results, save_dir)
598
+
599
+ # ── Figures ────────────────────────────────────────────────
600
+ log.info("Generating figures...")
601
+ fig_scatter_predicted_vs_true(p_proposed, t_proposed, w_proposed, save_dir)
602
+ fig_calibration_reliability(model_results, save_dir)
603
+ fig_ece_comparison(model_results, save_dir)
604
+ fig_attention_rollout_gallery(attn_imgs, all_layers, attn_ids, save_dir)
605
+ fig_attention_entropy_depth(all_layers, save_dir)
606
+
607
+ log.info("=" * 60)
608
+ log.info("ALL OUTPUTS SAVED TO: %s", save_dir)
609
+ log.info("=" * 60)
610
+
611
+ metrics = compute_metrics(p_proposed, t_proposed, w_proposed)
612
+ log.info("Proposed model β€” test set results:")
613
+ log.info(" Weighted MAE = %.5f", metrics["mae/weighted_avg"])
614
+ log.info(" Weighted RMSE = %.5f", metrics["rmse/weighted_avg"])
615
+ log.info(" Mean ECE = %.5f", metrics["ece/mean"])
616
+
617
+
618
+ if __name__ == "__main__":
619
+ main()
src/loss.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/loss.py
3
+ -----------
4
+ Loss functions for hierarchical probabilistic vote-fraction regression.
5
+
6
+ Two losses are implemented:
7
+
8
+ 1. HierarchicalLoss β€” proposed method: weighted KL + MSE per question.
9
+ 2. DirichletLoss β€” Zoobot-style comparison: weighted Dirichlet NLL.
10
+ 3. MSEOnlyLoss β€” ablation baseline: hierarchical MSE, no KL term.
11
+
12
+ Both main losses use identical per-sample hierarchical weighting:
13
+ w_q = parent branch vote fraction (1.0 for root question t01)
14
+
15
+ Mathematical formulation
16
+ ------------------------
17
+ HierarchicalLoss per question q:
18
+ L_q = w_q * [ Ξ»_kl * KL(p_q || Ε·_q) + Ξ»_mse * MSE(Ε·_q, p_q) ]
19
+
20
+ where p_q = ground-truth vote fractions [B, A_q]
21
+ Ε·_q = softmax(logits_q) [B, A_q]
22
+ w_q = hierarchical weight [B]
23
+
24
+ DirichletLoss per question q:
25
+ L_q = w_q * [ log B(Ξ±_q) βˆ’ Ξ£_a (Ξ±_qa βˆ’ 1) log(p_qa) ]
26
+
27
+ where Ξ±_q = 1 + softplus(logits_q) > 1 [B, A_q]
28
+
29
+ References
30
+ ----------
31
+ Walmsley et al. (2022), MNRAS 509, 3966 (Zoobot β€” Dirichlet approach)
32
+ """
33
+
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+ from omegaconf import DictConfig
38
+ from src.dataset import QUESTION_GROUPS
39
+
40
+
41
+ class HierarchicalLoss(nn.Module):
42
+ """Weighted hierarchical KL + MSE loss. Proposed method."""
43
+
44
+ def __init__(self, cfg: DictConfig):
45
+ super().__init__()
46
+ self.lambda_kl = float(cfg.loss.lambda_kl)
47
+ self.lambda_mse = float(cfg.loss.lambda_mse)
48
+ self.epsilon = float(cfg.loss.epsilon)
49
+ self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()]
50
+
51
+ def forward(self, predictions: torch.Tensor,
52
+ targets: torch.Tensor, weights: torch.Tensor):
53
+ total_loss = torch.zeros(1, device=predictions.device, dtype=predictions.dtype)
54
+ loss_dict = {}
55
+
56
+ for q_idx, (q_name, start, end) in enumerate(self.question_slices):
57
+ logits_q = predictions[:, start:end]
58
+ target_q = targets[:, start:end]
59
+ weight_q = weights[:, q_idx]
60
+
61
+ pred_q = F.softmax(logits_q, dim=-1)
62
+ pred_q_c = pred_q.clamp(min=self.epsilon, max=1.0)
63
+ target_q_c = target_q.clamp(min=self.epsilon, max=1.0)
64
+
65
+ kl_per_sample = (
66
+ target_q_c * (target_q_c.log() - pred_q_c.log())
67
+ ).sum(dim=-1)
68
+
69
+ mse_per_sample = F.mse_loss(
70
+ pred_q, target_q, reduction="none"
71
+ ).mean(dim=-1)
72
+
73
+ combined = (self.lambda_kl * kl_per_sample +
74
+ self.lambda_mse * mse_per_sample)
75
+ q_loss = (weight_q * combined).mean()
76
+
77
+ total_loss = total_loss + q_loss
78
+ loss_dict[f"loss/{q_name}"] = q_loss.detach().item()
79
+
80
+ loss_dict["loss/total"] = total_loss.detach().item()
81
+ return total_loss, loss_dict
82
+
83
+
84
+ class DirichletLoss(nn.Module):
85
+ """
86
+ Weighted hierarchical Dirichlet negative log-likelihood.
87
+ Used to train GalaxyViTDirichlet for comparison with the proposed method.
88
+ Matches the Zoobot approach (Walmsley et al. 2022).
89
+ """
90
+
91
+ def __init__(self, cfg: DictConfig):
92
+ super().__init__()
93
+ self.epsilon = float(cfg.loss.epsilon)
94
+ self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()]
95
+
96
+ def forward(self, alpha: torch.Tensor,
97
+ targets: torch.Tensor, weights: torch.Tensor):
98
+ total_loss = torch.zeros(1, device=alpha.device, dtype=alpha.dtype)
99
+ loss_dict = {}
100
+
101
+ for q_idx, (q_name, start, end) in enumerate(self.question_slices):
102
+ alpha_q = alpha[:, start:end]
103
+ target_q = targets[:, start:end]
104
+ weight_q = weights[:, q_idx]
105
+
106
+ target_q_c = target_q.clamp(min=self.epsilon)
107
+
108
+ # log B(Ξ±) = Ξ£ lgamma(Ξ±_a) βˆ’ lgamma(Ξ£ Ξ±_a)
109
+ log_beta = (
110
+ torch.lgamma(alpha_q).sum(dim=-1) -
111
+ torch.lgamma(alpha_q.sum(dim=-1))
112
+ )
113
+ # βˆ’Ξ£ (Ξ±_a βˆ’ 1) log(p_a)
114
+ log_likelihood = ((alpha_q - 1.0) * target_q_c.log()).sum(dim=-1)
115
+
116
+ nll_per_sample = log_beta - log_likelihood
117
+ q_loss = (weight_q * nll_per_sample).mean()
118
+
119
+ total_loss = total_loss + q_loss
120
+ loss_dict[f"loss/{q_name}"] = q_loss.detach().item()
121
+
122
+ loss_dict["loss/total"] = total_loss.detach().item()
123
+ return total_loss, loss_dict
124
+
125
+
126
+ class MSEOnlyLoss(nn.Module):
127
+ """
128
+ Hierarchical MSE loss without KL term. Used as ablation baseline.
129
+ Equivalent to HierarchicalLoss with lambda_kl=0.
130
+ """
131
+
132
+ def __init__(self, cfg: DictConfig):
133
+ super().__init__()
134
+ self.epsilon = float(cfg.loss.epsilon)
135
+ self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()]
136
+
137
+ def forward(self, predictions: torch.Tensor,
138
+ targets: torch.Tensor, weights: torch.Tensor):
139
+ total_loss = torch.zeros(1, device=predictions.device, dtype=predictions.dtype)
140
+ loss_dict = {}
141
+
142
+ for q_idx, (q_name, start, end) in enumerate(self.question_slices):
143
+ logits_q = predictions[:, start:end]
144
+ target_q = targets[:, start:end]
145
+ weight_q = weights[:, q_idx]
146
+
147
+ pred_q = F.softmax(logits_q, dim=-1)
148
+ mse_per_sample = F.mse_loss(pred_q, target_q, reduction="none").mean(dim=-1)
149
+ q_loss = (weight_q * mse_per_sample).mean()
150
+
151
+ total_loss = total_loss + q_loss
152
+ loss_dict[f"loss/{q_name}"] = q_loss.detach().item()
153
+
154
+ loss_dict["loss/total"] = total_loss.detach().item()
155
+ return total_loss, loss_dict
src/metrics.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/metrics.py
3
+ --------------
4
+ Evaluation metrics for hierarchical probabilistic vote-fraction regression
5
+ on Galaxy Zoo 2.
6
+
7
+ Three evaluation regimes
8
+ ------------------------
9
+ 1. GLOBAL β€” all test samples (dominated by root question t01).
10
+ 2. REACHED-BRANCH β€” samples where branch was actually reached (w >= threshold).
11
+ This is the scientifically correct regime for conditional questions.
12
+ 3. ECE β€” Expected Calibration Error using adaptive (equal-frequency) bins.
13
+
14
+ Fixes applied vs. original
15
+ ---------------------------
16
+ - ECE uses adaptive binning (equal-frequency bins) instead of equal-width.
17
+ Equal-width bins saturate at 0.200 for bimodal questions (t02, t03, t04)
18
+ where predictions cluster near 0 and 1. Adaptive bins are unbiased for
19
+ any distribution shape.
20
+ - simplex_violation_rate() added: fraction of question groups where the
21
+ sigmoid baseline predictions do not sum to 1 Β± 0.02. Used to explain
22
+ why ResNet-18 + sigmoid achieves lower raw MAE despite predicting
23
+ invalid distributions.
24
+ """
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn.functional as F
29
+ from src.dataset import QUESTION_GROUPS
30
+
31
+ WEIGHT_THRESHOLDS = [0.05, 0.50, 0.75]
32
+
33
+
34
+ # ─────────────────────────────────────────────────────────────
35
+ # Main metrics function
36
+ # ─────────────────────────────────────────────────────────────
37
+
38
+ def compute_metrics(
39
+ all_predictions: np.ndarray, # [N, 37]
40
+ all_targets: np.ndarray, # [N, 37]
41
+ all_weights: np.ndarray, # [N, 11]
42
+ ) -> dict:
43
+ """
44
+ Full metrics suite: global + reached-branch MAE/RMSE + bias + ECE.
45
+ """
46
+ metrics = {}
47
+ q_names = list(QUESTION_GROUPS.keys())
48
+
49
+ # ── 1. Global metrics ──────────────────────────────────────
50
+ mae_values = []
51
+ rmse_values = []
52
+ weight_means = []
53
+
54
+ for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
55
+ pred_q = all_predictions[:, start:end]
56
+ target_q = all_targets[:, start:end]
57
+ weight_q = all_weights[:, q_idx]
58
+
59
+ mae_q = np.abs(pred_q - target_q).mean(axis=1).mean()
60
+ rmse_q = np.sqrt(((pred_q - target_q) ** 2).mean(axis=1).mean())
61
+ w_mean = weight_q.mean()
62
+
63
+ metrics[f"mae/{q_name}"] = float(mae_q)
64
+ metrics[f"rmse/{q_name}"] = float(rmse_q)
65
+ metrics[f"bias/{q_name}"] = float(
66
+ (all_predictions[:, start:end] - all_targets[:, start:end]).mean()
67
+ )
68
+
69
+ mae_values.append(mae_q)
70
+ rmse_values.append(rmse_q)
71
+ weight_means.append(w_mean)
72
+
73
+ weight_means = np.array(weight_means)
74
+ weight_sum = weight_means.sum()
75
+ metrics["mae/weighted_avg"] = float(
76
+ (weight_means * np.array(mae_values)).sum() / weight_sum
77
+ )
78
+ metrics["rmse/weighted_avg"] = float(
79
+ (weight_means * np.array(rmse_values)).sum() / weight_sum
80
+ )
81
+
82
+ # ── 2. Reached-branch metrics ──────────────────────────────
83
+ for thresh in WEIGHT_THRESHOLDS:
84
+ thresh_key = str(thresh).replace(".", "")
85
+ branch_maes = []
86
+ branch_ws = []
87
+
88
+ for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
89
+ pred_q = all_predictions[:, start:end]
90
+ target_q = all_targets[:, start:end]
91
+ weight_q = all_weights[:, q_idx]
92
+ mask = weight_q >= thresh
93
+ n_reached = mask.sum()
94
+ metrics[f"n_reached_w{thresh_key}/{q_name}"] = int(n_reached)
95
+
96
+ if n_reached >= 10:
97
+ mae_q = np.abs(pred_q[mask] - target_q[mask]).mean(axis=1).mean()
98
+ metrics[f"mae_w{thresh_key}/{q_name}"] = float(mae_q)
99
+ branch_maes.append(mae_q)
100
+ branch_ws.append(weight_q[mask].mean())
101
+ else:
102
+ metrics[f"mae_w{thresh_key}/{q_name}"] = float("nan")
103
+
104
+ if branch_maes:
105
+ bw = np.array(branch_ws)
106
+ bm = np.array(branch_maes)
107
+ metrics[f"mae_w{thresh_key}/conditional_avg"] = float(
108
+ (bw * bm).sum() / bw.sum()
109
+ )
110
+
111
+ # ── 3. ECE per question (adaptive binning) ─────────────────
112
+ ece_values = []
113
+ for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
114
+ pred_flat = all_predictions[:, start:end].flatten()
115
+ target_flat = all_targets[:, start:end].flatten()
116
+ ece = _compute_ece(pred_flat, target_flat)
117
+ metrics[f"ece/{q_name}"] = float(ece)
118
+ ece_values.append(ece)
119
+
120
+ metrics["ece/mean"] = float(np.nanmean(ece_values))
121
+
122
+ return metrics
123
+
124
+
125
+ # ──────────────────────���──────────────────────────────────────
126
+ # ECE β€” adaptive (equal-frequency) binning
127
+ # ─────────────────────────────────────────────────────────────
128
+
129
+ def _compute_ece(pred: np.ndarray, target: np.ndarray,
130
+ n_bins: int = 15) -> float:
131
+ """
132
+ Expected Calibration Error with adaptive (equal-frequency) binning.
133
+
134
+ Equal-width binning saturates for bimodal distributions (e.g. t02, t03,
135
+ t04 where predictions cluster at 0 and 1) because >95% of samples fall
136
+ into boundary bins. Adaptive binning places bin edges at percentiles of
137
+ the predicted distribution, giving each bin an equal number of samples
138
+ and making ECE meaningful regardless of the prediction distribution shape.
139
+
140
+ Parameters
141
+ ----------
142
+ pred : [N] predicted vote fractions
143
+ target : [N] true vote fractions
144
+ n_bins : number of bins (default 15)
145
+
146
+ Returns
147
+ -------
148
+ ECE : float in [0, 1]
149
+ """
150
+ if len(pred) < n_bins:
151
+ return float("nan")
152
+
153
+ # Build equal-frequency bin edges from percentiles of pred
154
+ percentiles = np.linspace(0, 100, n_bins + 1)
155
+ bin_edges = np.unique(np.percentile(pred, percentiles))
156
+
157
+ if len(bin_edges) < 2:
158
+ return float("nan")
159
+
160
+ # Assign samples to bins (digitize returns 1-indexed; clip to [0, n-2])
161
+ bin_ids = np.clip(np.digitize(pred, bin_edges[1:-1]), 0, len(bin_edges) - 2)
162
+
163
+ ece = 0.0
164
+ n = len(pred)
165
+ for b in np.unique(bin_ids):
166
+ mask = bin_ids == b
167
+ if not mask.any():
168
+ continue
169
+ ece += (mask.sum() / n) * abs(pred[mask].mean() - target[mask].mean())
170
+
171
+ return float(ece)
172
+
173
+
174
+ # ─────────────────────────────────────────────────────────────
175
+ # Simplex violation rate
176
+ # ─────────────────────────────────────────────────────────────
177
+
178
+ def simplex_violation_rate(
179
+ predictions: np.ndarray, # [N, 37]
180
+ tolerance: float = 0.02,
181
+ ) -> dict:
182
+ """
183
+ Compute the fraction of galaxies for which each question's predictions
184
+ do NOT sum to 1 Β± tolerance. Used to demonstrate that the sigmoid
185
+ baseline produces invalid probability distributions.
186
+
187
+ A model trained with softmax per question group will have violation_rate
188
+ β‰ˆ 0.0 by construction. A sigmoid baseline will have nonzero rates,
189
+ explaining why its raw per-answer MAE is lower (unconstrained outputs
190
+ can fit each marginal independently).
191
+
192
+ Parameters
193
+ ----------
194
+ predictions : [N, 37] array of predicted values
195
+ tolerance : acceptable deviation from 1.0 (default 0.02)
196
+
197
+ Returns
198
+ -------
199
+ dict mapping question name to violation rate in [0, 1]
200
+ """
201
+ rates = {}
202
+ for q_name, (start, end) in QUESTION_GROUPS.items():
203
+ pred_q = predictions[:, start:end]
204
+ row_sums = pred_q.sum(axis=1)
205
+ violated = np.abs(row_sums - 1.0) > tolerance
206
+ rates[q_name] = float(violated.mean())
207
+ rates["mean"] = float(np.mean(list(rates.values())))
208
+ return rates
209
+
210
+
211
+ # ─────────────────────────────────────────────────────────────
212
+ # Reached-branch comparison table (for paper Table 2)
213
+ # ─────────────────────────────────────────────────────────────
214
+
215
+ def compute_reached_branch_mae_table(
216
+ model_results: dict,
217
+ ) -> "pd.DataFrame":
218
+ """
219
+ Build the reached-branch MAE comparison table across all models.
220
+
221
+ Parameters
222
+ ----------
223
+ model_results : dict mapping model_name β†’ (preds, targets, weights)
224
+ All arrays are [N, 37] or [N, 11].
225
+
226
+ Returns
227
+ -------
228
+ pd.DataFrame with columns:
229
+ model, question, description, n_w005, mae_w005, mae_w050, mae_w075
230
+ """
231
+ import pandas as pd
232
+
233
+ QUESTION_DESCRIPTIONS = {
234
+ "t01": "Smooth or features",
235
+ "t02": "Edge-on disk",
236
+ "t03": "Bar",
237
+ "t04": "Spiral arms",
238
+ "t05": "Bulge prominence",
239
+ "t06": "Odd feature",
240
+ "t07": "Roundedness (smooth)",
241
+ "t08": "Odd feature type",
242
+ "t09": "Bulge shape (edge-on)",
243
+ "t10": "Arms winding",
244
+ "t11": "Arms number",
245
+ }
246
+
247
+ rows = []
248
+ for model_name, (preds, targets, weights) in model_results.items():
249
+ for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
250
+ pred_q = preds[:, start:end]
251
+ target_q = targets[:, start:end]
252
+ weight_q = weights[:, q_idx]
253
+
254
+ row = {
255
+ "model" : model_name,
256
+ "question" : q_name,
257
+ "description": QUESTION_DESCRIPTIONS[q_name],
258
+ }
259
+ for thresh in WEIGHT_THRESHOLDS:
260
+ mask = weight_q >= thresh
261
+ n = mask.sum()
262
+ key = f"n_w{str(thresh).replace('.','')}"
263
+ mkey = f"mae_w{str(thresh).replace('.','')}"
264
+ row[key] = int(n)
265
+ row[mkey] = (
266
+ float(np.abs(pred_q[mask] - target_q[mask]).mean(axis=1).mean())
267
+ if n >= 10 else float("nan")
268
+ )
269
+ rows.append(row)
270
+
271
+ # Weighted-average row for this model
272
+ branch_maes = []
273
+ branch_ws = []
274
+ for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
275
+ weight_q = weights[:, q_idx]
276
+ pred_q = preds[:, start:end]
277
+ target_q = targets[:, start:end]
278
+ mask = weight_q >= 0.05
279
+ if mask.sum() >= 10:
280
+ branch_maes.append(
281
+ np.abs(pred_q[mask] - target_q[mask]).mean(axis=1).mean()
282
+ )
283
+ branch_ws.append(weight_q[mask].mean())
284
+
285
+ bw = np.array(branch_ws)
286
+ bm = np.array(branch_maes)
287
+ rows.append({
288
+ "model" : model_name,
289
+ "question" : "weighted_avg",
290
+ "description": "Weighted average (wβ‰₯0.05)",
291
+ "n_w005" : int(sum(weights[:, q] >= 0.05 for q in range(11)).sum()
292
+ if hasattr(weights, "__len__") else 0),
293
+ "mae_w005" : float((bw * bm).sum() / bw.sum()) if len(bw) > 0 else float("nan"),
294
+ "mae_w050" : float("nan"),
295
+ "mae_w075" : float("nan"),
296
+ })
297
+
298
+ return pd.DataFrame(rows)
299
+
300
+
301
+ # ─────────────────────────────────────────────────────────────
302
+ # Tensor β†’ numpy helpers
303
+ # ─────────────────────────────────────────────────────────────
304
+
305
+ def predictions_to_numpy(
306
+ predictions: torch.Tensor,
307
+ targets: torch.Tensor,
308
+ weights: torch.Tensor,
309
+ ) -> tuple:
310
+ """Apply softmax per question group and return numpy arrays."""
311
+ pred_np = predictions.detach().cpu().clone()
312
+ for q_name, (start, end) in QUESTION_GROUPS.items():
313
+ pred_np[:, start:end] = F.softmax(pred_np[:, start:end], dim=-1)
314
+ return (
315
+ pred_np.numpy(),
316
+ targets.detach().cpu().numpy(),
317
+ weights.detach().cpu().numpy(),
318
+ )
319
+
320
+
321
+ def dirichlet_predictions_to_numpy(
322
+ alpha: torch.Tensor,
323
+ targets: torch.Tensor,
324
+ weights: torch.Tensor,
325
+ ) -> tuple:
326
+ """Convert Dirichlet concentration parameters to mean predictions."""
327
+ means = torch.zeros_like(alpha)
328
+ for q_name, (start, end) in QUESTION_GROUPS.items():
329
+ a_q = alpha[:, start:end]
330
+ means[:, start:end] = a_q / a_q.sum(dim=-1, keepdim=True)
331
+ return (
332
+ means.detach().cpu().numpy(),
333
+ targets.detach().cpu().numpy(),
334
+ weights.detach().cpu().numpy(),
335
+ )
src/model.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/model.py
3
+ ------------
4
+ Vision Transformer (ViT-Base/16) backbone with three head variants:
5
+
6
+ 1. GalaxyViT β€” linear regression head (37 logits). Proposed model.
7
+ 2. GalaxyViTDirichlet β€” Dirichlet concentration head (Zoobot-style baseline).
8
+ 3. mc_dropout_predict β€” MC Dropout uncertainty estimation wrapper.
9
+
10
+ Architecture
11
+ ------------
12
+ Backbone : vit_base_patch16_224 from timm (pretrained ImageNet-21k)
13
+ 12 transformer layers, 12 heads, embed_dim=768
14
+ Input : [B, 3, 224, 224]
15
+ CLS out: [B, 768]
16
+ Head : Dropout(p) β†’ Linear(768, 37)
17
+
18
+ Full multi-layer attention rollout
19
+ ------------------------------------
20
+ All 12 transformer blocks use fused_attn=False so forward hooks can
21
+ capture the post-softmax attention matrices. Rollout is computed in
22
+ attention_viz.py using the corrected right-multiplication order.
23
+
24
+ MC Dropout
25
+ -----------
26
+ enable_mc_dropout() keeps Dropout active at inference time.
27
+ Running N stochastic forward passes gives mean prediction and
28
+ per-answer std (epistemic uncertainty). N=30 is standard practice
29
+ per Gal & Ghahramani (2016).
30
+
31
+ Dirichlet head
32
+ --------------
33
+ Outputs Ξ± > 1 per answer via: Ξ± = 1 + softplus(linear(features))
34
+ Matches the Zoobot approach for a fair direct comparison.
35
+ Mean vote fraction: E[p_q] = Ξ±_q / sum(Ξ±_q).
36
+
37
+ References
38
+ ----------
39
+ Gal & Ghahramani (2016). Dropout as a Bayesian Approximation.
40
+ ICML 2016. https://arxiv.org/abs/1506.02142
41
+ Walmsley et al. (2022). Towards Galaxy Foundation Models.
42
+ MNRAS 509, 3966. https://arxiv.org/abs/2110.12735
43
+ """
44
+
45
+ from __future__ import annotations
46
+
47
+ import torch
48
+ import torch.nn as nn
49
+ import torch.nn.functional as F
50
+ import timm
51
+ import numpy as np
52
+ from omegaconf import DictConfig
53
+ from typing import Optional, List, Tuple
54
+
55
+ from src.dataset import QUESTION_GROUPS
56
+
57
+
58
+ # ─────────────────────────────────────────────────────────────
59
+ # Attention hook manager
60
+ # ─────────────────────────────────────────────────────────────
61
+
62
+ class AttentionHookManager:
63
+ """
64
+ Registers forward hooks on all transformer blocks to capture
65
+ post-softmax attention matrices for full rollout computation.
66
+
67
+ With fused_attn=False, timm's attention block executes:
68
+ attn = softmax(q @ k.T / scale) # [B, H, N+1, N+1]
69
+ attn = attn_drop(attn) # hook fires on INPUT = post-softmax
70
+ out = attn @ v
71
+ """
72
+
73
+ def __init__(self, blocks):
74
+ self.blocks = blocks
75
+ self._attn_list: List[torch.Tensor] = []
76
+ self._handles = []
77
+ self._register_hooks()
78
+
79
+ def _register_hooks(self):
80
+ for block in self.blocks:
81
+ block.attn.fused_attn = False
82
+
83
+ def _make_hook():
84
+ def _hook(module, input, output):
85
+ # input[0] is the post-softmax attention tensor
86
+ self._attn_list.append(input[0].detach())
87
+ return _hook
88
+
89
+ h = block.attn.attn_drop.register_forward_hook(_make_hook())
90
+ self._handles.append(h)
91
+
92
+ def clear(self):
93
+ self._attn_list.clear()
94
+
95
+ def get_all_attentions(self) -> Optional[List[torch.Tensor]]:
96
+ """Returns list of L tensors, each [B, H, N+1, N+1]."""
97
+ if not self._attn_list:
98
+ return None
99
+ return list(self._attn_list)
100
+
101
+ def get_last_attention(self) -> Optional[torch.Tensor]:
102
+ if not self._attn_list:
103
+ return None
104
+ return self._attn_list[-1]
105
+
106
+ def remove_all(self):
107
+ for h in self._handles:
108
+ h.remove()
109
+ self._handles.clear()
110
+
111
+
112
+ # ─────────────────────────────────────────────────────────────
113
+ # GalaxyViT β€” proposed model
114
+ # ─────────────────────────────────────────────────────────────
115
+
116
+ class GalaxyViT(nn.Module):
117
+ """
118
+ ViT-Base/16 backbone + linear regression head for GZ2.
119
+
120
+ Outputs 37 raw logits; softmax is applied per question group
121
+ during loss computation and metric evaluation.
122
+
123
+ Full 12-layer attention hooks are registered at construction.
124
+ """
125
+
126
+ def __init__(self, cfg: DictConfig):
127
+ super().__init__()
128
+
129
+ self.backbone = timm.create_model(
130
+ cfg.model.backbone,
131
+ pretrained=cfg.model.pretrained,
132
+ num_classes=0,
133
+ )
134
+ embed_dim = self.backbone.embed_dim # 768
135
+
136
+ self.head = nn.Sequential(
137
+ nn.Dropout(p=cfg.model.dropout),
138
+ nn.Linear(embed_dim, 37),
139
+ )
140
+
141
+ self._hook_mgr = AttentionHookManager(self.backbone.blocks)
142
+ self._mc_dropout = False
143
+
144
+ def enable_mc_dropout(self):
145
+ """Keep Dropout active at inference time for MC sampling."""
146
+ self._mc_dropout = True
147
+ for m in self.modules():
148
+ if isinstance(m, nn.Dropout):
149
+ m.train()
150
+
151
+ def disable_mc_dropout(self):
152
+ self._mc_dropout = False
153
+
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
+ self._hook_mgr.clear()
156
+ features = self.backbone(x) # [B, 768]
157
+ logits = self.head(features) # [B, 37]
158
+ return logits
159
+
160
+ def get_attention_weights(self) -> Optional[torch.Tensor]:
161
+ return self._hook_mgr.get_last_attention()
162
+
163
+ def get_all_attention_weights(self) -> Optional[List[torch.Tensor]]:
164
+ return self._hook_mgr.get_all_attentions()
165
+
166
+ def remove_hooks(self):
167
+ self._hook_mgr.remove_all()
168
+
169
+
170
+ # ─────────────────────────────────────────────────────────────
171
+ # GalaxyViTDirichlet β€” Zoobot-style comparison baseline
172
+ # ─────────────────────────────────────────────────────────────
173
+
174
+ class GalaxyViTDirichlet(nn.Module):
175
+ """
176
+ ViT-Base/16 + Dirichlet concentration head.
177
+
178
+ Outputs Ξ± > 1 per answer via Ξ± = 1 + softplus(linear(features)).
179
+ Enforcing Ξ± > 1 ensures unimodal Dirichlet distributions.
180
+
181
+ Mean vote fraction: E[p_q] = Ξ±_q / sum(Ξ±_q) (same as softmax mean).
182
+ Total concentration sum(Ξ±_q) encodes prediction confidence.
183
+ """
184
+
185
+ def __init__(self, cfg: DictConfig):
186
+ super().__init__()
187
+
188
+ self.backbone = timm.create_model(
189
+ cfg.model.backbone,
190
+ pretrained=cfg.model.pretrained,
191
+ num_classes=0,
192
+ )
193
+ embed_dim = self.backbone.embed_dim
194
+
195
+ self.head = nn.Sequential(
196
+ nn.Dropout(p=cfg.model.dropout),
197
+ nn.Linear(embed_dim, 37),
198
+ )
199
+
200
+ self._hook_mgr = AttentionHookManager(self.backbone.blocks)
201
+
202
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
203
+ """Returns Ξ±: [B, 37] Dirichlet concentration parameters > 1."""
204
+ self._hook_mgr.clear()
205
+ features = self.backbone(x)
206
+ logits = self.head(features)
207
+ alpha = 1.0 + F.softplus(logits) # Ξ± > 1
208
+ return alpha
209
+
210
+ def get_mean_prediction(self, alpha: torch.Tensor) -> torch.Tensor:
211
+ means = torch.zeros_like(alpha)
212
+ for q_name, (start, end) in QUESTION_GROUPS.items():
213
+ a_q = alpha[:, start:end]
214
+ means[:, start:end] = a_q / a_q.sum(dim=-1, keepdim=True)
215
+ return means
216
+
217
+ def get_attention_weights(self):
218
+ return self._hook_mgr.get_last_attention()
219
+
220
+ def get_all_attention_weights(self):
221
+ return self._hook_mgr.get_all_attentions()
222
+
223
+
224
+ # ─────────────────────────────────────────────────────────────
225
+ # MC Dropout inference
226
+ # ─────────────────────────────────────────────────────────────
227
+
228
+ @torch.no_grad()
229
+ def mc_dropout_predict(
230
+ model: GalaxyViT,
231
+ images: torch.Tensor,
232
+ n_passes: int = 30,
233
+ device: torch.device = None,
234
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
235
+ """
236
+ MC Dropout epistemic uncertainty estimation.
237
+
238
+ Runs n_passes stochastic forward passes with dropout active,
239
+ returning mean prediction and per-answer std.
240
+
241
+ Parameters
242
+ ----------
243
+ model : GalaxyViT instance
244
+ images : [B, 3, H, W]
245
+ n_passes : number of MC samples (30 is standard)
246
+ device : inference device
247
+
248
+ Returns
249
+ -------
250
+ mean_pred : [B, 37] mean softmax predictions
251
+ std_pred : [B, 37] std across passes (epistemic uncertainty)
252
+ per_q_uncertainty: [B, 11] mean std per question
253
+ """
254
+ if device is None:
255
+ device = next(model.parameters()).device
256
+
257
+ model.eval()
258
+ model.enable_mc_dropout()
259
+ images = images.to(device)
260
+ all_preds = []
261
+
262
+ for _ in range(n_passes):
263
+ logits = model(images) # [B, 37]
264
+ preds = torch.zeros_like(logits)
265
+ for q_name, (start, end) in QUESTION_GROUPS.items():
266
+ preds[:, start:end] = F.softmax(logits[:, start:end], dim=-1)
267
+ all_preds.append(preds.cpu().numpy())
268
+
269
+ model.disable_mc_dropout()
270
+
271
+ all_preds = np.stack(all_preds, axis=0) # [n_passes, B, 37]
272
+ mean_pred = all_preds.mean(axis=0) # [B, 37]
273
+ std_pred = all_preds.std(axis=0) # [B, 37]
274
+
275
+ q_names = list(QUESTION_GROUPS.keys())
276
+ per_q_unc = np.zeros(
277
+ (mean_pred.shape[0], len(q_names)), dtype=np.float32
278
+ )
279
+ for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
280
+ per_q_unc[:, q_idx] = std_pred[:, start:end].mean(axis=1)
281
+
282
+ return (
283
+ mean_pred.astype(np.float32),
284
+ std_pred.astype(np.float32),
285
+ per_q_unc,
286
+ )
287
+
288
+
289
+ # ─────────────────────────────────────────────────────────────
290
+ # Factory functions
291
+ # ─────────────────────────────────────────────────────────────
292
+
293
+ def build_model(cfg: DictConfig) -> GalaxyViT:
294
+ model = GalaxyViT(cfg)
295
+ _print_summary(model, cfg, "GalaxyViT (regression β€” proposed)")
296
+ return model
297
+
298
+
299
+ def build_dirichlet_model(cfg: DictConfig) -> GalaxyViTDirichlet:
300
+ model = GalaxyViTDirichlet(cfg)
301
+ _print_summary(model, cfg, "GalaxyViTDirichlet (Zoobot-style baseline)")
302
+ return model
303
+
304
+
305
+ def _print_summary(model, cfg, name: str):
306
+ total = sum(p.numel() for p in model.parameters())
307
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
308
+ n_hooks = len(model.backbone.blocks)
309
+ print(f"\n{'='*55}")
310
+ print(f"Model : {name}")
311
+ print(f"Backbone : {cfg.model.backbone}")
312
+ print(f"Pretrained : {cfg.model.pretrained}")
313
+ print(f"Dropout : {cfg.model.dropout}")
314
+ print(f"Parameters : {total:,} ({trainable:,} trainable)")
315
+ print(f"Attn hooks : {n_hooks} layers (full rollout enabled)")
316
+ print(f"{'='*55}\n")
src/train.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/train.py
3
+ ------------
4
+ Main training loop for the proposed hierarchical probabilistic ViT
5
+ regression model on Galaxy Zoo 2.
6
+
7
+ Model : GalaxyViT (ViT-Base/16 + linear head)
8
+ Loss : HierarchicalLoss (KL + MSE, Ξ»=0.5 each)
9
+ Scheduler: CosineAnnealingLR
10
+ Dropout : 0.3 (increased from 0.1 β€” see base.yaml rationale)
11
+
12
+ Saves
13
+ -----
14
+ outputs/checkpoints/best_<experiment_name>.pt β€” best checkpoint
15
+ outputs/logs/training_<experiment_name>_history.csv β€” epoch history
16
+
17
+ Usage
18
+ -----
19
+ cd ~/galaxy
20
+ nohup python -m src.train --config configs/full_train.yaml \
21
+ > outputs/logs/train_full.log 2>&1 &
22
+ echo "PID: $!"
23
+ """
24
+
25
+ import argparse
26
+ import logging
27
+ import random
28
+ import sys
29
+ from pathlib import Path
30
+
31
+ import numpy as np
32
+ import torch
33
+ import torch.nn as nn
34
+ from torch.amp import autocast, GradScaler
35
+ from omegaconf import OmegaConf
36
+ import pandas as pd
37
+
38
+ import wandb
39
+ from tqdm import tqdm
40
+
41
+ from src.dataset import build_dataloaders
42
+ from src.loss import HierarchicalLoss
43
+ from src.metrics import compute_metrics, predictions_to_numpy
44
+ from src.model import build_model
45
+ from src.attention_viz import plot_attention_grid
46
+
47
+ logging.basicConfig(
48
+ format="%(asctime)s %(levelname)s %(name)s %(message)s",
49
+ datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
50
+ )
51
+ log = logging.getLogger("train")
52
+
53
+
54
+ # ─────────────────────────────────────────────────────────────
55
+ # Utilities
56
+ # ─────────────────────────────────────────────────────────────
57
+
58
+ def set_seed(seed: int):
59
+ random.seed(seed)
60
+ np.random.seed(seed)
61
+ torch.manual_seed(seed)
62
+ torch.cuda.manual_seed_all(seed)
63
+ torch.backends.cudnn.deterministic = True
64
+ torch.backends.cudnn.benchmark = False
65
+
66
+
67
+ class EarlyStopping:
68
+ def __init__(self, patience, min_delta, checkpoint_path):
69
+ self.patience = patience
70
+ self.min_delta = min_delta
71
+ self.checkpoint_path = checkpoint_path
72
+ self.best_loss = float("inf")
73
+ self.counter = 0
74
+ self.best_epoch = 0
75
+
76
+ def step(self, val_loss, model, epoch) -> bool:
77
+ if val_loss < self.best_loss - self.min_delta:
78
+ self.best_loss = val_loss
79
+ self.counter = 0
80
+ self.best_epoch = epoch
81
+ torch.save(
82
+ {"epoch": epoch, "model_state": model.state_dict(),
83
+ "val_loss": val_loss},
84
+ self.checkpoint_path,
85
+ )
86
+ log.info(" [ckpt] saved epoch=%d val_loss=%.6f", epoch, val_loss)
87
+ else:
88
+ self.counter += 1
89
+ log.info(" [early_stop] %d/%d best=%.6f",
90
+ self.counter, self.patience, self.best_loss)
91
+ return self.counter >= self.patience
92
+
93
+ def restore_best(self, model):
94
+ ckpt = torch.load(self.checkpoint_path, map_location="cpu",
95
+ weights_only=True)
96
+ model.load_state_dict(ckpt["model_state"])
97
+ log.info("Restored best weights epoch=%d val_loss=%.6f",
98
+ ckpt["epoch"], ckpt["val_loss"])
99
+
100
+
101
+ # ─────────────────────────────────────────────────────────────
102
+ # Training / validation steps
103
+ # ─────────────────────────────────────────────────────────────
104
+
105
+ def train_one_epoch(model, loader, loss_fn, optimizer,
106
+ scaler, device, cfg, epoch):
107
+ model.train()
108
+ total = 0.0
109
+ nb = 0
110
+ for images, targets, weights, _ in tqdm(
111
+ loader, desc=f"Train E{epoch}", leave=False
112
+ ):
113
+ images = images.to(device, non_blocking=True)
114
+ targets = targets.to(device, non_blocking=True)
115
+ weights = weights.to(device, non_blocking=True)
116
+
117
+ optimizer.zero_grad(set_to_none=True)
118
+ with autocast("cuda", enabled=cfg.training.mixed_precision):
119
+ logits = model(images)
120
+ loss, _ = loss_fn(logits, targets, weights)
121
+ scaler.scale(loss).backward()
122
+ scaler.unscale_(optimizer)
123
+ torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.training.grad_clip)
124
+ scaler.step(optimizer)
125
+ scaler.update()
126
+
127
+ total += loss.item()
128
+ nb += 1
129
+ return total / nb
130
+
131
+
132
+ def validate(model, loader, loss_fn, device, cfg,
133
+ collect_attn=False, n_attn=8, epoch=0):
134
+ model.eval()
135
+ total = 0.0
136
+ nb = 0
137
+ all_preds, all_targets, all_weights = [], [], []
138
+ attn_imgs, all_layers_list, attn_ids = [], [], []
139
+ attn_done = False
140
+
141
+ with torch.no_grad():
142
+ for images, targets, weights, image_ids in tqdm(
143
+ loader, desc=f"Val E{epoch}", leave=False
144
+ ):
145
+ images = images.to(device, non_blocking=True)
146
+ targets = targets.to(device, non_blocking=True)
147
+ weights = weights.to(device, non_blocking=True)
148
+
149
+ with autocast("cuda", enabled=cfg.training.mixed_precision):
150
+ logits = model(images)
151
+ loss, _ = loss_fn(logits, targets, weights)
152
+
153
+ total += loss.item()
154
+ nb += 1
155
+ p, t, w = predictions_to_numpy(logits, targets, weights)
156
+ all_preds.append(p)
157
+ all_targets.append(t)
158
+ all_weights.append(w)
159
+
160
+ if collect_attn and not attn_done:
161
+ all_layers = model.get_all_attention_weights()
162
+ if all_layers is not None:
163
+ n = min(n_attn, images.shape[0])
164
+ attn_imgs.append(images[:n].cpu())
165
+ all_layers_list.append([l[:n].cpu() for l in all_layers])
166
+ attn_ids.extend([int(i) for i in image_ids[:n]])
167
+ if len(attn_ids) >= n_attn:
168
+ attn_done = True
169
+
170
+ all_preds = np.concatenate(all_preds)
171
+ all_targets = np.concatenate(all_targets)
172
+ all_weights = np.concatenate(all_weights)
173
+ metrics = compute_metrics(all_preds, all_targets, all_weights)
174
+
175
+ val_logs = {"val/loss_total": total / nb}
176
+ val_logs.update({f"val/{k}": v for k, v in metrics.items()})
177
+ val_logs["val/reached_mae_w050"] = metrics.get("mae_w050/conditional_avg", 0)
178
+
179
+ attn_data = None
180
+ if collect_attn and attn_imgs:
181
+ attn_data = (
182
+ torch.cat(attn_imgs, dim=0),
183
+ [torch.cat([b[li] for b in all_layers_list], dim=0)
184
+ for li in range(len(all_layers_list[0]))],
185
+ attn_ids,
186
+ )
187
+
188
+ return val_logs, attn_data
189
+
190
+
191
+ # ─────────────────────────────────────────────────────────────
192
+ # Main
193
+ # ─────────────────────────────────────────────────────────────
194
+
195
+ def main():
196
+ parser = argparse.ArgumentParser()
197
+ parser.add_argument("--config", required=True)
198
+ args = parser.parse_args()
199
+
200
+ base_cfg = OmegaConf.load("configs/base.yaml")
201
+ exp_cfg = OmegaConf.load(args.config)
202
+ cfg = OmegaConf.merge(base_cfg, exp_cfg)
203
+
204
+ set_seed(cfg.seed)
205
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
206
+ log.info("Device: %s", device)
207
+
208
+ torch.backends.cuda.matmul.allow_tf32 = True
209
+ torch.backends.cudnn.allow_tf32 = True
210
+
211
+ Path(cfg.outputs.checkpoint_dir).mkdir(parents=True, exist_ok=True)
212
+ Path(cfg.outputs.figures_dir).mkdir(parents=True, exist_ok=True)
213
+ Path(cfg.outputs.log_dir).mkdir(parents=True, exist_ok=True)
214
+
215
+ checkpoint_path = str(
216
+ Path(cfg.outputs.checkpoint_dir) / f"best_{cfg.experiment_name}.pt"
217
+ )
218
+ history_path = str(
219
+ Path(cfg.outputs.log_dir) / f"training_{cfg.experiment_name}_history.csv"
220
+ )
221
+
222
+ if cfg.wandb.enabled:
223
+ wandb.init(
224
+ project=cfg.wandb.project,
225
+ name=cfg.experiment_name,
226
+ config=OmegaConf.to_container(cfg, resolve=True),
227
+ )
228
+
229
+ log.info("Building dataloaders...")
230
+ train_loader, val_loader, _ = build_dataloaders(cfg)
231
+
232
+ log.info("Building model...")
233
+ model = build_model(cfg).to(device)
234
+ loss_fn = HierarchicalLoss(cfg)
235
+
236
+ optimizer = torch.optim.AdamW(
237
+ [
238
+ {"params": model.backbone.parameters(),
239
+ "lr": cfg.training.learning_rate * 0.1},
240
+ {"params": model.head.parameters(),
241
+ "lr": cfg.training.learning_rate},
242
+ ],
243
+ weight_decay=cfg.training.weight_decay,
244
+ )
245
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
246
+ optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min
247
+ )
248
+ scaler = GradScaler("cuda")
249
+ early_stop = EarlyStopping(
250
+ patience = cfg.early_stopping.patience,
251
+ min_delta = cfg.early_stopping.min_delta,
252
+ checkpoint_path = checkpoint_path,
253
+ )
254
+
255
+ log.info("Starting training: %s", cfg.experiment_name)
256
+ history = []
257
+
258
+ for epoch in range(1, cfg.training.epochs + 1):
259
+ train_loss = train_one_epoch(
260
+ model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch
261
+ )
262
+ collect_attn = (epoch % cfg.wandb.log_attention_every_n_epochs == 0)
263
+ val_logs, attn_data = validate(
264
+ model, val_loader, loss_fn, device, cfg,
265
+ collect_attn=collect_attn,
266
+ n_attn=cfg.wandb.n_attention_samples,
267
+ epoch=epoch,
268
+ )
269
+ scheduler.step()
270
+ lr = scheduler.get_last_lr()[0]
271
+
272
+ val_mae = val_logs.get("val/mae/weighted_avg", 0)
273
+ val_loss = val_logs["val/loss_total"]
274
+ reached = val_logs.get("val/reached_mae_w050", 0)
275
+
276
+ log.info(
277
+ "Epoch %d train=%.4f val=%.4f mae=%.4f reached_mae=%.4f lr=%.2e",
278
+ epoch, train_loss, val_loss, val_mae, reached, lr,
279
+ )
280
+
281
+ history.append({
282
+ "epoch" : epoch,
283
+ "train_loss" : train_loss,
284
+ "val_loss" : val_loss,
285
+ "val_mae" : val_mae,
286
+ "reached_mae": reached,
287
+ "lr" : lr,
288
+ })
289
+
290
+ if cfg.wandb.enabled:
291
+ log_dict = {
292
+ "train/loss": train_loss,
293
+ **val_logs,
294
+ "lr": lr, "epoch": epoch,
295
+ }
296
+ if attn_data is not None:
297
+ import matplotlib.pyplot as plt
298
+ imgs, layers, ids = attn_data
299
+ fig = plot_attention_grid(
300
+ imgs, layers, ids,
301
+ save_path=(
302
+ f"{cfg.outputs.figures_dir}/{cfg.experiment_name}/"
303
+ f"attn_epoch{epoch:03d}.png"
304
+ ),
305
+ n_cols=4, rollout_mode="full",
306
+ )
307
+ log_dict["attention/rollout_full"] = wandb.Image(fig)
308
+ plt.close(fig)
309
+ wandb.log(log_dict, step=epoch)
310
+
311
+ if early_stop.step(val_loss, model, epoch):
312
+ log.info("Early stopping at epoch %d best=%d loss=%.6f",
313
+ epoch, early_stop.best_epoch, early_stop.best_loss)
314
+ break
315
+
316
+ # Save history
317
+ pd.DataFrame(history).to_csv(history_path, index=False)
318
+ log.info("Saved history: %s", history_path)
319
+
320
+ early_stop.restore_best(model)
321
+ if cfg.wandb.enabled:
322
+ wandb.finish()
323
+ log.info("Done. Best checkpoint: %s", checkpoint_path)
324
+
325
+
326
+ if __name__ == "__main__":
327
+ main()
src/train_single.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/train_single.py
3
+ -------------------
4
+ Train any single model by name. Designed for running baselines
5
+ one at a time with breaks between them.
6
+
7
+ Available models
8
+ ----------------
9
+ proposed β€” ViT-Base + hierarchical KL+MSE (main model)
10
+ b1_resnet_mse β€” ResNet-18 + independent MSE (sigmoid)
11
+ b2_resnet_kl β€” ResNet-18 + hierarchical KL+MSE
12
+ b3_vit_mse β€” ViT-Base + hierarchical MSE only (no KL)
13
+ b4_vit_dir β€” ViT-Base + Dirichlet NLL (Zoobot-style)
14
+
15
+ Usage
16
+ -----
17
+ # Train proposed model
18
+ python -m src.train_single --model proposed --config configs/full_train.yaml
19
+
20
+ # Train one baseline at a time
21
+ python -m src.train_single --model b1_resnet_mse --config configs/full_train.yaml
22
+ python -m src.train_single --model b2_resnet_kl --config configs/full_train.yaml
23
+ python -m src.train_single --model b3_vit_mse --config configs/full_train.yaml
24
+ python -m src.train_single --model b4_vit_dir --config configs/full_train.yaml
25
+
26
+ # With nohup (recommended)
27
+ nohup python -m src.train_single --model b3_vit_mse \\
28
+ --config configs/full_train.yaml \\
29
+ > outputs/logs/train_b3_vit_mse.log 2>&1 &
30
+ echo "PID: $!"
31
+
32
+ Each model saves its checkpoint independently, so you can run them
33
+ in any order and resume from any point. Already-trained models are
34
+ detected by their checkpoint file and skipped unless --force is passed.
35
+ """
36
+
37
+ import argparse
38
+ import logging
39
+ import sys
40
+ from pathlib import Path
41
+
42
+ import numpy as np
43
+ import torch
44
+ from omegaconf import OmegaConf
45
+
46
+ logging.basicConfig(
47
+ format="%(asctime)s %(levelname)s %(message)s",
48
+ datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
49
+ )
50
+ log = logging.getLogger("train_single")
51
+
52
+ # ── Checkpoint paths per model ─────────────────────────────────────────────────
53
+ CHECKPOINT_NAMES = {
54
+ "proposed" : "best_full_train.pt",
55
+ "b1_resnet_mse" : "baseline_resnet18_mse.pt",
56
+ "b2_resnet_kl" : "baseline_resnet18_klmse.pt",
57
+ "b3_vit_mse" : "baseline_vit_mse.pt",
58
+ "b4_vit_dir" : "baseline_vit_dirichlet.pt",
59
+ }
60
+
61
+ # ── Human-readable labels ──────────────────────────────────────────────────────
62
+ MODEL_LABELS = {
63
+ "proposed" : "ViT-Base + hierarchical KL+MSE (proposed)",
64
+ "b1_resnet_mse" : "ResNet-18 + independent MSE (sigmoid, no hierarchy)",
65
+ "b2_resnet_kl" : "ResNet-18 + hierarchical KL+MSE",
66
+ "b3_vit_mse" : "ViT-Base + hierarchical MSE only (no KL)",
67
+ "b4_vit_dir" : "ViT-Base + Dirichlet NLL (Zoobot-style)",
68
+ }
69
+
70
+
71
+ def train_proposed(cfg, device, ckpt_path):
72
+ """Train the proposed ViT + hierarchical KL+MSE model."""
73
+ from src.train import (
74
+ train_one_epoch, validate, EarlyStopping, set_seed
75
+ )
76
+ from src.dataset import build_dataloaders
77
+ from src.model import build_model
78
+ from src.loss import HierarchicalLoss
79
+ from src.attention_viz import plot_attention_grid
80
+ import pandas as pd
81
+ import wandb
82
+ from torch.amp import GradScaler
83
+ import matplotlib.pyplot as plt
84
+
85
+ set_seed(cfg.seed)
86
+ log.info("Training: %s", MODEL_LABELS["proposed"])
87
+
88
+ Path(cfg.outputs.checkpoint_dir).mkdir(parents=True, exist_ok=True)
89
+ Path(cfg.outputs.figures_dir).mkdir(parents=True, exist_ok=True)
90
+ Path(cfg.outputs.log_dir).mkdir(parents=True, exist_ok=True)
91
+
92
+ history_path = str(
93
+ Path(cfg.outputs.log_dir) / "training_full_train_history.csv"
94
+ )
95
+
96
+ if cfg.wandb.enabled:
97
+ wandb.init(
98
+ project=cfg.wandb.project,
99
+ name=cfg.experiment_name,
100
+ config=OmegaConf.to_container(cfg, resolve=True),
101
+ )
102
+
103
+ train_loader, val_loader, _ = build_dataloaders(cfg)
104
+ model = build_model(cfg).to(device)
105
+ loss_fn = HierarchicalLoss(cfg)
106
+
107
+ optimizer = torch.optim.AdamW(
108
+ [
109
+ {"params": model.backbone.parameters(),
110
+ "lr": cfg.training.learning_rate * 0.1},
111
+ {"params": model.head.parameters(),
112
+ "lr": cfg.training.learning_rate},
113
+ ],
114
+ weight_decay=cfg.training.weight_decay,
115
+ )
116
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
117
+ optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min
118
+ )
119
+ scaler = GradScaler("cuda")
120
+ early_stop = EarlyStopping(
121
+ patience=cfg.early_stopping.patience,
122
+ min_delta=cfg.early_stopping.min_delta,
123
+ checkpoint_path=ckpt_path,
124
+ )
125
+
126
+ history = []
127
+ for epoch in range(1, cfg.training.epochs + 1):
128
+ train_loss = train_one_epoch(
129
+ model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch
130
+ )
131
+ collect_attn = (epoch % cfg.wandb.log_attention_every_n_epochs == 0)
132
+ val_logs, attn_data = validate(
133
+ model, val_loader, loss_fn, device, cfg,
134
+ collect_attn=collect_attn,
135
+ n_attn=cfg.wandb.n_attention_samples,
136
+ epoch=epoch,
137
+ )
138
+ scheduler.step()
139
+ lr = scheduler.get_last_lr()[0]
140
+
141
+ val_mae = val_logs.get("val/mae/weighted_avg", 0)
142
+ val_loss = val_logs["val/loss_total"]
143
+ log.info("Epoch %d train=%.4f val=%.4f mae=%.4f lr=%.2e",
144
+ epoch, train_loss, val_loss, val_mae, lr)
145
+
146
+ history.append({
147
+ "epoch": epoch, "train_loss": train_loss,
148
+ "val_loss": val_loss, "val_mae": val_mae, "lr": lr,
149
+ })
150
+
151
+ if cfg.wandb.enabled:
152
+ log_dict = {"train/loss": train_loss, **val_logs,
153
+ "lr": lr, "epoch": epoch}
154
+ if attn_data is not None:
155
+ imgs, layers, ids = attn_data
156
+ fig = plot_attention_grid(
157
+ imgs, layers, ids,
158
+ save_path=(f"{cfg.outputs.figures_dir}/{cfg.experiment_name}/"
159
+ f"attn_epoch{epoch:03d}.png"),
160
+ n_cols=4, rollout_mode="full",
161
+ )
162
+ log_dict["attention/rollout_full"] = wandb.Image(fig)
163
+ plt.close(fig)
164
+ wandb.log(log_dict, step=epoch)
165
+
166
+ if early_stop.step(val_loss, model, epoch):
167
+ log.info("Early stopping at epoch %d", epoch)
168
+ break
169
+
170
+ pd.DataFrame(history).to_csv(history_path, index=False)
171
+ early_stop.restore_best(model)
172
+ if cfg.wandb.enabled:
173
+ wandb.finish()
174
+ log.info("Done. Checkpoint: %s", ckpt_path)
175
+
176
+
177
+ def train_baseline(cfg, device, ckpt_path, model_key):
178
+ """Train any of the four baselines."""
179
+ import wandb
180
+ from torch.amp import GradScaler
181
+ from src.dataset import build_dataloaders
182
+ from src.model import build_model, build_dirichlet_model
183
+ from src.loss import HierarchicalLoss, DirichletLoss, MSEOnlyLoss
184
+ from src.metrics import (compute_metrics, predictions_to_numpy,
185
+ dirichlet_predictions_to_numpy)
186
+ from src.baselines import (
187
+ ResNet18Baseline, IndependentMSELoss, EarlyStopping,
188
+ set_seed, _train_epoch, _val_epoch,
189
+ _train_epoch_dirichlet, _val_epoch_dirichlet,
190
+ )
191
+ import pandas as pd
192
+ from omegaconf import OmegaConf as OC
193
+
194
+ set_seed(cfg.seed)
195
+ log.info("Training: %s", MODEL_LABELS[model_key])
196
+
197
+ Path(cfg.outputs.checkpoint_dir).mkdir(parents=True, exist_ok=True)
198
+
199
+ # ── Build model and loss ───────────────────────────────────
200
+ if model_key == "b1_resnet_mse":
201
+ model = ResNet18Baseline(dropout=cfg.model.dropout).to(device)
202
+ loss_fn = IndependentMSELoss()
203
+ use_sigmoid = True
204
+ is_dirichlet = False
205
+ use_layerwise_lr = False
206
+ wandb_name = "B1-ResNet18-MSE"
207
+
208
+ elif model_key == "b2_resnet_kl":
209
+ model = ResNet18Baseline(dropout=cfg.model.dropout).to(device)
210
+ loss_fn = HierarchicalLoss(cfg)
211
+ use_sigmoid = False
212
+ is_dirichlet = False
213
+ use_layerwise_lr = False
214
+ wandb_name = "B2-ResNet18-KL+MSE"
215
+
216
+ elif model_key == "b3_vit_mse":
217
+ vit_mse_cfg = OC.merge(
218
+ cfg, OC.create({"loss": {"lambda_kl": 0.0, "lambda_mse": 1.0}})
219
+ )
220
+ model = build_model(vit_mse_cfg).to(device)
221
+ loss_fn = MSEOnlyLoss(vit_mse_cfg)
222
+ cfg = vit_mse_cfg # use updated cfg for optimizer
223
+ use_sigmoid = False
224
+ is_dirichlet = False
225
+ use_layerwise_lr = True
226
+ wandb_name = "B3-ViT-MSE"
227
+
228
+ elif model_key == "b4_vit_dir":
229
+ model = build_dirichlet_model(cfg).to(device)
230
+ loss_fn = DirichletLoss(cfg)
231
+ use_sigmoid = False
232
+ is_dirichlet = True
233
+ use_layerwise_lr = True
234
+ wandb_name = "B4-ViT-Dirichlet"
235
+
236
+ else:
237
+ raise ValueError(f"Unknown model key: {model_key}")
238
+
239
+ total = sum(p.numel() for p in model.parameters())
240
+ log.info("Parameters: %s", f"{total:,}")
241
+
242
+ # ── Optimizer ──────────────────────────────────────────────
243
+ if use_layerwise_lr and hasattr(model, "backbone") and hasattr(model, "head"):
244
+ optimizer = torch.optim.AdamW(
245
+ [
246
+ {"params": model.backbone.parameters(),
247
+ "lr": cfg.training.learning_rate * 0.1},
248
+ {"params": model.head.parameters(),
249
+ "lr": cfg.training.learning_rate},
250
+ ],
251
+ weight_decay=cfg.training.weight_decay,
252
+ )
253
+ else:
254
+ optimizer = torch.optim.AdamW(
255
+ model.parameters(),
256
+ lr=cfg.training.learning_rate,
257
+ weight_decay=cfg.training.weight_decay,
258
+ )
259
+
260
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
261
+ optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min
262
+ )
263
+ scaler = GradScaler("cuda")
264
+ early_stop = EarlyStopping(
265
+ patience=cfg.early_stopping.patience,
266
+ min_delta=cfg.early_stopping.min_delta,
267
+ checkpoint_path=ckpt_path,
268
+ )
269
+
270
+ train_loader, val_loader, test_loader = build_dataloaders(cfg)
271
+
272
+ wandb.init(
273
+ project=cfg.wandb.project, name=wandb_name,
274
+ config={"model": wandb_name, "seed": cfg.seed,
275
+ "epochs": cfg.training.epochs,
276
+ "lambda_kl": cfg.loss.lambda_kl},
277
+ reinit=True,
278
+ )
279
+
280
+ # ── Training loop ──────────────────────────────────────────
281
+ history = []
282
+ for epoch in range(1, cfg.training.epochs + 1):
283
+ if is_dirichlet:
284
+ train_loss = _train_epoch_dirichlet(
285
+ model, train_loader, loss_fn, optimizer, scaler,
286
+ device, cfg, epoch, wandb_name
287
+ )
288
+ val_loss, val_metrics = _val_epoch_dirichlet(
289
+ model, val_loader, loss_fn, device, cfg, epoch, wandb_name
290
+ )
291
+ else:
292
+ train_loss = _train_epoch(
293
+ model, train_loader, loss_fn, optimizer, scaler,
294
+ device, cfg, epoch, wandb_name
295
+ )
296
+ val_loss, val_metrics = _val_epoch(
297
+ model, val_loader, loss_fn, device, cfg, epoch, wandb_name,
298
+ use_sigmoid=use_sigmoid
299
+ )
300
+
301
+ scheduler.step()
302
+ lr = scheduler.get_last_lr()[0]
303
+ val_mae = val_metrics.get("mae/weighted_avg", 0)
304
+
305
+ log.info("%s epoch=%d train=%.4f val=%.4f mae=%.4f lr=%.2e",
306
+ wandb_name, epoch, train_loss, val_loss, val_mae, lr)
307
+ history.append({
308
+ "epoch": epoch, "train_loss": train_loss,
309
+ "val_loss": val_loss, "val_mae": val_mae,
310
+ })
311
+ wandb.log({
312
+ "train_loss": train_loss, "val_loss": val_loss,
313
+ "val_mae": val_mae, "lr": lr,
314
+ }, step=epoch)
315
+
316
+ if early_stop.step(val_loss, model, epoch):
317
+ log.info("%s: early stopping at epoch %d", wandb_name, epoch)
318
+ break
319
+
320
+ best_val = early_stop.restore_best(model)
321
+ wandb.finish()
322
+
323
+ # ── Test evaluation ────────────────────────────────────────
324
+ log.info("Evaluating on test set...")
325
+ if is_dirichlet:
326
+ _, test_metrics = _val_epoch_dirichlet(
327
+ model, test_loader, loss_fn, device, cfg,
328
+ epoch=0, label=f"{wandb_name}-test"
329
+ )
330
+ else:
331
+ _, test_metrics = _val_epoch(
332
+ model, test_loader, loss_fn, device, cfg,
333
+ epoch=0, label=f"{wandb_name}-test", use_sigmoid=use_sigmoid
334
+ )
335
+
336
+ log.info("%s β€” Test MAE=%.5f RMSE=%.5f",
337
+ wandb_name,
338
+ test_metrics["mae/weighted_avg"],
339
+ test_metrics["rmse/weighted_avg"])
340
+
341
+ # ── Save per-model history ─────────────────────────────────
342
+ hist_path = Path(cfg.outputs.log_dir) / f"training_{model_key}_history.csv"
343
+ pd.DataFrame(history).to_csv(hist_path, index=False)
344
+ log.info("History saved: %s", hist_path)
345
+ log.info("Done. Checkpoint: %s", ckpt_path)
346
+
347
+ return test_metrics, best_val, early_stop.best_epoch
348
+
349
+
350
+ # ─────────────────────────────────────────────────────────────
351
+ # Main
352
+ # ─────────────────────────────────────────────────────────────
353
+
354
+ def main():
355
+ parser = argparse.ArgumentParser(
356
+ description="Train a single model. Run multiple times to train "
357
+ "different models with breaks in between."
358
+ )
359
+ parser.add_argument(
360
+ "--model",
361
+ required=True,
362
+ choices=list(CHECKPOINT_NAMES.keys()),
363
+ help=(
364
+ "Which model to train:\n"
365
+ " proposed β€” ViT-Base + hierarchical KL+MSE (main)\n"
366
+ " b1_resnet_mse β€” ResNet-18 + independent MSE (sigmoid)\n"
367
+ " b2_resnet_kl β€” ResNet-18 + hierarchical KL+MSE\n"
368
+ " b3_vit_mse β€” ViT-Base + hierarchical MSE only\n"
369
+ " b4_vit_dir β€” ViT-Base + Dirichlet NLL\n"
370
+ ),
371
+ )
372
+ parser.add_argument("--config", required=True)
373
+ parser.add_argument(
374
+ "--force",
375
+ action="store_true",
376
+ help="Retrain even if checkpoint already exists.",
377
+ )
378
+ args = parser.parse_args()
379
+
380
+ base_cfg = OmegaConf.load("configs/base.yaml")
381
+ exp_cfg = OmegaConf.load(args.config)
382
+ cfg = OmegaConf.merge(base_cfg, exp_cfg)
383
+
384
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
385
+ ckpt_dir = Path(cfg.outputs.checkpoint_dir)
386
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
387
+ Path(cfg.outputs.log_dir).mkdir(parents=True, exist_ok=True)
388
+
389
+ torch.backends.cuda.matmul.allow_tf32 = True
390
+ torch.backends.cudnn.allow_tf32 = True
391
+
392
+ ckpt_path = str(ckpt_dir / CHECKPOINT_NAMES[args.model])
393
+
394
+ # ── Skip if already done ───────────────────────────────────
395
+ if Path(ckpt_path).exists() and not args.force:
396
+ log.info("Checkpoint already exists: %s", ckpt_path)
397
+ log.info("Model '%s' is already trained. Skipping.", args.model)
398
+ log.info("Use --force to retrain.")
399
+ return
400
+
401
+ log.info("=" * 60)
402
+ log.info("Training: %s", MODEL_LABELS[args.model])
403
+ log.info("Device : %s", device)
404
+ log.info("Config : %s", args.config)
405
+ log.info("Ckpt : %s", ckpt_path)
406
+ log.info("=" * 60)
407
+
408
+ if args.model == "proposed":
409
+ train_proposed(cfg, device, ckpt_path)
410
+ else:
411
+ train_baseline(cfg, device, ckpt_path, args.model)
412
+
413
+ log.info("=" * 60)
414
+ log.info("FINISHED: %s", MODEL_LABELS[args.model])
415
+ log.info("=" * 60)
416
+
417
+
418
+ if __name__ == "__main__":
419
+ main()
src/uncertainty_analysis.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/uncertainty_analysis.py
3
+ ----------------------------
4
+ MC Dropout epistemic uncertainty analysis for the proposed model.
5
+
6
+ MC Dropout (Gal & Ghahramani 2016) is used as a post-hoc uncertainty
7
+ estimator. At inference time, dropout is kept active and N=30 stochastic
8
+ forward passes are run per batch. The standard deviation across passes
9
+ is used as the epistemic uncertainty estimate per galaxy per question.
10
+
11
+ Key findings reported
12
+ ---------------------
13
+ 1. Uncertainty distributions: right-skewed, well-separated means across
14
+ questions reflecting the conditional nature of the decision tree.
15
+
16
+ 2. Uncertainty vs. error correlation: Spearman ρ reported per question.
17
+ Strong positive correlation for root and shallow-branch questions
18
+ (t01, t02, t04, t07) indicates the model is well-calibrated in
19
+ uncertainty. Weak or near-zero correlation for deep conditional
20
+ branches (t03, t05, t08, t09, t10, t11) is expected β€” these branches
21
+ have small effective sample sizes and aleatoric uncertainty dominates.
22
+
23
+ 3. Morphology selection benchmark: F1 score at threshold Ο„ for downstream
24
+ binary morphology classification tasks.
25
+
26
+ Output files
27
+ ------------
28
+ outputs/figures/uncertainty/
29
+ fig_uncertainty_distributions.pdf
30
+ fig_uncertainty_vs_error.pdf
31
+ fig_morphology_f1_comparison.pdf
32
+ table_uncertainty_summary.csv
33
+ table_morphology_selection_benchmark.csv
34
+ mc_cache/ β€” cached numpy arrays (crash-safe)
35
+
36
+ Usage
37
+ -----
38
+ cd ~/galaxy
39
+ nohup python -m src.uncertainty_analysis \
40
+ --config configs/full_train.yaml --n_passes 30 \
41
+ > outputs/logs/uncertainty.log 2>&1 &
42
+ echo "PID: $!"
43
+ """
44
+
45
+ import argparse
46
+ import logging
47
+ import sys
48
+ from pathlib import Path
49
+
50
+ import numpy as np
51
+ import pandas as pd
52
+ import torch
53
+ import torch.nn.functional as F
54
+ import matplotlib
55
+ matplotlib.use("Agg")
56
+ import matplotlib.pyplot as plt
57
+ from scipy import stats as scipy_stats
58
+ from torch.amp import autocast
59
+ from omegaconf import OmegaConf
60
+ from tqdm import tqdm
61
+
62
+ from src.dataset import build_dataloaders, QUESTION_GROUPS
63
+ from src.model import build_model, build_dirichlet_model
64
+ from src.baselines import ResNet18Baseline
65
+ from src.metrics import predictions_to_numpy, dirichlet_predictions_to_numpy
66
+
67
+ logging.basicConfig(
68
+ format="%(asctime)s %(levelname)s %(message)s",
69
+ datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
70
+ )
71
+ log = logging.getLogger("uncertainty")
72
+
73
+ plt.rcParams.update({
74
+ "figure.dpi": 150, "savefig.dpi": 300,
75
+ "font.family": "serif", "font.size": 11,
76
+ "axes.titlesize": 10, "axes.labelsize": 10,
77
+ "xtick.labelsize": 8, "ytick.labelsize": 8,
78
+ "legend.fontsize": 8,
79
+ "figure.facecolor": "white", "axes.facecolor": "white",
80
+ "axes.grid": True, "grid.alpha": 0.3,
81
+ "pdf.fonttype": 42, "ps.fonttype": 42,
82
+ })
83
+
84
+ QUESTION_LABELS = {
85
+ "t01": "Smooth or features", "t02": "Edge-on disk",
86
+ "t03": "Bar", "t04": "Spiral arms",
87
+ "t05": "Bulge prominence", "t06": "Odd feature",
88
+ "t07": "Roundedness", "t08": "Odd feature type",
89
+ "t09": "Bulge shape", "t10": "Arms winding",
90
+ "t11": "Arms number",
91
+ }
92
+
93
+ MODEL_COLORS = {
94
+ "ViT-Base + KL+MSE (proposed)" : "#27ae60",
95
+ "ViT-Base + Dirichlet (Zoobot-style)": "#8e44ad",
96
+ "ResNet-18 + MSE (sigmoid)" : "#c0392b",
97
+ "ResNet-18 + KL+MSE" : "#e67e22",
98
+ }
99
+
100
+ SELECTION_THRESHOLDS = [0.5, 0.7, 0.8, 0.9]
101
+ SELECTION_ANSWERS = {
102
+ "t01": (0, "smooth"),
103
+ "t02": (0, "edge-on"),
104
+ "t03": (0, "bar"),
105
+ "t04": (0, "spiral"),
106
+ "t06": (0, "odd feature"),
107
+ }
108
+
109
+
110
+ # ─────────────────────────────────────────────────────────────
111
+ # MC Dropout inference β€” Welford online algorithm, crash-safe
112
+ # ─────────────────────────────────────────────────────────────
113
+
114
+ def run_mc_inference(model, loader, device, cfg,
115
+ n_passes=30, cache_dir=None):
116
+ """
117
+ Fast batched MC Dropout inference.
118
+
119
+ Uses Welford's online algorithm to compute mean and std
120
+ per batch without storing all n_passes Γ— N predictions.
121
+ Memory usage: O(N Γ— 37) regardless of n_passes.
122
+
123
+ Parameters
124
+ ----------
125
+ model : GalaxyViT with enable_mc_dropout() available
126
+ loader : test DataLoader
127
+ device : inference device
128
+ cfg : OmegaConf config
129
+ n_passes : number of stochastic forward passes (default 30)
130
+ cache_dir : if given, saves .npy files and skips if they exist
131
+
132
+ Returns
133
+ -------
134
+ mean_all, std_all : [N, 37] float32
135
+ targets_all : [N, 37] float32
136
+ weights_all : [N, 11] float32
137
+ """
138
+ if cache_dir is not None:
139
+ cache_dir = Path(cache_dir)
140
+ cache_dir.mkdir(parents=True, exist_ok=True)
141
+ fp_mean = cache_dir / "mc_mean.npy"
142
+ fp_std = cache_dir / "mc_std.npy"
143
+ fp_targets = cache_dir / "mc_targets.npy"
144
+ fp_weights = cache_dir / "mc_weights.npy"
145
+
146
+ if all(p.exists() for p in [fp_mean, fp_std, fp_targets, fp_weights]):
147
+ log.info("MC cache found β€” loading from disk (skipping inference).")
148
+ return (np.load(fp_mean), np.load(fp_std),
149
+ np.load(fp_targets), np.load(fp_weights))
150
+
151
+ model.eval()
152
+ model.enable_mc_dropout()
153
+
154
+ all_means, all_stds, all_targets, all_weights = [], [], [], []
155
+ log.info("MC Dropout: %d passes Γ— %d-image batches = %d total forward passes",
156
+ n_passes, loader.batch_size, n_passes * len(loader))
157
+
158
+ for images, targets, weights, _ in tqdm(loader, desc="MC Dropout"):
159
+ images_dev = images.to(device, non_blocking=True)
160
+
161
+ # Welford online mean and M2
162
+ mean_acc = None
163
+ M2_acc = None
164
+ count = 0
165
+
166
+ for _ in range(n_passes):
167
+ with torch.no_grad():
168
+ with autocast("cuda", enabled=cfg.training.mixed_precision):
169
+ logits = model(images_dev)
170
+
171
+ pred = torch.zeros_like(logits)
172
+ for q, (s, e) in QUESTION_GROUPS.items():
173
+ pred[:, s:e] = F.softmax(logits[:, s:e], dim=-1)
174
+ pred_np = pred.cpu().float().numpy() # [B, 37]
175
+
176
+ count += 1
177
+ if mean_acc is None:
178
+ mean_acc = pred_np.copy()
179
+ M2_acc = np.zeros_like(pred_np)
180
+ else:
181
+ delta = pred_np - mean_acc
182
+ mean_acc += delta / count
183
+ M2_acc += delta * (pred_np - mean_acc)
184
+
185
+ std_acc = np.sqrt(M2_acc / (count - 1) if count > 1
186
+ else np.zeros_like(M2_acc))
187
+
188
+ all_means.append(mean_acc)
189
+ all_stds.append(std_acc)
190
+ all_targets.append(targets.numpy())
191
+ all_weights.append(weights.numpy())
192
+
193
+ model.disable_mc_dropout()
194
+
195
+ mean_all = np.concatenate(all_means)
196
+ std_all = np.concatenate(all_stds)
197
+ targets_all = np.concatenate(all_targets)
198
+ weights_all = np.concatenate(all_weights)
199
+
200
+ if cache_dir is not None:
201
+ np.save(fp_mean, mean_all)
202
+ np.save(fp_std, std_all)
203
+ np.save(fp_targets, targets_all)
204
+ np.save(fp_weights, weights_all)
205
+ log.info("MC results cached: %s", cache_dir)
206
+
207
+ return mean_all, std_all, targets_all, weights_all
208
+
209
+
210
+ # ─────────────────────────────────────────────────────────────
211
+ # Figure 1: Uncertainty distributions
212
+ # ─────────────────────────────────────────────────────────────
213
+
214
+ def fig_uncertainty_distributions(mean_preds, std_preds,
215
+ targets, weights, save_dir):
216
+ path_pdf = save_dir / "fig_uncertainty_distributions.pdf"
217
+ path_png = save_dir / "fig_uncertainty_distributions.png"
218
+ if path_pdf.exists() and path_png.exists():
219
+ log.info("Skip (exists): fig_uncertainty_distributions"); return
220
+
221
+ fig, axes = plt.subplots(3, 4, figsize=(16, 12))
222
+ axes = axes.flatten()
223
+
224
+ for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
225
+ ax = axes[q_idx]
226
+ mask = weights[:, q_idx] >= 0.05
227
+ std_q = std_preds[mask, start:end].mean(axis=1)
228
+
229
+ ax.hist(std_q, bins=50, color="#6366f1", alpha=0.85,
230
+ edgecolor="none", density=True)
231
+ ax.axvline(std_q.mean(), color="#c0392b", linewidth=1.8,
232
+ linestyle="--", label=f"Mean = {std_q.mean():.4f}")
233
+ ax.set_xlabel("MC Dropout std (epistemic uncertainty)")
234
+ ax.set_ylabel("Density")
235
+ ax.set_title(
236
+ f"{q_name}: {QUESTION_LABELS[q_name]}\n"
237
+ f"$n$ = {mask.sum():,} (w β‰₯ 0.05)",
238
+ fontsize=9,
239
+ )
240
+ ax.legend(fontsize=7)
241
+
242
+ axes[-1].axis("off")
243
+ plt.suptitle(
244
+ "Epistemic uncertainty distributions β€” MC Dropout (30 passes)\n"
245
+ "Proposed model (ViT-Base/16 + hierarchical KL+MSE), test set",
246
+ fontsize=12,
247
+ )
248
+ plt.tight_layout()
249
+ fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
250
+ fig.savefig(path_png, dpi=300, bbox_inches="tight")
251
+ plt.close(fig)
252
+ log.info("Saved: fig_uncertainty_distributions")
253
+
254
+
255
+ # ─────────────────────────────────────────────────────────────
256
+ # Figure 2: Uncertainty vs. error (Spearman ρ)
257
+ # ─────────────────────────────────────────────────────────────
258
+
259
+ def fig_uncertainty_vs_error(mean_preds, std_preds,
260
+ targets, weights, save_dir):
261
+ path_pdf = save_dir / "fig_uncertainty_vs_error.pdf"
262
+ path_png = save_dir / "fig_uncertainty_vs_error.png"
263
+ if path_pdf.exists() and path_png.exists():
264
+ log.info("Skip (exists): fig_uncertainty_vs_error"); return
265
+
266
+ fig, axes = plt.subplots(3, 4, figsize=(16, 12))
267
+ axes = axes.flatten()
268
+
269
+ for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
270
+ ax = axes[q_idx]
271
+ mask = weights[:, q_idx] >= 0.05
272
+ unc = std_preds[mask, start:end].mean(axis=1)
273
+ err = np.abs(mean_preds[mask, start:end] -
274
+ targets[mask, start:end]).mean(axis=1)
275
+
276
+ # Adaptive bin means for trend line
277
+ n_bins = 15
278
+ unc_bins = np.unique(np.percentile(unc, np.linspace(0, 100, n_bins + 1)))
279
+ bin_ids = np.clip(np.digitize(unc, unc_bins) - 1, 0, len(unc_bins) - 2)
280
+ bn_unc = [unc[bin_ids == b].mean() for b in range(len(unc_bins) - 1)
281
+ if (bin_ids == b).any()]
282
+ bn_err = [err[bin_ids == b].mean() for b in range(len(unc_bins) - 1)
283
+ if (bin_ids == b).any()]
284
+
285
+ ax.scatter(unc, err, alpha=0.04, s=1, color="#94a3b8", rasterized=True)
286
+ ax.plot(bn_unc, bn_err, "r-o", markersize=4, linewidth=2,
287
+ label="Bin mean")
288
+
289
+ # Spearman rank correlation (more robust than Pearson for this data)
290
+ rho, pval = scipy_stats.spearmanr(unc, err)
291
+ p_str = f"p < 0.001" if pval < 0.001 else f"p = {pval:.3f}"
292
+ ax.text(0.05, 0.90,
293
+ f"Spearman ρ = {rho:.3f}\n{p_str}",
294
+ transform=ax.transAxes, fontsize=7.5,
295
+ bbox=dict(boxstyle="round,pad=0.25", facecolor="white",
296
+ edgecolor="grey", alpha=0.85))
297
+
298
+ ax.set_xlabel("Uncertainty (MC std)")
299
+ ax.set_ylabel("Absolute error")
300
+ ax.set_title(f"{q_name}: {QUESTION_LABELS[q_name]}", fontsize=9)
301
+ ax.legend(fontsize=7)
302
+
303
+ axes[-1].axis("off")
304
+ plt.suptitle(
305
+ "Epistemic uncertainty vs. absolute prediction error β€” per morphological question\n"
306
+ "Strong Spearman ρ for root/shallow questions; weak ρ for deep conditional branches "
307
+ "(expected: aleatoric uncertainty dominates when branch is rarely reached)",
308
+ fontsize=10,
309
+ )
310
+ plt.tight_layout()
311
+ fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
312
+ fig.savefig(path_png, dpi=300, bbox_inches="tight")
313
+ plt.close(fig)
314
+ log.info("Saved: fig_uncertainty_vs_error")
315
+
316
+
317
+ # ─────────────────────────────────────────────────────────────
318
+ # Table: uncertainty summary
319
+ # ─────────────────────────────────────────────────────────────
320
+
321
+ def table_uncertainty_summary(mean_preds, std_preds,
322
+ targets, weights, save_dir):
323
+ path = save_dir / "table_uncertainty_summary.csv"
324
+ if path.exists():
325
+ log.info("Skip (exists): table_uncertainty_summary"); return
326
+
327
+ rows = []
328
+ for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
329
+ mask = weights[:, q_idx] >= 0.05
330
+ unc = std_preds[mask, start:end].mean(axis=1)
331
+ err = np.abs(mean_preds[mask, start:end] -
332
+ targets[mask, start:end]).mean(axis=1)
333
+
334
+ if mask.sum() > 10:
335
+ rho, pval = scipy_stats.spearmanr(unc, err)
336
+ else:
337
+ rho, pval = float("nan"), float("nan")
338
+
339
+ rows.append({
340
+ "question" : q_name,
341
+ "description" : QUESTION_LABELS[q_name],
342
+ "n_reached" : int(mask.sum()),
343
+ "mean_uncertainty": round(float(unc.mean()), 5),
344
+ "std_uncertainty" : round(float(unc.std()), 5),
345
+ "mean_mae" : round(float(err.mean()), 5),
346
+ "spearman_rho" : round(float(rho), 4),
347
+ "spearman_pval" : round(float(pval), 4),
348
+ })
349
+
350
+ df = pd.DataFrame(rows)
351
+ df.to_csv(path, index=False)
352
+ log.info("Saved: table_uncertainty_summary.csv")
353
+ print("\n" + df.to_string(index=False) + "\n")
354
+ return df
355
+
356
+
357
+ # ─────────────────────────────────────────────────────────────
358
+ # Figure 3 + Table: Morphology selection benchmark
359
+ # ─────────────────────────────────────────────────────────────
360
+
361
+ def morphology_selection_benchmark(model_results, save_dir):
362
+ csv_path = save_dir / "table_morphology_selection_benchmark.csv"
363
+ if csv_path.exists():
364
+ log.info("Loading existing morphology benchmark...")
365
+ df = pd.read_csv(csv_path)
366
+ _fig_morphology_f1(df, save_dir)
367
+ return df
368
+
369
+ rows = []
370
+ for model_name, (preds, targets, weights) in model_results.items():
371
+ for q_name, (ans_idx, ans_label) in SELECTION_ANSWERS.items():
372
+ start, end = QUESTION_GROUPS[q_name]
373
+ q_idx = list(QUESTION_GROUPS.keys()).index(q_name)
374
+ mask = weights[:, q_idx] >= 0.05
375
+ pred_a = preds[mask, start + ans_idx]
376
+ true_a = targets[mask, start + ans_idx]
377
+
378
+ for thresh in SELECTION_THRESHOLDS:
379
+ sel = pred_a >= thresh
380
+ true_pos = true_a >= thresh
381
+ n_sel = sel.sum()
382
+ n_tp_all = true_pos.sum()
383
+ n_tp = (sel & true_pos).sum()
384
+ prec = n_tp / n_sel if n_sel > 0 else 0.0
385
+ rec = n_tp / n_tp_all if n_tp_all > 0 else 0.0
386
+ f1 = (2 * prec * rec / (prec + rec)
387
+ if (prec + rec) > 0 else 0.0)
388
+ rows.append({
389
+ "model" : model_name,
390
+ "question" : q_name,
391
+ "answer" : ans_label,
392
+ "threshold" : thresh,
393
+ "n_selected": int(n_sel),
394
+ "n_true_pos": int(n_tp_all),
395
+ "precision" : round(float(prec), 4),
396
+ "recall" : round(float(rec), 4),
397
+ "f1" : round(float(f1), 4),
398
+ })
399
+
400
+ df = pd.DataFrame(rows)
401
+ df.to_csv(csv_path, index=False)
402
+ log.info("Saved: table_morphology_selection_benchmark.csv")
403
+ _fig_morphology_f1(df, save_dir)
404
+ return df
405
+
406
+
407
+ def _fig_morphology_f1(df, save_dir):
408
+ path_pdf = save_dir / "fig_morphology_f1_comparison.pdf"
409
+ path_png = save_dir / "fig_morphology_f1_comparison.png"
410
+ if path_pdf.exists() and path_png.exists():
411
+ log.info("Skip (exists): fig_morphology_f1_comparison"); return
412
+
413
+ thresh = 0.8
414
+ sub = df[df["threshold"] == thresh]
415
+ q_list = list(SELECTION_ANSWERS.keys())
416
+ models = list(df["model"].unique())
417
+
418
+ x = np.arange(len(q_list))
419
+ width = 0.80 / len(models)
420
+ palette = list(MODEL_COLORS.values())
421
+
422
+ fig, ax = plt.subplots(figsize=(12, 5))
423
+ for i, model in enumerate(models):
424
+ f1s = []
425
+ for q in q_list:
426
+ row = sub[(sub["model"] == model) & (sub["question"] == q)]
427
+ f1s.append(float(row["f1"].values[0]) if len(row) > 0 else 0.0)
428
+ ax.bar(
429
+ x + i * width, f1s, width,
430
+ label=model,
431
+ color=MODEL_COLORS.get(model, palette[i % len(palette)]),
432
+ alpha=0.85, edgecolor="white", linewidth=0.5,
433
+ )
434
+
435
+ ax.set_xticks(x + width * (len(models) - 1) / 2)
436
+ ax.set_xticklabels(
437
+ [f"{q}\n({SELECTION_ANSWERS[q][1]})" for q in q_list], fontsize=9
438
+ )
439
+ ax.set_ylabel("F$_1$ score", fontsize=11)
440
+ ax.set_title(
441
+ f"Downstream morphology selection β€” F$_1$ at threshold $\\tau$ = {thresh}\n"
442
+ "Higher F$_1$ indicates cleaner, more complete morphological sample selection.",
443
+ fontsize=11,
444
+ )
445
+ ax.legend(fontsize=8)
446
+ ax.set_ylim(0, 1)
447
+ ax.grid(True, alpha=0.3, axis="y")
448
+ ax.set_axisbelow(True)
449
+ plt.tight_layout()
450
+ fig.savefig(path_pdf, dpi=300, bbox_inches="tight")
451
+ fig.savefig(path_png, dpi=300, bbox_inches="tight")
452
+ plt.close(fig)
453
+ log.info("Saved: fig_morphology_f1_comparison")
454
+
455
+
456
+ # ─────────────────────────────────────────────────────────────
457
+ # Main
458
+ # ─────────────────────────────────────────────────────────────
459
+
460
+ def main():
461
+ parser = argparse.ArgumentParser()
462
+ parser.add_argument("--config", required=True)
463
+ parser.add_argument("--n_passes", type=int, default=30)
464
+ args = parser.parse_args()
465
+
466
+ base_cfg = OmegaConf.load("configs/base.yaml")
467
+ exp_cfg = OmegaConf.load(args.config)
468
+ cfg = OmegaConf.merge(base_cfg, exp_cfg)
469
+
470
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
471
+ save_dir = Path(cfg.outputs.figures_dir) / "uncertainty"
472
+ save_dir.mkdir(parents=True, exist_ok=True)
473
+ cache_dir = save_dir / "mc_cache"
474
+ ckpt_dir = Path(cfg.outputs.checkpoint_dir)
475
+
476
+ _, _, test_loader = build_dataloaders(cfg)
477
+
478
+ # ── 1. MC Dropout on proposed model ───────────────────────
479
+ log.info("Loading proposed model...")
480
+ proposed = build_model(cfg).to(device)
481
+ proposed.load_state_dict(
482
+ torch.load(ckpt_dir / "best_full_train.pt",
483
+ map_location="cpu", weights_only=True)["model_state"]
484
+ )
485
+
486
+ mean_preds, std_preds, targets, weights = run_mc_inference(
487
+ proposed, test_loader, device, cfg,
488
+ n_passes=args.n_passes, cache_dir=cache_dir,
489
+ )
490
+ log.info("MC Dropout complete: %d galaxies, %d passes.",
491
+ len(mean_preds), args.n_passes)
492
+
493
+ # ── 2. Uncertainty figures and table ──────────────────────
494
+ fig_uncertainty_distributions(mean_preds, std_preds, targets, weights, save_dir)
495
+ fig_uncertainty_vs_error(mean_preds, std_preds, targets, weights, save_dir)
496
+ table_uncertainty_summary(mean_preds, std_preds, targets, weights, save_dir)
497
+
498
+ # ── 3. Downstream benchmark across all models ─────────────
499
+ log.info("Building model_results for downstream benchmark...")
500
+
501
+ model_results = {
502
+ "ViT-Base + KL+MSE (proposed)": (mean_preds, targets, weights),
503
+ }
504
+
505
+ def _load_resnet(ckpt_name, use_sigmoid):
506
+ m = ResNet18Baseline(dropout=cfg.model.dropout).to(device)
507
+ m.load_state_dict(
508
+ torch.load(ckpt_dir / ckpt_name, map_location="cpu",
509
+ weights_only=True)["model_state"]
510
+ )
511
+ m.eval()
512
+ preds_l, tgts_l, wgts_l = [], [], []
513
+ with torch.no_grad():
514
+ for images, tgts, wgts, _ in tqdm(test_loader, desc=f"ResNet {ckpt_name}"):
515
+ images = images.to(device, non_blocking=True)
516
+ with autocast("cuda", enabled=cfg.training.mixed_precision):
517
+ logits = m(images)
518
+ if use_sigmoid:
519
+ p = torch.sigmoid(logits).cpu().numpy()
520
+ else:
521
+ p = logits.detach().cpu().clone()
522
+ for q, (s, e) in QUESTION_GROUPS.items():
523
+ p[:, s:e] = F.softmax(p[:, s:e], dim=-1)
524
+ p = p.numpy()
525
+ preds_l.append(p)
526
+ tgts_l.append(tgts.numpy())
527
+ wgts_l.append(wgts.numpy())
528
+ return (np.concatenate(preds_l),
529
+ np.concatenate(tgts_l),
530
+ np.concatenate(wgts_l))
531
+
532
+ rn_mse_ckpt = "baseline_resnet18_mse.pt"
533
+ rn_klm_ckpt = "baseline_resnet18_klmse.pt"
534
+
535
+ if (ckpt_dir / rn_mse_ckpt).exists():
536
+ model_results["ResNet-18 + MSE (sigmoid)"] = _load_resnet(
537
+ rn_mse_ckpt, use_sigmoid=True
538
+ )
539
+ if (ckpt_dir / rn_klm_ckpt).exists():
540
+ model_results["ResNet-18 + KL+MSE"] = _load_resnet(
541
+ rn_klm_ckpt, use_sigmoid=False
542
+ )
543
+
544
+ dp = ckpt_dir / "baseline_vit_dirichlet.pt"
545
+ if dp.exists():
546
+ vit_dir = build_dirichlet_model(cfg).to(device)
547
+ vit_dir.load_state_dict(
548
+ torch.load(dp, map_location="cpu", weights_only=True)["model_state"]
549
+ )
550
+ vit_dir.eval()
551
+ d_p, d_t, d_w = [], [], []
552
+ with torch.no_grad():
553
+ for images, tgts, wgts, _ in tqdm(test_loader, desc="Dirichlet"):
554
+ images = images.to(device, non_blocking=True)
555
+ with autocast("cuda", enabled=cfg.training.mixed_precision):
556
+ alpha = vit_dir(images)
557
+ p, t, w = dirichlet_predictions_to_numpy(alpha, tgts, wgts)
558
+ d_p.append(p); d_t.append(t); d_w.append(w)
559
+ model_results["ViT-Base + Dirichlet (Zoobot-style)"] = (
560
+ np.concatenate(d_p),
561
+ np.concatenate(d_t),
562
+ np.concatenate(d_w),
563
+ )
564
+
565
+ df_sel = morphology_selection_benchmark(model_results, save_dir)
566
+
567
+ log.info("=" * 60)
568
+ log.info("DOWNSTREAM F1 @ Ο„ = 0.8")
569
+ log.info("=" * 60)
570
+ summary = df_sel[df_sel["threshold"] == 0.8][
571
+ ["model", "question", "answer", "precision", "recall", "f1"]
572
+ ]
573
+ log.info("\n%s\n", summary.to_string(index=False))
574
+ log.info("All outputs saved to: %s", save_dir)
575
+
576
+
577
+ if __name__ == "__main__":
578
+ main()