Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from segmenters import PCASegmenter | |
| from segmenters.sam3 import SAM3Segmenter | |
| from backbones import get_backbone | |
| from dataset.dataloader import load_mvtec | |
| from models.model_bank_knn import PatchKNNDetector | |
| from evaluation.anomaly_evaluator import AnomalyEvaluator | |
| def main( | |
| category: str = "bottle", | |
| root: str | None = None, | |
| backbone_name: str = "dinov3_small", | |
| use_sam3: bool = True, | |
| use_pca: bool = False, | |
| pca_backbone_name: str | None = None, | |
| return_results: bool = False, | |
| backbone_model=None, | |
| segmenter_obj=None, | |
| n_ref: int = 1, | |
| ): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # 1. Setup | |
| root = root or os.path.join("dataset", "mvtec_anomaly_detection") | |
| train_paths, test_paths = load_mvtec(category=category, root=root) | |
| test_paths = test_paths[::5] | |
| train_paths = train_paths[::5] | |
| print(f"{category}: {len(train_paths)} train, {len(test_paths)} test images") | |
| # 2. Initialize Evaluator with PRO support | |
| evaluator = AnomalyEvaluator(pixel_subsample_rate=0.01, compute_pro=False) | |
| # 3. Model Init | |
| segmenter = segmenter_obj | |
| if segmenter is None: | |
| if use_pca: | |
| pca_backbone = pca_backbone_name or backbone_name | |
| segmenter = PCASegmenter(backbone_name=pca_backbone, device=device) | |
| print(f"Using PCA segmenter (backbone={pca_backbone})") | |
| elif use_sam3: | |
| segmenter = SAM3Segmenter(text_prompt=category, device=device) | |
| print("Using SAM3 segmenter") | |
| else: | |
| print("No segmenter selected; using full-image foreground.") | |
| backbone = backbone_model or get_backbone(backbone_name) | |
| model = PatchKNNDetector( | |
| backbone=backbone, | |
| segmenter=segmenter, | |
| device=device, | |
| k_neighbors=1, | |
| ) | |
| print(f"Fitting model... (backbone={backbone_name}, sam3={use_sam3})") | |
| model.fit(train_paths, n_ref=n_ref) | |
| # 4. Evaluation Loop | |
| print(f"Starting evaluation on {len(test_paths)} images...") | |
| for i, path in enumerate(test_paths): | |
| # Predict | |
| image, amap, score = model.predict(path) | |
| # Ground Truth Logic | |
| is_anomaly = 0 if "good" in path else 1 | |
| if is_anomaly == 0: | |
| gt_mask = np.zeros_like(amap) | |
| else: | |
| mask_path = path.replace("test", "ground_truth").replace(".png", "_mask.png") | |
| if os.path.exists(mask_path): | |
| gt_mask = cv2.imread(mask_path, 0) | |
| if gt_mask.shape != amap.shape: | |
| gt_mask = cv2.resize(gt_mask, (amap.shape[1], amap.shape[0]), interpolation=cv2.INTER_NEAREST) | |
| gt_mask = (gt_mask > 0).astype(int) | |
| else: | |
| gt_mask = np.zeros_like(amap) | |
| # Update Evaluator | |
| evaluator.update(image_score=score, gt_label=is_anomaly, anomaly_map=amap, gt_mask=gt_mask) | |
| if i % 20 == 0: | |
| print(f"Processed {i}/{len(test_paths)}...") | |
| # 5. Compute & Print Results | |
| results = evaluator.compute() | |
| print("\n" + "="*40) | |
| print(f"FINAL RESULTS: {category}") | |
| print("-" * 40) | |
| # Image Level | |
| print(f"Image AUROC: {results['image_auroc']:.4f}") | |
| print(f"Image F1-Max: {results['image_f1_max']:.4f}") | |
| print(f"Image AP: {results['image_ap']:.4f}") | |
| print("-" * 40) | |
| # Pixel Level | |
| if 'pixel_auroc' in results: | |
| print(f"Pixel AUROC: {results['pixel_auroc']:.4f}") | |
| print(f"Pixel F1-Max: {results['pixel_f1_max']:.4f}") | |
| #print(f"PRO Score: {results['pixel_pro']:.4f}") | |
| print("="*40) | |
| return results if return_results else None | |
| if __name__ == "__main__": | |
| category = os.environ.get("MVTec_CATEGORY", "bottle") | |
| root = os.environ.get("MVTec_ROOT", None) | |
| backbone_name = os.environ.get("BACKBONE_NAME", "dinov3_small") | |
| use_sam3_env = os.environ.get("USE_SAM3", "0").lower() | |
| use_sam3 = use_sam3_env not in {"0", "false", "no"} | |
| use_pca_env = os.environ.get("USE_PCA", "1").lower() | |
| use_pca = use_pca_env in {"1", "true", "yes"} | |
| pca_backbone_name = os.environ.get("PCA_BACKBONE", None) | |
| main( | |
| category=category, | |
| root=root, | |
| backbone_name=backbone_name, | |
| use_sam3=use_sam3, | |
| use_pca=use_pca, | |
| pca_backbone_name=pca_backbone_name, | |
| n_ref=int(os.environ.get("N_REF", 1)), | |
| ) | |