Spaces:
Build error
Build error
File size: 6,177 Bytes
c8c12e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
"""CFLOW: Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows.
https://arxiv.org/pdf/2107.12571v1.pdf
"""
# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
import logging
import einops
import torch
import torch.nn.functional as F
from pytorch_lightning.callbacks import EarlyStopping
from torch import optim
from anomalib.models.cflow.torch_model import CflowModel
from anomalib.models.cflow.utils import get_logp, positional_encoding_2d
from anomalib.models.components import AnomalyModule
logger = logging.getLogger(__name__)
__all__ = ["CflowLightning"]
class CflowLightning(AnomalyModule):
"""PL Lightning Module for the CFLOW algorithm."""
def __init__(self, hparams):
super().__init__(hparams)
logger.info("Initializing Cflow Lightning model.")
self.model: CflowModel = CflowModel(hparams)
self.loss_val = 0
self.automatic_optimization = False
def configure_callbacks(self):
"""Configure model-specific callbacks."""
early_stopping = EarlyStopping(
monitor=self.hparams.model.early_stopping.metric,
patience=self.hparams.model.early_stopping.patience,
mode=self.hparams.model.early_stopping.mode,
)
return [early_stopping]
def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configures optimizers for each decoder.
Returns:
Optimizer: Adam optimizer for each decoder
"""
decoders_parameters = []
for decoder_idx in range(len(self.model.pool_layers)):
decoders_parameters.extend(list(self.model.decoders[decoder_idx].parameters()))
optimizer = optim.Adam(
params=decoders_parameters,
lr=self.hparams.model.lr,
)
return optimizer
def training_step(self, batch, _): # pylint: disable=arguments-differ
"""Training Step of CFLOW.
For each batch, decoder layers are trained with a dynamic fiber batch size.
Training step is performed manually as multiple training steps are involved
per batch of input images
Args:
batch: Input batch
_: Index of the batch.
Returns:
Loss value for the batch
"""
opt = self.optimizers()
self.model.encoder.eval()
images = batch["image"]
activation = self.model.encoder(images)
avg_loss = torch.zeros([1], dtype=torch.float64).to(images.device)
height = []
width = []
for layer_idx, layer in enumerate(self.model.pool_layers):
encoder_activations = activation[layer].detach() # BxCxHxW
batch_size, dim_feature_vector, im_height, im_width = encoder_activations.size()
image_size = im_height * im_width
embedding_length = batch_size * image_size # number of rows in the conditional vector
height.append(im_height)
width.append(im_width)
# repeats positional encoding for the entire batch 1 C H W to B C H W
pos_encoding = einops.repeat(
positional_encoding_2d(self.model.condition_vector, im_height, im_width).unsqueeze(0),
"b c h w-> (tile b) c h w",
tile=batch_size,
).to(images.device)
c_r = einops.rearrange(pos_encoding, "b c h w -> (b h w) c") # BHWxP
e_r = einops.rearrange(encoder_activations, "b c h w -> (b h w) c") # BHWxC
perm = torch.randperm(embedding_length) # BHW
decoder = self.model.decoders[layer_idx].to(images.device)
fiber_batches = embedding_length // self.model.fiber_batch_size # number of fiber batches
assert fiber_batches > 0, "Make sure we have enough fibers, otherwise decrease N or batch-size!"
for batch_num in range(fiber_batches): # per-fiber processing
opt.zero_grad()
if batch_num < (fiber_batches - 1):
idx = torch.arange(
batch_num * self.model.fiber_batch_size, (batch_num + 1) * self.model.fiber_batch_size
)
else: # When non-full batch is encountered batch_num * N will go out of bounds
idx = torch.arange(batch_num * self.model.fiber_batch_size, embedding_length)
# get random vectors
c_p = c_r[perm[idx]] # NxP
e_p = e_r[perm[idx]] # NxC
# decoder returns the transformed variable z and the log Jacobian determinant
p_u, log_jac_det = decoder(e_p, [c_p])
#
decoder_log_prob = get_logp(dim_feature_vector, p_u, log_jac_det)
log_prob = decoder_log_prob / dim_feature_vector # likelihood per dim
loss = -F.logsigmoid(log_prob)
self.manual_backward(loss.mean())
opt.step()
avg_loss += loss.sum()
return {"loss": avg_loss}
def validation_step(self, batch, _): # pylint: disable=arguments-differ
"""Validation Step of CFLOW.
Similar to the training step, encoder features
are extracted from the CNN for each batch, and anomaly
map is computed.
Args:
batch: Input batch
_: Index of the batch.
Returns:
Dictionary containing images, anomaly maps, true labels and masks.
These are required in `validation_epoch_end` for feature concatenation.
"""
batch["anomaly_maps"] = self.model(batch["image"])
return batch
|