oweller2 commited on
Commit ·
2d5427f
1
Parent(s): 082b6b3
done
Browse files- config.json +1 -1
- modeling_flexbert.py +0 -22
- tokenizer.py +15 -3
config.json
CHANGED
|
@@ -70,7 +70,7 @@
|
|
| 70 |
"num_hidden_layers": 22,
|
| 71 |
"num_initial_layers": 1,
|
| 72 |
"pad_logits": true,
|
| 73 |
-
"pad_token_id":
|
| 74 |
"padding": "unpadded",
|
| 75 |
"pooling_type": "cls",
|
| 76 |
"position_embedding_type": "absolute",
|
|
|
|
| 70 |
"num_hidden_layers": 22,
|
| 71 |
"num_initial_layers": 1,
|
| 72 |
"pad_logits": true,
|
| 73 |
+
"pad_token_id": null,
|
| 74 |
"padding": "unpadded",
|
| 75 |
"pooling_type": "cls",
|
| 76 |
"position_embedding_type": "absolute",
|
modeling_flexbert.py
CHANGED
|
@@ -1713,36 +1713,14 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
| 1713 |
self,
|
| 1714 |
input_ids: torch.Tensor,
|
| 1715 |
attention_mask: Optional[torch.Tensor] = None,
|
| 1716 |
-
position_ids: Optional[torch.Tensor] = None,
|
| 1717 |
**kwargs
|
| 1718 |
) -> dict:
|
| 1719 |
if attention_mask is None:
|
| 1720 |
attention_mask = torch.ones_like(input_ids)
|
| 1721 |
|
| 1722 |
-
# Calculate sequence-local positions
|
| 1723 |
-
seqlens = attention_mask.sum(dim=-1) # Get length of each sequence
|
| 1724 |
-
position_ids = torch.zeros_like(input_ids)
|
| 1725 |
-
for i in range(len(seqlens)):
|
| 1726 |
-
position_ids[i, :seqlens[i]] = torch.arange(seqlens[i], device=input_ids.device)
|
| 1727 |
-
|
| 1728 |
-
|
| 1729 |
-
batch_size, seq_len = input_ids.shape[:2]
|
| 1730 |
-
if self.unpad_embeddings:
|
| 1731 |
-
input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = self.unpad_inputs(
|
| 1732 |
-
input_ids, attention_mask, position_ids, None
|
| 1733 |
-
)
|
| 1734 |
-
else:
|
| 1735 |
-
indices = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).repeat(batch_size, 1)
|
| 1736 |
-
cu_seqlens = None
|
| 1737 |
-
max_seqlen = None
|
| 1738 |
return {
|
| 1739 |
"input_ids": input_ids,
|
| 1740 |
"attention_mask": attention_mask,
|
| 1741 |
-
"position_ids": position_ids,
|
| 1742 |
-
"indices": indices,
|
| 1743 |
-
"cu_seqlens": cu_seqlens,
|
| 1744 |
-
"max_seqlen": max_seqlen,
|
| 1745 |
-
"batch_size": batch_size,
|
| 1746 |
}
|
| 1747 |
|
| 1748 |
def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
|
|
|
|
| 1713 |
self,
|
| 1714 |
input_ids: torch.Tensor,
|
| 1715 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
| 1716 |
**kwargs
|
| 1717 |
) -> dict:
|
| 1718 |
if attention_mask is None:
|
| 1719 |
attention_mask = torch.ones_like(input_ids)
|
| 1720 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1721 |
return {
|
| 1722 |
"input_ids": input_ids,
|
| 1723 |
"attention_mask": attention_mask,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1724 |
}
|
| 1725 |
|
| 1726 |
def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
|
tokenizer.py
CHANGED
|
@@ -23,7 +23,11 @@ class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
|
|
| 23 |
ends_with_eos(seq) for seq in input_ids
|
| 24 |
], dtype=torch.bool)
|
| 25 |
|
| 26 |
-
if last_token_is_eos.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# Process each sequence individually
|
| 28 |
batch_size = input_ids.shape[0]
|
| 29 |
for i in range(batch_size):
|
|
@@ -41,7 +45,11 @@ class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
|
|
| 41 |
ends_with_eos(seq) for seq in input_ids
|
| 42 |
], dtype=bool)
|
| 43 |
|
| 44 |
-
if last_token_is_eos.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
batch_size = input_ids.shape[0]
|
| 46 |
for i in range(batch_size):
|
| 47 |
if last_token_is_eos[i]:
|
|
@@ -56,7 +64,11 @@ class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
|
|
| 56 |
elif isinstance(input_ids, list):
|
| 57 |
last_token_is_eos = [ends_with_eos(seq) for seq in input_ids]
|
| 58 |
|
| 59 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
for key in ['input_ids', 'attention_mask']:
|
| 61 |
outputs[key] = [
|
| 62 |
[0] + sequence[:-1] if is_eos else sequence
|
|
|
|
| 23 |
ends_with_eos(seq) for seq in input_ids
|
| 24 |
], dtype=torch.bool)
|
| 25 |
|
| 26 |
+
if last_token_is_eos.all():
|
| 27 |
+
# If all sequences have EOS, just truncate all
|
| 28 |
+
for key in ['input_ids', 'attention_mask']:
|
| 29 |
+
outputs[key] = outputs[key][..., :-1]
|
| 30 |
+
elif last_token_is_eos.any():
|
| 31 |
# Process each sequence individually
|
| 32 |
batch_size = input_ids.shape[0]
|
| 33 |
for i in range(batch_size):
|
|
|
|
| 45 |
ends_with_eos(seq) for seq in input_ids
|
| 46 |
], dtype=bool)
|
| 47 |
|
| 48 |
+
if last_token_is_eos.all():
|
| 49 |
+
# If all sequences have EOS, just truncate all
|
| 50 |
+
for key in ['input_ids', 'attention_mask']:
|
| 51 |
+
outputs[key] = outputs[key][..., :-1]
|
| 52 |
+
elif last_token_is_eos.any():
|
| 53 |
batch_size = input_ids.shape[0]
|
| 54 |
for i in range(batch_size):
|
| 55 |
if last_token_is_eos[i]:
|
|
|
|
| 64 |
elif isinstance(input_ids, list):
|
| 65 |
last_token_is_eos = [ends_with_eos(seq) for seq in input_ids]
|
| 66 |
|
| 67 |
+
if all(last_token_is_eos):
|
| 68 |
+
# If all sequences have EOS, just truncate all
|
| 69 |
+
for key in ['input_ids', 'attention_mask']:
|
| 70 |
+
outputs[key] = [sequence[:-1] for sequence in outputs[key]]
|
| 71 |
+
elif any(last_token_is_eos):
|
| 72 |
for key in ['input_ids', 'attention_mask']:
|
| 73 |
outputs[key] = [
|
| 74 |
[0] + sequence[:-1] if is_eos else sequence
|