erkutt's picture
Upload open source code of MTFL model
28e129b verified
import torch
from sklearn.metrics import auc, roc_curve, average_precision_score
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import option
from torch.utils.data import DataLoader
from dataset import Dataset
from model import Model
import warnings
from sklearn.exceptions import UndefinedMetricWarning
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)
def get_gt(start_end_couples, num_frames, device):
"""
Generate a ground truth tensor representing events in a time sequence based on given start and end pairs.
Args:
start_end_couples (list): A list containing pairs of start and end frames.
If None or all '-1', no events are present.
num_frames (int): Total number of frames in the time sequence.
device: Device where the tensor should be placed.
Returns:
gt: A tensor of shape (num_frames,) representing whether each frame belongs to an anomalous event.
'1' means anomalous, and '0' means normal.
"""
gt = torch.zeros(num_frames).to(device)
if start_end_couples is not None and num_frames is not None:
for i in range(0, len(start_end_couples) - 1, 2):
if start_end_couples[i].item() != -1 and start_end_couples[i + 1].item() != -1:
couple = start_end_couples[i:i + 2]
gt[couple[0].item():couple[1].item()] = 1.0
return gt
def save_scores(pred, start_end_couples, save_path):
"""
Save plots containing anomaly scores and annotated regions.
Args:
pred (list): List of anomaly scores.
start_end_couples (Tensor): Pairs of start and end frames indicating anomalous regions.
save_path (str): Path to save the generated plot.
file_name (str): Name to be displayed in the legend of the plot.
"""
plt.figure()
file_name = os.path.basename(save_path).split(".")[0]
plt.plot(pred, label=file_name, color='blue')
# Plot anomalous regions
for i in range(0, len(start_end_couples) - 1, 2):
if start_end_couples[i].item() != -1 and start_end_couples[i + 1].item() != -1:
plt.axvspan(start_end_couples[i].item(), start_end_couples[i + 1].item(), color='red', alpha=0.3)
plt.ylim(0, 1)
plt.xlabel('Frames', fontdict={'size': 16})
plt.ylabel('Anomaly Score', fontdict={'size': 16})
plt.yticks(size=14)
plt.xticks(size=14)
plt.legend(prop={'size': 16})
#plt.show()
plt.savefig(save_path)
plt.close()
def test(dataloader, model, device, gen_scores=False, save_dir=None):
"""
Test the model's performance on the given dataloader.
Args:
dataloader (DataLoader): DataLoader for test data.
model: The model to be tested.
device: Device to perform testing on.
gen_scores (bool): Whether to generate and save anomaly scores plot.
save_dir (str): Directory to save generated plots.
Returns:
single_video_AUC (dict): A dictionary containing AUC values for each video.
overall_auc (float): Overall AUC value.
ap (float): average precision
"""
single_video_AUC = {"video": [], "AUC": []}
with torch.no_grad():
model.to(device).eval()
pred = torch.zeros(0, device=device)
gt = torch.zeros(0, device=device)
for input1, input2, input3, label, start_end_couples, num_frames, file in tqdm(dataloader):
input1 = input1.to(device)
input2 = input2.to(device)
input3 = input3.to(device)
score_abnormal, score_normal, feat_select_abn, feat_select_normal, scores = model(input1, input2, input3)
sig = torch.squeeze(scores, dim=(0, 2)) # T scores
segment = num_frames.item() // sig.size()[0]
sig = sig.repeat_interleave(segment) # Frames
if len(sig) < num_frames.item():
last_ele = sig[-1]
sig = torch.cat((sig, last_ele.repeat(num_frames.item()-len(sig)))) # 1 x Frames
pred = torch.cat((pred, sig))
cur_gt = get_gt(start_end_couples, num_frames, device)
gt = torch.cat((gt, cur_gt))
sig = sig.cpu().detach().numpy()
cur_gt = cur_gt.cpu().detach().numpy()
fpr, tpr, threshold = roc_curve(cur_gt, sig)
video_auc = auc(fpr, tpr)
single_video_AUC["video"].append(file)
single_video_AUC["AUC"].append(video_auc)
if gen_scores:
save_path = os.path.join(save_dir, file[0] + '.png')
os.makedirs(os.path.dirname(save_path), exist_ok=True)
save_scores(sig, start_end_couples, save_path)
pred = pred.cpu().detach().numpy()
gt = gt.cpu().detach().numpy()
ap = average_precision_score(gt, pred)
fpr, tpr, threshold = roc_curve(gt, pred)
overall_auc = auc(fpr, tpr)
print('\n' + 'Overall auc : ' + str(overall_auc) + ', Average Precision : ' + str(ap) + '\n')
return single_video_AUC, overall_auc, ap
def main():
args = option.test_parser.parse_args()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
AUC_path = os.path.join(args.output_dir, 'AUC')
scores_path = os.path.join(args.output_dir, 'scores')
test_loader = DataLoader(Dataset(args, test_mode=True),
batch_size=1, shuffle=False,
num_workers=args.workers, pin_memory=True)
model = Model(feature_dim=args.feature_size, batch_size=1, seg_num=args.seg_num)
model.load_state_dict(torch.load(args.detection_model))
single_video_AUC, overall_auc, ap = test(dataloader=test_loader,
model=model,
device=device,
gen_scores=True,
save_dir=scores_path)
# save AUC results
video_sub_dir = os.path.basename(os.path.dirname(single_video_AUC["video"][0][0]))
file_path = os.path.join(AUC_path, video_sub_dir, 'results.txt')
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w") as f:
for video, single_auc in zip(single_video_AUC["video"], single_video_AUC["AUC"]):
f.write(f"Video: {video}, AUC: {single_auc}\n")
f.write("Overall AUC: {}, Average Precision: {}\n".format(overall_auc, ap))
if __name__ == '__main__':
main()