Spaces:
Build error
Build error
| """CFLOW: Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows. | |
| https://arxiv.org/pdf/2107.12571v1.pdf | |
| """ | |
| # Copyright (C) 2020 Intel Corporation | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, | |
| # software distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions | |
| # and limitations under the License. | |
| import logging | |
| import einops | |
| import torch | |
| import torch.nn.functional as F | |
| from pytorch_lightning.callbacks import EarlyStopping | |
| from torch import optim | |
| from anomalib.models.cflow.torch_model import CflowModel | |
| from anomalib.models.cflow.utils import get_logp, positional_encoding_2d | |
| from anomalib.models.components import AnomalyModule | |
| logger = logging.getLogger(__name__) | |
| __all__ = ["CflowLightning"] | |
| class CflowLightning(AnomalyModule): | |
| """PL Lightning Module for the CFLOW algorithm.""" | |
| def __init__(self, hparams): | |
| super().__init__(hparams) | |
| logger.info("Initializing Cflow Lightning model.") | |
| self.model: CflowModel = CflowModel(hparams) | |
| self.loss_val = 0 | |
| self.automatic_optimization = False | |
| def configure_callbacks(self): | |
| """Configure model-specific callbacks.""" | |
| early_stopping = EarlyStopping( | |
| monitor=self.hparams.model.early_stopping.metric, | |
| patience=self.hparams.model.early_stopping.patience, | |
| mode=self.hparams.model.early_stopping.mode, | |
| ) | |
| return [early_stopping] | |
| def configure_optimizers(self) -> torch.optim.Optimizer: | |
| """Configures optimizers for each decoder. | |
| Returns: | |
| Optimizer: Adam optimizer for each decoder | |
| """ | |
| decoders_parameters = [] | |
| for decoder_idx in range(len(self.model.pool_layers)): | |
| decoders_parameters.extend(list(self.model.decoders[decoder_idx].parameters())) | |
| optimizer = optim.Adam( | |
| params=decoders_parameters, | |
| lr=self.hparams.model.lr, | |
| ) | |
| return optimizer | |
| def training_step(self, batch, _): # pylint: disable=arguments-differ | |
| """Training Step of CFLOW. | |
| For each batch, decoder layers are trained with a dynamic fiber batch size. | |
| Training step is performed manually as multiple training steps are involved | |
| per batch of input images | |
| Args: | |
| batch: Input batch | |
| _: Index of the batch. | |
| Returns: | |
| Loss value for the batch | |
| """ | |
| opt = self.optimizers() | |
| self.model.encoder.eval() | |
| images = batch["image"] | |
| activation = self.model.encoder(images) | |
| avg_loss = torch.zeros([1], dtype=torch.float64).to(images.device) | |
| height = [] | |
| width = [] | |
| for layer_idx, layer in enumerate(self.model.pool_layers): | |
| encoder_activations = activation[layer].detach() # BxCxHxW | |
| batch_size, dim_feature_vector, im_height, im_width = encoder_activations.size() | |
| image_size = im_height * im_width | |
| embedding_length = batch_size * image_size # number of rows in the conditional vector | |
| height.append(im_height) | |
| width.append(im_width) | |
| # repeats positional encoding for the entire batch 1 C H W to B C H W | |
| pos_encoding = einops.repeat( | |
| positional_encoding_2d(self.model.condition_vector, im_height, im_width).unsqueeze(0), | |
| "b c h w-> (tile b) c h w", | |
| tile=batch_size, | |
| ).to(images.device) | |
| c_r = einops.rearrange(pos_encoding, "b c h w -> (b h w) c") # BHWxP | |
| e_r = einops.rearrange(encoder_activations, "b c h w -> (b h w) c") # BHWxC | |
| perm = torch.randperm(embedding_length) # BHW | |
| decoder = self.model.decoders[layer_idx].to(images.device) | |
| fiber_batches = embedding_length // self.model.fiber_batch_size # number of fiber batches | |
| assert fiber_batches > 0, "Make sure we have enough fibers, otherwise decrease N or batch-size!" | |
| for batch_num in range(fiber_batches): # per-fiber processing | |
| opt.zero_grad() | |
| if batch_num < (fiber_batches - 1): | |
| idx = torch.arange( | |
| batch_num * self.model.fiber_batch_size, (batch_num + 1) * self.model.fiber_batch_size | |
| ) | |
| else: # When non-full batch is encountered batch_num * N will go out of bounds | |
| idx = torch.arange(batch_num * self.model.fiber_batch_size, embedding_length) | |
| # get random vectors | |
| c_p = c_r[perm[idx]] # NxP | |
| e_p = e_r[perm[idx]] # NxC | |
| # decoder returns the transformed variable z and the log Jacobian determinant | |
| p_u, log_jac_det = decoder(e_p, [c_p]) | |
| # | |
| decoder_log_prob = get_logp(dim_feature_vector, p_u, log_jac_det) | |
| log_prob = decoder_log_prob / dim_feature_vector # likelihood per dim | |
| loss = -F.logsigmoid(log_prob) | |
| self.manual_backward(loss.mean()) | |
| opt.step() | |
| avg_loss += loss.sum() | |
| return {"loss": avg_loss} | |
| def validation_step(self, batch, _): # pylint: disable=arguments-differ | |
| """Validation Step of CFLOW. | |
| Similar to the training step, encoder features | |
| are extracted from the CNN for each batch, and anomaly | |
| map is computed. | |
| Args: | |
| batch: Input batch | |
| _: Index of the batch. | |
| Returns: | |
| Dictionary containing images, anomaly maps, true labels and masks. | |
| These are required in `validation_epoch_end` for feature concatenation. | |
| """ | |
| batch["anomaly_maps"] = self.model(batch["image"]) | |
| return batch | |