|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
import mmengine
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from einops import rearrange
|
|
|
|
|
|
from mmaction.registry import MODELS
|
|
|
from .vindlu import VindLUBase
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class VindLUVQA(VindLUBase):
|
|
|
"""VindLU VQA.
|
|
|
|
|
|
Args:
|
|
|
text_decoder (dict): Backbone for extracting
|
|
|
multi-modal features. We apply this part as VQA fusion module.
|
|
|
answer_list_path (str, optional): Path to `answer_list.json`.
|
|
|
max_question_len (int): Max text length of question text.
|
|
|
Defaults to 25.
|
|
|
max_answer_len (int): Max text length of answer text. Defaults to 5.
|
|
|
num_ans_candidates (int): Number of answer candidates, used to filter
|
|
|
out answers with low probability. Defaults to 128.
|
|
|
**kwargs: Other keyword arguments accepted by the VindLUBase.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
text_decoder: dict,
|
|
|
answer_list_path: Optional[str] = None,
|
|
|
max_question_len: int = 25,
|
|
|
max_answer_len: int = 5,
|
|
|
num_ans_candidates: int = 128,
|
|
|
**kwargs):
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
self.max_question_len = max_question_len
|
|
|
self.max_answer_len = max_answer_len
|
|
|
self.num_ans_candidates = num_ans_candidates
|
|
|
self.answer_list_path = answer_list_path
|
|
|
self.text_decoder_cfg = text_decoder
|
|
|
|
|
|
|
|
|
if answer_list_path:
|
|
|
self.answer_list = mmengine.load(answer_list_path)
|
|
|
|
|
|
|
|
|
extra_attributes = ['vision_proj', 'text_proj', 'temp', 'itm_head']
|
|
|
for attr in extra_attributes:
|
|
|
delattr(self, attr)
|
|
|
|
|
|
self.text_decoder_cfg.gradient_checkpointing = \
|
|
|
self.gradient_checkpointing
|
|
|
self.text_decoder = MODELS.build(self.text_decoder_cfg)
|
|
|
|
|
|
def forward_encoder(self, inputs, data_samples):
|
|
|
|
|
|
image_embeds, _ = self.encode_vision(inputs)
|
|
|
image_embeds = rearrange(image_embeds, 'b t l c -> b (t l) c')
|
|
|
image_atts = torch.ones(
|
|
|
image_embeds.size()[:-1], dtype=torch.long).to(inputs.device)
|
|
|
|
|
|
|
|
|
questions = [sample.question for sample in data_samples]
|
|
|
questions = self.tokenizer(
|
|
|
questions,
|
|
|
padding='max_length',
|
|
|
truncation=True,
|
|
|
max_length=self.max_question_len,
|
|
|
return_tensors='pt').to(inputs.device)
|
|
|
|
|
|
question_output = self.text_encoder(
|
|
|
questions.input_ids,
|
|
|
attention_mask=questions.attention_mask,
|
|
|
encoder_hidden_states=image_embeds,
|
|
|
encoder_attention_mask=image_atts,
|
|
|
return_dict=True)
|
|
|
|
|
|
return questions, question_output
|
|
|
|
|
|
def loss(self, inputs, data_samples):
|
|
|
"""Calculate losses from a batch of inputs and data samples.
|
|
|
|
|
|
Args:
|
|
|
inputs (dict): A batch of inputs. The input tensor with of
|
|
|
at least one modality. For image, the value is a tensor
|
|
|
of shape (N, C, ...) in general.
|
|
|
For text, the value is a dict of tokenized text inputs.
|
|
|
data_samples (Optional[List[DataSample]]):
|
|
|
The annotation data of every samples. Defaults to None.
|
|
|
|
|
|
Returns:
|
|
|
Dict[str, torch.tensor]: a dictionary of loss components of
|
|
|
"""
|
|
|
|
|
|
questions, question_output = self.forward_encoder(inputs, data_samples)
|
|
|
|
|
|
weights = torch.cat(
|
|
|
[torch.tensor(sample.gt_answer_weight) for sample in data_samples],
|
|
|
dim=0).to(inputs.device)
|
|
|
raw_answers = []
|
|
|
for sample in data_samples:
|
|
|
raw_answers.extend(sample.gt_answer)
|
|
|
answer_count = torch.tensor([
|
|
|
len(sample.gt_answer) for sample in data_samples
|
|
|
]).to(inputs.device)
|
|
|
answers = [a + ' ' + '[SEP]' for a in raw_answers]
|
|
|
answers = self.tokenizer(
|
|
|
answers,
|
|
|
padding='max_length',
|
|
|
truncation=True,
|
|
|
max_length=self.max_answer_len,
|
|
|
return_tensors='pt').to(inputs.device)
|
|
|
|
|
|
answer_targets = answers.input_ids.masked_fill(
|
|
|
answers.input_ids == self.tokenizer.pad_token_id, -100)
|
|
|
|
|
|
question_states = []
|
|
|
question_atts = []
|
|
|
for b, n in enumerate(answer_count):
|
|
|
question_states += [question_output.last_hidden_state[b]] * n
|
|
|
question_atts += [questions.attention_mask[b]] * n
|
|
|
question_states = torch.stack(question_states, 0).to(inputs.device)
|
|
|
question_atts = torch.stack(question_atts, 0).to(inputs.device)
|
|
|
|
|
|
answer_output = self.text_decoder(
|
|
|
answers.input_ids,
|
|
|
attention_mask=answers.attention_mask,
|
|
|
encoder_hidden_states=question_states,
|
|
|
encoder_attention_mask=question_atts,
|
|
|
labels=answer_targets,
|
|
|
return_dict=True,
|
|
|
reduction='none',
|
|
|
)
|
|
|
loss = weights * answer_output.loss
|
|
|
loss = loss.sum() / inputs.size(0)
|
|
|
|
|
|
return dict(loss=loss)
|
|
|
|
|
|
def predict(self, inputs, data_samples, **kwargs):
|
|
|
|
|
|
questions, question_output = self.forward_encoder(inputs, data_samples)
|
|
|
|
|
|
raw_answers = self.answer_list
|
|
|
answers = [a + ' ' + '[SEP]' for a in raw_answers]
|
|
|
answers = self.tokenizer(
|
|
|
answers,
|
|
|
padding='max_length',
|
|
|
truncation=True,
|
|
|
max_length=self.max_answer_len,
|
|
|
return_tensors='pt',
|
|
|
).to(inputs.device)
|
|
|
|
|
|
topk_ids, topk_probs = self.rank_answer(
|
|
|
question_output.last_hidden_state, questions.attention_mask,
|
|
|
answers.input_ids, answers.attention_mask, self.num_ans_candidates)
|
|
|
|
|
|
out_data_samples = []
|
|
|
for data_sample, topk_id, topk_prob in zip(data_samples, topk_ids,
|
|
|
topk_probs):
|
|
|
_, pred = topk_prob.max(dim=0)
|
|
|
data_sample.pred_answer = raw_answers[topk_id[pred]]
|
|
|
out_data_samples.append(data_sample)
|
|
|
|
|
|
return out_data_samples
|
|
|
|
|
|
def rank_answer(self, question_states, question_atts, answer_ids,
|
|
|
answer_atts, k):
|
|
|
"""
|
|
|
question_states: (bsz, Lq, d)
|
|
|
answer_ids: answer input id after tokenization, (#answers, La)
|
|
|
"""
|
|
|
num_ques = question_states.size(0)
|
|
|
start_ids = answer_ids[0, 0].repeat(num_ques, 1)
|
|
|
|
|
|
start_output = self.text_decoder(
|
|
|
start_ids,
|
|
|
encoder_hidden_states=question_states,
|
|
|
encoder_attention_mask=question_atts,
|
|
|
return_dict=True,
|
|
|
reduction='none',
|
|
|
)
|
|
|
logits = start_output.logits[:, 0, :]
|
|
|
|
|
|
|
|
|
|
|
|
answer_first_token = answer_ids[:, 1]
|
|
|
prob_first_token = F.softmax(
|
|
|
logits, dim=1).index_select(
|
|
|
dim=1, index=answer_first_token)
|
|
|
topk_probs, topk_ids = prob_first_token.topk(k, dim=1)
|
|
|
|
|
|
|
|
|
input_ids = []
|
|
|
input_atts = []
|
|
|
for b, topk_id in enumerate(topk_ids):
|
|
|
input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
|
|
|
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
|
|
|
input_ids = torch.cat(input_ids, dim=0)
|
|
|
input_atts = torch.cat(input_atts, dim=0)
|
|
|
|
|
|
targets_ids = input_ids.masked_fill(
|
|
|
input_ids == self.tokenizer.pad_token_id, -100)
|
|
|
|
|
|
question_states = question_states.repeat_interleave(k, dim=0)
|
|
|
question_atts = question_atts.repeat_interleave(k, dim=0)
|
|
|
|
|
|
output = self.text_decoder(
|
|
|
input_ids,
|
|
|
attention_mask=input_atts,
|
|
|
encoder_hidden_states=question_states,
|
|
|
encoder_attention_mask=question_atts,
|
|
|
labels=targets_ids,
|
|
|
return_dict=True,
|
|
|
reduction='none',
|
|
|
)
|
|
|
|
|
|
answer_loss = output.loss
|
|
|
answer_loss = answer_loss.view(input_ids.size(0), -1)
|
|
|
|
|
|
|
|
|
topk_probs = topk_probs.view(-1, 1)
|
|
|
log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
log_probs_sum = log_probs.sum(1)
|
|
|
log_probs_sum = log_probs_sum.view(num_ques, k)
|
|
|
|
|
|
topk_probs = F.softmax(log_probs_sum, dim=-1)
|
|
|
|
|
|
topk_probs, rerank_id = topk_probs.topk(k, dim=1)
|
|
|
topk_ids = torch.gather(topk_ids, 1, rerank_id)
|
|
|
|
|
|
return topk_ids, topk_probs
|
|
|
|
|
|
def preprocess_state_dict(self, state_dict):
|
|
|
"""Preprocess pretrained checkpoint for text_encoder and
|
|
|
text_decoder."""
|
|
|
for key in list(state_dict.keys()):
|
|
|
if 'bert' in key:
|
|
|
encoder_key = key.replace('bert.', '')
|
|
|
state_dict[encoder_key] = state_dict[key]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.text_decoder_cfg and 'text_encoder' in key:
|
|
|
if 'layer' in key:
|
|
|
encoder_keys = key.split('.')
|
|
|
layer_num = int(encoder_keys[4])
|
|
|
if layer_num < self.text_encoder_cfg.fusion_layer:
|
|
|
del state_dict[key]
|
|
|
continue
|
|
|
else:
|
|
|
decoder_layer_num = layer_num - 9
|
|
|
encoder_keys[4] = str(decoder_layer_num)
|
|
|
encoder_key = '.'.join(encoder_keys)
|
|
|
else:
|
|
|
encoder_key = key
|
|
|
decoder_key = encoder_key.replace('text_encoder',
|
|
|
'text_decoder')
|
|
|
state_dict[decoder_key] = state_dict[key]
|
|
|
del state_dict[key]
|
|
|
return state_dict
|
|
|
|