eshwar-gz2-api / src /model.py
sreshwarprasad's picture
Upload folder using huggingface_hub
e36eee4 verified
"""
src/model.py
------------
Vision Transformer (ViT-Base/16) backbone with three head variants:
1. GalaxyViT β€” linear regression head (37 logits). Proposed model.
2. GalaxyViTDirichlet β€” Dirichlet concentration head (Zoobot-style baseline).
3. mc_dropout_predict β€” MC Dropout uncertainty estimation wrapper.
Architecture
------------
Backbone : vit_base_patch16_224 from timm (pretrained ImageNet-21k)
12 transformer layers, 12 heads, embed_dim=768
Input : [B, 3, 224, 224]
CLS out: [B, 768]
Head : Dropout(p) β†’ Linear(768, 37)
Full multi-layer attention rollout
------------------------------------
All 12 transformer blocks use fused_attn=False so forward hooks can
capture the post-softmax attention matrices. Rollout is computed in
attention_viz.py using the corrected right-multiplication order.
MC Dropout
-----------
enable_mc_dropout() keeps Dropout active at inference time.
Running N stochastic forward passes gives mean prediction and
per-answer std (epistemic uncertainty). N=30 is standard practice
per Gal & Ghahramani (2016).
Dirichlet head
--------------
Outputs Ξ± > 1 per answer via: Ξ± = 1 + softplus(linear(features))
Matches the Zoobot approach for a fair direct comparison.
Mean vote fraction: E[p_q] = Ξ±_q / sum(Ξ±_q).
References
----------
Gal & Ghahramani (2016). Dropout as a Bayesian Approximation.
ICML 2016. https://arxiv.org/abs/1506.02142
Walmsley et al. (2022). Towards Galaxy Foundation Models.
MNRAS 509, 3966. https://arxiv.org/abs/2110.12735
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import numpy as np
from omegaconf import DictConfig
from typing import Optional, List, Tuple
from src.dataset import QUESTION_GROUPS
# ─────────────────────────────────────────────────────────────
# Attention hook manager
# ─────────────────────────────────────────────────────────────
class AttentionHookManager:
"""
Registers forward hooks on all transformer blocks to capture
post-softmax attention matrices for full rollout computation.
With fused_attn=False, timm's attention block executes:
attn = softmax(q @ k.T / scale) # [B, H, N+1, N+1]
attn = attn_drop(attn) # hook fires on INPUT = post-softmax
out = attn @ v
"""
def __init__(self, blocks):
self.blocks = blocks
self._attn_list: List[torch.Tensor] = []
self._handles = []
self._register_hooks()
def _register_hooks(self):
for block in self.blocks:
block.attn.fused_attn = False
def _make_hook():
def _hook(module, input, output):
# input[0] is the post-softmax attention tensor
self._attn_list.append(input[0].detach())
return _hook
h = block.attn.attn_drop.register_forward_hook(_make_hook())
self._handles.append(h)
def clear(self):
self._attn_list.clear()
def get_all_attentions(self) -> Optional[List[torch.Tensor]]:
"""Returns list of L tensors, each [B, H, N+1, N+1]."""
if not self._attn_list:
return None
return list(self._attn_list)
def get_last_attention(self) -> Optional[torch.Tensor]:
if not self._attn_list:
return None
return self._attn_list[-1]
def remove_all(self):
for h in self._handles:
h.remove()
self._handles.clear()
# ─────────────────────────────────────────────────────────────
# GalaxyViT β€” proposed model
# ─────────────────────────────────────────────────────────────
class GalaxyViT(nn.Module):
"""
ViT-Base/16 backbone + linear regression head for GZ2.
Outputs 37 raw logits; softmax is applied per question group
during loss computation and metric evaluation.
Full 12-layer attention hooks are registered at construction.
"""
def __init__(self, cfg: DictConfig):
super().__init__()
self.backbone = timm.create_model(
cfg.model.backbone,
pretrained=cfg.model.pretrained,
num_classes=0,
)
embed_dim = self.backbone.embed_dim # 768
self.head = nn.Sequential(
nn.Dropout(p=cfg.model.dropout),
nn.Linear(embed_dim, 37),
)
self._hook_mgr = AttentionHookManager(self.backbone.blocks)
self._mc_dropout = False
def enable_mc_dropout(self):
"""Keep Dropout active at inference time for MC sampling."""
self._mc_dropout = True
for m in self.modules():
if isinstance(m, nn.Dropout):
m.train()
def disable_mc_dropout(self):
self._mc_dropout = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
self._hook_mgr.clear()
features = self.backbone(x) # [B, 768]
logits = self.head(features) # [B, 37]
return logits
def get_attention_weights(self) -> Optional[torch.Tensor]:
return self._hook_mgr.get_last_attention()
def get_all_attention_weights(self) -> Optional[List[torch.Tensor]]:
return self._hook_mgr.get_all_attentions()
def remove_hooks(self):
self._hook_mgr.remove_all()
# ─────────────────────────────────────────────────────────────
# GalaxyViTDirichlet β€” Zoobot-style comparison baseline
# ─────────────────────────────────────────────────────────────
class GalaxyViTDirichlet(nn.Module):
"""
ViT-Base/16 + Dirichlet concentration head.
Outputs Ξ± > 1 per answer via Ξ± = 1 + softplus(linear(features)).
Enforcing Ξ± > 1 ensures unimodal Dirichlet distributions.
Mean vote fraction: E[p_q] = Ξ±_q / sum(Ξ±_q) (same as softmax mean).
Total concentration sum(Ξ±_q) encodes prediction confidence.
"""
def __init__(self, cfg: DictConfig):
super().__init__()
self.backbone = timm.create_model(
cfg.model.backbone,
pretrained=cfg.model.pretrained,
num_classes=0,
)
embed_dim = self.backbone.embed_dim
self.head = nn.Sequential(
nn.Dropout(p=cfg.model.dropout),
nn.Linear(embed_dim, 37),
)
self._hook_mgr = AttentionHookManager(self.backbone.blocks)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Returns Ξ±: [B, 37] Dirichlet concentration parameters > 1."""
self._hook_mgr.clear()
features = self.backbone(x)
logits = self.head(features)
alpha = 1.0 + F.softplus(logits) # Ξ± > 1
return alpha
def get_mean_prediction(self, alpha: torch.Tensor) -> torch.Tensor:
means = torch.zeros_like(alpha)
for q_name, (start, end) in QUESTION_GROUPS.items():
a_q = alpha[:, start:end]
means[:, start:end] = a_q / a_q.sum(dim=-1, keepdim=True)
return means
def get_attention_weights(self):
return self._hook_mgr.get_last_attention()
def get_all_attention_weights(self):
return self._hook_mgr.get_all_attentions()
# ─────────────────────────────────────────────────────────────
# MC Dropout inference
# ─────────────────────────────────────────────────────────────
@torch.no_grad()
def mc_dropout_predict(
model: GalaxyViT,
images: torch.Tensor,
n_passes: int = 30,
device: torch.device = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
MC Dropout epistemic uncertainty estimation.
Runs n_passes stochastic forward passes with dropout active,
returning mean prediction and per-answer std.
Parameters
----------
model : GalaxyViT instance
images : [B, 3, H, W]
n_passes : number of MC samples (30 is standard)
device : inference device
Returns
-------
mean_pred : [B, 37] mean softmax predictions
std_pred : [B, 37] std across passes (epistemic uncertainty)
per_q_uncertainty: [B, 11] mean std per question
"""
if device is None:
device = next(model.parameters()).device
model.eval()
model.enable_mc_dropout()
images = images.to(device)
all_preds = []
for _ in range(n_passes):
logits = model(images) # [B, 37]
preds = torch.zeros_like(logits)
for q_name, (start, end) in QUESTION_GROUPS.items():
preds[:, start:end] = F.softmax(logits[:, start:end], dim=-1)
all_preds.append(preds.cpu().numpy())
model.disable_mc_dropout()
all_preds = np.stack(all_preds, axis=0) # [n_passes, B, 37]
mean_pred = all_preds.mean(axis=0) # [B, 37]
std_pred = all_preds.std(axis=0) # [B, 37]
q_names = list(QUESTION_GROUPS.keys())
per_q_unc = np.zeros(
(mean_pred.shape[0], len(q_names)), dtype=np.float32
)
for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()):
per_q_unc[:, q_idx] = std_pred[:, start:end].mean(axis=1)
return (
mean_pred.astype(np.float32),
std_pred.astype(np.float32),
per_q_unc,
)
# ─────────────────────────────────────────────────────────────
# Factory functions
# ─────────────────────────────────────────────────────────────
def build_model(cfg: DictConfig) -> GalaxyViT:
model = GalaxyViT(cfg)
_print_summary(model, cfg, "GalaxyViT (regression β€” proposed)")
return model
def build_dirichlet_model(cfg: DictConfig) -> GalaxyViTDirichlet:
model = GalaxyViTDirichlet(cfg)
_print_summary(model, cfg, "GalaxyViTDirichlet (Zoobot-style baseline)")
return model
def _print_summary(model, cfg, name: str):
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
n_hooks = len(model.backbone.blocks)
print(f"\n{'='*55}")
print(f"Model : {name}")
print(f"Backbone : {cfg.model.backbone}")
print(f"Pretrained : {cfg.model.pretrained}")
print(f"Dropout : {cfg.model.dropout}")
print(f"Parameters : {total:,} ({trainable:,} trainable)")
print(f"Attn hooks : {n_hooks} layers (full rollout enabled)")
print(f"{'='*55}\n")