Spaces:
Build error
Build error
| """Common helpers for both nightly and pre-merge model tests.""" | |
| # Copyright (C) 2020 Intel Corporation | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, | |
| # software distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions | |
| # and limitations under the License. | |
| import os | |
| from typing import Dict, List, Tuple, Union | |
| import numpy as np | |
| from omegaconf import DictConfig, ListConfig | |
| from pytorch_lightning import LightningDataModule, Trainer | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| from anomalib.config import get_configurable_parameters, update_nncf_config | |
| from anomalib.data import get_datamodule | |
| from anomalib.models import get_model | |
| from anomalib.models.components import AnomalyModule | |
| from anomalib.utils.callbacks import VisualizerCallback, get_callbacks | |
| def setup_model_train( | |
| model_name: str, | |
| dataset_path: str, | |
| project_path: str, | |
| nncf: bool, | |
| category: str, | |
| score_type: str = None, | |
| weight_file: str = "weights/model.ckpt", | |
| fast_run: bool = False, | |
| device: Union[List[int], int] = [0], | |
| ) -> Tuple[Union[DictConfig, ListConfig], LightningDataModule, AnomalyModule, Trainer]: | |
| """Train the model based on the parameters passed. | |
| Args: | |
| model_name (str): Name of the model to train. | |
| dataset_path (str): Location of the dataset. | |
| project_path (str): Path to temporary project folder. | |
| nncf (bool): Add nncf callback. | |
| category (str): Category to train on. | |
| score_type (str, optional): Only used for DFM. Defaults to None. | |
| weight_file (str, optional): Path to weight file. | |
| fast_run (bool, optional): If set to true, the model trains for only 1 epoch. We train for one epoch as | |
| this ensures that both anomalous and non-anomalous images are present in the validation step. | |
| device (List[int], int, optional): Select which device you want to train the model on. Defaults to first GPU. | |
| Returns: | |
| Tuple[DictConfig, LightningDataModule, AnomalyModule, Trainer]: config, datamodule, trained model, trainer | |
| """ | |
| config = get_configurable_parameters(model_name=model_name) | |
| if score_type is not None: | |
| config.model.score_type = score_type | |
| config.project.seed = 42 | |
| config.dataset.category = category | |
| config.dataset.path = dataset_path | |
| config.project.log_images_to = [] | |
| config.trainer.gpus = device | |
| # If weight file is empty, remove the key from config | |
| if "weight_file" in config.model.keys() and weight_file == "": | |
| config.model.pop("weight_file") | |
| else: | |
| config.model.weight_file = weight_file if not fast_run else "weights/last.ckpt" | |
| if nncf: | |
| config.optimization.nncf.apply = True | |
| config = update_nncf_config(config) | |
| config.init_weights = None | |
| # reassign project path as config is updated in `update_config_for_nncf` | |
| config.project.path = project_path | |
| datamodule = get_datamodule(config) | |
| model = get_model(config) | |
| callbacks = get_callbacks(config) | |
| # Force model checkpoint to create checkpoint after first epoch | |
| if fast_run == True: | |
| for index, callback in enumerate(callbacks): | |
| if isinstance(callback, ModelCheckpoint): | |
| callbacks.pop(index) | |
| break | |
| model_checkpoint = ModelCheckpoint( | |
| dirpath=os.path.join(config.project.path, "weights"), | |
| filename="last", | |
| monitor=None, | |
| mode="max", | |
| save_last=True, | |
| auto_insert_metric_name=False, | |
| ) | |
| callbacks.append(model_checkpoint) | |
| for index, callback in enumerate(callbacks): | |
| if isinstance(callback, VisualizerCallback): | |
| callbacks.pop(index) | |
| break | |
| # Train the model. | |
| if fast_run: | |
| config.trainer.max_epochs = 1 | |
| config.trainer.check_val_every_n_epoch = 1 | |
| trainer = Trainer(callbacks=callbacks, **config.trainer) | |
| trainer.fit(model=model, datamodule=datamodule) | |
| return config, datamodule, model, trainer | |
| def model_load_test(config: Union[DictConfig, ListConfig], datamodule: LightningDataModule, results: Dict): | |
| """Create a new model based on the weights specified in config. | |
| Args: | |
| config ([Union[DictConfig, ListConfig]): Model config. | |
| datamodule (LightningDataModule): Dataloader | |
| results (Dict): Results from original model. | |
| """ | |
| loaded_model = get_model(config) # get new model | |
| callbacks = get_callbacks(config) | |
| for index, callback in enumerate(callbacks): | |
| # Remove visualizer callback as saving results takes time | |
| if isinstance(callback, VisualizerCallback): | |
| callbacks.pop(index) | |
| break | |
| # create new trainer object with LoadModel callback (assumes it is present) | |
| trainer = Trainer(callbacks=callbacks, **config.trainer) | |
| # Assumes the new model has LoadModel callback and the old one had ModelCheckpoint callback | |
| new_results = trainer.test(model=loaded_model, datamodule=datamodule)[0] | |
| assert np.isclose( | |
| results["image_AUROC"], new_results["image_AUROC"] | |
| ), "Loaded model does not yield close performance results" | |
| if config.dataset.task == "segmentation": | |
| assert np.isclose( | |
| results["pixel_AUROC"], new_results["pixel_AUROC"] | |
| ), "Loaded model does not yield close performance results" | |