File size: 9,784 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import numpy as np
import random
from xtuner.utils import DEFAULT_IMAGE_TOKEN
import re

REGION_QUESTIONS = [
    'Can you provide me with a detailed description of the region in the picture marked by <region>?',
    "I'm curious about the region represented by <region> in the picture. Could you describe it in detail?",
    'What can you tell me about the region indicated by <region> in the image?',
    "I'd like to know more about the area in the photo labeled <region>. Can you give me a detailed description?",
    'Could you describe the region shown as <region> in the picture in great detail?',
    'What details can you give me about the region outlined by <region> in the photo?',
    'Please provide me with a comprehensive description of the region marked with <region> in the image.',
    'Can you give me a detailed account of the region labeled as <region> in the picture?',
    "I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail?",
    'What is the region outlined by <region> in the picture like? Could you give me a detailed description?',
    'Can you provide me with a detailed description of the region in the picture marked by <region>, please?',
    "I'm curious about the region represented by <region> in the picture. Could you describe it in detail, please?",
    'What can you tell me about the region indicated by <region> in the image, exactly?',
    "I'd like to know more about the area in the photo labeled <region>, please. Can you give me a detailed description?",
    'Could you describe the region shown as <region> in the picture in great detail, please?',
    'What details can you give me about the region outlined by <region> in the photo, please?',
    'Please provide me with a comprehensive description of the region marked with <region> in the image, please.',
    'Can you give me a detailed account of the region labeled as <region> in the picture, please?',
    "I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail, please?",
    'What is the region outlined by <region> in the picture like, please? Could you give me a detailed description?',
]

def region_caption_conversation(descriptions):
    questions = []
    answers = []
    for i, description in enumerate(descriptions):
        question = random.choice(REGION_QUESTIONS).strip().replace('<region>', f'region{i + 1} <region>')
        if i == 0:
            question = DEFAULT_IMAGE_TOKEN + question
        questions.append(question)
        answers.append(description.replace('<region>', f'region{i + 1}'))

    # seg qa
    selected_seg_idx = 1 + np.random.randint(0, len(descriptions))
    question = "Please segment the region{}.".format(selected_seg_idx)
    answer = "Sure, it is [SEG]."
    questions.append(question)
    answers.append(answer)

    conversations = []
    for question, answer in zip(questions, answers):
        conversations.append({'from': 'human', 'value': question})
        conversations.append({'from': 'gpt', 'value': answer})
    return conversations, [selected_seg_idx - 1]

def region_caption_gcg_format_conversation(descriptions):
    questions = []
    answers = []
    for i, description in enumerate(descriptions):
        question = random.choice(REGION_QUESTIONS).strip().replace('<region>', f'region{i + 1} <region>')
        if i == 0:
            question = DEFAULT_IMAGE_TOKEN + question
        questions.append(question)
        answers.append(description.replace('<region>', f'region{i + 1}'))

    # seg qa
    selected_seg_idx = 1 + np.random.randint(0, len(descriptions))
    question = "Please segment the region{}.".format(selected_seg_idx)
    answer = "<p> Region{} </p> [SEG].".format(selected_seg_idx)
    questions.append(question)
    answers.append(answer)

    conversations = []
    for question, answer in zip(questions, answers):
        conversations.append({'from': 'human', 'value': question})
        conversations.append({'from': 'gpt', 'value': answer})
    return conversations, [selected_seg_idx - 1]

def region_caption_preprocess(example):
    descriptions = example['description']

    # random select some labels
    if len(descriptions) >= 3:
        sampled_inds = np.random.choice(
            list(range(len(descriptions))), size=3, replace=False
        )
    else:
        sampled_inds = list(range(len(descriptions)))

    selected_descriptions = [descriptions[idx] for idx in sampled_inds]
    selected_descriptions = [re.sub(r'<[^>]*>', '<region>', item) for item in selected_descriptions]

    conversations, selected_seg_idx = region_caption_conversation(selected_descriptions)
    example['conversations'] = conversations
    example['sampled_inds'] = sampled_inds
    example['seg_region_idx'] = selected_seg_idx
    return example

def region_caption_gcg_format_preprocess(example):
    descriptions = example['description']

    # random select some labels
    if len(descriptions) >= 3:
        sampled_inds = np.random.choice(
            list(range(len(descriptions))), size=3, replace=False
        )
    else:
        sampled_inds = list(range(len(descriptions)))

    selected_descriptions = [descriptions[idx] for idx in sampled_inds]
    selected_descriptions = [re.sub(r'<[^>]*>', '<region>', item) for item in selected_descriptions]

    conversations, selected_seg_idx = region_caption_gcg_format_conversation(selected_descriptions)
    example['conversations'] = conversations
    example['sampled_inds'] = sampled_inds
    example['seg_region_idx'] = selected_seg_idx
    return example

def osprey_region_caption_map_fn(example):
    # examples {'image', 'description'}
    example = region_caption_preprocess(example)

    # do llava preprocess
    messages = example['conversations']
    input = ''
    conversation = []
    while messages and messages[0]['from'] == 'gpt':
        # Skip the first one if it is from gpt
        messages = messages[1:]
    for msg in messages:
        if msg['from'] == 'human':
            if DEFAULT_IMAGE_TOKEN in msg['value']:
                msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
                                                    '').strip()
                msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
                msg['value'] = msg['value'].strip()
            input += msg['value']

        elif msg['from'] == 'gpt':
            conversation.append({'input': input, 'output': msg['value']})
            input = ''
        else:
            raise NotImplementedError
    example.update({'conversation': conversation})
    return example

def osprey_region_caption_gcg_format_map_fn(example):
    # examples {'image', 'description'}
    example = region_caption_gcg_format_preprocess(example)

    # do llava preprocess
    messages = example['conversations']
    input = ''
    conversation = []
    while messages and messages[0]['from'] == 'gpt':
        # Skip the first one if it is from gpt
        messages = messages[1:]
    for msg in messages:
        if msg['from'] == 'human':
            if DEFAULT_IMAGE_TOKEN in msg['value']:
                msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
                                                    '').strip()
                msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
                msg['value'] = msg['value'].strip()
            input += msg['value']

        elif msg['from'] == 'gpt':
            conversation.append({'input': input, 'output': msg['value']})
            input = ''
        else:
            raise NotImplementedError
    example.update({'conversation': conversation})
    return example

def region_conversations_preprocess(example):
    conversations = example['conversations']
    num_regions = example['num_regions']

    for i, conversation in enumerate(conversations):
        if i == 0:
            role = conversation['from']
            assert role == 'human'
            question = DEFAULT_IMAGE_TOKEN + 'There are some regions:'
            for i in range(num_regions):
                question = question + ' region{} <region>'.format(i + 1)
                if i + 1 == num_regions:
                    question = question + '.\n'
                else:
                    question = question + ','
            question = question + conversation['value'].replace('<', '').replace('>', '').\
                replace("regin", "region")
            conversation['value'] = question
        else:
            conversation['value'] = conversation['value'].replace('<', '').replace('>', ''). \
                replace("regin", "region")

    example['conversations'] = conversations
    return example


def osprey_region_conversation_map_fn(example):
    # examples {'image', 'conversations'}
    example = region_conversations_preprocess(example)

    # do llava preprocess
    messages = example['conversations']
    input = ''
    conversation = []
    while messages and messages[0]['from'] == 'gpt':
        # Skip the first one if it is from gpt
        messages = messages[1:]
    for msg in messages:
        if msg['from'] == 'human':
            if DEFAULT_IMAGE_TOKEN in msg['value']:
                msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
                                                    '').strip()
                msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
                msg['value'] = msg['value'].strip()
            input += msg['value']

        elif msg['from'] == 'gpt':
            conversation.append({'input': input, 'output': msg['value']})
            input = ''
        else:
            raise NotImplementedError
    example.update({'conversation': conversation})
    return example