|
|
|
|
| import torch
|
| import torch.nn.functional as F
|
| import numpy as np
|
| from .glass import GLASS
|
| import os
|
| import logging
|
| from torchvision import models
|
|
|
| LOGGER = logging.getLogger(__name__)
|
|
|
| class AnomalyDetector:
|
| def __init__(self, device='cuda'):
|
| self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
| backbone = models.resnet50(pretrained=False)
|
|
|
|
|
| backbone_weights_path = './backbones/resnet50_backbone.pth'
|
| if os.path.exists(backbone_weights_path):
|
| LOGGER.info(f"Loading ResNet-50 backbone weights from '{backbone_weights_path}'")
|
| checkpoint = torch.load(backbone_weights_path, map_location="cpu")
|
| try:
|
| backbone.load_state_dict(checkpoint, strict=True)
|
| LOGGER.info("ResNet-50 backbone weights loaded successfully.")
|
| except RuntimeError as e:
|
| LOGGER.error(f"Error loading ResNet-50 backbone state_dict: {e}")
|
| raise
|
| else:
|
| LOGGER.error(f"Backbone weights not found at '{backbone_weights_path}'")
|
| raise FileNotFoundError(f"Backbone weights not found at '{backbone_weights_path}'")
|
|
|
|
|
| self.glass = GLASS(device=self.device)
|
|
|
|
|
| layers_to_extract_from = ['layer4']
|
| input_shape = (3, 224, 224)
|
| pretrain_embed_dimension = 2048
|
| target_embed_dimension = 1024
|
|
|
|
|
| self.glass.load(
|
| backbone=backbone,
|
| layers_to_extract_from=layers_to_extract_from,
|
| device=self.device,
|
| input_shape=input_shape,
|
| pretrain_embed_dimension=pretrain_embed_dimension,
|
| target_embed_dimension=target_embed_dimension,
|
| patchsize=3,
|
| patchstride=1,
|
| meta_epochs=640,
|
| eval_epochs=1,
|
| dsc_layers=2,
|
| dsc_hidden=1024,
|
| dsc_margin=0.5,
|
| train_backbone=False,
|
| pre_proj=1,
|
| mining=1,
|
| noise=0.015,
|
| radius=0.75,
|
| p=0.5,
|
| lr=0.0001,
|
| svd=0,
|
| step=20,
|
| limit=392,
|
| **{}
|
| )
|
|
|
|
|
| model_dir = "./models"
|
| dataset_name = "rayan_dataset"
|
| self.glass.set_model_dir(model_dir, dataset_name)
|
|
|
| self.glass.to(self.device)
|
| self.glass.eval()
|
|
|
|
|
| self.loaded_classes = set()
|
|
|
| def load_model_weights(self, model_dir, classname):
|
| """
|
| Load the saved model weights for a specific class.
|
|
|
| Args:
|
| model_dir (str): Base directory where models are saved.
|
| classname (str): The class name whose model weights to load.
|
| """
|
| checkpoint_path = os.path.join(model_dir, classname, f"best_model_{classname}.pth")
|
| if os.path.exists(checkpoint_path):
|
| LOGGER.info(f"Loading model weights from '{checkpoint_path}' for class '{classname}'")
|
| checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| try:
|
| self.glass.load_state_dict(checkpoint, strict=True)
|
| LOGGER.info(f"Model weights loaded successfully for class '{classname}'")
|
| except RuntimeError as e:
|
| LOGGER.error(f"Error loading state_dict for class '{classname}': {e}")
|
| raise
|
| else:
|
| LOGGER.error(f"Checkpoint not found at '{checkpoint_path}' for class '{classname}'")
|
| raise FileNotFoundError(f"Checkpoint not found at '{checkpoint_path}' for class '{classname}'")
|
|
|
| def extract_features(self, image, classname):
|
| """
|
| Use GLASS to extract features and generate anomaly scores for a specific class.
|
|
|
| Args:
|
| image (torch.Tensor): Image tensor of shape [3, H, W]
|
| classname (str): The class name for which to perform anomaly detection.
|
|
|
| Returns:
|
| tuple: (image_score, anomaly_map)
|
| """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| image = image.unsqueeze(0).to(self.device)
|
|
|
|
|
| with torch.no_grad():
|
| patch_features, patch_shapes = self.glass._embed(image, evaluation=True)
|
| if self.glass.pre_proj > 0:
|
| patch_features = self.glass.pre_projection(patch_features)
|
|
|
| if isinstance(patch_features, tuple) or isinstance(patch_features, list):
|
| patch_features = patch_features[0]
|
|
|
|
|
| patch_scores = self.glass.discriminator(patch_features)
|
| patch_scores = self.glass.patch_maker.unpatch_scores(patch_scores, batchsize=image.shape[0])
|
|
|
|
|
| last_patch_shape = patch_shapes[-1]
|
|
|
|
|
| if isinstance(last_patch_shape, (list, tuple)) and len(last_patch_shape) == 2:
|
|
|
|
|
| patch_scores = patch_scores.squeeze(-1)
|
|
|
|
|
| patch_scores = patch_scores.reshape(image.shape[0], *last_patch_shape)
|
| else:
|
| LOGGER.error(f"Unexpected patch_shapes format: {patch_shapes}")
|
| raise ValueError(f"Unexpected patch_shapes format: {patch_shapes}")
|
|
|
|
|
| image_score = patch_scores.mean().item()
|
|
|
|
|
| anomaly_map = patch_scores.cpu().numpy()
|
| anomaly_map = np.clip(anomaly_map, 0, 1)
|
|
|
|
|
| LOGGER.info(f"Anomaly map stats for class '{classname}': min={anomaly_map.min():.4f}, max={anomaly_map.max():.4f}, mean={anomaly_map.mean():.4f}")
|
|
|
| return image_score, anomaly_map
|
|
|
| def compute_pixel_score(self, anomaly_map):
|
| """
|
| Processes the anomaly map for pixel-level evaluation.
|
|
|
| Args:
|
| anomaly_map (np.ndarray): Anomaly map of shape [17, 17]
|
|
|
| Returns:
|
| np.ndarray: Processed anomaly map of shape [17, 17]
|
| """
|
|
|
| min_val = anomaly_map.min()
|
| max_val = anomaly_map.max()
|
| if max_val - min_val < 1e-8:
|
| LOGGER.warning("Anomaly map has zero variance. Returning zero map.")
|
| return np.zeros_like(anomaly_map)
|
| anomaly_map = (anomaly_map - min_val) / (max_val - min_val + 1e-8)
|
| return anomaly_map
|
|
|