File size: 6,954 Bytes
1a2a9f7 9bcd027 1a2a9f7 9bcd027 1a2a9f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import copy
import random
import json
def get_bos_eos_token_ids(tokenizer):
if tokenizer.__class__.__name__ in [
'QWenTokenizer', 'QWen2Tokenizer', 'Qwen2TokenizerFast'
]:
bos_token_id = []
eos_token_id = tokenizer.eos_token_id
assert eos_token_id is not None, \
'Please set eos_token for Qwen tokenizer!'
elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer':
bos_token_id = [64790, 64792]
eos_token_id = tokenizer.eos_token_id
else:
bos_token_id = tokenizer.bos_token_id
eos_token_id = tokenizer.eos_token_id
if isinstance(bos_token_id, int):
bos_token_id = [bos_token_id]
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
return bos_token_id, eos_token_id
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN_INDEX = 0
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = '<image>'
INPUT_IMAGE_TOKEN_INDEX = IMAGE_TOKEN_INDEX
OUTPUT_IMAGE_TOKEN_INDEX = -300
QUERY_TOKEN_INDEX = -400
QUERY_TOKEN = '<query>'
def crop2square(pil_img):
width, height = pil_img.width, pil_img.height
if width > height:
y0, y1 = 0, height
x0 = random.randint(0, width - height)
x1 = x0 + height
else:
x0, x1 = 0, width
y0 = random.randint(0, height - width)
y1 = y0 + width
return pil_img.crop(box=(x0, y0, x1, y1))
def load_jsonl(json_file):
with open(json_file) as f:
lines = f.readlines()
data = []
for line in lines:
data.append(json.loads(line))
return data
def encode_fn(example,
tokenizer,
max_length=None,
image_length=1,
query_length=1,
input_ids_with_output=True,
with_image_token=False,
prompt_template=None,
truncation='right'):
"""Only support the following three scenarios:
1. Incremental pretraining dataset.
example['conversation'] = [
{
'input': '',
'output': '### Human: Can you write xxx'
}
]
2. Single-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
}
]
3. Multi-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
},
{
'input': 'Please expand on the second point.',
'output': 'Here is an expanded explanation of the xxx'
}
]
"""
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
is_multi_turn_conversation = len(example['conversation']) > 1
if is_multi_turn_conversation:
assert input_ids_with_output
input_ids, labels = [], []
next_needs_bos_token = True
for single_turn_conversation in example['conversation']:
input = single_turn_conversation['input']
if DEFAULT_IMAGE_TOKEN in input and with_image_token:
chunk_encode = [
tokenizer.encode(chunk, add_special_tokens=False)
for chunk in input.split(DEFAULT_IMAGE_TOKEN)
]
assert len(chunk_encode) == 2
input_encode = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
input_encode.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
input_encode += [INPUT_IMAGE_TOKEN_INDEX] * image_length
else:
input_encode = tokenizer.encode(input, add_special_tokens=False)
if next_needs_bos_token:
input_ids += bos_token_id
labels += [IGNORE_INDEX] * len(bos_token_id)
input_ids += input_encode
labels += [IGNORE_INDEX] * len(input_encode)
if input_ids_with_output and 'output' in single_turn_conversation:
# Add output
output_with_loss = single_turn_conversation.get(
'output_with_loss', True)
output = single_turn_conversation['output']
if DEFAULT_IMAGE_TOKEN in output and with_image_token:
chunk_encode = [
tokenizer.encode(chunk, add_special_tokens=False)
for chunk in output.split(DEFAULT_IMAGE_TOKEN)
]
assert len(chunk_encode) == 2
output_encode = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
output_encode.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
output_encode += [OUTPUT_IMAGE_TOKEN_INDEX] * image_length
elif QUERY_TOKEN in output:
chunk_encode = [
tokenizer.encode(chunk, add_special_tokens=False)
for chunk in output.split(QUERY_TOKEN)
]
assert len(chunk_encode) == 2
output_encode = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
output_encode.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
output_encode += [QUERY_TOKEN_INDEX] * query_length
else:
output_encode = tokenizer.encode(output, add_special_tokens=False)
input_ids += output_encode
if output_with_loss:
labels += copy.deepcopy(output_encode)
else:
labels += [IGNORE_INDEX] * len(output_encode)
# Add EOS_TOKEN (with loss)
if single_turn_conversation.get('need_eos_token', True):
next_needs_bos_token = True
input_ids += eos_token_id
if output_with_loss:
labels += copy.deepcopy(eos_token_id)
else:
labels += [IGNORE_INDEX] * len(eos_token_id)
else:
next_needs_bos_token = False
# Add SEP (without loss)
sep = single_turn_conversation.get('sep', '')
if sep != '':
sep_encode = tokenizer.encode(sep, add_special_tokens=False)
input_ids += sep_encode
labels += [IGNORE_INDEX] * len(sep_encode)
if max_length is not None and len(input_ids) > max_length:
if truncation == 'right':
input_ids = input_ids[:max_length]
labels = labels[:max_length]
elif truncation == 'left':
input_ids = input_ids[-max_length:]
labels = labels[-max_length:]
else:
assert truncation is None
return {'input_ids': input_ids, 'labels': labels}
|