模型并行出错并给出修改方案
#54
by
yuanzhoulvpi
- opened
- modeling_chatglm.py +1 -1
modeling_chatglm.py
CHANGED
|
@@ -952,7 +952,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 952 |
|
| 953 |
# Shift so that tokens < n predict n
|
| 954 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 955 |
-
shift_labels = labels[..., 1:].contiguous()
|
| 956 |
# Flatten the tokens
|
| 957 |
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
| 958 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
|
|
| 952 |
|
| 953 |
# Shift so that tokens < n predict n
|
| 954 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 955 |
+
shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
|
| 956 |
# Flatten the tokens
|
| 957 |
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
| 958 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|