robinfaro commited on
Commit
8ffbc00
·
verified ·
1 Parent(s): 0aa6558

Adding files from hf_modeling_btm_log_prob_mixing

Browse files
Files changed (1) hide show
  1. modeling.py +2 -5
modeling.py CHANGED
@@ -151,19 +151,16 @@ class MoLM(PreTrainedModel):
151
  weighted_log_probs = log_probs + log_weights_exp # (E, B, T, V)
152
 
153
  combined_log_probs = torch.logsumexp(weighted_log_probs, dim=0) # (B, T, V)
154
- combined_logits = combined_log_probs # because loss works with log-probs if used properly
155
 
156
  else:
157
  # Unweighted average in log-prob space across active experts (equal weights)
158
  log_weights = torch.log(1.0 / active_experts_count.float().clamp(min=1.0)).view(1, -1, 1, 1) # (1, B, 1, 1)
159
  weighted_log_probs = log_probs + log_weights
160
  combined_log_probs = torch.logsumexp(weighted_log_probs, dim=0) # (B, T, V)
161
- combined_logits = combined_log_probs # because loss works with log-probs if used properly
162
 
163
  # Calculate the loss if targets are provided
164
  if targets is not None:
165
- #loss = F.cross_entropy(combined_logits.view(-1, combined_logits.size(-1)), targets.view(-1), ignore_index=-1)
166
- loss = F.nll_loss(combined_logits.view(-1, combined_logits.size(-1)), targets.view(-1), ignore_index=-1)
167
  loss_to_log = loss.item()
168
 
169
  # Add auxiliary router losses (only if routing is used and we're training)
@@ -188,7 +185,7 @@ class MoLM(PreTrainedModel):
188
  loss_to_log = None
189
 
190
  return Output(
191
- logits=combined_logits,
192
  loss=loss,
193
  combined_log_probs=combined_log_probs,
194
  loss_to_log=loss_to_log,
 
151
  weighted_log_probs = log_probs + log_weights_exp # (E, B, T, V)
152
 
153
  combined_log_probs = torch.logsumexp(weighted_log_probs, dim=0) # (B, T, V)
 
154
 
155
  else:
156
  # Unweighted average in log-prob space across active experts (equal weights)
157
  log_weights = torch.log(1.0 / active_experts_count.float().clamp(min=1.0)).view(1, -1, 1, 1) # (1, B, 1, 1)
158
  weighted_log_probs = log_probs + log_weights
159
  combined_log_probs = torch.logsumexp(weighted_log_probs, dim=0) # (B, T, V)
 
160
 
161
  # Calculate the loss if targets are provided
162
  if targets is not None:
163
+ loss = F.nll_loss(combined_log_probs.view(-1, combined_log_probs.size(-1)), targets.view(-1), ignore_index=-1)
 
164
  loss_to_log = loss.item()
165
 
166
  # Add auxiliary router losses (only if routing is used and we're training)
 
185
  loss_to_log = None
186
 
187
  return Output(
188
+ logits=torch.Tensor([expert_output for expert_output in expert_outputs]),
189
  loss=loss,
190
  combined_log_probs=combined_log_probs,
191
  loss_to_log=loss_to_log,