duzx16 commited on
Commit ·
08bc851
1
Parent(s): 4b7ffbf
Fix attention mask for prefix prompt
Browse files- modeling_chatglm.py +6 -5
modeling_chatglm.py
CHANGED
|
@@ -919,11 +919,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 919 |
device=input_ids.device
|
| 920 |
)
|
| 921 |
|
| 922 |
-
if self.pre_seq_len is not None:
|
| 923 |
-
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
|
| 924 |
-
attention_mask.device)
|
| 925 |
-
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
| 926 |
-
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
| 927 |
|
| 928 |
if position_ids is None:
|
| 929 |
MASK, gMASK = 150000, 150001
|
|
@@ -938,6 +933,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 938 |
gmask=use_gmask
|
| 939 |
)
|
| 940 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 941 |
# [seq_len, batch, hidden_size]
|
| 942 |
hidden_states = inputs_embeds.transpose(0, 1)
|
| 943 |
|
|
|
|
| 919 |
device=input_ids.device
|
| 920 |
)
|
| 921 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 922 |
|
| 923 |
if position_ids is None:
|
| 924 |
MASK, gMASK = 150000, 150001
|
|
|
|
| 933 |
gmask=use_gmask
|
| 934 |
)
|
| 935 |
|
| 936 |
+
if self.pre_seq_len is not None and attention_mask is not None:
|
| 937 |
+
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
|
| 938 |
+
attention_mask.device)
|
| 939 |
+
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
| 940 |
+
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
| 941 |
+
|
| 942 |
# [seq_len, batch, hidden_size]
|
| 943 |
hidden_states = inputs_embeds.transpose(0, 1)
|
| 944 |
|