|
|
import torch |
|
|
from sklearn.metrics import auc, roc_curve, average_precision_score |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
import matplotlib.pyplot as plt |
|
|
import option |
|
|
|
|
|
from torch.utils.data import DataLoader |
|
|
from dataset import Dataset |
|
|
from model import Model |
|
|
import warnings |
|
|
from sklearn.exceptions import UndefinedMetricWarning |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UndefinedMetricWarning) |
|
|
|
|
|
|
|
|
def get_gt(start_end_couples, num_frames, device): |
|
|
""" |
|
|
Generate a ground truth tensor representing events in a time sequence based on given start and end pairs. |
|
|
|
|
|
Args: |
|
|
start_end_couples (list): A list containing pairs of start and end frames. |
|
|
If None or all '-1', no events are present. |
|
|
num_frames (int): Total number of frames in the time sequence. |
|
|
device: Device where the tensor should be placed. |
|
|
|
|
|
Returns: |
|
|
gt: A tensor of shape (num_frames,) representing whether each frame belongs to an anomalous event. |
|
|
'1' means anomalous, and '0' means normal. |
|
|
""" |
|
|
gt = torch.zeros(num_frames).to(device) |
|
|
if start_end_couples is not None and num_frames is not None: |
|
|
for i in range(0, len(start_end_couples) - 1, 2): |
|
|
if start_end_couples[i].item() != -1 and start_end_couples[i + 1].item() != -1: |
|
|
couple = start_end_couples[i:i + 2] |
|
|
gt[couple[0].item():couple[1].item()] = 1.0 |
|
|
|
|
|
return gt |
|
|
|
|
|
|
|
|
def save_scores(pred, start_end_couples, save_path): |
|
|
""" |
|
|
Save plots containing anomaly scores and annotated regions. |
|
|
|
|
|
Args: |
|
|
pred (list): List of anomaly scores. |
|
|
start_end_couples (Tensor): Pairs of start and end frames indicating anomalous regions. |
|
|
save_path (str): Path to save the generated plot. |
|
|
file_name (str): Name to be displayed in the legend of the plot. |
|
|
""" |
|
|
|
|
|
plt.figure() |
|
|
file_name = os.path.basename(save_path).split(".")[0] |
|
|
plt.plot(pred, label=file_name, color='blue') |
|
|
|
|
|
|
|
|
for i in range(0, len(start_end_couples) - 1, 2): |
|
|
if start_end_couples[i].item() != -1 and start_end_couples[i + 1].item() != -1: |
|
|
plt.axvspan(start_end_couples[i].item(), start_end_couples[i + 1].item(), color='red', alpha=0.3) |
|
|
|
|
|
plt.ylim(0, 1) |
|
|
plt.xlabel('Frames', fontdict={'size': 16}) |
|
|
plt.ylabel('Anomaly Score', fontdict={'size': 16}) |
|
|
plt.yticks(size=14) |
|
|
plt.xticks(size=14) |
|
|
|
|
|
plt.legend(prop={'size': 16}) |
|
|
|
|
|
plt.savefig(save_path) |
|
|
plt.close() |
|
|
|
|
|
|
|
|
def test(dataloader, model, device, gen_scores=False, save_dir=None): |
|
|
""" |
|
|
Test the model's performance on the given dataloader. |
|
|
|
|
|
Args: |
|
|
dataloader (DataLoader): DataLoader for test data. |
|
|
model: The model to be tested. |
|
|
device: Device to perform testing on. |
|
|
gen_scores (bool): Whether to generate and save anomaly scores plot. |
|
|
save_dir (str): Directory to save generated plots. |
|
|
|
|
|
Returns: |
|
|
single_video_AUC (dict): A dictionary containing AUC values for each video. |
|
|
overall_auc (float): Overall AUC value. |
|
|
ap (float): average precision |
|
|
""" |
|
|
single_video_AUC = {"video": [], "AUC": []} |
|
|
|
|
|
with torch.no_grad(): |
|
|
model.to(device).eval() |
|
|
pred = torch.zeros(0, device=device) |
|
|
gt = torch.zeros(0, device=device) |
|
|
|
|
|
for input1, input2, input3, label, start_end_couples, num_frames, file in tqdm(dataloader): |
|
|
input1 = input1.to(device) |
|
|
input2 = input2.to(device) |
|
|
input3 = input3.to(device) |
|
|
score_abnormal, score_normal, feat_select_abn, feat_select_normal, scores = model(input1, input2, input3) |
|
|
sig = torch.squeeze(scores, dim=(0, 2)) |
|
|
segment = num_frames.item() // sig.size()[0] |
|
|
sig = sig.repeat_interleave(segment) |
|
|
if len(sig) < num_frames.item(): |
|
|
last_ele = sig[-1] |
|
|
sig = torch.cat((sig, last_ele.repeat(num_frames.item()-len(sig)))) |
|
|
|
|
|
pred = torch.cat((pred, sig)) |
|
|
cur_gt = get_gt(start_end_couples, num_frames, device) |
|
|
gt = torch.cat((gt, cur_gt)) |
|
|
|
|
|
sig = sig.cpu().detach().numpy() |
|
|
cur_gt = cur_gt.cpu().detach().numpy() |
|
|
fpr, tpr, threshold = roc_curve(cur_gt, sig) |
|
|
video_auc = auc(fpr, tpr) |
|
|
single_video_AUC["video"].append(file) |
|
|
single_video_AUC["AUC"].append(video_auc) |
|
|
|
|
|
if gen_scores: |
|
|
save_path = os.path.join(save_dir, file[0] + '.png') |
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True) |
|
|
save_scores(sig, start_end_couples, save_path) |
|
|
|
|
|
pred = pred.cpu().detach().numpy() |
|
|
gt = gt.cpu().detach().numpy() |
|
|
ap = average_precision_score(gt, pred) |
|
|
fpr, tpr, threshold = roc_curve(gt, pred) |
|
|
overall_auc = auc(fpr, tpr) |
|
|
print('\n' + 'Overall auc : ' + str(overall_auc) + ', Average Precision : ' + str(ap) + '\n') |
|
|
|
|
|
return single_video_AUC, overall_auc, ap |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = option.test_parser.parse_args() |
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
AUC_path = os.path.join(args.output_dir, 'AUC') |
|
|
scores_path = os.path.join(args.output_dir, 'scores') |
|
|
|
|
|
test_loader = DataLoader(Dataset(args, test_mode=True), |
|
|
batch_size=1, shuffle=False, |
|
|
num_workers=args.workers, pin_memory=True) |
|
|
model = Model(feature_dim=args.feature_size, batch_size=1, seg_num=args.seg_num) |
|
|
model.load_state_dict(torch.load(args.detection_model)) |
|
|
|
|
|
single_video_AUC, overall_auc, ap = test(dataloader=test_loader, |
|
|
model=model, |
|
|
device=device, |
|
|
gen_scores=True, |
|
|
save_dir=scores_path) |
|
|
|
|
|
|
|
|
video_sub_dir = os.path.basename(os.path.dirname(single_video_AUC["video"][0][0])) |
|
|
file_path = os.path.join(AUC_path, video_sub_dir, 'results.txt') |
|
|
os.makedirs(os.path.dirname(file_path), exist_ok=True) |
|
|
with open(file_path, "w") as f: |
|
|
for video, single_auc in zip(single_video_AUC["video"], single_video_AUC["AUC"]): |
|
|
f.write(f"Video: {video}, AUC: {single_auc}\n") |
|
|
f.write("Overall AUC: {}, Average Precision: {}\n".format(overall_auc, ap)) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|
|
|
|
|
|
|