| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| import math |
| import time |
|
|
| from misc.detr_utils import box_ops |
| from misc.detr_utils.misc import (inverse_sigmoid) |
|
|
| from .matcher import build_matcher |
|
|
| from .deformable_transformer import build_deforamble_transformer |
| from pdvc.CaptioningHead import build_captioner |
| import copy |
| from .criterion import AlignCriterion, SetCriterion, ContrastiveCriterion |
| |
| from misc.utils import decide_two_stage |
| from .base_encoder import build_base_encoder |
| |
| from .video_segmentation import * |
| |
| |
| import numpy as np |
| from itertools import chain |
| |
|
|
|
|
| def _get_clones(module, N): |
| return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
|
|
| class PDVC(nn.Module): |
| """ This is the PDVC module that performs dense video captioning """ |
|
|
| def __init__(self, base_encoder, transformer, captioner, num_classes, num_queries, num_feature_levels, |
| aux_loss=True, with_box_refine=False, opt=None, translator=None): |
| """ Initializes the model. |
| Parameters: |
| transformer: torch module of the transformer architecture. See transformer.py |
| captioner: captioning head for generate a sentence for each event queries |
| num_classes: number of foreground classes |
| num_queries: number of event queries. This is the maximal number of events |
| PDVC can detect in a single video. For ActivityNet Captions, we recommend 10-30 queries. |
| aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. |
| with_box_refine: iterative bounding box refinement |
| opt: all configs |
| """ |
| super().__init__() |
| self.opt = opt |
| self.base_encoder = base_encoder |
| self.transformer = transformer |
| self.caption_head = captioner |
| num_pred_text = 0 |
|
|
| |
| |
| |
| |
|
|
| hidden_dim = transformer.d_model |
| text_hidden_dim = opt.text_hidden_dim |
| |
| if self.opt.use_anchor: |
| |
| self.anchor_embed = nn.Embedding(num_queries, 2) |
| self.query_embed = self.transformer.prepare_init_anchor_and_query(self.anchor_embed, hidden_dim, \ |
| random_anchor_init=True, prior_anchor_duration_init=True, \ |
| prior_duration=0.048) |
| self.query_embed = nn.Parameter(self.query_embed, requires_grad=True) |
| else: |
| self.query_embed = nn.Embedding(num_queries, hidden_dim * 2) |
|
|
| self.class_head = nn.Linear(hidden_dim, num_classes) |
| self.class_refine_head = nn.Linear(hidden_dim, num_classes) |
| self.count_head = nn.Linear(hidden_dim, opt.max_eseq_length + 1) |
| self.bbox_head = MLP(hidden_dim, hidden_dim, 2, 3) |
|
|
| self.num_feature_levels = num_feature_levels |
| self.aux_loss = aux_loss |
| self.with_box_refine = with_box_refine |
| self.share_caption_head = opt.share_caption_head |
|
|
| |
| prior_prob = 0.01 |
| bias_value = -math.log((1 - prior_prob) / prior_prob) |
| self.class_head.bias.data = torch.ones(num_classes) * bias_value |
| self.class_refine_head.bias.data = torch.ones(num_classes) * bias_value |
| nn.init.constant_(self.bbox_head.layers[-1].weight.data, 0) |
| nn.init.constant_(self.bbox_head.layers[-1].bias.data, 0) |
|
|
| if self.opt.matcher_type == 'DTW' or self.opt.matcher_type == 'Sim' \ |
| or self.opt.use_pseudo_box: |
| self.load_text_embed = True |
| else: |
| self.load_text_embed = False |
|
|
|
|
| num_pred = transformer.decoder.num_layers |
| if self.share_caption_head: |
| print('all decoder layers share the same caption head') |
| self.caption_head = nn.ModuleList([self.caption_head for _ in range(num_pred)]) |
| else: |
| print('do NOT share the caption head') |
| self.caption_head = _get_clones(self.caption_head, num_pred) |
|
|
| if self.opt.use_additional_cap_layer: |
| self.caption_head_refine = _get_clones(captioner, self.opt.refine_pseudo_stage_num) |
|
|
| if with_box_refine: |
| self.class_head = _get_clones(self.class_head, num_pred) |
| self.count_head = _get_clones(self.count_head, num_pred) |
| self.bbox_head = _get_clones(self.bbox_head, num_pred) |
| nn.init.constant_(self.bbox_head[0].layers[-1].bias.data[1:], -2) |
| |
| self.transformer.decoder.bbox_head = self.bbox_head |
| else: |
| nn.init.constant_(self.bbox_head.layers[-1].bias.data[1:], -2) |
| self.class_head = nn.ModuleList([self.class_head for _ in range(num_pred)]) |
| self.count_head = nn.ModuleList([self.count_head for _ in range(num_pred)]) |
| self.bbox_head = nn.ModuleList([self.bbox_head for _ in range(num_pred)]) |
| self.transformer.decoder.bbox_head = None |
|
|
| self.class_refine_head = _get_clones(self.class_refine_head, self.opt.refine_pseudo_stage_num) |
| |
| if opt.disable_contrastive_projection: |
| projection_event = nn.Identity() |
| projection_text = nn.Identity() |
| else: |
| projection_event = nn.Linear(hidden_dim, opt.contrastive_hidden_size) |
| projection_text = nn.Linear(text_hidden_dim, opt.contrastive_hidden_size) |
| self.contrastive_projection_event = nn.ModuleList( |
| [projection_event for _ in range(num_pred)]) |
| self.contrastive_projection_text = nn.ModuleList( |
| [projection_text for _ in range(num_pred)]) |
| if opt.enable_bg_for_cl: |
| self.background_embed = nn.Parameter(torch.randn(1, opt.contrastive_hidden_size), requires_grad=True) |
| else: |
| self.background_embed = None |
| |
|
|
| self.translator = translator |
|
|
| self.disable_mid_caption_heads = opt.disable_mid_caption_heads |
| if self.disable_mid_caption_heads: |
| print('only calculate caption loss in the last decoding layer') |
| |
| self.pseudo_boxes = {} |
| |
|
|
| def get_filter_rule_for_encoder(self): |
| filter_rule = lambda x: 'input_proj' in x \ |
| or 'transformer.encoder' in x \ |
| or 'transformer.level_embed' in x \ |
| or 'base_encoder' in x |
| return filter_rule |
|
|
| def encoder_decoder_parameters(self): |
| filter_rule = self.get_filter_rule_for_encoder() |
| enc_paras = [] |
| dec_paras = [] |
| for name, para in self.named_parameters(): |
| if filter_rule(name): |
| print('enc: {}'.format(name)) |
| enc_paras.append(para) |
| else: |
| print('dec: {}'.format(name)) |
| dec_paras.append(para) |
| return enc_paras, dec_paras |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| def forward(self, dt, criterion, contrastive_criterion, eval_mode=False): |
| transformer_input_type = self.opt.transformer_input_type |
| vf = dt['video_tensor'] |
| mask = ~ dt['video_mask'] |
| duration = dt['video_length'][:, 1] |
| video_name = dt['video_key'][0][2:] |
| |
| N, L, C = vf.shape |
| |
|
|
| srcs, masks, pos = self.base_encoder(vf, mask, duration) |
|
|
| src_flatten, temporal_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten = self.transformer.prepare_encoder_inputs( |
| srcs, masks, pos) |
| memory = self.transformer.forward_encoder(src_flatten, temporal_shapes, level_start_index, valid_ratios, |
| lvl_pos_embed_flatten, mask_flatten) |
|
|
| two_stage, disable_iterative_refine, proposals, proposals_mask = decide_two_stage(transformer_input_type, |
| dt, criterion) |
| if two_stage: |
| if transformer_input_type == 'prior_proposals': |
| if self.opt.prior_manner == 'add': |
| |
| init_query_embed = self.query_embed.weight |
| _, tgt = torch.chunk(init_query_embed, 2, dim=1) |
| tgt = tgt.unsqueeze(0).expand(N, -1, -1) |
| init_reference, _, reference_points, query_embed = self.transformer.prepare_decoder_input_prior(proposals, num_queries = self.query_embed.weight.shape[0]) |
| proposals_mask = torch.ones(N, self.query_embed.weight.shape[0], device=query_embed.device).bool() |
| else: |
| init_reference, tgt, reference_points, query_embed = self.transformer.prepare_decoder_input_prior(proposals, num_queries = self.query_embed.weight.shape[0]) |
| proposals_mask = torch.ones(N, self.query_embed.weight.shape[0], device=query_embed.device).bool() |
| else: |
| init_reference, tgt, reference_points, query_embed = self.transformer.prepare_decoder_input_proposal( |
| proposals) |
| else: |
| if self.opt.use_anchor: |
| |
| anchor = self.anchor_embed.weight |
| query_anchor = (self.query_embed, anchor) |
| proposals_mask = torch.ones(N, self.query_embed.shape[0], device=self.query_embed.device).bool() |
| init_reference, tgt, reference_points, query_embed = self.transformer.prepare_decoder_input_anchor(memory, query_anchor) |
| else: |
| query_embed = self.query_embed.weight |
| proposals_mask = torch.ones(N, query_embed.shape[0], device=query_embed.device).bool() |
| init_reference, tgt, reference_points, query_embed = self.transformer.prepare_decoder_input_query(memory, |
| query_embed) |
| hs, inter_references = self.transformer.forward_decoder(tgt, reference_points, memory, temporal_shapes, |
| level_start_index, valid_ratios, query_embed, |
| mask_flatten, proposals_mask, disable_iterative_refine) |
| |
|
|
| |
| |
| if self.load_text_embed and eval_mode==False: |
| |
| |
| |
| raw_text_embed = dt['cap_embed'] * hs.shape[0] |
| |
| event_embed = torch.stack([self.contrastive_projection_event[i](hs_i) for i, hs_i in enumerate(hs)]) |
| text_embed = torch.stack([self.contrastive_projection_text[j](hs_j.cuda()) for j, hs_j in enumerate(raw_text_embed)]) |
| |
| |
| else: |
| raw_text_embed = None |
| text_embed = None |
| event_embed = hs |
| |
| if self.opt.use_pseudo_box and self.training: |
| |
| |
| video_frame_num = dt['video_length'][:,0].cpu().numpy() |
| video_name = dt['video_key'][0] |
| if self.pseudo_boxes.get(video_name) is not None and 'box' in self.pseudo_boxes[video_name].keys() and 'loss' in self.pseudo_boxes[video_name].keys(): |
| |
| video_step_alignment = [self.pseudo_boxes[video_name]['box']] |
|
|
| else: |
| if self.opt.pseudo_box_type == 'align': |
| video_step_segment = [segment_video_into_steps(dt['video_tensor'][i], raw_text_embed[i].to(memory.device)) for i in range(N)] |
| bbox_alignment = [torch.tensor(alignment_to_boundary(video_step_segment[i], video_frame_num)).to(memory.device) for i in range(N)] |
| |
| |
| |
| |
| elif self.opt.pseudo_box_type == "similarity": |
| |
| if self.opt.width_ratio < 0: |
| video_step_alignment = [align_frame_into_steps(dt['video_tensor'][i], raw_text_embed[i].to(memory.device), \ |
| topk=self.opt.top_frames, w=self.opt.window_size, mode=self.opt.statistic_mode) for i in range(N)] |
| else: |
| video_step_alignment = [align_frame_into_steps_order(dt['video_tensor'][i], raw_text_embed[i].to(memory.device), \ |
| topk=self.opt.top_frames, w=self.opt.window_size, mode=self.opt.statistic_mode, ratio=self.opt.width_ratio) for i in range(N)] |
| elif self.opt.pseudo_box_type == 'similarity_op': |
| video_step_alignment = [align_frame_into_steps_op(dt['video_tensor'][i], raw_text_embed[i].to(memory.device), \ |
| topk=self.opt.top_frames, scale=self.opt.width_ratio, beta=1, order=False, num_iterations=self.opt.iteration) for i in range(N)] |
| elif self.opt.pseudo_box_type == 'similarity_op_order': |
| video_step_alignment = [align_frame_into_steps_op(dt['video_tensor'][i], raw_text_embed[i].to(memory.device), topk=self.opt.top_frames, scale=self.opt.width_ratio, beta=1, order=True, num_iterations=self.opt.iteration) for i in range(N)] |
| elif self.opt.pseudo_box_type == 'similarity_op_order_v1': |
| video_step_alignment = [align_frame_into_steps_op_v1(dt['video_tensor'][i], raw_text_embed[i].to(memory.device), topk=self.opt.top_frames, scale=self.opt.width_ratio, beta=1, order=True, num_iterations=self.opt.iteration) for i in range(N)] |
| elif self.opt.pseudo_box_type == 'similarity_op_order_v2': |
| video_step_alignment = [align_frame_into_steps_op_order_v2(dt['video_tensor'][i], raw_text_embed[i].to(memory.device), topk=self.opt.top_frames, threshold=self.opt.width_th, ratio=self.opt.width_ratio, iteration=self.opt.iteration) for i in range(N)] |
| elif self.opt.pseudo_box_type == 'similarity_op_v2': |
| video_step_alignment = [align_frame_into_steps_op_v2(dt['video_tensor'][i], raw_text_embed[i].to(memory.device), topk=self.opt.top_frames, threshold=self.opt.width_th, ratio=self.opt.width_ratio, iteration=self.opt.iteration) for i in range(N)] |
| elif self.opt.pseudo_box_type == 'weight_sim': |
| if self.opt.width_ratio < 0: |
| video_step_alignment = [step_retrieval_weight_sim(dt['video_tensor'][i], raw_text_embed[i].to(memory.device), \ |
| topk=self.opt.top_frames, w=self.opt.window_size) for i in range(N)] |
| else: |
| |
| video_step_alignment = [step_retrieval_weight_sim_order(dt['video_tensor'][i], raw_text_embed[i].to(memory.device), \ |
| topk=self.opt.top_frames, w=self.opt.window_size, ratio=self.opt.width_ratio) for i in range(N)] |
|
|
| elif self.opt.pseudo_box_type == 'weight_index': |
| video_step_alignment = [step_retrieval_weight_index(dt['video_tensor'][i], raw_text_embed[i].to(memory.device), \ |
| topk=self.opt.top_frames, w=self.opt.window_size) for i in range(N)] |
| elif self.opt.pseudo_box_type == 'modeframe': |
| video_step_alignment = [align_frame_into_steps_mode(dt['video_tensor'][i], raw_text_embed[i].to(memory.device), \ |
| topk=self.opt.top_frames, w=self.opt.window_size, ratio=self.opt.width_ratio) for i in range(N)] |
| elif self.opt.pseudo_box_type == 'uniform': |
| video_step_alignment = [uniform_box(dt['video_tensor'][i], raw_text_embed[i].to(memory.device)) for i in range(N)] |
| |
| else: |
| raise NotImplementedError('pseudo_box_type {} is not implemented'.format(self.opt.pseudo_box_type)) |
| |
|
|
| if self.opt.pseudo_box_type != 'align': |
| if self.opt.pseudo_box_type == 'similarity_op_order_v2' or self.opt.pseudo_box_type == 'similarity_op_v2': |
| |
| video_step_alignment, loss_op = [out[0] for out in video_step_alignment], [out[1] for out in video_step_alignment] |
| self.pseudo_boxes[video_name] = {'box': video_step_alignment[0], 'loss': loss_op[0].item()} |
| else: |
| self.pseudo_boxes[video_name] = {'box': video_step_alignment[0]} |
| |
| if self.opt.pseudo_box_type != 'align': |
| bbox_alignment = [(torch.tensor(video_step_alignment[i]) / video_frame_num).to(memory.device).to(torch.float32) for i in range(N)] |
| else: |
| bbox_alignment = [torch.tensor(alignment_to_boundary(video_step_segment[i], video_frame_num)).to(memory.device) for i in range(N)] |
|
|
| |
| |
| |
| |
|
|
| bbox_alignment = to_center_duration(bbox_alignment) |
|
|
|
|
| for sample in range(len(dt['video_target'])): |
| dt['video_target'][sample]['boxes_pseudo'] = bbox_alignment[sample] |
| |
| |
| |
|
|
| |
| others = {'memory': memory, |
| 'mask_flatten': mask_flatten, |
| 'spatial_shapes': temporal_shapes, |
| 'level_start_index': level_start_index, |
| 'valid_ratios': valid_ratios, |
| 'proposals_mask': proposals_mask, |
| 'text_embed': text_embed, |
| 'event_embed': event_embed} |
| |
| if eval_mode or self.opt.caption_loss_coef == 0: |
| out, loss = self.parallel_prediction_full(dt, criterion, hs, init_reference, inter_references, others, |
| disable_iterative_refine, transformer_input_type) |
| else: |
| if self.opt.refine_pseudo_box and self.opt.use_pseudo_box: |
| |
| out, loss = self.parallel_prediction_refine_matched(dt, criterion, contrastive_criterion, hs, init_reference, inter_references, others, |
| disable_iterative_refine, transformer_input_type) |
| else: |
| |
| out, loss = self.parallel_prediction_matched(dt, criterion, contrastive_criterion, hs, init_reference, inter_references, others, |
| disable_iterative_refine, transformer_input_type) |
| return out, loss |
|
|
| def predict_event_num(self, counter, hs_lid): |
| hs_lid_pool = torch.max(hs_lid, dim=1, keepdim=False)[0] |
| outputs_class0 = counter(hs_lid_pool) |
| return outputs_class0 |
|
|
| def parallel_prediction_full(self, dt, criterion, hs, init_reference, inter_references, others, |
| disable_iterative_refine, transformer_input_type='queries'): |
| ''' |
| hs: [decoder_layer, bs, num_query, feat_dim] |
| init_reference: [bs, num_query, 1] |
| inter_references: [decoder_layer, bs, num_query, 2] |
| ''' |
| outputs_classes = [] |
| outputs_classes0 = [] |
| outputs_coords = [] |
| outputs_cap_losses = [] |
| outputs_cap_probs = [] |
| outputs_cap_seqs = [] |
| num_pred = hs.shape[0] |
| |
| for l_id in range(hs.shape[0]): |
| if l_id == 0: |
| reference = init_reference |
| else: |
| reference = inter_references[l_id - 1] |
| hs_lid = hs[l_id] |
| outputs_class = self.class_head[l_id](hs_lid) |
| output_count = self.predict_event_num(self.count_head[l_id], hs_lid) |
| n_pred_sentence = output_count.argmax(dim=-1).clamp(min=1).item() |
| tmp = self.bbox_head[l_id](hs_lid) |
|
|
| |
| if l_id != hs.shape[0] - 1: |
| cap_probs, seq = self.caption_prediction_eval( |
| self.caption_head[l_id], dt, hs_lid, reference, others, 'none') |
| else: |
| cap_probs, seq = self.caption_prediction_eval( |
| self.caption_head[l_id], dt, hs_lid, reference, others, self.opt.caption_decoder_type) |
|
|
| |
| |
| |
| if disable_iterative_refine: |
| outputs_coord = reference |
| else: |
| reference = inverse_sigmoid(reference) |
| if self.opt.matcher_type == 'DTW': |
| assert reference.shape[-1] == 2 and tmp.shape[-1] == 2 |
| if reference.shape[-1] == 2: |
| tmp += reference |
| else: |
| assert reference.shape[-1] == 1 |
| tmp[..., :2] += reference |
| outputs_coord = tmp.sigmoid() |
|
|
| outputs_classes.append(outputs_class) |
| outputs_classes0.append(output_count) |
| outputs_coords.append(outputs_coord) |
| outputs_cap_probs.append(cap_probs) |
| outputs_cap_seqs.append(seq) |
| outputs_class = torch.stack(outputs_classes) |
| output_count = torch.stack(outputs_classes0) |
| outputs_coord = torch.stack(outputs_coords) |
|
|
| all_out = {'pred_logits': outputs_class, |
| 'pred_count': output_count, |
| 'pred_boxes': outputs_coord, |
| 'caption_probs': outputs_cap_probs, |
| 'seq': outputs_cap_seqs} |
| out = {k: v[-1] for k, v in all_out.items()} |
|
|
| if self.aux_loss: |
| ks, vs = list(zip(*(all_out.items()))) |
| out['aux_outputs'] = [{ks[i]: vs[i][j] for i in range(len(ks))} for j in range(num_pred - 1)] |
|
|
| |
| return out, [] |
|
|
| def parallel_prediction_refine_matched(self, dt, criterion, contrastive_criterion, hs, init_reference, inter_references, others, |
| disable_iterative_refine, transformer_input_type='queries'): |
| |
| outputs_classes = [] |
| outputs_counts = [] |
| outputs_coords = [] |
| outputs_cap_costs = [] |
| outputs_cap_losses = [] |
| outputs_cap_probs = [] |
| outputs_cap_seqs = [] |
| cl_match_mats = [] |
|
|
| num_pred = hs.shape[0] |
| if self.opt.pseudo_box_aug: |
| assert self.opt.use_pseudo_box |
| num_sentence = dt['gt_boxes'].size(-2) |
| assert num_sentence == len(dt['cap_raw'][0]) |
| if self.opt.pseudo_box_aug_num * num_sentence > self.opt.num_queries: |
| aug_num = self.opt.num_queries // num_sentence |
| else: |
| aug_num = self.opt.pseudo_box_aug_num |
| if self.opt.refine_pseudo_box: |
| ori_dt_cap_tensor = copy.deepcopy(dt['cap_tensor']) |
| ori_dt_cap_mask = copy.deepcopy(dt['cap_mask']) |
| cap_dim = dt['cap_tensor'].shape[-1] |
| dt['cap_tensor'] = dt['cap_tensor'].repeat(1, aug_num).reshape(-1, cap_dim) |
| dt['cap_mask'] = dt['cap_mask'].repeat(1, aug_num).reshape(-1, cap_dim) |
|
|
| for l_id in range(num_pred): |
| hs_lid = hs[l_id] |
| reference = init_reference if l_id == 0 else inter_references[ |
| l_id - 1] |
| outputs_class = self.class_head[l_id](hs_lid) |
| outputs_count = self.predict_event_num(self.count_head[l_id], hs_lid) |
| tmp = self.bbox_head[l_id](hs_lid) |
| |
| cost_caption, loss_caption, cap_probs, seq = self.caption_prediction(self.caption_head[l_id], dt, hs_lid, |
| reference, others, 'none') |
|
|
| if disable_iterative_refine: |
| outputs_coord = reference |
| else: |
| reference = inverse_sigmoid(reference) |
| if reference.shape[-1] == 2: |
| tmp += reference |
| else: |
| assert reference.shape[-1] == 1 |
| tmp[..., :1] += reference |
| outputs_coord = tmp.sigmoid() |
|
|
| |
| if self.load_text_embed or self.opt.disable_contrastive_projection: |
| assert others['text_embed'].shape[0] == num_pred, \ |
| 'visual features have {} levels, but text have {}'.format(num_pred, others['text_embed'].shape[0]) |
| text_embed = others['text_embed'][l_id] |
| event_embed = others['event_embed'][l_id] |
| event_embed = event_embed.reshape(-1, event_embed.shape[-1]) |
| |
| |
|
|
|
|
| if self.opt.enable_contrastive and self.opt.set_cost_cl > 0: |
| assert len(others['text_embed']) == num_pred, \ |
| 'visual features have {} levels, but text have {}'.format(num_pred, len(others['text_embed'])) |
| text_embed = torch.cat(others['text_embed'][l_id], dim=0) |
| event_embed = others['event_embed'][l_id] |
| event_embed = event_embed.reshape(-1, event_embed.shape[-1]) |
| cl_match_mat = contrastive_criterion.forward_logits(text_embed, event_embed, self.background_embed).t() |
| |
| cl_match_mats.append(cl_match_mat) |
| else: |
| cl_match_mats.append(0) |
|
|
| outputs_classes.append(outputs_class) |
| outputs_counts.append(outputs_count) |
| outputs_coords.append(outputs_coord) |
| |
| outputs_cap_probs.append(cap_probs) |
| outputs_cap_seqs.append(seq) |
|
|
| outputs_class = torch.stack(outputs_classes) |
| outputs_count = torch.stack(outputs_counts) |
| outputs_coord = torch.stack(outputs_coords) |
| |
|
|
| all_out = { |
| 'pred_logits': outputs_class, |
| 'pred_count': outputs_count, |
| 'pred_boxes': outputs_coord, |
| 'caption_probs': outputs_cap_probs, |
| 'seq': outputs_cap_seqs, |
| 'cl_match_mats': cl_match_mats} |
| out = {k: v[-1] for k, v in all_out.items()} |
|
|
|
|
| |
| ks, vs = list(zip(*(all_out.items()))) |
| out['aux_outputs'] = [{ks[i]: vs[i][j] for i in range(len(ks))} for j in range(num_pred - 1)] |
| mil_dict = {} |
| bag_score_cache = [] |
| for stage in range(self.opt.refine_pseudo_stage_num): |
| |
| aug_ratio = self.opt.pseudo_box_aug_ratio * (0.5 ** stage) |
| _, last_indices, aux_indices = criterion(out, dt['video_target'], others, aug_num, aug_ratio) |
| |
| hs_lid = hs[-1] |
| reference = inter_references[-1] |
| indices = last_indices[0] |
| query_indices = indices[0][0] |
| cap_indices = indices[0][1] |
| |
| |
| cap_sort = torch.sort(cap_indices)[1] |
| reorder_query_indices = query_indices[cap_sort] |
| if self.opt.use_neg_pseudo_box: |
| neg_query_indices = [] |
| neg_cap_indices = torch.arange(0,cap_indices.size(0),aug_num).view(num_sentence,-1).repeat(1,self.opt.num_neg_box).view(-1) |
| for i in range(num_sentence): |
| |
| candidates_r = (reorder_query_indices[(i+1)*aug_num:]) |
| candidates_l = (reorder_query_indices[:(i)*aug_num]) |
| if (candidates_r.size(0) > 0) and (candidates_l.size(0) > 0): |
| candidates = torch.cat((candidates_r, candidates_l)) |
| else: |
| candidates = candidates_r if candidates_r.size(0) > 0 else candidates_l |
| if candidates.size(0) == 0: |
| candidates = reorder_query_indices |
| if candidates.size(0) < self.opt.num_neg_box: |
| random_selected_indices = torch.randperm(candidates.size(0)) |
| padding_num = self.opt.num_neg_box - candidates.size(0) |
| random_selected_indices = torch.cat((random_selected_indices, random_selected_indices[:padding_num])) |
| else: |
| random_selected_indices = torch.randperm(reorder_query_indices.size(0)-aug_num)[:self.opt.num_neg_box] |
| neg_query_indices.append(candidates[random_selected_indices]) |
| neg_query_indices = torch.cat(neg_query_indices) |
| neg_indices = [(neg_query_indices, neg_cap_indices)] |
| |
| |
| if self.opt.use_additional_cap_layer: |
| cap_loss, cap_probs, seq, sentence_cap_prob = self.caption_prediction(self.caption_head_refine[stage], dt, hs_lid, reference, |
| others, self.opt.caption_decoder_type, indices) |
| if (stage > 0) and self.opt.use_neg_pseudo_box: |
| _, _, _, neg_cap_prob = self.caption_prediction(self.caption_head_refine[stage], dt, hs_lid, reference, |
| others, self.opt.caption_decoder_type, neg_indices) |
| else: |
| cap_loss, cap_probs, seq, sentence_cap_prob = self.caption_prediction(self.caption_head[-1], dt, hs_lid, reference, |
| others, self.opt.caption_decoder_type, indices) |
| if (stage > 0) and self.opt.use_neg_pseudo_box: |
| _, _, _, neg_cap_prob = self.caption_prediction(self.caption_head[-1], dt, hs_lid, reference, |
| others, self.opt.caption_decoder_type, neg_indices) |
| |
| |
| if self.opt.use_additional_score_layer: |
| query_ins_score = self.class_refine_head[stage](hs_lid)[:, query_indices, :] |
| else: |
| query_ins_score = outputs_classes[-1][:, query_indices, :] |
| query_pred_boxes = outputs_coord[-1][:, query_indices, :] |
| query_pred_boxes = query_pred_boxes[0,:,:][cap_sort].view(-1, 2) |
| |
| try: |
| query_ins_score = query_ins_score[0,cap_sort,0].view(-1, aug_num) |
| except: |
| breakpoint() |
| if self.opt.norm_ins_score == 'softmax': |
| query_ins_score = torch.softmax(query_ins_score, dim=-1) |
| elif self.opt.norm_ins_score == 'sigmoid': |
| query_ins_score = query_ins_score.sigmoid() |
| else: |
| raise NotImplementedError |
|
|
| |
| |
| temperature = 2 |
| sentence_cap_prob = sentence_cap_prob[cap_sort].view(-1, aug_num) |
| cap_len = torch.tensor([len(cap.split()) for cap in dt['cap_raw'][0]], device=sentence_cap_prob.device).unsqueeze(1) |
| sentence_cap_score = (sentence_cap_prob / cap_len) ** temperature + 1e-5 |
|
|
| sentence_cap_score[torch.isinf(sentence_cap_score)] = 1e8 |
|
|
| sentence_cap_score = sentence_cap_score.detach() |
| query_ins_score = query_ins_score.detach() |
|
|
| |
| query_score = sentence_cap_score + query_ins_score |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| bag_score = query_score.sum(dim=-1) |
| bag_score = bag_score.clamp(0,1) |
| bag_score_cache.append(bag_score) |
| mil_weight = bag_score_cache[stage-1] if self.opt.weighted_mil_loss else torch.ones_like(bag_score).to(bag_score.device) |
| if stage > 0: |
| if self.opt.focal_mil: |
| focal_weight = (torch.ones_like(bag_score).to(bag_score.device) - bag_score).pow(2) |
| mil_loss = - focal_weight * (bag_score + 1e-6).log() |
| mil_loss = (mil_weight * mil_loss).mean() |
| else: |
| |
| mil_loss = - (mil_weight * bag_score.log()).mean() |
| if self.opt.use_neg_pseudo_box: |
| neg_cap_prob = neg_cap_prob.sigmoid() |
| neg_loss = - ((neg_cap_prob).pow(2) * (1- neg_cap_prob).log()).view(num_sentence,-1).mean(dim=-1) |
| neg_loss = (mil_weight * neg_loss).mean() |
| mil_loss += neg_loss |
| else: |
| mil_loss = F.binary_cross_entropy(bag_score, torch.ones_like(bag_score).to(bag_score.device)) |
| if 'loss_mil' in mil_dict.keys(): |
| mil_dict['loss_mil'] += mil_loss |
| else: |
| mil_dict['loss_mil'] = mil_loss |
| |
| if self.opt.merge_criterion == 'cap_topk': |
| topk_pseudo_scores, topk_pseudo_indices = torch.topk(sentence_cap_score, k=self.opt.merge_k_boxes, dim=-1) |
| elif self.opt.merge_criterion == 'ins_topk': |
| topk_pseudo_scores, topk_pseudo_indices = torch.topk(query_ins_score, k=self.opt.merge_k_boxes, dim=-1) |
| elif self.opt.merge_criterion == 'ins_cap_topk': |
| topk_pseudo_scores, topk_pseudo_indices = torch.topk(query_score, k=self.opt.merge_k_boxes, dim=-1) |
| else: |
| raise NotImplementedError('merge_criterion {} is not implemented'.format(self.opt.merge_criterion)) |
| |
| topk_pseudo_scores = topk_pseudo_scores / (topk_pseudo_scores.sum(dim=-1, keepdim=True) + 1e-6) |
| weight = topk_pseudo_scores.unsqueeze(-1).repeat(1,1,2) |
| for i in range(len(dt['video_target'])): |
| previous_pseudo_box = dt['video_target'][i]['box_pseudo_aug'] |
| if self.opt.use_query_box_for_refine: |
| |
| previous_pseudo_box = (previous_pseudo_box + query_pred_boxes) / 2 |
| if self.opt.merge_mode == 'weighted_sum': |
| |
| selected_pseudo_box = torch.gather(previous_pseudo_box.view(-1,aug_num,2), 1, \ |
| topk_pseudo_indices.unsqueeze(-1).expand(-1,-1,previous_pseudo_box.size(-1))) |
| refined_pseudo_box = (weight * selected_pseudo_box).sum(dim=1).clamp(0,1) |
| dt['video_target'][i]['boxes_pseudo'] = refined_pseudo_box.detach().clone() |
| |
| |
| |
| |
| |
| elif self.opt.merge_mode == 'interpolate': |
| |
| max_pseudo_scores = topk_pseudo_scores[:,:1] |
| max_coef = 0.5 * torch.ones_like(max_pseudo_scores).to(max_pseudo_scores.device) |
| max_pseudo_box = torch.gather(previous_pseudo_box.view(-1,aug_num,2), 1, \ |
| topk_pseudo_indices[:,:1].unsqueeze(-1).expand(-1,-1,previous_pseudo_box.size(-1))) |
| interpolate_coef = torch.min(max_pseudo_scores, max_coef) |
| refined_pseudo_box = (1-interpolate_coef) * previous_pseudo_box[(aug_num-1)::aug_num, :] \ |
| + interpolate_coef * max_pseudo_box.squeeze(1) |
| refined_pseudo_box = refined_pseudo_box.clamp(0,1) |
| dt['video_target'][i]['boxes_pseudo'] = refined_pseudo_box.detach().clone() |
|
|
| |
| dt['cap_tensor'] = ori_dt_cap_tensor |
| dt['cap_mask'] = ori_dt_cap_mask |
| mil_dict['loss_mil'] = mil_dict['loss_mil'] / self.opt.refine_pseudo_stage_num |
| criterion.pseudo_box_aug = False |
| |
| |
| if self.aux_loss: |
| ks, vs = list(zip(*(all_out.items()))) |
| out['aux_outputs'] = [{ks[i]: vs[i][j] for i in range(len(ks))} for j in range(num_pred - 1)] |
| loss, last_indices, aux_indices = criterion(out, dt['video_target'], others) |
| if self.opt.disable_rematch: |
| |
| selected_indices = query_score.argmax(dim=-1).unsqueeze(-1) |
| query_indices_in_refine = reorder_query_indices.to(selected_indices.device).view(-1, aug_num) |
| query_indices_in_refine = query_indices_in_refine.gather(1, selected_indices) |
| query_indices_in_refine, index_sort = torch.sort(query_indices_in_refine, 0) |
| cap_indices_in_refine = last_indices[0][0][1].sort()[0] |
| last_indices = [[(query_indices_in_refine.view(-1), cap_indices_in_refine[index_sort.view(-1)])], last_indices[1]] |
| loss.update(mil_dict) |
| criterion.pseudo_box_aug = True |
| for l_id in range(hs.shape[0]): |
| hs_lid = hs[l_id] |
| reference = init_reference if l_id == 0 else inter_references[l_id - 1] |
| indices = last_indices[0] if l_id == hs.shape[0] - 1 else aux_indices[l_id][0] |
| cap_loss, cap_probs, seq, sentence_cap_prob = self.caption_prediction(self.caption_head[l_id], dt, hs_lid, reference, |
| others, self.opt.caption_decoder_type, indices) |
| l_dict = {'loss_caption': cap_loss} |
| if l_id != hs.shape[0] - 1: |
| l_dict = {k + f'_{l_id}': v for k, v in l_dict.items()} |
| loss.update(l_dict) |
| out.update({'caption_probs': cap_probs, 'seq': seq}) |
| else: |
| loss, last_indices = criterion(out, dt['video_target'], others) |
| criterion.pseudo_box_aug = True |
| l_id = hs.shape[0] - 1 |
| reference = inter_references[l_id - 1] |
| hs_lid = hs[l_id] |
| indices = last_indices[0] |
| cap_loss, cap_probs, seq, sentence_cap_prob = self.caption_prediction(self.caption_head[l_id], dt, hs_lid, reference, |
| others, self.opt.caption_decoder_type, indices) |
| l_dict = {'loss_caption': cap_loss} |
| loss.update(l_dict) |
|
|
| out.pop('caption_losses') |
| out.pop('caption_costs') |
| out.update({'caption_probs': cap_probs, 'seq': seq}) |
|
|
|
|
| return out, loss |
|
|
| def parallel_prediction_matched(self, dt, criterion, contrastive_criterion, hs, init_reference, inter_references, others, |
| disable_iterative_refine, transformer_input_type='queries'): |
| |
| outputs_classes = [] |
| outputs_counts = [] |
| outputs_coords = [] |
| outputs_cap_costs = [] |
| outputs_cap_losses = [] |
| outputs_cap_probs = [] |
| outputs_cap_seqs = [] |
| cl_match_mats = [] |
|
|
| num_pred = hs.shape[0] |
|
|
| if self.opt.pseudo_box_aug: |
| assert self.opt.use_pseudo_box |
| cap_dim = dt['cap_tensor'].shape[-1] |
| dt['cap_tensor'] = dt['cap_tensor'].repeat(1, self.opt.pseudo_box_aug_num).reshape(-1, cap_dim) |
| dt['cap_mask'] = dt['cap_mask'].repeat(1, self.opt.pseudo_box_aug_num).reshape(-1, cap_dim) |
|
|
| for l_id in range(num_pred): |
| hs_lid = hs[l_id] |
| reference = init_reference if l_id == 0 else inter_references[ |
| l_id - 1] |
| outputs_class = self.class_head[l_id](hs_lid) |
| outputs_count = self.predict_event_num(self.count_head[l_id], hs_lid) |
| tmp = self.bbox_head[l_id](hs_lid) |
|
|
|
|
| cost_caption, loss_caption, cap_probs, seq = self.caption_prediction(self.caption_head[l_id], dt, hs_lid, |
| reference, others, 'none') |
| |
| |
| |
| if disable_iterative_refine: |
| outputs_coord = reference |
| else: |
| reference = inverse_sigmoid(reference) |
| if reference.shape[-1] == 2: |
| tmp += reference |
| else: |
| assert reference.shape[-1] == 1 |
| tmp[..., :1] += reference |
| outputs_coord = tmp.sigmoid() |
|
|
| |
| if self.load_text_embed or not self.opt.disable_contrastive_projection: |
| assert others['text_embed'].shape[0] == num_pred, \ |
| 'visual features have {} levels, but text have {}'.format(num_pred, others['text_embed'].shape[0]) |
| text_embed = others['text_embed'][l_id] |
| event_embed = others['event_embed'][l_id] |
| event_embed = event_embed.reshape(-1, event_embed.shape[-1]) |
| |
| |
|
|
|
|
| if self.opt.enable_contrastive and self.opt.set_cost_cl > 0: |
| assert len(others['text_embed']) == num_pred, \ |
| 'visual features have {} levels, but text have {}'.format(num_pred, len(others['text_embed'])) |
| text_embed = torch.cat(others['text_embed'][l_id], dim=0) |
| event_embed = others['event_embed'][l_id] |
| event_embed = event_embed.reshape(-1, event_embed.shape[-1]) |
| cl_match_mat = contrastive_criterion.forward_logits(text_embed, event_embed, self.background_embed).t() |
| |
| cl_match_mats.append(cl_match_mat) |
| else: |
| cl_match_mats.append(0) |
|
|
| outputs_classes.append(outputs_class) |
| outputs_counts.append(outputs_count) |
| outputs_coords.append(outputs_coord) |
| |
| outputs_cap_probs.append(cap_probs) |
| outputs_cap_seqs.append(seq) |
|
|
| outputs_class = torch.stack(outputs_classes) |
| outputs_count = torch.stack(outputs_counts) |
| outputs_coord = torch.stack(outputs_coords) |
| |
|
|
| all_out = { |
| 'pred_logits': outputs_class, |
| 'pred_count': outputs_count, |
| 'pred_boxes': outputs_coord, |
| 'caption_probs': outputs_cap_probs, |
| 'seq': outputs_cap_seqs, |
| 'cl_match_mats': cl_match_mats} |
| out = {k: v[-1] for k, v in all_out.items()} |
|
|
| if self.aux_loss: |
| ks, vs = list(zip(*(all_out.items()))) |
| out['aux_outputs'] = [{ks[i]: vs[i][j] for i in range(len(ks))} for j in range(num_pred - 1)] |
| if transformer_input_type == 'prior_proposals': |
| loss, _, _ = criterion(out, dt['video_target']) |
| |
| num_sentence = dt['cap_tensor'].shape[0] |
| num_query = hs.shape[-2] |
| num_query_interval = num_query // num_sentence |
| query_indices = [] |
| for i in range(num_sentence): |
| interval_min = i * num_query_interval |
| interval_max = interval_min + num_query_interval |
| sample = torch.randint(interval_min, interval_max, (hs.shape[0],)) |
| query_indices.append(sample) |
| query_indices = torch.cat(query_indices, dim=0) |
| gt_indices = torch.arange(num_sentence) |
|
|
| last_indices = ([(query_indices[::hs.shape[0]], gt_indices)], [None, None]) |
| aux_indices = [] |
| for l_id in range(hs.shape[0]-1): |
| aux_indices.append(([(query_indices[(l_id+1)::hs.shape[0]], gt_indices)], [None, None])) |
| else: |
| loss, last_indices, aux_indices = criterion(out, dt['video_target'], others) |
| for l_id in range(hs.shape[0]): |
| hs_lid = hs[l_id] |
| reference = init_reference if l_id == 0 else inter_references[l_id - 1] |
| indices = last_indices[0] if l_id == hs.shape[0] - 1 else aux_indices[l_id][0] |
| cap_loss, cap_probs, seq, sentence_cap_prob = self.caption_prediction(self.caption_head[l_id], dt, hs_lid, reference, |
| others, self.opt.caption_decoder_type, indices) |
|
|
| l_dict = {'loss_caption': cap_loss} |
| if (self.opt.matcher_type == 'DTW' or self.opt.matcher_type == 'Sim'): |
| contrastive_loss = contrastive_criterion( |
| text_embed = others['text_embed'][l_id], |
| event_embed = others['event_embed'][l_id], |
| matching_indices = indices, |
| bg_embed = self.background_embed, |
| ) |
|
|
| l_dict.update({'contrastive_loss': contrastive_loss}) |
| if l_id != hs.shape[0] - 1: |
| l_dict = {k + f'_{l_id}': v for k, v in l_dict.items()} |
| loss.update(l_dict) |
| out.update({'caption_probs': cap_probs, 'seq': seq}) |
| else: |
| loss, last_indices = criterion(out, dt['video_target'], others) |
|
|
| l_id = hs.shape[0] - 1 |
| reference = inter_references[l_id - 1] |
| hs_lid = hs[l_id] |
| indices = last_indices[0] |
| cap_loss, cap_probs, seq, sentence_cap_prob = self.caption_prediction(self.caption_head[l_id], dt, hs_lid, reference, |
| others, self.opt.caption_decoder_type, indices) |
| l_dict = {'loss_caption': cap_loss} |
| loss.update(l_dict) |
|
|
| out.pop('caption_losses') |
| out.pop('caption_costs') |
| out.update({'caption_probs': cap_probs, 'seq': seq}) |
|
|
| return out, loss |
|
|
| def caption_prediction(self, cap_head, dt, hs, reference, others, captioner_type, indices=None): |
| N_, N_q, C = hs.shape |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| all_cap_num = len(dt['cap_tensor']) |
| query_mask = others['proposals_mask'] |
| gt_mask = dt['gt_boxes_mask'] |
| mix_mask = torch.zeros(query_mask.sum().item(), gt_mask.sum().item()) |
| query_nums, gt_nums = query_mask.sum(1).cpu(), gt_mask.sum(1).cpu() |
| hs_r = torch.masked_select(hs, query_mask.unsqueeze(-1)).reshape(-1, C) |
|
|
| if indices == None: |
| row_idx, col_idx = 0, 0 |
| for i in range(N_): |
| mix_mask[row_idx: (row_idx + query_nums[i]), col_idx: (col_idx + gt_nums[i])] = 1 |
| row_idx=row_idx + query_nums[i] |
| col_idx= col_idx + gt_nums[i] |
|
|
| bigids = mix_mask.nonzero(as_tuple=False) |
| feat_bigids, cap_bigids = bigids[:, 0], bigids[:, 1] |
| else: |
| |
| feat_bigids = torch.zeros(sum([len(_[0]) for _ in indices])).long() |
| cap_bigids = torch.zeros_like(feat_bigids) |
| total_query_ids = 0 |
| total_cap_ids = 0 |
| total_ids = 0 |
| max_pair_num = max([len(_[0]) for _ in indices]) |
| new_hr_for_dsa = torch.zeros(N_, max_pair_num, C) |
| cap_seq = dt['cap_tensor'] |
| new_seq_for_dsa = torch.zeros(N_, max_pair_num, cap_seq.shape[-1], dtype=cap_seq.dtype) |
| for i, index in enumerate(indices): |
| feat_ids, cap_ids = index |
| feat_bigids[total_ids: total_ids + len(feat_ids)] = total_query_ids + feat_ids |
| cap_bigids[total_ids: total_ids + len(feat_ids)] = total_cap_ids + cap_ids |
| new_hr_for_dsa[i, :len(feat_ids)] = hs[i, feat_ids] |
| new_seq_for_dsa[i, :len(feat_ids)] = cap_seq[total_cap_ids + cap_ids] |
| total_query_ids += query_nums[i] |
| total_cap_ids += gt_nums[i] |
| total_ids += len(feat_ids) |
| |
| |
| |
| cap_probs = {} |
| flag = True |
|
|
| if captioner_type == 'none': |
| cost_caption = torch.zeros(N_, N_q, all_cap_num, |
| device=hs.device) |
| loss_caption = torch.zeros(N_, N_q, all_cap_num, device=hs.device) |
| cap_probs['cap_prob_train'] = torch.zeros(1, device=hs.device) |
| cap_probs['cap_prob_eval'] = torch.zeros(N_, N_q, 3, device=hs.device) |
| seq = torch.zeros(N_, N_q, 3, device=hs.device) |
| return cost_caption, loss_caption, cap_probs, seq |
|
|
| elif captioner_type in ['light']: |
| clip = hs_r.unsqueeze(1) |
| clip_mask = clip.new_ones(clip.shape[:2]) |
| event = None |
| elif self.opt.caption_decoder_type == 'standard': |
| |
| |
| if self.training: |
| |
| seq = dt['cap_tensor'][cap_bigids] |
| if self.opt.caption_cost_type != 'rl': |
| if self.opt.refine_pseudo_box: |
| cap_prob, raw_cap_prob = cap_head(hs[:, feat_bigids], reference[:, feat_bigids], others, seq) |
| |
| |
| cap_probs['cap_prob_train'] = cap_prob |
| cap_probs['raw_cap_prob'] = raw_cap_prob |
| else: |
| cap_prob = cap_head(hs[:, feat_bigids], reference[:, feat_bigids], others, seq) |
| |
| cap_probs['cap_prob_train'] = cap_prob |
| else: |
| with torch.no_grad(): |
| cap_prob = cap_head(hs[:, feat_bigids], reference[:, feat_bigids], others, |
| dt['cap_tensor'][cap_bigids]) |
| seq, cap_prob_eval = cap_head.sample(hs, reference, others) |
| if len(seq): |
| seq = seq.reshape(-1, N_q, seq.shape[-1]) |
| cap_prob_eval = cap_prob_eval.reshape(-1, N_q, cap_prob_eval.shape[-1]) |
| cap_probs['cap_prob_eval'] = cap_prob_eval |
|
|
| flag = False |
| pass |
|
|
| if flag: |
| clip_ext = clip[feat_bigids] |
| clip_mask_ext = clip_mask[feat_bigids] |
|
|
| if self.training: |
| seq = dt['cap_tensor'][cap_bigids] |
| if self.opt.caption_cost_type != 'rl': |
| cap_prob = cap_head(event, clip_ext, clip_mask_ext, seq) |
| cap_probs['cap_prob_train'] = cap_prob |
| else: |
| with torch.no_grad(): |
| seq_gt = dt['cap_tensor'][cap_bigids] |
| cap_prob = cap_head(event, clip_ext, clip_mask_ext, seq_gt) |
| seq, cap_prob_eval = cap_head.sample(event, clip, clip_mask) |
|
|
| if len(seq): |
| |
| |
| seq = seq.reshape(-1, N_q, seq.shape[-1]) |
| cap_prob_eval = cap_prob_eval.reshape(-1, N_q, cap_prob_eval.shape[-1]) |
| cap_probs['cap_prob_eval'] = cap_prob_eval |
|
|
| if self.opt.caption_cost_type == 'loss': |
| cap_prob = cap_prob.reshape(-1, cap_prob.shape[-2], cap_prob.shape[-1]) |
| caption_tensor = dt['cap_tensor'][:, 1:][cap_bigids] |
| caption_mask = dt['cap_mask'][:, 1:][cap_bigids] |
| cap_loss = cap_head.build_loss(cap_prob, caption_tensor, caption_mask) |
| cap_cost = cap_loss |
| else: |
| raise AssertionError('caption cost type error') |
|
|
| |
| |
| |
| |
| |
| sentence_cap_prob = - cap_loss |
|
|
| if indices: |
| return cap_loss.mean(), cap_probs, seq, sentence_cap_prob |
| |
| |
| |
| |
| cap_id, query_id = cap_bigids, feat_bigids |
| cost_caption = hs_r.new_zeros((max(query_id) + 1, max(cap_id) + 1)) |
| cost_caption[query_id, cap_id] = cap_cost |
| loss_caption = hs_r.new_zeros((max(query_id) + 1, max(cap_id) + 1)) |
| loss_caption[query_id, cap_id] = cap_loss |
| cost_caption = cost_caption.reshape(-1, N_q, |
| max(cap_id) + 1) |
| loss_caption = loss_caption.reshape(-1, N_q, max(cap_id) + 1) |
| return cost_caption, loss_caption, cap_probs, seq |
|
|
| def caption_prediction_eval(self, cap_head, dt, hs, reference, others, decoder_type, pred_num=None, indices=None): |
| assert indices == None |
| N_, N_q, C = hs.shape |
| query_mask = others['proposals_mask'] |
| gt_mask = dt['gt_boxes_mask'] |
| mix_mask = torch.zeros(query_mask.sum().item(), gt_mask.sum().item()) |
| query_nums, gt_nums = query_mask.sum(1).cpu(), gt_mask.sum(1).cpu() |
| hs_r = torch.masked_select(hs, query_mask.unsqueeze(-1)).reshape(-1, C) |
|
|
| row_idx, col_idx = 0, 0 |
| for i in range(N_): |
| mix_mask[row_idx: (row_idx + query_nums[i]), col_idx: (col_idx + gt_nums[i])] = 1 |
| row_idx = row_idx + query_nums[i] |
| col_idx = col_idx + gt_nums[i] |
|
|
| cap_probs = {} |
|
|
| if decoder_type in ['none']: |
| cap_probs['cap_prob_train'] = torch.zeros(1, device=hs.device) |
| cap_probs['cap_prob_eval'] = torch.zeros(N_, N_q, 3, device=hs.device) |
| seq = torch.zeros(N_, N_q, 3, device=hs.device) |
| return cap_probs, seq |
|
|
| elif decoder_type in ['light']: |
| clip = hs_r.unsqueeze(1) |
| clip_mask = clip.new_ones(clip.shape[:2]) |
| event = None |
| seq, cap_prob_eval = cap_head.sample(event, clip, clip_mask) |
| if len(seq): |
| seq = seq.reshape(-1, N_q, seq.shape[-1]) |
| cap_prob_eval = cap_prob_eval.reshape(-1, N_q, cap_prob_eval.shape[-1]) |
| cap_probs['cap_prob_eval'] = cap_prob_eval |
|
|
| elif decoder_type in ['standard']: |
| assert N_ == 1, 'only support batchsize = 1' |
| with torch.no_grad(): |
| if self.opt.transformer_input_type == 'prior_proposals': |
| |
| |
| if pred_num: |
| num_cap = pred_num |
| else: |
| num_cap = dt['cap_tensor'].shape[0] |
| interval = N_q // num_cap |
| pool_layer = torch.nn.AvgPool1d(interval,stride=interval) |
| hs = pool_layer(hs.permute(0,2,1)).permute(0,2,1)[:,:num_cap,:] |
| reference = pool_layer(reference.permute(0,2,1)).permute(0,2,1)[:,:num_cap,:] |
| seq, cap_prob_eval = cap_head.sample(hs, reference, others) |
| if len(seq): |
| seq = seq.reshape(-1, num_cap, seq.shape[-1]) |
| cap_prob_eval = cap_prob_eval.reshape(-1, num_cap, cap_prob_eval.shape[-1]) |
| cap_probs['cap_prob_eval'] = cap_prob_eval |
| else: |
| seq, cap_prob_eval = cap_head.sample(hs, reference, others) |
| if len(seq): |
| seq = seq.reshape(-1, N_q, seq.shape[-1]) |
| cap_prob_eval = cap_prob_eval.reshape(-1, N_q, cap_prob_eval.shape[-1]) |
| cap_probs['cap_prob_eval'] = cap_prob_eval |
| return cap_probs, seq |
|
|
|
|
| class PostProcess(nn.Module): |
| """ This module converts the model's output into the format expected by the coco api""" |
|
|
| def __init__(self, opt): |
| super().__init__() |
| self.opt = opt |
|
|
| @torch.no_grad() |
| def forward(self, outputs, target_sizes, loader): |
| """ Perform the computation |
| Parameters: |
| outputs: raw outputs of the model |
| target_sizes: tensor of dimension [batch_size] containing the size of each video of the batch |
| """ |
| out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] |
| N, N_q, N_class = out_logits.shape |
| assert len(out_logits) == len(target_sizes) |
| prob = out_logits.sigmoid() |
|
|
| if self.opt.transformer_input_type == 'prior_proposals': |
| |
| |
| topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), outputs['seq'].shape[1], dim=1) |
| else: |
| topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), N_q, dim=1) |
| scores = topk_values |
| |
| topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode='floor') |
| labels = topk_indexes % out_logits.shape[2] |
| boxes = box_ops.box_cl_to_xy(out_bbox) |
| raw_boxes = copy.deepcopy(boxes) |
| boxes[boxes < 0] = 0 |
| boxes[boxes > 1] = 1 |
| boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 2)) |
|
|
| scale_fct = torch.stack([target_sizes, target_sizes], dim=1) |
| boxes = boxes * scale_fct[:, None, :] |
| seq = outputs['seq'] |
| cap_prob = outputs['caption_probs']['cap_prob_eval'] |
| eseq_lens = outputs['pred_count'].argmax(dim=-1).clamp(min=1) |
|
|
| if len(seq): |
| mask = (seq > 0).float() |
| |
| |
| cap_scores = (mask * cap_prob).sum(2).cpu().numpy().astype('float') |
| seq = seq.detach().cpu().numpy().astype('int') |
| caps = [[loader.dataset.translator.rtranslate(s) for s in s_vid] for s_vid in seq] |
| if self.opt.transformer_input_type != 'prior_proposals': |
| caps = [[caps[batch][idx] for q_id, idx in enumerate(b)] for batch, b in enumerate(topk_boxes)] |
| cap_scores = [[cap_scores[batch, idx] for q_id, idx in enumerate(b)] for batch, b in enumerate(topk_boxes)] |
| else: |
| bs, num_queries = boxes.shape[:2] |
| cap_scores = [[-1e5] * num_queries] * bs |
| caps = [[''] * num_queries] * bs |
|
|
| results = [ |
| {'scores': s, 'labels': l, 'boxes': b, 'raw_boxes': b, 'captions': c, 'caption_scores': cs, 'query_id': qid, |
| 'vid_duration': ts, 'pred_seq_len': sl} for s, l, b, rb, c, cs, qid, ts, sl in |
| zip(scores, labels, boxes, raw_boxes, caps, cap_scores, topk_boxes, target_sizes, eseq_lens)] |
| return results |
|
|
|
|
| class MLP(nn.Module): |
| """ Very simple multi-layer perceptron (also called FFN)""" |
|
|
| def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
| super().__init__() |
| self.num_layers = num_layers |
| h = [hidden_dim] * (num_layers - 1) |
| self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) |
|
|
| def forward(self, x): |
| for i, layer in enumerate(self.layers): |
| x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
| return x |
|
|
|
|
| def build(args): |
| device = torch.device(args.device) |
| base_encoder = build_base_encoder(args) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| transformer = build_deforamble_transformer(args) |
| captioner = build_captioner(args) |
|
|
| model = PDVC( |
| base_encoder, |
| transformer, |
| captioner, |
| num_classes=args.num_classes, |
| num_queries=args.num_queries, |
| num_feature_levels=args.num_feature_levels, |
| aux_loss=args.aux_loss, |
| with_box_refine=args.with_box_refine, |
| opt=args |
| ) |
|
|
| matcher = build_matcher(args) |
| if args.matcher_type == 'DTW' and args.use_anchor: |
| weight_dict = {'loss_ce': args.cls_loss_coef, |
| 'loss_bbox': args.bbox_loss_coef, |
| 'loss_giou': args.giou_loss_coef, |
| 'loss_self_iou': args.self_iou_loss_coef, |
| 'loss_ref_rank': args.ref_rank_loss_coef, |
| 'loss_counter': args.count_loss_coef, |
| 'loss_caption': args.caption_loss_coef, |
| 'contrastive_loss': args.contrastive_loss_start_coef, |
| } |
| else: |
| weight_dict = {'loss_ce': args.cls_loss_coef, |
| 'loss_bbox': args.bbox_loss_coef, |
| 'loss_giou': args.giou_loss_coef, |
| 'loss_counter': args.count_loss_coef, |
| 'loss_caption': args.caption_loss_coef, |
| 'contrastive_loss': args.contrastive_loss_start_coef, |
| } |
| if args.refine_pseudo_box: |
| weight_dict.update({'loss_mil': args.mil_loss_coef}) |
| |
| if args.aux_loss: |
| aux_weight_dict = {} |
| for i in range(args.dec_layers - 1): |
| aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) |
| weight_dict.update(aux_weight_dict) |
|
|
| losses = ['labels', 'boxes', 'cardinality'] |
|
|
| if args.matcher_type == 'DTW' or args.matcher_type == 'Sim': |
| criterion = AlignCriterion(args.num_classes, matcher, weight_dict, losses, focal_alpha=args.focal_alpha, |
| focal_gamma=args.focal_gamma, opt=args) |
| contrastive_criterion = ContrastiveCriterion(temperature=args.contrastive_loss_temperature, |
| enable_cross_video_cl=args.enable_cross_video_cl, |
| enable_e2t_cl = args.enable_e2t_cl, |
| enable_bg_for_cl = args.enable_bg_for_cl) |
| contrastive_criterion.to(device) |
| else: |
| criterion = SetCriterion(args.num_classes, matcher, weight_dict, losses, focal_alpha=args.focal_alpha, |
| focal_gamma=args.focal_gamma, opt=args) |
| contrastive_criterion = None |
| |
| criterion.to(device) |
| postprocessors = {'bbox': PostProcess(args)} |
|
|
| return model, criterion, contrastive_criterion, postprocessors |
|
|
|
|
|
|