|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import re |
|
|
from copy import deepcopy |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
class SingleRegionCaptionDataset(Dataset): |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "true" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
image, |
|
|
mask, |
|
|
processor, |
|
|
prompt_token="<Prompt1>", |
|
|
prompt_number=5, |
|
|
visual_prompt_tokens=[ |
|
|
"<Prompt0>", |
|
|
"<Prompt1>", |
|
|
"<Prompt2>", |
|
|
"<Prompt3>", |
|
|
"<Prompt4>", |
|
|
"<NO_Prompt>", |
|
|
], |
|
|
data_dtype=torch.bfloat16, |
|
|
**kwargs, |
|
|
): |
|
|
self.processor = processor |
|
|
self.prompt_token = prompt_token |
|
|
|
|
|
self.prompt_number = prompt_number |
|
|
self.special_tokens = visual_prompt_tokens |
|
|
self.visual_prompt_ids = { |
|
|
token: self.processor.tokenizer.convert_tokens_to_ids(token) - 128256 |
|
|
for token in self.special_tokens |
|
|
} |
|
|
|
|
|
self.image = image |
|
|
self.mask = mask |
|
|
self.data_dtype = data_dtype |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.coco.anns) |
|
|
|
|
|
def _parse_annotations(self): |
|
|
image = self.image |
|
|
mask = self.mask |
|
|
|
|
|
np.array(image) |
|
|
mask_np = mask.astype(np.uint8) |
|
|
|
|
|
filled_matrix = -1 * np.ones((image.height, image.width), dtype=np.uint8) |
|
|
prompt_token = self.prompt_token |
|
|
prompt_id = self.visual_prompt_ids.get( |
|
|
prompt_token, self.visual_prompt_ids["<NO_Prompt>"] |
|
|
) |
|
|
assert prompt_id < 16, f"prompt_id should be less than {16}, got {prompt_id}" |
|
|
fill_area = (filled_matrix == -1) & mask_np.astype(bool) |
|
|
filled_matrix[fill_area] = prompt_id |
|
|
|
|
|
filled_matrix[filled_matrix == -1] = self.visual_prompt_ids["<NO_Prompt>"] |
|
|
|
|
|
bboxes = {} |
|
|
|
|
|
prompt_idx = int(re.match(r"<Prompt(\d+)>", prompt_token).group(1)) |
|
|
non_zero_coords = np.argwhere(mask_np) |
|
|
y_min, x_min = non_zero_coords.min(axis=0) |
|
|
y_max, x_max = non_zero_coords.max(axis=0) |
|
|
bbox = ( |
|
|
x_min / image.width, |
|
|
y_min / image.height, |
|
|
x_max / image.width, |
|
|
y_max / image.height, |
|
|
) |
|
|
bboxes[ |
|
|
str( |
|
|
self.processor.tokenizer.convert_tokens_to_ids( |
|
|
f"<|reserved_special_token_{prompt_idx + 2}|>" |
|
|
) |
|
|
) |
|
|
] = bbox |
|
|
|
|
|
data_dict = { |
|
|
"image": image, |
|
|
"visual_prompt": Image.fromarray(filled_matrix), |
|
|
"bboxes": bboxes, |
|
|
} |
|
|
return data_dict |
|
|
|
|
|
def __getitem__(self, index): |
|
|
data_dict = deepcopy(self._parse_annotations()) |
|
|
image = data_dict["image"] |
|
|
visual_prompt = data_dict["visual_prompt"] |
|
|
|
|
|
prompt_idx = int(re.match(r"<Prompt(\d+)>", self.prompt_token).group(1)) |
|
|
|
|
|
|
|
|
qs = f"There are some objects I am curious about: {self.prompt_token};\n{self.prompt_token}: <|reserved_special_token_{prompt_idx + 2}|>Describe this masked region in detail." |
|
|
qs = qs.replace( |
|
|
f"<|reserved_special_token_{prompt_idx + 2}|>", |
|
|
f"<|reserved_special_token_{prompt_idx + 2}|>" * 256, |
|
|
) |
|
|
|
|
|
user_content = [{"type": "image", "image": image}, {"type": "text", "text": qs}] |
|
|
|
|
|
messages = [ |
|
|
{"role": "user", "content": user_content}, |
|
|
] |
|
|
|
|
|
|
|
|
raw_prompt = self.processor.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=False, |
|
|
) |
|
|
|
|
|
model_inputs = self.processor(text=[raw_prompt], images=[image], visual_prompts=[visual_prompt], return_tensors="pt") |
|
|
|
|
|
pixel_values = model_inputs["pixel_values"] |
|
|
mask_values = model_inputs["mask_values"] |
|
|
input_ids = model_inputs["input_ids"].squeeze(0) |
|
|
attention_mask = model_inputs["attention_mask"].squeeze(0) |
|
|
aspect_ratio = model_inputs["aspect_ratio"] |
|
|
|
|
|
ret = dict( |
|
|
input_ids=input_ids.unsqueeze(0), |
|
|
attention_mask=attention_mask.to(self.data_dtype).unsqueeze(0), |
|
|
pixel_values=pixel_values.to(self.data_dtype).flatten(0, 1), |
|
|
global_mask_values=mask_values.to(self.data_dtype).squeeze(), |
|
|
bboxes=[data_dict["bboxes"]], |
|
|
aspect_ratios=aspect_ratio.unsqueeze(0), |
|
|
) |
|
|
return ret |
|
|
|
|
|
|
|
|
class MultiRegionDataset(Dataset): |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "true" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
image, |
|
|
masks, |
|
|
question_str, |
|
|
processor, |
|
|
prompt_token="<Prompt1>", |
|
|
prompt_number=5, |
|
|
visual_prompt_tokens=[ |
|
|
"<Prompt0>", |
|
|
"<Prompt1>", |
|
|
"<Prompt2>", |
|
|
"<Prompt3>", |
|
|
"<Prompt4>", |
|
|
"<NO_Prompt>", |
|
|
], |
|
|
data_dtype=torch.bfloat16, |
|
|
**kwargs, |
|
|
): |
|
|
self.processor = processor |
|
|
self.prompt_token = prompt_token |
|
|
|
|
|
self.prompt_number = prompt_number |
|
|
self.special_tokens = visual_prompt_tokens |
|
|
self.visual_prompt_ids = { |
|
|
token: self.processor.tokenizer.convert_tokens_to_ids(token) - 128256 |
|
|
for token in self.special_tokens |
|
|
} |
|
|
|
|
|
self.image = image |
|
|
self.masks = masks |
|
|
self.question_str = question_str |
|
|
self.data_dtype = data_dtype |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.coco.anns) |
|
|
|
|
|
def _parse_annotations(self): |
|
|
image = self.image |
|
|
masks = self.masks |
|
|
|
|
|
width, height = image.size |
|
|
|
|
|
np.array(image) |
|
|
masks_np = [np.array(mask).astype(np.uint8) for mask in masks] |
|
|
|
|
|
for mask_id, mask in enumerate(masks_np): |
|
|
if image.width != mask.shape[1] or image.height != mask.shape[0]: |
|
|
mask = mask.resize(image.size, Image.NEAREST) |
|
|
masks[mask_id] = mask |
|
|
masks_np[mask_id] = np.array(mask).astype(np.unint8) |
|
|
|
|
|
prompt_matches = set(re.findall(r'<Prompt\d+>', self.question_str)) |
|
|
assert len(prompt_matches) == len(masks) |
|
|
|
|
|
objects_desc = "There are some objects I am curious about: " |
|
|
sub_image_desc = "" |
|
|
for matched_prompt in prompt_matches: |
|
|
objects_desc += f"{matched_prompt}; " |
|
|
|
|
|
prompt_idx = int(re.match(r'<Prompt(\d+)>', matched_prompt).group(1)) |
|
|
sub_image_desc += f"{matched_prompt}: <|reserved_special_token_{prompt_idx + 2}|>\n" |
|
|
sub_image_desc = sub_image_desc.replace(f"<|reserved_special_token_{prompt_idx + 2}|>", f"<|reserved_special_token_{prompt_idx + 2}|>" * 256) |
|
|
|
|
|
prompt = objects_desc + "\n" + sub_image_desc + "\n" + self.question_str |
|
|
|
|
|
filled_matrix = -1 * np.ones((image.height, image.width), dtype=np.uint8) |
|
|
bboxes = {} |
|
|
for matched_prompt in prompt_matches: |
|
|
prompt_idx = int(re.match(r'<Prompt(\d+)>', matched_prompt).group(1)) |
|
|
mask = masks[prompt_idx] |
|
|
prompt_token = matched_prompt |
|
|
prompt_id = self.visual_prompt_ids.get(prompt_token, self.visual_prompt_ids["<NO_Prompt>"]) |
|
|
assert prompt_id < self.prompt_number + 1, f"prompt_id should be less than {self.prompt_numbers + 1}, got {prompt_id}" |
|
|
fill_area = (filled_matrix == -1) & mask.astype(bool) |
|
|
filled_matrix[fill_area] = prompt_id |
|
|
|
|
|
non_zero_coords = np.argwhere(masks_np[mask_id]) |
|
|
y_min, x_min = non_zero_coords.min(axis=0) |
|
|
y_max, x_max = non_zero_coords.max(axis=0) |
|
|
bbox = (x_min / image.width, y_min / image.height, x_max / image.width, y_max / image.height) |
|
|
bboxes[str(self.processor.tokenizer.convert_tokens_to_ids(f"<|reserved_special_token_{prompt_idx + 2}|>"))] = bbox |
|
|
|
|
|
filled_matrix[filled_matrix == -1] = self.visual_prompt_ids["<NO_Prompt>"] |
|
|
|
|
|
masks = [Image.fromarray((masks_np[i] * 255).astype(np.uint8)) for i in range(len(masks))] |
|
|
|
|
|
data_dict = { |
|
|
'image': image, |
|
|
'visual_prompt': Image.fromarray(filled_matrix), |
|
|
'bboxes': bboxes, |
|
|
'prompt': prompt, |
|
|
} |
|
|
return data_dict |
|
|
|
|
|
def __getitem__(self, index): |
|
|
data_dict = self._parse_annotations() |
|
|
image = data_dict["image"] |
|
|
visual_prompt = data_dict["visual_prompt"] |
|
|
qs = data_dict["prompt"] |
|
|
|
|
|
user_content = [ |
|
|
{"type": "image", "image": image}, |
|
|
{"type": "text", "text": qs} |
|
|
] |
|
|
|
|
|
messages = [ |
|
|
{"role": "user", "content": user_content}, |
|
|
] |
|
|
|
|
|
|
|
|
raw_prompt = self.processor.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=False, |
|
|
) |
|
|
|
|
|
model_inputs = self.processor(text=[raw_prompt], images=[image], visual_prompts=[visual_prompt], return_tensors="pt") |
|
|
|
|
|
pixel_values = model_inputs["pixel_values"] |
|
|
mask_values = model_inputs["mask_values"] |
|
|
input_ids = model_inputs["input_ids"].squeeze(0) |
|
|
attention_mask = model_inputs["attention_mask"].squeeze(0) |
|
|
aspect_ratio = model_inputs["aspect_ratio"] |
|
|
|
|
|
ret = dict( |
|
|
input_ids=input_ids.unsqueeze(0), |
|
|
attention_mask=attention_mask.to(self.data_dtype).unsqueeze(0), |
|
|
pixel_values=pixel_values.to(self.data_dtype).flatten(0, 1), |
|
|
global_mask_values=mask_values.to(self.data_dtype).squeeze(), |
|
|
bboxes=[data_dict["bboxes"]], |
|
|
aspect_ratios=aspect_ratio.unsqueeze(0), |
|
|
) |
|
|
return ret |
|
|
|