update modeling and tokenization
Browse files- modeling_gptpangu.py +2 -2
- tokenization_gptpangu.py +32 -10
modeling_gptpangu.py
CHANGED
|
@@ -460,7 +460,7 @@ class GPTPanguForCausalLM(GPTPanguPreTrainedModel):
|
|
| 460 |
|
| 461 |
if attention_mask is not None and position_ids is None:
|
| 462 |
# create position_ids on the fly for batch generation
|
| 463 |
-
position_ids = attention_mask.
|
| 464 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 465 |
if past:
|
| 466 |
position_ids = position_ids[:, -1].unsqueeze(-1)
|
|
@@ -521,7 +521,7 @@ class GPTPanguForCausalLM(GPTPanguPreTrainedModel):
|
|
| 521 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 522 |
shift_labels = labels[..., 1:].contiguous()
|
| 523 |
# Flatten the tokens
|
| 524 |
-
loss_fct = nn.CrossEntropyLoss()
|
| 525 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 526 |
|
| 527 |
if not return_dict:
|
|
|
|
| 460 |
|
| 461 |
if attention_mask is not None and position_ids is None:
|
| 462 |
# create position_ids on the fly for batch generation
|
| 463 |
+
position_ids = attention_mask.int().cumsum(-1).long() - 1
|
| 464 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 465 |
if past:
|
| 466 |
position_ids = position_ids[:, -1].unsqueeze(-1)
|
|
|
|
| 521 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 522 |
shift_labels = labels[..., 1:].contiguous()
|
| 523 |
# Flatten the tokens
|
| 524 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
| 525 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 526 |
|
| 527 |
if not return_dict:
|
tokenization_gptpangu.py
CHANGED
|
@@ -6,6 +6,13 @@ import numpy as np
|
|
| 6 |
|
| 7 |
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
class GPTPanguTokenizer(PreTrainedTokenizer):
|
| 11 |
# Ref: https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha/src/branch/master/tokenization_jieba.py
|
|
@@ -69,10 +76,25 @@ class GPTPanguTokenizer(PreTrainedTokenizer):
|
|
| 69 |
|
| 70 |
if isinstance(tokens, str):
|
| 71 |
return self._convert_token_to_id_with_added_voc(tokens)
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
def _convert_token_to_id(self, token):
|
| 78 |
return self.sp.piece_to_id(token)
|
|
@@ -83,16 +105,16 @@ class GPTPanguTokenizer(PreTrainedTokenizer):
|
|
| 83 |
def convert_ids_to_tokens(self, ids):
|
| 84 |
return self.decode(ids)
|
| 85 |
|
| 86 |
-
def decode(self,
|
| 87 |
-
if isinstance(
|
| 88 |
-
|
| 89 |
|
| 90 |
if kwargs.get('skip_special_tokens', None) is True:
|
| 91 |
-
|
| 92 |
-
text = self.sp.decode(
|
| 93 |
if isinstance(text, list):
|
| 94 |
text = text[0]
|
| 95 |
-
text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')
|
| 96 |
return text
|
| 97 |
|
| 98 |
@property
|
|
|
|
| 6 |
|
| 7 |
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 8 |
|
| 9 |
+
jieba.add_word('<s>')
|
| 10 |
+
jieba.add_word('</s>')
|
| 11 |
+
jieba.add_word('<eot>')
|
| 12 |
+
jieba.add_word('<unk>')
|
| 13 |
+
jieba.add_word('<sep>')
|
| 14 |
+
jieba.add_word('<pad>')
|
| 15 |
+
|
| 16 |
|
| 17 |
class GPTPanguTokenizer(PreTrainedTokenizer):
|
| 18 |
# Ref: https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha/src/branch/master/tokenization_jieba.py
|
|
|
|
| 76 |
|
| 77 |
if isinstance(tokens, str):
|
| 78 |
return self._convert_token_to_id_with_added_voc(tokens)
|
| 79 |
+
|
| 80 |
+
special_tokens_index = [i for i, token in enumerate(tokens) if token in self.all_special_tokens]
|
| 81 |
+
|
| 82 |
+
ids = []
|
| 83 |
+
i = 0
|
| 84 |
+
for j in special_tokens_index:
|
| 85 |
+
new_seg = " ".join(tokens[i:j])
|
| 86 |
+
ids.extend(self.sp.encode(new_seg))
|
| 87 |
+
ids.append(self._convert_token_to_id(tokens[j]))
|
| 88 |
+
i = j + 1
|
| 89 |
+
|
| 90 |
+
new_seg = " ".join(tokens[i:])
|
| 91 |
+
ids.extend(self.sp.encode(new_seg))
|
| 92 |
+
|
| 93 |
+
return ids
|
| 94 |
+
|
| 95 |
+
# new_seg = " ".join(tokens)
|
| 96 |
+
# return self.sp.encode(new_seg)
|
| 97 |
+
# # return tokens
|
| 98 |
|
| 99 |
def _convert_token_to_id(self, token):
|
| 100 |
return self.sp.piece_to_id(token)
|
|
|
|
| 105 |
def convert_ids_to_tokens(self, ids):
|
| 106 |
return self.decode(ids)
|
| 107 |
|
| 108 |
+
def decode(self, ids, **kwargs):
|
| 109 |
+
if isinstance(ids, torch.Tensor) or isinstance(ids, np.ndarray):
|
| 110 |
+
ids = ids.tolist()
|
| 111 |
|
| 112 |
if kwargs.get('skip_special_tokens', None) is True:
|
| 113 |
+
ids = [token_id for token_id in ids if token_id not in self.all_special_ids]
|
| 114 |
+
text = self.sp.decode(ids)
|
| 115 |
if isinstance(text, list):
|
| 116 |
text = text[0]
|
| 117 |
+
text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')#.replace('⁇', self.unk_token)
|
| 118 |
return text
|
| 119 |
|
| 120 |
@property
|