Spaces:
Build error
Build error
| import pylab | |
| from lxmert.src.modeling_frcnn import GeneralizedRCNN | |
| import lxmert.src.vqa_utils as utils | |
| from lxmert.src.processing_image import Preprocess | |
| from transformers import LxmertTokenizer | |
| from lxmert.src.huggingface_lxmert import LxmertForQuestionAnswering | |
| from lxmert.src.lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP | |
| from tqdm import tqdm | |
| from lxmert.src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines, GeneratorOursAblationNoAggregation | |
| import random | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from captum.attr import visualization | |
| import requests | |
| OBJ_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_objects_vocab.txt" | |
| ATTR_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_attributes_vocab.txt" | |
| VQA_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_lxmert_master_data_vqa_trainval_label2ans.json" | |
| class ModelUsage: | |
| def __init__(self, use_lrp=False): | |
| self.vqa_answers = utils.get_data(VQA_URL) | |
| # load models and model components | |
| self.frcnn_cfg = utils.Config.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned") | |
| self.frcnn_cfg.MODEL.DEVICE = "cpu" | |
| self.frcnn = GeneralizedRCNN.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned", config=self.frcnn_cfg) | |
| self.image_preprocess = Preprocess(self.frcnn_cfg) | |
| self.lxmert_tokenizer = LxmertTokenizer.from_pretrained("./lxmert/unc-nlp/lxmert-base-uncased") | |
| if use_lrp: | |
| self.lxmert_vqa = LxmertForQuestionAnsweringLRP.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased") | |
| else: | |
| self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased") | |
| self.lxmert_vqa.eval() | |
| self.model = self.lxmert_vqa | |
| # self.vqa_dataset = vqa_data.VQADataset(splits="valid") | |
| def forward(self, item): | |
| URL, question = item | |
| self.image_file_path = URL | |
| # run frcnn | |
| images, sizes, scales_yx = self.image_preprocess(URL) | |
| output_dict = self.frcnn( | |
| images, | |
| sizes, | |
| scales_yx=scales_yx, | |
| padding="max_detections", | |
| max_detections=self.frcnn_cfg.max_detections, | |
| return_tensors="pt" | |
| ) | |
| inputs = self.lxmert_tokenizer( | |
| question, | |
| truncation=True, | |
| return_token_type_ids=True, | |
| return_attention_mask=True, | |
| add_special_tokens=True, | |
| return_tensors="pt" | |
| ) | |
| self.question_tokens = self.lxmert_tokenizer.convert_ids_to_tokens(inputs.input_ids.flatten()) | |
| self.text_len = len(self.question_tokens) | |
| # Very important that the boxes are normalized | |
| normalized_boxes = output_dict.get("normalized_boxes") | |
| features = output_dict.get("roi_features") | |
| self.image_boxes_len = features.shape[1] | |
| self.bboxes = output_dict.get("boxes") | |
| self.output = self.lxmert_vqa( | |
| input_ids=inputs.input_ids), | |
| attention_mask=inputs.attention_mask, | |
| visual_feats=features, | |
| visual_pos=normalized_boxes, | |
| token_type_ids=inputs.token_type_ids, | |
| return_dict=True, | |
| output_attentions=False, | |
| ) | |
| return self.output | |
| def save_image_vis(image_file_path, bbox_scores): | |
| bbox_scores = image_scores | |
| _, top_bboxes_indices = bbox_scores.topk(k=1, dim=-1) | |
| img = cv2.imread(image_file_path) | |
| mask = torch.zeros(img.shape[0], img.shape[1]) | |
| for index in range(len(bbox_scores)): | |
| [x, y, w, h] = model_lrp.bboxes[0][index] | |
| curr_score_tensor = mask[int(y):int(h), int(x):int(w)] | |
| new_score_tensor = torch.ones_like(curr_score_tensor) * bbox_scores[index].item() | |
| mask[int(y):int(h), int(x):int(w)] = torch.max(new_score_tensor, mask[int(y):int(h), int(x):int(w)]) | |
| mask = (mask - mask.min()) / (mask.max() - mask.min()) | |
| mask = mask.unsqueeze_(-1) | |
| mask = mask.expand(img.shape) | |
| img = img * mask.cpu().data.numpy() | |
| cv2.imwrite( | |
| 'lxmert/lxmert/experiments/paper/new.jpg', img) | |
| return img | |
| model_lrp = ModelUsage(use_lrp=True) | |
| lrp = GeneratorOurs(model_lrp) | |
| baselines = GeneratorBaselines(model_lrp) | |
| vqa_answers = utils.get_data(VQA_URL) | |
| image_ids = [ | |
| # giraffe | |
| 'COCO_val2014_000000185590', | |
| # baseball | |
| 'COCO_val2014_000000127510', | |
| # bath | |
| 'COCO_val2014_000000324266', | |
| # frisbee | |
| 'COCO_val2014_000000200717' | |
| ] | |
| test_questions_for_images = [ | |
| ################## paper samples | |
| # giraffe | |
| "is the animal eating?", | |
| # baseball | |
| "did he catch the ball?", | |
| # bath | |
| "is the tub white ?", | |
| # frisbee | |
| "did the man just catch the frisbee?" | |
| ################## paper samples | |
| ] | |