Update modeling_duo_predict_gpt2.py
Browse files
modeling_duo_predict_gpt2.py
CHANGED
|
@@ -129,10 +129,10 @@ def sdpa_attention_forward(
|
|
| 129 |
query,
|
| 130 |
key,
|
| 131 |
value,
|
| 132 |
-
attn_mask=create_attention_mask_matrix(query.shape[-2]).to(query.device),
|
| 133 |
dropout_p=dropout,
|
| 134 |
scale=scaling,
|
| 135 |
-
is_causal=
|
| 136 |
)
|
| 137 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 138 |
|
|
@@ -582,9 +582,12 @@ class DuoPredictGPT2Model(DuoPredictGPT2PretrainedModel):
|
|
| 582 |
inputs_embeds = self.wte(input_ids)
|
| 583 |
position_embeds = self.wpe(position_ids)
|
| 584 |
###TODO: correctly initialized
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
|
|
|
|
|
|
|
|
|
| 588 |
|
| 589 |
# Attention mask.
|
| 590 |
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
|
|
@@ -897,5 +900,6 @@ if __name__ == "__main__":
|
|
| 897 |
model = DuoPredictGPT2LMHeadModel(cg)
|
| 898 |
from src.utils.model_utlis import print_trainable_parameters
|
| 899 |
print_trainable_parameters(model)
|
|
|
|
| 900 |
model(torch.randint(0, 10000, (1, 100)))
|
| 901 |
print()
|
|
|
|
| 129 |
query,
|
| 130 |
key,
|
| 131 |
value,
|
| 132 |
+
attn_mask=create_attention_mask_matrix(query.shape[-2]).to(query.device) if module.training else None,
|
| 133 |
dropout_p=dropout,
|
| 134 |
scale=scaling,
|
| 135 |
+
is_causal=False if module.training else True,
|
| 136 |
)
|
| 137 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 138 |
|
|
|
|
| 582 |
inputs_embeds = self.wte(input_ids)
|
| 583 |
position_embeds = self.wpe(position_ids)
|
| 584 |
###TODO: correctly initialized
|
| 585 |
+
if inputs_embeds.shape[1] != position_embeds.shape[1]:
|
| 586 |
+
hidden_states = torch.empty((batch_size, input_shape[-1], self.embed_dim), device=device)
|
| 587 |
+
hidden_states[:, ::2] = inputs_embeds[:, ::2] + position_embeds.to(inputs_embeds.device)
|
| 588 |
+
hidden_states[:, 1::2] = inputs_embeds[:, 1::2] + position_embeds[:, :self.config.max_position_embeddings-1].to(inputs_embeds.device)
|
| 589 |
+
else:
|
| 590 |
+
hidden_states = inputs_embeds + position_embeds
|
| 591 |
|
| 592 |
# Attention mask.
|
| 593 |
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
|
|
|
|
| 900 |
model = DuoPredictGPT2LMHeadModel(cg)
|
| 901 |
from src.utils.model_utlis import print_trainable_parameters
|
| 902 |
print_trainable_parameters(model)
|
| 903 |
+
model.eval()
|
| 904 |
model(torch.randint(0, 10000, (1, 100)))
|
| 905 |
print()
|