Spaces:
Sleeping
Sleeping
vari fix
Browse files- README.md +13 -0
- gradio_ui/content.py +3 -20
- gradio_ui/handlers.py +3 -25
- gradio_ui/tabs/compare.py +1 -1
- gradio_ui/visualization.py +5 -3
- scripts/download_checkpoints.py +83 -4
README.md
CHANGED
|
@@ -25,6 +25,9 @@ A comprehensive benchmark for anomaly detection models on the [MVTec AD dataset]
|
|
| 25 |
- **Multiple Models**: PatchCore, EfficientAD, FastFlow, STFPM, PaDiM
|
| 26 |
- **Full Benchmark**: Train and evaluate on all 15 MVTec categories
|
| 27 |
- **Interactive Demo**: [Gradio UI for real-time anomaly detection](https://huggingface.co/spaces/micguida1/mvtec-anomaly-benchmark)
|
|
|
|
|
|
|
|
|
|
| 28 |
- **Easy Configuration**: YAML-based model configs
|
| 29 |
|
| 30 |
## π¦ Installation
|
|
@@ -105,6 +108,16 @@ python app.py
|
|
| 105 |
|
| 106 |
The demo will be available at `http://localhost:7860`.
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
## π Project Structure
|
| 109 |
|
| 110 |
```
|
|
|
|
| 25 |
- **Multiple Models**: PatchCore, EfficientAD, FastFlow, STFPM, PaDiM
|
| 26 |
- **Full Benchmark**: Train and evaluate on all 15 MVTec categories
|
| 27 |
- **Interactive Demo**: [Gradio UI for real-time anomaly detection](https://huggingface.co/spaces/micguida1/mvtec-anomaly-benchmark)
|
| 28 |
+
- **Sample Image Gallery**: Browse and select sample images from MVTec dataset with automatic category detection
|
| 29 |
+
- **Draw Defects**: Draw artificial defects on images and see how models detect them
|
| 30 |
+
- **Model Comparison**: Compare multiple models side-by-side on the same image
|
| 31 |
- **Easy Configuration**: YAML-based model configs
|
| 32 |
|
| 33 |
## π¦ Installation
|
|
|
|
| 108 |
|
| 109 |
The demo will be available at `http://localhost:7860`.
|
| 110 |
|
| 111 |
+
#### Demo Features
|
| 112 |
+
|
| 113 |
+
- **π€ Upload Image**: Upload any image and analyze it for anomalies
|
| 114 |
+
- **βοΈ Draw Defects**: Load a sample image and draw artificial defects to test detection
|
| 115 |
+
- **π Compare Models**: Compare multiple models side-by-side on the same image
|
| 116 |
+
- **π Learn**: Educational content about each anomaly detection model
|
| 117 |
+
- **π Metrics**: View detailed performance metrics for each model
|
| 118 |
+
|
| 119 |
+
**Sample Image Gallery**: Each tab includes a gallery of sample images from the MVTec dataset. Click on any image to load it and the category will be automatically selected.
|
| 120 |
+
|
| 121 |
## π Project Structure
|
| 122 |
|
| 123 |
```
|
gradio_ui/content.py
CHANGED
|
@@ -20,26 +20,11 @@ Test how well the anomaly detection model captures defects:
|
|
| 20 |
1. **Upload a GOOD image** (normal, without defects)
|
| 21 |
2. **Use the brush to draw** artificial defects (scratches, stains, cracks, etc.)
|
| 22 |
3. **Click Analyze** to see if the heatmap detects your drawn defects
|
| 23 |
-
|
| 24 |
-
> π‘ **Tip:** Use different brush sizes and colors to simulate various defect types!
|
| 25 |
"""
|
| 26 |
|
| 27 |
-
BRUSH_COLORS_INFO = ""
|
| 28 |
-
---
|
| 29 |
-
**Brush Colors:**
|
| 30 |
-
- β« Black - Dark scratches, contamination
|
| 31 |
-
- π΄ Red - Highlighted defects
|
| 32 |
-
- π€ Brown - Rust, stains
|
| 33 |
-
- βͺ Gray/White - Light marks
|
| 34 |
-
"""
|
| 35 |
|
| 36 |
-
HEATMAP_INTERPRETATION = ""
|
| 37 |
-
---
|
| 38 |
-
**How to interpret:**
|
| 39 |
-
- π΄ **Red areas** in heatmap = High anomaly probability (model detected something)
|
| 40 |
-
- π΅ **Blue areas** = Normal regions
|
| 41 |
-
- Compare drawn defects with heatmap response!
|
| 42 |
-
"""
|
| 43 |
|
| 44 |
# Compare Models Tab Instructions
|
| 45 |
COMPARE_MODELS_INSTRUCTIONS = """
|
|
@@ -47,10 +32,8 @@ COMPARE_MODELS_INSTRUCTIONS = """
|
|
| 47 |
|
| 48 |
Compare how different anomaly detection models perform on the same image:
|
| 49 |
1. **Upload an image** (normal or with defects)
|
| 50 |
-
2. **Select 2-
|
| 51 |
3. **Click Compare** to see heatmaps side-by-side
|
| 52 |
-
|
| 53 |
-
> π‘ **Tip:** This helps identify which model is most sensitive to certain defect types!
|
| 54 |
"""
|
| 55 |
|
| 56 |
# Learn About Models Tab Content
|
|
|
|
| 20 |
1. **Upload a GOOD image** (normal, without defects)
|
| 21 |
2. **Use the brush to draw** artificial defects (scratches, stains, cracks, etc.)
|
| 22 |
3. **Click Analyze** to see if the heatmap detects your drawn defects
|
|
|
|
|
|
|
| 23 |
"""
|
| 24 |
|
| 25 |
+
BRUSH_COLORS_INFO = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
HEATMAP_INTERPRETATION = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# Compare Models Tab Instructions
|
| 30 |
COMPARE_MODELS_INSTRUCTIONS = """
|
|
|
|
| 32 |
|
| 33 |
Compare how different anomaly detection models perform on the same image:
|
| 34 |
1. **Upload an image** (normal or with defects)
|
| 35 |
+
2. **Select 2-5 models** to compare
|
| 36 |
3. **Click Compare** to see heatmaps side-by-side
|
|
|
|
|
|
|
| 37 |
"""
|
| 38 |
|
| 39 |
# Learn About Models Tab Content
|
gradio_ui/handlers.py
CHANGED
|
@@ -164,8 +164,8 @@ def predict_compare(image, selected_models: list, category: str):
|
|
| 164 |
if not selected_models or len(selected_models) < 2:
|
| 165 |
return None, "β οΈ Please select at least 2 models to compare"
|
| 166 |
|
| 167 |
-
if len(selected_models) >
|
| 168 |
-
return None, "β οΈ Please select at most
|
| 169 |
|
| 170 |
try:
|
| 171 |
# Save temp image if needed
|
|
@@ -204,29 +204,7 @@ def predict_compare(image, selected_models: list, category: str):
|
|
| 204 |
# Create comparison visualization
|
| 205 |
viz_image = create_comparison_visualization(original, results_list)
|
| 206 |
|
| 207 |
-
|
| 208 |
-
summary_lines = ["### π Comparison Results\n"]
|
| 209 |
-
summary_lines.append("| Model | Score | Status |")
|
| 210 |
-
summary_lines.append("|-------|-------|--------|")
|
| 211 |
-
|
| 212 |
-
for result in results_list:
|
| 213 |
-
if result.get('error'):
|
| 214 |
-
summary_lines.append(f"| {result['model_name']} | β Error | {result['error'][:30]}... |")
|
| 215 |
-
else:
|
| 216 |
-
score = result['score']
|
| 217 |
-
model_name = result['model_name']
|
| 218 |
-
|
| 219 |
-
# Scale if EfficientAD
|
| 220 |
-
display_score = score
|
| 221 |
-
if "efficientad" in model_name.lower():
|
| 222 |
-
display_score = scale_efficientad_score(score)
|
| 223 |
-
|
| 224 |
-
status = "π΄ Anomaly" if display_score > 0.5 else "π’ Normal"
|
| 225 |
-
summary_lines.append(f"| {model_name} | {display_score:.4f} | {status} |")
|
| 226 |
-
|
| 227 |
-
summary = "\n".join(summary_lines)
|
| 228 |
-
|
| 229 |
-
return viz_image, summary
|
| 230 |
|
| 231 |
except Exception as e:
|
| 232 |
return None, f"β Error: {str(e)}"
|
|
|
|
| 164 |
if not selected_models or len(selected_models) < 2:
|
| 165 |
return None, "β οΈ Please select at least 2 models to compare"
|
| 166 |
|
| 167 |
+
if len(selected_models) > 5:
|
| 168 |
+
return None, "β οΈ Please select at most 5 models"
|
| 169 |
|
| 170 |
try:
|
| 171 |
# Save temp image if needed
|
|
|
|
| 204 |
# Create comparison visualization
|
| 205 |
viz_image = create_comparison_visualization(original, results_list)
|
| 206 |
|
| 207 |
+
return viz_image, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
except Exception as e:
|
| 210 |
return None, f"β Error: {str(e)}"
|
gradio_ui/tabs/compare.py
CHANGED
|
@@ -35,7 +35,7 @@ def create_compare_tab(available_models: list, initial_categories: list):
|
|
| 35 |
height=300
|
| 36 |
)
|
| 37 |
|
| 38 |
-
gr.Markdown("### π§ Select Models (2-
|
| 39 |
|
| 40 |
compare_model_checkboxes = gr.CheckboxGroup(
|
| 41 |
choices=available_models,
|
|
|
|
| 35 |
height=300
|
| 36 |
)
|
| 37 |
|
| 38 |
+
gr.Markdown("### π§ Select Models (2-5)")
|
| 39 |
|
| 40 |
compare_model_checkboxes = gr.CheckboxGroup(
|
| 41 |
choices=available_models,
|
gradio_ui/visualization.py
CHANGED
|
@@ -264,10 +264,12 @@ def create_comparison_visualization(original: np.ndarray, results_list: list) ->
|
|
| 264 |
axes[row, 3].set_title(col_titles[3], fontsize=12, fontweight='bold')
|
| 265 |
|
| 266 |
# Add model name as row label on the left using annotation
|
| 267 |
-
#
|
| 268 |
-
|
|
|
|
| 269 |
xy=(-0.1, 0.5), xycoords='axes fraction',
|
| 270 |
-
fontsize=11, fontweight='bold', ha='right', va='center'
|
|
|
|
| 271 |
|
| 272 |
# Add colorbar
|
| 273 |
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
|
|
|
|
| 264 |
axes[row, 3].set_title(col_titles[3], fontsize=12, fontweight='bold')
|
| 265 |
|
| 266 |
# Add model name as row label on the left using annotation
|
| 267 |
+
# Show model name, score and status
|
| 268 |
+
status_color = 'red' if display_score > 0.5 else 'green'
|
| 269 |
+
axes[row, 0].annotate(f"{model_name}\nScore: {display_score:.2f}\n{status}",
|
| 270 |
xy=(-0.1, 0.5), xycoords='axes fraction',
|
| 271 |
+
fontsize=11, fontweight='bold', ha='right', va='center',
|
| 272 |
+
color=status_color)
|
| 273 |
|
| 274 |
# Add colorbar
|
| 275 |
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
|
scripts/download_checkpoints.py
CHANGED
|
@@ -61,6 +61,21 @@ def get_checkpoint_hf_path(model_name: str, category: str) -> str:
|
|
| 61 |
return f"{dirname}/MVTecAD/{category}/latest/weights/lightning/model.ckpt"
|
| 62 |
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
def get_local_checkpoint_path(model_name: str, category: str) -> Path:
|
| 65 |
"""
|
| 66 |
Returns the local path where the checkpoint should be stored.
|
|
@@ -76,6 +91,21 @@ def get_local_checkpoint_path(model_name: str, category: str) -> Path:
|
|
| 76 |
return DIR_RESULTS / dirname / "MVTecAD" / category / "latest" / "weights" / "lightning" / "model.ckpt"
|
| 77 |
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
def download_checkpoint(model_name: str, category: str, force: bool = False) -> bool:
|
| 80 |
"""
|
| 81 |
Downloads a single checkpoint from HuggingFace Hub.
|
|
@@ -115,13 +145,52 @@ def download_checkpoint(model_name: str, category: str, force: bool = False) ->
|
|
| 115 |
return False
|
| 116 |
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
def download_all_checkpoints(
|
| 119 |
models: list[str] = None,
|
| 120 |
categories: list[str] = None,
|
| 121 |
force: bool = False
|
| 122 |
) -> dict:
|
| 123 |
"""
|
| 124 |
-
Downloads checkpoints for specified models and categories.
|
| 125 |
|
| 126 |
Args:
|
| 127 |
models: List of model names (None = all available)
|
|
@@ -136,11 +205,11 @@ def download_all_checkpoints(
|
|
| 136 |
if categories is None:
|
| 137 |
categories = MVTEC_CATEGORIES
|
| 138 |
|
| 139 |
-
stats = {"downloaded": 0, "existed": 0, "failed": 0}
|
| 140 |
|
| 141 |
total = len(models) * len(categories)
|
| 142 |
|
| 143 |
-
print(f"π¦ Downloading checkpoints from: {HF_REPO_ID}")
|
| 144 |
print(f" Models: {', '.join(models)}")
|
| 145 |
print(f" Categories: {len(categories)} total")
|
| 146 |
print()
|
|
@@ -156,6 +225,10 @@ def download_all_checkpoints(
|
|
| 156 |
stats["downloaded"] += 1
|
| 157 |
else:
|
| 158 |
stats["failed"] += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
pbar.update(1)
|
| 161 |
|
|
@@ -179,6 +252,7 @@ def check_checkpoint_exists(model_name: str, category: str) -> bool:
|
|
| 179 |
def ensure_checkpoint(model_name: str, category: str) -> Path:
|
| 180 |
"""
|
| 181 |
Ensures a checkpoint exists, downloading if necessary.
|
|
|
|
| 182 |
|
| 183 |
This is the main function to call from inference/app code.
|
| 184 |
|
|
@@ -195,11 +269,15 @@ def ensure_checkpoint(model_name: str, category: str) -> Path:
|
|
| 195 |
local_path = get_local_checkpoint_path(model_name, category)
|
| 196 |
|
| 197 |
if local_path.exists():
|
|
|
|
|
|
|
| 198 |
return local_path
|
| 199 |
|
| 200 |
print(f"β¬ Checkpoint not found locally. Downloading {model_name}/{category}...")
|
| 201 |
|
| 202 |
if download_checkpoint(model_name, category):
|
|
|
|
|
|
|
| 203 |
if local_path.exists():
|
| 204 |
print(f"β Downloaded successfully")
|
| 205 |
return local_path
|
|
@@ -258,8 +336,9 @@ def main():
|
|
| 258 |
# Report
|
| 259 |
print()
|
| 260 |
print("=" * 50)
|
| 261 |
-
print(f"β
|
| 262 |
print(f"β Already existed: {stats['existed']}")
|
|
|
|
| 263 |
if stats['failed'] > 0:
|
| 264 |
print(f"β Failed: {stats['failed']}")
|
| 265 |
print("=" * 50)
|
|
|
|
| 61 |
return f"{dirname}/MVTecAD/{category}/latest/weights/lightning/model.ckpt"
|
| 62 |
|
| 63 |
|
| 64 |
+
def get_metrics_hf_path(model_name: str, category: str) -> str:
|
| 65 |
+
"""
|
| 66 |
+
Returns the path of the metrics.json file in the HF repository.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
model_name: Name of the model
|
| 70 |
+
category: MVTec category
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Path string relative to HF repo root
|
| 74 |
+
"""
|
| 75 |
+
dirname = MODEL_TO_DIRNAME.get(model_name, model_name.capitalize())
|
| 76 |
+
return f"{dirname}/MVTecAD/{category}/latest/metrics.json"
|
| 77 |
+
|
| 78 |
+
|
| 79 |
def get_local_checkpoint_path(model_name: str, category: str) -> Path:
|
| 80 |
"""
|
| 81 |
Returns the local path where the checkpoint should be stored.
|
|
|
|
| 91 |
return DIR_RESULTS / dirname / "MVTecAD" / category / "latest" / "weights" / "lightning" / "model.ckpt"
|
| 92 |
|
| 93 |
|
| 94 |
+
def get_local_metrics_path(model_name: str, category: str) -> Path:
|
| 95 |
+
"""
|
| 96 |
+
Returns the local path where the metrics.json should be stored.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
model_name: Name of the model
|
| 100 |
+
category: MVTec category
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
Path object for local metrics file
|
| 104 |
+
"""
|
| 105 |
+
dirname = MODEL_TO_DIRNAME.get(model_name, model_name.capitalize())
|
| 106 |
+
return DIR_RESULTS / dirname / "MVTecAD" / category / "latest" / "metrics.json"
|
| 107 |
+
|
| 108 |
+
|
| 109 |
def download_checkpoint(model_name: str, category: str, force: bool = False) -> bool:
|
| 110 |
"""
|
| 111 |
Downloads a single checkpoint from HuggingFace Hub.
|
|
|
|
| 145 |
return False
|
| 146 |
|
| 147 |
|
| 148 |
+
def download_metrics(model_name: str, category: str, force: bool = False) -> bool:
|
| 149 |
+
"""
|
| 150 |
+
Downloads metrics.json for a model/category from HuggingFace Hub.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
model_name: Name of the model
|
| 154 |
+
category: MVTec category
|
| 155 |
+
force: If True, re-download even if exists
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
True if downloaded/exists, False if failed
|
| 159 |
+
"""
|
| 160 |
+
local_path = get_local_metrics_path(model_name, category)
|
| 161 |
+
|
| 162 |
+
# Skip if already exists
|
| 163 |
+
if local_path.exists() and not force:
|
| 164 |
+
return True
|
| 165 |
+
|
| 166 |
+
hf_path = get_metrics_hf_path(model_name, category)
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
# Create parent directories
|
| 170 |
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
| 171 |
+
|
| 172 |
+
# Download from HF Hub
|
| 173 |
+
downloaded_path = hf_hub_download(
|
| 174 |
+
repo_id=HF_REPO_ID,
|
| 175 |
+
filename=hf_path,
|
| 176 |
+
local_dir=DIR_RESULTS,
|
| 177 |
+
local_dir_use_symlinks=False,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
return True
|
| 181 |
+
|
| 182 |
+
except Exception as e:
|
| 183 |
+
# Metrics file is optional, don't print error
|
| 184 |
+
return False
|
| 185 |
+
|
| 186 |
+
|
| 187 |
def download_all_checkpoints(
|
| 188 |
models: list[str] = None,
|
| 189 |
categories: list[str] = None,
|
| 190 |
force: bool = False
|
| 191 |
) -> dict:
|
| 192 |
"""
|
| 193 |
+
Downloads checkpoints and metrics for specified models and categories.
|
| 194 |
|
| 195 |
Args:
|
| 196 |
models: List of model names (None = all available)
|
|
|
|
| 205 |
if categories is None:
|
| 206 |
categories = MVTEC_CATEGORIES
|
| 207 |
|
| 208 |
+
stats = {"downloaded": 0, "existed": 0, "failed": 0, "metrics_downloaded": 0}
|
| 209 |
|
| 210 |
total = len(models) * len(categories)
|
| 211 |
|
| 212 |
+
print(f"π¦ Downloading checkpoints and metrics from: {HF_REPO_ID}")
|
| 213 |
print(f" Models: {', '.join(models)}")
|
| 214 |
print(f" Categories: {len(categories)} total")
|
| 215 |
print()
|
|
|
|
| 225 |
stats["downloaded"] += 1
|
| 226 |
else:
|
| 227 |
stats["failed"] += 1
|
| 228 |
+
|
| 229 |
+
# Also download metrics.json if available
|
| 230 |
+
if download_metrics(model, category, force):
|
| 231 |
+
stats["metrics_downloaded"] += 1
|
| 232 |
|
| 233 |
pbar.update(1)
|
| 234 |
|
|
|
|
| 252 |
def ensure_checkpoint(model_name: str, category: str) -> Path:
|
| 253 |
"""
|
| 254 |
Ensures a checkpoint exists, downloading if necessary.
|
| 255 |
+
Also downloads metrics.json if available.
|
| 256 |
|
| 257 |
This is the main function to call from inference/app code.
|
| 258 |
|
|
|
|
| 269 |
local_path = get_local_checkpoint_path(model_name, category)
|
| 270 |
|
| 271 |
if local_path.exists():
|
| 272 |
+
# Also try to download metrics if not present
|
| 273 |
+
download_metrics(model_name, category)
|
| 274 |
return local_path
|
| 275 |
|
| 276 |
print(f"β¬ Checkpoint not found locally. Downloading {model_name}/{category}...")
|
| 277 |
|
| 278 |
if download_checkpoint(model_name, category):
|
| 279 |
+
# Also download metrics
|
| 280 |
+
download_metrics(model_name, category)
|
| 281 |
if local_path.exists():
|
| 282 |
print(f"β Downloaded successfully")
|
| 283 |
return local_path
|
|
|
|
| 336 |
# Report
|
| 337 |
print()
|
| 338 |
print("=" * 50)
|
| 339 |
+
print(f"β Checkpoints downloaded: {stats['downloaded']}")
|
| 340 |
print(f"β Already existed: {stats['existed']}")
|
| 341 |
+
print(f"π Metrics downloaded: {stats['metrics_downloaded']}")
|
| 342 |
if stats['failed'] > 0:
|
| 343 |
print(f"β Failed: {stats['failed']}")
|
| 344 |
print("=" * 50)
|