|
|
|
|
| import os
|
| import torch
|
| from torch.utils.data import DataLoader
|
| from datasets.all_classes_dataset import AllClassesDataset, DatasetSplit
|
| from models.anomaly_detector import AnomalyDetector
|
| from utils.dump_scores import DumpScores
|
| import logging
|
| import json
|
| from sklearn.metrics import average_precision_score, roc_auc_score, f1_score
|
| import numpy as np
|
| import torch.nn.functional as F
|
| import random
|
|
|
| def set_seed(seed: int):
|
| """
|
| Set the seed for reproducibility across various libraries.
|
|
|
| Args:
|
| seed (int): The seed value to be set.
|
| """
|
| random.seed(seed)
|
| np.random.seed(seed)
|
| torch.manual_seed(seed)
|
|
|
| if torch.cuda.is_available():
|
| torch.cuda.manual_seed(seed)
|
| torch.cuda.manual_seed_all(seed)
|
|
|
|
|
| torch.backends.cudnn.deterministic = True
|
| torch.backends.cudnn.benchmark = False
|
|
|
|
|
| os.environ['PYTHONHASHSEED'] = str(seed)
|
|
|
| def worker_init_fn(worker_id):
|
| """
|
| Initialize the seed for each DataLoader worker to ensure reproducibility.
|
|
|
| Args:
|
| worker_id (int): The worker ID.
|
| """
|
| seed = torch.initial_seed()
|
| np.random.seed(seed % 2**32)
|
| random.seed(seed % 2**32)
|
|
|
| def compute_aupro(y_true_pixel, y_scores_pixel, num_thresholds=50):
|
| """
|
| Compute Area Under the Per-Region Overlap Curve (AUPRO).
|
|
|
| Args:
|
| y_true_pixel (np.ndarray): Ground truth binary masks, shape [N, H, W]
|
| y_scores_pixel (np.ndarray): Predicted anomaly scores, shape [N, H, W]
|
| num_thresholds (int): Number of thresholds to evaluate.
|
|
|
| Returns:
|
| float: AUPRO score.
|
| """
|
|
|
| thresholds = np.linspace(0, 1, num_thresholds)
|
|
|
|
|
| overlaps = []
|
|
|
| for thresh in thresholds:
|
|
|
| y_pred = (y_scores_pixel >= thresh).astype(int)
|
|
|
|
|
| ious = []
|
| for gt, pred in zip(y_true_pixel, y_pred):
|
| intersection = np.logical_and(gt, pred).sum()
|
| union = np.logical_or(gt, pred).sum()
|
| if union == 0:
|
| iou = 1.0
|
| else:
|
| iou = intersection / union
|
| ious.append(iou)
|
|
|
|
|
| avg_iou = np.mean(ious)
|
| overlaps.append(avg_iou)
|
|
|
|
|
| aupro = np.trapz(overlaps, thresholds) / np.trapz([1] * len(thresholds), thresholds)
|
| return aupro
|
|
|
|
|
| def compute_metrics(y_true_image, y_scores_image, y_true_pixel, y_scores_pixel):
|
| """
|
| Compute the required metrics based on true labels and predicted scores.
|
|
|
| Args:
|
| y_true_image (np.ndarray): Ground truth image labels, shape [N]
|
| y_scores_image (np.ndarray): Predicted image scores, shape [N]
|
| y_true_pixel (np.ndarray): Ground truth pixel masks, shape [N, H, W]
|
| y_scores_pixel (np.ndarray): Predicted pixel anomaly scores, shape [N, H, W]
|
|
|
| Returns:
|
| dict: Dictionary containing computed metrics.
|
| """
|
|
|
| if len(y_true_image) != len(y_scores_image):
|
| raise ValueError(f"Image-level y_true and y_scores have different lengths: {len(y_true_image)} vs {len(y_scores_image)}")
|
|
|
|
|
| if y_true_pixel.shape != y_scores_pixel.shape:
|
| raise ValueError(f"Pixel-level y_true and y_scores have different shapes: {y_true_pixel.shape} vs {y_scores_pixel.shape}")
|
|
|
|
|
| image_ap = average_precision_score(y_true_image, y_scores_image)
|
| image_auroc = roc_auc_score(y_true_image, y_scores_image)
|
| y_pred_image = (y_scores_image >= 0.5).astype(int)
|
| image_f1 = f1_score(y_true_image, y_pred_image)
|
|
|
|
|
| pixel_ap = average_precision_score(y_true_pixel.flatten(), y_scores_pixel.flatten())
|
| pixel_auroc = roc_auc_score(y_true_pixel.flatten(), y_scores_pixel.flatten())
|
| pixel_aupro = compute_aupro(y_true_pixel, y_scores_pixel)
|
| y_pred_pixel = (y_scores_pixel >= 0.5).astype(int)
|
| pixel_f1 = f1_score(y_true_pixel.flatten(), y_pred_pixel.flatten())
|
|
|
|
|
|
|
| leaderboard_score = (
|
| 0.25 * image_auroc +
|
| 0.25 * image_f1 +
|
| 0.25 * pixel_auroc +
|
| 0.25 * pixel_f1
|
| )
|
|
|
| metrics = {
|
| "image_metrics": {
|
| "image_ap": round(float(image_ap), 4),
|
| "image_auroc": round(float(image_auroc), 4),
|
| "image_f1": round(float(image_f1), 4)
|
| },
|
| "pixel_metrics": {
|
| "pixel_ap": round(float(pixel_ap), 4),
|
| "pixel_aupro": round(float(pixel_aupro), 4),
|
| "pixel_auroc": round(float(pixel_auroc), 4),
|
| "pixel_f1": round(float(pixel_f1), 4)
|
| },
|
| "overall_metric": {
|
| "leaderboard_score": round(float(leaderboard_score), 4)
|
| }
|
| }
|
|
|
| return metrics
|
|
|
|
|
| def get_class_name(image_path, source_dir):
|
| """
|
| Extract the class name from the image path.
|
|
|
| Args:
|
| image_path (str): Path to the image file.
|
| source_dir (str): Root source directory.
|
|
|
| Returns:
|
| str: Class name.
|
| """
|
|
|
| rel_path = os.path.relpath(image_path, source_dir)
|
| parts = rel_path.split(os.sep)
|
| if len(parts) < 2:
|
| raise ValueError(f"Unexpected image path format: {image_path}")
|
| class_name = parts[0]
|
| return class_name
|
|
|
|
|
| def main():
|
| SEED = 41
|
| set_seed(SEED)
|
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
| source_dir = "./data"
|
| output_scores_dir = "./output_scores"
|
| split = DatasetSplit.TEST
|
| device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
| logging.info("Initializing the dataset and dataloader...")
|
|
|
|
|
| dataset = AllClassesDataset(
|
| source=source_dir,
|
| split=split,
|
|
|
| )
|
| dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0)
|
|
|
| logging.info("Initializing the anomaly detector...")
|
|
|
| detector = AnomalyDetector(device=device)
|
|
|
|
|
| dump_scores = DumpScores(output_dir=output_scores_dir)
|
|
|
| logging.info("Starting anomaly detection inference...")
|
|
|
| classes = dataset.get_all_class_names()
|
| metrics_data = {cls: {
|
| "y_true_image": [],
|
| "y_scores_image": [],
|
| "y_true_pixel": [],
|
| "y_scores_pixel": []
|
| } for cls in classes}
|
|
|
|
|
| for batch_idx, batch in enumerate(dataloader):
|
| image = batch['image'].squeeze(0)
|
| mask = batch['mask'].squeeze(1).numpy()
|
| image_label = batch['is_anomaly'].item()
|
| image_path = batch['image_path'][0]
|
|
|
|
|
| try:
|
| class_name = get_class_name(image_path, source_dir)
|
| except ValueError as e:
|
| logging.error(f"Error extracting class name: {e}")
|
| continue
|
|
|
|
|
| image_score, anomaly_map = detector.extract_features(image, "all")
|
|
|
|
|
| pixel_score = detector.compute_pixel_score(anomaly_map).squeeze()
|
|
|
| pixel_score_tensor = torch.from_numpy(pixel_score).float().unsqueeze(0).unsqueeze(0).to(
|
| device)
|
|
|
|
|
|
|
| pixel_score = F.interpolate(
|
| pixel_score_tensor,
|
| size=(224, 224),
|
| mode='bilinear',
|
| align_corners=False
|
| ).squeeze(0).cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| metrics_data[class_name]["y_true_image"].append(image_label)
|
| metrics_data[class_name]["y_scores_image"].append(image_score)
|
| metrics_data[class_name]["y_true_pixel"].append(mask)
|
| metrics_data[class_name]["y_scores_pixel"].append(pixel_score)
|
|
|
|
|
| dump_scores.save_scores([image_path], [image_score], [pixel_score])
|
|
|
| logging.info(f"[{batch_idx + 1}/{len(dataloader)}] Processed image: {image_path}")
|
| logging.info(f"Image-level score: {image_score:.4f}")
|
| logging.info(f"Pixel-level mean score: {pixel_score.mean():.4f}")
|
|
|
| logging.info("Anomaly detection inference completed. Computing metrics...")
|
|
|
|
|
| classes_metrics = {}
|
|
|
| for cls in classes:
|
| y_true_image = np.array(metrics_data[cls]["y_true_image"])
|
| y_scores_image = np.array(metrics_data[cls]["y_scores_image"])
|
| y_true_pixel = np.array(metrics_data[cls]["y_true_pixel"])
|
| y_scores_pixel = np.array(metrics_data[cls]["y_scores_pixel"])
|
|
|
|
|
| if len(y_true_image) == 0:
|
| logging.warning(f"No samples found for class {cls}. Skipping metric computation.")
|
| continue
|
|
|
| try:
|
| metrics = compute_metrics(y_true_image, y_scores_image, y_true_pixel, y_scores_pixel)
|
| classes_metrics[cls] = metrics
|
| logging.info(f"Metrics computed for class: {cls}")
|
| except Exception as e:
|
| logging.error(f"Failed to compute metrics for class {cls}: {e}")
|
|
|
|
|
| os.makedirs(output_scores_dir, exist_ok=True)
|
| metrics_json_path = os.path.join(output_scores_dir, "metrics.json")
|
| try:
|
| with open(metrics_json_path, "w") as f:
|
| json.dump(classes_metrics, f, indent=4)
|
| logging.info(f"Metrics successfully saved to {metrics_json_path}")
|
| except Exception as e:
|
| logging.error(f"Failed to save metrics to {metrics_json_path}: {e}")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|