Spaces:
Running
Running
| """ | |
| src/loss.py | |
| ----------- | |
| Loss functions for hierarchical probabilistic vote-fraction regression. | |
| Two losses are implemented: | |
| 1. HierarchicalLoss — proposed method: weighted KL + MSE per question. | |
| 2. DirichletLoss — Zoobot-style comparison: weighted Dirichlet NLL. | |
| 3. MSEOnlyLoss — ablation baseline: hierarchical MSE, no KL term. | |
| Both main losses use identical per-sample hierarchical weighting: | |
| w_q = parent branch vote fraction (1.0 for root question t01) | |
| Mathematical formulation | |
| ------------------------ | |
| HierarchicalLoss per question q: | |
| L_q = w_q * [ λ_kl * KL(p_q || ŷ_q) + λ_mse * MSE(ŷ_q, p_q) ] | |
| where p_q = ground-truth vote fractions [B, A_q] | |
| ŷ_q = softmax(logits_q) [B, A_q] | |
| w_q = hierarchical weight [B] | |
| DirichletLoss per question q: | |
| L_q = w_q * [ log B(α_q) − Σ_a (α_qa − 1) log(p_qa) ] | |
| where α_q = 1 + softplus(logits_q) > 1 [B, A_q] | |
| References | |
| ---------- | |
| Walmsley et al. (2022), MNRAS 509, 3966 (Zoobot — Dirichlet approach) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from omegaconf import DictConfig | |
| from src.dataset import QUESTION_GROUPS | |
| class HierarchicalLoss(nn.Module): | |
| """Weighted hierarchical KL + MSE loss. Proposed method.""" | |
| def __init__(self, cfg: DictConfig): | |
| super().__init__() | |
| self.lambda_kl = float(cfg.loss.lambda_kl) | |
| self.lambda_mse = float(cfg.loss.lambda_mse) | |
| self.epsilon = float(cfg.loss.epsilon) | |
| self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()] | |
| def forward(self, predictions: torch.Tensor, | |
| targets: torch.Tensor, weights: torch.Tensor): | |
| total_loss = torch.zeros(1, device=predictions.device, dtype=predictions.dtype) | |
| loss_dict = {} | |
| for q_idx, (q_name, start, end) in enumerate(self.question_slices): | |
| logits_q = predictions[:, start:end] | |
| target_q = targets[:, start:end] | |
| weight_q = weights[:, q_idx] | |
| pred_q = F.softmax(logits_q, dim=-1) | |
| pred_q_c = pred_q.clamp(min=self.epsilon, max=1.0) | |
| target_q_c = target_q.clamp(min=self.epsilon, max=1.0) | |
| kl_per_sample = ( | |
| target_q_c * (target_q_c.log() - pred_q_c.log()) | |
| ).sum(dim=-1) | |
| mse_per_sample = F.mse_loss( | |
| pred_q, target_q, reduction="none" | |
| ).mean(dim=-1) | |
| combined = (self.lambda_kl * kl_per_sample + | |
| self.lambda_mse * mse_per_sample) | |
| q_loss = (weight_q * combined).mean() | |
| total_loss = total_loss + q_loss | |
| loss_dict[f"loss/{q_name}"] = q_loss.detach().item() | |
| loss_dict["loss/total"] = total_loss.detach().item() | |
| return total_loss, loss_dict | |
| class DirichletLoss(nn.Module): | |
| """ | |
| Weighted hierarchical Dirichlet negative log-likelihood. | |
| Used to train GalaxyViTDirichlet for comparison with the proposed method. | |
| Matches the Zoobot approach (Walmsley et al. 2022). | |
| """ | |
| def __init__(self, cfg: DictConfig): | |
| super().__init__() | |
| self.epsilon = float(cfg.loss.epsilon) | |
| self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()] | |
| def forward(self, alpha: torch.Tensor, | |
| targets: torch.Tensor, weights: torch.Tensor): | |
| total_loss = torch.zeros(1, device=alpha.device, dtype=alpha.dtype) | |
| loss_dict = {} | |
| for q_idx, (q_name, start, end) in enumerate(self.question_slices): | |
| alpha_q = alpha[:, start:end] | |
| target_q = targets[:, start:end] | |
| weight_q = weights[:, q_idx] | |
| target_q_c = target_q.clamp(min=self.epsilon) | |
| # log B(α) = Σ lgamma(α_a) − lgamma(Σ α_a) | |
| log_beta = ( | |
| torch.lgamma(alpha_q).sum(dim=-1) - | |
| torch.lgamma(alpha_q.sum(dim=-1)) | |
| ) | |
| # −Σ (α_a − 1) log(p_a) | |
| log_likelihood = ((alpha_q - 1.0) * target_q_c.log()).sum(dim=-1) | |
| nll_per_sample = log_beta - log_likelihood | |
| q_loss = (weight_q * nll_per_sample).mean() | |
| total_loss = total_loss + q_loss | |
| loss_dict[f"loss/{q_name}"] = q_loss.detach().item() | |
| loss_dict["loss/total"] = total_loss.detach().item() | |
| return total_loss, loss_dict | |
| class MSEOnlyLoss(nn.Module): | |
| """ | |
| Hierarchical MSE loss without KL term. Used as ablation baseline. | |
| Equivalent to HierarchicalLoss with lambda_kl=0. | |
| """ | |
| def __init__(self, cfg: DictConfig): | |
| super().__init__() | |
| self.epsilon = float(cfg.loss.epsilon) | |
| self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()] | |
| def forward(self, predictions: torch.Tensor, | |
| targets: torch.Tensor, weights: torch.Tensor): | |
| total_loss = torch.zeros(1, device=predictions.device, dtype=predictions.dtype) | |
| loss_dict = {} | |
| for q_idx, (q_name, start, end) in enumerate(self.question_slices): | |
| logits_q = predictions[:, start:end] | |
| target_q = targets[:, start:end] | |
| weight_q = weights[:, q_idx] | |
| pred_q = F.softmax(logits_q, dim=-1) | |
| mse_per_sample = F.mse_loss(pred_q, target_q, reduction="none").mean(dim=-1) | |
| q_loss = (weight_q * mse_per_sample).mean() | |
| total_loss = total_loss + q_loss | |
| loss_dict[f"loss/{q_name}"] = q_loss.detach().item() | |
| loss_dict["loss/total"] = total_loss.detach().item() | |
| return total_loss, loss_dict | |