mik170802 commited on
Commit
1d2dba1
Β·
1 Parent(s): 687d7bd
README.md CHANGED
@@ -25,6 +25,9 @@ A comprehensive benchmark for anomaly detection models on the [MVTec AD dataset]
25
  - **Multiple Models**: PatchCore, EfficientAD, FastFlow, STFPM, PaDiM
26
  - **Full Benchmark**: Train and evaluate on all 15 MVTec categories
27
  - **Interactive Demo**: [Gradio UI for real-time anomaly detection](https://huggingface.co/spaces/micguida1/mvtec-anomaly-benchmark)
 
 
 
28
  - **Easy Configuration**: YAML-based model configs
29
 
30
  ## πŸ“¦ Installation
@@ -105,6 +108,16 @@ python app.py
105
 
106
  The demo will be available at `http://localhost:7860`.
107
 
 
 
 
 
 
 
 
 
 
 
108
  ## πŸ“ Project Structure
109
 
110
  ```
 
25
  - **Multiple Models**: PatchCore, EfficientAD, FastFlow, STFPM, PaDiM
26
  - **Full Benchmark**: Train and evaluate on all 15 MVTec categories
27
  - **Interactive Demo**: [Gradio UI for real-time anomaly detection](https://huggingface.co/spaces/micguida1/mvtec-anomaly-benchmark)
28
+ - **Sample Image Gallery**: Browse and select sample images from MVTec dataset with automatic category detection
29
+ - **Draw Defects**: Draw artificial defects on images and see how models detect them
30
+ - **Model Comparison**: Compare multiple models side-by-side on the same image
31
  - **Easy Configuration**: YAML-based model configs
32
 
33
  ## πŸ“¦ Installation
 
108
 
109
  The demo will be available at `http://localhost:7860`.
110
 
111
+ #### Demo Features
112
+
113
+ - **πŸ“€ Upload Image**: Upload any image and analyze it for anomalies
114
+ - **✏️ Draw Defects**: Load a sample image and draw artificial defects to test detection
115
+ - **πŸ”„ Compare Models**: Compare multiple models side-by-side on the same image
116
+ - **πŸ“š Learn**: Educational content about each anomaly detection model
117
+ - **πŸ“Š Metrics**: View detailed performance metrics for each model
118
+
119
+ **Sample Image Gallery**: Each tab includes a gallery of sample images from the MVTec dataset. Click on any image to load it and the category will be automatically selected.
120
+
121
  ## πŸ“ Project Structure
122
 
123
  ```
gradio_ui/content.py CHANGED
@@ -20,26 +20,11 @@ Test how well the anomaly detection model captures defects:
20
  1. **Upload a GOOD image** (normal, without defects)
21
  2. **Use the brush to draw** artificial defects (scratches, stains, cracks, etc.)
22
  3. **Click Analyze** to see if the heatmap detects your drawn defects
23
-
24
- > πŸ’‘ **Tip:** Use different brush sizes and colors to simulate various defect types!
25
  """
26
 
27
- BRUSH_COLORS_INFO = """
28
- ---
29
- **Brush Colors:**
30
- - ⚫ Black - Dark scratches, contamination
31
- - πŸ”΄ Red - Highlighted defects
32
- - 🟀 Brown - Rust, stains
33
- - βšͺ Gray/White - Light marks
34
- """
35
 
36
- HEATMAP_INTERPRETATION = """
37
- ---
38
- **How to interpret:**
39
- - πŸ”΄ **Red areas** in heatmap = High anomaly probability (model detected something)
40
- - πŸ”΅ **Blue areas** = Normal regions
41
- - Compare drawn defects with heatmap response!
42
- """
43
 
44
  # Compare Models Tab Instructions
45
  COMPARE_MODELS_INSTRUCTIONS = """
@@ -47,10 +32,8 @@ COMPARE_MODELS_INSTRUCTIONS = """
47
 
48
  Compare how different anomaly detection models perform on the same image:
49
  1. **Upload an image** (normal or with defects)
50
- 2. **Select 2-4 models** to compare
51
  3. **Click Compare** to see heatmaps side-by-side
52
-
53
- > πŸ’‘ **Tip:** This helps identify which model is most sensitive to certain defect types!
54
  """
55
 
56
  # Learn About Models Tab Content
 
20
  1. **Upload a GOOD image** (normal, without defects)
21
  2. **Use the brush to draw** artificial defects (scratches, stains, cracks, etc.)
22
  3. **Click Analyze** to see if the heatmap detects your drawn defects
 
 
23
  """
24
 
25
+ BRUSH_COLORS_INFO = ""
 
 
 
 
 
 
 
26
 
27
+ HEATMAP_INTERPRETATION = ""
 
 
 
 
 
 
28
 
29
  # Compare Models Tab Instructions
30
  COMPARE_MODELS_INSTRUCTIONS = """
 
32
 
33
  Compare how different anomaly detection models perform on the same image:
34
  1. **Upload an image** (normal or with defects)
35
+ 2. **Select 2-5 models** to compare
36
  3. **Click Compare** to see heatmaps side-by-side
 
 
37
  """
38
 
39
  # Learn About Models Tab Content
gradio_ui/handlers.py CHANGED
@@ -164,8 +164,8 @@ def predict_compare(image, selected_models: list, category: str):
164
  if not selected_models or len(selected_models) < 2:
165
  return None, "⚠️ Please select at least 2 models to compare"
166
 
167
- if len(selected_models) > 4:
168
- return None, "⚠️ Please select at most 4 models"
169
 
170
  try:
171
  # Save temp image if needed
@@ -204,29 +204,7 @@ def predict_compare(image, selected_models: list, category: str):
204
  # Create comparison visualization
205
  viz_image = create_comparison_visualization(original, results_list)
206
 
207
- # Create summary text
208
- summary_lines = ["### πŸ“Š Comparison Results\n"]
209
- summary_lines.append("| Model | Score | Status |")
210
- summary_lines.append("|-------|-------|--------|")
211
-
212
- for result in results_list:
213
- if result.get('error'):
214
- summary_lines.append(f"| {result['model_name']} | ❌ Error | {result['error'][:30]}... |")
215
- else:
216
- score = result['score']
217
- model_name = result['model_name']
218
-
219
- # Scale if EfficientAD
220
- display_score = score
221
- if "efficientad" in model_name.lower():
222
- display_score = scale_efficientad_score(score)
223
-
224
- status = "πŸ”΄ Anomaly" if display_score > 0.5 else "🟒 Normal"
225
- summary_lines.append(f"| {model_name} | {display_score:.4f} | {status} |")
226
-
227
- summary = "\n".join(summary_lines)
228
-
229
- return viz_image, summary
230
 
231
  except Exception as e:
232
  return None, f"❌ Error: {str(e)}"
 
164
  if not selected_models or len(selected_models) < 2:
165
  return None, "⚠️ Please select at least 2 models to compare"
166
 
167
+ if len(selected_models) > 5:
168
+ return None, "⚠️ Please select at most 5 models"
169
 
170
  try:
171
  # Save temp image if needed
 
204
  # Create comparison visualization
205
  viz_image = create_comparison_visualization(original, results_list)
206
 
207
+ return viz_image, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  except Exception as e:
210
  return None, f"❌ Error: {str(e)}"
gradio_ui/tabs/compare.py CHANGED
@@ -35,7 +35,7 @@ def create_compare_tab(available_models: list, initial_categories: list):
35
  height=300
36
  )
37
 
38
- gr.Markdown("### 🧠 Select Models (2-4)")
39
 
40
  compare_model_checkboxes = gr.CheckboxGroup(
41
  choices=available_models,
 
35
  height=300
36
  )
37
 
38
+ gr.Markdown("### 🧠 Select Models (2-5)")
39
 
40
  compare_model_checkboxes = gr.CheckboxGroup(
41
  choices=available_models,
gradio_ui/visualization.py CHANGED
@@ -264,10 +264,12 @@ def create_comparison_visualization(original: np.ndarray, results_list: list) ->
264
  axes[row, 3].set_title(col_titles[3], fontsize=12, fontweight='bold')
265
 
266
  # Add model name as row label on the left using annotation
267
- # Removing numerical score as requested
268
- axes[row, 0].annotate(f"{model_name}\n{status}",
 
269
  xy=(-0.1, 0.5), xycoords='axes fraction',
270
- fontsize=11, fontweight='bold', ha='right', va='center')
 
271
 
272
  # Add colorbar
273
  cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
 
264
  axes[row, 3].set_title(col_titles[3], fontsize=12, fontweight='bold')
265
 
266
  # Add model name as row label on the left using annotation
267
+ # Show model name, score and status
268
+ status_color = 'red' if display_score > 0.5 else 'green'
269
+ axes[row, 0].annotate(f"{model_name}\nScore: {display_score:.2f}\n{status}",
270
  xy=(-0.1, 0.5), xycoords='axes fraction',
271
+ fontsize=11, fontweight='bold', ha='right', va='center',
272
+ color=status_color)
273
 
274
  # Add colorbar
275
  cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
scripts/download_checkpoints.py CHANGED
@@ -61,6 +61,21 @@ def get_checkpoint_hf_path(model_name: str, category: str) -> str:
61
  return f"{dirname}/MVTecAD/{category}/latest/weights/lightning/model.ckpt"
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def get_local_checkpoint_path(model_name: str, category: str) -> Path:
65
  """
66
  Returns the local path where the checkpoint should be stored.
@@ -76,6 +91,21 @@ def get_local_checkpoint_path(model_name: str, category: str) -> Path:
76
  return DIR_RESULTS / dirname / "MVTecAD" / category / "latest" / "weights" / "lightning" / "model.ckpt"
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def download_checkpoint(model_name: str, category: str, force: bool = False) -> bool:
80
  """
81
  Downloads a single checkpoint from HuggingFace Hub.
@@ -115,13 +145,52 @@ def download_checkpoint(model_name: str, category: str, force: bool = False) ->
115
  return False
116
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  def download_all_checkpoints(
119
  models: list[str] = None,
120
  categories: list[str] = None,
121
  force: bool = False
122
  ) -> dict:
123
  """
124
- Downloads checkpoints for specified models and categories.
125
 
126
  Args:
127
  models: List of model names (None = all available)
@@ -136,11 +205,11 @@ def download_all_checkpoints(
136
  if categories is None:
137
  categories = MVTEC_CATEGORIES
138
 
139
- stats = {"downloaded": 0, "existed": 0, "failed": 0}
140
 
141
  total = len(models) * len(categories)
142
 
143
- print(f"πŸ“¦ Downloading checkpoints from: {HF_REPO_ID}")
144
  print(f" Models: {', '.join(models)}")
145
  print(f" Categories: {len(categories)} total")
146
  print()
@@ -156,6 +225,10 @@ def download_all_checkpoints(
156
  stats["downloaded"] += 1
157
  else:
158
  stats["failed"] += 1
 
 
 
 
159
 
160
  pbar.update(1)
161
 
@@ -179,6 +252,7 @@ def check_checkpoint_exists(model_name: str, category: str) -> bool:
179
  def ensure_checkpoint(model_name: str, category: str) -> Path:
180
  """
181
  Ensures a checkpoint exists, downloading if necessary.
 
182
 
183
  This is the main function to call from inference/app code.
184
 
@@ -195,11 +269,15 @@ def ensure_checkpoint(model_name: str, category: str) -> Path:
195
  local_path = get_local_checkpoint_path(model_name, category)
196
 
197
  if local_path.exists():
 
 
198
  return local_path
199
 
200
  print(f"⬇ Checkpoint not found locally. Downloading {model_name}/{category}...")
201
 
202
  if download_checkpoint(model_name, category):
 
 
203
  if local_path.exists():
204
  print(f"βœ“ Downloaded successfully")
205
  return local_path
@@ -258,8 +336,9 @@ def main():
258
  # Report
259
  print()
260
  print("=" * 50)
261
- print(f"βœ“ Downloaded: {stats['downloaded']}")
262
  print(f"β—‹ Already existed: {stats['existed']}")
 
263
  if stats['failed'] > 0:
264
  print(f"βœ— Failed: {stats['failed']}")
265
  print("=" * 50)
 
61
  return f"{dirname}/MVTecAD/{category}/latest/weights/lightning/model.ckpt"
62
 
63
 
64
+ def get_metrics_hf_path(model_name: str, category: str) -> str:
65
+ """
66
+ Returns the path of the metrics.json file in the HF repository.
67
+
68
+ Args:
69
+ model_name: Name of the model
70
+ category: MVTec category
71
+
72
+ Returns:
73
+ Path string relative to HF repo root
74
+ """
75
+ dirname = MODEL_TO_DIRNAME.get(model_name, model_name.capitalize())
76
+ return f"{dirname}/MVTecAD/{category}/latest/metrics.json"
77
+
78
+
79
  def get_local_checkpoint_path(model_name: str, category: str) -> Path:
80
  """
81
  Returns the local path where the checkpoint should be stored.
 
91
  return DIR_RESULTS / dirname / "MVTecAD" / category / "latest" / "weights" / "lightning" / "model.ckpt"
92
 
93
 
94
+ def get_local_metrics_path(model_name: str, category: str) -> Path:
95
+ """
96
+ Returns the local path where the metrics.json should be stored.
97
+
98
+ Args:
99
+ model_name: Name of the model
100
+ category: MVTec category
101
+
102
+ Returns:
103
+ Path object for local metrics file
104
+ """
105
+ dirname = MODEL_TO_DIRNAME.get(model_name, model_name.capitalize())
106
+ return DIR_RESULTS / dirname / "MVTecAD" / category / "latest" / "metrics.json"
107
+
108
+
109
  def download_checkpoint(model_name: str, category: str, force: bool = False) -> bool:
110
  """
111
  Downloads a single checkpoint from HuggingFace Hub.
 
145
  return False
146
 
147
 
148
+ def download_metrics(model_name: str, category: str, force: bool = False) -> bool:
149
+ """
150
+ Downloads metrics.json for a model/category from HuggingFace Hub.
151
+
152
+ Args:
153
+ model_name: Name of the model
154
+ category: MVTec category
155
+ force: If True, re-download even if exists
156
+
157
+ Returns:
158
+ True if downloaded/exists, False if failed
159
+ """
160
+ local_path = get_local_metrics_path(model_name, category)
161
+
162
+ # Skip if already exists
163
+ if local_path.exists() and not force:
164
+ return True
165
+
166
+ hf_path = get_metrics_hf_path(model_name, category)
167
+
168
+ try:
169
+ # Create parent directories
170
+ local_path.parent.mkdir(parents=True, exist_ok=True)
171
+
172
+ # Download from HF Hub
173
+ downloaded_path = hf_hub_download(
174
+ repo_id=HF_REPO_ID,
175
+ filename=hf_path,
176
+ local_dir=DIR_RESULTS,
177
+ local_dir_use_symlinks=False,
178
+ )
179
+
180
+ return True
181
+
182
+ except Exception as e:
183
+ # Metrics file is optional, don't print error
184
+ return False
185
+
186
+
187
  def download_all_checkpoints(
188
  models: list[str] = None,
189
  categories: list[str] = None,
190
  force: bool = False
191
  ) -> dict:
192
  """
193
+ Downloads checkpoints and metrics for specified models and categories.
194
 
195
  Args:
196
  models: List of model names (None = all available)
 
205
  if categories is None:
206
  categories = MVTEC_CATEGORIES
207
 
208
+ stats = {"downloaded": 0, "existed": 0, "failed": 0, "metrics_downloaded": 0}
209
 
210
  total = len(models) * len(categories)
211
 
212
+ print(f"πŸ“¦ Downloading checkpoints and metrics from: {HF_REPO_ID}")
213
  print(f" Models: {', '.join(models)}")
214
  print(f" Categories: {len(categories)} total")
215
  print()
 
225
  stats["downloaded"] += 1
226
  else:
227
  stats["failed"] += 1
228
+
229
+ # Also download metrics.json if available
230
+ if download_metrics(model, category, force):
231
+ stats["metrics_downloaded"] += 1
232
 
233
  pbar.update(1)
234
 
 
252
  def ensure_checkpoint(model_name: str, category: str) -> Path:
253
  """
254
  Ensures a checkpoint exists, downloading if necessary.
255
+ Also downloads metrics.json if available.
256
 
257
  This is the main function to call from inference/app code.
258
 
 
269
  local_path = get_local_checkpoint_path(model_name, category)
270
 
271
  if local_path.exists():
272
+ # Also try to download metrics if not present
273
+ download_metrics(model_name, category)
274
  return local_path
275
 
276
  print(f"⬇ Checkpoint not found locally. Downloading {model_name}/{category}...")
277
 
278
  if download_checkpoint(model_name, category):
279
+ # Also download metrics
280
+ download_metrics(model_name, category)
281
  if local_path.exists():
282
  print(f"βœ“ Downloaded successfully")
283
  return local_path
 
336
  # Report
337
  print()
338
  print("=" * 50)
339
+ print(f"βœ“ Checkpoints downloaded: {stats['downloaded']}")
340
  print(f"β—‹ Already existed: {stats['existed']}")
341
+ print(f"πŸ“Š Metrics downloaded: {stats['metrics_downloaded']}")
342
  if stats['failed'] > 0:
343
  print(f"βœ— Failed: {stats['failed']}")
344
  print("=" * 50)