Spaces:
Sleeping
Sleeping
| import pylab | |
| from lxmert.src.modeling_frcnn import GeneralizedRCNN | |
| import lxmert.src.vqa_utils as utils | |
| from lxmert.src.processing_image import Preprocess | |
| from transformers import LxmertTokenizer | |
| from lxmert.src.huggingface_lxmert import LxmertForQuestionAnswering | |
| from lxmert.src.lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP | |
| from tqdm import tqdm | |
| from lxmert.src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines, GeneratorOursAblationNoAggregation | |
| import random | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from captum.attr import visualization | |
| import requests | |
| OBJ_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_objects_vocab.txt" | |
| ATTR_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_attributes_vocab.txt" | |
| VQA_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_lxmert_master_data_vqa_trainval_label2ans.json" | |
| class ModelUsage: | |
| def __init__(self, use_lrp=False): | |
| self.vqa_answers = utils.get_data(VQA_URL) | |
| # load models and model components | |
| self.frcnn_cfg = utils.Config.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned") | |
| self.frcnn_cfg.MODEL.DEVICE = "cpu" | |
| self.frcnn = GeneralizedRCNN.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned", config=self.frcnn_cfg) | |
| self.image_preprocess = Preprocess(self.frcnn_cfg) | |
| self.lxmert_tokenizer = LxmertTokenizer.from_pretrained("./lxmert/unc-nlp/lxmert-base-uncased") | |
| if use_lrp: | |
| self.lxmert_vqa = LxmertForQuestionAnsweringLRP.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased") | |
| else: | |
| self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased") | |
| self.lxmert_vqa.eval() | |
| self.model = self.lxmert_vqa | |
| # self.vqa_dataset = vqa_data.VQADataset(splits="valid") | |
| def forward(self, item): | |
| URL, question = item | |
| self.image_file_path = URL | |
| # run frcnn | |
| images, sizes, scales_yx = self.image_preprocess(URL) | |
| output_dict = self.frcnn( | |
| images, | |
| sizes, | |
| scales_yx=scales_yx, | |
| padding="max_detections", | |
| max_detections=self.frcnn_cfg.max_detections, | |
| return_tensors="pt" | |
| ) | |
| inputs = self.lxmert_tokenizer( | |
| question, | |
| truncation=True, | |
| return_token_type_ids=True, | |
| return_attention_mask=True, | |
| add_special_tokens=True, | |
| return_tensors="pt" | |
| ) | |
| self.question_tokens = self.lxmert_tokenizer.convert_ids_to_tokens(inputs.input_ids.flatten()) | |
| self.text_len = len(self.question_tokens) | |
| # Very important that the boxes are normalized | |
| normalized_boxes = output_dict.get("normalized_boxes") | |
| features = output_dict.get("roi_features") | |
| self.image_boxes_len = features.shape[1] | |
| self.bboxes = output_dict.get("boxes") | |
| self.output = self.lxmert_vqa( | |
| input_ids=inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| visual_feats=features, | |
| visual_pos=normalized_boxes, | |
| token_type_ids=inputs.token_type_ids, | |
| return_dict=True, | |
| output_attentions=False, | |
| ) | |
| return self.output | |
| model_lrp = ModelUsage(use_lrp=True) | |
| lrp = GeneratorOurs(model_lrp) | |
| baselines = GeneratorBaselines(model_lrp) | |
| vqa_answers = utils.get_data(VQA_URL) | |
| def save_image_vis(image_file_path, question): | |
| R_t_t, R_t_i = lrp.generate_ours((image_file_path, question), use_lrp=False, | |
| normalize_self_attention=True, | |
| method_name="ours") | |
| image_scores = R_t_i[0] | |
| text_scores = R_t_t[0] | |
| # bbox_scores = image_scores | |
| _, top_bboxes_indices = image_scores.topk(k=1, dim=-1) | |
| img = cv2.imread(image_file_path) | |
| mask = torch.zeros(img.shape[0], img.shape[1]) | |
| for index in range(len(image_scores)): | |
| [x, y, w, h] = model_lrp.bboxes[0][index] | |
| curr_score_tensor = mask[int(y):int(h), int(x):int(w)] | |
| new_score_tensor = torch.ones_like(curr_score_tensor) * image_scores[index].item() | |
| mask[int(y):int(h), int(x):int(w)] = torch.max(new_score_tensor, mask[int(y):int(h), int(x):int(w)]) | |
| mask = (mask - mask.min()) / (mask.max() - mask.min()) | |
| mask = mask.unsqueeze_(-1) | |
| mask = mask.expand(img.shape) | |
| img = img * mask.cpu().data.numpy() | |
| # img = Image.fromarray(np.uint8(img)).convert('RGB') | |
| cv2.imwrite( | |
| 'lxmert/experiments/paper/new.jpg', img) | |
| img = Image.open('lxmert/experiments/paper/new.jpg') | |
| img = img.resize([448, 300]) | |
| orig_image = Image.open(image_file_path) | |
| orig_image = orig_image.resize([448, 300]) | |
| text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min()) | |
| vis_data_records = [visualization.VisualizationDataRecord(text_scores, 0, 0, 0, 0, 0, model_lrp.question_tokens, 1)] | |
| html1 = visualization.visualize_text(vis_data_records) | |
| answer = vqa_answers[model_lrp.output.question_answering_score.argmax()] | |
| return orig_image, img, html1.data, answer | |