# main.py import os import torch from torch.utils.data import DataLoader from datasets.all_classes_dataset import AllClassesDataset, DatasetSplit from models.anomaly_detector import AnomalyDetector from utils.dump_scores import DumpScores import logging import json from sklearn.metrics import average_precision_score, roc_auc_score, f1_score import numpy as np import torch.nn.functional as F import random def set_seed(seed: int): """ Set the seed for reproducibility across various libraries. Args: seed (int): The seed value to be set. """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # For multi-GPU setups # Ensure deterministic behavior in PyTorch torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # For DataLoader workers os.environ['PYTHONHASHSEED'] = str(seed) def worker_init_fn(worker_id): """ Initialize the seed for each DataLoader worker to ensure reproducibility. Args: worker_id (int): The worker ID. """ seed = torch.initial_seed() np.random.seed(seed % 2**32) random.seed(seed % 2**32) def compute_aupro(y_true_pixel, y_scores_pixel, num_thresholds=50): """ Compute Area Under the Per-Region Overlap Curve (AUPRO). Args: y_true_pixel (np.ndarray): Ground truth binary masks, shape [N, H, W] y_scores_pixel (np.ndarray): Predicted anomaly scores, shape [N, H, W] num_thresholds (int): Number of thresholds to evaluate. Returns: float: AUPRO score. """ # Define thresholds thresholds = np.linspace(0, 1, num_thresholds) # Initialize list to store overlaps overlaps = [] for thresh in thresholds: # Binarize predictions y_pred = (y_scores_pixel >= thresh).astype(int) # Compute Intersection over Union (IoU) for each sample ious = [] for gt, pred in zip(y_true_pixel, y_pred): intersection = np.logical_and(gt, pred).sum() union = np.logical_or(gt, pred).sum() if union == 0: iou = 1.0 # If both gt and pred are all zeros else: iou = intersection / union ious.append(iou) # Average IoU over all samples avg_iou = np.mean(ious) overlaps.append(avg_iou) # Compute the area under the overlap curve aupro = np.trapz(overlaps, thresholds) / np.trapz([1] * len(thresholds), thresholds) # Normalize return aupro def compute_metrics(y_true_image, y_scores_image, y_true_pixel, y_scores_pixel): """ Compute the required metrics based on true labels and predicted scores. Args: y_true_image (np.ndarray): Ground truth image labels, shape [N] y_scores_image (np.ndarray): Predicted image scores, shape [N] y_true_pixel (np.ndarray): Ground truth pixel masks, shape [N, H, W] y_scores_pixel (np.ndarray): Predicted pixel anomaly scores, shape [N, H, W] Returns: dict: Dictionary containing computed metrics. """ # Check image-level consistency if len(y_true_image) != len(y_scores_image): raise ValueError(f"Image-level y_true and y_scores have different lengths: {len(y_true_image)} vs {len(y_scores_image)}") # Check pixel-level consistency if y_true_pixel.shape != y_scores_pixel.shape: raise ValueError(f"Pixel-level y_true and y_scores have different shapes: {y_true_pixel.shape} vs {y_scores_pixel.shape}") # Image-level Metrics image_ap = average_precision_score(y_true_image, y_scores_image) image_auroc = roc_auc_score(y_true_image, y_scores_image) y_pred_image = (y_scores_image >= 0.5).astype(int) image_f1 = f1_score(y_true_image, y_pred_image) # Pixel-level Metrics pixel_ap = average_precision_score(y_true_pixel.flatten(), y_scores_pixel.flatten()) pixel_auroc = roc_auc_score(y_true_pixel.flatten(), y_scores_pixel.flatten()) pixel_aupro = compute_aupro(y_true_pixel, y_scores_pixel) y_pred_pixel = (y_scores_pixel >= 0.5).astype(int) pixel_f1 = f1_score(y_true_pixel.flatten(), y_pred_pixel.flatten()) # Compute leaderboard_score as a weighted average (example weights) # Adjust weights as per your specific requirements leaderboard_score = ( 0.25 * image_auroc + 0.25 * image_f1 + 0.25 * pixel_auroc + 0.25 * pixel_f1 ) metrics = { "image_metrics": { "image_ap": round(float(image_ap), 4), "image_auroc": round(float(image_auroc), 4), "image_f1": round(float(image_f1), 4) }, "pixel_metrics": { "pixel_ap": round(float(pixel_ap), 4), "pixel_aupro": round(float(pixel_aupro), 4), "pixel_auroc": round(float(pixel_auroc), 4), "pixel_f1": round(float(pixel_f1), 4) }, "overall_metric": { "leaderboard_score": round(float(leaderboard_score), 4) } } return metrics def get_class_name(image_path, source_dir): """ Extract the class name from the image path. Args: image_path (str): Path to the image file. source_dir (str): Root source directory. Returns: str: Class name. """ # Example image_path: "./data/pill/test/broken/image1.png" rel_path = os.path.relpath(image_path, source_dir) # "pill/test/broken/image1.png" parts = rel_path.split(os.sep) if len(parts) < 2: raise ValueError(f"Unexpected image path format: {image_path}") class_name = parts[0] # "pill" return class_name def main(): SEED = 41 # You can choose any integer value set_seed(SEED) # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Configuration source_dir = "./data" output_scores_dir = "./output_scores" split = DatasetSplit.TEST # Use the Enum instead of string device = "cuda:0" if torch.cuda.is_available() else "cpu" logging.info("Initializing the dataset and dataloader...") # Initialize dataset and dataloader using AllClassesDataset with output_size=17 dataset = AllClassesDataset( source=source_dir, split=split, # output_size=16 # Set to match anomaly_map resolution ) dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0) logging.info("Initializing the anomaly detector...") # Initialize anomaly detector detector = AnomalyDetector(device=device) # Initialize DumpScores dump_scores = DumpScores(output_dir=output_scores_dir) logging.info("Starting anomaly detection inference...") # Initialize containers for metrics classes = dataset.get_all_class_names() metrics_data = {cls: { "y_true_image": [], "y_scores_image": [], "y_true_pixel": [], "y_scores_pixel": [] } for cls in classes} # Iterate through the dataset for batch_idx, batch in enumerate(dataloader): image = batch['image'].squeeze(0) # Shape: [3, H, W] mask = batch['mask'].squeeze(1).numpy() # Remove all singleton dimensions to get [17, 17] image_label = batch['is_anomaly'].item() # 1 or 0 image_path = batch['image_path'][0] # Assuming batch_size=1 # Extract class name from image_path try: class_name = get_class_name(image_path, source_dir) except ValueError as e: logging.error(f"Error extracting class name: {e}") continue # Skip this sample # Extract features and compute scores using GLASS image_score, anomaly_map = detector.extract_features(image, "all") # Compute pixel-level anomaly score (already normalized) pixel_score = detector.compute_pixel_score(anomaly_map).squeeze() pixel_score_tensor = torch.from_numpy(pixel_score).float().unsqueeze(0).unsqueeze(0).to( device) # Shape: [1, 1, 17, 17] # **Upsample pixel_score to (224, 224)** # Option 1: Using PyTorch Interpolation pixel_score = F.interpolate( pixel_score_tensor, # Add batch and channel dimensions size=(224, 224), mode='bilinear', align_corners=False ).squeeze(0).cpu().numpy() # Removes all singleton dimensions, resulting in [224, 224] # Option 2: Using OpenCV (Uncomment if preferred) # pixel_score_np = pixel_score.numpy() # pixel_score = cv2.resize( # pixel_score, # dsize=(224, 224), # interpolation=cv2.INTER_LINEAR # ) # **Optional: Verify the upsampled pixel_score shape** # if pixel_score.shape != (1, 224, 224): # logging.warning( # f"Upsampled pixel score shape mismatch for image {image_path}: expected (224, 224), got {pixel_score.shape}") # continue # Skip this sample # Append to metrics_data metrics_data[class_name]["y_true_image"].append(image_label) metrics_data[class_name]["y_scores_image"].append(image_score) metrics_data[class_name]["y_true_pixel"].append(mask) metrics_data[class_name]["y_scores_pixel"].append(pixel_score) # Save individual image scores dump_scores.save_scores([image_path], [image_score], [pixel_score]) logging.info(f"[{batch_idx + 1}/{len(dataloader)}] Processed image: {image_path}") logging.info(f"Image-level score: {image_score:.4f}") logging.info(f"Pixel-level mean score: {pixel_score.mean():.4f}") logging.info("Anomaly detection inference completed. Computing metrics...") # Initialize dictionary to hold metrics per class classes_metrics = {} for cls in classes: y_true_image = np.array(metrics_data[cls]["y_true_image"]) y_scores_image = np.array(metrics_data[cls]["y_scores_image"]) y_true_pixel = np.array(metrics_data[cls]["y_true_pixel"]) y_scores_pixel = np.array(metrics_data[cls]["y_scores_pixel"]) # Check if there are any samples for the class if len(y_true_image) == 0: logging.warning(f"No samples found for class {cls}. Skipping metric computation.") continue try: metrics = compute_metrics(y_true_image, y_scores_image, y_true_pixel, y_scores_pixel) classes_metrics[cls] = metrics logging.info(f"Metrics computed for class: {cls}") except Exception as e: logging.error(f"Failed to compute metrics for class {cls}: {e}") # Save metrics to JSON os.makedirs(output_scores_dir, exist_ok=True) metrics_json_path = os.path.join(output_scores_dir, "metrics.json") try: with open(metrics_json_path, "w") as f: json.dump(classes_metrics, f, indent=4) logging.info(f"Metrics successfully saved to {metrics_json_path}") except Exception as e: logging.error(f"Failed to save metrics to {metrics_json_path}: {e}") if __name__ == "__main__": main()