Spaces:
Running
Running
File size: 5,693 Bytes
e36eee4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | """
src/loss.py
-----------
Loss functions for hierarchical probabilistic vote-fraction regression.
Two losses are implemented:
1. HierarchicalLoss — proposed method: weighted KL + MSE per question.
2. DirichletLoss — Zoobot-style comparison: weighted Dirichlet NLL.
3. MSEOnlyLoss — ablation baseline: hierarchical MSE, no KL term.
Both main losses use identical per-sample hierarchical weighting:
w_q = parent branch vote fraction (1.0 for root question t01)
Mathematical formulation
------------------------
HierarchicalLoss per question q:
L_q = w_q * [ λ_kl * KL(p_q || ŷ_q) + λ_mse * MSE(ŷ_q, p_q) ]
where p_q = ground-truth vote fractions [B, A_q]
ŷ_q = softmax(logits_q) [B, A_q]
w_q = hierarchical weight [B]
DirichletLoss per question q:
L_q = w_q * [ log B(α_q) − Σ_a (α_qa − 1) log(p_qa) ]
where α_q = 1 + softplus(logits_q) > 1 [B, A_q]
References
----------
Walmsley et al. (2022), MNRAS 509, 3966 (Zoobot — Dirichlet approach)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import DictConfig
from src.dataset import QUESTION_GROUPS
class HierarchicalLoss(nn.Module):
"""Weighted hierarchical KL + MSE loss. Proposed method."""
def __init__(self, cfg: DictConfig):
super().__init__()
self.lambda_kl = float(cfg.loss.lambda_kl)
self.lambda_mse = float(cfg.loss.lambda_mse)
self.epsilon = float(cfg.loss.epsilon)
self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()]
def forward(self, predictions: torch.Tensor,
targets: torch.Tensor, weights: torch.Tensor):
total_loss = torch.zeros(1, device=predictions.device, dtype=predictions.dtype)
loss_dict = {}
for q_idx, (q_name, start, end) in enumerate(self.question_slices):
logits_q = predictions[:, start:end]
target_q = targets[:, start:end]
weight_q = weights[:, q_idx]
pred_q = F.softmax(logits_q, dim=-1)
pred_q_c = pred_q.clamp(min=self.epsilon, max=1.0)
target_q_c = target_q.clamp(min=self.epsilon, max=1.0)
kl_per_sample = (
target_q_c * (target_q_c.log() - pred_q_c.log())
).sum(dim=-1)
mse_per_sample = F.mse_loss(
pred_q, target_q, reduction="none"
).mean(dim=-1)
combined = (self.lambda_kl * kl_per_sample +
self.lambda_mse * mse_per_sample)
q_loss = (weight_q * combined).mean()
total_loss = total_loss + q_loss
loss_dict[f"loss/{q_name}"] = q_loss.detach().item()
loss_dict["loss/total"] = total_loss.detach().item()
return total_loss, loss_dict
class DirichletLoss(nn.Module):
"""
Weighted hierarchical Dirichlet negative log-likelihood.
Used to train GalaxyViTDirichlet for comparison with the proposed method.
Matches the Zoobot approach (Walmsley et al. 2022).
"""
def __init__(self, cfg: DictConfig):
super().__init__()
self.epsilon = float(cfg.loss.epsilon)
self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()]
def forward(self, alpha: torch.Tensor,
targets: torch.Tensor, weights: torch.Tensor):
total_loss = torch.zeros(1, device=alpha.device, dtype=alpha.dtype)
loss_dict = {}
for q_idx, (q_name, start, end) in enumerate(self.question_slices):
alpha_q = alpha[:, start:end]
target_q = targets[:, start:end]
weight_q = weights[:, q_idx]
target_q_c = target_q.clamp(min=self.epsilon)
# log B(α) = Σ lgamma(α_a) − lgamma(Σ α_a)
log_beta = (
torch.lgamma(alpha_q).sum(dim=-1) -
torch.lgamma(alpha_q.sum(dim=-1))
)
# −Σ (α_a − 1) log(p_a)
log_likelihood = ((alpha_q - 1.0) * target_q_c.log()).sum(dim=-1)
nll_per_sample = log_beta - log_likelihood
q_loss = (weight_q * nll_per_sample).mean()
total_loss = total_loss + q_loss
loss_dict[f"loss/{q_name}"] = q_loss.detach().item()
loss_dict["loss/total"] = total_loss.detach().item()
return total_loss, loss_dict
class MSEOnlyLoss(nn.Module):
"""
Hierarchical MSE loss without KL term. Used as ablation baseline.
Equivalent to HierarchicalLoss with lambda_kl=0.
"""
def __init__(self, cfg: DictConfig):
super().__init__()
self.epsilon = float(cfg.loss.epsilon)
self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()]
def forward(self, predictions: torch.Tensor,
targets: torch.Tensor, weights: torch.Tensor):
total_loss = torch.zeros(1, device=predictions.device, dtype=predictions.dtype)
loss_dict = {}
for q_idx, (q_name, start, end) in enumerate(self.question_slices):
logits_q = predictions[:, start:end]
target_q = targets[:, start:end]
weight_q = weights[:, q_idx]
pred_q = F.softmax(logits_q, dim=-1)
mse_per_sample = F.mse_loss(pred_q, target_q, reduction="none").mean(dim=-1)
q_loss = (weight_q * mse_per_sample).mean()
total_loss = total_loss + q_loss
loss_dict[f"loss/{q_name}"] = q_loss.detach().item()
loss_dict["loss/total"] = total_loss.detach().item()
return total_loss, loss_dict
|