|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torchvision.transforms as transforms |
|
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoConfig |
|
|
import clip |
|
|
from PIL import Image |
|
|
import re |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import gradio as gr |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def proc_ques(ques): |
|
|
words = re.sub(r"([.,'!?\"()*#:;])",'',ques.lower()).replace('-', ' ').replace('/', ' ') |
|
|
return words |
|
|
|
|
|
|
|
|
def change_requires_grad(model, req_grad): |
|
|
for p in model.parameters(): |
|
|
p.requires_grad = req_grad |
|
|
|
|
|
|
|
|
def load_checkpoint(ckpt_path, epoch): |
|
|
|
|
|
model_name = 'nle_model_{}'.format(str(epoch)) |
|
|
tokenizer_name = 'nle_gpt2_tokenizer_0' |
|
|
tokenizer = GPT2Tokenizer.from_pretrained(ckpt_path + tokenizer_name) |
|
|
model = GPT2LMHeadModel.from_pretrained(ckpt_path + model_name).to(device) |
|
|
|
|
|
return tokenizer, model |
|
|
|
|
|
|
|
|
class ImageEncoder(nn.Module): |
|
|
|
|
|
def __init__(self): |
|
|
super(ImageEncoder, self).__init__() |
|
|
|
|
|
self.encoder, _ = clip.load("ViT-B/16", device=device) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Expects a tensor of size (batch_size, 3, 224, 224) |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
x = x.type(self.encoder.visual.conv1.weight.dtype) |
|
|
x = self.encoder.visual.conv1(x) |
|
|
x = x.reshape(x.shape[0], x.shape[1], -1) |
|
|
x = x.permute(0, 2, 1) |
|
|
x = torch.cat([self.encoder.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) |
|
|
x = x + self.encoder.visual.positional_embedding.to(x.dtype) |
|
|
x = self.encoder.visual.ln_pre(x) |
|
|
x = x.permute(1, 0, 2) |
|
|
x = self.encoder.visual.transformer(x) |
|
|
grid_feats = x.permute(1, 0, 2) |
|
|
grid_feats = self.encoder.visual.ln_post(grid_feats[:,1:]) |
|
|
|
|
|
return grid_feats.float() |
|
|
|
|
|
|
|
|
def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')): |
|
|
|
|
|
assert logits.dim() == 1 |
|
|
top_k = min(top_k, logits.size(-1)) |
|
|
if top_k > 0: |
|
|
|
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
|
logits[indices_to_remove] = filter_value |
|
|
|
|
|
if top_p > 0.0: |
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
|
cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probabilities > top_p |
|
|
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
|
|
|
indices_to_remove = sorted_indices[sorted_indices_to_remove] |
|
|
logits[indices_to_remove] = filter_value |
|
|
|
|
|
indices_to_remove = logits < threshold |
|
|
logits[indices_to_remove] = filter_value |
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
def sample_sequences(img, model, input_ids, segment_ids, tokenizer): |
|
|
|
|
|
SPECIAL_TOKENS = ['<|endoftext|>', '<pad>', '<question>', '<answer>', '<explanation>'] |
|
|
special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS) |
|
|
because_token = tokenizer.convert_tokens_to_ids('Ġbecause') |
|
|
max_len = 20 |
|
|
current_output = [] |
|
|
img_embeddings = image_encoder(img) |
|
|
always_exp = False |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
for step in range(max_len + 1): |
|
|
|
|
|
if step == max_len: |
|
|
break |
|
|
|
|
|
outputs = model(input_ids=input_ids, |
|
|
past_key_values=None, |
|
|
attention_mask=None, |
|
|
token_type_ids=segment_ids, |
|
|
position_ids=None, |
|
|
encoder_hidden_states=img_embeddings, |
|
|
encoder_attention_mask=None, |
|
|
labels=None, |
|
|
use_cache=False, |
|
|
output_attentions=True, |
|
|
return_dict=True) |
|
|
|
|
|
lm_logits = outputs.logits |
|
|
xa_maps = outputs.cross_attentions |
|
|
logits = lm_logits[0, -1, :] / temperature |
|
|
logits = top_filtering(logits, top_k=top_k, top_p=top_p) |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
prev = torch.topk(probs, 1)[1] if no_sample else torch.multinomial(probs, 1) |
|
|
|
|
|
if prev.item() in special_tokens_ids: |
|
|
break |
|
|
|
|
|
|
|
|
if not always_exp: |
|
|
|
|
|
if prev.item() != because_token: |
|
|
new_segment = special_tokens_ids[-2] |
|
|
else: |
|
|
new_segment = special_tokens_ids[-1] |
|
|
always_exp = True |
|
|
else: |
|
|
new_segment = special_tokens_ids[-1] |
|
|
|
|
|
new_segment = torch.LongTensor([new_segment]).to(device) |
|
|
current_output.append(prev.item()) |
|
|
input_ids = torch.cat((input_ids, prev.unsqueeze(0)), dim = 1) |
|
|
segment_ids = torch.cat((segment_ids, new_segment.unsqueeze(0)), dim = 1) |
|
|
|
|
|
decoded_sequences = tokenizer.decode(current_output, skip_special_tokens=True).lstrip() |
|
|
|
|
|
return decoded_sequences, xa_maps |
|
|
|
|
|
img_size = 224 |
|
|
ckpt_path = 'VQAX_p/' |
|
|
max_seq_len = 40 |
|
|
load_from_epoch = 11 |
|
|
no_sample = True |
|
|
top_k = 0 |
|
|
top_p = 0.9 |
|
|
temperature = 1 |
|
|
|
|
|
image_encoder = ImageEncoder().to(device) |
|
|
change_requires_grad(image_encoder, False) |
|
|
tokenizer, model = load_checkpoint(ckpt_path, load_from_epoch) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
img_transform = transforms.Compose([transforms.Resize((img_size,img_size)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) |
|
|
|
|
|
def get_inputs(text, tokenizer): |
|
|
q_segment_id, a_segment_id, e_segment_id = tokenizer.convert_tokens_to_ids(['<question>', '<answer>', '<explanation>']) |
|
|
tokens = tokenizer.tokenize(text) |
|
|
segment_ids = [q_segment_id] * len(tokens) |
|
|
answer = [tokenizer.bos_token] + tokenizer.tokenize(" the answer is") |
|
|
answer_len = len(answer) |
|
|
tokens += answer |
|
|
segment_ids += [a_segment_id] * answer_len |
|
|
input_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
|
input_ids = torch.tensor(input_ids, dtype=torch.long) |
|
|
segment_ids = torch.tensor(segment_ids, dtype=torch.long) |
|
|
return input_ids.unsqueeze(0).to(device), segment_ids.unsqueeze(0).to(device) |
|
|
|
|
|
def inference(raw_image, question): |
|
|
|
|
|
oimg = raw_image.convert('RGB').resize((224,224)) |
|
|
img = img_transform(oimg).unsqueeze(0).to(device) |
|
|
text = proc_ques(question) |
|
|
input_ids, segment_ids = get_inputs(text, tokenizer) |
|
|
question_len = len(tokenizer.convert_ids_to_tokens(input_ids[0])) |
|
|
seq, xa_maps = sample_sequences(img, model, input_ids, segment_ids, tokenizer) |
|
|
last_am = xa_maps[-1].mean(1)[0, question_len:] |
|
|
mask = last_am[0, :].reshape(14,14).cpu().numpy() |
|
|
mask = cv2.resize(mask / mask.max(), oimg.size)[..., np.newaxis] |
|
|
attention_map = (mask * oimg).astype("uint8") |
|
|
splitted_seq = seq.split("because") |
|
|
return splitted_seq[0].strip(), "because " + splitted_seq[-1].strip(), Image.fromarray(attention_map) |
|
|
|
|
|
inputs = [gr.inputs.Image(type='pil', label="Load the image of your interest"), gr.inputs.Textbox(label="Ask a question on this image")] |
|
|
outputs = [gr.outputs.Textbox(label="Answer"), gr.outputs.Textbox(label="Textual Explanation"), gr.outputs.Image(type='pil', label="Visual Explanation")] |
|
|
|
|
|
title = "NLX-GPT: Explanations with Natural Text (Visual Question Answering Demo)" |
|
|
gr.Interface(inference, inputs, outputs, title=title).launch() |