Commit ·
2f7b834
1
Parent(s): 3da1244
Refactor
Browse files- inference.py +9 -56
- train.py +16 -40
inference.py
CHANGED
|
@@ -32,10 +32,6 @@ from core import (
|
|
| 32 |
)
|
| 33 |
|
| 34 |
|
| 35 |
-
# =============================================================================
|
| 36 |
-
# INFERENCE
|
| 37 |
-
# =============================================================================
|
| 38 |
-
|
| 39 |
def run_inference(image_path: str, category: str = "bottle", model_name: str = "patchcore", checkpoint_path: str = None) -> dict:
|
| 40 |
"""
|
| 41 |
Runs inference on a single image.
|
|
@@ -54,18 +50,15 @@ def run_inference(image_path: str, category: str = "bottle", model_name: str = "
|
|
| 54 |
if not os.path.exists(ckpt):
|
| 55 |
raise FileNotFoundError(f"Checkpoint not found: {ckpt}")
|
| 56 |
|
| 57 |
-
# Load model and dataset
|
| 58 |
model = load_model(model_name)
|
| 59 |
dataset = PredictDataset(path=image_path)
|
| 60 |
|
| 61 |
-
# Run prediction (disable automatic image saving)
|
| 62 |
engine = Engine(
|
| 63 |
default_root_dir="/tmp/anomalib_inference",
|
| 64 |
callbacks=[],
|
| 65 |
)
|
| 66 |
predictions = engine.predict(model=model, dataset=dataset, ckpt_path=ckpt)
|
| 67 |
|
| 68 |
-
# Extract results
|
| 69 |
results = {"image_path": image_path, "category": category, "model": model_name}
|
| 70 |
|
| 71 |
for batch in predictions:
|
|
@@ -77,10 +70,6 @@ def run_inference(image_path: str, category: str = "bottle", model_name: str = "
|
|
| 77 |
return results
|
| 78 |
|
| 79 |
|
| 80 |
-
# =============================================================================
|
| 81 |
-
# VISUALIZATION
|
| 82 |
-
# =============================================================================
|
| 83 |
-
|
| 84 |
def visualize_results(results: dict, output_dir: Path = None) -> str:
|
| 85 |
"""
|
| 86 |
Visualizes heatmap and mask overlayed on original image.
|
|
@@ -92,15 +81,12 @@ def visualize_results(results: dict, output_dir: Path = None) -> str:
|
|
| 92 |
Returns:
|
| 93 |
Path of saved image
|
| 94 |
"""
|
| 95 |
-
# Get model name for subfolder
|
| 96 |
model_name = results.get("model", "unknown")
|
| 97 |
|
| 98 |
-
# Create output/model_name subfolder
|
| 99 |
base_output_dir = output_dir or DIR_OUTPUT
|
| 100 |
output_dir = base_output_dir / model_name
|
| 101 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 102 |
|
| 103 |
-
# Load data
|
| 104 |
image_path = results["image_path"]
|
| 105 |
anomaly_score = results["anomaly_score"]
|
| 106 |
anomaly_map = results["anomaly_map"]
|
|
@@ -109,72 +95,54 @@ def visualize_results(results: dict, output_dir: Path = None) -> str:
|
|
| 109 |
|
| 110 |
original = np.array(Image.open(image_path).convert("RGB"))
|
| 111 |
|
| 112 |
-
# Prepare anomaly map
|
| 113 |
if anomaly_map.ndim == 3:
|
| 114 |
anomaly_map = anomaly_map.squeeze(0)
|
| 115 |
|
| 116 |
-
# Adaptive normalization: use image-level normalization if values are compressed
|
| 117 |
is_efficientad = model_name.lower() == "efficientad"
|
| 118 |
|
| 119 |
-
# Adaptive normalization: use image-level normalization if values are compressed
|
| 120 |
-
# STRICTLY RESTRICTED TO EFFICIENTAD
|
| 121 |
amap_min, amap_max = anomaly_map.min(), anomaly_map.max()
|
| 122 |
amap_range = amap_max - amap_min
|
| 123 |
|
| 124 |
-
if is_efficientad and amap_range < 0.1:
|
| 125 |
-
# Apply image-level min-max normalization
|
| 126 |
if amap_range > 1e-8:
|
| 127 |
anomaly_map = (anomaly_map - amap_min) / amap_range
|
| 128 |
else:
|
| 129 |
anomaly_map = np.zeros_like(anomaly_map)
|
| 130 |
-
print(f"
|
| 131 |
else:
|
| 132 |
-
# Standard clipping for all other models
|
| 133 |
anomaly_map = np.clip(anomaly_map, 0, 1)
|
| 134 |
|
| 135 |
anomaly_map = resize_to_match(anomaly_map, original.shape[:2])
|
| 136 |
|
| 137 |
-
# EfficientAD-specific: for "good" images, force blue heatmap
|
| 138 |
-
# is_efficientad already defined above
|
| 139 |
is_good = anomaly_score is not None and anomaly_score < 0.5
|
| 140 |
-
show_mask_contours = True
|
| 141 |
|
| 142 |
if is_efficientad and is_good:
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
print(f"ℹ️ EfficientAD: Good image detected - scaling heatmap to low values")
|
| 147 |
|
| 148 |
-
# Prepare pred mask
|
| 149 |
if pred_mask is not None:
|
| 150 |
if pred_mask.ndim == 3:
|
| 151 |
pred_mask = pred_mask.squeeze(0)
|
| 152 |
pred_mask = resize_to_match(pred_mask, original.shape[:2])
|
| 153 |
|
| 154 |
-
# Create figure - always 4 columns if pred_mask exists or EfficientAD good image
|
| 155 |
show_fourth_panel = pred_mask is not None or (is_efficientad and is_good)
|
| 156 |
num_cols = 4 if show_fourth_panel else 3
|
| 157 |
fig, axes = plt.subplots(1, num_cols, figsize=(5 * num_cols, 5), facecolor='white')
|
| 158 |
|
| 159 |
-
# Set white background for all axes to avoid colored borders
|
| 160 |
for ax in axes:
|
| 161 |
ax.set_facecolor('white')
|
| 162 |
|
| 163 |
-
# 1. Original Image
|
| 164 |
axes[0].imshow(original)
|
| 165 |
axes[0].set_title("Original")
|
| 166 |
axes[0].axis("off")
|
| 167 |
|
| 168 |
-
# 2. Heatmap - use aspect='auto' to fill subplot without borders
|
| 169 |
-
# 2. Heatmap - use aspect='auto' to fill subplot without borders
|
| 170 |
-
|
| 171 |
if is_efficientad:
|
| 172 |
-
# Mask 0 values (padding) to remove blue border
|
| 173 |
anomaly_map_masked = np.ma.masked_where(anomaly_map == 0, anomaly_map)
|
| 174 |
cmap = plt.cm.jet
|
| 175 |
-
cmap.set_bad(color='none')
|
| 176 |
else:
|
| 177 |
-
# Standard visualization for other models (0 is Blue, not transparent)
|
| 178 |
anomaly_map_masked = anomaly_map
|
| 179 |
cmap = plt.cm.jet
|
| 180 |
|
|
@@ -183,14 +151,11 @@ def visualize_results(results: dict, output_dir: Path = None) -> str:
|
|
| 183 |
axes[1].axis("off")
|
| 184 |
plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
|
| 185 |
|
| 186 |
-
|
| 187 |
-
# 3. Overlay - use aspect='auto' to match original image
|
| 188 |
axes[2].imshow(original, aspect='auto')
|
| 189 |
axes[2].imshow(anomaly_map_masked, cmap=cmap, alpha=0.5, vmin=0, vmax=1, aspect='auto')
|
| 190 |
axes[2].set_title("Overlay")
|
| 191 |
axes[2].axis("off")
|
| 192 |
|
| 193 |
-
# 4. Mask panel
|
| 194 |
if show_fourth_panel:
|
| 195 |
axes[3].imshow(original)
|
| 196 |
if show_mask_contours and pred_mask is not None:
|
|
@@ -198,11 +163,8 @@ def visualize_results(results: dict, output_dir: Path = None) -> str:
|
|
| 198 |
axes[3].set_title("Predicted Mask")
|
| 199 |
axes[3].axis("off")
|
| 200 |
|
| 201 |
-
# Title and Save - scale score to push towards 0 or 1 for EfficientAD
|
| 202 |
if anomaly_score is not None:
|
| 203 |
if is_efficientad:
|
| 204 |
-
# Scale EfficientAD scores (typically ~0.5) to be more extreme
|
| 205 |
-
# Good (<0.5) -> ~0.2, Anomalous (>0.5) -> high values
|
| 206 |
scaled_score = scale_efficientad_score(anomaly_score)
|
| 207 |
score_str = f"{scaled_score:.4f}"
|
| 208 |
else:
|
|
@@ -217,16 +179,11 @@ def visualize_results(results: dict, output_dir: Path = None) -> str:
|
|
| 217 |
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
| 218 |
plt.show()
|
| 219 |
|
| 220 |
-
# Always close the figure
|
| 221 |
plt.close(fig)
|
| 222 |
|
| 223 |
return str(output_path)
|
| 224 |
|
| 225 |
|
| 226 |
-
# =============================================================================
|
| 227 |
-
# MAIN
|
| 228 |
-
# =============================================================================
|
| 229 |
-
|
| 230 |
def parse_args():
|
| 231 |
"""Parse command line arguments."""
|
| 232 |
parser = argparse.ArgumentParser(
|
|
@@ -265,14 +222,12 @@ Examples:
|
|
| 265 |
def main():
|
| 266 |
args = parse_args()
|
| 267 |
|
| 268 |
-
# Validate input
|
| 269 |
if not os.path.exists(args.image_path):
|
| 270 |
raise FileNotFoundError(f"Image not found: {args.image_path}")
|
| 271 |
|
| 272 |
-
print(f"
|
| 273 |
-
print(f"
|
| 274 |
|
| 275 |
-
# Run inference
|
| 276 |
results = run_inference(
|
| 277 |
image_path=args.image_path,
|
| 278 |
category=args.category,
|
|
@@ -280,13 +235,11 @@ def main():
|
|
| 280 |
checkpoint_path=args.checkpoint
|
| 281 |
)
|
| 282 |
|
| 283 |
-
# Print result
|
| 284 |
score = results["anomaly_score"]
|
| 285 |
print(f"\n{'='*50}")
|
| 286 |
print(f"ANOMALY SCORE: {score:.4f}" if score else "ANOMALY SCORE: N/A")
|
| 287 |
print(f"{'='*50}\n")
|
| 288 |
|
| 289 |
-
# Visualize and save
|
| 290 |
output_dir = Path(args.output_dir) if args.output_dir else None
|
| 291 |
output_path = visualize_results(results, output_dir)
|
| 292 |
|
|
|
|
| 32 |
)
|
| 33 |
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
def run_inference(image_path: str, category: str = "bottle", model_name: str = "patchcore", checkpoint_path: str = None) -> dict:
|
| 36 |
"""
|
| 37 |
Runs inference on a single image.
|
|
|
|
| 50 |
if not os.path.exists(ckpt):
|
| 51 |
raise FileNotFoundError(f"Checkpoint not found: {ckpt}")
|
| 52 |
|
|
|
|
| 53 |
model = load_model(model_name)
|
| 54 |
dataset = PredictDataset(path=image_path)
|
| 55 |
|
|
|
|
| 56 |
engine = Engine(
|
| 57 |
default_root_dir="/tmp/anomalib_inference",
|
| 58 |
callbacks=[],
|
| 59 |
)
|
| 60 |
predictions = engine.predict(model=model, dataset=dataset, ckpt_path=ckpt)
|
| 61 |
|
|
|
|
| 62 |
results = {"image_path": image_path, "category": category, "model": model_name}
|
| 63 |
|
| 64 |
for batch in predictions:
|
|
|
|
| 70 |
return results
|
| 71 |
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
def visualize_results(results: dict, output_dir: Path = None) -> str:
|
| 74 |
"""
|
| 75 |
Visualizes heatmap and mask overlayed on original image.
|
|
|
|
| 81 |
Returns:
|
| 82 |
Path of saved image
|
| 83 |
"""
|
|
|
|
| 84 |
model_name = results.get("model", "unknown")
|
| 85 |
|
|
|
|
| 86 |
base_output_dir = output_dir or DIR_OUTPUT
|
| 87 |
output_dir = base_output_dir / model_name
|
| 88 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 89 |
|
|
|
|
| 90 |
image_path = results["image_path"]
|
| 91 |
anomaly_score = results["anomaly_score"]
|
| 92 |
anomaly_map = results["anomaly_map"]
|
|
|
|
| 95 |
|
| 96 |
original = np.array(Image.open(image_path).convert("RGB"))
|
| 97 |
|
|
|
|
| 98 |
if anomaly_map.ndim == 3:
|
| 99 |
anomaly_map = anomaly_map.squeeze(0)
|
| 100 |
|
|
|
|
| 101 |
is_efficientad = model_name.lower() == "efficientad"
|
| 102 |
|
|
|
|
|
|
|
| 103 |
amap_min, amap_max = anomaly_map.min(), anomaly_map.max()
|
| 104 |
amap_range = amap_max - amap_min
|
| 105 |
|
| 106 |
+
if is_efficientad and amap_range < 0.1:
|
|
|
|
| 107 |
if amap_range > 1e-8:
|
| 108 |
anomaly_map = (anomaly_map - amap_min) / amap_range
|
| 109 |
else:
|
| 110 |
anomaly_map = np.zeros_like(anomaly_map)
|
| 111 |
+
print(f"[INFO] Applied image-level normalization (original range: {amap_min:.4f}-{amap_max:.4f})")
|
| 112 |
else:
|
|
|
|
| 113 |
anomaly_map = np.clip(anomaly_map, 0, 1)
|
| 114 |
|
| 115 |
anomaly_map = resize_to_match(anomaly_map, original.shape[:2])
|
| 116 |
|
|
|
|
|
|
|
| 117 |
is_good = anomaly_score is not None and anomaly_score < 0.5
|
| 118 |
+
show_mask_contours = True
|
| 119 |
|
| 120 |
if is_efficientad and is_good:
|
| 121 |
+
anomaly_map = anomaly_map * 0.3
|
| 122 |
+
show_mask_contours = False
|
| 123 |
+
print(f"[INFO] EfficientAD: Good image detected - scaling heatmap to low values")s")
|
|
|
|
| 124 |
|
|
|
|
| 125 |
if pred_mask is not None:
|
| 126 |
if pred_mask.ndim == 3:
|
| 127 |
pred_mask = pred_mask.squeeze(0)
|
| 128 |
pred_mask = resize_to_match(pred_mask, original.shape[:2])
|
| 129 |
|
|
|
|
| 130 |
show_fourth_panel = pred_mask is not None or (is_efficientad and is_good)
|
| 131 |
num_cols = 4 if show_fourth_panel else 3
|
| 132 |
fig, axes = plt.subplots(1, num_cols, figsize=(5 * num_cols, 5), facecolor='white')
|
| 133 |
|
|
|
|
| 134 |
for ax in axes:
|
| 135 |
ax.set_facecolor('white')
|
| 136 |
|
|
|
|
| 137 |
axes[0].imshow(original)
|
| 138 |
axes[0].set_title("Original")
|
| 139 |
axes[0].axis("off")
|
| 140 |
|
|
|
|
|
|
|
|
|
|
| 141 |
if is_efficientad:
|
|
|
|
| 142 |
anomaly_map_masked = np.ma.masked_where(anomaly_map == 0, anomaly_map)
|
| 143 |
cmap = plt.cm.jet
|
| 144 |
+
cmap.set_bad(color='none')
|
| 145 |
else:
|
|
|
|
| 146 |
anomaly_map_masked = anomaly_map
|
| 147 |
cmap = plt.cm.jet
|
| 148 |
|
|
|
|
| 151 |
axes[1].axis("off")
|
| 152 |
plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
|
| 153 |
|
|
|
|
|
|
|
| 154 |
axes[2].imshow(original, aspect='auto')
|
| 155 |
axes[2].imshow(anomaly_map_masked, cmap=cmap, alpha=0.5, vmin=0, vmax=1, aspect='auto')
|
| 156 |
axes[2].set_title("Overlay")
|
| 157 |
axes[2].axis("off")
|
| 158 |
|
|
|
|
| 159 |
if show_fourth_panel:
|
| 160 |
axes[3].imshow(original)
|
| 161 |
if show_mask_contours and pred_mask is not None:
|
|
|
|
| 163 |
axes[3].set_title("Predicted Mask")
|
| 164 |
axes[3].axis("off")
|
| 165 |
|
|
|
|
| 166 |
if anomaly_score is not None:
|
| 167 |
if is_efficientad:
|
|
|
|
|
|
|
| 168 |
scaled_score = scale_efficientad_score(anomaly_score)
|
| 169 |
score_str = f"{scaled_score:.4f}"
|
| 170 |
else:
|
|
|
|
| 179 |
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
| 180 |
plt.show()
|
| 181 |
|
|
|
|
| 182 |
plt.close(fig)
|
| 183 |
|
| 184 |
return str(output_path)
|
| 185 |
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
def parse_args():
|
| 188 |
"""Parse command line arguments."""
|
| 189 |
parser = argparse.ArgumentParser(
|
|
|
|
| 222 |
def main():
|
| 223 |
args = parse_args()
|
| 224 |
|
|
|
|
| 225 |
if not os.path.exists(args.image_path):
|
| 226 |
raise FileNotFoundError(f"Image not found: {args.image_path}")
|
| 227 |
|
| 228 |
+
print(f"Inference on: {args.image_path}")
|
| 229 |
+
print(f"Model: {args.model} | Category: {args.category}")
|
| 230 |
|
|
|
|
| 231 |
results = run_inference(
|
| 232 |
image_path=args.image_path,
|
| 233 |
category=args.category,
|
|
|
|
| 235 |
checkpoint_path=args.checkpoint
|
| 236 |
)
|
| 237 |
|
|
|
|
| 238 |
score = results["anomaly_score"]
|
| 239 |
print(f"\n{'='*50}")
|
| 240 |
print(f"ANOMALY SCORE: {score:.4f}" if score else "ANOMALY SCORE: N/A")
|
| 241 |
print(f"{'='*50}\n")
|
| 242 |
|
|
|
|
| 243 |
output_dir = Path(args.output_dir) if args.output_dir else None
|
| 244 |
output_path = visualize_results(results, output_dir)
|
| 245 |
|
train.py
CHANGED
|
@@ -34,12 +34,6 @@ from core import (
|
|
| 34 |
|
| 35 |
logger = logging.getLogger(__name__)
|
| 36 |
|
| 37 |
-
# =============================================================================
|
| 38 |
-
# EFFICIENTAD MONKEY-PATCH
|
| 39 |
-
# =============================================================================
|
| 40 |
-
# Override the pretrained weights directory to keep everything inside
|
| 41 |
-
# efficientad_resources/ instead of the hardcoded ./pre_trained/
|
| 42 |
-
|
| 43 |
EFFICIENTAD_RESOURCES_DIR = Path(__file__).parent / "efficientad_resources"
|
| 44 |
|
| 45 |
def _patched_prepare_pretrained_model(self) -> None:
|
|
@@ -66,12 +60,8 @@ def patch_efficientad():
|
|
| 66 |
"""Apply monkey-patch to EfficientAd to use custom pretrained weights directory."""
|
| 67 |
from anomalib.models import EfficientAd
|
| 68 |
EfficientAd.prepare_pretrained_model = _patched_prepare_pretrained_model
|
| 69 |
-
print(f"
|
| 70 |
-
|
| 71 |
|
| 72 |
-
# =============================================================================
|
| 73 |
-
# METRICS FUNCTIONS
|
| 74 |
-
# =============================================================================
|
| 75 |
|
| 76 |
def save_metrics(category_metrics, category, model_name):
|
| 77 |
"""Saves metrics in the Anomalib directory structure."""
|
|
@@ -95,7 +85,7 @@ def save_metrics(category_metrics, category, model_name):
|
|
| 95 |
version_json_path = version_dir / "metrics.json"
|
| 96 |
with open(version_json_path, 'w', encoding='utf-8') as f:
|
| 97 |
json.dump(category_metrics, f, indent=2, ensure_ascii=False)
|
| 98 |
-
print(f"
|
| 99 |
|
| 100 |
# Save in latest (only if it exists)
|
| 101 |
latest_dir = category_base_dir / "latest"
|
|
@@ -107,7 +97,7 @@ def save_metrics(category_metrics, category, model_name):
|
|
| 107 |
|
| 108 |
def print_category_metrics(metrics):
|
| 109 |
"""Prints metrics for a category."""
|
| 110 |
-
print(f"\n
|
| 111 |
print(f" EFFICACY: AUROC img={format_metric(metrics['image_auroc'])} | "
|
| 112 |
f"AUROC pix={format_metric(metrics['pixel_auroc'])} | "
|
| 113 |
f"F1={format_metric(metrics['image_f1'])}")
|
|
@@ -123,7 +113,7 @@ def print_final_report(all_metrics, model_name):
|
|
| 123 |
return
|
| 124 |
|
| 125 |
print(f"\n{'='*100}")
|
| 126 |
-
print(f"
|
| 127 |
print(f"{'='*100}\n")
|
| 128 |
|
| 129 |
# Header
|
|
@@ -157,14 +147,10 @@ def print_final_report(all_metrics, model_name):
|
|
| 157 |
print(f"\n{'='*100}")
|
| 158 |
|
| 159 |
|
| 160 |
-
# =============================================================================
|
| 161 |
-
# MAIN TRAINING FUNCTION
|
| 162 |
-
# =============================================================================
|
| 163 |
-
|
| 164 |
def train_category(category, model_name):
|
| 165 |
"""Runs training, test, and calculates metrics for a category."""
|
| 166 |
print(f"\n{'='*60}")
|
| 167 |
-
print(f"
|
| 168 |
print(f"{'='*60}")
|
| 169 |
|
| 170 |
# Load config
|
|
@@ -181,11 +167,10 @@ def train_category(category, model_name):
|
|
| 181 |
|
| 182 |
# EfficientAd-specific setup
|
| 183 |
if model_name == "efficientad":
|
| 184 |
-
patch_efficientad()
|
| 185 |
-
# Override imagenet_dir to use absolute path inside efficientad_resources/
|
| 186 |
model_params["imagenet_dir"] = str(EFFICIENTAD_RESOURCES_DIR / "imagenette")
|
| 187 |
-
print(f"
|
| 188 |
-
print("
|
| 189 |
model_params["visualizer"] = False
|
| 190 |
|
| 191 |
model = model_class(**model_params)
|
|
@@ -222,15 +207,11 @@ def train_category(category, model_name):
|
|
| 222 |
# Output and save
|
| 223 |
print_category_metrics(category_metrics)
|
| 224 |
save_metrics(category_metrics, category, model_name)
|
| 225 |
-
print(f"\
|
| 226 |
|
| 227 |
return category_metrics
|
| 228 |
|
| 229 |
|
| 230 |
-
# =============================================================================
|
| 231 |
-
# MAIN
|
| 232 |
-
# =============================================================================
|
| 233 |
-
|
| 234 |
def parse_args():
|
| 235 |
"""Parse command line arguments."""
|
| 236 |
available_models = get_available_models()
|
|
@@ -262,37 +243,32 @@ Examples:
|
|
| 262 |
def main():
|
| 263 |
args = parse_args()
|
| 264 |
|
| 265 |
-
# Determine categories
|
| 266 |
if args.category == "all":
|
| 267 |
categories = MVTEC_CATEGORIES
|
| 268 |
-
print(f"
|
| 269 |
else:
|
| 270 |
categories = [args.category]
|
| 271 |
-
print(f"
|
| 272 |
|
| 273 |
-
# Determine models
|
| 274 |
if args.model == "all":
|
| 275 |
models = get_available_models()
|
| 276 |
-
print(f"
|
| 277 |
else:
|
| 278 |
models = [args.model]
|
| 279 |
-
print(f"
|
| 280 |
|
| 281 |
-
# Create results directory
|
| 282 |
DIR_RESULTS.mkdir(parents=True, exist_ok=True)
|
| 283 |
|
| 284 |
-
# Training loop
|
| 285 |
all_metrics = []
|
| 286 |
for model_name in models:
|
| 287 |
if len(models) > 1:
|
| 288 |
-
print(f"\n{'
|
| 289 |
-
print(f"
|
| 290 |
-
print(f"{'
|
| 291 |
|
| 292 |
model_metrics = [train_category(cat, model_name) for cat in categories]
|
| 293 |
all_metrics.extend(model_metrics)
|
| 294 |
|
| 295 |
-
# Final Report per model
|
| 296 |
print_final_report(model_metrics, model_name)
|
| 297 |
|
| 298 |
|
|
|
|
| 34 |
|
| 35 |
logger = logging.getLogger(__name__)
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
EFFICIENTAD_RESOURCES_DIR = Path(__file__).parent / "efficientad_resources"
|
| 38 |
|
| 39 |
def _patched_prepare_pretrained_model(self) -> None:
|
|
|
|
| 60 |
"""Apply monkey-patch to EfficientAd to use custom pretrained weights directory."""
|
| 61 |
from anomalib.models import EfficientAd
|
| 62 |
EfficientAd.prepare_pretrained_model = _patched_prepare_pretrained_model
|
| 63 |
+
print(f" [INFO] EfficientAd: Pretrained weights directory: {EFFICIENTAD_RESOURCES_DIR / 'pre_trained'}")
|
|
|
|
| 64 |
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
def save_metrics(category_metrics, category, model_name):
|
| 67 |
"""Saves metrics in the Anomalib directory structure."""
|
|
|
|
| 85 |
version_json_path = version_dir / "metrics.json"
|
| 86 |
with open(version_json_path, 'w', encoding='utf-8') as f:
|
| 87 |
json.dump(category_metrics, f, indent=2, ensure_ascii=False)
|
| 88 |
+
print(f" Saved: {version_json_path}")
|
| 89 |
|
| 90 |
# Save in latest (only if it exists)
|
| 91 |
latest_dir = category_base_dir / "latest"
|
|
|
|
| 97 |
|
| 98 |
def print_category_metrics(metrics):
|
| 99 |
"""Prints metrics for a category."""
|
| 100 |
+
print(f"\n[METRICS]")
|
| 101 |
print(f" EFFICACY: AUROC img={format_metric(metrics['image_auroc'])} | "
|
| 102 |
f"AUROC pix={format_metric(metrics['pixel_auroc'])} | "
|
| 103 |
f"F1={format_metric(metrics['image_f1'])}")
|
|
|
|
| 113 |
return
|
| 114 |
|
| 115 |
print(f"\n{'='*100}")
|
| 116 |
+
print(f"FINAL REPORT - {model_name.upper()} PERFORMANCE METRICS")
|
| 117 |
print(f"{'='*100}\n")
|
| 118 |
|
| 119 |
# Header
|
|
|
|
| 147 |
print(f"\n{'='*100}")
|
| 148 |
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
def train_category(category, model_name):
|
| 151 |
"""Runs training, test, and calculates metrics for a category."""
|
| 152 |
print(f"\n{'='*60}")
|
| 153 |
+
print(f"Training: {category} ({model_name})")
|
| 154 |
print(f"{'='*60}")
|
| 155 |
|
| 156 |
# Load config
|
|
|
|
| 167 |
|
| 168 |
# EfficientAd-specific setup
|
| 169 |
if model_name == "efficientad":
|
| 170 |
+
patch_efficientad()
|
|
|
|
| 171 |
model_params["imagenet_dir"] = str(EFFICIENTAD_RESOURCES_DIR / "imagenette")
|
| 172 |
+
print(f" [INFO] EfficientAd: ImageNet directory: {EFFICIENTAD_RESOURCES_DIR / 'imagenette'}")
|
| 173 |
+
print(" [INFO] EfficientAd: Image visualization disabled")
|
| 174 |
model_params["visualizer"] = False
|
| 175 |
|
| 176 |
model = model_class(**model_params)
|
|
|
|
| 207 |
# Output and save
|
| 208 |
print_category_metrics(category_metrics)
|
| 209 |
save_metrics(category_metrics, category, model_name)
|
| 210 |
+
print(f"\nCompleted: {category}\n")
|
| 211 |
|
| 212 |
return category_metrics
|
| 213 |
|
| 214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
def parse_args():
|
| 216 |
"""Parse command line arguments."""
|
| 217 |
available_models = get_available_models()
|
|
|
|
| 243 |
def main():
|
| 244 |
args = parse_args()
|
| 245 |
|
|
|
|
| 246 |
if args.category == "all":
|
| 247 |
categories = MVTEC_CATEGORIES
|
| 248 |
+
print(f"Training on ALL {len(categories)} categories")
|
| 249 |
else:
|
| 250 |
categories = [args.category]
|
| 251 |
+
print(f"Training on: {args.category}")
|
| 252 |
|
|
|
|
| 253 |
if args.model == "all":
|
| 254 |
models = get_available_models()
|
| 255 |
+
print(f"Models: ALL ({', '.join(models)})")
|
| 256 |
else:
|
| 257 |
models = [args.model]
|
| 258 |
+
print(f"Model: {args.model}")
|
| 259 |
|
|
|
|
| 260 |
DIR_RESULTS.mkdir(parents=True, exist_ok=True)
|
| 261 |
|
|
|
|
| 262 |
all_metrics = []
|
| 263 |
for model_name in models:
|
| 264 |
if len(models) > 1:
|
| 265 |
+
print(f"\n{'='*60}")
|
| 266 |
+
print(f"MODEL: {model_name.upper()}")
|
| 267 |
+
print(f"{'='*60}")
|
| 268 |
|
| 269 |
model_metrics = [train_category(cat, model_name) for cat in categories]
|
| 270 |
all_metrics.extend(model_metrics)
|
| 271 |
|
|
|
|
| 272 |
print_final_report(model_metrics, model_name)
|
| 273 |
|
| 274 |
|