|
|
""" Reference source: https://github.com/tianyu0207/RTFM""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.nn.init as torch_init |
|
|
torch.set_default_tensor_type('torch.FloatTensor') |
|
|
|
|
|
|
|
|
def weight_init(m): |
|
|
classname = m.__class__.__name__ |
|
|
if classname.find('Conv') != -1 or classname.find('Linear') != -1: |
|
|
torch_init.xavier_uniform_(m.weight) |
|
|
if m.bias is not None: |
|
|
m.bias.data.fill_(0) |
|
|
|
|
|
|
|
|
class CVA(nn.Module): |
|
|
def __init__(self, input_dim=1024): |
|
|
""" |
|
|
Cross-View Attention (CVA) module. |
|
|
|
|
|
Args: |
|
|
input_dim (int): Dimension of the input features. |
|
|
""" |
|
|
super(CVA, self).__init__() |
|
|
drop_out_rate = 0.1 |
|
|
num_heads = 4 |
|
|
self.cross_attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, dropout=drop_out_rate, |
|
|
device='cuda') |
|
|
|
|
|
def forward(self, feature1, feature2): |
|
|
""" |
|
|
Args: |
|
|
feature1 (torch.Tensor): one path features. Shape: B x T x C. |
|
|
feature2 (torch.Tensor): another path features. Shape: B x T x C. |
|
|
|
|
|
Returns: |
|
|
out1 (torch.Tensor): Processed features after cross-attention. Shape: B x T x C. |
|
|
""" |
|
|
|
|
|
feature1 = F.layer_norm(feature1, [feature1.size(-1)]) |
|
|
feature2 = F.layer_norm(feature2, [feature2.size(-1)]) |
|
|
feature1 = feature1.permute(1, 0, 2) |
|
|
feature2 = feature2.permute(1, 0, 2) |
|
|
|
|
|
out1, _ = self.cross_attention(query=feature1, key=feature2, value=feature2) |
|
|
out1 = out1 + feature1 |
|
|
|
|
|
return out1 |
|
|
|
|
|
|
|
|
class Aggregate(nn.Module): |
|
|
def __init__(self, input_dim): |
|
|
""" |
|
|
An aggregate network including local temporal correlation learning, global temporal correlation learning, |
|
|
and feature fusion in MTFF. |
|
|
|
|
|
Args: |
|
|
input_dim (int): input features dim. |
|
|
""" |
|
|
super(Aggregate, self).__init__() |
|
|
bn = nn.BatchNorm1d |
|
|
num_heads = 4 |
|
|
self.input_dim = input_dim |
|
|
self.conv_1 = nn.Sequential( |
|
|
nn.Conv1d(in_channels=input_dim, out_channels=512, kernel_size=3, |
|
|
stride=1,dilation=1, padding=1), |
|
|
nn.LeakyReLU(negative_slope=5e-2), |
|
|
bn(512) |
|
|
) |
|
|
self.conv_2 = nn.Sequential( |
|
|
nn.Conv1d(in_channels=input_dim, out_channels=512, kernel_size=3, |
|
|
stride=1, dilation=2, padding=2), |
|
|
nn.LeakyReLU(negative_slope=5e-2), |
|
|
bn(512) |
|
|
) |
|
|
self.conv_3 = nn.Sequential( |
|
|
nn.Conv1d(in_channels=input_dim, out_channels=512, kernel_size=3, |
|
|
stride=1, dilation=4, padding=4), |
|
|
nn.LeakyReLU(negative_slope=5e-2), |
|
|
bn(512) |
|
|
) |
|
|
self.conv_4 = nn.Sequential( |
|
|
nn.Conv1d(in_channels=input_dim*3, out_channels=512, kernel_size=1, |
|
|
stride=1, padding=0, bias = False), |
|
|
nn.LeakyReLU(negative_slope=5e-2), |
|
|
) |
|
|
self.conv_5 = nn.Sequential( |
|
|
nn.Conv1d(in_channels=2048, out_channels=input_dim, kernel_size=3, |
|
|
stride=1, padding=1, bias=False), |
|
|
nn.LeakyReLU(negative_slope=5e-2), |
|
|
nn.BatchNorm1d(input_dim), |
|
|
) |
|
|
self.self_attention = nn.MultiheadAttention(embed_dim=512, num_heads=num_heads, |
|
|
dropout=0.1, device='cuda') |
|
|
|
|
|
def forward(self, input1, input2, input3): |
|
|
""" |
|
|
Args: |
|
|
input1 (torch.Tensor): long-frame-length features. Shape: T x B x C. |
|
|
input2 (torch.Tensor): medium-frame-length features. Shape: T x B x C. |
|
|
input3 (torch.Tensor): short-frame-length features. Shape: T x B x C. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Processed and fused output features. Shape: B x T x C. |
|
|
""" |
|
|
x1 = input1.permute(1, 2, 0) |
|
|
x2 = input2.permute(1, 2, 0) |
|
|
x3 = input3.permute(1, 2, 0) |
|
|
tensor_list = [x1, x2, x3] |
|
|
|
|
|
residual = torch.mean(torch.stack(tensor_list), dim=0) |
|
|
|
|
|
out1 = self.conv_1(x1) |
|
|
out2 = self.conv_2(x2) |
|
|
out3 = self.conv_3(x3) |
|
|
x = torch.cat([out1, out2, out3], dim=1) |
|
|
|
|
|
feature = torch.cat((x1, x2, x3), dim=1) |
|
|
out = self.conv_4(feature) |
|
|
out = out.permute(2, 0, 1) |
|
|
out = F.layer_norm(out, normalized_shape=[out.size(-1)]) |
|
|
out, _ = self.self_attention(out, out, out) |
|
|
out = out.permute(1, 2, 0) |
|
|
out = torch.cat((x, out), dim=1) |
|
|
out = self.conv_5(out) |
|
|
out = out + residual |
|
|
out = out.permute(0, 2, 1) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
class Encoder(nn.Module): |
|
|
def __init__(self, input_dim=1024, seg_num=32): |
|
|
""" |
|
|
Multi-Temporal Feature Fusion (MTFF) module. |
|
|
|
|
|
Args: |
|
|
input_dim (int): Dimension of the input features. |
|
|
seg_num (int): Number of snippets in a video. |
|
|
""" |
|
|
super(Encoder, self).__init__() |
|
|
self.drop_out_rate = 0.1 |
|
|
self.input_dim = input_dim |
|
|
self.min_temporal_dim = seg_num |
|
|
self.CVA1 = CVA(input_dim=input_dim) |
|
|
self.CVA2 = CVA(input_dim=input_dim) |
|
|
self.CVA3 = CVA(input_dim=input_dim) |
|
|
|
|
|
self.aggregate = Aggregate(input_dim=input_dim) |
|
|
|
|
|
def forward(self, feature1, feature2, feature3): |
|
|
""" |
|
|
Args: |
|
|
feature1 (torch.Tensor): long-frame-length features. Shape: B x T x C. |
|
|
(Batch size X The number of snippets x Input dimensions) |
|
|
feature2 (torch.Tensor): medium-frame-length features. Shape: B x T x C. |
|
|
feature3 (torch.Tensor): short-frame-length features. Shape: B x T x C. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Fused and processed output features. Shape: B x T x C. |
|
|
""" |
|
|
|
|
|
att1 = self.CVA1(feature1, feature2) |
|
|
att2 = self.CVA2(feature2, feature3) |
|
|
att3 = self.CVA3(feature3, feature1) |
|
|
|
|
|
out1 = self.aggregate(att1, att2, att3) |
|
|
|
|
|
return out1 |
|
|
|
|
|
|
|
|
class Model(nn.Module): |
|
|
def __init__(self, feature_dim, batch_size, seg_num=32): |
|
|
""" |
|
|
Multi-Temporal Feature Learning (MTFL) recognition model. |
|
|
|
|
|
Args: |
|
|
feature_dim (int): Dimension of the input features. |
|
|
batch_size (int): Batch size. |
|
|
seg_num (int): Number of snippets in a video. |
|
|
""" |
|
|
super(Model, self).__init__() |
|
|
self.batch_size = batch_size |
|
|
self.num_segments = seg_num |
|
|
self.k_abn = self.num_segments // 10 |
|
|
self.k_nor = self.num_segments // 10 |
|
|
|
|
|
self.Encoder = Encoder(input_dim=feature_dim, seg_num=seg_num) |
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(feature_dim, 512) |
|
|
self.fc2 = nn.Linear(512, 128) |
|
|
self.fc3 = nn.Linear(128, 18) |
|
|
|
|
|
self.drop_out = nn.Dropout(0.2) |
|
|
self.relu = nn.LeakyReLU(negative_slope=5e-2) |
|
|
self.sigmoid = nn.Sigmoid() |
|
|
self.apply(weight_init) |
|
|
|
|
|
def forward(self, input1, input2, input3): |
|
|
""" |
|
|
Args: |
|
|
input1 (torch.Tensor): long-frame-length features. Shape: B x T x feature_dim. |
|
|
input2 (torch.Tensor): medium-frame-length features. Shape: B x T x feature_dim. |
|
|
input3 (torch.Tensor): short-frame-length features. Shape: B x T x feature_dim. |
|
|
|
|
|
Returns: |
|
|
score_abnormal (torch.Tensor): The mean scores for top-3 abnormal instances. |
|
|
score_normal (torch.Tensor): The mean scores for top-3 normal instances. |
|
|
feat_select_abn (torch.Tensor): Selected abnormal features. |
|
|
feat_select_normal (torch.Tensor): Selected normal features. |
|
|
scores (torch.Tensor): All computed scores. Shape: B x T x the number of classes (18) |
|
|
""" |
|
|
k_abn = self.k_abn |
|
|
k_nor = self.k_nor |
|
|
ncrops = 1 |
|
|
|
|
|
|
|
|
out = self.Encoder(input1, input2, input3) |
|
|
bs, t, f = out.size() |
|
|
features = self.drop_out(out) |
|
|
|
|
|
|
|
|
scores = self.relu(self.fc1(features)) |
|
|
scores = self.drop_out(scores) |
|
|
scores = self.relu(self.fc2(scores)) |
|
|
scores = self.drop_out(scores) |
|
|
scores = self.sigmoid(self.fc3(scores)) |
|
|
scores = scores.view(bs, t, -1) |
|
|
|
|
|
normal_features = features[0:self.batch_size] |
|
|
normal_scores = scores[0:self.batch_size] |
|
|
|
|
|
abnormal_features = features[self.batch_size:] |
|
|
abnormal_scores = scores[self.batch_size:] |
|
|
|
|
|
|
|
|
feat_magnitudes = torch.norm(features, p=2, dim=2) |
|
|
feat_magnitudes = feat_magnitudes.view(bs, ncrops, -1).mean(1) |
|
|
nfea_magnitudes = feat_magnitudes[0:self.batch_size] |
|
|
afea_magnitudes = feat_magnitudes[self.batch_size:] |
|
|
n_size = nfea_magnitudes.shape[0] |
|
|
|
|
|
|
|
|
if nfea_magnitudes.shape[0] == 1: |
|
|
afea_magnitudes = nfea_magnitudes |
|
|
abnormal_scores = normal_scores |
|
|
abnormal_features = normal_features |
|
|
|
|
|
select_idx = torch.ones_like(nfea_magnitudes) |
|
|
select_idx = self.drop_out(select_idx) |
|
|
|
|
|
|
|
|
afea_magnitudes_drop = afea_magnitudes * select_idx |
|
|
idx_abn = torch.topk(afea_magnitudes_drop, k_abn, dim=1)[1] |
|
|
idx_abn_feat = idx_abn.unsqueeze(2).expand([-1, -1, abnormal_features.shape[2]]) |
|
|
|
|
|
abnormal_features = abnormal_features.view(n_size, ncrops, t, f) |
|
|
abnormal_features = abnormal_features.permute(1, 0, 2, 3) |
|
|
|
|
|
total_select_abn_feature = torch.zeros(0, device=input1.device) |
|
|
for abnormal_feature in abnormal_features: |
|
|
feat_select_abn = torch.gather(abnormal_feature, 1, idx_abn_feat) |
|
|
total_select_abn_feature = torch.cat((total_select_abn_feature, feat_select_abn)) |
|
|
|
|
|
idx_abn_score = idx_abn.unsqueeze(2).expand([-1, -1, abnormal_scores.shape[2]]) |
|
|
|
|
|
score_abnormal = torch.mean(torch.gather(abnormal_scores, 1, idx_abn_score), dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
select_idx_normal = torch.ones_like(nfea_magnitudes) |
|
|
select_idx_normal = self.drop_out(select_idx_normal) |
|
|
nfea_magnitudes_drop = nfea_magnitudes * select_idx_normal |
|
|
idx_normal = torch.topk(nfea_magnitudes_drop, k_nor, dim=1)[1] |
|
|
idx_normal_feat = idx_normal.unsqueeze(2).expand([-1, -1, normal_features.shape[2]]) |
|
|
|
|
|
normal_features = normal_features.view(n_size, ncrops, t, f) |
|
|
normal_features = normal_features.permute(1, 0, 2, 3) |
|
|
|
|
|
total_select_nor_feature = torch.zeros(0, device=input1.device) |
|
|
for nor_fea in normal_features: |
|
|
feat_select_normal = torch.gather(nor_fea, 1, idx_normal_feat) |
|
|
total_select_nor_feature = torch.cat((total_select_nor_feature, feat_select_normal)) |
|
|
|
|
|
idx_normal_score = idx_normal.unsqueeze(2).expand([-1, -1, normal_scores.shape[2]]) |
|
|
score_normal = torch.mean(torch.gather(normal_scores, 1, idx_normal_score), dim=1) |
|
|
|
|
|
feat_select_abn = total_select_abn_feature |
|
|
feat_select_normal = total_select_nor_feature |
|
|
|
|
|
return score_abnormal, score_normal, feat_select_abn, feat_select_normal, scores |
|
|
|