|
|
""" |
|
|
Copyright (c) 2023, salesforce.com, inc. |
|
|
All rights reserved. |
|
|
SPDX-License-Identifier: BSD-3-Clause |
|
|
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
|
|
""" |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
from transformers.utils import ModelOutput |
|
|
from typing import Optional, Tuple |
|
|
from dist_utils import is_dist_avail_and_initialized |
|
|
from base_model import all_gather_with_grad, concat_all_gather |
|
|
from dataclasses import dataclass |
|
|
|
|
|
from blip2 import ( |
|
|
Blip2Base, |
|
|
compute_sim_matrix, |
|
|
) |
|
|
|
|
|
@dataclass |
|
|
class PDQ_Output(ModelOutput): |
|
|
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
|
|
|
loss_itc: Optional[torch.FloatTensor] = None |
|
|
|
|
|
loss_itm: Optional[torch.FloatTensor] = None |
|
|
|
|
|
loss_lm: Optional[torch.FloatTensor] = None |
|
|
|
|
|
FSUIE_inputs: Optional[torch.FloatTensor] = None |
|
|
|
|
|
cross_attentions: Optional[torch.FloatTensor] = None |
|
|
|
|
|
class PDQ(Blip2Base): |
|
|
|
|
|
PRETRAINED_MODEL_CONFIG_DICT = { |
|
|
"pretrain": "configs/models/blip2/blip2_pretrain.yaml", |
|
|
"coco": "configs/models/blip2/blip2_coco.yaml", |
|
|
} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vision_width=1664, |
|
|
num_query_token=32, |
|
|
embed_dim=256, |
|
|
max_txt_len=512, |
|
|
device='cuda:0' |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.tokenizer = self.init_tokenizer() |
|
|
|
|
|
self.model_device=device |
|
|
self.num_query_token=num_query_token |
|
|
self.Qformer, self.query_tokens = self.init_Qformer( |
|
|
num_query_token, vision_width) |
|
|
|
|
|
self.Qformer.resize_token_embeddings(len(self.tokenizer)) |
|
|
state_dict = self.Qformer.state_dict() |
|
|
|
|
|
for name, param in self.Qformer.named_parameters(): |
|
|
if "_query" in name: |
|
|
key_orig = name.replace("_query", "") |
|
|
param.data.copy_(state_dict[key_orig]) |
|
|
|
|
|
self.vision_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim) |
|
|
self.text_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim) |
|
|
|
|
|
self.itm_head = nn.Linear(self.Qformer.config.hidden_size, 2) |
|
|
|
|
|
self.temp = nn.Parameter(0.07 * torch.ones([])) |
|
|
|
|
|
self.max_txt_len = max_txt_len |
|
|
|
|
|
def forward(self, samples, no_its_and_itm=False): |
|
|
|
|
|
image_embeds = samples["image_embeds"] |
|
|
query_ids = samples["query_inputs"] |
|
|
|
|
|
|
|
|
text_tokens = samples["scene_graph"] |
|
|
text_tokens['input_ids']=text_tokens['input_ids'] |
|
|
text_tokens['attention_mask']=text_tokens['attention_mask'] |
|
|
|
|
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device) |
|
|
|
|
|
|
|
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) |
|
|
|
|
|
query_output = self.Qformer.bert( |
|
|
query_embeds=query_tokens, |
|
|
encoder_hidden_states=image_embeds, |
|
|
encoder_attention_mask=image_atts, |
|
|
use_cache=True, |
|
|
return_dict=True, |
|
|
output_attentions=True |
|
|
) |
|
|
cross_attentions=query_output.cross_attentions |
|
|
|
|
|
if no_its_and_itm: |
|
|
loss_itc=0.0 |
|
|
loss_itm=0.0 |
|
|
loss_lm=0.0 |
|
|
return PDQ_Output( |
|
|
loss=loss_itc + loss_itm + loss_lm, |
|
|
loss_itc=loss_itc, |
|
|
loss_itm=loss_itm, |
|
|
loss_lm=loss_lm, |
|
|
FSUIE_inputs=query_output.last_hidden_state |
|
|
) |
|
|
|
|
|
image_feats = F.normalize( |
|
|
self.vision_proj(query_output.last_hidden_state), dim=-1 |
|
|
) |
|
|
|
|
|
text_output = self.Qformer.bert( |
|
|
text_tokens['input_ids'], |
|
|
attention_mask=text_tokens['attention_mask'], |
|
|
return_dict=True, |
|
|
) |
|
|
text_feat = F.normalize( |
|
|
self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1 |
|
|
) |
|
|
|
|
|
|
|
|
image_feats_all = concat_all_gather( |
|
|
image_feats |
|
|
) |
|
|
text_feat_all = concat_all_gather(text_feat) |
|
|
|
|
|
|
|
|
sim_q2t = torch.matmul( |
|
|
image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1) |
|
|
).squeeze(-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sim_i2t, _ = sim_q2t.max(-1) |
|
|
sim_i2t = sim_i2t / self.temp |
|
|
|
|
|
|
|
|
sim_t2q = torch.matmul( |
|
|
text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1) |
|
|
).squeeze(-2) |
|
|
|
|
|
|
|
|
sim_t2i, _ = sim_t2q.max(-1) |
|
|
sim_t2i = sim_t2i / self.temp |
|
|
|
|
|
if is_dist_avail_and_initialized(): |
|
|
rank = dist.get_rank() |
|
|
bs = image_embeds.size(0) |
|
|
targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(image_embeds.device) |
|
|
else: |
|
|
bs = image_embeds.size(0) |
|
|
targets = torch.arange(bs).to(image_embeds.device) |
|
|
|
|
|
loss_itc = ( |
|
|
F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) |
|
|
+ F.cross_entropy(sim_t2i, targets, label_smoothing=0.1) |
|
|
) / 2 |
|
|
|
|
|
|
|
|
text_input_ids_world = concat_all_gather(text_tokens['input_ids']) |
|
|
text_attention_mask_world = concat_all_gather(text_tokens['attention_mask']) |
|
|
image_embeds_world = all_gather_with_grad(image_embeds) |
|
|
with torch.no_grad(): |
|
|
if is_dist_avail_and_initialized(): |
|
|
weights_t2i = F.softmax(sim_t2i, dim=1) + 1e-4 |
|
|
weights_i2t = F.softmax(sim_i2t, dim=1) + 1e-4 |
|
|
weights_t2i[:, rank * bs : rank * bs + bs].fill_diagonal_(0) |
|
|
weights_i2t[:, rank * bs : rank * bs + bs].fill_diagonal_(0) |
|
|
else: |
|
|
weights_t2i = F.softmax(sim_t2i, dim=1) + 1e-4 |
|
|
weights_i2t = F.softmax(sim_i2t, dim=1) + 1e-4 |
|
|
weights_t2i.fill_diagonal_(0) |
|
|
weights_i2t.fill_diagonal_(0) |
|
|
|
|
|
|
|
|
|
|
|
image_embeds_neg = [] |
|
|
|
|
|
|
|
|
for b in range(bs): |
|
|
neg_idx = torch.multinomial(weights_t2i[b], 1).item() |
|
|
image_embeds_neg.append(image_embeds_world[neg_idx]) |
|
|
image_embeds_neg = torch.stack(image_embeds_neg, dim=0) |
|
|
|
|
|
|
|
|
text_ids_neg = [] |
|
|
text_atts_neg = [] |
|
|
for b in range(bs): |
|
|
neg_idx = torch.multinomial(weights_i2t[b], 1).item() |
|
|
text_ids_neg.append(text_input_ids_world[neg_idx]) |
|
|
text_atts_neg.append(text_attention_mask_world[neg_idx]) |
|
|
|
|
|
text_ids_neg = torch.stack(text_ids_neg, dim=0) |
|
|
text_atts_neg = torch.stack(text_atts_neg, dim=0) |
|
|
|
|
|
text_ids_all = torch.cat( |
|
|
[text_tokens['input_ids'], text_tokens['input_ids'], text_ids_neg], dim=0 |
|
|
) |
|
|
text_atts_all = torch.cat( |
|
|
[text_tokens['attention_mask'], text_tokens['attention_mask'], text_atts_neg], |
|
|
dim=0, |
|
|
) |
|
|
|
|
|
|
|
|
query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1) |
|
|
query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(image_embeds.device) |
|
|
attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1) |
|
|
image_embeds_all = torch.cat( |
|
|
[image_embeds, image_embeds_neg, image_embeds], dim=0 |
|
|
) |
|
|
image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to(image_embeds.device) |
|
|
|
|
|
output_itm = self.Qformer.bert( |
|
|
text_ids_all, |
|
|
query_embeds=query_tokens_itm, |
|
|
attention_mask=attention_mask_all, |
|
|
encoder_hidden_states=image_embeds_all, |
|
|
encoder_attention_mask=image_atts_all, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :] |
|
|
vl_output = self.itm_head(vl_embeddings) |
|
|
logits = vl_output.mean(dim=1) |
|
|
|
|
|
itm_labels = torch.cat( |
|
|
[torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)], |
|
|
dim=0, |
|
|
).to(image_embeds.device) |
|
|
loss_itm = F.cross_entropy(logits, itm_labels) |
|
|
|
|
|
|
|
|
|
|
|
decoder_input_ids = text_tokens['input_ids'].clone() |
|
|
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id |
|
|
labels = decoder_input_ids.masked_fill( |
|
|
decoder_input_ids == self.tokenizer.pad_token_id, -100 |
|
|
) |
|
|
|
|
|
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image_embeds.device) |
|
|
|
|
|
attention_mask = torch.cat([query_atts, text_tokens["attention_mask"]], dim=1) |
|
|
lm_output = self.Qformer( |
|
|
decoder_input_ids, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=query_output.past_key_values, |
|
|
return_dict=True, |
|
|
labels=labels, |
|
|
) |
|
|
|
|
|
loss_lm = lm_output.loss |
|
|
|
|
|
return PDQ_Output( |
|
|
loss=loss_itc + loss_itm + loss_lm, |
|
|
loss_itc=loss_itc, |
|
|
loss_itm=loss_itm, |
|
|
loss_lm=loss_lm, |
|
|
FSUIE_inputs=query_output.last_hidden_state, |
|
|
cross_attentions=cross_attentions |
|
|
) |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
samples, |
|
|
use_nucleus_sampling=False, |
|
|
num_beams=3, |
|
|
max_length=30, |
|
|
min_length=10, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.0, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
samples (dict): A dictionary containing the following keys: |
|
|
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) |
|
|
use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling. |
|
|
num_beams (int): Number of beams for beam search. 1 means no beam search. |
|
|
max_length (int): The maximum length of the sequence to be generated. |
|
|
min_length (int): The minimum length of the sequence to be generated. |
|
|
top_p (float): The cumulative probability for nucleus sampling. |
|
|
repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. |
|
|
num_captions (int): Number of captions to be generated for each image. |
|
|
Returns: |
|
|
captions (list): A list of strings of length batch_size * num_captions. |
|
|
""" |
|
|
image = samples["image"] |
|
|
image_embeds = self.ln_vision(self.visual_encoder(image)) |
|
|
|
|
|
if not use_nucleus_sampling: |
|
|
image_embeds = image_embeds.repeat_interleave(num_beams, dim=0) |
|
|
else: |
|
|
num_beams = 1 |
|
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device) |
|
|
|
|
|
model_kwargs = { |
|
|
"encoder_hidden_states": image_embeds, |
|
|
"encoder_attention_mask": image_atts, |
|
|
} |
|
|
|
|
|
input_ids = ( |
|
|
torch.LongTensor(image.size(0), 1) |
|
|
.fill_(self.tokenizer.bos_token_id) |
|
|
.to(image_embeds.device) |
|
|
) |
|
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) |
|
|
|
|
|
outputs = self.Qformer.generate( |
|
|
input_ids=input_ids, |
|
|
query_embeds=query_tokens, |
|
|
max_length=max_length, |
|
|
min_length=min_length, |
|
|
num_beams=num_beams, |
|
|
do_sample=use_nucleus_sampling, |
|
|
top_p=top_p, |
|
|
eos_token_id=self.tokenizer.sep_token_id, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
**model_kwargs |
|
|
) |
|
|
captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
return captions |
|
|
|
|
|
def forward_image(self, image): |
|
|
image_embeds = self.ln_vision(self.visual_encoder(image)) |
|
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) |
|
|
|
|
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) |
|
|
|
|
|
query_output = self.Qformer.bert( |
|
|
query_embeds=query_tokens, |
|
|
encoder_hidden_states=image_embeds, |
|
|
encoder_attention_mask=image_atts, |
|
|
return_dict=True, |
|
|
) |
|
|
return query_output.last_hidden_state, image_embeds |
|
|
|
|
|
def forward_text(self, text_tokens): |
|
|
text_output = self.Qformer.bert( |
|
|
text_tokens['input_ids'], |
|
|
attention_mask=text_tokens['attention_mask'], |
|
|
return_dict=True, |
|
|
) |
|
|
return text_output.last_hidden_state[:, 0, :] |
|
|
|
|
|
def compute_itm(self, image_inputs, text_ids, text_atts): |
|
|
image_atts = torch.ones(image_inputs.size()[:-1], dtype=torch.long).to(image_inputs.device) |
|
|
query_tokens = self.query_tokens.expand(image_inputs.shape[0], -1, -1) |
|
|
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image_inputs.device) |
|
|
attention_mask = torch.cat([query_atts, text_atts], dim=1) |
|
|
output_itm = self.Qformer.bert( |
|
|
text_ids, |
|
|
query_embeds=query_tokens, |
|
|
attention_mask=attention_mask, |
|
|
encoder_hidden_states=image_inputs, |
|
|
encoder_attention_mask=image_atts, |
|
|
return_dict=True, |
|
|
) |
|
|
vl_embeddings = output_itm.last_hidden_state[:, : query_tokens.size(1), :] |
|
|
itm_logit = self.itm_head(vl_embeddings) |
|
|
itm_logit = itm_logit[:, :, 1].mean(dim=1) |
|
|
return itm_logit |
|
|
|
|
|
@classmethod |
|
|
def from_config(cls, cfg): |
|
|
|
|
|
img_size = cfg.get("image_size") |
|
|
num_query_token = cfg.get("num_query_token") |
|
|
|
|
|
drop_path_rate = cfg.get("drop_path_rate", 0) |
|
|
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) |
|
|
vit_precision = cfg.get("vit_precision", "fp16") |
|
|
freeze_vit = cfg.get("freeze_vit", True) |
|
|
|
|
|
max_txt_len = cfg.get("max_txt_len", 32) |
|
|
|
|
|
model = cls( |
|
|
img_size=img_size, |
|
|
drop_path_rate=drop_path_rate, |
|
|
use_grad_checkpoint=use_grad_checkpoint, |
|
|
vit_precision=vit_precision, |
|
|
freeze_vit=freeze_vit, |
|
|
num_query_token=num_query_token, |
|
|
max_txt_len=max_txt_len, |
|
|
) |
|
|
model.load_checkpoint_from_config(cfg) |
|
|
|
|
|
return model |
|
|
|
|
|
def compute_sim_matrix(self, data_loader, task_cfg): |
|
|
""" |
|
|
Compute similarity i2t, t2i matrix for the given data loader. |
|
|
""" |
|
|
k_test = task_cfg.k_test |
|
|
|
|
|
return compute_sim_matrix(model=self, data_loader=data_loader, k_test=k_test) |
|
|
|