"""This module contains functions for loading models.""" import logging from os import path from typing import Tuple import torch from network.anomaly_detector_model import AnomalyDetector from network.c3d import C3D from network.MFNET import MFNET_3D from network.resnet import generate_model from network.TorchUtils import TorchModel from utils.types import Device, FeatureExtractor def load_feature_extractor( features_method: str, feature_extractor_path: str, device: Device ) -> FeatureExtractor: """Load feature extractor from given path. Args: features_method (str): The feature extractor model type to use. Either c3d | mfnet | r3d101 | r3d152. feature_extractor_path (str): Path to the feature extractor model. device (Union[torch.device, str]): Device to use for the model. Raises: FileNotFoundError: The path to the model does not exist. NotImplementedError: The provided feature extractor method is not implemented. Returns: FeatureExtractor """ if not path.exists(feature_extractor_path): raise FileNotFoundError( f"Couldn't find feature extractor {feature_extractor_path}.\n" + r"If you are using resnet, download it first from:\n" + r"r3d101: https://drive.google.com/file/d/1p80RJsghFIKBSLKgtRG94LE38OGY5h4y/view?usp=share_link" + "\n" + r"r3d152: https://drive.google.com/file/d/1irIdC_v7wa-sBpTiBlsMlS7BYNdj4Gr7/view?usp=share_link" ) logging.info(f"Loading feature extractor from {feature_extractor_path}") model: FeatureExtractor if features_method == "c3d": model = C3D(pretrained=feature_extractor_path) elif features_method == "mfnet": model = MFNET_3D() model.load_state(state_dict=feature_extractor_path) elif features_method == "r3d101": model = generate_model(model_depth=101) param_dict = torch.load(feature_extractor_path)["state_dict"] param_dict.pop("fc.weight") param_dict.pop("fc.bias") model.load_state_dict(param_dict) elif features_method == "r3d152": model = generate_model(model_depth=152) param_dict = torch.load(feature_extractor_path)["state_dict"] param_dict.pop("fc.weight") param_dict.pop("fc.bias") model.load_state_dict(param_dict) else: raise NotImplementedError( f"Features extraction method {features_method} not implemented" ) return model.to(device).eval() def load_anomaly_detector(ad_model_path: str, device: Device) -> AnomalyDetector: """Load anomaly detection model from given path. Args: ad_model_path (str): Path to the anomaly detection model. device (Device): Device to use for the model. Raises: FileNotFoundError: The path to the model does not exist. Returns: AnomalyDetector """ if not path.exists(ad_model_path): raise FileNotFoundError(f"Couldn't find anomaly detector {ad_model_path}.") logging.info(f"Loading anomaly detector from {ad_model_path}") anomaly_detector = TorchModel.load_model(ad_model_path).to(device) return anomaly_detector.eval() def load_models( feature_extractor_path: str, ad_model_path: str, features_method: str = "c3d", device: Device = "cuda", ) -> Tuple[AnomalyDetector, FeatureExtractor]: """Loads both feature extractor and anomaly detector from the given paths. Args: feature_extractor_path (str): Path of the features extractor weights to load. ad_model_path (str): Path of the anomaly detector weights to load. features_method (str, optional): Name of the model to use for features extraction. Defaults to "c3d". device (str, optional): Device to use for the models. Defaults to "cuda". Returns: Tuple[nn.Module, nn.Module] """ feature_extractor = load_feature_extractor( features_method, feature_extractor_path, device ) anomaly_detector = load_anomaly_detector(ad_model_path, device) return anomaly_detector, feature_extractor