erkutt's picture
Upload open source code of MTFL model
28e129b verified
""" Reference source: https://github.com/tianyu0207/RTFM"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as torch_init
torch.set_default_tensor_type('torch.FloatTensor')
def weight_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1 or classname.find('Linear') != -1:
torch_init.xavier_uniform_(m.weight)
if m.bias is not None:
m.bias.data.fill_(0)
class CVA(nn.Module):
def __init__(self, input_dim=1024):
"""
Cross-View Attention (CVA) module.
Args:
input_dim (int): Dimension of the input features.
"""
super(CVA, self).__init__()
drop_out_rate = 0.1
num_heads = 4
self.cross_attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, dropout=drop_out_rate,
device='cuda')
def forward(self, feature1, feature2):
"""
Args:
feature1 (torch.Tensor): one path features. Shape: B x T x C.
feature2 (torch.Tensor): another path features. Shape: B x T x C.
Returns:
out1 (torch.Tensor): Processed features after cross-attention. Shape: B x T x C.
"""
feature1 = F.layer_norm(feature1, [feature1.size(-1)])
feature2 = F.layer_norm(feature2, [feature2.size(-1)])
feature1 = feature1.permute(1, 0, 2) # T B C
feature2 = feature2.permute(1, 0, 2)
out1, _ = self.cross_attention(query=feature1, key=feature2, value=feature2) # T B C (For test:32 1 1024)
out1 = out1 + feature1 # residual connection
return out1 # B T C
class Aggregate(nn.Module):
def __init__(self, input_dim):
"""
An aggregate network including local temporal correlation learning, global temporal correlation learning,
and feature fusion in MTFF.
Args:
input_dim (int): input features dim.
"""
super(Aggregate, self).__init__()
bn = nn.BatchNorm1d
num_heads = 4
self.input_dim = input_dim
self.conv_1 = nn.Sequential(
nn.Conv1d(in_channels=input_dim, out_channels=512, kernel_size=3,
stride=1,dilation=1, padding=1),
nn.LeakyReLU(negative_slope=5e-2),
bn(512)
)
self.conv_2 = nn.Sequential(
nn.Conv1d(in_channels=input_dim, out_channels=512, kernel_size=3,
stride=1, dilation=2, padding=2),
nn.LeakyReLU(negative_slope=5e-2),
bn(512)
)
self.conv_3 = nn.Sequential(
nn.Conv1d(in_channels=input_dim, out_channels=512, kernel_size=3,
stride=1, dilation=4, padding=4),
nn.LeakyReLU(negative_slope=5e-2),
bn(512)
)
self.conv_4 = nn.Sequential(
nn.Conv1d(in_channels=input_dim*3, out_channels=512, kernel_size=1,
stride=1, padding=0, bias = False),
nn.LeakyReLU(negative_slope=5e-2),
)
self.conv_5 = nn.Sequential(
nn.Conv1d(in_channels=2048, out_channels=input_dim, kernel_size=3,
stride=1, padding=1, bias=False),
nn.LeakyReLU(negative_slope=5e-2),
nn.BatchNorm1d(input_dim),
)
self.self_attention = nn.MultiheadAttention(embed_dim=512, num_heads=num_heads,
dropout=0.1, device='cuda')
def forward(self, input1, input2, input3):
"""
Args:
input1 (torch.Tensor): long-frame-length features. Shape: T x B x C.
input2 (torch.Tensor): medium-frame-length features. Shape: T x B x C.
input3 (torch.Tensor): short-frame-length features. Shape: T x B x C.
Returns:
torch.Tensor: Processed and fused output features. Shape: B x T x C.
"""
x1 = input1.permute(1, 2, 0) # B C T
x2 = input2.permute(1, 2, 0)
x3 = input3.permute(1, 2, 0)
tensor_list = [x1, x2, x3]
residual = torch.mean(torch.stack(tensor_list), dim=0)
out1 = self.conv_1(x1) # B C/2 T
out2 = self.conv_2(x2)
out3 = self.conv_3(x3)
x = torch.cat([out1, out2, out3], dim=1) # B 3C/2 T
feature = torch.cat((x1, x2, x3), dim=1)
out = self.conv_4(feature)
out = out.permute(2, 0, 1) # T B C/2
out = F.layer_norm(out, normalized_shape=[out.size(-1)])
out, _ = self.self_attention(out, out, out) # T B C/2
out = out.permute(1, 2, 0) # B C/2 T
out = torch.cat((x, out), dim=1) # B 2C T
out = self.conv_5(out) # fuse all the features together
out = out + residual
out = out.permute(0, 2, 1)
return out
class Encoder(nn.Module):
def __init__(self, input_dim=1024, seg_num=32):
"""
Multi-Temporal Feature Fusion (MTFF) module.
Args:
input_dim (int): Dimension of the input features.
seg_num (int): Number of snippets in a video.
"""
super(Encoder, self).__init__()
self.drop_out_rate = 0.1
self.input_dim = input_dim
self.min_temporal_dim = seg_num
self.CVA1 = CVA(input_dim=input_dim)
self.CVA2 = CVA(input_dim=input_dim)
self.CVA3 = CVA(input_dim=input_dim)
self.aggregate = Aggregate(input_dim=input_dim)
def forward(self, feature1, feature2, feature3):
"""
Args:
feature1 (torch.Tensor): long-frame-length features. Shape: B x T x C.
(Batch size X The number of snippets x Input dimensions)
feature2 (torch.Tensor): medium-frame-length features. Shape: B x T x C.
feature3 (torch.Tensor): short-frame-length features. Shape: B x T x C.
Returns:
torch.Tensor: Fused and processed output features. Shape: B x T x C.
"""
att1 = self.CVA1(feature1, feature2)
att2 = self.CVA2(feature2, feature3)
att3 = self.CVA3(feature3, feature1)
out1 = self.aggregate(att1, att2, att3) # B T C
return out1
class Model(nn.Module):
def __init__(self, feature_dim, batch_size, seg_num=32):
"""
Multi-Temporal Feature Learning (MTFL) recognition model.
Args:
feature_dim (int): Dimension of the input features.
batch_size (int): Batch size.
seg_num (int): Number of snippets in a video.
"""
super(Model, self).__init__()
self.batch_size = batch_size
self.num_segments = seg_num
self.k_abn = self.num_segments // 10 # select 3 snippets
self.k_nor = self.num_segments // 10
self.Encoder = Encoder(input_dim=feature_dim, seg_num=seg_num)
# Fully connected layers for classification
self.fc1 = nn.Linear(feature_dim, 512)
self.fc2 = nn.Linear(512, 128)
self.fc3 = nn.Linear(128, 18) # class amount = 18
self.drop_out = nn.Dropout(0.2)
self.relu = nn.LeakyReLU(negative_slope=5e-2)
self.sigmoid = nn.Sigmoid()
self.apply(weight_init)
def forward(self, input1, input2, input3):
"""
Args:
input1 (torch.Tensor): long-frame-length features. Shape: B x T x feature_dim.
input2 (torch.Tensor): medium-frame-length features. Shape: B x T x feature_dim.
input3 (torch.Tensor): short-frame-length features. Shape: B x T x feature_dim.
Returns:
score_abnormal (torch.Tensor): The mean scores for top-3 abnormal instances.
score_normal (torch.Tensor): The mean scores for top-3 normal instances.
feat_select_abn (torch.Tensor): Selected abnormal features.
feat_select_normal (torch.Tensor): Selected normal features.
scores (torch.Tensor): All computed scores. Shape: B x T x the number of classes (18)
"""
k_abn = self.k_abn
k_nor = self.k_nor
ncrops = 1 # Reserving the parameter for spatial cropping, which is not used and defaults to 1
# Multi-Temporal Feature Fusion
out = self.Encoder(input1, input2, input3)
bs, t, f = out.size()
features = self.drop_out(out) # B T D
# classification layers
scores = self.relu(self.fc1(features))
scores = self.drop_out(scores)
scores = self.relu(self.fc2(scores))
scores = self.drop_out(scores)
scores = self.sigmoid(self.fc3(scores))
scores = scores.view(bs, t, -1) # B T 18
# B * t * f
normal_features = features[0:self.batch_size]
normal_scores = scores[0:self.batch_size]
abnormal_features = features[self.batch_size:]
abnormal_scores = scores[self.batch_size:]
# Compute feature magnitudes
feat_magnitudes = torch.norm(features, p=2, dim=2)
feat_magnitudes = feat_magnitudes.view(bs, ncrops, -1).mean(1)
nfea_magnitudes = feat_magnitudes[0:self.batch_size] # normal feature magnitudes
afea_magnitudes = feat_magnitudes[self.batch_size:] # abnormal feature magnitudes
n_size = nfea_magnitudes.shape[0]
# Inference mode for batch size 1
if nfea_magnitudes.shape[0] == 1:
afea_magnitudes = nfea_magnitudes
abnormal_scores = normal_scores
abnormal_features = normal_features
select_idx = torch.ones_like(nfea_magnitudes)
select_idx = self.drop_out(select_idx)
####### process abnormal videos -> select top3 feature magnitude #######
afea_magnitudes_drop = afea_magnitudes * select_idx
idx_abn = torch.topk(afea_magnitudes_drop, k_abn, dim=1)[1]
idx_abn_feat = idx_abn.unsqueeze(2).expand([-1, -1, abnormal_features.shape[2]])
abnormal_features = abnormal_features.view(n_size, ncrops, t, f) # B X N X T X F
abnormal_features = abnormal_features.permute(1, 0, 2, 3) # N X B X T X F
total_select_abn_feature = torch.zeros(0, device=input1.device)
for abnormal_feature in abnormal_features:
feat_select_abn = torch.gather(abnormal_feature, 1, idx_abn_feat) # top 3 features magnitude in abnormal bag
total_select_abn_feature = torch.cat((total_select_abn_feature, feat_select_abn))
idx_abn_score = idx_abn.unsqueeze(2).expand([-1, -1, abnormal_scores.shape[2]])
# top 3 scores in abnormal bag based on the top-3 magnitude
score_abnormal = torch.mean(torch.gather(abnormal_scores, 1, idx_abn_score), dim=1)
####### process normal videos -> select top3 feature magnitude #######
select_idx_normal = torch.ones_like(nfea_magnitudes)
select_idx_normal = self.drop_out(select_idx_normal)
nfea_magnitudes_drop = nfea_magnitudes * select_idx_normal
idx_normal = torch.topk(nfea_magnitudes_drop, k_nor, dim=1)[1]
idx_normal_feat = idx_normal.unsqueeze(2).expand([-1, -1, normal_features.shape[2]])
normal_features = normal_features.view(n_size, ncrops, t, f)
normal_features = normal_features.permute(1, 0, 2, 3) # 1 B T D
total_select_nor_feature = torch.zeros(0, device=input1.device)
for nor_fea in normal_features:
feat_select_normal = torch.gather(nor_fea, 1, idx_normal_feat) # top 3 features magnitude in normal bag (hard negative)
total_select_nor_feature = torch.cat((total_select_nor_feature, feat_select_normal))
idx_normal_score = idx_normal.unsqueeze(2).expand([-1, -1, normal_scores.shape[2]])
score_normal = torch.mean(torch.gather(normal_scores, 1, idx_normal_score), dim=1) # top 3 scores in normal bag
feat_select_abn = total_select_abn_feature
feat_select_normal = total_select_nor_feature
return score_abnormal, score_normal, feat_select_abn, feat_select_normal, scores