Charlie81 commited on
Commit
1ec67ec
·
1 Parent(s): c306fa9

aaux to torch tensor

Browse files
Files changed (1) hide show
  1. myolmoe/modeling_myolmoe.py +2 -1
myolmoe/modeling_myolmoe.py CHANGED
@@ -1037,7 +1037,8 @@ class MyOlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
1037
  if labels is not None:
1038
  loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
1039
  #
1040
- total_aux_loss = 0
 
1041
  if output_router_logits and outputs.router_logits is not None:
1042
  # Regular load balancing loss
1043
  total_aux_loss += load_balancing_loss_func(
 
1037
  if labels is not None:
1038
  loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
1039
  #
1040
+ total_aux_loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
1041
+
1042
  if output_router_logits and outputs.router_logits is not None:
1043
  # Regular load balancing loss
1044
  total_aux_loss += load_balancing_loss_func(