Spaces:
Build error
Build error
| """PyTorch model for the PaDiM model implementation.""" | |
| # Copyright (C) 2020 Intel Corporation | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, | |
| # software distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions | |
| # and limitations under the License. | |
| from random import sample | |
| from typing import Dict, List, Optional, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision | |
| from torch import Tensor, nn | |
| from anomalib.models.components import FeatureExtractor, MultiVariateGaussian | |
| from anomalib.models.padim.anomaly_map import AnomalyMapGenerator | |
| from anomalib.pre_processing import Tiler | |
| DIMS = { | |
| "resnet18": {"orig_dims": 448, "reduced_dims": 100, "emb_scale": 4}, | |
| "wide_resnet50_2": {"orig_dims": 1792, "reduced_dims": 550, "emb_scale": 4}, | |
| } | |
| class PadimModel(nn.Module): | |
| """Padim Module. | |
| Args: | |
| layers (List[str]): Layers used for feature extraction | |
| input_size (Tuple[int, int]): Input size for the model. | |
| tile_size (Tuple[int, int]): Tile size | |
| tile_stride (int): Stride for tiling | |
| apply_tiling (bool, optional): Apply tiling. Defaults to False. | |
| backbone (str, optional): Pre-trained model backbone. Defaults to "resnet18". | |
| """ | |
| def __init__( | |
| self, | |
| layers: List[str], | |
| input_size: Tuple[int, int], | |
| backbone: str = "resnet18", | |
| apply_tiling: bool = False, | |
| tile_size: Optional[Tuple[int, int]] = None, | |
| tile_stride: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| self.backbone = getattr(torchvision.models, backbone) | |
| self.layers = layers | |
| self.apply_tiling = apply_tiling | |
| self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.layers) | |
| self.dims = DIMS[backbone] | |
| # pylint: disable=not-callable | |
| # Since idx is randomly selected, save it with model to get same results | |
| self.register_buffer( | |
| "idx", | |
| torch.tensor(sample(range(0, DIMS[backbone]["orig_dims"]), DIMS[backbone]["reduced_dims"])), | |
| ) | |
| self.idx: Tensor | |
| self.loss = None | |
| self.anomaly_map_generator = AnomalyMapGenerator(image_size=input_size) | |
| n_features = DIMS[backbone]["reduced_dims"] | |
| patches_dims = torch.tensor(input_size) / DIMS[backbone]["emb_scale"] | |
| n_patches = patches_dims.ceil().prod().int().item() | |
| self.gaussian = MultiVariateGaussian(n_features, n_patches) | |
| if apply_tiling: | |
| assert tile_size is not None | |
| assert tile_stride is not None | |
| self.tiler = Tiler(tile_size, tile_stride) | |
| def forward(self, input_tensor: Tensor) -> Tensor: | |
| """Forward-pass image-batch (N, C, H, W) into model to extract features. | |
| Args: | |
| input_tensor: Image-batch (N, C, H, W) | |
| input_tensor: Tensor: | |
| Returns: | |
| Features from single/multiple layers. | |
| Example: | |
| >>> x = torch.randn(32, 3, 224, 224) | |
| >>> features = self.extract_features(input_tensor) | |
| >>> features.keys() | |
| dict_keys(['layer1', 'layer2', 'layer3']) | |
| >>> [v.shape for v in features.values()] | |
| [torch.Size([32, 64, 56, 56]), | |
| torch.Size([32, 128, 28, 28]), | |
| torch.Size([32, 256, 14, 14])] | |
| """ | |
| if self.apply_tiling: | |
| input_tensor = self.tiler.tile(input_tensor) | |
| with torch.no_grad(): | |
| features = self.feature_extractor(input_tensor) | |
| embeddings = self.generate_embedding(features) | |
| if self.apply_tiling: | |
| embeddings = self.tiler.untile(embeddings) | |
| if self.training: | |
| output = embeddings | |
| else: | |
| output = self.anomaly_map_generator( | |
| embedding=embeddings, mean=self.gaussian.mean, inv_covariance=self.gaussian.inv_covariance | |
| ) | |
| return output | |
| def generate_embedding(self, features: Dict[str, Tensor]) -> Tensor: | |
| """Generate embedding from hierarchical feature map. | |
| Args: | |
| features (Dict[str, Tensor]): Hierarchical feature map from a CNN (ResNet18 or WideResnet) | |
| Returns: | |
| Embedding vector | |
| """ | |
| embeddings = features[self.layers[0]] | |
| for layer in self.layers[1:]: | |
| layer_embedding = features[layer] | |
| layer_embedding = F.interpolate(layer_embedding, size=embeddings.shape[-2:], mode="nearest") | |
| embeddings = torch.cat((embeddings, layer_embedding), 1) | |
| # subsample embeddings | |
| idx = self.idx.to(embeddings.device) | |
| embeddings = torch.index_select(embeddings, 1, idx) | |
| return embeddings | |