Alexandru Gherghescu
commited on
Fix modeling_gpt1.py
Browse filesFix an issue with the attention mask, where its size would not be
correct during training and inference.
- modeling_gpt1.py +19 -9
modeling_gpt1.py
CHANGED
|
@@ -154,6 +154,7 @@ class GPT1Model(GPT1PreTrainedModel):
|
|
| 154 |
self.register_buffer('causal_mask',
|
| 155 |
torch.triu(causal_mask, diagonal=1),
|
| 156 |
persistent=False)
|
|
|
|
| 157 |
|
| 158 |
self.post_init()
|
| 159 |
|
|
@@ -172,12 +173,18 @@ class GPT1Model(GPT1PreTrainedModel):
|
|
| 172 |
position_embeds = self.pos_emb(position_ids)
|
| 173 |
hidden_state = self.embs_dropout(input_embeds) + position_embeds
|
| 174 |
|
| 175 |
-
if attention_mask is not None:
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
causal_mask =
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
for layer in self.layers:
|
| 183 |
hidden_state = layer(hidden_state, attn_mask=causal_mask)
|
|
@@ -240,10 +247,13 @@ class GPT1ForCausalLM(GPT1PreTrainedModel):
|
|
| 240 |
logits=logits
|
| 241 |
)
|
| 242 |
|
| 243 |
-
def prepare_inputs_for_generation(self, input_ids,
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
-
attn_mask = torch.full((
|
| 247 |
attn_mask = torch.triu(attn_mask, diagonal=1)
|
| 248 |
|
| 249 |
return {
|
|
|
|
| 154 |
self.register_buffer('causal_mask',
|
| 155 |
torch.triu(causal_mask, diagonal=1),
|
| 156 |
persistent=False)
|
| 157 |
+
self.mask_cache_len = config.max_position_embeddings
|
| 158 |
|
| 159 |
self.post_init()
|
| 160 |
|
|
|
|
| 173 |
position_embeds = self.pos_emb(position_ids)
|
| 174 |
hidden_state = self.embs_dropout(input_embeds) + position_embeds
|
| 175 |
|
| 176 |
+
if attention_mask is not None and attention_mask.size(1) > self.mask_cache_len:
|
| 177 |
+
seq_len = attention_mask.size(1)
|
| 178 |
+
self.mask_cache_len = seq_len
|
| 179 |
+
|
| 180 |
+
causal_mask = torch.full((seq_len, seq_len),
|
| 181 |
+
fill_value=float('-inf'))
|
| 182 |
+
self.register_buffer('causal_mask',
|
| 183 |
+
torch.triu(causal_mask, diagonal=1),
|
| 184 |
+
persistent=False)
|
| 185 |
+
|
| 186 |
+
causal_mask = self.causal_mask.to(dtype=input_embeds.dtype,
|
| 187 |
+
device=input_embeds.device)
|
| 188 |
|
| 189 |
for layer in self.layers:
|
| 190 |
hidden_state = layer(hidden_state, attn_mask=causal_mask)
|
|
|
|
| 247 |
logits=logits
|
| 248 |
)
|
| 249 |
|
| 250 |
+
def prepare_inputs_for_generation(self, input_ids, attention_mask,
|
| 251 |
+
*args, **kwargs):
|
| 252 |
+
assert attention_mask.size(1) == input_ids.size(1)
|
| 253 |
+
|
| 254 |
+
seq_len = attention_mask.size(1)
|
| 255 |
|
| 256 |
+
attn_mask = torch.full((seq_len, seq_len), fill_value=float('-inf'))
|
| 257 |
attn_mask = torch.triu(attn_mask, diagonal=1)
|
| 258 |
|
| 259 |
return {
|