|
|
|
|
|
import random |
|
|
from xtuner.utils import DEFAULT_IMAGE_TOKEN |
|
|
|
|
|
SEG_QUESTIONS = [ |
|
|
"Can you segment the {class_name} in this image?", |
|
|
"Please segment {class_name} in this image.", |
|
|
"What is {class_name} in this image? Please respond with segmentation mask.", |
|
|
"What is {class_name} in this image? Please output segmentation mask.", |
|
|
|
|
|
"Can you segment the {class_name} in this image", |
|
|
"Please segment {class_name} in this image", |
|
|
"What is {class_name} in this image? Please respond with segmentation mask", |
|
|
"What is {class_name} in this image? Please output segmentation mask", |
|
|
|
|
|
"Could you provide a segmentation mask for the {class_name} in this image?", |
|
|
"Please identify and segment the {class_name} in this image.", |
|
|
"Where is the {class_name} in this picture? Please respond with a segmentation mask.", |
|
|
"Can you highlight the {class_name} in this image with a segmentation mask?", |
|
|
|
|
|
"Could you provide a segmentation mask for the {class_name} in this image", |
|
|
"Please identify and segment the {class_name} in this image", |
|
|
"Where is the {class_name} in this picture? Please respond with a segmentation mask", |
|
|
"Can you highlight the {class_name} in this image with a segmentation mask", |
|
|
] |
|
|
|
|
|
ANSWER_LIST = [ |
|
|
"It is [SEG].", |
|
|
"Sure, [SEG].", |
|
|
"Sure, it is [SEG].", |
|
|
"Sure, the segmentation result is [SEG].", |
|
|
"[SEG].", |
|
|
] |
|
|
|
|
|
ANSWER_LIST_GCG_FORMAT = [ |
|
|
"<p> {} </p> [SEG].", |
|
|
] |
|
|
|
|
|
def semantic_seg_conversations(labels): |
|
|
ret = [] |
|
|
for i, label in enumerate(labels): |
|
|
label = label.strip() |
|
|
assert len(label.split("||")) == 1 |
|
|
for question_template in SEG_QUESTIONS: |
|
|
for answer_template in ANSWER_LIST: |
|
|
item = {} |
|
|
item['conversations'] = [{'from': 'human', 'value': DEFAULT_IMAGE_TOKEN+question_template.format(class_name=label.lower())}, |
|
|
{'from': 'gpt', 'value': answer_template}] |
|
|
item['class_id'] = i |
|
|
ret.append(item) |
|
|
return ret |
|
|
|
|
|
def semantic_seg_map_fn(example): |
|
|
|
|
|
messages = example['conversations'] |
|
|
input = '' |
|
|
conversation = [] |
|
|
while messages and messages[0]['from'] == 'gpt': |
|
|
|
|
|
messages = messages[1:] |
|
|
for msg in messages: |
|
|
if msg['from'] == 'human': |
|
|
if DEFAULT_IMAGE_TOKEN in msg['value']: |
|
|
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, |
|
|
'').strip() |
|
|
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] |
|
|
msg['value'] = msg['value'].strip() |
|
|
input += msg['value'] |
|
|
|
|
|
elif msg['from'] == 'gpt': |
|
|
conversation.append({'input': input, 'output': msg['value']}) |
|
|
input = '' |
|
|
else: |
|
|
raise NotImplementedError |
|
|
example.update({'conversation': conversation}) |
|
|
return example |
|
|
|
|
|
def pascal_part_conversation(selected_labels): |
|
|
conversations = [] |
|
|
for i, selected_label in enumerate(selected_labels): |
|
|
question = random.choice(SEG_QUESTIONS).format(class_name=selected_label.lower()).strip() |
|
|
answer = random.choice(ANSWER_LIST) |
|
|
if i == 0: |
|
|
question = DEFAULT_IMAGE_TOKEN + question |
|
|
conversations.append({'from': 'human', 'value': question}) |
|
|
conversations.append({'from': 'gpt', 'value': answer}) |
|
|
return conversations |
|
|
|
|
|
def pascal_part_preprocess(example): |
|
|
selected_labels = example["selected_labels"] |
|
|
conversations = pascal_part_conversation(selected_labels) |
|
|
example['conversations'] = conversations |
|
|
return example |
|
|
|
|
|
def pascal_part_map_fn(example): |
|
|
example = pascal_part_preprocess(example) |
|
|
example['image'] = example["file_name"] |
|
|
|
|
|
messages = example['conversations'] |
|
|
input = '' |
|
|
conversation = [] |
|
|
while messages and messages[0]['from'] == 'gpt': |
|
|
|
|
|
messages = messages[1:] |
|
|
for msg in messages: |
|
|
if msg['from'] == 'human': |
|
|
if DEFAULT_IMAGE_TOKEN in msg['value']: |
|
|
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, |
|
|
'').strip() |
|
|
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] |
|
|
msg['value'] = msg['value'].strip() |
|
|
input += msg['value'] |
|
|
|
|
|
elif msg['from'] == 'gpt': |
|
|
conversation.append({'input': input, 'output': msg['value']}) |
|
|
input = '' |
|
|
else: |
|
|
raise NotImplementedError |
|
|
example.update({'conversation': conversation}) |
|
|
return example |
|
|
|
|
|
|
|
|
def semantic_seg_gcg_format_conversations(labels): |
|
|
ret = [] |
|
|
for i, label in enumerate(labels): |
|
|
label = label.strip() |
|
|
assert len(label.split("||")) == 1 |
|
|
for question_template in SEG_QUESTIONS: |
|
|
for answer_template in ANSWER_LIST_GCG_FORMAT: |
|
|
item = {} |
|
|
item['conversations'] = [{'from': 'human', 'value': DEFAULT_IMAGE_TOKEN+question_template.format(class_name=label.lower())}, |
|
|
{'from': 'gpt', 'value': answer_template.format(label.lower().capitalize())}] |
|
|
item['class_id'] = i |
|
|
ret.append(item) |
|
|
return ret |
|
|
|
|
|
def semantic_seg_gcg_format_map_fn(example): |
|
|
|
|
|
messages = example['conversations'] |
|
|
input = '' |
|
|
conversation = [] |
|
|
while messages and messages[0]['from'] == 'gpt': |
|
|
|
|
|
messages = messages[1:] |
|
|
for msg in messages: |
|
|
if msg['from'] == 'human': |
|
|
if DEFAULT_IMAGE_TOKEN in msg['value']: |
|
|
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, |
|
|
'').strip() |
|
|
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] |
|
|
msg['value'] = msg['value'].strip() |
|
|
input += msg['value'] |
|
|
|
|
|
elif msg['from'] == 'gpt': |
|
|
conversation.append({'input': input, 'output': msg['value']}) |
|
|
input = '' |
|
|
else: |
|
|
raise NotImplementedError |
|
|
example.update({'conversation': conversation}) |
|
|
return example |
|
|
|
|
|
def pascal_part_gcg_format_conversation(selected_labels): |
|
|
conversations = [] |
|
|
for i, selected_label in enumerate(selected_labels): |
|
|
question = random.choice(SEG_QUESTIONS).format(class_name=selected_label.lower()).strip() |
|
|
answer = random.choice(ANSWER_LIST).format(selected_label.lower().capitalize()) |
|
|
if i == 0: |
|
|
question = DEFAULT_IMAGE_TOKEN + question |
|
|
conversations.append({'from': 'human', 'value': question}) |
|
|
conversations.append({'from': 'gpt', 'value': answer}) |
|
|
return conversations |
|
|
|
|
|
def pascal_part_gcg_format_preprocess(example): |
|
|
selected_labels = example["selected_labels"] |
|
|
conversations = pascal_part_gcg_format_conversation(selected_labels) |
|
|
example['conversations'] = conversations |
|
|
return example |
|
|
|
|
|
def pascal_part_gcg_format_map_fn(example): |
|
|
example = pascal_part_gcg_format_preprocess(example) |
|
|
example['image'] = example["file_name"] |
|
|
|
|
|
messages = example['conversations'] |
|
|
input = '' |
|
|
conversation = [] |
|
|
while messages and messages[0]['from'] == 'gpt': |
|
|
|
|
|
messages = messages[1:] |
|
|
for msg in messages: |
|
|
if msg['from'] == 'human': |
|
|
if DEFAULT_IMAGE_TOKEN in msg['value']: |
|
|
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, |
|
|
'').strip() |
|
|
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] |
|
|
msg['value'] = msg['value'].strip() |
|
|
input += msg['value'] |
|
|
|
|
|
elif msg['from'] == 'gpt': |
|
|
conversation.append({'input': input, 'output': msg['value']}) |
|
|
input = '' |
|
|
else: |
|
|
raise NotImplementedError |
|
|
example.update({'conversation': conversation}) |
|
|
return example |
|
|
|
|
|
|
|
|
|