anomaly_dector / generate_ROC.py
foreversheikh's picture
Upload 10 files
152e3c7 verified
""" "This module contains an evaluation procedure for video anomaly
detection."""
import argparse
import os
from os import path
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.metrics import auc, roc_curve
from torch.backends import cudnn
from tqdm import tqdm
from features_loader import FeaturesLoaderVal
from network.TorchUtils import TorchModel
def get_args() -> argparse.Namespace:
"""Reads command line args and returns the parser object the represent the
specified arguments."""
parser = argparse.ArgumentParser(
description="Video Anomaly Detection Evaluation Parser"
)
parser.add_argument(
"--features_path",
type=str,
default="../anomaly_features",
required=True,
help="path to features",
)
parser.add_argument(
"--annotation_path", default="Test_Annotation.txt", help="path to annotations"
)
parser.add_argument(
"--model_path", type=str, required=True, help="Path to the anomaly detector."
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_loader = FeaturesLoaderVal(
features_path=args.features_path,
annotation_path=args.annotation_path,
)
data_iter = torch.utils.data.DataLoader(
data_loader,
batch_size=1,
shuffle=False,
num_workers=0, # 4, # change this part accordingly
pin_memory=True,
)
model = TorchModel.load_model(args.model_path).to(device).eval()
# enable cudnn tune
cudnn.benchmark = True
# pylint: disable=not-callable
y_trues = np.array([])
y_preds = np.array([])
with torch.no_grad():
for features, start_end_couples, lengths in tqdm(data_iter):
# features is a batch where each item is a tensor of 32 4096D features
features = features.to(device)
outputs = model(features).squeeze(-1) # (batch_size, 32)
for vid_len, couples, output in zip(
lengths, start_end_couples, outputs.cpu().numpy()
):
y_true = np.zeros(vid_len)
y_pred = np.zeros(vid_len)
segments_len = vid_len // 32
for couple in couples:
if couple[0] != -1:
y_true[couple[0] : couple[1]] = 1
for i in range(32):
segment_start_frame = i * segments_len
segment_end_frame = (i + 1) * segments_len
y_pred[segment_start_frame:segment_end_frame] = output[i]
if y_trues is None:
y_trues = y_true
y_preds = y_pred
else:
y_trues = np.concatenate([y_trues, y_true])
y_preds = np.concatenate([y_preds, y_pred])
fpr, tpr, thresholds = roc_curve(y_true=y_trues, y_score=y_preds, pos_label=1)
plt.figure()
lw = 2
roc_auc = auc(fpr, tpr)
plt.plot(
fpr, tpr, color="darkorange", lw=lw, label=f"ROC curve (area = {roc_auc:0.2f})"
)
plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend(loc="lower right")
os.makedirs("graphs", exist_ok=True)
plt.savefig(path.join("graphs", "roc_auc.png"))
plt.close()
print(f"ROC curve (area = {roc_auc:0.2f})")