foreversheikh commited on
Commit
4868e52
·
verified ·
1 Parent(s): 5229a53

Upload load_model.py

Browse files
Files changed (1) hide show
  1. load_model.py +114 -0
load_model.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains functions for loading models."""
2
+
3
+ import logging
4
+ from os import path
5
+ from typing import Tuple
6
+
7
+ import torch
8
+
9
+ from network.anomaly_detector_model import AnomalyDetector
10
+ from network.c3d import C3D
11
+ from network.MFNET import MFNET_3D
12
+ from network.resnet import generate_model
13
+ from network.TorchUtils import TorchModel
14
+ from utils.types import Device, FeatureExtractor
15
+
16
+
17
+ def load_feature_extractor(
18
+ features_method: str, feature_extractor_path: str, device: Device
19
+ ) -> FeatureExtractor:
20
+ """Load feature extractor from given path.
21
+
22
+ Args:
23
+ features_method (str): The feature extractor model type to use. Either c3d | mfnet | r3d101 | r3d152.
24
+ feature_extractor_path (str): Path to the feature extractor model.
25
+ device (Union[torch.device, str]): Device to use for the model.
26
+
27
+ Raises:
28
+ FileNotFoundError: The path to the model does not exist.
29
+ NotImplementedError: The provided feature extractor method is not implemented.
30
+
31
+ Returns:
32
+ FeatureExtractor
33
+ """
34
+ if not path.exists(feature_extractor_path):
35
+ raise FileNotFoundError(
36
+ f"Couldn't find feature extractor {feature_extractor_path}.\n"
37
+ + r"If you are using resnet, download it first from:\n"
38
+ + r"r3d101: https://drive.google.com/file/d/1p80RJsghFIKBSLKgtRG94LE38OGY5h4y/view?usp=share_link"
39
+ + "\n"
40
+ + r"r3d152: https://drive.google.com/file/d/1irIdC_v7wa-sBpTiBlsMlS7BYNdj4Gr7/view?usp=share_link"
41
+ )
42
+ logging.info(f"Loading feature extractor from {feature_extractor_path}")
43
+
44
+ model: FeatureExtractor
45
+
46
+ if features_method == "c3d":
47
+ model = C3D(pretrained=feature_extractor_path)
48
+ elif features_method == "mfnet":
49
+ model = MFNET_3D()
50
+ model.load_state(state_dict=feature_extractor_path)
51
+ elif features_method == "r3d101":
52
+ model = generate_model(model_depth=101)
53
+ param_dict = torch.load(feature_extractor_path)["state_dict"]
54
+ param_dict.pop("fc.weight")
55
+ param_dict.pop("fc.bias")
56
+ model.load_state_dict(param_dict)
57
+ elif features_method == "r3d152":
58
+ model = generate_model(model_depth=152)
59
+ param_dict = torch.load(feature_extractor_path)["state_dict"]
60
+ param_dict.pop("fc.weight")
61
+ param_dict.pop("fc.bias")
62
+ model.load_state_dict(param_dict)
63
+ else:
64
+ raise NotImplementedError(
65
+ f"Features extraction method {features_method} not implemented"
66
+ )
67
+
68
+ return model.to(device).eval()
69
+
70
+
71
+ def load_anomaly_detector(ad_model_path: str, device: Device) -> AnomalyDetector:
72
+ """Load anomaly detection model from given path.
73
+
74
+ Args:
75
+ ad_model_path (str): Path to the anomaly detection model.
76
+ device (Device): Device to use for the model.
77
+
78
+ Raises:
79
+ FileNotFoundError: The path to the model does not exist.
80
+
81
+ Returns:
82
+ AnomalyDetector
83
+ """
84
+ if not path.exists(ad_model_path):
85
+ raise FileNotFoundError(f"Couldn't find anomaly detector {ad_model_path}.")
86
+ logging.info(f"Loading anomaly detector from {ad_model_path}")
87
+
88
+ anomaly_detector = TorchModel.load_model(ad_model_path).to(device)
89
+ return anomaly_detector.eval()
90
+
91
+
92
+ def load_models(
93
+ feature_extractor_path: str,
94
+ ad_model_path: str,
95
+ features_method: str = "c3d",
96
+ device: Device = "cuda",
97
+ ) -> Tuple[AnomalyDetector, FeatureExtractor]:
98
+ """Loads both feature extractor and anomaly detector from the given paths.
99
+
100
+ Args:
101
+ feature_extractor_path (str): Path of the features extractor weights to load.
102
+ ad_model_path (str): Path of the anomaly detector weights to load.
103
+ features_method (str, optional): Name of the model to use for features extraction.
104
+ Defaults to "c3d".
105
+ device (str, optional): Device to use for the models. Defaults to "cuda".
106
+
107
+ Returns:
108
+ Tuple[nn.Module, nn.Module]
109
+ """
110
+ feature_extractor = load_feature_extractor(
111
+ features_method, feature_extractor_path, device
112
+ )
113
+ anomaly_detector = load_anomaly_detector(ad_model_path, device)
114
+ return anomaly_detector, feature_extractor