Gong Baitao
commited on
Commit
·
2156c56
1
Parent(s):
32554d7
Update tokenization_cpmbee.py
Browse files- tokenization_cpmbee.py +130 -0
tokenization_cpmbee.py
CHANGED
|
@@ -18,6 +18,7 @@ import os
|
|
| 18 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 19 |
|
| 20 |
import numpy as np
|
|
|
|
| 21 |
from typing_extensions import TypedDict
|
| 22 |
|
| 23 |
from transformers.tokenization_utils import PaddingStrategy, PreTrainedTokenizer, TensorType
|
|
@@ -866,3 +867,132 @@ class CpmBeeTokenizer(PreTrainedTokenizer):
|
|
| 866 |
)
|
| 867 |
|
| 868 |
return batch_outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 19 |
|
| 20 |
import numpy as np
|
| 21 |
+
from numpy.typing import NDArray
|
| 22 |
from typing_extensions import TypedDict
|
| 23 |
|
| 24 |
from transformers.tokenization_utils import PaddingStrategy, PreTrainedTokenizer, TensorType
|
|
|
|
| 867 |
)
|
| 868 |
|
| 869 |
return batch_outputs
|
| 870 |
+
|
| 871 |
+
def prepare_for_finetune(
|
| 872 |
+
self,
|
| 873 |
+
data_list: List[Dict],
|
| 874 |
+
max_length: int = 2048
|
| 875 |
+
):
|
| 876 |
+
_inputs: List[NDArray[np.int32]] = []
|
| 877 |
+
_inputs_sub: List[NDArray[np.int32]] = []
|
| 878 |
+
_context: List[NDArray[np.int8]] = []
|
| 879 |
+
_sample_ids: List[NDArray[np.int32]] = []
|
| 880 |
+
_segments: List[NDArray[np.int32]] = []
|
| 881 |
+
_num_segments: List[NDArray[np.int32]] = []
|
| 882 |
+
_segment_rel_offset: List[NDArray[np.int32]] = []
|
| 883 |
+
_segment_rel: List[NDArray[np.int32]] = []
|
| 884 |
+
_spans: List[List[int]] = []
|
| 885 |
+
_raw_data: List[List[Any]] = []
|
| 886 |
+
|
| 887 |
+
raw_data = {}
|
| 888 |
+
for data in data_list:
|
| 889 |
+
(
|
| 890 |
+
input_ids,
|
| 891 |
+
input_id_subs,
|
| 892 |
+
context,
|
| 893 |
+
segment_ids,
|
| 894 |
+
segment_rel,
|
| 895 |
+
n_segments,
|
| 896 |
+
_
|
| 897 |
+
) = self.convert_data_to_id(data)
|
| 898 |
+
|
| 899 |
+
input_ids = input_ids[: max_length]
|
| 900 |
+
context = context[: max_length]
|
| 901 |
+
segment_ids = segment_ids[: max_length]
|
| 902 |
+
raw_data["input"] = data
|
| 903 |
+
raw_data["samples"] = []
|
| 904 |
+
|
| 905 |
+
sample_ids = np.zeros(input_ids.shape, dtype=np.int32)
|
| 906 |
+
segment_rel_offset = np.zeros(input_ids.shape, dtype=np.int32)
|
| 907 |
+
num_segments = np.full(input_ids.shape, n_segments, dtype=np.int32)
|
| 908 |
+
|
| 909 |
+
_inputs.append(input_ids)
|
| 910 |
+
_inputs_sub.append(input_id_subs)
|
| 911 |
+
_context.append(context)
|
| 912 |
+
_sample_ids.append(sample_ids)
|
| 913 |
+
_segments.append(segment_ids)
|
| 914 |
+
_num_segments.append(num_segments)
|
| 915 |
+
_segment_rel_offset.append(segment_rel_offset)
|
| 916 |
+
_segment_rel.append(segment_rel)
|
| 917 |
+
_spans.append([input_ids.shape[0]])
|
| 918 |
+
_raw_data.append([raw_data])
|
| 919 |
+
|
| 920 |
+
batch_size = len(_inputs)
|
| 921 |
+
inputs = np.zeros((batch_size, max_length), dtype=np.int32)
|
| 922 |
+
inputs_sub = np.zeros((batch_size, max_length), dtype=np.int32)
|
| 923 |
+
context = np.zeros((batch_size, max_length), dtype=np.int8)
|
| 924 |
+
sample_ids = np.zeros((batch_size, max_length), dtype=np.int32)
|
| 925 |
+
segments = np.zeros((batch_size, max_length), dtype=np.int32)
|
| 926 |
+
num_segments = np.zeros((batch_size, max_length), dtype=np.int32)
|
| 927 |
+
segment_rel_offset = np.zeros((batch_size, max_length), dtype=np.int32)
|
| 928 |
+
tgt = np.full((batch_size, max_length), -100, dtype=np.int32)
|
| 929 |
+
|
| 930 |
+
max_rel = 0
|
| 931 |
+
for i in range(batch_size):
|
| 932 |
+
max_rel = max(max_rel, _segment_rel[i].shape[0])
|
| 933 |
+
segment_rel = np.zeros((batch_size, max_rel), dtype=np.int32)
|
| 934 |
+
spans = np.zeros((batch_size, max_length), dtype=np.int32)
|
| 935 |
+
length = np.zeros((batch_size,), dtype=np.int32)
|
| 936 |
+
|
| 937 |
+
batch_ext_table_map: Dict[Tuple[int, int], int] = {}
|
| 938 |
+
batch_ext_table_ids: List[int] = []
|
| 939 |
+
batch_ext_table_sub: List[int] = []
|
| 940 |
+
raw_data_list: List[Any] = []
|
| 941 |
+
|
| 942 |
+
for i in range(batch_size):
|
| 943 |
+
instance_length = _inputs[i].shape[0]
|
| 944 |
+
rel_size = _segment_rel[i].shape[0]
|
| 945 |
+
inputs[i, :instance_length] = _inputs[i]
|
| 946 |
+
inputs_sub[i, :instance_length] = _inputs_sub[i]
|
| 947 |
+
context[i, :instance_length] = _context[i]
|
| 948 |
+
sample_ids[i, :instance_length] = _sample_ids[i]
|
| 949 |
+
segments[i, :instance_length] = _segments[i]
|
| 950 |
+
num_segments[i, :instance_length] = _num_segments[i]
|
| 951 |
+
segment_rel_offset[i, :instance_length] = _segment_rel_offset[i]
|
| 952 |
+
segment_rel[i, :rel_size] = _segment_rel[i]
|
| 953 |
+
|
| 954 |
+
span_begin = 0
|
| 955 |
+
for span_id, span_end in enumerate(_spans[i]):
|
| 956 |
+
spans[i, span_begin:span_end] = span_id
|
| 957 |
+
span_begin = span_end
|
| 958 |
+
length[i] = instance_length
|
| 959 |
+
raw_data_list.extend(_raw_data[i])
|
| 960 |
+
|
| 961 |
+
for j in range(instance_length):
|
| 962 |
+
idx, idx_sub = _inputs[i][j], _inputs_sub[i][j]
|
| 963 |
+
tgt_idx = idx
|
| 964 |
+
if idx_sub > 0:
|
| 965 |
+
# need to be in ext table
|
| 966 |
+
if (idx, idx_sub) not in batch_ext_table_map:
|
| 967 |
+
batch_ext_table_map[(idx, idx_sub)] = len(batch_ext_table_map)
|
| 968 |
+
batch_ext_table_ids.append(idx)
|
| 969 |
+
batch_ext_table_sub.append(idx_sub)
|
| 970 |
+
tgt_idx = batch_ext_table_map[(idx, idx_sub)] + self.vocab_size
|
| 971 |
+
if j > 1 and context[i, j - 1] == 0:
|
| 972 |
+
if idx != self.bos_token_id:
|
| 973 |
+
tgt[i, j - 1] = tgt_idx
|
| 974 |
+
else:
|
| 975 |
+
tgt[i, j - 1] = self.eos_token_id
|
| 976 |
+
if context[i, instance_length - 1] == 0:
|
| 977 |
+
tgt[i, instance_length - 1] = self.eos_token_id
|
| 978 |
+
|
| 979 |
+
if len(batch_ext_table_map) == 0:
|
| 980 |
+
# placeholder
|
| 981 |
+
batch_ext_table_ids.append(0)
|
| 982 |
+
batch_ext_table_sub.append(1)
|
| 983 |
+
|
| 984 |
+
return BatchEncoding({
|
| 985 |
+
"input_ids": inputs,
|
| 986 |
+
"input_id_sub": inputs_sub,
|
| 987 |
+
"length": length,
|
| 988 |
+
"context": context > 0,
|
| 989 |
+
"sample_ids": sample_ids,
|
| 990 |
+
"num_segments": num_segments,
|
| 991 |
+
"segment": segments,
|
| 992 |
+
"segment_rel_offset": segment_rel_offset,
|
| 993 |
+
"segment_rel": segment_rel,
|
| 994 |
+
"span": spans,
|
| 995 |
+
"labels": tgt,
|
| 996 |
+
"ext_table_ids": np.array(batch_ext_table_ids, dtype=np.int32),
|
| 997 |
+
"ext_table_sub": np.array(batch_ext_table_sub, dtype=np.int32)
|
| 998 |
+
}, tensor_type="pt")
|