aaux to torch tensor
Browse files
myolmoe/modeling_myolmoe.py
CHANGED
|
@@ -1037,7 +1037,8 @@ class MyOlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
|
|
| 1037 |
if labels is not None:
|
| 1038 |
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
| 1039 |
#
|
| 1040 |
-
total_aux_loss = 0
|
|
|
|
| 1041 |
if output_router_logits and outputs.router_logits is not None:
|
| 1042 |
# Regular load balancing loss
|
| 1043 |
total_aux_loss += load_balancing_loss_func(
|
|
|
|
| 1037 |
if labels is not None:
|
| 1038 |
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
| 1039 |
#
|
| 1040 |
+
total_aux_loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
|
| 1041 |
+
|
| 1042 |
if output_router_logits and outputs.router_logits is not None:
|
| 1043 |
# Regular load balancing loss
|
| 1044 |
total_aux_loss += load_balancing_loss_func(
|