Spaces:
Running on Zero
Running on Zero
File size: 6,412 Bytes
becf13a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | import copy
import random
from xtuner.dataset.utils import get_bos_eos_token_ids
from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
import json
INPUT_IMAGE_TOKEN_INDEX = IMAGE_TOKEN_INDEX
OUTPUT_IMAGE_TOKEN_INDEX = -300
def crop2square(pil_img):
width, height = pil_img.width, pil_img.height
if width > height:
y0, y1 = 0, height
x0 = random.randint(0, width - height) # [0, w - h]
x1 = x0 + height # [h, w]
else:
x0, x1 = 0, width
y0 = random.randint(0, height - width) # [0, h - w]
y1 = y0 + width # [w, h]
return pil_img.crop(box=(x0, y0, x1, y1))
def center_crop_to_square(pil_img):
"""Center crop PIL image to square based on short side.
Args:
pil_img: PIL Image
Returns:
PIL Image cropped to square (min(width, height) x min(width, height))
"""
width, height = pil_img.width, pil_img.height
if width > height:
# Crop width to match height (landscape → square)
y0, y1 = 0, height
x0 = (width - height) // 2 # Center crop
x1 = x0 + height
else:
# Crop height to match width (portrait → square)
x0, x1 = 0, width
y0 = (height - width) // 2 # Center crop
y1 = y0 + width
return pil_img.crop(box=(x0, y0, x1, y1))
def load_jsonl(json_file):
with open(json_file) as f:
lines = f.readlines()
data = []
for line in lines:
data.append(json.loads(line))
return data
def encode_fn(example,
tokenizer,
max_length=None,
image_length=1,
input_ids_with_output=True,
with_image_token=False,
truncation='right'):
"""We only support the following three scenarios:
1. Incremental pretraining dataset.
example['conversation'] = [
{
'input': '',
'output': '### Human: Can you write xxx'
}
]
2. Single-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
}
]
3. Multi-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
},
{
'input': 'Please expand on the second point.',
'output': 'Here is an expanded explanation of the xxx'
}
]
"""
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
is_multi_turn_conversation = len(example['conversation']) > 1
if is_multi_turn_conversation:
assert input_ids_with_output
input_ids, labels = [], []
next_needs_bos_token = True
for single_turn_conversation in example['conversation']:
input = single_turn_conversation['input']
if DEFAULT_IMAGE_TOKEN in input and with_image_token:
chunk_encode = [
tokenizer.encode(chunk, add_special_tokens=False)
for chunk in input.split(DEFAULT_IMAGE_TOKEN)
]
assert len(chunk_encode) == 2
input_encode = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
input_encode.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
# input_encode.append(IMAGE_TOKEN_INDEX)
input_encode += [IMAGE_TOKEN_INDEX] * image_length
else:
input_encode = tokenizer.encode(input, add_special_tokens=False)
if next_needs_bos_token:
input_ids += bos_token_id
labels += [IGNORE_INDEX] * len(bos_token_id)
input_ids += input_encode
labels += [IGNORE_INDEX] * len(input_encode)
if input_ids_with_output and 'output' in single_turn_conversation:
# Add output
output_with_loss = single_turn_conversation.get(
'output_with_loss', True)
output = single_turn_conversation['output']
if DEFAULT_IMAGE_TOKEN in output and with_image_token:
chunk_encode = [
tokenizer.encode(chunk, add_special_tokens=False)
for chunk in output.split(DEFAULT_IMAGE_TOKEN)
]
assert len(chunk_encode) == 2
output_encode = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
output_encode.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
output_encode += [IMAGE_TOKEN_INDEX] * image_length
else:
output_encode = tokenizer.encode(output, add_special_tokens=False)
# output_encode = tokenizer.encode(output, add_special_tokens=False)
input_ids += output_encode
if output_with_loss:
labels += copy.deepcopy(output_encode)
else:
labels += [IGNORE_INDEX] * len(output_encode)
# Add EOS_TOKEN (with loss)
if single_turn_conversation.get('need_eos_token', True):
next_needs_bos_token = True
input_ids += eos_token_id
if output_with_loss:
labels += copy.deepcopy(eos_token_id)
else:
labels += [IGNORE_INDEX] * len(eos_token_id)
else:
next_needs_bos_token = False
# Add SEP (without loss)
sep = single_turn_conversation.get('sep', '')
if sep != '':
sep_encode = tokenizer.encode(sep, add_special_tokens=False)
input_ids += sep_encode
labels += [IGNORE_INDEX] * len(sep_encode)
if max_length is not None and len(input_ids) > max_length:
if truncation == 'right':
input_ids = input_ids[:max_length]
labels = labels[:max_length]
elif truncation == 'left':
input_ids = input_ids[-max_length:]
labels = labels[-max_length:]
else:
assert truncation is None
return {'input_ids': input_ids, 'labels': labels}
|