| import math |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from .models import register_meta_arch, make_backbone, make_neck, make_generator |
| from .blocks import MaskedConv1D, Scale, LayerNorm |
| from .losses import ctr_diou_loss_1d, sigmoid_focal_loss |
|
|
| from ..utils import batched_nms |
| from IPython import embed |
|
|
| import torch |
| import torch.nn as nn |
|
|
| class TimeEmbedding(nn.Module): |
| def __init__(self, dim, max_period=10000): |
| super().__init__() |
| self.dim = dim |
| self.max_period = max_period |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, dim*2), |
| nn.GELU(), |
| nn.Linear(dim*2, dim) |
| ) |
|
|
| def forward(self, t: torch.Tensor): |
| half = self.dim // 2 |
| freqs = torch.exp(-math.log(self.max_period) * torch.arange(0, half, device=t.device) / half) |
| args = t.float().unsqueeze(-1) * freqs |
| emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| if self.dim % 2 == 1: |
| emb = torch.cat([emb, emb[..., :1]*0], dim=-1) |
| return self.mlp(emb) |
|
|
|
|
| class Denoiser1D(nn.Module): |
| def __init__(self, channels=128, temb_dim=128): |
| super().__init__() |
| self.in_conv = nn.Conv1d(channels, channels, 1) |
| self.dw = nn.Conv1d(channels, channels, kernel_size=3, padding=1, groups=channels) |
| self.pw = nn.Conv1d(channels, channels, 1) |
| self.gn = nn.GroupNorm(8, channels) |
| self.to_scale = nn.Linear(temb_dim, channels) |
| self.to_shift = nn.Linear(temb_dim, channels) |
| self.out_conv = nn.Conv1d(channels, channels, 1) |
|
|
| def forward(self, x, temb): |
| |
| h = self.in_conv(x) |
| scale = self.to_scale(temb).unsqueeze(-1) |
| shift = self.to_shift(temb).unsqueeze(-1) |
| h = self.gn(h) * (1 + scale) + shift |
| h = F.gelu(self.dw(h)) |
| h = F.gelu(self.pw(h)) |
| h = self.out_conv(h) |
| return x + h |
|
|
| class NDRefiner(nn.Module): |
| """ |
| """ |
| def __init__(self, channels=128, steps=4, beta_start=1e-4, beta_end=0.02, eta=0.0): |
| super().__init__() |
| self.channels = channels |
| self.steps = steps |
| self.eta = eta |
| betas = torch.linspace(beta_start, beta_end, steps) |
| alphas = 1.0 - betas |
| alpha_bars = torch.cumprod(alphas, dim=0) |
| self.register_buffer('betas', betas) |
| self.register_buffer('alphas', alphas) |
| self.register_buffer('alpha_bars', alpha_bars) |
|
|
| self.temb = TimeEmbedding(channels) |
| self.denoiser = Denoiser1D(channels, temb_dim=channels) |
|
|
| def q_sample(self, x0, t_idx, noise=None): |
| if noise is None: |
| noise = torch.randn_like(x0) |
| a_bar = self.alpha_bars[t_idx].view(-1, 1, 1) |
| return a_bar.sqrt() * x0 + (1 - a_bar).sqrt() * noise, noise |
|
|
| def forward(self, x): |
| """ |
| """ |
| B, C, L = x.shape |
| device = x.device |
|
|
| if self.training: |
| t_idx = torch.randint(0, self.steps, (B,), device=device) |
| else: |
| t_idx = torch.full((B,), self.steps - 1, device=device) |
|
|
| x_t, noise = self.q_sample(x, t_idx) |
|
|
| for idx in reversed(range(self.steps)): |
| t = torch.full((B,), idx, device=device) |
| t_emb = self.temb(t) |
| eps_pred = self.denoiser(x_t, t_emb) |
|
|
| a_bar_t = self.alpha_bars[idx] |
| a_t = self.alphas[idx] |
| a_bar_prev = self.alpha_bars[idx - 1] if idx > 0 else torch.tensor(1.0, device=device) |
| a_bar_prev = a_bar_prev.view(1, 1, 1) |
|
|
| x0_pred = (x_t - (1 - a_bar_t).sqrt() * eps_pred) / a_bar_t.sqrt() |
| x0_pred = x0_pred.clamp(-3.0, 3.0) |
|
|
| sigma_t = self.eta * ((1 - a_bar_prev)/(1 - a_bar_t) * (1 - a_t)).sqrt() |
| noise = torch.randn_like(x_t) if sigma_t > 0 else 0.0 |
|
|
| x_t = ( |
| a_bar_prev.sqrt() * x0_pred + |
| (1 - a_bar_prev - sigma_t**2).sqrt() * eps_pred + |
| sigma_t * noise |
| ) |
|
|
| return x_t |
|
|
| class PtTransformerClsHead(nn.Module): |
| """ |
| 1D Conv heads for classification |
| """ |
| def __init__( |
| self, |
| input_dim, |
| feat_dim, |
| num_classes, |
| prior_prob=0.01, |
| num_layers=3, |
| kernel_size=3, |
| act_layer=nn.ReLU, |
| with_ln=False, |
| empty_cls = [] |
| ): |
| super().__init__() |
| self.act = act_layer() |
|
|
| |
| self.head = nn.ModuleList() |
| self.norm = nn.ModuleList() |
| for idx in range(num_layers-1): |
| if idx == 0: |
| in_dim = input_dim |
| out_dim = feat_dim |
| else: |
| in_dim = feat_dim |
| out_dim = feat_dim |
| self.head.append( |
| MaskedConv1D( |
| in_dim, out_dim, kernel_size, |
| stride=1, |
| padding=kernel_size//2, |
| bias=(not with_ln) |
| ) |
| ) |
| if with_ln: |
| self.norm.append( |
| LayerNorm(out_dim) |
| ) |
| else: |
| self.norm.append(nn.Identity()) |
|
|
| |
| self.cls_head = MaskedConv1D( |
| feat_dim, num_classes, kernel_size, |
| stride=1, padding=kernel_size//2 |
| ) |
|
|
| bias_value = -(math.log((1 - prior_prob) / prior_prob)) |
| torch.nn.init.constant_(self.cls_head.conv.bias, bias_value) |
|
|
| if len(empty_cls) > 0: |
| bias_value = -(math.log((1 - 1e-6) / 1e-6)) |
| for idx in empty_cls: |
| torch.nn.init.constant_(self.cls_head.conv.bias[idx], bias_value) |
|
|
| def forward(self, fpn_feats, fpn_masks): |
| assert len(fpn_feats) == len(fpn_masks) |
|
|
| |
| out_logits = tuple() |
| for _, (cur_feat, cur_mask) in enumerate(zip(fpn_feats, fpn_masks)): |
| cur_out = cur_feat |
| for idx in range(len(self.head)): |
| cur_out, _ = self.head[idx](cur_out, cur_mask) |
| cur_out = self.act(self.norm[idx](cur_out)) |
| cur_logits, _ = self.cls_head(cur_out, cur_mask) |
| out_logits += (cur_logits, ) |
|
|
| return out_logits |
|
|
|
|
| class PtTransformerRegHead(nn.Module): |
| """ |
| Shared 1D Conv heads for regression |
| Simlar logic as PtTransformerClsHead with separated implementation for clarity |
| """ |
| def __init__( |
| self, |
| input_dim, |
| feat_dim, |
| fpn_levels, |
| num_layers=3, |
| kernel_size=3, |
| act_layer=nn.ReLU, |
| with_ln=False |
| ): |
| super().__init__() |
| self.fpn_levels = fpn_levels |
| self.act = act_layer() |
|
|
| |
| self.head = nn.ModuleList() |
| self.norm = nn.ModuleList() |
| for idx in range(num_layers-1): |
| if idx == 0: |
| in_dim = input_dim |
| out_dim = feat_dim |
| else: |
| in_dim = feat_dim |
| out_dim = feat_dim |
| self.head.append( |
| MaskedConv1D( |
| in_dim, out_dim, kernel_size, |
| stride=1, |
| padding=kernel_size//2, |
| bias=(not with_ln) |
| ) |
| ) |
| if with_ln: |
| self.norm.append( |
| LayerNorm(out_dim) |
| ) |
| else: |
| self.norm.append(nn.Identity()) |
|
|
| self.scale = nn.ModuleList() |
| for idx in range(fpn_levels): |
| self.scale.append(Scale()) |
|
|
| |
| self.offset_head = MaskedConv1D( |
| feat_dim, 2, kernel_size, |
| stride=1, padding=kernel_size//2 |
| ) |
|
|
| def forward(self, fpn_feats, fpn_masks): |
| assert len(fpn_feats) == len(fpn_masks) |
| assert len(fpn_feats) == self.fpn_levels |
|
|
| |
| out_offsets = tuple() |
| for l, (cur_feat, cur_mask) in enumerate(zip(fpn_feats, fpn_masks)): |
| cur_out = cur_feat |
| for idx in range(len(self.head)): |
| cur_out, _ = self.head[idx](cur_out, cur_mask) |
| cur_out = self.act(self.norm[idx](cur_out)) |
| cur_offsets, _ = self.offset_head(cur_out, cur_mask) |
| out_offsets += (F.relu(self.scale[l](cur_offsets)), ) |
|
|
| |
| return out_offsets |
|
|
|
|
| @register_meta_arch("LocPointTransformer") |
| class PtTransformer(nn.Module): |
| """ |
| Transformer based model for single stage action localization |
| """ |
| def __init__( |
| self, |
| backbone_type, |
| fpn_type, |
| backbone_arch, |
| scale_factor, |
| input_dim, |
| max_seq_len, |
| max_buffer_len_factor, |
| n_head, |
| n_mha_win_size, |
| embd_kernel_size, |
| embd_dim, |
| embd_with_ln, |
| fpn_dim, |
| fpn_with_ln, |
| fpn_start_level, |
| head_dim, |
| regression_range, |
| head_num_layers, |
| head_kernel_size, |
| head_with_ln, |
| use_abs_pe, |
| use_rel_pe, |
| num_classes, |
| train_cfg, |
| test_cfg |
| ): |
| super().__init__() |
| |
| self.fpn_strides = [scale_factor**i for i in range( |
| fpn_start_level, backbone_arch[-1]+1 |
| )] |
| self.reg_range = regression_range |
| assert len(self.fpn_strides) == len(self.reg_range) |
| self.scale_factor = scale_factor |
| |
| |
| self.num_classes = num_classes |
|
|
| |
| self.max_seq_len = max_seq_len |
| if isinstance(n_mha_win_size, int): |
| self.mha_win_size = [n_mha_win_size]*(1 + backbone_arch[-1]) |
| else: |
| assert len(n_mha_win_size) == (1 + backbone_arch[-1]) |
| self.mha_win_size = n_mha_win_size |
| max_div_factor = 1 |
| for l, (s, w) in enumerate(zip(self.fpn_strides, self.mha_win_size)): |
| stride = s * (w // 2) * 2 if w > 1 else s |
| assert max_seq_len % stride == 0, "max_seq_len must be divisible by fpn stride and window size" |
| if max_div_factor < stride: |
| max_div_factor = stride |
| self.max_div_factor = max_div_factor |
|
|
| |
| self.train_center_sample = train_cfg['center_sample'] |
| assert self.train_center_sample in ['radius', 'none'] |
| self.train_center_sample_radius = train_cfg['center_sample_radius'] |
| self.train_loss_weight = train_cfg['loss_weight'] |
| self.train_cls_prior_prob = train_cfg['cls_prior_prob'] |
| self.train_dropout = train_cfg['dropout'] |
| self.train_droppath = train_cfg['droppath'] |
| self.train_label_smoothing = train_cfg['label_smoothing'] |
|
|
| |
| self.test_pre_nms_thresh = test_cfg['pre_nms_thresh'] |
| self.test_pre_nms_topk = test_cfg['pre_nms_topk'] |
| self.test_iou_threshold = test_cfg['iou_threshold'] |
| self.test_min_score = test_cfg['min_score'] |
| self.test_max_seg_num = test_cfg['max_seg_num'] |
| self.test_nms_method = test_cfg['nms_method'] |
| assert self.test_nms_method in ['soft', 'hard', 'none'] |
| self.test_duration_thresh = test_cfg['duration_thresh'] |
| self.test_multiclass_nms = test_cfg['multiclass_nms'] |
| self.test_nms_sigma = test_cfg['nms_sigma'] |
| self.test_voting_thresh = test_cfg['voting_thresh'] |
|
|
| |
| |
| assert backbone_type in ['convTransformer', 'conv'] |
| if backbone_type == 'convTransformer': |
| self.backbone = make_backbone( |
| 'convTransformer', |
| **{ |
| 'n_in' : input_dim, |
| 'n_embd' : embd_dim, |
| 'n_head': n_head, |
| 'n_embd_ks': embd_kernel_size, |
| 'max_len': max_seq_len, |
| 'arch' : backbone_arch, |
| 'mha_win_size': self.mha_win_size, |
| 'scale_factor' : scale_factor, |
| 'with_ln' : embd_with_ln, |
| 'attn_pdrop' : 0.0, |
| 'proj_pdrop' : self.train_dropout, |
| 'path_pdrop' : self.train_droppath, |
| 'use_abs_pe' : use_abs_pe, |
| 'use_rel_pe' : use_rel_pe |
| } |
| ) |
| else: |
| self.backbone = make_backbone( |
| 'conv', |
| **{ |
| 'n_in': input_dim, |
| 'n_embd': embd_dim, |
| 'n_embd_ks': embd_kernel_size, |
| 'arch': backbone_arch, |
| 'scale_factor': scale_factor, |
| 'with_ln' : embd_with_ln |
| } |
| ) |
|
|
| |
| assert fpn_type in ['fpn', 'identity'] |
| self.neck = make_neck( |
| fpn_type, |
| **{ |
| 'in_channels' : [embd_dim] * (backbone_arch[-1] + 1), |
| 'out_channel' : fpn_dim, |
| 'scale_factor' : scale_factor, |
| 'start_level' : fpn_start_level, |
| 'with_ln' : fpn_with_ln |
| } |
| ) |
|
|
| |
| self.point_generator = make_generator( |
| 'point', |
| **{ |
| 'max_seq_len' : max_seq_len * max_buffer_len_factor, |
| 'fpn_strides' : self.fpn_strides, |
| 'regression_range' : self.reg_range |
| } |
| ) |
|
|
| |
| self.cls_head = PtTransformerClsHead( |
| fpn_dim, head_dim, self.num_classes, |
| kernel_size=head_kernel_size, |
| prior_prob=self.train_cls_prior_prob, |
| with_ln=head_with_ln, |
| num_layers=head_num_layers, |
| empty_cls=train_cfg['head_empty_cls'] |
| ) |
| self.reg_head = PtTransformerRegHead( |
| fpn_dim, head_dim, len(self.fpn_strides), |
| kernel_size=head_kernel_size, |
| num_layers=head_num_layers, |
| with_ln=head_with_ln |
| ) |
|
|
| |
| |
| self.loss_normalizer = train_cfg['init_loss_norm'] |
| self.loss_normalizer_momentum = 0.9 |
|
|
| self.nd_blocks = nn.ModuleList([ |
| NDRefiner(channels=fpn_dim, |
| steps=train_cfg['denoise_steps'], |
| beta_start=train_cfg['beta_start_end'][0], beta_end=train_cfg['beta_start_end'][1], |
| eta=train_cfg['eta'], |
| ) |
| for _ in range(6) |
| ]) |
|
|
| @property |
| def device(self): |
| return list(set(p.device for p in self.parameters()))[0] |
|
|
| def forward(self, video_list): |
| |
| batched_inputs, batched_masks = self.preprocessing(video_list) |
|
|
| |
| feats, masks = self.backbone(batched_inputs, batched_masks) |
| fpn_feats, fpn_masks = self.neck(feats, masks) |
|
|
| |
| |
| |
| |
| fpn_feats = [self.nd_blocks[i](anchor) for i, anchor in enumerate(fpn_feats)] |
| points = self.point_generator(fpn_feats) |
|
|
| |
| out_cls_logits = self.cls_head(fpn_feats, fpn_masks) |
| |
| out_offsets = self.reg_head(fpn_feats, fpn_masks) |
|
|
| |
| |
| out_cls_logits = [x.permute(0, 2, 1) for x in out_cls_logits] |
| |
| out_offsets = [x.permute(0, 2, 1) for x in out_offsets] |
| |
| fpn_masks = [x.squeeze(1) for x in fpn_masks] |
|
|
| |
| if self.training: |
| |
| assert video_list[0]['segments'] is not None, "GT action labels does not exist" |
| assert video_list[0]['labels'] is not None, "GT action labels does not exist" |
| gt_segments = [x['segments'].to(self.device) for x in video_list] |
| gt_labels = [x['labels'].to(self.device) for x in video_list] |
|
|
| |
| |
| |
| gt_cls_labels, gt_offsets = self.label_points( |
| points, gt_segments, gt_labels) |
|
|
| |
| losses = self.losses( |
| fpn_masks, |
| out_cls_logits, out_offsets, |
| gt_cls_labels, gt_offsets |
| ) |
| return losses |
|
|
| else: |
| |
| results = self.inference( |
| video_list, points, fpn_masks, |
| out_cls_logits, out_offsets |
| ) |
| return results |
|
|
| @torch.no_grad() |
| def preprocessing(self, video_list, padding_val=0.0): |
| """ |
| Generate batched features and masks from a list of dict items |
| """ |
| feats = [x['feats'] for x in video_list] |
| feats_lens = torch.as_tensor([feat.shape[-1] for feat in feats]) |
| max_len = feats_lens.max(0).values.item() |
|
|
| if self.training: |
| assert max_len <= self.max_seq_len, "Input length must be smaller than max_seq_len during training" |
| |
| max_len = self.max_seq_len |
| |
| batch_shape = [len(feats), feats[0].shape[0], max_len] |
| batched_inputs = feats[0].new_full(batch_shape, padding_val) |
| for feat, pad_feat in zip(feats, batched_inputs): |
| pad_feat[..., :feat.shape[-1]].copy_(feat) |
| else: |
| assert len(video_list) == 1, "Only support batch_size = 1 during inference" |
| |
| if max_len <= self.max_seq_len: |
| max_len = self.max_seq_len |
| else: |
| |
| stride = self.max_div_factor |
| max_len = (max_len + (stride - 1)) // stride * stride |
| padding_size = [0, max_len - feats_lens[0]] |
| batched_inputs = F.pad( |
| feats[0], padding_size, value=padding_val).unsqueeze(0) |
|
|
| |
| batched_masks = torch.arange(max_len)[None, :] < feats_lens[:, None] |
|
|
| |
| batched_inputs = batched_inputs.to(self.device) |
| batched_masks = batched_masks.unsqueeze(1).to(self.device) |
|
|
| return batched_inputs, batched_masks |
|
|
| @torch.no_grad() |
| def label_points(self, points, gt_segments, gt_labels): |
| |
| |
| num_levels = len(points) |
| concat_points = torch.cat(points, dim=0) |
| gt_cls, gt_offset = [], [] |
|
|
| |
| for gt_segment, gt_label in zip(gt_segments, gt_labels): |
| cls_targets, reg_targets = self.label_points_single_video( |
| concat_points, gt_segment, gt_label |
| ) |
| |
| gt_cls.append(cls_targets) |
| gt_offset.append(reg_targets) |
|
|
| return gt_cls, gt_offset |
|
|
| @torch.no_grad() |
| def label_points_single_video(self, concat_points, gt_segment, gt_label): |
| |
| |
| |
| num_pts = concat_points.shape[0] |
| num_gts = gt_segment.shape[0] |
|
|
| |
| if num_gts == 0: |
| cls_targets = gt_segment.new_full((num_pts, self.num_classes), 0) |
| reg_targets = gt_segment.new_zeros((num_pts, 2)) |
| return cls_targets, reg_targets |
|
|
| |
| lens = gt_segment[:, 1] - gt_segment[:, 0] |
| lens = lens[None, :].repeat(num_pts, 1) |
|
|
| |
| |
| gt_segs = gt_segment[None].expand(num_pts, num_gts, 2) |
| left = concat_points[:, 0, None] - gt_segs[:, :, 0] |
| right = gt_segs[:, :, 1] - concat_points[:, 0, None] |
| reg_targets = torch.stack((left, right), dim=-1) |
|
|
| if self.train_center_sample == 'radius': |
| |
| center_pts = 0.5 * (gt_segs[:, :, 0] + gt_segs[:, :, 1]) |
| |
| |
| |
| t_mins = \ |
| center_pts - concat_points[:, 3, None] * self.train_center_sample_radius |
| t_maxs = \ |
| center_pts + concat_points[:, 3, None] * self.train_center_sample_radius |
| |
| |
| |
| |
| cb_dist_left = concat_points[:, 0, None] \ |
| - torch.maximum(t_mins, gt_segs[:, :, 0]) |
| cb_dist_right = torch.minimum(t_maxs, gt_segs[:, :, 1]) \ |
| - concat_points[:, 0, None] |
| |
| center_seg = torch.stack( |
| (cb_dist_left, cb_dist_right), -1) |
| |
| inside_gt_seg_mask = center_seg.min(-1)[0] > 0 |
| else: |
| |
| inside_gt_seg_mask = reg_targets.min(-1)[0] > 0 |
|
|
| |
| max_regress_distance = reg_targets.max(-1)[0] |
| |
| inside_regress_range = torch.logical_and( |
| (max_regress_distance >= concat_points[:, 1, None]), |
| (max_regress_distance <= concat_points[:, 2, None]) |
| ) |
|
|
| |
| |
| max_int = 1000000 |
| |
| ''' |
| In [6]: lens.masked_fill_(inside_gt_seg_mask==0, float('inf')) |
| --------------------------------------------------------------------------- |
| RuntimeError Traceback (most recent call last) |
| Cell In[6], line 1 |
| ----> 1 lens.masked_fill_(inside_gt_seg_mask==0, float('inf')) |
| |
| RuntimeError: value cannot be converted to type int64_t without overflow: inf |
| ''' |
| lens.masked_fill_(inside_gt_seg_mask==0, max_int) |
| lens.masked_fill_(inside_regress_range==0, max_int) |
| |
| min_len, min_len_inds = lens.min(dim=1) |
|
|
| |
| min_len_mask = torch.logical_and( |
| (lens <= (min_len[:, None] + 1e-3)), (lens < max_int) |
| ).to(reg_targets.dtype) |
|
|
| |
| gt_label_one_hot = F.one_hot( |
| gt_label, self.num_classes |
| ).to(reg_targets.dtype) |
| cls_targets = min_len_mask @ gt_label_one_hot |
| |
| cls_targets.clamp_(min=0.0, max=1.0) |
| |
| reg_targets = reg_targets[range(num_pts), min_len_inds] |
| |
| reg_targets /= concat_points[:, 3, None] |
|
|
| return cls_targets, reg_targets |
|
|
| def losses( |
| self, fpn_masks, |
| out_cls_logits, out_offsets, |
| gt_cls_labels, gt_offsets |
| ): |
| |
| |
| |
| valid_mask = torch.cat(fpn_masks, dim=1) |
|
|
| |
| |
| gt_cls = torch.stack(gt_cls_labels) |
| pos_mask = torch.logical_and((gt_cls.sum(-1) > 0), valid_mask) |
|
|
| |
| pred_offsets = torch.cat(out_offsets, dim=1)[pos_mask] |
| gt_offsets = torch.stack(gt_offsets)[pos_mask] |
|
|
| |
| num_pos = pos_mask.sum().item() |
| self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + ( |
| 1 - self.loss_normalizer_momentum |
| ) * max(num_pos, 1) |
|
|
| |
| gt_target = gt_cls[valid_mask] |
|
|
| |
| gt_target *= 1 - self.train_label_smoothing |
| gt_target += self.train_label_smoothing / (self.num_classes + 1) |
|
|
| |
| cls_loss = sigmoid_focal_loss( |
| torch.cat(out_cls_logits, dim=1)[valid_mask], |
| gt_target, |
| reduction='sum' |
| ) |
| cls_loss /= self.loss_normalizer |
|
|
| |
| if num_pos == 0: |
| reg_loss = 0 * pred_offsets.sum() |
| else: |
| |
| reg_loss = ctr_diou_loss_1d( |
| pred_offsets, |
| gt_offsets, |
| reduction='sum' |
| ) |
| reg_loss /= self.loss_normalizer |
|
|
| if self.train_loss_weight > 0: |
| loss_weight = self.train_loss_weight |
| else: |
| loss_weight = cls_loss.detach() / max(reg_loss.item(), 0.01) |
|
|
| |
| final_loss = cls_loss + reg_loss * loss_weight |
| return {'cls_loss' : cls_loss, |
| 'reg_loss' : reg_loss, |
| 'final_loss' : final_loss} |
|
|
| @torch.no_grad() |
| def inference( |
| self, |
| video_list, |
| points, fpn_masks, |
| out_cls_logits, out_offsets |
| ): |
| |
| |
| |
| results = [] |
|
|
| |
| vid_idxs = [x['video_id'] for x in video_list] |
| durations = [x['duration'] for x in video_list] |
| gt_times = [x['gt_time'] for x in video_list] |
| |
| |
| |
| vid_ft_nframes = [x['feat_num_frames'] for x in video_list] |
|
|
| |
| |
| |
| |
| |
| for idx, (vidx, nframes, duration, gt_time) in enumerate( |
| zip(vid_idxs, vid_ft_nframes, durations, gt_times) |
| ): |
| |
| cls_logits_per_vid = [x[idx] for x in out_cls_logits] |
| offsets_per_vid = [x[idx] for x in out_offsets] |
| fpn_masks_per_vid = [x[idx] for x in fpn_masks] |
| |
| results_per_vid = self.inference_single_video( |
| points, fpn_masks_per_vid, |
| cls_logits_per_vid, offsets_per_vid |
| ) |
| |
| results_per_vid['video_id'] = vidx |
| |
| |
| |
| results_per_vid['duration'] = duration |
| results_per_vid['gt_time'] = gt_time |
| results_per_vid['feat_num_frames'] = nframes |
| results.append(results_per_vid) |
|
|
| |
| results = self.postprocessing(results) |
|
|
| return results |
|
|
| @torch.no_grad() |
| def inference_single_video( |
| self, |
| points, |
| fpn_masks, |
| out_cls_logits, |
| out_offsets, |
| ): |
| |
| |
| segs_all = [] |
| scores_all = [] |
| cls_idxs_all = [] |
|
|
| |
| for cls_i, offsets_i, pts_i, mask_i in zip( |
| out_cls_logits, out_offsets, points, fpn_masks |
| ): |
| |
| pred_prob = (cls_i.sigmoid() * mask_i.unsqueeze(-1)).flatten() |
|
|
| |
| |
| keep_idxs1 = (pred_prob > self.test_pre_nms_thresh) |
| pred_prob = pred_prob[keep_idxs1] |
| topk_idxs = keep_idxs1.nonzero(as_tuple=True)[0] |
|
|
| |
| num_topk = min(self.test_pre_nms_topk, topk_idxs.size(0)) |
| pred_prob, idxs = pred_prob.sort(descending=True) |
| pred_prob = pred_prob[:num_topk].clone() |
| topk_idxs = topk_idxs[idxs[:num_topk]].clone() |
|
|
| |
| pt_idxs = torch.div( |
| topk_idxs, self.num_classes, rounding_mode='floor' |
| ) |
| cls_idxs = torch.fmod(topk_idxs, self.num_classes) |
|
|
| |
| offsets = offsets_i[pt_idxs] |
| pts = pts_i[pt_idxs] |
|
|
| |
| seg_left = pts[:, 0] - offsets[:, 0] * pts[:, 3] |
| seg_right = pts[:, 0] + offsets[:, 1] * pts[:, 3] |
| pred_segs = torch.stack((seg_left, seg_right), -1) |
|
|
| |
| seg_areas = seg_right - seg_left |
| keep_idxs2 = seg_areas > self.test_duration_thresh |
|
|
| |
| segs_all.append(pred_segs[keep_idxs2]) |
| scores_all.append(pred_prob[keep_idxs2]) |
| cls_idxs_all.append(cls_idxs[keep_idxs2]) |
|
|
| |
| segs_all, scores_all, cls_idxs_all = [ |
| torch.cat(x) for x in [segs_all, scores_all, cls_idxs_all] |
| ] |
| results = {'segments' : segs_all, |
| 'scores' : scores_all, |
| 'labels' : cls_idxs_all} |
|
|
| return results |
|
|
| @torch.no_grad() |
| def postprocessing(self, results): |
| |
| |
| processed_results = [] |
| for results_per_vid in results: |
| |
| vidx = results_per_vid['video_id'] |
| |
| |
| |
| nframes = results_per_vid['feat_num_frames'] |
| gt_time = results_per_vid['gt_time'] |
| duration = results_per_vid['duration'] |
| |
| segs = results_per_vid['segments'].detach().cpu() |
| scores = results_per_vid['scores'].detach().cpu() |
| labels = results_per_vid['labels'].detach().cpu() |
| if self.test_nms_method != 'none': |
| |
| segs, scores, labels = batched_nms( |
| segs, scores, labels, |
| self.test_iou_threshold, |
| self.test_min_score, |
| self.test_max_seg_num, |
| use_soft_nms = (self.test_nms_method == 'soft'), |
| multiclass = self.test_multiclass_nms, |
| sigma = self.test_nms_sigma, |
| voting_thresh = self.test_voting_thresh |
| ) |
| |
| |
| |
| |
| |
| |
| |
| processed_results.append( |
| {'video_id' : vidx, |
| 'segments' : segs, |
| 'scores' : scores, |
| 'labels' : labels, |
| 'duration' : duration, |
| 'gt_time' : gt_time, |
| } |
| ) |
|
|
| return processed_results |
|
|