Model save
Browse files- README.md +39 -39
- model.safetensors +1 -1
- modeling_duo_predict_gpt2.py +23 -15
README.md
CHANGED
|
@@ -17,9 +17,9 @@ should probably proofread and complete it, then remove this comment. -->
|
|
| 17 |
|
| 18 |
This model is a fine-tuned version of [](https://huggingface.co/) on an unknown dataset.
|
| 19 |
It achieves the following results on the evaluation set:
|
| 20 |
-
- Loss:
|
| 21 |
- Accuracy: 0.0073
|
| 22 |
-
- Perplexity:
|
| 23 |
- Bleu: 1.0
|
| 24 |
|
| 25 |
## Model description
|
|
@@ -50,43 +50,43 @@ The following hyperparameters were used during training:
|
|
| 50 |
|
| 51 |
### Training results
|
| 52 |
|
| 53 |
-
| Training Loss | Epoch | Step | Validation Loss | Accuracy | Perplexity | Bleu
|
| 54 |
-
|
| 55 |
-
| 7.
|
| 56 |
-
|
|
| 57 |
-
|
|
| 58 |
-
|
|
| 59 |
-
|
|
| 60 |
-
|
|
| 61 |
-
|
|
| 62 |
-
|
|
| 63 |
-
|
|
| 64 |
-
|
|
| 65 |
-
|
|
| 66 |
-
|
|
| 67 |
-
|
|
| 68 |
-
|
|
| 69 |
-
|
|
| 70 |
-
|
|
| 71 |
-
|
|
| 72 |
-
|
|
| 73 |
-
|
|
| 74 |
-
|
|
| 75 |
-
|
|
| 76 |
-
|
|
| 77 |
-
|
|
| 78 |
-
|
|
| 79 |
-
|
|
| 80 |
-
|
|
| 81 |
-
|
|
| 82 |
-
|
|
| 83 |
-
|
|
| 84 |
-
|
|
| 85 |
-
|
|
| 86 |
-
|
|
| 87 |
-
|
|
| 88 |
-
|
|
| 89 |
-
|
|
| 90 |
|
| 91 |
|
| 92 |
### Framework versions
|
|
|
|
| 17 |
|
| 18 |
This model is a fine-tuned version of [](https://huggingface.co/) on an unknown dataset.
|
| 19 |
It achieves the following results on the evaluation set:
|
| 20 |
+
- Loss: 2.2546
|
| 21 |
- Accuracy: 0.0073
|
| 22 |
+
- Perplexity: 9.5311
|
| 23 |
- Bleu: 1.0
|
| 24 |
|
| 25 |
## Model description
|
|
|
|
| 50 |
|
| 51 |
### Training results
|
| 52 |
|
| 53 |
+
| Training Loss | Epoch | Step | Validation Loss | Accuracy | Perplexity | Bleu |
|
| 54 |
+
|:-------------:|:------:|:-----:|:---------------:|:--------:|:----------:|:----:|
|
| 55 |
+
| 7.6654 | 0.1403 | 500 | 3.7315 | 0.0073 | 41.7396 | 1.0 |
|
| 56 |
+
| 7.0276 | 0.2807 | 1000 | 3.4735 | 0.0073 | 32.2490 | 1.0 |
|
| 57 |
+
| 6.4629 | 0.4210 | 1500 | 3.1863 | 0.0073 | 24.1987 | 1.0 |
|
| 58 |
+
| 5.9671 | 0.5613 | 2000 | 2.9542 | 0.0073 | 19.1873 | 1.0 |
|
| 59 |
+
| 5.6969 | 0.7017 | 2500 | 2.8233 | 0.0073 | 16.8331 | 1.0 |
|
| 60 |
+
| 5.5077 | 0.8420 | 3000 | 2.7351 | 0.0073 | 15.4112 | 1.0 |
|
| 61 |
+
| 5.3536 | 0.9823 | 3500 | 2.6607 | 0.0073 | 14.3059 | 1.0 |
|
| 62 |
+
| 5.2099 | 1.1226 | 4000 | 2.6000 | 0.0073 | 13.4641 | 1.0 |
|
| 63 |
+
| 5.1158 | 1.2630 | 4500 | 2.5493 | 0.0073 | 12.7980 | 1.0 |
|
| 64 |
+
| 5.0453 | 1.4033 | 5000 | 2.5125 | 0.0073 | 12.3362 | 1.0 |
|
| 65 |
+
| 4.955 | 1.5436 | 5500 | 2.4806 | 0.0073 | 11.9489 | 1.0 |
|
| 66 |
+
| 4.9157 | 1.6840 | 6000 | 2.4537 | 0.0073 | 11.6310 | 1.0 |
|
| 67 |
+
| 4.8756 | 1.8243 | 6500 | 2.4300 | 0.0073 | 11.3584 | 1.0 |
|
| 68 |
+
| 4.844 | 1.9646 | 7000 | 2.4100 | 0.0073 | 11.1342 | 1.0 |
|
| 69 |
+
| 4.7136 | 2.1050 | 7500 | 2.3948 | 0.0073 | 10.9657 | 1.0 |
|
| 70 |
+
| 4.6911 | 2.2453 | 8000 | 2.3805 | 0.0073 | 10.8105 | 1.0 |
|
| 71 |
+
| 4.6741 | 2.3856 | 8500 | 2.3668 | 0.0073 | 10.6637 | 1.0 |
|
| 72 |
+
| 4.6485 | 2.5260 | 9000 | 2.3538 | 0.0073 | 10.5257 | 1.0 |
|
| 73 |
+
| 4.623 | 2.6663 | 9500 | 2.3416 | 0.0073 | 10.3976 | 1.0 |
|
| 74 |
+
| 4.6016 | 2.8066 | 10000 | 2.3303 | 0.0073 | 10.2806 | 1.0 |
|
| 75 |
+
| 4.5823 | 2.9470 | 10500 | 2.3202 | 0.0073 | 10.1776 | 1.0 |
|
| 76 |
+
| 4.4802 | 3.0873 | 11000 | 2.3143 | 0.0073 | 10.1182 | 1.0 |
|
| 77 |
+
| 4.4671 | 3.2276 | 11500 | 2.3073 | 0.0073 | 10.0469 | 1.0 |
|
| 78 |
+
| 4.4557 | 3.3679 | 12000 | 2.3006 | 0.0073 | 9.9800 | 1.0 |
|
| 79 |
+
| 4.4437 | 3.5083 | 12500 | 2.2928 | 0.0073 | 9.9023 | 1.0 |
|
| 80 |
+
| 4.4402 | 3.6486 | 13000 | 2.2862 | 0.0073 | 9.8375 | 1.0 |
|
| 81 |
+
| 4.4482 | 3.7889 | 13500 | 2.2800 | 0.0073 | 9.7763 | 1.0 |
|
| 82 |
+
| 4.4279 | 3.9293 | 14000 | 2.2752 | 0.0073 | 9.7303 | 1.0 |
|
| 83 |
+
| 4.3188 | 4.0696 | 14500 | 2.2730 | 0.0073 | 9.7087 | 1.0 |
|
| 84 |
+
| 4.3193 | 4.2099 | 15000 | 2.2691 | 0.0073 | 9.6704 | 1.0 |
|
| 85 |
+
| 4.3158 | 4.3503 | 15500 | 2.2652 | 0.0073 | 9.6329 | 1.0 |
|
| 86 |
+
| 4.3196 | 4.4906 | 16000 | 2.2619 | 0.0073 | 9.6012 | 1.0 |
|
| 87 |
+
| 4.2946 | 4.6309 | 16500 | 2.2589 | 0.0073 | 9.5722 | 1.0 |
|
| 88 |
+
| 4.3078 | 4.7713 | 17000 | 2.2564 | 0.0073 | 9.5487 | 1.0 |
|
| 89 |
+
| 4.2974 | 4.9116 | 17500 | 2.2546 | 0.0073 | 9.5311 | 1.0 |
|
| 90 |
|
| 91 |
|
| 92 |
### Framework versions
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1417229824
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9839d5f4cd171dea6e15ba583be6e1cce05dea4e8444d950b8ccd0ca73da4483
|
| 3 |
size 1417229824
|
modeling_duo_predict_gpt2.py
CHANGED
|
@@ -77,7 +77,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
|
| 77 |
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
| 78 |
if is_causal:
|
| 79 |
assert attn_mask is None
|
| 80 |
-
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
|
| 81 |
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
| 82 |
attn_bias.to(query.dtype)
|
| 83 |
|
|
@@ -129,10 +129,10 @@ def sdpa_attention_forward(
|
|
| 129 |
query,
|
| 130 |
key,
|
| 131 |
value,
|
| 132 |
-
attn_mask=create_attention_mask_matrix(query.shape[-2]).to(query.device),
|
| 133 |
dropout_p=dropout,
|
| 134 |
scale=scaling,
|
| 135 |
-
is_causal=
|
| 136 |
)
|
| 137 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 138 |
|
|
@@ -582,9 +582,12 @@ class DuoPredictGPT2Model(DuoPredictGPT2PretrainedModel):
|
|
| 582 |
inputs_embeds = self.wte(input_ids)
|
| 583 |
position_embeds = self.wpe(position_ids)
|
| 584 |
###TODO: correctly initialized
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
|
|
|
|
|
|
|
|
|
| 588 |
|
| 589 |
# Attention mask.
|
| 590 |
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
|
|
@@ -836,17 +839,21 @@ class DuoPredictGPT2LMHeadModel(DuoPredictGPT2PretrainedModel, GenerationMixin):
|
|
| 836 |
lm_logits = self.lm_head(hidden_states)
|
| 837 |
|
| 838 |
loss = None
|
|
|
|
| 839 |
if labels is not None:
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
|
|
|
|
|
|
|
|
|
| 844 |
loss = self.loss_function(
|
| 845 |
-
lm_logits,
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
|
| 851 |
if not return_dict:
|
| 852 |
output = (lm_logits,) + transformer_outputs[1:]
|
|
@@ -897,5 +904,6 @@ if __name__ == "__main__":
|
|
| 897 |
model = DuoPredictGPT2LMHeadModel(cg)
|
| 898 |
from src.utils.model_utlis import print_trainable_parameters
|
| 899 |
print_trainable_parameters(model)
|
|
|
|
| 900 |
model(torch.randint(0, 10000, (1, 100)))
|
| 901 |
print()
|
|
|
|
| 77 |
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
| 78 |
if is_causal:
|
| 79 |
assert attn_mask is None
|
| 80 |
+
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
|
| 81 |
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
| 82 |
attn_bias.to(query.dtype)
|
| 83 |
|
|
|
|
| 129 |
query,
|
| 130 |
key,
|
| 131 |
value,
|
| 132 |
+
attn_mask=create_attention_mask_matrix(query.shape[-2]).to(query.device) if query.shape[1]>module.config.max_position_embeddings else None,
|
| 133 |
dropout_p=dropout,
|
| 134 |
scale=scaling,
|
| 135 |
+
is_causal=False if query.shape[1]>module.config.max_position_embeddings else True,
|
| 136 |
)
|
| 137 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 138 |
|
|
|
|
| 582 |
inputs_embeds = self.wte(input_ids)
|
| 583 |
position_embeds = self.wpe(position_ids)
|
| 584 |
###TODO: correctly initialized
|
| 585 |
+
if inputs_embeds.shape[1] != position_embeds.shape[1]:
|
| 586 |
+
hidden_states = torch.empty((batch_size, input_shape[-1], self.embed_dim), device=device)
|
| 587 |
+
hidden_states[:, ::2] = inputs_embeds[:, ::2] + position_embeds.to(inputs_embeds.device)
|
| 588 |
+
hidden_states[:, 1::2] = inputs_embeds[:, 1::2] + position_embeds[:, :self.config.max_position_embeddings-1].to(inputs_embeds.device)
|
| 589 |
+
else:
|
| 590 |
+
hidden_states = inputs_embeds + position_embeds
|
| 591 |
|
| 592 |
# Attention mask.
|
| 593 |
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
|
|
|
|
| 839 |
lm_logits = self.lm_head(hidden_states)
|
| 840 |
|
| 841 |
loss = None
|
| 842 |
+
bs, seq = lm_logits.shape[:2]
|
| 843 |
if labels is not None:
|
| 844 |
+
if seq>labels.shape[1]:
|
| 845 |
+
# Flatten the tokens
|
| 846 |
+
total_labels = torch.full((bs, seq-1), -100, dtype=input_ids.dtype, device=input_ids.device)
|
| 847 |
+
total_labels[:, :-1:2] = labels[:, 1: ]
|
| 848 |
+
total_labels[:, 1::2] = labels[:, :-1]
|
| 849 |
+
else:
|
| 850 |
+
total_labels = labels[:, 1:]
|
| 851 |
loss = self.loss_function(
|
| 852 |
+
lm_logits[:, :-1],
|
| 853 |
+
total_labels,
|
| 854 |
+
vocab_size=self.config.vocab_size,
|
| 855 |
+
**kwargs,
|
| 856 |
+
)
|
| 857 |
|
| 858 |
if not return_dict:
|
| 859 |
output = (lm_logits,) + transformer_outputs[1:]
|
|
|
|
| 904 |
model = DuoPredictGPT2LMHeadModel(cg)
|
| 905 |
from src.utils.model_utlis import print_trainable_parameters
|
| 906 |
print_trainable_parameters(model)
|
| 907 |
+
model.eval()
|
| 908 |
model(torch.randint(0, 10000, (1, 100)))
|
| 909 |
print()
|