Spaces:
Build error
Build error
| """Callbacks for Anomalib models.""" | |
| # Copyright (C) 2020 Intel Corporation | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, | |
| # software distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions | |
| # and limitations under the License. | |
| import os | |
| from importlib import import_module | |
| from typing import List, Union | |
| import yaml | |
| from omegaconf import DictConfig, ListConfig, OmegaConf | |
| from pytorch_lightning.callbacks import Callback, ModelCheckpoint | |
| from .cdf_normalization import CdfNormalizationCallback | |
| from .min_max_normalization import MinMaxNormalizationCallback | |
| from .model_loader import LoadModelCallback | |
| from .timer import TimerCallback | |
| from .visualizer_callback import VisualizerCallback | |
| __all__ = [ | |
| "LoadModelCallback", | |
| "TimerCallback", | |
| "VisualizerCallback", | |
| ] | |
| def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]: | |
| """Return base callbacks for all the lightning models. | |
| Args: | |
| config (DictConfig): Model config | |
| Return: | |
| (List[Callback]): List of callbacks. | |
| """ | |
| callbacks: List[Callback] = [] | |
| monitor_metric = None if "early_stopping" not in config.model.keys() else config.model.early_stopping.metric | |
| monitor_mode = "max" if "early_stopping" not in config.model.keys() else config.model.early_stopping.mode | |
| checkpoint = ModelCheckpoint( | |
| dirpath=os.path.join(config.project.path, "weights"), | |
| filename="model", | |
| monitor=monitor_metric, | |
| mode=monitor_mode, | |
| auto_insert_metric_name=False, | |
| ) | |
| callbacks.extend([checkpoint, TimerCallback()]) | |
| if "weight_file" in config.model.keys(): | |
| load_model = LoadModelCallback(os.path.join(config.project.path, config.model.weight_file)) | |
| callbacks.append(load_model) | |
| if "normalization_method" in config.model.keys() and not config.model.normalization_method == "none": | |
| if config.model.normalization_method == "cdf": | |
| if config.model.name in ["padim", "stfpm"]: | |
| if "nncf" in config.optimization and config.optimization.nncf.apply: | |
| raise NotImplementedError("CDF Score Normalization is currently not compatible with NNCF.") | |
| callbacks.append(CdfNormalizationCallback()) | |
| else: | |
| raise NotImplementedError("Score Normalization is currently supported for PADIM and STFPM only.") | |
| elif config.model.normalization_method == "min_max": | |
| callbacks.append(MinMaxNormalizationCallback()) | |
| else: | |
| raise ValueError(f"Normalization method not recognized: {config.model.normalization_method}") | |
| if not config.project.log_images_to == []: | |
| callbacks.append( | |
| VisualizerCallback( | |
| task=config.dataset.task, inputs_are_normalized=not config.model.normalization_method == "none" | |
| ) | |
| ) | |
| if "optimization" in config.keys(): | |
| if "nncf" in config.optimization and config.optimization.nncf.apply: | |
| # NNCF wraps torch's jit which conflicts with kornia's jit calls. | |
| # Hence, nncf is imported only when required | |
| nncf_module = import_module("anomalib.utils.callbacks.nncf.callback") | |
| nncf_callback = getattr(nncf_module, "NNCFCallback") | |
| nncf_config = yaml.safe_load(OmegaConf.to_yaml(config.optimization.nncf)) | |
| callbacks.append( | |
| nncf_callback( | |
| config=nncf_config, | |
| export_dir=os.path.join(config.project.path, "compressed"), | |
| ) | |
| ) | |
| if "openvino" in config.optimization and config.optimization.openvino.apply: | |
| from .openvino import ( # pylint: disable=import-outside-toplevel | |
| OpenVINOCallback, | |
| ) | |
| callbacks.append( | |
| OpenVINOCallback( | |
| input_size=config.model.input_size, | |
| dirpath=os.path.join(config.project.path, "openvino"), | |
| filename="openvino_model", | |
| ) | |
| ) | |
| return callbacks | |