""" src/loss.py ----------- Loss functions for hierarchical probabilistic vote-fraction regression. Two losses are implemented: 1. HierarchicalLoss — proposed method: weighted KL + MSE per question. 2. DirichletLoss — Zoobot-style comparison: weighted Dirichlet NLL. 3. MSEOnlyLoss — ablation baseline: hierarchical MSE, no KL term. Both main losses use identical per-sample hierarchical weighting: w_q = parent branch vote fraction (1.0 for root question t01) Mathematical formulation ------------------------ HierarchicalLoss per question q: L_q = w_q * [ λ_kl * KL(p_q || ŷ_q) + λ_mse * MSE(ŷ_q, p_q) ] where p_q = ground-truth vote fractions [B, A_q] ŷ_q = softmax(logits_q) [B, A_q] w_q = hierarchical weight [B] DirichletLoss per question q: L_q = w_q * [ log B(α_q) − Σ_a (α_qa − 1) log(p_qa) ] where α_q = 1 + softplus(logits_q) > 1 [B, A_q] References ---------- Walmsley et al. (2022), MNRAS 509, 3966 (Zoobot — Dirichlet approach) """ import torch import torch.nn as nn import torch.nn.functional as F from omegaconf import DictConfig from src.dataset import QUESTION_GROUPS class HierarchicalLoss(nn.Module): """Weighted hierarchical KL + MSE loss. Proposed method.""" def __init__(self, cfg: DictConfig): super().__init__() self.lambda_kl = float(cfg.loss.lambda_kl) self.lambda_mse = float(cfg.loss.lambda_mse) self.epsilon = float(cfg.loss.epsilon) self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()] def forward(self, predictions: torch.Tensor, targets: torch.Tensor, weights: torch.Tensor): total_loss = torch.zeros(1, device=predictions.device, dtype=predictions.dtype) loss_dict = {} for q_idx, (q_name, start, end) in enumerate(self.question_slices): logits_q = predictions[:, start:end] target_q = targets[:, start:end] weight_q = weights[:, q_idx] pred_q = F.softmax(logits_q, dim=-1) pred_q_c = pred_q.clamp(min=self.epsilon, max=1.0) target_q_c = target_q.clamp(min=self.epsilon, max=1.0) kl_per_sample = ( target_q_c * (target_q_c.log() - pred_q_c.log()) ).sum(dim=-1) mse_per_sample = F.mse_loss( pred_q, target_q, reduction="none" ).mean(dim=-1) combined = (self.lambda_kl * kl_per_sample + self.lambda_mse * mse_per_sample) q_loss = (weight_q * combined).mean() total_loss = total_loss + q_loss loss_dict[f"loss/{q_name}"] = q_loss.detach().item() loss_dict["loss/total"] = total_loss.detach().item() return total_loss, loss_dict class DirichletLoss(nn.Module): """ Weighted hierarchical Dirichlet negative log-likelihood. Used to train GalaxyViTDirichlet for comparison with the proposed method. Matches the Zoobot approach (Walmsley et al. 2022). """ def __init__(self, cfg: DictConfig): super().__init__() self.epsilon = float(cfg.loss.epsilon) self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()] def forward(self, alpha: torch.Tensor, targets: torch.Tensor, weights: torch.Tensor): total_loss = torch.zeros(1, device=alpha.device, dtype=alpha.dtype) loss_dict = {} for q_idx, (q_name, start, end) in enumerate(self.question_slices): alpha_q = alpha[:, start:end] target_q = targets[:, start:end] weight_q = weights[:, q_idx] target_q_c = target_q.clamp(min=self.epsilon) # log B(α) = Σ lgamma(α_a) − lgamma(Σ α_a) log_beta = ( torch.lgamma(alpha_q).sum(dim=-1) - torch.lgamma(alpha_q.sum(dim=-1)) ) # −Σ (α_a − 1) log(p_a) log_likelihood = ((alpha_q - 1.0) * target_q_c.log()).sum(dim=-1) nll_per_sample = log_beta - log_likelihood q_loss = (weight_q * nll_per_sample).mean() total_loss = total_loss + q_loss loss_dict[f"loss/{q_name}"] = q_loss.detach().item() loss_dict["loss/total"] = total_loss.detach().item() return total_loss, loss_dict class MSEOnlyLoss(nn.Module): """ Hierarchical MSE loss without KL term. Used as ablation baseline. Equivalent to HierarchicalLoss with lambda_kl=0. """ def __init__(self, cfg: DictConfig): super().__init__() self.epsilon = float(cfg.loss.epsilon) self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()] def forward(self, predictions: torch.Tensor, targets: torch.Tensor, weights: torch.Tensor): total_loss = torch.zeros(1, device=predictions.device, dtype=predictions.dtype) loss_dict = {} for q_idx, (q_name, start, end) in enumerate(self.question_slices): logits_q = predictions[:, start:end] target_q = targets[:, start:end] weight_q = weights[:, q_idx] pred_q = F.softmax(logits_q, dim=-1) mse_per_sample = F.mse_loss(pred_q, target_q, reduction="none").mean(dim=-1) q_loss = (weight_q * mse_per_sample).mean() total_loss = total_loss + q_loss loss_dict[f"loss/{q_name}"] = q_loss.detach().item() loss_dict["loss/total"] = total_loss.detach().item() return total_loss, loss_dict