| | """Sample evaluation script for track 2.""" |
| |
|
| | import os |
| | from datetime import datetime |
| | from pathlib import Path |
| |
|
| | |
| | os.environ['TORCH_HOME'] = './checkpoint' |
| | os.environ['HF_HOME'] = './checkpoint/huggingface' |
| | os.environ['TRANSFORMERS_CACHE'] = './checkpoint/huggingface/transformers' |
| | os.environ['HF_HUB_CACHE'] = './checkpoint/huggingface/hub' |
| |
|
| | |
| | os.makedirs('./checkpoint/huggingface/transformers', exist_ok=True) |
| | os.makedirs('./checkpoint/huggingface/hub', exist_ok=True) |
| |
|
| | import argparse |
| | import importlib |
| | import importlib.util |
| |
|
| | import torch |
| | import logging |
| | from torch import nn |
| |
|
| | |
| | |
| | |
| | |
| | |
| | from anomalib.data import MVTecLoco |
| | from anomalib.metrics.f1_max import F1Max |
| | from anomalib.metrics.auroc import AUROC |
| | from tabulate import tabulate |
| | import numpy as np |
| |
|
| | FEW_SHOT_SAMPLES = [0, 1, 2, 3] |
| |
|
| | def write_results_to_markdown(category, results_data, module_path): |
| | """Write evaluation results to markdown file. |
| | |
| | Args: |
| | category (str): Dataset category name |
| | results_data (dict): Dictionary containing all metrics |
| | module_path (str): Model module path (for protocol identification) |
| | """ |
| | |
| | protocol = "Few-shot" if "few_shot" in module_path else "Full-data" |
| | |
| | |
| | results_dir = Path("results") |
| | results_dir.mkdir(exist_ok=True) |
| | |
| | |
| | protocol_suffix = "few_shot" if "few_shot" in module_path else "full_data" |
| | combined_file = results_dir / f"{protocol_suffix}_results.md" |
| | |
| | |
| | existing_results = {} |
| | if combined_file.exists(): |
| | with open(combined_file, 'r') as f: |
| | content = f.read() |
| | |
| | lines = content.split('\n') |
| | for line in lines: |
| | if '|' in line and line.count('|') >= 8: |
| | parts = [p.strip() for p in line.split('|')] |
| | if len(parts) >= 8 and parts[1] != 'Category' and parts[1] != '-----': |
| | existing_results[parts[1]] = { |
| | 'k_shots': parts[2], |
| | 'f1_image': parts[3], |
| | 'auroc_image': parts[4], |
| | 'f1_logical': parts[5], |
| | 'auroc_logical': parts[6], |
| | 'f1_structural': parts[7], |
| | 'auroc_structural': parts[8] |
| | } |
| | |
| | |
| | existing_results[category] = { |
| | 'k_shots': str(len(FEW_SHOT_SAMPLES)), |
| | 'f1_image': f"{results_data['f1_image']:.2f}", |
| | 'auroc_image': f"{results_data['auroc_image']:.2f}", |
| | 'f1_logical': f"{results_data['f1_logical']:.2f}", |
| | 'auroc_logical': f"{results_data['auroc_logical']:.2f}", |
| | 'f1_structural': f"{results_data['f1_structural']:.2f}", |
| | 'auroc_structural': f"{results_data['auroc_structural']:.2f}" |
| | } |
| | |
| | |
| | with open(combined_file, 'w') as f: |
| | f.write(f"# All Categories - {protocol} Protocol Results\n\n") |
| | f.write(f"**Last Updated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") |
| | f.write(f"**Protocol:** {protocol}\n") |
| | f.write(f"**Available Categories:** {', '.join(sorted(existing_results.keys()))}\n\n") |
| | |
| | f.write("## Summary Table\n\n") |
| | f.write("| Category | K-shots | F1-Max (Image) | AUROC (Image) | F1-Max (Logical) | AUROC (Logical) | F1-Max (Structural) | AUROC (Structural) |\n") |
| | f.write("|----------|---------|----------------|---------------|------------------|-----------------|---------------------|-------------------|\n") |
| | |
| | |
| | for cat in sorted(existing_results.keys()): |
| | data = existing_results[cat] |
| | f.write(f"| {cat} | {data['k_shots']} | {data['f1_image']} | {data['auroc_image']} | {data['f1_logical']} | {data['auroc_logical']} | {data['f1_structural']} | {data['auroc_structural']} |\n") |
| | |
| | print(f"\n✓ Results saved to:") |
| | print(f" - Combined: {combined_file}") |
| |
|
| | def parse_args() -> argparse.Namespace: |
| | """Parse command line arguments. |
| | |
| | Returns: |
| | argparse.Namespace: Parsed arguments. |
| | """ |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--module_path", type=str, required=True) |
| | parser.add_argument("--class_name", default='MyModel', type=str, required=False) |
| | parser.add_argument("--weights_path", type=str, required=False) |
| | parser.add_argument("--dataset_path", default='/home/bhu/Project/datasets/mvtec_loco_anomaly_detection/', type=str, required=False) |
| | parser.add_argument("--category", type=str, required=True) |
| | parser.add_argument("--viz", action='store_true', default=False) |
| | return parser.parse_args() |
| |
|
| |
|
| | def load_model(module_path: str, class_name: str, weights_path: str) -> nn.Module: |
| | """Load model. |
| | |
| | Args: |
| | module_path (str): Path to the module containing the model class. |
| | class_name (str): Name of the model class. |
| | weights_path (str): Path to the model weights. |
| | |
| | Returns: |
| | nn.Module: Loaded model. |
| | """ |
| | |
| | model_class = getattr(importlib.import_module(module_path), class_name) |
| | |
| | model = model_class() |
| | |
| | if weights_path: |
| | model.load_state_dict(torch.load(weights_path)) |
| | return model |
| |
|
| |
|
| | def run(module_path: str, class_name: str, weights_path: str, dataset_path: str, category: str, viz: bool) -> None: |
| | """Run the evaluation script. |
| | |
| | Args: |
| | module_path (str): Path to the module containing the model class. |
| | class_name (str): Name of the model class. |
| | weights_path (str): Path to the model weights. |
| | dataset_path (str): Path to the dataset. |
| | category (str): Category of the dataset. |
| | """ |
| | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| |
|
| | |
| | |
| | model = load_model(module_path, class_name, weights_path) |
| | model.to(device) |
| |
|
| | |
| | |
| | datamodule = MVTecLoco(root=dataset_path, eval_batch_size=1, image_size=(448, 448), category=category) |
| | datamodule.setup() |
| |
|
| | model.set_viz(viz) |
| |
|
| | |
| | |
| | image_metric = F1Max() |
| | pixel_metric = F1Max() |
| |
|
| | image_metric_logical = F1Max() |
| | image_metric_structure = F1Max() |
| |
|
| | image_metric_auroc = AUROC() |
| | pixel_metric_auroc = AUROC() |
| |
|
| | image_metric_auroc_logical = AUROC() |
| | image_metric_auroc_structure = AUROC() |
| | |
| |
|
| | |
| | |
| | setup_data = { |
| | "few_shot_samples": torch.stack([datamodule.train_data[idx]["image"] for idx in FEW_SHOT_SAMPLES]).to(device), |
| | "few_shot_samples_path": [datamodule.train_data[idx]["image_path"] for idx in FEW_SHOT_SAMPLES], |
| | "dataset_category": category, |
| | } |
| | model.setup(setup_data) |
| |
|
| | |
| | for data in datamodule.test_dataloader(): |
| | with torch.no_grad(): |
| | image_path = data['image_path'] |
| | output = model(data["image"].to(device), data['image_path']) |
| |
|
| | image_metric.update(output["pred_score"].cpu(), data["label"]) |
| | image_metric_auroc.update(output["pred_score"].cpu(), data["label"]) |
| |
|
| | if 'logical' not in image_path[0]: |
| | image_metric_structure.update(output["pred_score"].cpu(), data["label"]) |
| | image_metric_auroc_structure.update(output["pred_score"].cpu(), data["label"]) |
| | if 'structural' not in image_path[0]: |
| | image_metric_logical.update(output["pred_score"].cpu(), data["label"]) |
| | image_metric_auroc_logical.update(output["pred_score"].cpu(), data["label"]) |
| |
|
| |
|
| |
|
| | |
| | logging.getLogger().setLevel(logging.ERROR) |
| | logging.getLogger('anomalib').setLevel(logging.ERROR) |
| | logging.getLogger('lightning').setLevel(logging.ERROR) |
| | logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) |
| | |
| | |
| | logger = logging.getLogger('evaluation') |
| | logger.handlers.clear() |
| | logger.setLevel(logging.INFO) |
| | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S') |
| | console_handler = logging.StreamHandler() |
| | console_handler.setFormatter(formatter) |
| | logger.addHandler(console_handler) |
| |
|
| | table_ls = [[category, |
| | str(len(FEW_SHOT_SAMPLES)), |
| | str(np.round(image_metric.compute().item() * 100, decimals=2)), |
| | str(np.round(image_metric_auroc.compute().item() * 100, decimals=2)), |
| | |
| | |
| | str(np.round(image_metric_logical.compute().item() * 100, decimals=2)), |
| | str(np.round(image_metric_auroc_logical.compute().item() * 100, decimals=2)), |
| | str(np.round(image_metric_structure.compute().item() * 100, decimals=2)), |
| | str(np.round(image_metric_auroc_structure.compute().item() * 100, decimals=2)), |
| | ]] |
| | |
| | results = tabulate(table_ls, headers=['category', 'K-shots', 'F1-Max(image)', 'AUROC(image)', 'F1-Max (logical)', 'AUROC (logical)', 'F1-Max (structural)', 'AUROC (structural)'], tablefmt="pipe") |
| | |
| | logger.info("\n%s", results) |
| | |
| | |
| | results_data = { |
| | 'f1_image': np.round(image_metric.compute().item() * 100, decimals=2), |
| | 'auroc_image': np.round(image_metric_auroc.compute().item() * 100, decimals=2), |
| | 'f1_logical': np.round(image_metric_logical.compute().item() * 100, decimals=2), |
| | 'auroc_logical': np.round(image_metric_auroc_logical.compute().item() * 100, decimals=2), |
| | 'f1_structural': np.round(image_metric_structure.compute().item() * 100, decimals=2), |
| | 'auroc_structural': np.round(image_metric_auroc_structure.compute().item() * 100, decimals=2) |
| | } |
| | write_results_to_markdown(category, results_data, module_path) |
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| | run(args.module_path, args.class_name, args.weights_path, args.dataset_path, args.category, args.viz) |
| |
|