Spaces:
Running
Running
Upload 4 files
Browse files- data_loader.py +155 -0
- feature_extractor.py +314 -0
- features_loader.py +201 -0
- generate_ROC.py +113 -0
data_loader.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" "This module contains a video loader."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from typing import List, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
from torch.utils import data
|
| 11 |
+
from torchvision.datasets.video_utils import VideoClips
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class VideoIter(data.Dataset):
|
| 15 |
+
"""This class implements a loader for videos."""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
clip_length,
|
| 20 |
+
frame_stride,
|
| 21 |
+
dataset_path=None,
|
| 22 |
+
video_transform=None,
|
| 23 |
+
return_label=False,
|
| 24 |
+
) -> None:
|
| 25 |
+
super().__init__()
|
| 26 |
+
# video clip properties
|
| 27 |
+
self.frames_stride = frame_stride
|
| 28 |
+
self.total_clip_length_in_frames = clip_length * frame_stride
|
| 29 |
+
self.video_transform = video_transform
|
| 30 |
+
|
| 31 |
+
# IO
|
| 32 |
+
self.dataset_path = dataset_path
|
| 33 |
+
self.video_list = self._get_video_list(dataset_path=self.dataset_path)
|
| 34 |
+
self.return_label = return_label
|
| 35 |
+
|
| 36 |
+
# data loading
|
| 37 |
+
self.video_clips = VideoClips(
|
| 38 |
+
video_paths=self.video_list,
|
| 39 |
+
clip_length_in_frames=self.total_clip_length_in_frames,
|
| 40 |
+
frames_between_clips=self.total_clip_length_in_frames,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def video_count(self) -> int:
|
| 45 |
+
"""Retrieve the number of the videos in the dataset."""
|
| 46 |
+
return len(self.video_list)
|
| 47 |
+
|
| 48 |
+
def getitem_from_raw_video(
|
| 49 |
+
self, idx: int
|
| 50 |
+
) -> Union[Tuple[Tensor, int, str, str], Tuple[Tensor, int, int, str, str]]:
|
| 51 |
+
"""Fetch a sample from the dataset.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
idx (int): Index of the sample the retrieve.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Tuple[Tensor, int, str, str]: Video clip, clip idx in the video, directory name, and file
|
| 58 |
+
Tuple[Tensor, int, int, str, str]: Video clip, label, clip idx in the video, directory name, and file
|
| 59 |
+
"""
|
| 60 |
+
video, _, _, _ = self.video_clips.get_clip(idx)
|
| 61 |
+
video_idx, clip_idx = self.video_clips.get_clip_location(idx)
|
| 62 |
+
video_path = self.video_clips.video_paths[video_idx]
|
| 63 |
+
in_clip_frames = list(
|
| 64 |
+
range(0, self.total_clip_length_in_frames, self.frames_stride)
|
| 65 |
+
)
|
| 66 |
+
video = video[in_clip_frames]
|
| 67 |
+
if self.video_transform is not None:
|
| 68 |
+
video = self.video_transform(video)
|
| 69 |
+
|
| 70 |
+
dir, file = video_path.split(os.sep)[-2:]
|
| 71 |
+
file = file.split(".")[0]
|
| 72 |
+
|
| 73 |
+
if self.return_label:
|
| 74 |
+
label = 0 if "Normal" in video_path else 1
|
| 75 |
+
return video, label, clip_idx, dir, file
|
| 76 |
+
|
| 77 |
+
return video, clip_idx, dir, file
|
| 78 |
+
|
| 79 |
+
def __len__(self) -> int:
|
| 80 |
+
return len(self.video_clips)
|
| 81 |
+
|
| 82 |
+
def __getitem__(self, index: int):
|
| 83 |
+
succ = False
|
| 84 |
+
while not succ:
|
| 85 |
+
try:
|
| 86 |
+
batch = self.getitem_from_raw_video(index)
|
| 87 |
+
succ = True
|
| 88 |
+
except Exception as e:
|
| 89 |
+
index = np.random.choice(range(0, self.__len__()))
|
| 90 |
+
trace_back = sys.exc_info()[2]
|
| 91 |
+
if trace_back is not None:
|
| 92 |
+
line = str(trace_back.tb_lineno)
|
| 93 |
+
else:
|
| 94 |
+
line = "no-line"
|
| 95 |
+
# pylint: disable=line-too-long
|
| 96 |
+
logging.warning(
|
| 97 |
+
f"VideoIter:: ERROR (line number {line}) !! (Force using another index:\n{index})\n{e}"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
return batch
|
| 101 |
+
|
| 102 |
+
def _get_video_list(self, dataset_path: str) -> List[str]:
|
| 103 |
+
"""Fetche all videos in a directory and sub-directories.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
dataset_path (str): A string that represents the directory of the dataset.
|
| 107 |
+
|
| 108 |
+
Raises:
|
| 109 |
+
FileNotFoundError: The directory could not be found in the provided path.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
List[str]
|
| 113 |
+
"""
|
| 114 |
+
if not os.path.exists(dataset_path):
|
| 115 |
+
raise FileNotFoundError(f"VideoIter:: failed to locate: `{dataset_path}'")
|
| 116 |
+
|
| 117 |
+
vid_list = []
|
| 118 |
+
for path, _, files in os.walk(dataset_path):
|
| 119 |
+
for name in files:
|
| 120 |
+
if "mp4" not in name:
|
| 121 |
+
continue
|
| 122 |
+
vid_list.append(os.path.join(path, name))
|
| 123 |
+
|
| 124 |
+
logging.info(f"Found {len(vid_list)} video files in {dataset_path}")
|
| 125 |
+
return vid_list
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class SingleVideoIter(VideoIter):
|
| 129 |
+
"""Loader for a single video."""
|
| 130 |
+
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
clip_length,
|
| 134 |
+
frame_stride,
|
| 135 |
+
video_path,
|
| 136 |
+
video_transform=None,
|
| 137 |
+
return_label=False,
|
| 138 |
+
) -> None:
|
| 139 |
+
super().__init__(
|
| 140 |
+
clip_length, frame_stride, video_path, video_transform, return_label
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def _get_video_list(self, dataset_path: str) -> List[str]:
|
| 144 |
+
return [dataset_path]
|
| 145 |
+
|
| 146 |
+
def __getitem__(self, idx: int) -> Tensor:
|
| 147 |
+
video, _, _, _ = self.video_clips.get_clip(idx)
|
| 148 |
+
in_clip_frames = list(
|
| 149 |
+
range(0, self.total_clip_length_in_frames, self.frames_stride)
|
| 150 |
+
)
|
| 151 |
+
video = video[in_clip_frames]
|
| 152 |
+
if self.video_transform is not None:
|
| 153 |
+
video = self.video_transform(video)
|
| 154 |
+
|
| 155 |
+
return video
|
feature_extractor.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This module contains a training procedure for video feature extraction."""
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
from os import mkdir, path
|
| 7 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from torch.backends import cudnn
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
|
| 15 |
+
from data_loader import VideoIter
|
| 16 |
+
from network.TorchUtils import get_torch_device
|
| 17 |
+
from utils.load_model import load_feature_extractor
|
| 18 |
+
from utils.utils import build_transforms, register_logger
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_args() -> argparse.Namespace:
|
| 22 |
+
"""Reads command line args and returns the parser object the represent the
|
| 23 |
+
specified arguments."""
|
| 24 |
+
|
| 25 |
+
parser = argparse.ArgumentParser(description="Video Feature Extraction Parser")
|
| 26 |
+
|
| 27 |
+
# io
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--dataset_path",
|
| 30 |
+
default="../kinetics2/kinetics2/AnomalyDetection",
|
| 31 |
+
help="path to dataset",
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--clip-length",
|
| 35 |
+
type=int,
|
| 36 |
+
default=16,
|
| 37 |
+
help="define the length of each input sample.",
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--num_workers",
|
| 41 |
+
type=int,
|
| 42 |
+
default=8,
|
| 43 |
+
help="define the number of workers used for loading the videos",
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--frame-interval",
|
| 47 |
+
type=int,
|
| 48 |
+
default=1,
|
| 49 |
+
help="define the sampling interval between frames.",
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--log-every",
|
| 53 |
+
type=int,
|
| 54 |
+
default=50,
|
| 55 |
+
help="log the writing of clips every n steps.",
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument("--log-file", type=str, help="set logging file.")
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--save_dir",
|
| 60 |
+
type=str,
|
| 61 |
+
default="features",
|
| 62 |
+
help="set output directory for the features.",
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# optimization
|
| 66 |
+
parser.add_argument("--batch-size", type=int, default=8, help="batch size")
|
| 67 |
+
|
| 68 |
+
# model
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--model_type",
|
| 71 |
+
type=str,
|
| 72 |
+
required=True,
|
| 73 |
+
help="type of feature extractor",
|
| 74 |
+
choices=["c3d", "i3d", "mfnet", "3dResNet"],
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--pretrained_3d", type=str, help="load default 3D pretrained model."
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
return parser.parse_args()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def to_segments(
|
| 84 |
+
data: Union[Tensor, np.ndarray], n_segments: int = 32
|
| 85 |
+
) -> List[np.ndarray]:
|
| 86 |
+
"""These code is taken from:
|
| 87 |
+
|
| 88 |
+
# https://github.com/rajanjitenpatel/C3D_feature_extraction/blob/b5894fa06d43aa62b3b64e85b07feb0853e7011a/extract_C3D_feature.py#L805
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
data (Union[Tensor, np.ndarray]): List of features of a certain video
|
| 92 |
+
n_segments (int, optional): Number of segments
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
List[np.ndarray]: List of `num` segments
|
| 96 |
+
"""
|
| 97 |
+
data = np.array(data)
|
| 98 |
+
Segments_Features = []
|
| 99 |
+
thirty2_shots = np.round(np.linspace(0, len(data) - 1, num=n_segments + 1)).astype(
|
| 100 |
+
int
|
| 101 |
+
)
|
| 102 |
+
for ss, ee in zip(thirty2_shots[:-1], thirty2_shots[1:]):
|
| 103 |
+
if ss == ee:
|
| 104 |
+
temp_vect = data[min(ss, data.shape[0] - 1), :]
|
| 105 |
+
else:
|
| 106 |
+
temp_vect = data[ss:ee, :].mean(axis=0)
|
| 107 |
+
|
| 108 |
+
temp_vect = temp_vect / np.linalg.norm(temp_vect)
|
| 109 |
+
|
| 110 |
+
if np.linalg.norm(temp_vect) != 0:
|
| 111 |
+
Segments_Features.append(temp_vect.tolist())
|
| 112 |
+
|
| 113 |
+
return Segments_Features
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class FeaturesWriter:
|
| 117 |
+
"""Accumulates and saves extracted features."""
|
| 118 |
+
|
| 119 |
+
def __init__(self, num_videos: int, chunk_size: int = 16) -> None:
|
| 120 |
+
self.path = ""
|
| 121 |
+
self.dir = ""
|
| 122 |
+
self.data = {}
|
| 123 |
+
self.chunk_size = chunk_size
|
| 124 |
+
self.num_videos = num_videos
|
| 125 |
+
self.dump_count = 0
|
| 126 |
+
|
| 127 |
+
def _init_video(self, video_name: str, dir: str) -> None:
|
| 128 |
+
"""Initialize the state of the writer for a new video.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
video_name (str): Name of the video to initialize for.
|
| 132 |
+
dir (str): Directory where the video is stored.
|
| 133 |
+
"""
|
| 134 |
+
self.path = path.join(dir, f"{video_name}.txt")
|
| 135 |
+
self.dir = dir
|
| 136 |
+
self.data = {}
|
| 137 |
+
|
| 138 |
+
def has_video(self) -> bool:
|
| 139 |
+
"""Checks whether the writer is initialized with a video.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
bool
|
| 143 |
+
"""
|
| 144 |
+
return self.data is not None
|
| 145 |
+
|
| 146 |
+
def dump(self, dir: str) -> None:
|
| 147 |
+
"""Saves the accumulated features to disk.
|
| 148 |
+
|
| 149 |
+
The features will be segmented and normalized.
|
| 150 |
+
"""
|
| 151 |
+
logging.info(f"{self.dump_count} / {self.num_videos}: Dumping {self.path}")
|
| 152 |
+
self.dump_count += 1
|
| 153 |
+
self.dir = dir
|
| 154 |
+
if not path.exists(self.dir):
|
| 155 |
+
os.makedirs(self.dir, exist_ok=True)
|
| 156 |
+
#####################################################
|
| 157 |
+
# Check if data is empty before attempting to process it
|
| 158 |
+
if len(self.data) == 0:
|
| 159 |
+
logging.warning("No data to dump, skipping.")
|
| 160 |
+
return # If data is empty, skip this dump.
|
| 161 |
+
#####################################################
|
| 162 |
+
features = to_segments(np.array([self.data[key] for key in sorted(self.data)]))
|
| 163 |
+
with open(self.path, "w") as fp:
|
| 164 |
+
for d in features:
|
| 165 |
+
d_str = [str(x) for x in d]
|
| 166 |
+
fp.write(" ".join(d_str) + "\n")
|
| 167 |
+
|
| 168 |
+
def _is_new_video(self, video_name: str, dir: str) -> bool:
|
| 169 |
+
"""Checks whether the given video is new or the writer is already
|
| 170 |
+
initialized with it.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
video_name (str): Name of the possibly new video.
|
| 174 |
+
dir (str): Directory where the video is stored.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
bool
|
| 178 |
+
"""
|
| 179 |
+
new_path = path.join(dir, f"{video_name}.txt")
|
| 180 |
+
if self.path != new_path and self.path is not None:
|
| 181 |
+
return True
|
| 182 |
+
|
| 183 |
+
return False
|
| 184 |
+
|
| 185 |
+
def store(self, feature: Union[Tensor, np.ndarray], idx: int) -> None:
|
| 186 |
+
"""Accumulate features.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
feature (Union[Tensor, np.ndarray]): Features to be accumulated.
|
| 190 |
+
idx (int): Indices of features in the video.
|
| 191 |
+
"""
|
| 192 |
+
self.data[idx] = list(feature)
|
| 193 |
+
|
| 194 |
+
def write(
|
| 195 |
+
self, feature: Union[Tensor, np.ndarray], video_name: str, idx: int, dir: str
|
| 196 |
+
) -> None:
|
| 197 |
+
if not self.has_video():
|
| 198 |
+
self._init_video(video_name, dir)
|
| 199 |
+
if self._is_new_video(video_name, dir):
|
| 200 |
+
self.dump(dir)
|
| 201 |
+
self._init_video(video_name, dir)
|
| 202 |
+
|
| 203 |
+
self.store(feature, idx)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def read_features(file_path, cache: Optional[Dict[str, Tensor]] = None) -> Tensor:
|
| 207 |
+
"""Reads features from file.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
file_path (_type_): Path to a text file containing features. Each line should contain a feature
|
| 211 |
+
for a single video segment.
|
| 212 |
+
cache (Dict, optional): A cache that stores features that were already loaded.
|
| 213 |
+
If `None`, caching is disabled.Defaults to None.
|
| 214 |
+
|
| 215 |
+
Raises:
|
| 216 |
+
FileNotFoundError: The provided path does not exist.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Tensor
|
| 220 |
+
"""
|
| 221 |
+
if cache is not None and file_path in cache:
|
| 222 |
+
return cache[file_path]
|
| 223 |
+
|
| 224 |
+
if not path.exists(file_path):
|
| 225 |
+
raise FileNotFoundError(f"Feature doesn't exist: `{file_path}`")
|
| 226 |
+
|
| 227 |
+
features = None
|
| 228 |
+
with open(file_path) as fp:
|
| 229 |
+
data = fp.read().splitlines(keepends=False)
|
| 230 |
+
features = torch.tensor(
|
| 231 |
+
np.stack([line.split(" ") for line in data]).astype(np.float32)
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
if cache is not None:
|
| 235 |
+
cache[file_path] = features
|
| 236 |
+
return features
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def get_features_loader(
|
| 240 |
+
dataset_path: str,
|
| 241 |
+
clip_length: int,
|
| 242 |
+
frame_interval: int,
|
| 243 |
+
batch_size: int,
|
| 244 |
+
num_workers: int,
|
| 245 |
+
mode: str,
|
| 246 |
+
) -> Tuple[VideoIter, DataLoader]:
|
| 247 |
+
data_loader = VideoIter(
|
| 248 |
+
dataset_path=dataset_path,
|
| 249 |
+
clip_length=clip_length,
|
| 250 |
+
frame_stride=frame_interval,
|
| 251 |
+
video_transform=build_transforms(mode),
|
| 252 |
+
return_label=False,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
data_iter = torch.utils.data.DataLoader(
|
| 256 |
+
data_loader,
|
| 257 |
+
batch_size=batch_size,
|
| 258 |
+
shuffle=False,
|
| 259 |
+
num_workers=num_workers,
|
| 260 |
+
pin_memory=True,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
return data_loader, data_iter
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
if __name__ == "__main__":
|
| 267 |
+
device = get_torch_device()
|
| 268 |
+
|
| 269 |
+
args = get_args()
|
| 270 |
+
register_logger(log_file=args.log_file)
|
| 271 |
+
|
| 272 |
+
cudnn.benchmark = True
|
| 273 |
+
|
| 274 |
+
data_loader, data_iter = get_features_loader(
|
| 275 |
+
args.dataset_path,
|
| 276 |
+
args.clip_length,
|
| 277 |
+
args.frame_interval,
|
| 278 |
+
args.batch_size,
|
| 279 |
+
args.num_workers,
|
| 280 |
+
args.model_type,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
network = load_feature_extractor(args.model_type, args.pretrained_3d, device).eval()
|
| 284 |
+
|
| 285 |
+
if not path.exists(args.save_dir):
|
| 286 |
+
mkdir(args.save_dir)
|
| 287 |
+
|
| 288 |
+
features_writer = FeaturesWriter(num_videos=data_loader.video_count)
|
| 289 |
+
loop_i = 0
|
| 290 |
+
global_dir: str = "none"
|
| 291 |
+
with torch.no_grad():
|
| 292 |
+
for data, clip_idxs, dirs, vid_names in data_iter:
|
| 293 |
+
outputs = network(data.to(device)).detach().cpu().numpy()
|
| 294 |
+
|
| 295 |
+
for i, (_dir, vid_name, clip_idx) in enumerate(
|
| 296 |
+
zip(dirs, vid_names, clip_idxs)
|
| 297 |
+
):
|
| 298 |
+
if loop_i == 0:
|
| 299 |
+
# pylint: disable=line-too-long
|
| 300 |
+
logging.info(
|
| 301 |
+
f"Video {features_writer.dump_count} / {features_writer.num_videos} : Writing clip {clip_idx} of video {vid_name}"
|
| 302 |
+
)
|
| 303 |
+
loop_i += 1
|
| 304 |
+
loop_i %= args.log_every
|
| 305 |
+
_dir = path.join(args.save_dir, _dir)
|
| 306 |
+
global_dir = _dir
|
| 307 |
+
features_writer.write(
|
| 308 |
+
feature=outputs[i],
|
| 309 |
+
video_name=vid_name,
|
| 310 |
+
idx=clip_idx,
|
| 311 |
+
dir=_dir,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
features_writer.dump(global_dir)
|
features_loader.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" "This module contains a video feature loader."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
from torch.utils import data
|
| 11 |
+
|
| 12 |
+
from feature_extractor import read_features
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FeaturesLoader:
|
| 16 |
+
"""Loads video features that are stored as text files."""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
features_path: str,
|
| 21 |
+
annotation_path: str,
|
| 22 |
+
bucket_size: int = 30,
|
| 23 |
+
iterations: int = 20000,
|
| 24 |
+
) -> None:
|
| 25 |
+
"""
|
| 26 |
+
Args:
|
| 27 |
+
features_path: Path to the directory that contains the features in text files
|
| 28 |
+
annotation_path: Path to the annotation file
|
| 29 |
+
bucket_size: Size of each bucket
|
| 30 |
+
iterations: How many iterations the loader should perform
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
super().__init__()
|
| 34 |
+
self._features_path = features_path
|
| 35 |
+
self._bucket_size = bucket_size
|
| 36 |
+
|
| 37 |
+
# load video list
|
| 38 |
+
(
|
| 39 |
+
self.features_list_normal,
|
| 40 |
+
self.features_list_anomaly,
|
| 41 |
+
) = FeaturesLoader._get_features_list(
|
| 42 |
+
features_path=self._features_path, annotation_path=annotation_path
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
self._iterations = iterations
|
| 46 |
+
self._features_cache = {}
|
| 47 |
+
self._i = 0
|
| 48 |
+
|
| 49 |
+
def __len__(self) -> int:
|
| 50 |
+
return self._iterations
|
| 51 |
+
|
| 52 |
+
def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
|
| 53 |
+
if self._i == len(self):
|
| 54 |
+
self._i = 0
|
| 55 |
+
raise StopIteration
|
| 56 |
+
|
| 57 |
+
succ = False
|
| 58 |
+
while not succ:
|
| 59 |
+
try:
|
| 60 |
+
feature, label = self.get_features()
|
| 61 |
+
succ = True
|
| 62 |
+
except Exception as e:
|
| 63 |
+
index = np.random.choice(range(0, self.__len__()))
|
| 64 |
+
logging.warning(
|
| 65 |
+
f"VideoIter:: ERROR!! (Force using another index:\n{index})\n{e}"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
self._i += 1
|
| 69 |
+
return feature, label
|
| 70 |
+
|
| 71 |
+
def get_features(self) -> Tuple[Tensor, Tensor]:
|
| 72 |
+
"""Fetches a bucket sample from the dataset."""
|
| 73 |
+
normal_paths = np.random.choice(
|
| 74 |
+
self.features_list_normal, size=self._bucket_size
|
| 75 |
+
)
|
| 76 |
+
abnormal_paths = np.random.choice(
|
| 77 |
+
self.features_list_anomaly, size=self._bucket_size
|
| 78 |
+
)
|
| 79 |
+
all_paths = np.concatenate([normal_paths, abnormal_paths])
|
| 80 |
+
features = torch.stack(
|
| 81 |
+
[
|
| 82 |
+
read_features(f"{feature_subpath}.txt", self._features_cache)
|
| 83 |
+
for feature_subpath in all_paths
|
| 84 |
+
]
|
| 85 |
+
)
|
| 86 |
+
return (
|
| 87 |
+
features,
|
| 88 |
+
torch.cat([torch.zeros(self._bucket_size), torch.ones(self._bucket_size)]),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
@staticmethod
|
| 92 |
+
def _get_features_list(
|
| 93 |
+
features_path: str, annotation_path: str
|
| 94 |
+
) -> Tuple[List[str], List[str]]:
|
| 95 |
+
"""Retrieves the paths of all feature files contained within the
|
| 96 |
+
annotation file.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
features_path: Path to the directory that contains feature text files
|
| 100 |
+
annotation_path: Path to the annotation file
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
Tuple[List[str], List[str]]: Two list that contain the corresponding paths of normal and abnormal
|
| 104 |
+
feature files.
|
| 105 |
+
"""
|
| 106 |
+
assert os.path.exists(features_path)
|
| 107 |
+
features_list_normal = []
|
| 108 |
+
features_list_anomaly = []
|
| 109 |
+
with open(annotation_path) as f:
|
| 110 |
+
lines = f.read().splitlines(keepends=False)
|
| 111 |
+
for line in lines:
|
| 112 |
+
items = line.split()
|
| 113 |
+
file = items[0].split(".")[0]
|
| 114 |
+
file = file.replace("/", os.sep)
|
| 115 |
+
feature_path = os.path.join(features_path, file)
|
| 116 |
+
if "Normal" in feature_path:
|
| 117 |
+
features_list_normal.append(feature_path)
|
| 118 |
+
else:
|
| 119 |
+
features_list_anomaly.append(feature_path)
|
| 120 |
+
|
| 121 |
+
return features_list_normal, features_list_anomaly
|
| 122 |
+
|
| 123 |
+
@property
|
| 124 |
+
def get_feature_dim(self) -> int:
|
| 125 |
+
return self[0][0].shape[-1]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class FeaturesLoaderVal(data.Dataset):
|
| 129 |
+
"""Loader for video features for validation phase."""
|
| 130 |
+
|
| 131 |
+
def __init__(self, features_path, annotation_path):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.features_path = features_path
|
| 134 |
+
# load video list
|
| 135 |
+
self.state = "Normal"
|
| 136 |
+
self.features_list = FeaturesLoaderVal._get_features_list(
|
| 137 |
+
features_path=features_path, annotation_path=annotation_path
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def __len__(self):
|
| 141 |
+
return len(self.features_list)
|
| 142 |
+
|
| 143 |
+
def __getitem__(self, index: int):
|
| 144 |
+
succ = False
|
| 145 |
+
while not succ:
|
| 146 |
+
try:
|
| 147 |
+
data = self.get_feature(index)
|
| 148 |
+
succ = True
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logging.warning(
|
| 151 |
+
f"VideoIter:: ERROR!! (Force using another index:\n{index})\n{e}"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
return data
|
| 155 |
+
|
| 156 |
+
def get_feature(self, index: int):
|
| 157 |
+
"""Fetch feature that matches given index in the dataset.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
index (int): Index of the feature to fetch.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
_type_: _description_
|
| 164 |
+
"""
|
| 165 |
+
feature_subpath, start_end_couples, length = self.features_list[index]
|
| 166 |
+
features = read_features(f"{feature_subpath}.txt")
|
| 167 |
+
return features, start_end_couples, length
|
| 168 |
+
|
| 169 |
+
@staticmethod
|
| 170 |
+
def _get_features_list(
|
| 171 |
+
features_path: str, annotation_path: str
|
| 172 |
+
) -> List[Tuple[str, Tensor, int]]:
|
| 173 |
+
"""Retrieves the paths of all feature files contained within the
|
| 174 |
+
annotation file.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
features_path: Path to the directory that contains feature text files
|
| 178 |
+
annotation_path: Path to the annotation file
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
List[Tuple[str, Tensor, int]]: A list of tuples that describe each video and the temporal annotations
|
| 182 |
+
of anomalies in the videos
|
| 183 |
+
"""
|
| 184 |
+
assert os.path.exists(features_path)
|
| 185 |
+
features_list = []
|
| 186 |
+
with open(annotation_path) as f:
|
| 187 |
+
lines = f.read().splitlines(keepends=False)
|
| 188 |
+
for line in lines:
|
| 189 |
+
items = line.split()
|
| 190 |
+
anomalies_frames = [int(x) for x in items[3:]]
|
| 191 |
+
start_end_couples = torch.tensor(
|
| 192 |
+
[anomalies_frames[:2], anomalies_frames[2:]]
|
| 193 |
+
)
|
| 194 |
+
file = items[0].split(".")[0]
|
| 195 |
+
file = file.replace("/", os.sep)
|
| 196 |
+
feature_path = os.path.join(features_path, file)
|
| 197 |
+
length = int(items[1])
|
| 198 |
+
|
| 199 |
+
features_list.append((feature_path, start_end_couples, length))
|
| 200 |
+
|
| 201 |
+
return features_list
|
generate_ROC.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" "This module contains an evaluation procedure for video anomaly
|
| 2 |
+
detection."""
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import os
|
| 6 |
+
from os import path
|
| 7 |
+
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from sklearn.metrics import auc, roc_curve
|
| 12 |
+
from torch.backends import cudnn
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from features_loader import FeaturesLoaderVal
|
| 16 |
+
from network.TorchUtils import TorchModel
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_args() -> argparse.Namespace:
|
| 20 |
+
"""Reads command line args and returns the parser object the represent the
|
| 21 |
+
specified arguments."""
|
| 22 |
+
parser = argparse.ArgumentParser(
|
| 23 |
+
description="Video Anomaly Detection Evaluation Parser"
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--features_path",
|
| 27 |
+
type=str,
|
| 28 |
+
default="../anomaly_features",
|
| 29 |
+
required=True,
|
| 30 |
+
help="path to features",
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--annotation_path", default="Test_Annotation.txt", help="path to annotations"
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--model_path", type=str, required=True, help="Path to the anomaly detector."
|
| 37 |
+
)
|
| 38 |
+
return parser.parse_args()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
args = get_args()
|
| 43 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 44 |
+
|
| 45 |
+
data_loader = FeaturesLoaderVal(
|
| 46 |
+
features_path=args.features_path,
|
| 47 |
+
annotation_path=args.annotation_path,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
data_iter = torch.utils.data.DataLoader(
|
| 51 |
+
data_loader,
|
| 52 |
+
batch_size=1,
|
| 53 |
+
shuffle=False,
|
| 54 |
+
num_workers=0, # 4, # change this part accordingly
|
| 55 |
+
pin_memory=True,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
model = TorchModel.load_model(args.model_path).to(device).eval()
|
| 59 |
+
|
| 60 |
+
# enable cudnn tune
|
| 61 |
+
cudnn.benchmark = True
|
| 62 |
+
|
| 63 |
+
# pylint: disable=not-callable
|
| 64 |
+
y_trues = np.array([])
|
| 65 |
+
y_preds = np.array([])
|
| 66 |
+
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
for features, start_end_couples, lengths in tqdm(data_iter):
|
| 69 |
+
# features is a batch where each item is a tensor of 32 4096D features
|
| 70 |
+
features = features.to(device)
|
| 71 |
+
outputs = model(features).squeeze(-1) # (batch_size, 32)
|
| 72 |
+
for vid_len, couples, output in zip(
|
| 73 |
+
lengths, start_end_couples, outputs.cpu().numpy()
|
| 74 |
+
):
|
| 75 |
+
y_true = np.zeros(vid_len)
|
| 76 |
+
y_pred = np.zeros(vid_len)
|
| 77 |
+
|
| 78 |
+
segments_len = vid_len // 32
|
| 79 |
+
for couple in couples:
|
| 80 |
+
if couple[0] != -1:
|
| 81 |
+
y_true[couple[0] : couple[1]] = 1
|
| 82 |
+
|
| 83 |
+
for i in range(32):
|
| 84 |
+
segment_start_frame = i * segments_len
|
| 85 |
+
segment_end_frame = (i + 1) * segments_len
|
| 86 |
+
y_pred[segment_start_frame:segment_end_frame] = output[i]
|
| 87 |
+
|
| 88 |
+
if y_trues is None:
|
| 89 |
+
y_trues = y_true
|
| 90 |
+
y_preds = y_pred
|
| 91 |
+
else:
|
| 92 |
+
y_trues = np.concatenate([y_trues, y_true])
|
| 93 |
+
y_preds = np.concatenate([y_preds, y_pred])
|
| 94 |
+
|
| 95 |
+
fpr, tpr, thresholds = roc_curve(y_true=y_trues, y_score=y_preds, pos_label=1)
|
| 96 |
+
|
| 97 |
+
plt.figure()
|
| 98 |
+
lw = 2
|
| 99 |
+
roc_auc = auc(fpr, tpr)
|
| 100 |
+
plt.plot(
|
| 101 |
+
fpr, tpr, color="darkorange", lw=lw, label=f"ROC curve (area = {roc_auc:0.2f})"
|
| 102 |
+
)
|
| 103 |
+
plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
|
| 104 |
+
plt.xlim([0.0, 1.0])
|
| 105 |
+
plt.ylim([0.0, 1.05])
|
| 106 |
+
plt.xlabel("False Positive Rate")
|
| 107 |
+
plt.ylabel("True Positive Rate")
|
| 108 |
+
plt.legend(loc="lower right")
|
| 109 |
+
|
| 110 |
+
os.makedirs("graphs", exist_ok=True)
|
| 111 |
+
plt.savefig(path.join("graphs", "roc_auc.png"))
|
| 112 |
+
plt.close()
|
| 113 |
+
print(f"ROC curve (area = {roc_auc:0.2f})")
|