""" Copyright (c) 2023, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import torch import torch.distributed as dist import torch.nn as nn from torch.nn import functional as F from transformers.utils import ModelOutput from typing import Optional, Tuple from dist_utils import is_dist_avail_and_initialized from base_model import all_gather_with_grad, concat_all_gather from dataclasses import dataclass from blip2 import ( Blip2Base, compute_sim_matrix, ) @dataclass class PDQ_Output(ModelOutput): # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. loss: Optional[torch.FloatTensor] = None loss_itc: Optional[torch.FloatTensor] = None loss_itm: Optional[torch.FloatTensor] = None loss_lm: Optional[torch.FloatTensor] = None FSUIE_inputs: Optional[torch.FloatTensor] = None cross_attentions: Optional[torch.FloatTensor] = None class PDQ(Blip2Base): PRETRAINED_MODEL_CONFIG_DICT = { "pretrain": "configs/models/blip2/blip2_pretrain.yaml", "coco": "configs/models/blip2/blip2_coco.yaml", } def __init__( self, vision_width=1664, num_query_token=32, embed_dim=256, max_txt_len=512, device='cuda:0' ): super().__init__() self.tokenizer = self.init_tokenizer() self.model_device=device self.num_query_token=num_query_token self.Qformer, self.query_tokens = self.init_Qformer( num_query_token, vision_width) self.Qformer.resize_token_embeddings(len(self.tokenizer)) state_dict = self.Qformer.state_dict() for name, param in self.Qformer.named_parameters(): if "_query" in name: key_orig = name.replace("_query", "") param.data.copy_(state_dict[key_orig]) self.vision_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim) self.text_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim) self.itm_head = nn.Linear(self.Qformer.config.hidden_size, 2) self.temp = nn.Parameter(0.07 * torch.ones([])) self.max_txt_len = max_txt_len def forward(self, samples, no_its_and_itm=False): image_embeds = samples["image_embeds"]# torch.Size([6, 257, 1664]) query_ids = samples["query_inputs"]# torch.Size([6, 32]) # text_tokens = samples["answer_inputs"] text_tokens = samples["scene_graph"] text_tokens['input_ids']=text_tokens['input_ids']# torch.Size([6, 512]) text_tokens['attention_mask']=text_tokens['attention_mask'] image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device) # ori_query_tokens=self.Qformer.bert.embeddings(query_ids) # torch.Size([6, 32, 768]) # query_tokens = ori_query_tokens.expand(image_embeds.shape[0], -1, -1) #torch.Size([6, 32, 768]) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_output = self.Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, use_cache=True, return_dict=True, output_attentions=True ) cross_attentions=query_output.cross_attentions if no_its_and_itm: loss_itc=0.0 loss_itm=0.0 loss_lm=0.0 return PDQ_Output( loss=loss_itc + loss_itm + loss_lm, # + loss_lm, loss_itc=loss_itc, loss_itm=loss_itm, loss_lm=loss_lm, FSUIE_inputs=query_output.last_hidden_state ) image_feats = F.normalize( self.vision_proj(query_output.last_hidden_state), dim=-1 ) text_output = self.Qformer.bert( text_tokens['input_ids'], attention_mask=text_tokens['attention_mask'], return_dict=True, ) text_feat = F.normalize( self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1 ) ###============== Image-text Contrastive ===================### image_feats_all = concat_all_gather( image_feats ) # [batch_size*num_gpu, num_query_tokens, embed_dim] text_feat_all = concat_all_gather(text_feat) # [batch_size*num_gpu, embed_dim] sim_q2t = torch.matmul( image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1) ).squeeze(-1) # [batch_size, batch_size*num_gpu, num_query_tokens] # image-text similarity: aggregate across all query tokens sim_i2t, _ = sim_q2t.max(-1) sim_i2t = sim_i2t / self.temp # text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens] sim_t2q = torch.matmul( text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1) ).squeeze(-2) # text-image similarity: aggregate across all query tokens sim_t2i, _ = sim_t2q.max(-1) sim_t2i = sim_t2i / self.temp # [batch_size, batch_size*num_gpu] if is_dist_avail_and_initialized(): rank = dist.get_rank() bs = image_embeds.size(0) targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(image_embeds.device)#.to((self.model_device)) else: bs = image_embeds.size(0) targets = torch.arange(bs).to(image_embeds.device)#.to((self.model_device)) loss_itc = ( F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1) ) / 2 ###============== Image-text Matching ===================### text_input_ids_world = concat_all_gather(text_tokens['input_ids']) text_attention_mask_world = concat_all_gather(text_tokens['attention_mask']) image_embeds_world = all_gather_with_grad(image_embeds) with torch.no_grad(): if is_dist_avail_and_initialized(): weights_t2i = F.softmax(sim_t2i, dim=1) + 1e-4 weights_i2t = F.softmax(sim_i2t, dim=1) + 1e-4 weights_t2i[:, rank * bs : rank * bs + bs].fill_diagonal_(0) weights_i2t[:, rank * bs : rank * bs + bs].fill_diagonal_(0) else: weights_t2i = F.softmax(sim_t2i, dim=1) + 1e-4 weights_i2t = F.softmax(sim_i2t, dim=1) + 1e-4 weights_t2i.fill_diagonal_(0) weights_i2t.fill_diagonal_(0) # select a negative image for each text image_embeds_neg = [] for b in range(bs): neg_idx = torch.multinomial(weights_t2i[b], 1).item() image_embeds_neg.append(image_embeds_world[neg_idx]) image_embeds_neg = torch.stack(image_embeds_neg, dim=0) # select a negative text for each image text_ids_neg = [] text_atts_neg = [] for b in range(bs): neg_idx = torch.multinomial(weights_i2t[b], 1).item() text_ids_neg.append(text_input_ids_world[neg_idx]) text_atts_neg.append(text_attention_mask_world[neg_idx]) text_ids_neg = torch.stack(text_ids_neg, dim=0) text_atts_neg = torch.stack(text_atts_neg, dim=0) text_ids_all = torch.cat( [text_tokens['input_ids'], text_tokens['input_ids'], text_ids_neg], dim=0 ) # pos, pos, neg text_atts_all = torch.cat( [text_tokens['attention_mask'], text_tokens['attention_mask'], text_atts_neg], dim=0, ) # query_tokens_itm = torch.cat([ori_query_tokens,ori_query_tokens,ori_query_tokens],dim=0) query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1) query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(image_embeds.device)#.to(self.model_device) attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1) image_embeds_all = torch.cat( [image_embeds, image_embeds_neg, image_embeds], dim=0 ) # pos, neg, pos image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to(image_embeds.device)#.to(self.model_device) output_itm = self.Qformer.bert( text_ids_all, query_embeds=query_tokens_itm, attention_mask=attention_mask_all, encoder_hidden_states=image_embeds_all, encoder_attention_mask=image_atts_all, return_dict=True, ) vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :] vl_output = self.itm_head(vl_embeddings) logits = vl_output.mean(dim=1) itm_labels = torch.cat( [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)], dim=0, ).to(image_embeds.device)#.to(self.model_device) loss_itm = F.cross_entropy(logits, itm_labels) ##================= Image Captioning ========================## # decoder_input_ids = text_tokens.input_ids.clone() decoder_input_ids = text_tokens['input_ids'].clone() decoder_input_ids[:, 0] = self.tokenizer.bos_token_id labels = decoder_input_ids.masked_fill( decoder_input_ids == self.tokenizer.pad_token_id, -100 ) query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image_embeds.device) # attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1) attention_mask = torch.cat([query_atts, text_tokens["attention_mask"]], dim=1) lm_output = self.Qformer( decoder_input_ids, attention_mask=attention_mask, past_key_values=query_output.past_key_values, return_dict=True, labels=labels, ) loss_lm = lm_output.loss return PDQ_Output( loss=loss_itc + loss_itm + loss_lm, # + loss_lm, loss_itc=loss_itc, loss_itm=loss_itm, loss_lm=loss_lm, FSUIE_inputs=query_output.last_hidden_state, cross_attentions=cross_attentions ) @torch.no_grad() def generate( self, samples, use_nucleus_sampling=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0, ): """ Args: samples (dict): A dictionary containing the following keys: - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling. num_beams (int): Number of beams for beam search. 1 means no beam search. max_length (int): The maximum length of the sequence to be generated. min_length (int): The minimum length of the sequence to be generated. top_p (float): The cumulative probability for nucleus sampling. repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. num_captions (int): Number of captions to be generated for each image. Returns: captions (list): A list of strings of length batch_size * num_captions. """ image = samples["image"] image_embeds = self.ln_vision(self.visual_encoder(image)) if not use_nucleus_sampling: image_embeds = image_embeds.repeat_interleave(num_beams, dim=0) else: num_beams = 1 image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)#.to(self.model_device) model_kwargs = { "encoder_hidden_states": image_embeds, "encoder_attention_mask": image_atts, } input_ids = ( torch.LongTensor(image.size(0), 1) .fill_(self.tokenizer.bos_token_id) .to(image_embeds.device)#.to(self.model_device) ) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) outputs = self.Qformer.generate( input_ids=input_ids, query_embeds=query_tokens, max_length=max_length, min_length=min_length, num_beams=num_beams, do_sample=use_nucleus_sampling, top_p=top_p, eos_token_id=self.tokenizer.sep_token_id, pad_token_id=self.tokenizer.pad_token_id, **model_kwargs ) captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) return captions def forward_image(self, image): image_embeds = self.ln_vision(self.visual_encoder(image)) image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)#.to(self.model_device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_output = self.Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ) return query_output.last_hidden_state, image_embeds def forward_text(self, text_tokens): text_output = self.Qformer.bert( text_tokens['input_ids'], attention_mask=text_tokens['attention_mask'], return_dict=True, ) return text_output.last_hidden_state[:, 0, :] def compute_itm(self, image_inputs, text_ids, text_atts): image_atts = torch.ones(image_inputs.size()[:-1], dtype=torch.long).to(image_inputs.device) query_tokens = self.query_tokens.expand(image_inputs.shape[0], -1, -1) query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image_inputs.device) attention_mask = torch.cat([query_atts, text_atts], dim=1) output_itm = self.Qformer.bert( text_ids, query_embeds=query_tokens, attention_mask=attention_mask, encoder_hidden_states=image_inputs, encoder_attention_mask=image_atts, return_dict=True, ) vl_embeddings = output_itm.last_hidden_state[:, : query_tokens.size(1), :] itm_logit = self.itm_head(vl_embeddings) itm_logit = itm_logit[:, :, 1].mean(dim=1) return itm_logit @classmethod def from_config(cls, cfg): img_size = cfg.get("image_size") num_query_token = cfg.get("num_query_token") drop_path_rate = cfg.get("drop_path_rate", 0) use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) vit_precision = cfg.get("vit_precision", "fp16") freeze_vit = cfg.get("freeze_vit", True) max_txt_len = cfg.get("max_txt_len", 32) model = cls( img_size=img_size, drop_path_rate=drop_path_rate, use_grad_checkpoint=use_grad_checkpoint, vit_precision=vit_precision, freeze_vit=freeze_vit, num_query_token=num_query_token, max_txt_len=max_txt_len, ) model.load_checkpoint_from_config(cfg) return model def compute_sim_matrix(self, data_loader, task_cfg): """ Compute similarity i2t, t2i matrix for the given data loader. """ k_test = task_cfg.k_test return compute_sim_matrix(model=self, data_loader=data_loader, k_test=k_test)