Spaces:
Running
Running
| """ | |
| src/model.py | |
| ------------ | |
| Vision Transformer (ViT-Base/16) backbone with three head variants: | |
| 1. GalaxyViT β linear regression head (37 logits). Proposed model. | |
| 2. GalaxyViTDirichlet β Dirichlet concentration head (Zoobot-style baseline). | |
| 3. mc_dropout_predict β MC Dropout uncertainty estimation wrapper. | |
| Architecture | |
| ------------ | |
| Backbone : vit_base_patch16_224 from timm (pretrained ImageNet-21k) | |
| 12 transformer layers, 12 heads, embed_dim=768 | |
| Input : [B, 3, 224, 224] | |
| CLS out: [B, 768] | |
| Head : Dropout(p) β Linear(768, 37) | |
| Full multi-layer attention rollout | |
| ------------------------------------ | |
| All 12 transformer blocks use fused_attn=False so forward hooks can | |
| capture the post-softmax attention matrices. Rollout is computed in | |
| attention_viz.py using the corrected right-multiplication order. | |
| MC Dropout | |
| ----------- | |
| enable_mc_dropout() keeps Dropout active at inference time. | |
| Running N stochastic forward passes gives mean prediction and | |
| per-answer std (epistemic uncertainty). N=30 is standard practice | |
| per Gal & Ghahramani (2016). | |
| Dirichlet head | |
| -------------- | |
| Outputs Ξ± > 1 per answer via: Ξ± = 1 + softplus(linear(features)) | |
| Matches the Zoobot approach for a fair direct comparison. | |
| Mean vote fraction: E[p_q] = Ξ±_q / sum(Ξ±_q). | |
| References | |
| ---------- | |
| Gal & Ghahramani (2016). Dropout as a Bayesian Approximation. | |
| ICML 2016. https://arxiv.org/abs/1506.02142 | |
| Walmsley et al. (2022). Towards Galaxy Foundation Models. | |
| MNRAS 509, 3966. https://arxiv.org/abs/2110.12735 | |
| """ | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import timm | |
| import numpy as np | |
| from omegaconf import DictConfig | |
| from typing import Optional, List, Tuple | |
| from src.dataset import QUESTION_GROUPS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Attention hook manager | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class AttentionHookManager: | |
| """ | |
| Registers forward hooks on all transformer blocks to capture | |
| post-softmax attention matrices for full rollout computation. | |
| With fused_attn=False, timm's attention block executes: | |
| attn = softmax(q @ k.T / scale) # [B, H, N+1, N+1] | |
| attn = attn_drop(attn) # hook fires on INPUT = post-softmax | |
| out = attn @ v | |
| """ | |
| def __init__(self, blocks): | |
| self.blocks = blocks | |
| self._attn_list: List[torch.Tensor] = [] | |
| self._handles = [] | |
| self._register_hooks() | |
| def _register_hooks(self): | |
| for block in self.blocks: | |
| block.attn.fused_attn = False | |
| def _make_hook(): | |
| def _hook(module, input, output): | |
| # input[0] is the post-softmax attention tensor | |
| self._attn_list.append(input[0].detach()) | |
| return _hook | |
| h = block.attn.attn_drop.register_forward_hook(_make_hook()) | |
| self._handles.append(h) | |
| def clear(self): | |
| self._attn_list.clear() | |
| def get_all_attentions(self) -> Optional[List[torch.Tensor]]: | |
| """Returns list of L tensors, each [B, H, N+1, N+1].""" | |
| if not self._attn_list: | |
| return None | |
| return list(self._attn_list) | |
| def get_last_attention(self) -> Optional[torch.Tensor]: | |
| if not self._attn_list: | |
| return None | |
| return self._attn_list[-1] | |
| def remove_all(self): | |
| for h in self._handles: | |
| h.remove() | |
| self._handles.clear() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GalaxyViT β proposed model | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class GalaxyViT(nn.Module): | |
| """ | |
| ViT-Base/16 backbone + linear regression head for GZ2. | |
| Outputs 37 raw logits; softmax is applied per question group | |
| during loss computation and metric evaluation. | |
| Full 12-layer attention hooks are registered at construction. | |
| """ | |
| def __init__(self, cfg: DictConfig): | |
| super().__init__() | |
| self.backbone = timm.create_model( | |
| cfg.model.backbone, | |
| pretrained=cfg.model.pretrained, | |
| num_classes=0, | |
| ) | |
| embed_dim = self.backbone.embed_dim # 768 | |
| self.head = nn.Sequential( | |
| nn.Dropout(p=cfg.model.dropout), | |
| nn.Linear(embed_dim, 37), | |
| ) | |
| self._hook_mgr = AttentionHookManager(self.backbone.blocks) | |
| self._mc_dropout = False | |
| def enable_mc_dropout(self): | |
| """Keep Dropout active at inference time for MC sampling.""" | |
| self._mc_dropout = True | |
| for m in self.modules(): | |
| if isinstance(m, nn.Dropout): | |
| m.train() | |
| def disable_mc_dropout(self): | |
| self._mc_dropout = False | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| self._hook_mgr.clear() | |
| features = self.backbone(x) # [B, 768] | |
| logits = self.head(features) # [B, 37] | |
| return logits | |
| def get_attention_weights(self) -> Optional[torch.Tensor]: | |
| return self._hook_mgr.get_last_attention() | |
| def get_all_attention_weights(self) -> Optional[List[torch.Tensor]]: | |
| return self._hook_mgr.get_all_attentions() | |
| def remove_hooks(self): | |
| self._hook_mgr.remove_all() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GalaxyViTDirichlet β Zoobot-style comparison baseline | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class GalaxyViTDirichlet(nn.Module): | |
| """ | |
| ViT-Base/16 + Dirichlet concentration head. | |
| Outputs Ξ± > 1 per answer via Ξ± = 1 + softplus(linear(features)). | |
| Enforcing Ξ± > 1 ensures unimodal Dirichlet distributions. | |
| Mean vote fraction: E[p_q] = Ξ±_q / sum(Ξ±_q) (same as softmax mean). | |
| Total concentration sum(Ξ±_q) encodes prediction confidence. | |
| """ | |
| def __init__(self, cfg: DictConfig): | |
| super().__init__() | |
| self.backbone = timm.create_model( | |
| cfg.model.backbone, | |
| pretrained=cfg.model.pretrained, | |
| num_classes=0, | |
| ) | |
| embed_dim = self.backbone.embed_dim | |
| self.head = nn.Sequential( | |
| nn.Dropout(p=cfg.model.dropout), | |
| nn.Linear(embed_dim, 37), | |
| ) | |
| self._hook_mgr = AttentionHookManager(self.backbone.blocks) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Returns Ξ±: [B, 37] Dirichlet concentration parameters > 1.""" | |
| self._hook_mgr.clear() | |
| features = self.backbone(x) | |
| logits = self.head(features) | |
| alpha = 1.0 + F.softplus(logits) # Ξ± > 1 | |
| return alpha | |
| def get_mean_prediction(self, alpha: torch.Tensor) -> torch.Tensor: | |
| means = torch.zeros_like(alpha) | |
| for q_name, (start, end) in QUESTION_GROUPS.items(): | |
| a_q = alpha[:, start:end] | |
| means[:, start:end] = a_q / a_q.sum(dim=-1, keepdim=True) | |
| return means | |
| def get_attention_weights(self): | |
| return self._hook_mgr.get_last_attention() | |
| def get_all_attention_weights(self): | |
| return self._hook_mgr.get_all_attentions() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MC Dropout inference | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def mc_dropout_predict( | |
| model: GalaxyViT, | |
| images: torch.Tensor, | |
| n_passes: int = 30, | |
| device: torch.device = None, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| MC Dropout epistemic uncertainty estimation. | |
| Runs n_passes stochastic forward passes with dropout active, | |
| returning mean prediction and per-answer std. | |
| Parameters | |
| ---------- | |
| model : GalaxyViT instance | |
| images : [B, 3, H, W] | |
| n_passes : number of MC samples (30 is standard) | |
| device : inference device | |
| Returns | |
| ------- | |
| mean_pred : [B, 37] mean softmax predictions | |
| std_pred : [B, 37] std across passes (epistemic uncertainty) | |
| per_q_uncertainty: [B, 11] mean std per question | |
| """ | |
| if device is None: | |
| device = next(model.parameters()).device | |
| model.eval() | |
| model.enable_mc_dropout() | |
| images = images.to(device) | |
| all_preds = [] | |
| for _ in range(n_passes): | |
| logits = model(images) # [B, 37] | |
| preds = torch.zeros_like(logits) | |
| for q_name, (start, end) in QUESTION_GROUPS.items(): | |
| preds[:, start:end] = F.softmax(logits[:, start:end], dim=-1) | |
| all_preds.append(preds.cpu().numpy()) | |
| model.disable_mc_dropout() | |
| all_preds = np.stack(all_preds, axis=0) # [n_passes, B, 37] | |
| mean_pred = all_preds.mean(axis=0) # [B, 37] | |
| std_pred = all_preds.std(axis=0) # [B, 37] | |
| q_names = list(QUESTION_GROUPS.keys()) | |
| per_q_unc = np.zeros( | |
| (mean_pred.shape[0], len(q_names)), dtype=np.float32 | |
| ) | |
| for q_idx, (q_name, (start, end)) in enumerate(QUESTION_GROUPS.items()): | |
| per_q_unc[:, q_idx] = std_pred[:, start:end].mean(axis=1) | |
| return ( | |
| mean_pred.astype(np.float32), | |
| std_pred.astype(np.float32), | |
| per_q_unc, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Factory functions | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_model(cfg: DictConfig) -> GalaxyViT: | |
| model = GalaxyViT(cfg) | |
| _print_summary(model, cfg, "GalaxyViT (regression β proposed)") | |
| return model | |
| def build_dirichlet_model(cfg: DictConfig) -> GalaxyViTDirichlet: | |
| model = GalaxyViTDirichlet(cfg) | |
| _print_summary(model, cfg, "GalaxyViTDirichlet (Zoobot-style baseline)") | |
| return model | |
| def _print_summary(model, cfg, name: str): | |
| total = sum(p.numel() for p in model.parameters()) | |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| n_hooks = len(model.backbone.blocks) | |
| print(f"\n{'='*55}") | |
| print(f"Model : {name}") | |
| print(f"Backbone : {cfg.model.backbone}") | |
| print(f"Pretrained : {cfg.model.pretrained}") | |
| print(f"Dropout : {cfg.model.dropout}") | |
| print(f"Parameters : {total:,} ({trainable:,} trainable)") | |
| print(f"Attn hooks : {n_hooks} layers (full rollout enabled)") | |
| print(f"{'='*55}\n") | |