Adding files from hf_modeling_btm_log_prob_mixing
Browse files- modeling.py +2 -5
modeling.py
CHANGED
|
@@ -151,19 +151,16 @@ class MoLM(PreTrainedModel):
|
|
| 151 |
weighted_log_probs = log_probs + log_weights_exp # (E, B, T, V)
|
| 152 |
|
| 153 |
combined_log_probs = torch.logsumexp(weighted_log_probs, dim=0) # (B, T, V)
|
| 154 |
-
combined_logits = combined_log_probs # because loss works with log-probs if used properly
|
| 155 |
|
| 156 |
else:
|
| 157 |
# Unweighted average in log-prob space across active experts (equal weights)
|
| 158 |
log_weights = torch.log(1.0 / active_experts_count.float().clamp(min=1.0)).view(1, -1, 1, 1) # (1, B, 1, 1)
|
| 159 |
weighted_log_probs = log_probs + log_weights
|
| 160 |
combined_log_probs = torch.logsumexp(weighted_log_probs, dim=0) # (B, T, V)
|
| 161 |
-
combined_logits = combined_log_probs # because loss works with log-probs if used properly
|
| 162 |
|
| 163 |
# Calculate the loss if targets are provided
|
| 164 |
if targets is not None:
|
| 165 |
-
|
| 166 |
-
loss = F.nll_loss(combined_logits.view(-1, combined_logits.size(-1)), targets.view(-1), ignore_index=-1)
|
| 167 |
loss_to_log = loss.item()
|
| 168 |
|
| 169 |
# Add auxiliary router losses (only if routing is used and we're training)
|
|
@@ -188,7 +185,7 @@ class MoLM(PreTrainedModel):
|
|
| 188 |
loss_to_log = None
|
| 189 |
|
| 190 |
return Output(
|
| 191 |
-
logits=
|
| 192 |
loss=loss,
|
| 193 |
combined_log_probs=combined_log_probs,
|
| 194 |
loss_to_log=loss_to_log,
|
|
|
|
| 151 |
weighted_log_probs = log_probs + log_weights_exp # (E, B, T, V)
|
| 152 |
|
| 153 |
combined_log_probs = torch.logsumexp(weighted_log_probs, dim=0) # (B, T, V)
|
|
|
|
| 154 |
|
| 155 |
else:
|
| 156 |
# Unweighted average in log-prob space across active experts (equal weights)
|
| 157 |
log_weights = torch.log(1.0 / active_experts_count.float().clamp(min=1.0)).view(1, -1, 1, 1) # (1, B, 1, 1)
|
| 158 |
weighted_log_probs = log_probs + log_weights
|
| 159 |
combined_log_probs = torch.logsumexp(weighted_log_probs, dim=0) # (B, T, V)
|
|
|
|
| 160 |
|
| 161 |
# Calculate the loss if targets are provided
|
| 162 |
if targets is not None:
|
| 163 |
+
loss = F.nll_loss(combined_log_probs.view(-1, combined_log_probs.size(-1)), targets.view(-1), ignore_index=-1)
|
|
|
|
| 164 |
loss_to_log = loss.item()
|
| 165 |
|
| 166 |
# Add auxiliary router losses (only if routing is used and we're training)
|
|
|
|
| 185 |
loss_to_log = None
|
| 186 |
|
| 187 |
return Output(
|
| 188 |
+
logits=torch.Tensor([expert_output for expert_output in expert_outputs]),
|
| 189 |
loss=loss,
|
| 190 |
combined_log_probs=combined_log_probs,
|
| 191 |
loss_to_log=loss_to_log,
|