File size: 4,262 Bytes
17ee76b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""This module contains functions for loading models."""

import logging
from os import path
from typing import Tuple

import torch

from network.anomaly_detector_model import AnomalyDetector
from network.c3d import C3D
from network.MFNET import MFNET_3D
from network.resnet import generate_model
from network.TorchUtils import TorchModel
from utils.types import Device, FeatureExtractor


def load_feature_extractor(

    features_method: str, feature_extractor_path: str, device: Device

) -> FeatureExtractor:
    """Load feature extractor from given path.



    Args:

        features_method (str): The feature extractor model type to use. Either c3d | mfnet | r3d101 | r3d152.

        feature_extractor_path (str): Path to the feature extractor model.

        device (Union[torch.device, str]): Device to use for the model.



    Raises:

        FileNotFoundError: The path to the model does not exist.

        NotImplementedError: The provided feature extractor method is not implemented.



    Returns:

        FeatureExtractor

    """
    if not path.exists(feature_extractor_path):
        raise FileNotFoundError(
            f"Couldn't find feature extractor {feature_extractor_path}.\n"
            + r"If you are using resnet, download it first from:\n"
            + r"r3d101: https://drive.google.com/file/d/1p80RJsghFIKBSLKgtRG94LE38OGY5h4y/view?usp=share_link"
            + "\n"
            + r"r3d152: https://drive.google.com/file/d/1irIdC_v7wa-sBpTiBlsMlS7BYNdj4Gr7/view?usp=share_link"
        )
    logging.info(f"Loading feature extractor from {feature_extractor_path}")

    model: FeatureExtractor

    if features_method == "c3d":
        model = C3D(pretrained=feature_extractor_path)
    elif features_method == "mfnet":
        model = MFNET_3D()
        model.load_state(state_dict=feature_extractor_path)
    elif features_method == "r3d101":
        model = generate_model(model_depth=101)
        param_dict = torch.load(feature_extractor_path)["state_dict"]
        param_dict.pop("fc.weight")
        param_dict.pop("fc.bias")
        model.load_state_dict(param_dict)
    elif features_method == "r3d152":
        model = generate_model(model_depth=152)
        param_dict = torch.load(feature_extractor_path)["state_dict"]
        param_dict.pop("fc.weight")
        param_dict.pop("fc.bias")
        model.load_state_dict(param_dict)
    else:
        raise NotImplementedError(
            f"Features extraction method {features_method} not implemented"
        )

    return model.to(device).eval()


def load_anomaly_detector(ad_model_path: str, device: Device) -> AnomalyDetector:
    """Load anomaly detection model from given path.



    Args:

        ad_model_path (str): Path to the anomaly detection model.

        device (Device): Device to use for the model.



    Raises:

        FileNotFoundError: The path to the model does not exist.



    Returns:

        AnomalyDetector

    """
    if not path.exists(ad_model_path):
        raise FileNotFoundError(f"Couldn't find anomaly detector {ad_model_path}.")
    logging.info(f"Loading anomaly detector from {ad_model_path}")

    anomaly_detector = TorchModel.load_model(ad_model_path).to(device)
    return anomaly_detector.eval()


def load_models(

    feature_extractor_path: str,

    ad_model_path: str,

    features_method: str = "c3d",

    device: Device = "cuda",

) -> Tuple[AnomalyDetector, FeatureExtractor]:
    """Loads both feature extractor and anomaly detector from the given paths.



    Args:

        feature_extractor_path (str): Path of the features extractor weights to load.

        ad_model_path (str): Path of the anomaly detector weights to load.

        features_method (str, optional): Name of the model to use for features extraction.

            Defaults to "c3d".

        device (str, optional): Device to use for the models. Defaults to "cuda".



    Returns:

        Tuple[nn.Module, nn.Module]

    """
    feature_extractor = load_feature_extractor(
        features_method, feature_extractor_path, device
    )
    anomaly_detector = load_anomaly_detector(ad_model_path, device)
    return anomaly_detector, feature_extractor