#!/usr/bin/env python3 """ Training script for MVTec AD using Anomalib models. Examples: python train.py # Train Patchcore on bottle (default) python train.py --model all # Train all models on bottle python train.py --category all # Train Patchcore on all categories """ import os import time import json import logging import argparse from pathlib import Path import torch from anomalib.data import MVTecAD from anomalib.engine import Engine from anomalib.data.utils import download_and_extract from core import ( MVTEC_CATEGORIES, DIR_RESULTS, DIR_DATASET, get_available_models, load_model_config, get_class_from_path, get_model_size_mb, format_metric, safe_mean, ) logger = logging.getLogger(__name__) EFFICIENTAD_RESOURCES_DIR = Path(__file__).parent / "efficientad_resources" def _patched_prepare_pretrained_model(self) -> None: """Patched version that uses efficientad_resources/pre_trained/ directory.""" from anomalib.models.image.efficient_ad.lightning_model import WEIGHTS_DOWNLOAD_INFO from anomalib.models.image.efficient_ad.torch_model import EfficientAdModelSize pretrained_models_dir = EFFICIENTAD_RESOURCES_DIR / "pre_trained" pretrained_models_dir.mkdir(parents=True, exist_ok=True) weights_dir = pretrained_models_dir / "efficientad_pretrained_weights" if not weights_dir.is_dir(): download_and_extract(pretrained_models_dir, WEIGHTS_DOWNLOAD_INFO) model_size_str = self.model_size.value if isinstance(self.model_size, EfficientAdModelSize) else self.model_size teacher_path = weights_dir / f"pretrained_teacher_{model_size_str}.pth" logger.info(f"Load pretrained teacher model from {teacher_path}") self.model.teacher.load_state_dict( torch.load(teacher_path, map_location=torch.device(self.device), weights_only=True), ) def patch_efficientad(): """Apply monkey-patch to EfficientAd to use custom pretrained weights directory.""" from anomalib.models import EfficientAd EfficientAd.prepare_pretrained_model = _patched_prepare_pretrained_model print(f" [INFO] EfficientAd: Pretrained weights directory: {EFFICIENTAD_RESOURCES_DIR / 'pre_trained'}") def save_metrics(category_metrics, category, model_name): """Saves metrics in the Anomalib directory structure.""" config = load_model_config(model_name) result_dirname = config["result_dirname"] category_base_dir = DIR_RESULTS / result_dirname / "MVTecAD" / category if not category_base_dir.exists(): return # Find current version (v0, v1, v2, ...) versions = [d.name for d in category_base_dir.iterdir() if d.is_dir() and d.name.startswith('v') and d.name[1:].isdigit()] if not versions: return latest_version = sorted(versions, key=lambda x: int(x[1:]))[-1] # Save in v_n version_dir = category_base_dir / latest_version version_json_path = version_dir / "metrics.json" with open(version_json_path, 'w', encoding='utf-8') as f: json.dump(category_metrics, f, indent=2, ensure_ascii=False) print(f" Saved: {version_json_path}") # Save in latest (only if it exists) latest_dir = category_base_dir / "latest" if latest_dir.exists(): latest_json_path = latest_dir / "metrics.json" with open(latest_json_path, 'w', encoding='utf-8') as f: json.dump(category_metrics, f, indent=2, ensure_ascii=False) def print_category_metrics(metrics): """Prints metrics for a category.""" print(f"\n[METRICS]") print(f" EFFICACY: AUROC img={format_metric(metrics['image_auroc'])} | " f"AUROC pix={format_metric(metrics['pixel_auroc'])} | " f"F1={format_metric(metrics['image_f1'])}") print(f" EFFICIENCY: Train={format_metric(metrics['train_time_s'], 1)}s | " f"Inf={format_metric(metrics['inference_time_ms'], 1)}ms | " f"FPS={format_metric(metrics['fps'], 1)} | " f"Size={format_metric(metrics['model_size_mb'], 1)}MB") def print_final_report(all_metrics, model_name): """Prints final report with all metrics.""" if not all_metrics: return print(f"\n{'='*100}") print(f"FINAL REPORT - {model_name.upper()} PERFORMANCE METRICS") print(f"{'='*100}\n") # Header header = f"{'Category':<12} | {'Img AUROC':<10} | {'Pix AUROC':<10} | {'Img F1':<10} | {'Train(s)':<10} | {'Inf(ms)':<10} | {'FPS':<8} | {'Size(MB)':<10}" print(header) print("-" * len(header)) # Rows for m in all_metrics: print(f"{m['category']:<12} | " f"{format_metric(m['image_auroc']):<10} | " f"{format_metric(m['pixel_auroc']):<10} | " f"{format_metric(m['image_f1']):<10} | " f"{format_metric(m['train_time_s'], 2):<10} | " f"{format_metric(m['inference_time_ms'], 2):<10} | " f"{format_metric(m['fps'], 1):<8} | " f"{format_metric(m['model_size_mb'], 2):<10}") # Average (only if more than one category) if len(all_metrics) > 1: print("-" * len(header)) print(f"{'AVERAGE':<12} | " f"{format_metric(safe_mean([m['image_auroc'] for m in all_metrics])):<10} | " f"{format_metric(safe_mean([m['pixel_auroc'] for m in all_metrics])):<10} | " f"{format_metric(safe_mean([m['image_f1'] for m in all_metrics])):<10} | " f"{format_metric(safe_mean([m['train_time_s'] for m in all_metrics]), 2):<10} | " f"{format_metric(safe_mean([m['inference_time_ms'] for m in all_metrics]), 2):<10} | " f"{format_metric(safe_mean([m['fps'] for m in all_metrics]), 1):<8} | " f"{format_metric(safe_mean([m['model_size_mb'] for m in all_metrics]), 2):<10}") print(f"\n{'='*100}") def train_category(category, model_name): """Runs training, test, and calculates metrics for a category.""" print(f"\n{'='*60}") print(f"Training: {category} ({model_name})") print(f"{'='*60}") # Load config config = load_model_config(model_name) # Initialize data with train_batch_size if specified (required for EfficientAD) train_batch_size = config.get("train_batch_size", 32) datamodule = MVTecAD(root=str(DIR_DATASET), category=category, train_batch_size=train_batch_size) # Initialize model model_class = get_class_from_path(config["class_path"]) model_params = config["init_args"] # EfficientAd-specific setup if model_name == "efficientad": patch_efficientad() model_params["imagenet_dir"] = str(EFFICIENTAD_RESOURCES_DIR / "imagenette") print(f" [INFO] EfficientAd: ImageNet directory: {EFFICIENTAD_RESOURCES_DIR / 'imagenette'}") print(" [INFO] EfficientAd: Image visualization disabled") model_params["visualizer"] = False model = model_class(**model_params) # Training train_start = time.time() max_epochs = config.get("max_epochs", 100) engine = Engine(default_root_dir=str(DIR_RESULTS), max_epochs=max_epochs) engine.fit(model=model, datamodule=datamodule) train_time = time.time() - train_start # Test test_results = engine.test(model=model, datamodule=datamodule) metrics = test_results[0] if test_results else {} # Inference for FPS measurement inference_start = time.time() predictions = engine.predict(model=model, datamodule=datamodule) inference_time = time.time() - inference_start num_images = len(predictions) if predictions else 1 # Collect metrics category_metrics = { "category": category, "image_auroc": metrics.get('image_AUROC'), "pixel_auroc": metrics.get('pixel_AUROC'), "image_f1": metrics.get('image_F1Score'), "train_time_s": train_time, "inference_time_ms": (inference_time / num_images) * 1000, "fps": num_images / inference_time if inference_time > 0 else 0, "model_size_mb": get_model_size_mb(model), } # Output and save print_category_metrics(category_metrics) save_metrics(category_metrics, category, model_name) print(f"\nCompleted: {category}\n") return category_metrics def parse_args(): """Parse command line arguments.""" available_models = get_available_models() parser = argparse.ArgumentParser( description="Training script for MVTec AD using Anomalib", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python train.py # Training on bottle (default model: patchcore) python train.py --model all # Train all models on bottle python train.py --category all # Train Patchcore on all categories python train.py --model all --category all # Train all models on all categories """ ) parser.add_argument( "--category", type=str, default="bottle", choices=MVTEC_CATEGORIES + ["all"], help="Category to train on, or 'all' (default: bottle)" ) parser.add_argument( "--model", type=str, default="patchcore", choices=available_models + ["all"], help="Model to use, or 'all' (default: patchcore)" ) return parser.parse_args() def main(): args = parse_args() if args.category == "all": categories = MVTEC_CATEGORIES print(f"Training on ALL {len(categories)} categories") else: categories = [args.category] print(f"Training on: {args.category}") if args.model == "all": models = get_available_models() print(f"Models: ALL ({', '.join(models)})") else: models = [args.model] print(f"Model: {args.model}") DIR_RESULTS.mkdir(parents=True, exist_ok=True) all_metrics = [] for model_name in models: if len(models) > 1: print(f"\n{'='*60}") print(f"MODEL: {model_name.upper()}") print(f"{'='*60}") model_metrics = [train_category(cat, model_name) for cat in categories] all_metrics.extend(model_metrics) print_final_report(model_metrics, model_name) if __name__ == "__main__": main()