AlessandroSchmitt commited on
Commit
2f7b834
·
1 Parent(s): 3da1244
Files changed (2) hide show
  1. inference.py +9 -56
  2. train.py +16 -40
inference.py CHANGED
@@ -32,10 +32,6 @@ from core import (
32
  )
33
 
34
 
35
- # =============================================================================
36
- # INFERENCE
37
- # =============================================================================
38
-
39
  def run_inference(image_path: str, category: str = "bottle", model_name: str = "patchcore", checkpoint_path: str = None) -> dict:
40
  """
41
  Runs inference on a single image.
@@ -54,18 +50,15 @@ def run_inference(image_path: str, category: str = "bottle", model_name: str = "
54
  if not os.path.exists(ckpt):
55
  raise FileNotFoundError(f"Checkpoint not found: {ckpt}")
56
 
57
- # Load model and dataset
58
  model = load_model(model_name)
59
  dataset = PredictDataset(path=image_path)
60
 
61
- # Run prediction (disable automatic image saving)
62
  engine = Engine(
63
  default_root_dir="/tmp/anomalib_inference",
64
  callbacks=[],
65
  )
66
  predictions = engine.predict(model=model, dataset=dataset, ckpt_path=ckpt)
67
 
68
- # Extract results
69
  results = {"image_path": image_path, "category": category, "model": model_name}
70
 
71
  for batch in predictions:
@@ -77,10 +70,6 @@ def run_inference(image_path: str, category: str = "bottle", model_name: str = "
77
  return results
78
 
79
 
80
- # =============================================================================
81
- # VISUALIZATION
82
- # =============================================================================
83
-
84
  def visualize_results(results: dict, output_dir: Path = None) -> str:
85
  """
86
  Visualizes heatmap and mask overlayed on original image.
@@ -92,15 +81,12 @@ def visualize_results(results: dict, output_dir: Path = None) -> str:
92
  Returns:
93
  Path of saved image
94
  """
95
- # Get model name for subfolder
96
  model_name = results.get("model", "unknown")
97
 
98
- # Create output/model_name subfolder
99
  base_output_dir = output_dir or DIR_OUTPUT
100
  output_dir = base_output_dir / model_name
101
  output_dir.mkdir(parents=True, exist_ok=True)
102
 
103
- # Load data
104
  image_path = results["image_path"]
105
  anomaly_score = results["anomaly_score"]
106
  anomaly_map = results["anomaly_map"]
@@ -109,72 +95,54 @@ def visualize_results(results: dict, output_dir: Path = None) -> str:
109
 
110
  original = np.array(Image.open(image_path).convert("RGB"))
111
 
112
- # Prepare anomaly map
113
  if anomaly_map.ndim == 3:
114
  anomaly_map = anomaly_map.squeeze(0)
115
 
116
- # Adaptive normalization: use image-level normalization if values are compressed
117
  is_efficientad = model_name.lower() == "efficientad"
118
 
119
- # Adaptive normalization: use image-level normalization if values are compressed
120
- # STRICTLY RESTRICTED TO EFFICIENTAD
121
  amap_min, amap_max = anomaly_map.min(), anomaly_map.max()
122
  amap_range = amap_max - amap_min
123
 
124
- if is_efficientad and amap_range < 0.1: # Compressed values (e.g., EfficientAD output)
125
- # Apply image-level min-max normalization
126
  if amap_range > 1e-8:
127
  anomaly_map = (anomaly_map - amap_min) / amap_range
128
  else:
129
  anomaly_map = np.zeros_like(anomaly_map)
130
- print(f"ℹ️ Applied image-level normalization (original range: {amap_min:.4f}-{amap_max:.4f})")
131
  else:
132
- # Standard clipping for all other models
133
  anomaly_map = np.clip(anomaly_map, 0, 1)
134
 
135
  anomaly_map = resize_to_match(anomaly_map, original.shape[:2])
136
 
137
- # EfficientAD-specific: for "good" images, force blue heatmap
138
- # is_efficientad already defined above
139
  is_good = anomaly_score is not None and anomaly_score < 0.5
140
- show_mask_contours = True # Whether to draw contours on mask panel
141
 
142
  if is_efficientad and is_good:
143
- # Scale heatmap to low values but keep texture variation (0-0.3 range)
144
- anomaly_map = anomaly_map * 0.3 # Keep variation but in low range
145
- show_mask_contours = False # Don't draw contours for good images
146
- print(f"ℹ️ EfficientAD: Good image detected - scaling heatmap to low values")
147
 
148
- # Prepare pred mask
149
  if pred_mask is not None:
150
  if pred_mask.ndim == 3:
151
  pred_mask = pred_mask.squeeze(0)
152
  pred_mask = resize_to_match(pred_mask, original.shape[:2])
153
 
154
- # Create figure - always 4 columns if pred_mask exists or EfficientAD good image
155
  show_fourth_panel = pred_mask is not None or (is_efficientad and is_good)
156
  num_cols = 4 if show_fourth_panel else 3
157
  fig, axes = plt.subplots(1, num_cols, figsize=(5 * num_cols, 5), facecolor='white')
158
 
159
- # Set white background for all axes to avoid colored borders
160
  for ax in axes:
161
  ax.set_facecolor('white')
162
 
163
- # 1. Original Image
164
  axes[0].imshow(original)
165
  axes[0].set_title("Original")
166
  axes[0].axis("off")
167
 
168
- # 2. Heatmap - use aspect='auto' to fill subplot without borders
169
- # 2. Heatmap - use aspect='auto' to fill subplot without borders
170
-
171
  if is_efficientad:
172
- # Mask 0 values (padding) to remove blue border
173
  anomaly_map_masked = np.ma.masked_where(anomaly_map == 0, anomaly_map)
174
  cmap = plt.cm.jet
175
- cmap.set_bad(color='none') # Transparent for masked values (padding)
176
  else:
177
- # Standard visualization for other models (0 is Blue, not transparent)
178
  anomaly_map_masked = anomaly_map
179
  cmap = plt.cm.jet
180
 
@@ -183,14 +151,11 @@ def visualize_results(results: dict, output_dir: Path = None) -> str:
183
  axes[1].axis("off")
184
  plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
185
 
186
-
187
- # 3. Overlay - use aspect='auto' to match original image
188
  axes[2].imshow(original, aspect='auto')
189
  axes[2].imshow(anomaly_map_masked, cmap=cmap, alpha=0.5, vmin=0, vmax=1, aspect='auto')
190
  axes[2].set_title("Overlay")
191
  axes[2].axis("off")
192
 
193
- # 4. Mask panel
194
  if show_fourth_panel:
195
  axes[3].imshow(original)
196
  if show_mask_contours and pred_mask is not None:
@@ -198,11 +163,8 @@ def visualize_results(results: dict, output_dir: Path = None) -> str:
198
  axes[3].set_title("Predicted Mask")
199
  axes[3].axis("off")
200
 
201
- # Title and Save - scale score to push towards 0 or 1 for EfficientAD
202
  if anomaly_score is not None:
203
  if is_efficientad:
204
- # Scale EfficientAD scores (typically ~0.5) to be more extreme
205
- # Good (<0.5) -> ~0.2, Anomalous (>0.5) -> high values
206
  scaled_score = scale_efficientad_score(anomaly_score)
207
  score_str = f"{scaled_score:.4f}"
208
  else:
@@ -217,16 +179,11 @@ def visualize_results(results: dict, output_dir: Path = None) -> str:
217
  plt.savefig(output_path, dpi=150, bbox_inches="tight")
218
  plt.show()
219
 
220
- # Always close the figure
221
  plt.close(fig)
222
 
223
  return str(output_path)
224
 
225
 
226
- # =============================================================================
227
- # MAIN
228
- # =============================================================================
229
-
230
  def parse_args():
231
  """Parse command line arguments."""
232
  parser = argparse.ArgumentParser(
@@ -265,14 +222,12 @@ Examples:
265
  def main():
266
  args = parse_args()
267
 
268
- # Validate input
269
  if not os.path.exists(args.image_path):
270
  raise FileNotFoundError(f"Image not found: {args.image_path}")
271
 
272
- print(f"🔍 Inference on: {args.image_path}")
273
- print(f"📦 Model: {args.model} | Category: {args.category}")
274
 
275
- # Run inference
276
  results = run_inference(
277
  image_path=args.image_path,
278
  category=args.category,
@@ -280,13 +235,11 @@ def main():
280
  checkpoint_path=args.checkpoint
281
  )
282
 
283
- # Print result
284
  score = results["anomaly_score"]
285
  print(f"\n{'='*50}")
286
  print(f"ANOMALY SCORE: {score:.4f}" if score else "ANOMALY SCORE: N/A")
287
  print(f"{'='*50}\n")
288
 
289
- # Visualize and save
290
  output_dir = Path(args.output_dir) if args.output_dir else None
291
  output_path = visualize_results(results, output_dir)
292
 
 
32
  )
33
 
34
 
 
 
 
 
35
  def run_inference(image_path: str, category: str = "bottle", model_name: str = "patchcore", checkpoint_path: str = None) -> dict:
36
  """
37
  Runs inference on a single image.
 
50
  if not os.path.exists(ckpt):
51
  raise FileNotFoundError(f"Checkpoint not found: {ckpt}")
52
 
 
53
  model = load_model(model_name)
54
  dataset = PredictDataset(path=image_path)
55
 
 
56
  engine = Engine(
57
  default_root_dir="/tmp/anomalib_inference",
58
  callbacks=[],
59
  )
60
  predictions = engine.predict(model=model, dataset=dataset, ckpt_path=ckpt)
61
 
 
62
  results = {"image_path": image_path, "category": category, "model": model_name}
63
 
64
  for batch in predictions:
 
70
  return results
71
 
72
 
 
 
 
 
73
  def visualize_results(results: dict, output_dir: Path = None) -> str:
74
  """
75
  Visualizes heatmap and mask overlayed on original image.
 
81
  Returns:
82
  Path of saved image
83
  """
 
84
  model_name = results.get("model", "unknown")
85
 
 
86
  base_output_dir = output_dir or DIR_OUTPUT
87
  output_dir = base_output_dir / model_name
88
  output_dir.mkdir(parents=True, exist_ok=True)
89
 
 
90
  image_path = results["image_path"]
91
  anomaly_score = results["anomaly_score"]
92
  anomaly_map = results["anomaly_map"]
 
95
 
96
  original = np.array(Image.open(image_path).convert("RGB"))
97
 
 
98
  if anomaly_map.ndim == 3:
99
  anomaly_map = anomaly_map.squeeze(0)
100
 
 
101
  is_efficientad = model_name.lower() == "efficientad"
102
 
 
 
103
  amap_min, amap_max = anomaly_map.min(), anomaly_map.max()
104
  amap_range = amap_max - amap_min
105
 
106
+ if is_efficientad and amap_range < 0.1:
 
107
  if amap_range > 1e-8:
108
  anomaly_map = (anomaly_map - amap_min) / amap_range
109
  else:
110
  anomaly_map = np.zeros_like(anomaly_map)
111
+ print(f"[INFO] Applied image-level normalization (original range: {amap_min:.4f}-{amap_max:.4f})")
112
  else:
 
113
  anomaly_map = np.clip(anomaly_map, 0, 1)
114
 
115
  anomaly_map = resize_to_match(anomaly_map, original.shape[:2])
116
 
 
 
117
  is_good = anomaly_score is not None and anomaly_score < 0.5
118
+ show_mask_contours = True
119
 
120
  if is_efficientad and is_good:
121
+ anomaly_map = anomaly_map * 0.3
122
+ show_mask_contours = False
123
+ print(f"[INFO] EfficientAD: Good image detected - scaling heatmap to low values")s")
 
124
 
 
125
  if pred_mask is not None:
126
  if pred_mask.ndim == 3:
127
  pred_mask = pred_mask.squeeze(0)
128
  pred_mask = resize_to_match(pred_mask, original.shape[:2])
129
 
 
130
  show_fourth_panel = pred_mask is not None or (is_efficientad and is_good)
131
  num_cols = 4 if show_fourth_panel else 3
132
  fig, axes = plt.subplots(1, num_cols, figsize=(5 * num_cols, 5), facecolor='white')
133
 
 
134
  for ax in axes:
135
  ax.set_facecolor('white')
136
 
 
137
  axes[0].imshow(original)
138
  axes[0].set_title("Original")
139
  axes[0].axis("off")
140
 
 
 
 
141
  if is_efficientad:
 
142
  anomaly_map_masked = np.ma.masked_where(anomaly_map == 0, anomaly_map)
143
  cmap = plt.cm.jet
144
+ cmap.set_bad(color='none')
145
  else:
 
146
  anomaly_map_masked = anomaly_map
147
  cmap = plt.cm.jet
148
 
 
151
  axes[1].axis("off")
152
  plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
153
 
 
 
154
  axes[2].imshow(original, aspect='auto')
155
  axes[2].imshow(anomaly_map_masked, cmap=cmap, alpha=0.5, vmin=0, vmax=1, aspect='auto')
156
  axes[2].set_title("Overlay")
157
  axes[2].axis("off")
158
 
 
159
  if show_fourth_panel:
160
  axes[3].imshow(original)
161
  if show_mask_contours and pred_mask is not None:
 
163
  axes[3].set_title("Predicted Mask")
164
  axes[3].axis("off")
165
 
 
166
  if anomaly_score is not None:
167
  if is_efficientad:
 
 
168
  scaled_score = scale_efficientad_score(anomaly_score)
169
  score_str = f"{scaled_score:.4f}"
170
  else:
 
179
  plt.savefig(output_path, dpi=150, bbox_inches="tight")
180
  plt.show()
181
 
 
182
  plt.close(fig)
183
 
184
  return str(output_path)
185
 
186
 
 
 
 
 
187
  def parse_args():
188
  """Parse command line arguments."""
189
  parser = argparse.ArgumentParser(
 
222
  def main():
223
  args = parse_args()
224
 
 
225
  if not os.path.exists(args.image_path):
226
  raise FileNotFoundError(f"Image not found: {args.image_path}")
227
 
228
+ print(f"Inference on: {args.image_path}")
229
+ print(f"Model: {args.model} | Category: {args.category}")
230
 
 
231
  results = run_inference(
232
  image_path=args.image_path,
233
  category=args.category,
 
235
  checkpoint_path=args.checkpoint
236
  )
237
 
 
238
  score = results["anomaly_score"]
239
  print(f"\n{'='*50}")
240
  print(f"ANOMALY SCORE: {score:.4f}" if score else "ANOMALY SCORE: N/A")
241
  print(f"{'='*50}\n")
242
 
 
243
  output_dir = Path(args.output_dir) if args.output_dir else None
244
  output_path = visualize_results(results, output_dir)
245
 
train.py CHANGED
@@ -34,12 +34,6 @@ from core import (
34
 
35
  logger = logging.getLogger(__name__)
36
 
37
- # =============================================================================
38
- # EFFICIENTAD MONKEY-PATCH
39
- # =============================================================================
40
- # Override the pretrained weights directory to keep everything inside
41
- # efficientad_resources/ instead of the hardcoded ./pre_trained/
42
-
43
  EFFICIENTAD_RESOURCES_DIR = Path(__file__).parent / "efficientad_resources"
44
 
45
  def _patched_prepare_pretrained_model(self) -> None:
@@ -66,12 +60,8 @@ def patch_efficientad():
66
  """Apply monkey-patch to EfficientAd to use custom pretrained weights directory."""
67
  from anomalib.models import EfficientAd
68
  EfficientAd.prepare_pretrained_model = _patched_prepare_pretrained_model
69
- print(f" ℹ️ EfficientAd: Pretrained weights will be saved to {EFFICIENTAD_RESOURCES_DIR / 'pre_trained'}")
70
-
71
 
72
- # =============================================================================
73
- # METRICS FUNCTIONS
74
- # =============================================================================
75
 
76
  def save_metrics(category_metrics, category, model_name):
77
  """Saves metrics in the Anomalib directory structure."""
@@ -95,7 +85,7 @@ def save_metrics(category_metrics, category, model_name):
95
  version_json_path = version_dir / "metrics.json"
96
  with open(version_json_path, 'w', encoding='utf-8') as f:
97
  json.dump(category_metrics, f, indent=2, ensure_ascii=False)
98
- print(f" 💾 Saved: {version_json_path}")
99
 
100
  # Save in latest (only if it exists)
101
  latest_dir = category_base_dir / "latest"
@@ -107,7 +97,7 @@ def save_metrics(category_metrics, category, model_name):
107
 
108
  def print_category_metrics(metrics):
109
  """Prints metrics for a category."""
110
- print(f"\n📊 Metrics:")
111
  print(f" EFFICACY: AUROC img={format_metric(metrics['image_auroc'])} | "
112
  f"AUROC pix={format_metric(metrics['pixel_auroc'])} | "
113
  f"F1={format_metric(metrics['image_f1'])}")
@@ -123,7 +113,7 @@ def print_final_report(all_metrics, model_name):
123
  return
124
 
125
  print(f"\n{'='*100}")
126
- print(f"📊 FINAL REPORT - {model_name.upper()} PERFORMANCE METRICS")
127
  print(f"{'='*100}\n")
128
 
129
  # Header
@@ -157,14 +147,10 @@ def print_final_report(all_metrics, model_name):
157
  print(f"\n{'='*100}")
158
 
159
 
160
- # =============================================================================
161
- # MAIN TRAINING FUNCTION
162
- # =============================================================================
163
-
164
  def train_category(category, model_name):
165
  """Runs training, test, and calculates metrics for a category."""
166
  print(f"\n{'='*60}")
167
- print(f"🚀 Training: {category} ({model_name})")
168
  print(f"{'='*60}")
169
 
170
  # Load config
@@ -181,11 +167,10 @@ def train_category(category, model_name):
181
 
182
  # EfficientAd-specific setup
183
  if model_name == "efficientad":
184
- patch_efficientad() # Redirect pretrained weights to efficientad_resources/
185
- # Override imagenet_dir to use absolute path inside efficientad_resources/
186
  model_params["imagenet_dir"] = str(EFFICIENTAD_RESOURCES_DIR / "imagenette")
187
- print(f" ℹ️ EfficientAd: ImageNet data will be saved to {EFFICIENTAD_RESOURCES_DIR / 'imagenette'}")
188
- print(" ℹ️ EfficientAd: Disabling image visualization (saving storage)")
189
  model_params["visualizer"] = False
190
 
191
  model = model_class(**model_params)
@@ -222,15 +207,11 @@ def train_category(category, model_name):
222
  # Output and save
223
  print_category_metrics(category_metrics)
224
  save_metrics(category_metrics, category, model_name)
225
- print(f"\n✓ Completed: {category}\n")
226
 
227
  return category_metrics
228
 
229
 
230
- # =============================================================================
231
- # MAIN
232
- # =============================================================================
233
-
234
  def parse_args():
235
  """Parse command line arguments."""
236
  available_models = get_available_models()
@@ -262,37 +243,32 @@ Examples:
262
  def main():
263
  args = parse_args()
264
 
265
- # Determine categories
266
  if args.category == "all":
267
  categories = MVTEC_CATEGORIES
268
- print(f"🔄 Training on ALL {len(categories)} categories")
269
  else:
270
  categories = [args.category]
271
- print(f"🎯 Training on: {args.category}")
272
 
273
- # Determine models
274
  if args.model == "all":
275
  models = get_available_models()
276
- print(f"📦 Models: ALL ({', '.join(models)})")
277
  else:
278
  models = [args.model]
279
- print(f"📦 Model: {args.model}")
280
 
281
- # Create results directory
282
  DIR_RESULTS.mkdir(parents=True, exist_ok=True)
283
 
284
- # Training loop
285
  all_metrics = []
286
  for model_name in models:
287
  if len(models) > 1:
288
- print(f"\n{'#'*60}")
289
- print(f"# MODEL: {model_name.upper()}")
290
- print(f"{'#'*60}")
291
 
292
  model_metrics = [train_category(cat, model_name) for cat in categories]
293
  all_metrics.extend(model_metrics)
294
 
295
- # Final Report per model
296
  print_final_report(model_metrics, model_name)
297
 
298
 
 
34
 
35
  logger = logging.getLogger(__name__)
36
 
 
 
 
 
 
 
37
  EFFICIENTAD_RESOURCES_DIR = Path(__file__).parent / "efficientad_resources"
38
 
39
  def _patched_prepare_pretrained_model(self) -> None:
 
60
  """Apply monkey-patch to EfficientAd to use custom pretrained weights directory."""
61
  from anomalib.models import EfficientAd
62
  EfficientAd.prepare_pretrained_model = _patched_prepare_pretrained_model
63
+ print(f" [INFO] EfficientAd: Pretrained weights directory: {EFFICIENTAD_RESOURCES_DIR / 'pre_trained'}")
 
64
 
 
 
 
65
 
66
  def save_metrics(category_metrics, category, model_name):
67
  """Saves metrics in the Anomalib directory structure."""
 
85
  version_json_path = version_dir / "metrics.json"
86
  with open(version_json_path, 'w', encoding='utf-8') as f:
87
  json.dump(category_metrics, f, indent=2, ensure_ascii=False)
88
+ print(f" Saved: {version_json_path}")
89
 
90
  # Save in latest (only if it exists)
91
  latest_dir = category_base_dir / "latest"
 
97
 
98
  def print_category_metrics(metrics):
99
  """Prints metrics for a category."""
100
+ print(f"\n[METRICS]")
101
  print(f" EFFICACY: AUROC img={format_metric(metrics['image_auroc'])} | "
102
  f"AUROC pix={format_metric(metrics['pixel_auroc'])} | "
103
  f"F1={format_metric(metrics['image_f1'])}")
 
113
  return
114
 
115
  print(f"\n{'='*100}")
116
+ print(f"FINAL REPORT - {model_name.upper()} PERFORMANCE METRICS")
117
  print(f"{'='*100}\n")
118
 
119
  # Header
 
147
  print(f"\n{'='*100}")
148
 
149
 
 
 
 
 
150
  def train_category(category, model_name):
151
  """Runs training, test, and calculates metrics for a category."""
152
  print(f"\n{'='*60}")
153
+ print(f"Training: {category} ({model_name})")
154
  print(f"{'='*60}")
155
 
156
  # Load config
 
167
 
168
  # EfficientAd-specific setup
169
  if model_name == "efficientad":
170
+ patch_efficientad()
 
171
  model_params["imagenet_dir"] = str(EFFICIENTAD_RESOURCES_DIR / "imagenette")
172
+ print(f" [INFO] EfficientAd: ImageNet directory: {EFFICIENTAD_RESOURCES_DIR / 'imagenette'}")
173
+ print(" [INFO] EfficientAd: Image visualization disabled")
174
  model_params["visualizer"] = False
175
 
176
  model = model_class(**model_params)
 
207
  # Output and save
208
  print_category_metrics(category_metrics)
209
  save_metrics(category_metrics, category, model_name)
210
+ print(f"\nCompleted: {category}\n")
211
 
212
  return category_metrics
213
 
214
 
 
 
 
 
215
  def parse_args():
216
  """Parse command line arguments."""
217
  available_models = get_available_models()
 
243
  def main():
244
  args = parse_args()
245
 
 
246
  if args.category == "all":
247
  categories = MVTEC_CATEGORIES
248
+ print(f"Training on ALL {len(categories)} categories")
249
  else:
250
  categories = [args.category]
251
+ print(f"Training on: {args.category}")
252
 
 
253
  if args.model == "all":
254
  models = get_available_models()
255
+ print(f"Models: ALL ({', '.join(models)})")
256
  else:
257
  models = [args.model]
258
+ print(f"Model: {args.model}")
259
 
 
260
  DIR_RESULTS.mkdir(parents=True, exist_ok=True)
261
 
 
262
  all_metrics = []
263
  for model_name in models:
264
  if len(models) > 1:
265
+ print(f"\n{'='*60}")
266
+ print(f"MODEL: {model_name.upper()}")
267
+ print(f"{'='*60}")
268
 
269
  model_metrics = [train_category(cat, model_name) for cat in categories]
270
  all_metrics.extend(model_metrics)
271
 
 
272
  print_final_report(model_metrics, model_name)
273
 
274