Spaces:
Build error
Build error
| """Base Anomaly Module for Training Task.""" | |
| # Copyright (C) 2020 Intel Corporation | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, | |
| # software distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions | |
| # and limitations under the License. | |
| from abc import ABC | |
| from typing import Any, List, Optional, Union | |
| import pytorch_lightning as pl | |
| from omegaconf import DictConfig, ListConfig | |
| from pytorch_lightning.callbacks.base import Callback | |
| from torch import Tensor, nn | |
| from anomalib.utils.metrics import ( | |
| AdaptiveThreshold, | |
| AnomalyScoreDistribution, | |
| MinMax, | |
| get_metrics, | |
| ) | |
| class AnomalyModule(pl.LightningModule, ABC): | |
| """AnomalyModule to train, validate, predict and test images. | |
| Acts as a base class for all the Anomaly Modules in the library. | |
| Args: | |
| params (Union[DictConfig, ListConfig]): Configuration | |
| """ | |
| def __init__(self, params: Union[DictConfig, ListConfig]): | |
| super().__init__() | |
| # Force the type for hparams so that it works with OmegaConfig style of accessing | |
| self.hparams: Union[DictConfig, ListConfig] # type: ignore | |
| self.save_hyperparameters(params) | |
| self.loss: Tensor | |
| self.callbacks: List[Callback] | |
| self.image_threshold = AdaptiveThreshold(self.hparams.model.threshold.image_default).cpu() | |
| self.pixel_threshold = AdaptiveThreshold(self.hparams.model.threshold.pixel_default).cpu() | |
| self.training_distribution = AnomalyScoreDistribution().cpu() | |
| self.min_max = MinMax().cpu() | |
| self.model: nn.Module | |
| # metrics | |
| self.image_metrics, self.pixel_metrics = get_metrics(self.hparams) | |
| self.image_metrics.set_threshold(self.hparams.model.threshold.image_default) | |
| self.pixel_metrics.set_threshold(self.hparams.model.threshold.pixel_default) | |
| def forward(self, batch): # pylint: disable=arguments-differ | |
| """Forward-pass input tensor to the module. | |
| Args: | |
| batch (Tensor): Input Tensor | |
| Returns: | |
| Tensor: Output tensor from the model. | |
| """ | |
| return self.model(batch) | |
| def validation_step(self, batch, batch_idx) -> dict: # type: ignore # pylint: disable=arguments-differ | |
| """To be implemented in the subclasses.""" | |
| raise NotImplementedError | |
| def predict_step(self, batch: Any, batch_idx: int, _dataloader_idx: Optional[int] = None) -> Any: | |
| """Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. | |
| By default, it calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`. | |
| Override to add any processing logic. | |
| Args: | |
| batch (Tensor): Current batch | |
| batch_idx (int): Index of current batch | |
| _dataloader_idx (int): Index of the current dataloader | |
| Return: | |
| Predicted output | |
| """ | |
| outputs = self.validation_step(batch, batch_idx) | |
| self._post_process(outputs) | |
| outputs["pred_labels"] = outputs["pred_scores"] >= self.image_threshold.value | |
| if "anomaly_maps" in outputs.keys(): | |
| outputs["pred_masks"] = outputs["anomaly_maps"] >= self.pixel_threshold.value | |
| return outputs | |
| def test_step(self, batch, _): # pylint: disable=arguments-differ | |
| """Calls validation_step for anomaly map/score calculation. | |
| Args: | |
| batch (Tensor): Input batch | |
| _: Index of the batch. | |
| Returns: | |
| Dictionary containing images, features, true labels and masks. | |
| These are required in `validation_epoch_end` for feature concatenation. | |
| """ | |
| return self.validation_step(batch, _) | |
| def validation_step_end(self, val_step_outputs): # pylint: disable=arguments-differ | |
| """Called at the end of each validation step.""" | |
| self._outputs_to_cpu(val_step_outputs) | |
| self._post_process(val_step_outputs) | |
| return val_step_outputs | |
| def test_step_end(self, test_step_outputs): # pylint: disable=arguments-differ | |
| """Called at the end of each test step.""" | |
| self._outputs_to_cpu(test_step_outputs) | |
| self._post_process(test_step_outputs) | |
| return test_step_outputs | |
| def validation_epoch_end(self, outputs): | |
| """Compute threshold and performance metrics. | |
| Args: | |
| outputs: Batch of outputs from the validation step | |
| """ | |
| if self.hparams.model.threshold.adaptive: | |
| self._compute_adaptive_threshold(outputs) | |
| self._collect_outputs(self.image_metrics, self.pixel_metrics, outputs) | |
| self._log_metrics() | |
| def test_epoch_end(self, outputs): | |
| """Compute and save anomaly scores of the test set. | |
| Args: | |
| outputs: Batch of outputs from the validation step | |
| """ | |
| self._collect_outputs(self.image_metrics, self.pixel_metrics, outputs) | |
| self._log_metrics() | |
| def _compute_adaptive_threshold(self, outputs): | |
| self._collect_outputs(self.image_threshold, self.pixel_threshold, outputs) | |
| self.image_threshold.compute() | |
| if "mask" in outputs[0].keys() and "anomaly_maps" in outputs[0].keys(): | |
| self.pixel_threshold.compute() | |
| else: | |
| self.pixel_threshold.value = self.image_threshold.value | |
| self.image_metrics.set_threshold(self.image_threshold.value.item()) | |
| self.pixel_metrics.set_threshold(self.pixel_threshold.value.item()) | |
| def _collect_outputs(self, image_metric, pixel_metric, outputs): | |
| for output in outputs: | |
| image_metric.cpu() | |
| image_metric.update(output["pred_scores"], output["label"].int()) | |
| if "mask" in output.keys() and "anomaly_maps" in output.keys(): | |
| pixel_metric.cpu() | |
| pixel_metric.update(output["anomaly_maps"].flatten(), output["mask"].flatten().int()) | |
| def _post_process(self, outputs): | |
| """Compute labels based on model predictions.""" | |
| if "pred_scores" not in outputs and "anomaly_maps" in outputs: | |
| outputs["pred_scores"] = ( | |
| outputs["anomaly_maps"].reshape(outputs["anomaly_maps"].shape[0], -1).max(dim=1).values | |
| ) | |
| def _outputs_to_cpu(self, output): | |
| # for output in outputs: | |
| for key, value in output.items(): | |
| if isinstance(value, Tensor): | |
| output[key] = value.cpu() | |
| def _log_metrics(self): | |
| """Log computed performance metrics.""" | |
| self.log_dict(self.image_metrics) | |
| if self.pixel_metrics.update_called: | |
| self.log_dict(self.pixel_metrics) | |