File size: 1,972 Bytes
032e687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from xtuner.utils import DEFAULT_IMAGE_TOKEN
def mdpv_points_preprocess(example):
conversations = example['conversations']
num_marks = example['num_marks']
for i, conversation in enumerate(conversations):
if i == 0:
role = conversation['from']
assert role == 'human'
question = DEFAULT_IMAGE_TOKEN + 'There are some marks:'
for i in range(num_marks):
question = question + ' Mark {} <mark>'.format(i + 1)
if i + 1 == num_marks:
question = question + '.\n'
else:
question = question + ','
question = question + conversation['value'].replace('<', '').replace('>', '')
conversation['value'] = question
else:
conversation['value'] = conversation['value'].replace('<', '').replace('>', '')
example['conversations'] = conversations
return example
def mdpv_points_map_fn(example):
# examples {'image', 'conversations'}
example = mdpv_points_preprocess(example)
# do llava preprocess
messages = example['conversations']
input = ''
conversation = []
while messages and messages[0]['from'] == 'gpt':
# Skip the first one if it is from gpt
messages = messages[1:]
for msg in messages:
if msg['from'] == 'human':
if DEFAULT_IMAGE_TOKEN in msg['value']:
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
'').strip()
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
msg['value'] = msg['value'].strip()
input += msg['value']
elif msg['from'] == 'gpt':
conversation.append({'input': input, 'output': msg['value']})
input = ''
else:
raise NotImplementedError
example.update({'conversation': conversation})
return example |