Spaces:
Running
Running
File size: 4,262 Bytes
17ee76b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
"""This module contains functions for loading models."""
import logging
from os import path
from typing import Tuple
import torch
from network.anomaly_detector_model import AnomalyDetector
from network.c3d import C3D
from network.MFNET import MFNET_3D
from network.resnet import generate_model
from network.TorchUtils import TorchModel
from utils.types import Device, FeatureExtractor
def load_feature_extractor(
features_method: str, feature_extractor_path: str, device: Device
) -> FeatureExtractor:
"""Load feature extractor from given path.
Args:
features_method (str): The feature extractor model type to use. Either c3d | mfnet | r3d101 | r3d152.
feature_extractor_path (str): Path to the feature extractor model.
device (Union[torch.device, str]): Device to use for the model.
Raises:
FileNotFoundError: The path to the model does not exist.
NotImplementedError: The provided feature extractor method is not implemented.
Returns:
FeatureExtractor
"""
if not path.exists(feature_extractor_path):
raise FileNotFoundError(
f"Couldn't find feature extractor {feature_extractor_path}.\n"
+ r"If you are using resnet, download it first from:\n"
+ r"r3d101: https://drive.google.com/file/d/1p80RJsghFIKBSLKgtRG94LE38OGY5h4y/view?usp=share_link"
+ "\n"
+ r"r3d152: https://drive.google.com/file/d/1irIdC_v7wa-sBpTiBlsMlS7BYNdj4Gr7/view?usp=share_link"
)
logging.info(f"Loading feature extractor from {feature_extractor_path}")
model: FeatureExtractor
if features_method == "c3d":
model = C3D(pretrained=feature_extractor_path)
elif features_method == "mfnet":
model = MFNET_3D()
model.load_state(state_dict=feature_extractor_path)
elif features_method == "r3d101":
model = generate_model(model_depth=101)
param_dict = torch.load(feature_extractor_path)["state_dict"]
param_dict.pop("fc.weight")
param_dict.pop("fc.bias")
model.load_state_dict(param_dict)
elif features_method == "r3d152":
model = generate_model(model_depth=152)
param_dict = torch.load(feature_extractor_path)["state_dict"]
param_dict.pop("fc.weight")
param_dict.pop("fc.bias")
model.load_state_dict(param_dict)
else:
raise NotImplementedError(
f"Features extraction method {features_method} not implemented"
)
return model.to(device).eval()
def load_anomaly_detector(ad_model_path: str, device: Device) -> AnomalyDetector:
"""Load anomaly detection model from given path.
Args:
ad_model_path (str): Path to the anomaly detection model.
device (Device): Device to use for the model.
Raises:
FileNotFoundError: The path to the model does not exist.
Returns:
AnomalyDetector
"""
if not path.exists(ad_model_path):
raise FileNotFoundError(f"Couldn't find anomaly detector {ad_model_path}.")
logging.info(f"Loading anomaly detector from {ad_model_path}")
anomaly_detector = TorchModel.load_model(ad_model_path).to(device)
return anomaly_detector.eval()
def load_models(
feature_extractor_path: str,
ad_model_path: str,
features_method: str = "c3d",
device: Device = "cuda",
) -> Tuple[AnomalyDetector, FeatureExtractor]:
"""Loads both feature extractor and anomaly detector from the given paths.
Args:
feature_extractor_path (str): Path of the features extractor weights to load.
ad_model_path (str): Path of the anomaly detector weights to load.
features_method (str, optional): Name of the model to use for features extraction.
Defaults to "c3d".
device (str, optional): Device to use for the models. Defaults to "cuda".
Returns:
Tuple[nn.Module, nn.Module]
"""
feature_extractor = load_feature_extractor(
features_method, feature_extractor_path, device
)
anomaly_detector = load_anomaly_detector(ad_model_path, device)
return anomaly_detector, feature_extractor
|