ZeroShot-AD / models /anomaly_detector.py
HoomKh's picture
files
e5461d8 verified
# models/anomaly_detector.py
import torch
import torch.nn.functional as F
import numpy as np
from .glass import GLASS # Ensure correct import
import os
import logging
from torchvision import models
LOGGER = logging.getLogger(__name__)
class AnomalyDetector:
def __init__(self, device='cuda'):
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
# Initialize the backbone (e.g., ResNet-50) without pretrained weights
backbone = models.resnet50(pretrained=False)
# Load backbone weights from local file
backbone_weights_path = './backbones/resnet50_backbone.pth' # Update this path as needed
if os.path.exists(backbone_weights_path):
LOGGER.info(f"Loading ResNet-50 backbone weights from '{backbone_weights_path}'")
checkpoint = torch.load(backbone_weights_path, map_location="cpu")
try:
backbone.load_state_dict(checkpoint, strict=True)
LOGGER.info("ResNet-50 backbone weights loaded successfully.")
except RuntimeError as e:
LOGGER.error(f"Error loading ResNet-50 backbone state_dict: {e}")
raise
else:
LOGGER.error(f"Backbone weights not found at '{backbone_weights_path}'")
raise FileNotFoundError(f"Backbone weights not found at '{backbone_weights_path}'")
# Initialize the GLASS model
self.glass = GLASS(device=self.device)
# Define parameters for GLASS.load() to match training
layers_to_extract_from = ['layer4'] # Extract only the last layer
input_shape = (3, 224, 224) # Match training input shape
pretrain_embed_dimension = 2048 # Corrected dimension for 'layer4' in ResNet-50
target_embed_dimension = 1024 # Match training target dimension
# Initialize GLASS with consistent parameters
self.glass.load(
backbone=backbone,
layers_to_extract_from=layers_to_extract_from,
device=self.device,
input_shape=input_shape,
pretrain_embed_dimension=pretrain_embed_dimension,
target_embed_dimension=target_embed_dimension,
patchsize=3,
patchstride=1,
meta_epochs=640, # Not relevant for inference but required by load method
eval_epochs=1,
dsc_layers=2,
dsc_hidden=1024,
dsc_margin=0.5,
train_backbone=False,
pre_proj=1,
mining=1,
noise=0.015,
radius=0.75,
p=0.5,
lr=0.0001,
svd=0,
step=20,
limit=392,
**{}
)
# Set model directories
model_dir = "./models" # Base directory for models
dataset_name = "rayan_dataset" # Example dataset name
self.glass.set_model_dir(model_dir, dataset_name)
self.glass.to(self.device)
self.glass.eval() # Set GLASS to evaluation mode
# Initialize a cache to keep track of loaded classes
self.loaded_classes = set()
def load_model_weights(self, model_dir, classname):
"""
Load the saved model weights for a specific class.
Args:
model_dir (str): Base directory where models are saved.
classname (str): The class name whose model weights to load.
"""
checkpoint_path = os.path.join(model_dir, classname, f"best_model_{classname}.pth")
if os.path.exists(checkpoint_path):
LOGGER.info(f"Loading model weights from '{checkpoint_path}' for class '{classname}'")
checkpoint = torch.load(checkpoint_path, map_location=self.device)
try:
self.glass.load_state_dict(checkpoint, strict=True)
LOGGER.info(f"Model weights loaded successfully for class '{classname}'")
except RuntimeError as e:
LOGGER.error(f"Error loading state_dict for class '{classname}': {e}")
raise
else:
LOGGER.error(f"Checkpoint not found at '{checkpoint_path}' for class '{classname}'")
raise FileNotFoundError(f"Checkpoint not found at '{checkpoint_path}' for class '{classname}'")
def extract_features(self, image, classname):
"""
Use GLASS to extract features and generate anomaly scores for a specific class.
Args:
image (torch.Tensor): Image tensor of shape [3, H, W]
classname (str): The class name for which to perform anomaly detection.
Returns:
tuple: (image_score, anomaly_map)
"""
# Load model weights for classname if not already loaded
# if classname not in self.loaded_classes:
# try:
# self.load_model_weights(model_dir="./models", classname=classname)
# self.loaded_classes.add(classname)
# except FileNotFoundError as e:
# LOGGER.error(f"Failed to load model weights for class '{classname}': {e}")
# raise
# Reshape image to include batch dimension
image = image.unsqueeze(0).to(self.device) # Shape: [1, 3, H, W]
# Use GLASS to get embeddings
with torch.no_grad():
patch_features, patch_shapes = self.glass._embed(image, evaluation=True)
if self.glass.pre_proj > 0:
patch_features = self.glass.pre_projection(patch_features)
# Handle if pre_projection returns multiple outputs
if isinstance(patch_features, tuple) or isinstance(patch_features, list):
patch_features = patch_features[0]
# Pass through discriminator to get anomaly scores
patch_scores = self.glass.discriminator(patch_features)
patch_scores = self.glass.patch_maker.unpatch_scores(patch_scores, batchsize=image.shape[0])
# Select the last layer's patch_shapes (only one layer now)
last_patch_shape = patch_shapes[-1] # Should be [17, 17]
# Ensure that last_patch_shape is a list or tuple of two integers
if isinstance(last_patch_shape, (list, tuple)) and len(last_patch_shape) == 2:
# Reshape patch_scores to [batch_size, H_patches, W_patches]
# First, squeeze the last dimension
patch_scores = patch_scores.squeeze(-1) # Shape: [1, 289]
# Reshape to [1, 17, 17]
patch_scores = patch_scores.reshape(image.shape[0], *last_patch_shape) # [1, 17, 17]
else:
LOGGER.error(f"Unexpected patch_shapes format: {patch_shapes}")
raise ValueError(f"Unexpected patch_shapes format: {patch_shapes}")
# Compute image-level score (example: mean of patch scores)
image_score = patch_scores.mean().item()
# Anomaly map is the patch_scores itself, normalized
anomaly_map = patch_scores.cpu().numpy()
anomaly_map = np.clip(anomaly_map, 0, 1)
# Log anomaly map statistics for debugging
LOGGER.info(f"Anomaly map stats for class '{classname}': min={anomaly_map.min():.4f}, max={anomaly_map.max():.4f}, mean={anomaly_map.mean():.4f}")
return image_score, anomaly_map
def compute_pixel_score(self, anomaly_map):
"""
Processes the anomaly map for pixel-level evaluation.
Args:
anomaly_map (np.ndarray): Anomaly map of shape [17, 17]
Returns:
np.ndarray: Processed anomaly map of shape [17, 17]
"""
# Normalize anomaly_map to [0, 1]
min_val = anomaly_map.min()
max_val = anomaly_map.max()
if max_val - min_val < 1e-8:
LOGGER.warning("Anomaly map has zero variance. Returning zero map.")
return np.zeros_like(anomaly_map)
anomaly_map = (anomaly_map - min_val) / (max_val - min_val + 1e-8)
return anomaly_map