File size: 5,693 Bytes
e36eee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
src/loss.py
-----------
Loss functions for hierarchical probabilistic vote-fraction regression.

Two losses are implemented:

1. HierarchicalLoss  — proposed method: weighted KL + MSE per question.
2. DirichletLoss     — Zoobot-style comparison: weighted Dirichlet NLL.
3. MSEOnlyLoss       — ablation baseline: hierarchical MSE, no KL term.

Both main losses use identical per-sample hierarchical weighting:
    w_q = parent branch vote fraction  (1.0 for root question t01)

Mathematical formulation
------------------------
HierarchicalLoss per question q:
    L_q = w_q * [ λ_kl * KL(p_q || ŷ_q) + λ_mse * MSE(ŷ_q, p_q) ]

    where  p_q = ground-truth vote fractions  [B, A_q]
           ŷ_q = softmax(logits_q)            [B, A_q]
           w_q = hierarchical weight           [B]

DirichletLoss per question q:
    L_q = w_q * [ log B(α_q) − Σ_a (α_qa − 1) log(p_qa) ]

    where  α_q = 1 + softplus(logits_q)  > 1   [B, A_q]

References
----------
Walmsley et al. (2022), MNRAS 509, 3966  (Zoobot — Dirichlet approach)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import DictConfig
from src.dataset import QUESTION_GROUPS


class HierarchicalLoss(nn.Module):
    """Weighted hierarchical KL + MSE loss. Proposed method."""

    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.lambda_kl       = float(cfg.loss.lambda_kl)
        self.lambda_mse      = float(cfg.loss.lambda_mse)
        self.epsilon         = float(cfg.loss.epsilon)
        self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()]

    def forward(self, predictions: torch.Tensor,
                targets: torch.Tensor, weights: torch.Tensor):
        total_loss = torch.zeros(1, device=predictions.device, dtype=predictions.dtype)
        loss_dict  = {}

        for q_idx, (q_name, start, end) in enumerate(self.question_slices):
            logits_q   = predictions[:, start:end]
            target_q   = targets[:,    start:end]
            weight_q   = weights[:, q_idx]

            pred_q     = F.softmax(logits_q, dim=-1)
            pred_q_c   = pred_q.clamp(min=self.epsilon, max=1.0)
            target_q_c = target_q.clamp(min=self.epsilon, max=1.0)

            kl_per_sample = (
                target_q_c * (target_q_c.log() - pred_q_c.log())
            ).sum(dim=-1)

            mse_per_sample = F.mse_loss(
                pred_q, target_q, reduction="none"
            ).mean(dim=-1)

            combined = (self.lambda_kl * kl_per_sample +
                        self.lambda_mse * mse_per_sample)
            q_loss   = (weight_q * combined).mean()

            total_loss = total_loss + q_loss
            loss_dict[f"loss/{q_name}"] = q_loss.detach().item()

        loss_dict["loss/total"] = total_loss.detach().item()
        return total_loss, loss_dict


class DirichletLoss(nn.Module):
    """
    Weighted hierarchical Dirichlet negative log-likelihood.
    Used to train GalaxyViTDirichlet for comparison with the proposed method.
    Matches the Zoobot approach (Walmsley et al. 2022).
    """

    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.epsilon         = float(cfg.loss.epsilon)
        self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()]

    def forward(self, alpha: torch.Tensor,
                targets: torch.Tensor, weights: torch.Tensor):
        total_loss = torch.zeros(1, device=alpha.device, dtype=alpha.dtype)
        loss_dict  = {}

        for q_idx, (q_name, start, end) in enumerate(self.question_slices):
            alpha_q    = alpha[:,   start:end]
            target_q   = targets[:, start:end]
            weight_q   = weights[:, q_idx]

            target_q_c = target_q.clamp(min=self.epsilon)

            # log B(α) = Σ lgamma(α_a) − lgamma(Σ α_a)
            log_beta = (
                torch.lgamma(alpha_q).sum(dim=-1) -
                torch.lgamma(alpha_q.sum(dim=-1))
            )
            # −Σ (α_a − 1) log(p_a)
            log_likelihood = ((alpha_q - 1.0) * target_q_c.log()).sum(dim=-1)

            nll_per_sample = log_beta - log_likelihood
            q_loss = (weight_q * nll_per_sample).mean()

            total_loss = total_loss + q_loss
            loss_dict[f"loss/{q_name}"] = q_loss.detach().item()

        loss_dict["loss/total"] = total_loss.detach().item()
        return total_loss, loss_dict


class MSEOnlyLoss(nn.Module):
    """
    Hierarchical MSE loss without KL term. Used as ablation baseline.
    Equivalent to HierarchicalLoss with lambda_kl=0.
    """

    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.epsilon         = float(cfg.loss.epsilon)
        self.question_slices = [(q, s, e) for q, (s, e) in QUESTION_GROUPS.items()]

    def forward(self, predictions: torch.Tensor,
                targets: torch.Tensor, weights: torch.Tensor):
        total_loss = torch.zeros(1, device=predictions.device, dtype=predictions.dtype)
        loss_dict  = {}

        for q_idx, (q_name, start, end) in enumerate(self.question_slices):
            logits_q = predictions[:, start:end]
            target_q = targets[:,    start:end]
            weight_q = weights[:, q_idx]

            pred_q         = F.softmax(logits_q, dim=-1)
            mse_per_sample = F.mse_loss(pred_q, target_q, reduction="none").mean(dim=-1)
            q_loss         = (weight_q * mse_per_sample).mean()

            total_loss = total_loss + q_loss
            loss_dict[f"loss/{q_name}"] = q_loss.detach().item()

        loss_dict["loss/total"] = total_loss.detach().item()
        return total_loss, loss_dict