Spaces:
Build error
Build error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import datetime | |
| import logging | |
| import time | |
| import lavis.common.dist_utils as dist_utils | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn.functional as F | |
| from lavis.common.config import node_to_dict | |
| from lavis.common.dist_utils import get_rank | |
| from lavis.common.logger import MetricLogger | |
| from lavis.common.registry import registry | |
| from lavis.models.alpro_models import AlproBase | |
| from lavis.models.alpro_models.alpro_outputs import AlproIntermediateOutput, AlproOutput | |
| from lavis.models.base_model import all_gather_with_grad | |
| from lavis.models.med import XBertEncoder | |
| from lavis.models.timesformer.vit import TimeSformer | |
| from torch import nn | |
| class AlproRetrieval(AlproBase): | |
| PRETRAINED_MODEL_CONFIG_DICT = { | |
| "msrvtt": "configs/models/alpro_retrieval_msrvtt.yaml", | |
| "didemo": "configs/models/alpro_retrieval_didemo.yaml", | |
| } | |
| def __init__( | |
| self, | |
| visual_encoder, | |
| text_encoder, | |
| vision_width=768, | |
| text_width=768, | |
| embed_dim=256, | |
| max_txt_len=35, | |
| temp=0.07, | |
| ): | |
| super().__init__() | |
| self.temp = nn.Parameter(torch.ones([]) * temp) | |
| self.tokenizer = self.init_tokenizer() | |
| self.visual_encoder = visual_encoder | |
| self.text_encoder = text_encoder | |
| vision_width = vision_width | |
| text_width = text_width | |
| self.vision_proj = nn.Linear(vision_width, embed_dim) | |
| self.text_proj = nn.Linear(text_width, embed_dim) | |
| self.itm_head = nn.Linear(text_width, 2) | |
| self.max_txt_len = max_txt_len | |
| def forward(self, samples): | |
| with torch.no_grad(): | |
| self.temp.clamp_(0.001, 0.5) | |
| visual_inputs = samples["video"] | |
| caption = samples["text_input"] | |
| b, t, c, h, w = visual_inputs.shape | |
| # forward text | |
| text = self.tokenizer( | |
| caption, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.max_txt_len, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| text_output = self.text_encoder.forward_text( | |
| text, | |
| token_type_ids=torch.zeros( | |
| text.input_ids.shape, dtype=torch.long, device=self.device | |
| ), | |
| ) | |
| text_embeds = text_output.last_hidden_state | |
| text_feat = F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1) | |
| # forward visual | |
| # timeSformer asks for (b, c, t, h, w) as input. | |
| video_embeds = self.visual_encoder.forward_features(visual_inputs) | |
| video_feat = F.normalize(self.vision_proj(video_embeds[:, 0, :]), dim=-1) | |
| video_atts = torch.ones(video_embeds.size()[:-1], dtype=torch.long).to( | |
| self.device | |
| ) | |
| # ========== (in-batch) ITC loss ========== | |
| gathered_video_feats = all_gather_with_grad(video_feat) | |
| gathered_text_feats = all_gather_with_grad(text_feat) | |
| sim_v2t = video_feat @ gathered_text_feats.t() / self.temp | |
| sim_t2v = text_feat @ gathered_video_feats.t() / self.temp | |
| sim_targets = torch.zeros_like(sim_v2t) | |
| local_rank = get_rank() | |
| b_start, b_end = b * local_rank, b * (local_rank + 1) | |
| sim_targets[:, b_start:b_end] = torch.eye(b) | |
| loss_v2t = -torch.sum(F.log_softmax(sim_v2t, dim=1) * sim_targets, dim=1).mean() | |
| loss_t2v = -torch.sum(F.log_softmax(sim_t2v, dim=1) * sim_targets, dim=1).mean() | |
| vtc_loss = (loss_v2t + loss_t2v) / 2 | |
| ( | |
| vtm_loss, | |
| vtm_logits, | |
| vtm_labels, | |
| encoder_output, | |
| encoder_output_neg, | |
| ) = self.compute_vtm( | |
| text_embeds=text_embeds, | |
| text_atts=text.attention_mask, | |
| image_embeds=video_embeds, | |
| image_atts=video_atts, | |
| sim_i2t=sim_v2t.clone(), # for hard mining | |
| sim_t2i=sim_t2v.clone(), # for hard mining | |
| ) | |
| loss = vtc_loss + vtm_loss | |
| # return {"loss": loss} | |
| return AlproOutput( | |
| loss=loss, | |
| loss_vtc=vtc_loss, | |
| loss_vtm=vtm_loss, | |
| intermediate_output=AlproIntermediateOutput( | |
| video_embeds=video_embeds, | |
| text_embeds=text_embeds, | |
| encoder_output=encoder_output, | |
| encoder_output_neg=encoder_output_neg, | |
| vtm_logits=vtm_logits, | |
| vtm_labels=vtm_labels, | |
| ), | |
| ) | |
| def compute_vtm( | |
| self, text_embeds, text_atts, image_embeds, image_atts, sim_i2t, sim_t2i | |
| ): | |
| device = self.device | |
| # ====== positive pairs ======= | |
| attention_mask = torch.cat([text_atts, image_atts], dim=1) | |
| embedding_output_pos = torch.cat([text_embeds, image_embeds], dim=1) | |
| encoder_outputs_pos = self.text_encoder( | |
| encoder_embeds=embedding_output_pos, | |
| attention_mask=attention_mask, | |
| return_dict=True, | |
| mode="fusion", | |
| ) | |
| # ====== negative pairs ======= | |
| bs = text_embeds.shape[0] | |
| local_rank = get_rank() | |
| b_start, b_end = bs * local_rank, bs * (local_rank + 1) | |
| with torch.no_grad(): | |
| weights_v2t = sim_i2t[:, b_start:b_end] | |
| weights_t2v = sim_t2i[:, b_start:b_end] | |
| # never select self as negative | |
| weights_v2t.fill_diagonal_(-np.Inf) | |
| weights_t2v.fill_diagonal_(-np.Inf) | |
| weights_v2t = F.softmax(weights_v2t, dim=1) | |
| weights_t2v = F.softmax(weights_t2v, dim=1) | |
| # select a negative image for each text | |
| # FIXME to optimize using indexing operations | |
| image_embeds_neg = [] | |
| for b in range(bs): | |
| neg_idx = torch.multinomial(weights_t2v[b], 1).item() | |
| image_embeds_neg.append(image_embeds[neg_idx]) | |
| image_embeds_neg = torch.stack(image_embeds_neg, dim=0) | |
| # select a negative text for each image | |
| text_embeds_neg = [] | |
| text_atts_neg = [] | |
| for b in range(bs): | |
| neg_idx = torch.multinomial(weights_v2t[b], 1).item() | |
| text_embeds_neg.append(text_embeds[neg_idx]) | |
| text_atts_neg.append(text_atts[neg_idx]) | |
| text_embeds_neg = torch.stack(text_embeds_neg, dim=0) | |
| text_atts_neg = torch.stack(text_atts_neg, dim=0) | |
| text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0) | |
| text_atts_all = torch.cat([text_atts, text_atts_neg], dim=0) | |
| video_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0) | |
| video_atts_all = torch.cat([image_atts, image_atts], dim=0) | |
| attention_mask_all = torch.cat([text_atts_all, video_atts_all], dim=1) | |
| embedding_output_all = torch.cat([text_embeds_all, video_embeds_all], dim=1) | |
| # forward negative pairs via cross encoder | |
| encoder_outputs_neg = self.text_encoder( | |
| encoder_embeds=embedding_output_all, | |
| attention_mask=attention_mask_all, | |
| return_dict=True, | |
| mode="fusion", | |
| ) | |
| vl_embeddings = torch.cat( | |
| [ | |
| encoder_outputs_pos.last_hidden_state[:, 0, :], | |
| encoder_outputs_neg.last_hidden_state[:, 0, :], | |
| ], | |
| dim=0, | |
| ) | |
| vtm_logits = self.itm_head(vl_embeddings) | |
| vtm_labels = torch.cat( | |
| [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)], | |
| dim=0, | |
| ).to(device) | |
| vtm_loss = F.cross_entropy(vtm_logits, vtm_labels) | |
| return ( | |
| vtm_loss, | |
| vtm_logits, | |
| vtm_labels, | |
| encoder_outputs_pos, | |
| encoder_outputs_neg, | |
| ) | |
| def compute_sim_matrix(self, data_loader, task_cfg): | |
| k_test = task_cfg.get("k_test") | |
| metric_logger = MetricLogger(delimiter=" ") | |
| header = "Evaluation:" | |
| logging.info("Computing features for evaluation...") | |
| start_time = time.time() | |
| texts = data_loader.dataset.text | |
| num_text = len(texts) | |
| text_bs = 256 | |
| text_ids = [] | |
| text_embeds = [] | |
| text_feats = [] | |
| text_atts = [] | |
| for i in range(0, num_text, text_bs): | |
| text = texts[i : min(num_text, i + text_bs)] | |
| text_input = self.tokenizer( | |
| text, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.max_txt_len, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| text_output = self.text_encoder.forward_text( | |
| text_input, | |
| token_type_ids=torch.zeros( | |
| text_input.input_ids.shape, dtype=torch.long, device=self.device | |
| ), | |
| ) | |
| text_feats.append(text_output.last_hidden_state.cpu()) | |
| text_embed = F.normalize( | |
| self.text_proj(text_output.last_hidden_state[:, 0, :]) | |
| ) | |
| text_embeds.append(text_embed) | |
| text_ids.append(text_input.input_ids) | |
| text_atts.append(text_input.attention_mask) | |
| text_embeds = torch.cat(text_embeds, dim=0) | |
| text_ids = torch.cat(text_ids, dim=0) | |
| text_atts = torch.cat(text_atts, dim=0) | |
| text_feats = torch.cat(text_feats, dim=0) | |
| video_feats = [] | |
| video_embeds = [] | |
| for samples in data_loader: | |
| video = samples["video"] | |
| video = video.to(self.device) | |
| video_feat = self.visual_encoder.forward_features(video) | |
| video_embed = self.vision_proj(video_feat[:, 0, :]) | |
| video_embed = F.normalize(video_embed, dim=-1) | |
| video_feats.append(video_feat.cpu()) | |
| video_embeds.append(video_embed) | |
| video_feats = torch.cat(video_feats, dim=0) | |
| video_embeds = torch.cat(video_embeds, dim=0) | |
| sims_matrix = video_embeds @ text_embeds.t() | |
| score_matrix_v2t = torch.full( | |
| (len(data_loader.dataset.image), len(texts)), -100.0 | |
| ).to(self.device) | |
| num_tasks = dist_utils.get_world_size() | |
| rank = dist_utils.get_rank() | |
| step = sims_matrix.size(0) // num_tasks + 1 | |
| start = rank * step | |
| end = min(sims_matrix.size(0), start + step) | |
| # video-to-text | |
| for i, sims in enumerate( | |
| metric_logger.log_every(sims_matrix[start:end], 50, header) | |
| ): | |
| topk_sim, topk_idx = sims.topk(k=k_test, dim=0) | |
| video_feats_repeat = ( | |
| video_feats[start + i].repeat(k_test, 1, 1).to(self.device) | |
| ) | |
| video_atts_repeat = torch.ones( | |
| video_feats_repeat.size()[:-1], dtype=torch.long | |
| ).to(self.device) | |
| attention_mask = torch.cat([text_atts[topk_idx], video_atts_repeat], dim=1) | |
| embedding_output = torch.cat( | |
| [text_feats[topk_idx].to(self.device), video_feats_repeat], dim=1 | |
| ) | |
| output = self.text_encoder( | |
| encoder_embeds=embedding_output, | |
| attention_mask=attention_mask, | |
| return_dict=True, | |
| mode="fusion", | |
| ) | |
| score = self.itm_head(output.last_hidden_state[:, 0, :])[:, 1] | |
| score_matrix_v2t[start + i, topk_idx] = score + topk_sim | |
| # text-to-video | |
| sims_matrix = sims_matrix.t() | |
| score_matrix_t2v = torch.full( | |
| (len(texts), len(data_loader.dataset.image)), -100.0 | |
| ).to(self.device) | |
| step = sims_matrix.size(0) // num_tasks + 1 | |
| start = rank * step | |
| end = min(sims_matrix.size(0), start + step) | |
| for i, sims in enumerate( | |
| metric_logger.log_every(sims_matrix[start:end], 50, header) | |
| ): | |
| topk_sim, topk_idx = sims.topk(k=k_test, dim=0) | |
| text_feats_repeat = ( | |
| text_feats[start + i].repeat(k_test, 1, 1).to(self.device) | |
| ) | |
| text_atts_repeat = text_atts[start + i].repeat(k_test, 1).to(self.device) | |
| video_atts = torch.ones( | |
| video_feats[topk_idx].size()[:-1], dtype=torch.long | |
| ).to(self.device) | |
| embedding_output = torch.cat( | |
| [text_feats_repeat, video_feats[topk_idx].to(self.device)], dim=1 | |
| ) | |
| attention_mask = torch.cat([text_atts_repeat, video_atts], dim=1) | |
| output = self.text_encoder( | |
| encoder_embeds=embedding_output, | |
| attention_mask=attention_mask, | |
| return_dict=True, | |
| mode="fusion", | |
| ) | |
| score = self.itm_head(output.last_hidden_state[:, 0, :])[:, 1] | |
| score_matrix_t2v[start + i, topk_idx] = score + topk_sim | |
| if dist_utils.is_dist_avail_and_initialized(): | |
| dist.barrier() | |
| torch.distributed.all_reduce( | |
| score_matrix_v2t, op=torch.distributed.ReduceOp.SUM | |
| ) | |
| torch.distributed.all_reduce( | |
| score_matrix_t2v, op=torch.distributed.ReduceOp.SUM | |
| ) | |
| total_time = time.time() - start_time | |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
| logging.info("Evaluation time {}".format(total_time_str)) | |
| return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy() | |
| def from_config(cls, cfg): | |
| # vision encoder | |
| visual_encoder_config = node_to_dict(cfg.timesformer) | |
| visual_encoder = TimeSformer(**visual_encoder_config) | |
| # text encoder | |
| text_encoder = XBertEncoder.from_config(cfg) | |
| max_txt_len = cfg.get("max_txt_len", 35) | |
| model = cls( | |
| visual_encoder=visual_encoder, | |
| text_encoder=text_encoder, | |
| max_txt_len=max_txt_len, | |
| ) | |
| num_patches = ( | |
| visual_encoder_config["image_size"] // visual_encoder_config["patch_size"] | |
| ) ** 2 | |
| num_frames = visual_encoder_config["n_frms"] | |
| model.load_checkpoint_from_config( | |
| cfg, num_frames=num_frames, num_patches=num_patches | |
| ) | |
| return model | |