File size: 6,214 Bytes
ace9173 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import copy
import random
from xtuner.dataset.utils import get_bos_eos_token_ids
from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
import json
INPUT_IMAGE_TOKEN_INDEX = IMAGE_TOKEN_INDEX
OUTPUT_IMAGE_TOKEN_INDEX = -300
QUERY_TOKEN_INDEX = -400
QUERY_TOKEN = '<query>'
def crop2square(pil_img):
width, height = pil_img.width, pil_img.height
if width > height:
y0, y1 = 0, height
x0 = random.randint(0, width - height)
x1 = x0 + height
else:
x0, x1 = 0, width
y0 = random.randint(0, height - width)
y1 = y0 + width
return pil_img.crop(box=(x0, y0, x1, y1))
def load_jsonl(json_file):
with open(json_file) as f:
lines = f.readlines()
data = []
for line in lines:
data.append(json.loads(line))
return data
def encode_fn(example,
tokenizer,
max_length=None,
image_length=1,
query_length=1,
input_ids_with_output=True,
with_image_token=False,
prompt_template=None,
truncation='right'):
"""Only support the following three scenarios:
1. Incremental pretraining dataset.
example['conversation'] = [
{
'input': '',
'output': '### Human: Can you write xxx'
}
]
2. Single-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
}
]
3. Multi-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
},
{
'input': 'Please expand on the second point.',
'output': 'Here is an expanded explanation of the xxx'
}
]
"""
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
is_multi_turn_conversation = len(example['conversation']) > 1
if is_multi_turn_conversation:
assert input_ids_with_output
input_ids, labels = [], []
next_needs_bos_token = True
for single_turn_conversation in example['conversation']:
input = single_turn_conversation['input']
if DEFAULT_IMAGE_TOKEN in input and with_image_token:
chunk_encode = [
tokenizer.encode(chunk, add_special_tokens=False)
for chunk in input.split(DEFAULT_IMAGE_TOKEN)
]
assert len(chunk_encode) == 2
input_encode = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
input_encode.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
input_encode += [INPUT_IMAGE_TOKEN_INDEX] * image_length
else:
input_encode = tokenizer.encode(input, add_special_tokens=False)
if next_needs_bos_token:
input_ids += bos_token_id
labels += [IGNORE_INDEX] * len(bos_token_id)
input_ids += input_encode
labels += [IGNORE_INDEX] * len(input_encode)
if input_ids_with_output and 'output' in single_turn_conversation:
# Add output
output_with_loss = single_turn_conversation.get(
'output_with_loss', True)
output = single_turn_conversation['output']
if DEFAULT_IMAGE_TOKEN in output and with_image_token:
chunk_encode = [
tokenizer.encode(chunk, add_special_tokens=False)
for chunk in output.split(DEFAULT_IMAGE_TOKEN)
]
assert len(chunk_encode) == 2
output_encode = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
output_encode.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
output_encode += [OUTPUT_IMAGE_TOKEN_INDEX] * image_length
elif QUERY_TOKEN in output:
chunk_encode = [
tokenizer.encode(chunk, add_special_tokens=False)
for chunk in output.split(QUERY_TOKEN)
]
assert len(chunk_encode) == 2
output_encode = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
output_encode.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
output_encode += [QUERY_TOKEN_INDEX] * query_length
else:
output_encode = tokenizer.encode(output, add_special_tokens=False)
input_ids += output_encode
if output_with_loss:
labels += copy.deepcopy(output_encode)
else:
labels += [IGNORE_INDEX] * len(output_encode)
# Add EOS_TOKEN (with loss)
if single_turn_conversation.get('need_eos_token', True):
next_needs_bos_token = True
input_ids += eos_token_id
if output_with_loss:
labels += copy.deepcopy(eos_token_id)
else:
labels += [IGNORE_INDEX] * len(eos_token_id)
else:
next_needs_bos_token = False
# Add SEP (without loss)
sep = single_turn_conversation.get('sep', '')
if sep != '':
sep_encode = tokenizer.encode(sep, add_special_tokens=False)
input_ids += sep_encode
labels += [IGNORE_INDEX] * len(sep_encode)
if max_length is not None and len(input_ids) > max_length:
if truncation == 'right':
input_ids = input_ids[:max_length]
labels = labels[:max_length]
elif truncation == 'left':
input_ids = input_ids[-max_length:]
labels = labels[-max_length:]
else:
assert truncation is None
return {'input_ids': input_ids, 'labels': labels}
|