Spaces:
Build error
Build error
File size: 6,976 Bytes
c8c12e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
"""Base Anomaly Module for Training Task."""
# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
from abc import ABC
from typing import Any, List, Optional, Union
import pytorch_lightning as pl
from omegaconf import DictConfig, ListConfig
from pytorch_lightning.callbacks.base import Callback
from torch import Tensor, nn
from anomalib.utils.metrics import (
AdaptiveThreshold,
AnomalyScoreDistribution,
MinMax,
get_metrics,
)
class AnomalyModule(pl.LightningModule, ABC):
"""AnomalyModule to train, validate, predict and test images.
Acts as a base class for all the Anomaly Modules in the library.
Args:
params (Union[DictConfig, ListConfig]): Configuration
"""
def __init__(self, params: Union[DictConfig, ListConfig]):
super().__init__()
# Force the type for hparams so that it works with OmegaConfig style of accessing
self.hparams: Union[DictConfig, ListConfig] # type: ignore
self.save_hyperparameters(params)
self.loss: Tensor
self.callbacks: List[Callback]
self.image_threshold = AdaptiveThreshold(self.hparams.model.threshold.image_default).cpu()
self.pixel_threshold = AdaptiveThreshold(self.hparams.model.threshold.pixel_default).cpu()
self.training_distribution = AnomalyScoreDistribution().cpu()
self.min_max = MinMax().cpu()
self.model: nn.Module
# metrics
self.image_metrics, self.pixel_metrics = get_metrics(self.hparams)
self.image_metrics.set_threshold(self.hparams.model.threshold.image_default)
self.pixel_metrics.set_threshold(self.hparams.model.threshold.pixel_default)
def forward(self, batch): # pylint: disable=arguments-differ
"""Forward-pass input tensor to the module.
Args:
batch (Tensor): Input Tensor
Returns:
Tensor: Output tensor from the model.
"""
return self.model(batch)
def validation_step(self, batch, batch_idx) -> dict: # type: ignore # pylint: disable=arguments-differ
"""To be implemented in the subclasses."""
raise NotImplementedError
def predict_step(self, batch: Any, batch_idx: int, _dataloader_idx: Optional[int] = None) -> Any:
"""Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`.
By default, it calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`.
Override to add any processing logic.
Args:
batch (Tensor): Current batch
batch_idx (int): Index of current batch
_dataloader_idx (int): Index of the current dataloader
Return:
Predicted output
"""
outputs = self.validation_step(batch, batch_idx)
self._post_process(outputs)
outputs["pred_labels"] = outputs["pred_scores"] >= self.image_threshold.value
if "anomaly_maps" in outputs.keys():
outputs["pred_masks"] = outputs["anomaly_maps"] >= self.pixel_threshold.value
return outputs
def test_step(self, batch, _): # pylint: disable=arguments-differ
"""Calls validation_step for anomaly map/score calculation.
Args:
batch (Tensor): Input batch
_: Index of the batch.
Returns:
Dictionary containing images, features, true labels and masks.
These are required in `validation_epoch_end` for feature concatenation.
"""
return self.validation_step(batch, _)
def validation_step_end(self, val_step_outputs): # pylint: disable=arguments-differ
"""Called at the end of each validation step."""
self._outputs_to_cpu(val_step_outputs)
self._post_process(val_step_outputs)
return val_step_outputs
def test_step_end(self, test_step_outputs): # pylint: disable=arguments-differ
"""Called at the end of each test step."""
self._outputs_to_cpu(test_step_outputs)
self._post_process(test_step_outputs)
return test_step_outputs
def validation_epoch_end(self, outputs):
"""Compute threshold and performance metrics.
Args:
outputs: Batch of outputs from the validation step
"""
if self.hparams.model.threshold.adaptive:
self._compute_adaptive_threshold(outputs)
self._collect_outputs(self.image_metrics, self.pixel_metrics, outputs)
self._log_metrics()
def test_epoch_end(self, outputs):
"""Compute and save anomaly scores of the test set.
Args:
outputs: Batch of outputs from the validation step
"""
self._collect_outputs(self.image_metrics, self.pixel_metrics, outputs)
self._log_metrics()
def _compute_adaptive_threshold(self, outputs):
self._collect_outputs(self.image_threshold, self.pixel_threshold, outputs)
self.image_threshold.compute()
if "mask" in outputs[0].keys() and "anomaly_maps" in outputs[0].keys():
self.pixel_threshold.compute()
else:
self.pixel_threshold.value = self.image_threshold.value
self.image_metrics.set_threshold(self.image_threshold.value.item())
self.pixel_metrics.set_threshold(self.pixel_threshold.value.item())
def _collect_outputs(self, image_metric, pixel_metric, outputs):
for output in outputs:
image_metric.cpu()
image_metric.update(output["pred_scores"], output["label"].int())
if "mask" in output.keys() and "anomaly_maps" in output.keys():
pixel_metric.cpu()
pixel_metric.update(output["anomaly_maps"].flatten(), output["mask"].flatten().int())
def _post_process(self, outputs):
"""Compute labels based on model predictions."""
if "pred_scores" not in outputs and "anomaly_maps" in outputs:
outputs["pred_scores"] = (
outputs["anomaly_maps"].reshape(outputs["anomaly_maps"].shape[0], -1).max(dim=1).values
)
def _outputs_to_cpu(self, output):
# for output in outputs:
for key, value in output.items():
if isinstance(value, Tensor):
output[key] = value.cpu()
def _log_metrics(self):
"""Log computed performance metrics."""
self.log_dict(self.image_metrics)
if self.pixel_metrics.update_called:
self.log_dict(self.pixel_metrics)
|