eshwar-gz2-api / src /loss.py
sreshwarprasad's picture
Upload folder using huggingface_hub
e36eee4 verified
"""
src/loss.py
-----------
Loss functions for hierarchical probabilistic vote-fraction regression.
Two losses are implemented:
1. HierarchicalLoss — proposed method: weighted KL + MSE per question.
2. DirichletLoss — Zoobot-style comparison: weighted Dirichlet NLL.
3. MSEOnlyLoss — ablation baseline: hierarchical MSE, no KL term.
Both main losses use identical per-sample hierarchical weighting:
w_q = parent branch vote fraction (1.0 for root question t01)
Mathematical formulation
------------------------
HierarchicalLoss per question q:
L_q = w_q * [ λ_kl * KL(p_q || ŷ_q) + λ_mse * MSE(ŷ_q, p_q) ]
where p_q = ground-truth vote fractions [B, A_q]
ŷ_q = softmax(logits_q) [B, A_q]
w_q = hierarchical weight [B]
DirichletLoss per question q:
L_q = w_q * [ log B(α_q) − Σ_a (α_qa − 1) log(p_qa) ]
where α_q = 1 + softplus(logits_q) > 1 [B, A_q]
References
----------
Walmsley et al. (2022), MNRAS 509, 3966 (Zoobot — Dirichlet approach)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import DictConfig
from src.dataset import QUESTION_GROUPS
class HierarchicalLoss(nn.Module):
"""Weighted hierarchical KL + MSE loss. Proposed method."""
def __init__(self, cfg: DictConfig):
super().__init__()
self.lambda_kl = float(cfg.loss.lambda_kl)
self.lambda_mse = float(cfg.loss.lambda_mse)
self.epsilon = float(cfg.loss.epsilon)
self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()]
def forward(self, predictions: torch.Tensor,
targets: torch.Tensor, weights: torch.Tensor):
total_loss = torch.zeros(1, device=predictions.device, dtype=predictions.dtype)
loss_dict = {}
for q_idx, (q_name, start, end) in enumerate(self.question_slices):
logits_q = predictions[:, start:end]
target_q = targets[:, start:end]
weight_q = weights[:, q_idx]
pred_q = F.softmax(logits_q, dim=-1)
pred_q_c = pred_q.clamp(min=self.epsilon, max=1.0)
target_q_c = target_q.clamp(min=self.epsilon, max=1.0)
kl_per_sample = (
target_q_c * (target_q_c.log() - pred_q_c.log())
).sum(dim=-1)
mse_per_sample = F.mse_loss(
pred_q, target_q, reduction="none"
).mean(dim=-1)
combined = (self.lambda_kl * kl_per_sample +
self.lambda_mse * mse_per_sample)
q_loss = (weight_q * combined).mean()
total_loss = total_loss + q_loss
loss_dict[f"loss/{q_name}"] = q_loss.detach().item()
loss_dict["loss/total"] = total_loss.detach().item()
return total_loss, loss_dict
class DirichletLoss(nn.Module):
"""
Weighted hierarchical Dirichlet negative log-likelihood.
Used to train GalaxyViTDirichlet for comparison with the proposed method.
Matches the Zoobot approach (Walmsley et al. 2022).
"""
def __init__(self, cfg: DictConfig):
super().__init__()
self.epsilon = float(cfg.loss.epsilon)
self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()]
def forward(self, alpha: torch.Tensor,
targets: torch.Tensor, weights: torch.Tensor):
total_loss = torch.zeros(1, device=alpha.device, dtype=alpha.dtype)
loss_dict = {}
for q_idx, (q_name, start, end) in enumerate(self.question_slices):
alpha_q = alpha[:, start:end]
target_q = targets[:, start:end]
weight_q = weights[:, q_idx]
target_q_c = target_q.clamp(min=self.epsilon)
# log B(α) = Σ lgamma(α_a) − lgamma(Σ α_a)
log_beta = (
torch.lgamma(alpha_q).sum(dim=-1) -
torch.lgamma(alpha_q.sum(dim=-1))
)
# −Σ (α_a − 1) log(p_a)
log_likelihood = ((alpha_q - 1.0) * target_q_c.log()).sum(dim=-1)
nll_per_sample = log_beta - log_likelihood
q_loss = (weight_q * nll_per_sample).mean()
total_loss = total_loss + q_loss
loss_dict[f"loss/{q_name}"] = q_loss.detach().item()
loss_dict["loss/total"] = total_loss.detach().item()
return total_loss, loss_dict
class MSEOnlyLoss(nn.Module):
"""
Hierarchical MSE loss without KL term. Used as ablation baseline.
Equivalent to HierarchicalLoss with lambda_kl=0.
"""
def __init__(self, cfg: DictConfig):
super().__init__()
self.epsilon = float(cfg.loss.epsilon)
self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()]
def forward(self, predictions: torch.Tensor,
targets: torch.Tensor, weights: torch.Tensor):
total_loss = torch.zeros(1, device=predictions.device, dtype=predictions.dtype)
loss_dict = {}
for q_idx, (q_name, start, end) in enumerate(self.question_slices):
logits_q = predictions[:, start:end]
target_q = targets[:, start:end]
weight_q = weights[:, q_idx]
pred_q = F.softmax(logits_q, dim=-1)
mse_per_sample = F.mse_loss(pred_q, target_q, reduction="none").mean(dim=-1)
q_loss = (weight_q * mse_per_sample).mean()
total_loss = total_loss + q_loss
loss_dict[f"loss/{q_name}"] = q_loss.detach().item()
loss_dict["loss/total"] = total_loss.detach().item()
return total_loss, loss_dict