forensics-grpo / code /libs /modeling /meta_archs.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
34.9 kB
import math
import torch
from torch import nn
from torch.nn import functional as F
from .models import register_meta_arch, make_backbone, make_neck, make_generator
from .blocks import MaskedConv1D, Scale, LayerNorm
from .losses import ctr_diou_loss_1d, sigmoid_focal_loss
from ..utils import batched_nms
from IPython import embed
import torch
import torch.nn as nn
class TimeEmbedding(nn.Module):
def __init__(self, dim, max_period=10000):
super().__init__()
self.dim = dim
self.max_period = max_period
self.mlp = nn.Sequential(
nn.Linear(dim, dim*2),
nn.GELU(),
nn.Linear(dim*2, dim)
)
def forward(self, t: torch.Tensor):
half = self.dim // 2
freqs = torch.exp(-math.log(self.max_period) * torch.arange(0, half, device=t.device) / half)
args = t.float().unsqueeze(-1) * freqs # [..., half]
emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if self.dim % 2 == 1:
emb = torch.cat([emb, emb[..., :1]*0], dim=-1)
return self.mlp(emb) # [B, dim]
class Denoiser1D(nn.Module):
def __init__(self, channels=128, temb_dim=128):
super().__init__()
self.in_conv = nn.Conv1d(channels, channels, 1)
self.dw = nn.Conv1d(channels, channels, kernel_size=3, padding=1, groups=channels)
self.pw = nn.Conv1d(channels, channels, 1)
self.gn = nn.GroupNorm(8, channels)
self.to_scale = nn.Linear(temb_dim, channels)
self.to_shift = nn.Linear(temb_dim, channels)
self.out_conv = nn.Conv1d(channels, channels, 1)
def forward(self, x, temb):
# x: [B,C,L], temb: [B,temb_dim]
h = self.in_conv(x)
scale = self.to_scale(temb).unsqueeze(-1) # [B,C,1]
shift = self.to_shift(temb).unsqueeze(-1)
h = self.gn(h) * (1 + scale) + shift
h = F.gelu(self.dw(h))
h = F.gelu(self.pw(h))
h = self.out_conv(h)
return x + h # residual
class NDRefiner(nn.Module):
"""
"""
def __init__(self, channels=128, steps=4, beta_start=1e-4, beta_end=0.02, eta=0.0):
super().__init__()
self.channels = channels
self.steps = steps
self.eta = eta
betas = torch.linspace(beta_start, beta_end, steps)
alphas = 1.0 - betas
alpha_bars = torch.cumprod(alphas, dim=0)
self.register_buffer('betas', betas)
self.register_buffer('alphas', alphas)
self.register_buffer('alpha_bars', alpha_bars)
self.temb = TimeEmbedding(channels)
self.denoiser = Denoiser1D(channels, temb_dim=channels)
def q_sample(self, x0, t_idx, noise=None):
if noise is None:
noise = torch.randn_like(x0)
a_bar = self.alpha_bars[t_idx].view(-1, 1, 1)
return a_bar.sqrt() * x0 + (1 - a_bar).sqrt() * noise, noise
def forward(self, x):
"""
"""
B, C, L = x.shape
device = x.device
if self.training:
t_idx = torch.randint(0, self.steps, (B,), device=device)
else:
t_idx = torch.full((B,), self.steps - 1, device=device)
x_t, noise = self.q_sample(x, t_idx)
for idx in reversed(range(self.steps)):
t = torch.full((B,), idx, device=device)
t_emb = self.temb(t)
eps_pred = self.denoiser(x_t, t_emb) # [B,C,L]
a_bar_t = self.alpha_bars[idx]
a_t = self.alphas[idx]
a_bar_prev = self.alpha_bars[idx - 1] if idx > 0 else torch.tensor(1.0, device=device)
a_bar_prev = a_bar_prev.view(1, 1, 1)
x0_pred = (x_t - (1 - a_bar_t).sqrt() * eps_pred) / a_bar_t.sqrt()
x0_pred = x0_pred.clamp(-3.0, 3.0)
sigma_t = self.eta * ((1 - a_bar_prev)/(1 - a_bar_t) * (1 - a_t)).sqrt()
noise = torch.randn_like(x_t) if sigma_t > 0 else 0.0
x_t = (
a_bar_prev.sqrt() * x0_pred +
(1 - a_bar_prev - sigma_t**2).sqrt() * eps_pred +
sigma_t * noise
)
return x_t
class PtTransformerClsHead(nn.Module):
"""
1D Conv heads for classification
"""
def __init__(
self,
input_dim,
feat_dim,
num_classes,
prior_prob=0.01,
num_layers=3,
kernel_size=3,
act_layer=nn.ReLU,
with_ln=False,
empty_cls = []
):
super().__init__()
self.act = act_layer()
# build the head
self.head = nn.ModuleList()
self.norm = nn.ModuleList()
for idx in range(num_layers-1):
if idx == 0:
in_dim = input_dim
out_dim = feat_dim
else:
in_dim = feat_dim
out_dim = feat_dim
self.head.append(
MaskedConv1D(
in_dim, out_dim, kernel_size,
stride=1,
padding=kernel_size//2,
bias=(not with_ln)
)
)
if with_ln:
self.norm.append(
LayerNorm(out_dim)
)
else:
self.norm.append(nn.Identity())
# classifier
self.cls_head = MaskedConv1D(
feat_dim, num_classes, kernel_size,
stride=1, padding=kernel_size//2
)
bias_value = -(math.log((1 - prior_prob) / prior_prob))
torch.nn.init.constant_(self.cls_head.conv.bias, bias_value)
if len(empty_cls) > 0:
bias_value = -(math.log((1 - 1e-6) / 1e-6))
for idx in empty_cls:
torch.nn.init.constant_(self.cls_head.conv.bias[idx], bias_value)
def forward(self, fpn_feats, fpn_masks):
assert len(fpn_feats) == len(fpn_masks)
# apply the classifier for each pyramid level
out_logits = tuple()
for _, (cur_feat, cur_mask) in enumerate(zip(fpn_feats, fpn_masks)):
cur_out = cur_feat
for idx in range(len(self.head)):
cur_out, _ = self.head[idx](cur_out, cur_mask)
cur_out = self.act(self.norm[idx](cur_out))
cur_logits, _ = self.cls_head(cur_out, cur_mask)
out_logits += (cur_logits, )
return out_logits
class PtTransformerRegHead(nn.Module):
"""
Shared 1D Conv heads for regression
Simlar logic as PtTransformerClsHead with separated implementation for clarity
"""
def __init__(
self,
input_dim,
feat_dim,
fpn_levels,
num_layers=3,
kernel_size=3,
act_layer=nn.ReLU,
with_ln=False
):
super().__init__()
self.fpn_levels = fpn_levels
self.act = act_layer()
# build the conv head
self.head = nn.ModuleList()
self.norm = nn.ModuleList()
for idx in range(num_layers-1):
if idx == 0:
in_dim = input_dim
out_dim = feat_dim
else:
in_dim = feat_dim
out_dim = feat_dim
self.head.append(
MaskedConv1D(
in_dim, out_dim, kernel_size,
stride=1,
padding=kernel_size//2,
bias=(not with_ln)
)
)
if with_ln:
self.norm.append(
LayerNorm(out_dim)
)
else:
self.norm.append(nn.Identity())
self.scale = nn.ModuleList()
for idx in range(fpn_levels):
self.scale.append(Scale())
# segment regression
self.offset_head = MaskedConv1D(
feat_dim, 2, kernel_size,
stride=1, padding=kernel_size//2
)
def forward(self, fpn_feats, fpn_masks):
assert len(fpn_feats) == len(fpn_masks)
assert len(fpn_feats) == self.fpn_levels
# apply the classifier for each pyramid level
out_offsets = tuple()
for l, (cur_feat, cur_mask) in enumerate(zip(fpn_feats, fpn_masks)):
cur_out = cur_feat
for idx in range(len(self.head)):
cur_out, _ = self.head[idx](cur_out, cur_mask)
cur_out = self.act(self.norm[idx](cur_out))
cur_offsets, _ = self.offset_head(cur_out, cur_mask)
out_offsets += (F.relu(self.scale[l](cur_offsets)), )
# fpn_masks remains the same
return out_offsets
@register_meta_arch("LocPointTransformer")
class PtTransformer(nn.Module):
"""
Transformer based model for single stage action localization
"""
def __init__(
self,
backbone_type, # a string defines which backbone we use
fpn_type, # a string defines which fpn we use
backbone_arch, # a tuple defines # layers in embed / stem / branch
scale_factor, # scale factor between branch layers
input_dim, # input feat dim
max_seq_len, # max sequence length (used for training)
max_buffer_len_factor, # max buffer size (defined a factor of max_seq_len)
n_head, # number of heads for self-attention in transformer
n_mha_win_size, # window size for self attention; -1 to use full seq
embd_kernel_size, # kernel size of the embedding network
embd_dim, # output feat channel of the embedding network
embd_with_ln, # attach layernorm to embedding network
fpn_dim, # feature dim on FPN
fpn_with_ln, # if to apply layer norm at the end of fpn
fpn_start_level, # start level of fpn
head_dim, # feature dim for head
regression_range, # regression range on each level of FPN
head_num_layers, # number of layers in the head (including the classifier)
head_kernel_size, # kernel size for reg/cls heads
head_with_ln, # attache layernorm to reg/cls heads
use_abs_pe, # if to use abs position encoding
use_rel_pe, # if to use rel position encoding
num_classes, # number of action classes
train_cfg, # other cfg for training
test_cfg # other cfg for testing
):
super().__init__()
# re-distribute params to backbone / neck / head
self.fpn_strides = [scale_factor**i for i in range(
fpn_start_level, backbone_arch[-1]+1
)]
self.reg_range = regression_range
assert len(self.fpn_strides) == len(self.reg_range)
self.scale_factor = scale_factor
# #classes = num_classes + 1 (background) with last category as background
# e.g., num_classes = 10 -> 0, 1, ..., 9 as actions, 10 as background
self.num_classes = num_classes
# check the feature pyramid and local attention window size
self.max_seq_len = max_seq_len
if isinstance(n_mha_win_size, int):
self.mha_win_size = [n_mha_win_size]*(1 + backbone_arch[-1])
else:
assert len(n_mha_win_size) == (1 + backbone_arch[-1])
self.mha_win_size = n_mha_win_size
max_div_factor = 1
for l, (s, w) in enumerate(zip(self.fpn_strides, self.mha_win_size)):
stride = s * (w // 2) * 2 if w > 1 else s
assert max_seq_len % stride == 0, "max_seq_len must be divisible by fpn stride and window size"
if max_div_factor < stride:
max_div_factor = stride
self.max_div_factor = max_div_factor
# training time config
self.train_center_sample = train_cfg['center_sample']
assert self.train_center_sample in ['radius', 'none']
self.train_center_sample_radius = train_cfg['center_sample_radius']
self.train_loss_weight = train_cfg['loss_weight']
self.train_cls_prior_prob = train_cfg['cls_prior_prob']
self.train_dropout = train_cfg['dropout']
self.train_droppath = train_cfg['droppath']
self.train_label_smoothing = train_cfg['label_smoothing']
# test time config
self.test_pre_nms_thresh = test_cfg['pre_nms_thresh']
self.test_pre_nms_topk = test_cfg['pre_nms_topk']
self.test_iou_threshold = test_cfg['iou_threshold']
self.test_min_score = test_cfg['min_score']
self.test_max_seg_num = test_cfg['max_seg_num']
self.test_nms_method = test_cfg['nms_method']
assert self.test_nms_method in ['soft', 'hard', 'none']
self.test_duration_thresh = test_cfg['duration_thresh']
self.test_multiclass_nms = test_cfg['multiclass_nms']
self.test_nms_sigma = test_cfg['nms_sigma']
self.test_voting_thresh = test_cfg['voting_thresh']
# we will need a better way to dispatch the params to backbones / necks
# backbone network: conv + transformer
assert backbone_type in ['convTransformer', 'conv']
if backbone_type == 'convTransformer':
self.backbone = make_backbone(
'convTransformer',
**{
'n_in' : input_dim,
'n_embd' : embd_dim,
'n_head': n_head,
'n_embd_ks': embd_kernel_size,
'max_len': max_seq_len,
'arch' : backbone_arch,
'mha_win_size': self.mha_win_size,
'scale_factor' : scale_factor,
'with_ln' : embd_with_ln,
'attn_pdrop' : 0.0,
'proj_pdrop' : self.train_dropout,
'path_pdrop' : self.train_droppath,
'use_abs_pe' : use_abs_pe,
'use_rel_pe' : use_rel_pe
}
)
else:
self.backbone = make_backbone(
'conv',
**{
'n_in': input_dim,
'n_embd': embd_dim,
'n_embd_ks': embd_kernel_size,
'arch': backbone_arch,
'scale_factor': scale_factor,
'with_ln' : embd_with_ln
}
)
# fpn network: convs
assert fpn_type in ['fpn', 'identity']
self.neck = make_neck(
fpn_type,
**{
'in_channels' : [embd_dim] * (backbone_arch[-1] + 1),
'out_channel' : fpn_dim,
'scale_factor' : scale_factor,
'start_level' : fpn_start_level,
'with_ln' : fpn_with_ln
}
)
# location generator: points
self.point_generator = make_generator(
'point',
**{
'max_seq_len' : max_seq_len * max_buffer_len_factor,
'fpn_strides' : self.fpn_strides,
'regression_range' : self.reg_range
}
)
# classfication and regerssion heads
self.cls_head = PtTransformerClsHead(
fpn_dim, head_dim, self.num_classes,
kernel_size=head_kernel_size,
prior_prob=self.train_cls_prior_prob,
with_ln=head_with_ln,
num_layers=head_num_layers,
empty_cls=train_cfg['head_empty_cls']
)
self.reg_head = PtTransformerRegHead(
fpn_dim, head_dim, len(self.fpn_strides),
kernel_size=head_kernel_size,
num_layers=head_num_layers,
with_ln=head_with_ln
)
# maintain an EMA of #foreground to stabilize the loss normalizer
# useful for small mini-batch training
self.loss_normalizer = train_cfg['init_loss_norm']
self.loss_normalizer_momentum = 0.9
self.nd_blocks = nn.ModuleList([
NDRefiner(channels=fpn_dim,
steps=train_cfg['denoise_steps'],
beta_start=train_cfg['beta_start_end'][0], beta_end=train_cfg['beta_start_end'][1],
eta=train_cfg['eta'],
)
for _ in range(6)
])
@property
def device(self):
return list(set(p.device for p in self.parameters()))[0]
def forward(self, video_list):
# batch the video list into feats (B, C, T) and masks (B, 1, T)
batched_inputs, batched_masks = self.preprocessing(video_list)
# forward the network (backbone -> neck -> heads)
feats, masks = self.backbone(batched_inputs, batched_masks)
fpn_feats, fpn_masks = self.neck(feats, masks)
# compute the point coordinate along the FPN
# this is used for computing the GT or decode the final results
# points: List[T x 4] with length = # fpn levels
# (shared across all samples in the mini-batch)
fpn_feats = [self.nd_blocks[i](anchor) for i, anchor in enumerate(fpn_feats)]
points = self.point_generator(fpn_feats)
# out_cls: List[B, #cls + 1, T_i]
out_cls_logits = self.cls_head(fpn_feats, fpn_masks)
# out_offset: List[B, 2, T_i]
out_offsets = self.reg_head(fpn_feats, fpn_masks)
# permute the outputs
# out_cls: F List[B, #cls, T_i] -> F List[B, T_i, #cls]
out_cls_logits = [x.permute(0, 2, 1) for x in out_cls_logits]
# out_offset: F List[B, 2 (xC), T_i] -> F List[B, T_i, 2 (xC)]
out_offsets = [x.permute(0, 2, 1) for x in out_offsets]
# fpn_masks: F list[B, 1, T_i] -> F List[B, T_i]
fpn_masks = [x.squeeze(1) for x in fpn_masks]
# return loss during training
if self.training:
# generate segment/lable List[N x 2] / List[N] with length = B
assert video_list[0]['segments'] is not None, "GT action labels does not exist"
assert video_list[0]['labels'] is not None, "GT action labels does not exist"
gt_segments = [x['segments'].to(self.device) for x in video_list]
gt_labels = [x['labels'].to(self.device) for x in video_list]
# compute the gt labels for cls & reg
# list of prediction targets
# embed()
gt_cls_labels, gt_offsets = self.label_points(
points, gt_segments, gt_labels)
# compute the loss and return
losses = self.losses(
fpn_masks,
out_cls_logits, out_offsets,
gt_cls_labels, gt_offsets
)
return losses
else:
# decode the actions (sigmoid / stride, etc)
results = self.inference(
video_list, points, fpn_masks,
out_cls_logits, out_offsets
)
return results
@torch.no_grad()
def preprocessing(self, video_list, padding_val=0.0):
"""
Generate batched features and masks from a list of dict items
"""
feats = [x['feats'] for x in video_list]
feats_lens = torch.as_tensor([feat.shape[-1] for feat in feats])
max_len = feats_lens.max(0).values.item()
if self.training:
assert max_len <= self.max_seq_len, "Input length must be smaller than max_seq_len during training"
# set max_len to self.max_seq_len
max_len = self.max_seq_len
# batch input shape B, C, T
batch_shape = [len(feats), feats[0].shape[0], max_len]
batched_inputs = feats[0].new_full(batch_shape, padding_val)
for feat, pad_feat in zip(feats, batched_inputs):
pad_feat[..., :feat.shape[-1]].copy_(feat)
else:
assert len(video_list) == 1, "Only support batch_size = 1 during inference"
# input length < self.max_seq_len, pad to max_seq_len
if max_len <= self.max_seq_len:
max_len = self.max_seq_len
else:
# pad the input to the next divisible size
stride = self.max_div_factor
max_len = (max_len + (stride - 1)) // stride * stride
padding_size = [0, max_len - feats_lens[0]]
batched_inputs = F.pad(
feats[0], padding_size, value=padding_val).unsqueeze(0)
# generate the mask
batched_masks = torch.arange(max_len)[None, :] < feats_lens[:, None]
# push to device
batched_inputs = batched_inputs.to(self.device)
batched_masks = batched_masks.unsqueeze(1).to(self.device)
return batched_inputs, batched_masks
@torch.no_grad()
def label_points(self, points, gt_segments, gt_labels):
# concat points on all fpn levels List[T x 4] -> F T x 4
# This is shared for all samples in the mini-batch
num_levels = len(points)
concat_points = torch.cat(points, dim=0)
gt_cls, gt_offset = [], []
# loop over each video sample
for gt_segment, gt_label in zip(gt_segments, gt_labels):
cls_targets, reg_targets = self.label_points_single_video(
concat_points, gt_segment, gt_label
)
# append to list (len = # images, each of size FT x C)
gt_cls.append(cls_targets)
gt_offset.append(reg_targets)
return gt_cls, gt_offset
@torch.no_grad()
def label_points_single_video(self, concat_points, gt_segment, gt_label):
# concat_points : F T x 4 (t, regressoin range, stride)
# gt_segment : N (#Events) x 2
# gt_label : N (#Events) x 1
num_pts = concat_points.shape[0]
num_gts = gt_segment.shape[0]
# corner case where current sample does not have actions
if num_gts == 0:
cls_targets = gt_segment.new_full((num_pts, self.num_classes), 0)
reg_targets = gt_segment.new_zeros((num_pts, 2))
return cls_targets, reg_targets
# compute the lengths of all segments -> F T x N
lens = gt_segment[:, 1] - gt_segment[:, 0]
lens = lens[None, :].repeat(num_pts, 1)
# compute the distance of every point to each segment boundary
# auto broadcasting for all reg target-> F T x N x2
gt_segs = gt_segment[None].expand(num_pts, num_gts, 2)
left = concat_points[:, 0, None] - gt_segs[:, :, 0]
right = gt_segs[:, :, 1] - concat_points[:, 0, None]
reg_targets = torch.stack((left, right), dim=-1)
if self.train_center_sample == 'radius':
# center of all segments F T x N
center_pts = 0.5 * (gt_segs[:, :, 0] + gt_segs[:, :, 1])
# center sampling based on stride radius
# compute the new boundaries:
# concat_points[:, 3] stores the stride
t_mins = \
center_pts - concat_points[:, 3, None] * self.train_center_sample_radius
t_maxs = \
center_pts + concat_points[:, 3, None] * self.train_center_sample_radius
# prevent t_mins / maxs from over-running the action boundary
# left: torch.maximum(t_mins, gt_segs[:, :, 0])
# right: torch.minimum(t_maxs, gt_segs[:, :, 1])
# F T x N (distance to the new boundary)
cb_dist_left = concat_points[:, 0, None] \
- torch.maximum(t_mins, gt_segs[:, :, 0])
cb_dist_right = torch.minimum(t_maxs, gt_segs[:, :, 1]) \
- concat_points[:, 0, None]
# F T x N x 2
center_seg = torch.stack(
(cb_dist_left, cb_dist_right), -1)
# F T x N
inside_gt_seg_mask = center_seg.min(-1)[0] > 0
else:
# inside an gt action
inside_gt_seg_mask = reg_targets.min(-1)[0] > 0
# limit the regression range for each location
max_regress_distance = reg_targets.max(-1)[0]
# F T x N
inside_regress_range = torch.logical_and(
(max_regress_distance >= concat_points[:, 1, None]),
(max_regress_distance <= concat_points[:, 2, None])
)
# if there are still more than one actions for one moment
# pick the one with the shortest duration (easiest to regress)
max_int = 1000000
#https://gemini.google.com/app/44208acd63bc96b5
'''
In [6]: lens.masked_fill_(inside_gt_seg_mask==0, float('inf'))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 lens.masked_fill_(inside_gt_seg_mask==0, float('inf'))
RuntimeError: value cannot be converted to type int64_t without overflow: inf
'''
lens.masked_fill_(inside_gt_seg_mask==0, max_int)
lens.masked_fill_(inside_regress_range==0, max_int)
# F T x N -> F T
min_len, min_len_inds = lens.min(dim=1)
# corner case: multiple actions with very similar durations (e.g., THUMOS14)
min_len_mask = torch.logical_and(
(lens <= (min_len[:, None] + 1e-3)), (lens < max_int)
).to(reg_targets.dtype)
# cls_targets: F T x C; reg_targets F T x 2
gt_label_one_hot = F.one_hot(
gt_label, self.num_classes
).to(reg_targets.dtype)
cls_targets = min_len_mask @ gt_label_one_hot
# to prevent multiple GT actions with the same label and boundaries
cls_targets.clamp_(min=0.0, max=1.0)
# OK to use min_len_inds
reg_targets = reg_targets[range(num_pts), min_len_inds]
# normalization based on stride
reg_targets /= concat_points[:, 3, None]
return cls_targets, reg_targets
def losses(
self, fpn_masks,
out_cls_logits, out_offsets,
gt_cls_labels, gt_offsets
):
# fpn_masks, out_*: F (List) [B, T_i, C]
# gt_* : B (list) [F T, C]
# fpn_masks -> (B, FT)
valid_mask = torch.cat(fpn_masks, dim=1)
# 1. classification loss
# stack the list -> (B, FT) -> (# Valid, )
gt_cls = torch.stack(gt_cls_labels)
pos_mask = torch.logical_and((gt_cls.sum(-1) > 0), valid_mask)
# cat the predicted offsets -> (B, FT, 2 (xC)) -> # (#Pos, 2 (xC))
pred_offsets = torch.cat(out_offsets, dim=1)[pos_mask]
gt_offsets = torch.stack(gt_offsets)[pos_mask]
# update the loss normalizer
num_pos = pos_mask.sum().item()
self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + (
1 - self.loss_normalizer_momentum
) * max(num_pos, 1)
# gt_cls is already one hot encoded now, simply masking out
gt_target = gt_cls[valid_mask]
# optinal label smoothing
gt_target *= 1 - self.train_label_smoothing
gt_target += self.train_label_smoothing / (self.num_classes + 1)
# focal loss
cls_loss = sigmoid_focal_loss(
torch.cat(out_cls_logits, dim=1)[valid_mask],
gt_target,
reduction='sum'
)
cls_loss /= self.loss_normalizer
# 2. regression using IoU/GIoU loss (defined on positive samples)
if num_pos == 0:
reg_loss = 0 * pred_offsets.sum()
else:
# giou loss defined on positive samples
reg_loss = ctr_diou_loss_1d(
pred_offsets,
gt_offsets,
reduction='sum'
)
reg_loss /= self.loss_normalizer
if self.train_loss_weight > 0:
loss_weight = self.train_loss_weight
else:
loss_weight = cls_loss.detach() / max(reg_loss.item(), 0.01)
# return a dict of losses
final_loss = cls_loss + reg_loss * loss_weight
return {'cls_loss' : cls_loss,
'reg_loss' : reg_loss,
'final_loss' : final_loss}
@torch.no_grad()
def inference(
self,
video_list,
points, fpn_masks,
out_cls_logits, out_offsets
):
# video_list B (list) [dict]
# points F (list) [T_i, 4]
# fpn_masks, out_*: F (List) [B, T_i, C]
results = []
# 1: gather video meta information
vid_idxs = [x['video_id'] for x in video_list]
durations = [x['duration'] for x in video_list]
gt_times = [x['gt_time'] for x in video_list]
# vid_fps = [x['fps'] for x in video_list]
# vid_lens = [x['duration'] for x in video_list]
# vid_ft_stride = [x['feat_stride'] for x in video_list]
vid_ft_nframes = [x['feat_num_frames'] for x in video_list]
# 2: inference on each single video and gather the results
# upto this point, all results use timestamps defined on feature grids
# for idx, (vidx, fps, vlen, stride, nframes) in enumerate(
# zip(vid_idxs, vid_fps, vid_lens, vid_ft_stride, vid_ft_nframes)
# ):
for idx, (vidx, nframes, duration, gt_time) in enumerate(
zip(vid_idxs, vid_ft_nframes, durations, gt_times)
):
# gather per-video outputs
cls_logits_per_vid = [x[idx] for x in out_cls_logits]
offsets_per_vid = [x[idx] for x in out_offsets]
fpn_masks_per_vid = [x[idx] for x in fpn_masks]
# inference on a single video (should always be the case)
results_per_vid = self.inference_single_video(
points, fpn_masks_per_vid,
cls_logits_per_vid, offsets_per_vid
)
# pass through video meta info
results_per_vid['video_id'] = vidx
# results_per_vid['fps'] = fps
# results_per_vid['duration'] = vlen
# results_per_vid['feat_stride'] = stride
results_per_vid['duration'] = duration
results_per_vid['gt_time'] = gt_time
results_per_vid['feat_num_frames'] = nframes
results.append(results_per_vid)
# step 3: postprocssing
results = self.postprocessing(results)
return results
@torch.no_grad()
def inference_single_video(
self,
points,
fpn_masks,
out_cls_logits,
out_offsets,
):
# points F (list) [T_i, 4]
# fpn_masks, out_*: F (List) [T_i, C]
segs_all = []
scores_all = []
cls_idxs_all = []
# loop over fpn levels
for cls_i, offsets_i, pts_i, mask_i in zip(
out_cls_logits, out_offsets, points, fpn_masks
):
# sigmoid normalization for output logits
pred_prob = (cls_i.sigmoid() * mask_i.unsqueeze(-1)).flatten()
# Apply filtering to make NMS faster following detectron2
# 1. Keep seg with confidence score > a threshold
keep_idxs1 = (pred_prob > self.test_pre_nms_thresh)
pred_prob = pred_prob[keep_idxs1]
topk_idxs = keep_idxs1.nonzero(as_tuple=True)[0]
# 2. Keep top k top scoring boxes only
num_topk = min(self.test_pre_nms_topk, topk_idxs.size(0))
pred_prob, idxs = pred_prob.sort(descending=True)
pred_prob = pred_prob[:num_topk].clone()
topk_idxs = topk_idxs[idxs[:num_topk]].clone()
# fix a warning in pytorch 1.9
pt_idxs = torch.div(
topk_idxs, self.num_classes, rounding_mode='floor'
)
cls_idxs = torch.fmod(topk_idxs, self.num_classes)
# 3. gather predicted offsets
offsets = offsets_i[pt_idxs]
pts = pts_i[pt_idxs]
# 4. compute predicted segments (denorm by stride for output offsets)
seg_left = pts[:, 0] - offsets[:, 0] * pts[:, 3]
seg_right = pts[:, 0] + offsets[:, 1] * pts[:, 3]
pred_segs = torch.stack((seg_left, seg_right), -1)
# 5. Keep seg with duration > a threshold (relative to feature grids)
seg_areas = seg_right - seg_left
keep_idxs2 = seg_areas > self.test_duration_thresh
# *_all : N (filtered # of segments) x 2 / 1
segs_all.append(pred_segs[keep_idxs2])
scores_all.append(pred_prob[keep_idxs2])
cls_idxs_all.append(cls_idxs[keep_idxs2])
# cat along the FPN levels (F N_i, C)
segs_all, scores_all, cls_idxs_all = [
torch.cat(x) for x in [segs_all, scores_all, cls_idxs_all]
]
results = {'segments' : segs_all,
'scores' : scores_all,
'labels' : cls_idxs_all}
return results
@torch.no_grad()
def postprocessing(self, results):
# input : list of dictionary items
# (1) push to CPU; (2) NMS; (3) convert to actual time stamps
processed_results = []
for results_per_vid in results:
# unpack the meta info
vidx = results_per_vid['video_id']
# fps = results_per_vid['fps']
# vlen = results_per_vid['duration']
# stride = results_per_vid['feat_stride']
nframes = results_per_vid['feat_num_frames']
gt_time = results_per_vid['gt_time']
duration = results_per_vid['duration']
# 1: unpack the results and move to CPU
segs = results_per_vid['segments'].detach().cpu()
scores = results_per_vid['scores'].detach().cpu()
labels = results_per_vid['labels'].detach().cpu()
if self.test_nms_method != 'none':
# 2: batched nms (only implemented on CPU)
segs, scores, labels = batched_nms(
segs, scores, labels,
self.test_iou_threshold,
self.test_min_score,
self.test_max_seg_num,
use_soft_nms = (self.test_nms_method == 'soft'),
multiclass = self.test_multiclass_nms,
sigma = self.test_nms_sigma,
voting_thresh = self.test_voting_thresh
)
# # 3: convert from feature grids to seconds
# if segs.shape[0] > 0:
# segs = (segs * stride + 0.5 * nframes) / fps
# # truncate all boundaries within [0, duration]
# segs[segs<=0.0] *= 0.0
# segs[segs>=vlen] = segs[segs>=vlen] * 0.0 + vlen
# 4: repack the results
processed_results.append(
{'video_id' : vidx,
'segments' : segs,
'scores' : scores,
'labels' : labels,
'duration' : duration,
'gt_time' : gt_time,
}
)
return processed_results