Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| import json | |
| from tqdm import tqdm | |
| from collections import defaultdict | |
| from copy import deepcopy | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from torch.nn import CrossEntropyLoss | |
| from LLaVA.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN | |
| from LLaVA.llava.conversation import conv_templates, SeparatorStyle | |
| from LLaVA.llava.model.builder import load_pretrained_model | |
| from LLaVA.llava.utils import disable_torch_init | |
| from LLaVA.llava.mm_utils import get_model_name_from_path, KeywordsStoppingCriteria, tokenizer_image_object_token | |
| from visual_search import parse_args, VSM, visual_search | |
| def normalize_bbox(bbox, image_width, image_height): | |
| normalized_bbox = [bbox[0]/image_width, bbox[1]/image_height, (bbox[0]+bbox[2])/image_width, (bbox[1]+bbox[3])/image_height] | |
| normalized_bbox = [np.clip(_, 0, 1) for _ in normalized_bbox] | |
| return normalized_bbox | |
| def expand2square(pil_img, background_color): | |
| width, height = pil_img.size | |
| if width == height: | |
| return pil_img, 0, 0 | |
| elif width > height: | |
| result = Image.new(pil_img.mode, (width, width), background_color) | |
| result.paste(pil_img, (0, (width - height) // 2)) | |
| return result, 0, (width - height) // 2 | |
| else: | |
| result = Image.new(pil_img.mode, (height, height), background_color) | |
| result.paste(pil_img, ((height - width) // 2, 0)) | |
| return result, (height - width) // 2, 0 | |
| class VQA_LLM: | |
| def __init__(self, args): | |
| disable_torch_init() | |
| model_path = args.vqa_model_path | |
| model_name = get_model_name_from_path(model_path) | |
| model_name += 'llava' | |
| model_base = None | |
| device_map = "auto" | |
| self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, model_base, model_name) | |
| self.conv_type = args.conv_type | |
| def get_patch(self, bbox, image_width, image_height, patch_size=224, patch_scale=None): | |
| object_width = int(np.ceil(bbox[2])) | |
| object_height = int(np.ceil(bbox[3])) | |
| object_center_x = int(bbox[0] + bbox[2]/2) | |
| object_center_y = int(bbox[1] + bbox[3]/2) | |
| if patch_scale is None: | |
| patch_width = max(object_width, patch_size) | |
| patch_height = max(object_height, patch_size) | |
| else: | |
| patch_width = int(object_width*patch_scale) | |
| patch_height = int(object_height*patch_scale) | |
| left = max(0, object_center_x-patch_width//2) | |
| right = min(left+patch_width, image_width) | |
| top = max(0, object_center_y-patch_height//2) | |
| bottom = min(top+patch_height, image_height) | |
| return [left, top, right, bottom] | |
| def get_object_crop(self, image, bbox, patch_scale): | |
| resized_bbox = self.get_patch(bbox, image.width, image.height, patch_scale=patch_scale) | |
| object_crop = image.crop((resized_bbox[0], resized_bbox[1], resized_bbox[2], resized_bbox[3])) | |
| object_crop = object_crop.resize((self.image_processor.crop_size['width'],self.image_processor.crop_size['height'])) | |
| object_crop = self.image_processor.preprocess(object_crop, return_tensors='pt')['pixel_values'][0] | |
| return object_crop | |
| def free_form_inference(self, image, question, temperature=0, top_p=None, num_beams=1, max_new_tokens=200, object_crops=None, images_long=None, objects_long=None): | |
| conv = conv_templates[self.conv_type].copy() | |
| qs = DEFAULT_IMAGE_TOKEN + '\n' + question | |
| conv.append_message(conv.roles[0], qs) | |
| conv.append_message(conv.roles[1], None) | |
| prompt = conv.get_prompt() | |
| stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | |
| keywords = [stop_str] | |
| input_ids = tokenizer_image_object_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() | |
| image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] | |
| stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) | |
| output_ids = self.model.generate( | |
| input_ids, | |
| images=image_tensor.unsqueeze(0).half().cuda(), | |
| object_features=object_crops.half().cuda() if object_crops is not None else None, | |
| images_long = images_long, | |
| objects_long = objects_long, | |
| do_sample= True if temperature > 0 else False, | |
| num_beams=num_beams, | |
| temperature=temperature, | |
| top_p = top_p, | |
| max_new_tokens=max_new_tokens, | |
| use_cache=True, | |
| stopping_criteria=[stopping_criteria]) | |
| input_token_len = input_ids.shape[1] | |
| n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() | |
| if n_diff_input_output > 0: | |
| print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') | |
| outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] | |
| outputs = outputs.strip() | |
| if outputs.endswith(stop_str): | |
| outputs = outputs[:-len(stop_str)] | |
| outputs = outputs.strip() | |
| return outputs | |
| def multiple_choices_inference(self, image, question, options, object_crops=None, images_long=None, objects_long=None): | |
| conv = conv_templates[self.conv_type].copy() | |
| qs = DEFAULT_IMAGE_TOKEN + '\n' + question | |
| conv.append_message(conv.roles[0], qs) | |
| conv.append_message(conv.roles[1], None) | |
| prompt = conv.get_prompt() | |
| question_input_ids = tokenizer_image_object_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() | |
| image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] | |
| output_question = self.model( | |
| question_input_ids, | |
| use_cache=True, | |
| images=image_tensor.unsqueeze(0).half().cuda(), | |
| object_features=object_crops.half().cuda() if object_crops is not None else None, | |
| images_long = images_long, | |
| objects_long = objects_long) | |
| question_logits = output_question.logits | |
| question_past_key_values = output_question.past_key_values | |
| loss_list = [] | |
| for option in options: | |
| conv = conv_templates[self.conv_type].copy() | |
| conv.append_message(conv.roles[0], qs) | |
| conv.append_message(conv.roles[1], option) | |
| full_prompt = conv.get_prompt() | |
| full_input_ids = tokenizer_image_object_token(full_prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() | |
| option_answer_input_ids = full_input_ids[:, question_input_ids.shape[1]:] | |
| output_option = self.model(input_ids=option_answer_input_ids, | |
| use_cache=True, | |
| attention_mask=torch.ones(1, question_logits.shape[1]+option_answer_input_ids.shape[1], device=full_input_ids.device), | |
| past_key_values=question_past_key_values) | |
| logits = torch.cat([question_logits[:, -1:], output_option.logits[:, :-1]], 1) | |
| loss_fct = CrossEntropyLoss() | |
| logits = logits.view(-1, self.model.config.vocab_size) | |
| labels = option_answer_input_ids.view(-1) | |
| loss = loss_fct(logits, labels) | |
| loss_list.append(loss) | |
| option_chosen = torch.stack(loss_list).argmin() | |
| return option_chosen.cpu().item() | |
| def eval_model(args): | |
| # init VQA LLM | |
| vqa_llm = VQA_LLM(args) | |
| # init VSM | |
| vsm_args = parse_args({}) | |
| vsm_args.version = args.vsm_model_path | |
| vsm = VSM(vsm_args) | |
| results = {} | |
| per_type_acc = defaultdict(list) | |
| all_acc = [] | |
| missing_objects_msg = "Sorry, I can not answer the question. Some visual information about the following objects is missing or unclear:" | |
| focus_msg = "Additional visual information to focus on: " | |
| for test_type in ['direct_attributes', 'relative_position']: | |
| results[test_type] = [] | |
| folder = os.path.join(args.benchmark_folder, test_type) | |
| image_files = list(filter(lambda file: '.json' not in file, os.listdir(folder))) | |
| for image_file in tqdm(image_files): | |
| result_single_sample = {} | |
| image_path = os.path.join(folder, image_file) | |
| annotation_path = image_path.split('.')[0] + '.json' | |
| image = Image.open(image_path).convert('RGB') | |
| annotation = json.load(open(annotation_path)) | |
| image, _, _ = expand2square(image, tuple(int(x*255) for x in vqa_llm.image_processor.image_mean)) | |
| question = annotation['question'] | |
| # generate free-form response to check whether visual search needs to be activated | |
| prediction = vqa_llm.free_form_inference(image, question) | |
| missing_objects = [] | |
| if missing_objects_msg in prediction: | |
| missing_objects = prediction.split(missing_objects_msg)[-1] | |
| if missing_objects.endswith('.'): | |
| missing_objects = missing_objects[:-1] | |
| missing_objects = missing_objects.split(',') | |
| missing_objects = [missing_object.strip() for missing_object in missing_objects] | |
| search_result = [] | |
| if len(missing_objects) > 0: | |
| # visual search | |
| for object_name in missing_objects: | |
| image = Image.open(image_path).convert('RGB') | |
| smallest_size = max(int(np.ceil(min(image.width, image.height)/args.minimum_size_scale)), args.minimum_size) | |
| final_step, path_length, search_successful, all_valid_boxes = visual_search(vsm, image, object_name, target_bbox=None, smallest_size=smallest_size) | |
| if all_valid_boxes is not None: | |
| # might exist multiple target instances | |
| for search_bbox in all_valid_boxes: | |
| search_final_patch = final_step['bbox'] | |
| search_bbox[0] += search_final_patch[0] | |
| search_bbox[1] += search_final_patch[1] | |
| search_result.append({'bbox':search_bbox.tolist(),'name':object_name}) | |
| else: | |
| search_bbox = final_step['detection_result'] | |
| search_final_patch = final_step['bbox'] | |
| search_bbox[0] += search_final_patch[0] | |
| search_bbox[1] += search_final_patch[1] | |
| search_result.append({'bbox':search_bbox.tolist(),'name':object_name}) | |
| # predict the multiple-choice option | |
| options = annotation['options'] | |
| image = Image.open(image_path).convert('RGB') | |
| if len(missing_objects) > 0: | |
| object_names = [_['name'] for _ in search_result] | |
| bboxs = deepcopy([_['bbox'] for _ in search_result]) | |
| if len(object_names) <= 2: | |
| images_long = [False] | |
| objects_long = [True]*len(object_names) | |
| else: | |
| images_long = [False] | |
| objects_long = [False]*len(object_names) | |
| object_crops = [] | |
| for bbox in bboxs: | |
| object_crop = vqa_llm.get_object_crop(image, bbox, patch_scale=1.2) | |
| object_crops.append(object_crop) | |
| object_crops = torch.stack(object_crops, 0) | |
| image, left, top = expand2square(image, tuple(int(x*255) for x in vqa_llm.image_processor.image_mean)) | |
| bbox_list = [] | |
| for bbox in bboxs: | |
| bbox[0] += left | |
| bbox[1] += top | |
| bbox_list.append(bbox) | |
| bbox_list = [normalize_bbox(bbox, image.width, image.height) for bbox in bbox_list] | |
| cur_focus_msg = focus_msg | |
| for i, (object_name, bbox) in enumerate(zip(object_names, bbox_list)): | |
| cur_focus_msg = cur_focus_msg + "{} <object> at location [{:.3f},{:.3f},{:.3f},{:.3f}]".format(object_name, bbox[0], bbox[1], bbox[2], bbox[3]) | |
| if i != len(bbox_list)-1: | |
| cur_focus_msg = cur_focus_msg+"; " | |
| else: | |
| cur_focus_msg = cur_focus_msg +'.' | |
| question_with_focus = cur_focus_msg+"\n"+question | |
| option_chosen = vqa_llm.multiple_choices_inference(image, question_with_focus, options, object_crops, images_long=images_long, objects_long=objects_long) | |
| else: | |
| option_chosen = vqa_llm.multiple_choices_inference(image, question, options) | |
| correct = 1 if option_chosen==0 else 0 | |
| per_type_acc[test_type].append(correct) | |
| all_acc.append(correct) | |
| result_single_sample['question'] = question | |
| result_single_sample['options'] = options | |
| result_single_sample['image'] = image_file | |
| result_single_sample['prediction_freeform'] = prediction | |
| result_single_sample['missing_objects'] = missing_objects | |
| result_single_sample['search_result'] = search_result | |
| result_single_sample['option_chosen'] = option_chosen | |
| result_single_sample['correct'] = correct | |
| results[test_type].append(result_single_sample) | |
| print(test_type, np.mean(per_type_acc[test_type])) | |
| print(np.mean(all_acc)) | |
| with open(args.output_path, 'w') as f: | |
| json.dump(results, f, indent=4) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--vqa-model-path", type=str, default="craigwu/seal_vqa_7b") | |
| parser.add_argument("--vqa-model-base", type=str, default=None) | |
| parser.add_argument("--conv_type", default="v1", type=str,) | |
| parser.add_argument("--benchmark-folder", type=str, default="vstar_bench") | |
| parser.add_argument("--vsm-model-path", type=str, default="craigwu/seal_vsm_7b") | |
| parser.add_argument("--output-path", type=str, default="eval_result.json") | |
| parser.add_argument("--minimum_size_scale", default=4.0, type=float, help="minimum sub-image scale for the termination of search") | |
| parser.add_argument("--minimum_size", default=224, type=int, help="minimum sub-image size for the termination of search") | |
| args = parser.parse_args() | |
| eval_model(args) |