Parikshit Rathode commited on
Commit
c5732cc
·
1 Parent(s): c653c53

initial commit

Browse files
Files changed (10) hide show
  1. .gitignore +69 -0
  2. README.md +2 -2
  3. app.py +290 -0
  4. config.py +85 -0
  5. core/explain.py +97 -0
  6. core/inference.py +98 -0
  7. core/postprocess.py +180 -0
  8. core/visualization.py +146 -0
  9. models/model_loader.py +95 -0
  10. requirements.txt +37 -0
.gitignore ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ MANIFEST
23
+
24
+ # Virtual Environment
25
+ venv/
26
+ ENV/
27
+ env/
28
+ .venv
29
+
30
+ # IDE
31
+ .vscode/
32
+ .idea/
33
+ *.swp
34
+ *.swo
35
+ *~
36
+ .DS_Store
37
+
38
+ # Project specific
39
+ models/* # Cached models are large, keep them out of git
40
+ datasets/
41
+ *.tar.xz
42
+ *.zip
43
+ *.ckpt
44
+ *.pth
45
+ *.onnx
46
+
47
+ # Environment variables
48
+ .env
49
+ .env.local
50
+ .env.*.local
51
+
52
+ # Logs
53
+ *.log
54
+ logs/
55
+
56
+ # Temporary files
57
+ tmp/
58
+ temp/
59
+ saved_pillow_image.png
60
+
61
+ # Jupyter
62
+ .ipynb_checkpoints/
63
+ *.ipynb
64
+
65
+ # OS
66
+ Thumbs.db
67
+
68
+ # Gradio
69
+ gradio_cookie_*.json
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Iad Explainable
3
  emoji: 🌖
4
  colorFrom: red
5
  colorTo: indigo
@@ -8,7 +8,7 @@ sdk_version: 6.10.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: Multi-model anomaly detection (PatchCore + EfficientAD) with
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Industrial Anomaly Detection & Explainability System
3
  emoji: 🌖
4
  colorFrom: red
5
  colorTo: indigo
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: Multi-model anomaly detection (PatchCore + EfficientAD) with Explainable AI using Gemini
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main application entry point with Gradio UI.
3
+
4
+ This module orchestrates the anomaly detection pipeline by integrating
5
+ all core modules and providing a user-friendly web interface.
6
+ """
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import cv2
11
+
12
+ from config import THRESHOLDS, MVTEC_CATEGORIES, IMAGE_SIZE, THRESHOLD_MULTIPLIER
13
+ from models.model_loader import load_model
14
+ from core.inference import run_inference
15
+ from core.postprocess import postprocess
16
+ from core.visualization import create_visuals
17
+ from core.explain import get_explanation, init_gemini_client
18
+
19
+
20
+ def get_threshold(model_name: str, category: str) -> float:
21
+ """
22
+ Get the threshold for a specific model and category, applying the multiplier.
23
+
24
+ Args:
25
+ model_name: Name of the model
26
+ category: MVTec AD category
27
+
28
+ Returns:
29
+ Adjusted threshold value
30
+ """
31
+ base_thresh = THRESHOLDS[model_name][category]
32
+ return base_thresh * THRESHOLD_MULTIPLIER
33
+
34
+
35
+ def get_status(score: float, model_name: str, category: str) -> str:
36
+ """
37
+ Determine the anomaly status based on score and threshold.
38
+
39
+ Args:
40
+ score: Anomaly score
41
+ model_name: Name of the model
42
+ category: MVTec AD category
43
+
44
+ Returns:
45
+ Status string with emoji
46
+ """
47
+ threshold = get_threshold(model_name, category)
48
+
49
+ if score < threshold:
50
+ return "🟢 Normal"
51
+ elif score < threshold + 0.1:
52
+ return "🟡 Slight Deviation"
53
+ else:
54
+ return "🔴 Strong Anomaly"
55
+
56
+
57
+ def is_valid_anomaly(score: float, model_name: str, category: str) -> bool:
58
+ """
59
+ Check if the score indicates a valid anomaly.
60
+
61
+ Args:
62
+ score: Anomaly score
63
+ model_name: Name of the model
64
+ category: MVTec AD category
65
+
66
+ Returns:
67
+ True if score exceeds threshold
68
+ """
69
+ threshold = get_threshold(model_name, category)
70
+ return score > threshold
71
+
72
+
73
+ def scale_efficientad_score(score: float) -> float:
74
+ """
75
+ Scale EfficientAD score for better visualization and display.
76
+
77
+ Args:
78
+ score: Raw EfficientAD score
79
+
80
+ Returns:
81
+ Scaled score
82
+ """
83
+ if score < 0.5:
84
+ return (score * 2) ** 2 / 4
85
+ else:
86
+ k = 500
87
+ return 1 / (1 + np.exp(-k * (score - 0.5)))
88
+
89
+
90
+ def detect(image, model_name: str, category: str, gemini_client):
91
+ """
92
+ Main detection function that runs the full anomaly detection pipeline.
93
+
94
+ Args:
95
+ image: Input image (PIL Image or numpy array)
96
+ model_name: Selected model name
97
+ category: Selected MVTec AD category
98
+ gemini_client: Initialized Gemini client for explanations
99
+
100
+ Returns:
101
+ Tuple of visualization outputs and metadata
102
+ """
103
+ if image is None:
104
+ return None, None, None, None, "", "", "", None
105
+
106
+ # Convert PIL Image to numpy array
107
+ image_np = np.array(image)
108
+
109
+ # Load model
110
+ model = load_model(model_name, category)
111
+
112
+ # Run inference to get raw outputs
113
+ heatmap, pred_mask_raw, score = run_inference(model, image_np, model_name, category)
114
+
115
+ # Determine if it's an anomaly
116
+ is_anomaly = is_valid_anomaly(score, model_name, category)
117
+
118
+ # Postprocess to get final mask and bounding boxes
119
+ # Note: We pass the resized image for postprocessing
120
+ img_resized = cv2.resize(image_np, (IMAGE_SIZE, IMAGE_SIZE))
121
+ final_mask, bboxes, heatmap_vis = postprocess(heatmap, img_resized, model_name, is_anomaly)
122
+
123
+ # Create visualizations
124
+ original_vis, heatmap_color, overlay, mask_vis = create_visuals(
125
+ image_np, heatmap_vis, final_mask, bboxes, model_name
126
+ )
127
+
128
+ # Get threshold and status
129
+ threshold = get_threshold(model_name, category)
130
+ status = get_status(score, model_name, category)
131
+
132
+ # Scale score for display if using EfficientAD
133
+ if model_name == "efficientad":
134
+ display_score = scale_efficientad_score(score)
135
+ else:
136
+ display_score = score
137
+
138
+ # Store state for explanation
139
+ state = {
140
+ "image": image_np,
141
+ "bboxes": bboxes,
142
+ "score": score,
143
+ "category": category,
144
+ "gemini_client": gemini_client
145
+ }
146
+
147
+ return (
148
+ original_vis,
149
+ heatmap_color,
150
+ overlay,
151
+ mask_vis,
152
+ f"{display_score:.4f}",
153
+ f"{threshold:.4f}",
154
+ status,
155
+ state
156
+ )
157
+
158
+
159
+ def explain(state):
160
+ """
161
+ Generate an explanation for the detected anomaly.
162
+
163
+ Args:
164
+ state: State dictionary containing image, bboxes, score, category, and gemini_client
165
+
166
+ Returns:
167
+ Explanation text
168
+ """
169
+ if state is None:
170
+ return "Run detection first."
171
+
172
+ gemini_client = state.get("gemini_client")
173
+ if gemini_client is None:
174
+ return "Gemini client not initialized. Please set GEMINI_API_KEY environment variable."
175
+
176
+ return get_explanation(
177
+ state["image"],
178
+ state["bboxes"],
179
+ state["score"],
180
+ state["category"],
181
+ gemini_client
182
+ )
183
+
184
+
185
+ def create_ui(gemini_client):
186
+ """
187
+ Create and configure the Gradio UI.
188
+
189
+ Args:
190
+ gemini_client: Initialized Gemini client
191
+
192
+ Returns:
193
+ Gradio Blocks interface
194
+ """
195
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
196
+ gr.Markdown("# 🔍 Industrial Anomaly Detection")
197
+ gr.Markdown("PatchCore + EfficientAD + Explainable AI")
198
+
199
+ state = gr.State()
200
+
201
+ with gr.Row():
202
+ with gr.Column(scale=1):
203
+ input_image = gr.Image(label="Upload Image", type="numpy", height=300)
204
+
205
+ model_dropdown = gr.Dropdown(
206
+ choices=["patchcore", "efficientad"],
207
+ value="patchcore",
208
+ label="Model"
209
+ )
210
+
211
+ category_dropdown = gr.Dropdown(
212
+ choices=list(THRESHOLDS["patchcore"].keys()),
213
+ value="bottle",
214
+ label="Category"
215
+ )
216
+
217
+ detect_btn = gr.Button("🚀 Run Detection")
218
+ explain_btn = gr.Button("🧠 Explain Anomaly")
219
+
220
+ with gr.Column(scale=2):
221
+ with gr.Row():
222
+ out_original = gr.Image(label="Original")
223
+ out_heatmap = gr.Image(label="Heatmap")
224
+
225
+ with gr.Row():
226
+ out_overlay = gr.Image(label="Overlay")
227
+ out_mask = gr.Image(label="Predicted Mask")
228
+
229
+ with gr.Row():
230
+ score_box = gr.Textbox(label="Score")
231
+ threshold_box = gr.Textbox(label="Threshold")
232
+ status_box = gr.Textbox(label="Status")
233
+
234
+ explanation_box = gr.Textbox(label="Explanation", lines=3)
235
+
236
+ # Button actions
237
+ detect_btn.click(
238
+ fn=lambda img, model, cat: detect(img, model, cat, gemini_client),
239
+ inputs=[input_image, model_dropdown, category_dropdown],
240
+ outputs=[
241
+ out_original,
242
+ out_heatmap,
243
+ out_overlay,
244
+ out_mask,
245
+ score_box,
246
+ threshold_box,
247
+ status_box,
248
+ state
249
+ ],
250
+ )
251
+
252
+ explain_btn.click(
253
+ fn=explain,
254
+ inputs=[state],
255
+ outputs=explanation_box
256
+ )
257
+
258
+ return demo
259
+
260
+
261
+ def main():
262
+ """Main entry point for the application."""
263
+ import os
264
+ from dotenv import load_dotenv
265
+
266
+ # Load environment variables
267
+ load_dotenv()
268
+
269
+ # Get Gemini API key
270
+ api_key = os.getenv("GEMINI_API_KEY")
271
+ if not api_key:
272
+ raise ValueError(
273
+ "GEMINI_API_KEY not found. Please set it in .env file or environment variables."
274
+ )
275
+
276
+ # Initialize Gemini client
277
+ gemini_client = init_gemini_client(api_key)
278
+
279
+ # Create and launch UI
280
+ demo = create_ui(gemini_client)
281
+
282
+ # Configure Gradio settings from environment variables
283
+ share = os.getenv("GRADIO_SHARE", "False").lower() == "true"
284
+ debug = os.getenv("GRADIO_DEBUG", "False").lower() == "true"
285
+
286
+ demo.launch(share=share, debug=debug)
287
+
288
+
289
+ if __name__ == "__main__":
290
+ main()
config.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration settings for the anomaly detection project.
3
+
4
+ This module contains all configurable parameters including thresholds,
5
+ model mappings, and other constants used throughout the application.
6
+ """
7
+
8
+ # Model directory mapping
9
+ MODEL_TO_DIR = {
10
+ "patchcore": "Patchcore",
11
+ "efficientad": "EfficientAd",
12
+ }
13
+
14
+ # MVTec AD dataset categories
15
+ MVTEC_CATEGORIES = [
16
+ "bottle", "cable", "capsule", "carpet", "grid",
17
+ "hazelnut", "leather", "metal_nut", "pill", "screw",
18
+ "tile", "toothbrush", "transistor", "wood", "zipper"
19
+ ]
20
+
21
+ # Precomputed thresholds for each model and category
22
+ # These thresholds are computed at the 95th percentile of normal training scores
23
+ THRESHOLDS = {
24
+ "patchcore": {
25
+ "bottle": 0.3218444108963013,
26
+ "cable": 0.34408192038536073,
27
+ "capsule": 0.5454285681247711,
28
+ "carpet": 0.3088440954685211,
29
+ "grid": 0.25913039445877073,
30
+ "hazelnut": 0.10068576037883759,
31
+ "leather": 0.2726534068584442,
32
+ "metal_nut": 0.34413049668073653,
33
+ "pill": 0.26968240439891816,
34
+ "screw": 0.49187072515487673,
35
+ "tile": 0.3581161931157112,
36
+ "toothbrush": 0.3721309259533882,
37
+ "transistor": 0.45495494604110714,
38
+ "wood": 0.1711873710155487,
39
+ "zipper": 0.4981046631932258
40
+ },
41
+ "efficientad": {
42
+ "bottle": 0.49928921461105347,
43
+ "cable": 0.4673861160874367,
44
+ "capsule": 0.5370000839233399,
45
+ "carpet": 0.49847708493471143,
46
+ "grid": 0.5295769184827804,
47
+ "hazelnut": 0.5202932059764862,
48
+ "leather": 0.504090940952301,
49
+ "metal_nut": 0.5047085165977478,
50
+ "pill": 0.5043391764163971,
51
+ "screw": 0.7167768508195878,
52
+ "tile": 0.5030474990606308,
53
+ "toothbrush": 0.5439804702997207,
54
+ "transistor": 0.5076832294464111,
55
+ "wood": 0.5024313390254974,
56
+ "zipper": 1.0
57
+ }
58
+ }
59
+
60
+ # Hugging Face repository ID for checkpoints
61
+ HF_REPO_ID = "micguida1/mvtec-anomaly-checkpoints"
62
+
63
+ # Image size for model input
64
+ IMAGE_SIZE = 256
65
+
66
+ # Threshold multiplier for sensitivity adjustment
67
+ # Lower value = more sensitive (more anomalies detected)
68
+ THRESHOLD_MULTIPLIER = 0.85
69
+
70
+ # Visualization parameters
71
+ HEATMAP_ALPHA = 0.5
72
+ OVERLAY_ALPHA = 0.5
73
+
74
+ # PatchCore specific parameters
75
+ PATCHCORE_BINARY_THRESHOLD = 0.60
76
+ PATCHCORE_MIN_CONTOUR_AREA = 100
77
+ PATCHCORE_MAX_INTENSITY_THRESHOLD = 0.75
78
+ PATCHCORE_BLUR_KERNEL = (7, 7)
79
+ PATCHCORE_MORPH_KERNEL = (5, 5)
80
+ PATCHCORE_FG_THRESHOLD = 15
81
+ PATCHCORE_FG_MORPH_KERNEL = (9, 9)
82
+
83
+ # EfficientAD specific parameters
84
+ EFFICIENTAD_BINARY_THRESHOLD = 0.5
85
+ EFFICIENTAD_MIN_CONTOUR_AREA = 5
core/explain.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Explainability module using Gemini VLM.
3
+
4
+ This module provides functions to generate human-readable explanations
5
+ for detected anomalies using Google's Gemini Vision Language Model.
6
+ """
7
+
8
+ from PIL import Image
9
+ import cv2
10
+ from google import genai
11
+ import numpy as np
12
+
13
+ # Model configuration
14
+ GEMINI_MODEL = "gemini-flash-lite-latest"
15
+
16
+
17
+ def get_explanation(
18
+ original_image: np.ndarray,
19
+ bboxes: list,
20
+ score: float,
21
+ category: str,
22
+ client
23
+ ) -> str:
24
+ """
25
+ Generate an explanation for the detected anomaly using Gemini VLM.
26
+
27
+ Args:
28
+ original_image: Original input image in RGB format
29
+ bboxes: List of bounding boxes [x1, y1, x2, y2] in 256x256 scale
30
+ score: Anomaly score
31
+ category: MVTec AD category
32
+ client: Initialized Gemini API client
33
+
34
+ Returns:
35
+ Explanation text from the model
36
+ """
37
+ if not bboxes:
38
+ return "No anomaly detected."
39
+
40
+ # Scale bounding boxes from 256x256 to original image size
41
+ h_orig, w_orig = original_image.shape[:2]
42
+ scale_x = w_orig / 256.0
43
+ scale_y = h_orig / 256.0
44
+
45
+ # Draw red bounding boxes on a copy of the original image
46
+ annotated_img = original_image.copy()
47
+ for (x1, y1, x2, y2) in bboxes:
48
+ x1_s, y1_s = int(x1 * scale_x), int(y1 * scale_y)
49
+ x2_s, y2_s = int(x2 * scale_x), int(y2 * scale_y)
50
+
51
+ # Dynamic thickness based on image size
52
+ thickness = max(2, int(max(h_orig, w_orig) * 0.005))
53
+ cv2.rectangle(annotated_img, (x1_s, y1_s), (x2_s, y2_s), (255, 0, 0), thickness)
54
+
55
+ # Convert to PIL Image
56
+ annotated_pil = Image.fromarray(annotated_img)
57
+
58
+ # Construct prompt
59
+ prompt = f"""
60
+ You are an expert industrial quality control inspector.
61
+ We are inspecting a: {category}
62
+
63
+ An anomaly detection model has flagged a potential defect, highlighted by the RED BOUNDING BOX in the provided image.
64
+
65
+ Your task is to classify the defect inside the red box and assess its severity.
66
+ Common defects for {category} include: scratches, cuts, cracks, holes, structural damage, or severe discoloration.
67
+
68
+ Analyze the highlighted region carefully in the context of the whole object.
69
+
70
+ Only Provide your final assessment strictly in this format:
71
+ Defect: <Short name, e.g., Deep Scratch, Surface Cut, Crack, Contamination, Colouration>
72
+ Location: <Where is it on the object?>
73
+ Severity: <Low/Medium/High>
74
+ """
75
+
76
+ # Generate response from Gemini with error handling
77
+ try:
78
+ response = client.models.generate_content(
79
+ model=GEMINI_MODEL,
80
+ contents=[prompt, annotated_pil]
81
+ )
82
+ return response.text
83
+ except Exception as e:
84
+ return f"Failed to generate explanation: {str(e)}"
85
+
86
+
87
+ def init_gemini_client(api_key: str):
88
+ """
89
+ Initialize the Gemini API client.
90
+
91
+ Args:
92
+ api_key: Gemini API key
93
+
94
+ Returns:
95
+ Initialized genai client
96
+ """
97
+ return genai.Client(api_key=api_key)
core/inference.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference module for anomaly detection.
3
+
4
+ This module handles image preprocessing, model inference, and basic output extraction.
5
+ It does not include postprocessing logic for mask generation.
6
+ """
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+
12
+ from config import IMAGE_SIZE
13
+
14
+
15
+ def preprocess_image(image: np.ndarray) -> torch.Tensor:
16
+ """
17
+ Preprocess an image for model input.
18
+
19
+ Args:
20
+ image: Input image in RGB format (H, W, 3)
21
+
22
+ Returns:
23
+ Preprocessed tensor ready for model inference (1, 3, H, W)
24
+ """
25
+ # Resize to model input size
26
+ img_resized = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
27
+
28
+ # Normalize to [0, 1]
29
+ img = img_resized / 255.0
30
+
31
+ # Transpose from (H, W, C) to (C, H, W)
32
+ img = np.transpose(img, (2, 0, 1))
33
+
34
+ # Add batch dimension
35
+ img = np.expand_dims(img, axis=0)
36
+
37
+ # Convert to tensor
38
+ tensor = torch.tensor(img, dtype=torch.float32)
39
+
40
+ return tensor
41
+
42
+
43
+ def run_inference(model, image: np.ndarray, model_name: str, category: str):
44
+ """
45
+ Run inference on a single image.
46
+
47
+ Args:
48
+ model: Loaded anomaly detection model
49
+ image: Input image in RGB format
50
+ model_name: Name of the model being used
51
+ category: MVTec AD category
52
+
53
+ Returns:
54
+ tuple: (heatmap, pred_mask_raw, score)
55
+ - heatmap: Raw anomaly heatmap (H, W)
56
+ - pred_mask_raw: Raw predicted mask if available (H, W) or None
57
+ - score: Anomaly score (float)
58
+ """
59
+ # Preprocess the image
60
+ tensor = preprocess_image(image)
61
+
62
+ # Get device from model
63
+ device = next(model.parameters()).device
64
+ tensor = tensor.to(device)
65
+
66
+ # Run inference
67
+ with torch.no_grad():
68
+ output = model(tensor)
69
+
70
+ # Extract outputs based on output format with validation
71
+ pred_mask_raw = None
72
+
73
+ if hasattr(output, "anomaly_map") and hasattr(output, "pred_score"):
74
+ heatmap = output.anomaly_map
75
+ score = output.pred_score
76
+ pred_mask_raw = getattr(output, "pred_mask", None)
77
+ elif isinstance(output, dict) and "anomaly_map" in output and "pred_score" in output:
78
+ heatmap = output["anomaly_map"]
79
+ score = output["pred_score"]
80
+ pred_mask_raw = output.get("pred_mask", None)
81
+ elif isinstance(output, tuple) and len(output) >= 2:
82
+ score, heatmap = output[0], output[1]
83
+ else:
84
+ raise ValueError(
85
+ f"Model output must contain anomaly_map and pred_score. "
86
+ f"Got output type: {type(output)}. "
87
+ f"If using a dict, ensure it has 'anomaly_map' and 'pred_score' keys. "
88
+ f"If using an object, ensure it has 'anomaly_map' and 'pred_score' attributes."
89
+ )
90
+
91
+ # Convert to numpy
92
+ heatmap = heatmap.squeeze().cpu().numpy()
93
+ score = float(score.cpu().numpy() if torch.is_tensor(score) else score)
94
+
95
+ if pred_mask_raw is not None:
96
+ pred_mask_raw = pred_mask_raw.squeeze().cpu().numpy()
97
+
98
+ return heatmap, pred_mask_raw, score
core/postprocess.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Postprocessing module for anomaly detection.
3
+
4
+ This module handles mask generation, bounding box extraction, and validation
5
+ from raw anomaly heatmaps. It contains model-specific logic for both PatchCore
6
+ and EfficientAD.
7
+ """
8
+
9
+ import cv2
10
+ import numpy as np
11
+
12
+ from config import (
13
+ PATCHCORE_BINARY_THRESHOLD,
14
+ PATCHCORE_MIN_CONTOUR_AREA,
15
+ PATCHCORE_MAX_INTENSITY_THRESHOLD,
16
+ PATCHCORE_BLUR_KERNEL,
17
+ PATCHCORE_MORPH_KERNEL,
18
+ PATCHCORE_FG_THRESHOLD,
19
+ PATCHCORE_FG_MORPH_KERNEL,
20
+ EFFICIENTAD_BINARY_THRESHOLD,
21
+ EFFICIENTAD_MIN_CONTOUR_AREA,
22
+ )
23
+
24
+
25
+ def postprocess_patchcore(heatmap: np.ndarray, original_image: np.ndarray, is_anomaly: bool):
26
+ """
27
+ Postprocess heatmap for PatchCore model.
28
+
29
+ Args:
30
+ heatmap: Raw anomaly heatmap (H, W)
31
+ original_image: Original resized image (H, W, 3)
32
+ is_anomaly: Whether the image is classified as an anomaly
33
+
34
+ Returns:
35
+ tuple: (final_mask, bboxes, heatmap_vis)
36
+ - final_mask: Binary mask of anomaly regions (uint8)
37
+ - bboxes: List of bounding boxes [x1, y1, x2, y2]
38
+ - heatmap_vis: Normalized heatmap for visualization (H, W)
39
+ """
40
+ h, w = heatmap.shape
41
+
42
+ # Blur the heatmap for smoother contours
43
+ heatmap_blurred = cv2.GaussianBlur(heatmap, PATCHCORE_BLUR_KERNEL, 0)
44
+
45
+ # Normalize to [0, 1] with robust handling of constant heatmaps
46
+ h_min = float(heatmap_blurred.min())
47
+ h_max = float(heatmap_blurred.max())
48
+ h_range = h_max - h_min
49
+ if h_range < 1e-6:
50
+ # Heatmap is essentially constant, normalize to all zeros
51
+ heatmap_vis = np.zeros_like(heatmap_blurred)
52
+ else:
53
+ heatmap_vis = (heatmap_blurred - h_min) / h_range
54
+
55
+ # Foreground masking to ignore background
56
+ gray = cv2.cvtColor(original_image, cv2.COLOR_RGB2GRAY)
57
+ _, fg_mask = cv2.threshold(gray, PATCHCORE_FG_THRESHOLD, 255, cv2.THRESH_BINARY)
58
+ kernel_fg = np.ones(PATCHCORE_FG_MORPH_KERNEL, np.uint8)
59
+ fg_mask = cv2.morphologyEx(fg_mask, cv2.MORPH_CLOSE, kernel_fg)
60
+
61
+ heatmap_vis[fg_mask == 0] = 0
62
+
63
+ # Initialize outputs
64
+ final_mask = np.zeros_like(heatmap_vis, dtype=np.uint8)
65
+ bboxes = []
66
+
67
+ if is_anomaly:
68
+ # Threshold to binary
69
+ binary = (heatmap_vis > PATCHCORE_BINARY_THRESHOLD).astype(np.uint8) * 255
70
+
71
+ # Morphological operations to clean up
72
+ kernel = np.ones(PATCHCORE_MORPH_KERNEL, np.uint8)
73
+ binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
74
+ binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
75
+
76
+ # Find contours
77
+ contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
78
+
79
+ # Filter valid contours based on area and intensity
80
+ valid_contours = []
81
+ for c in contours:
82
+ x, y, cw, ch = cv2.boundingRect(c)
83
+ area = cw * ch
84
+
85
+ if area < PATCHCORE_MIN_CONTOUR_AREA:
86
+ continue
87
+
88
+ # Check max intensity within contour
89
+ mask_temp = np.zeros_like(heatmap_vis, dtype=np.uint8)
90
+ cv2.drawContours(mask_temp, [c], -1, 1, thickness=-1)
91
+ max_intensity = heatmap_vis[mask_temp == 1].max()
92
+
93
+ if max_intensity > PATCHCORE_MAX_INTENSITY_THRESHOLD:
94
+ valid_contours.append(c)
95
+
96
+ # Draw valid contours and extract bounding boxes
97
+ if valid_contours:
98
+ valid_contours = sorted(valid_contours, key=cv2.contourArea, reverse=True)
99
+ cv2.drawContours(final_mask, valid_contours, -1, 255, thickness=-1)
100
+
101
+ for c in valid_contours:
102
+ x, y, cw, ch = cv2.boundingRect(c)
103
+ bboxes.append([x, y, x + cw, y + ch])
104
+
105
+ return final_mask, bboxes, heatmap_vis
106
+
107
+
108
+ def postprocess_efficientad(heatmap: np.ndarray, is_anomaly: bool):
109
+ """
110
+ Postprocess heatmap for EfficientAD model.
111
+
112
+ Args:
113
+ heatmap: Raw anomaly heatmap (H, W)
114
+ is_anomaly: Whether the image is classified as an anomaly
115
+
116
+ Returns:
117
+ tuple: (final_mask, bboxes, heatmap_vis)
118
+ - final_mask: Binary mask of anomaly regions (uint8)
119
+ - bboxes: List of bounding boxes [x1, y1, x2, y2]
120
+ - heatmap_vis: Normalized heatmap for visualization (H, W)
121
+ """
122
+ h, w = heatmap.shape
123
+
124
+ # Normalize with adaptive strategy and robust handling
125
+ amap_min = float(heatmap.min())
126
+ amap_max = float(heatmap.max())
127
+ amap_range = amap_max - amap_min
128
+
129
+ if amap_range < 0.1:
130
+ if amap_range > 1e-6:
131
+ heatmap_vis = (heatmap - amap_min) / amap_range
132
+ else:
133
+ # Heatmap is essentially constant
134
+ heatmap_vis = np.zeros_like(heatmap)
135
+ else:
136
+ # Clip to [0, 1] range, assuming heatmap is already roughly normalized
137
+ heatmap_vis = np.clip(heatmap, 0, 1)
138
+
139
+ # Dim the heatmap if not an anomaly
140
+ if not is_anomaly:
141
+ heatmap_vis = heatmap_vis * 0.3
142
+
143
+ # Initialize outputs
144
+ final_mask = np.zeros_like(heatmap, dtype=np.uint8)
145
+ bboxes = []
146
+
147
+ if is_anomaly:
148
+ # Threshold exactly at 0.5
149
+ binary = (heatmap_vis > EFFICIENTAD_BINARY_THRESHOLD).astype(np.uint8) * 255
150
+
151
+ # Find contours
152
+ contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
153
+
154
+ # Filter by minimum area
155
+ for c in contours:
156
+ x, y, cw, ch = cv2.boundingRect(c)
157
+ if cw * ch > EFFICIENTAD_MIN_CONTOUR_AREA:
158
+ cv2.drawContours(final_mask, [c], -1, 255, thickness=-1)
159
+ bboxes.append([x, y, x + cw, y + ch])
160
+
161
+ return final_mask, bboxes, heatmap_vis
162
+
163
+
164
+ def postprocess(heatmap: np.ndarray, original_image: np.ndarray, model_name: str, is_anomaly: bool):
165
+ """
166
+ Main postprocessing function that routes to the appropriate model-specific processor.
167
+
168
+ Args:
169
+ heatmap: Raw anomaly heatmap (H, W)
170
+ original_image: Original resized image (H, W, 3)
171
+ model_name: Name of the model ("patchcore" or "efficientad")
172
+ is_anomaly: Whether the image is classified as an anomaly
173
+
174
+ Returns:
175
+ tuple: (final_mask, bboxes, heatmap_vis)
176
+ """
177
+ if model_name == "efficientad":
178
+ return postprocess_efficientad(heatmap, is_anomaly)
179
+ else:
180
+ return postprocess_patchcore(heatmap, original_image, is_anomaly)
core/visualization.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization module for anomaly detection results.
3
+
4
+ This module provides functions to create visual outputs including heatmaps,
5
+ overlays, and predicted mask visualizations.
6
+ """
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ from config import HEATMAP_ALPHA, OVERLAY_ALPHA
12
+
13
+
14
+ def create_overlay(image: np.ndarray, heatmap: np.ndarray, model_name: str) -> np.ndarray:
15
+ """
16
+ Create an overlay of the heatmap on the original image.
17
+
18
+ Args:
19
+ image: Original image in RGB format (H, W, 3)
20
+ heatmap: Normalized heatmap (H, W) in range [0, 1]
21
+ model_name: Name of the model for model-specific handling
22
+
23
+ Returns:
24
+ Overlay image in RGB format (H, W, 3)
25
+ """
26
+ image_resized = cv2.resize(image, (256, 256))
27
+
28
+ # Convert heatmap to uint8 and apply colormap
29
+ heatmap_uint8 = (heatmap * 255).astype(np.uint8)
30
+ heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
31
+ heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
32
+
33
+ # Model-specific handling for EfficientAD padding
34
+ if model_name == "efficientad":
35
+ # Mask out zero values (padding) to show original image
36
+ mask_0 = (heatmap == 0)[..., np.newaxis]
37
+ overlay = np.where(mask_0, image_resized, cv2.addWeighted(image_resized, OVERLAY_ALPHA, heatmap_color, HEATMAP_ALPHA, 0))
38
+ else:
39
+ overlay = cv2.addWeighted(image_resized, OVERLAY_ALPHA, heatmap_color, HEATMAP_ALPHA, 0)
40
+
41
+ return overlay
42
+
43
+
44
+ def create_mask_visualization(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
45
+ """
46
+ Create a visualization of the predicted mask overlaid on the image.
47
+
48
+ Args:
49
+ image: Original image in RGB format (H, W, 3)
50
+ mask: Binary mask (H, W) where non-zero values indicate anomaly
51
+
52
+ Returns:
53
+ Visualization image with semi-transparent red mask and contours
54
+ """
55
+ image_resized = cv2.resize(image, (256, 256))
56
+ vis_img = image_resized.copy()
57
+
58
+ if np.any(mask):
59
+ # Create a red color mask
60
+ color_mask = np.zeros_like(image_resized)
61
+ color_mask[mask > 0] = [255, 0, 0] # RGB Red
62
+
63
+ # Blend with original image
64
+ vis_img = cv2.addWeighted(vis_img, 0.7, color_mask, 0.3, 0)
65
+
66
+ # Draw contours
67
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
68
+ cv2.drawContours(vis_img, contours, -1, (255, 255, 255), 2)
69
+
70
+ return vis_img
71
+
72
+
73
+ def draw_bounding_boxes(overlay: np.ndarray, mask_vis: np.ndarray, bboxes: list):
74
+ """
75
+ Draw bounding boxes on both overlay and mask visualization images.
76
+
77
+ Args:
78
+ overlay: Overlay image to draw on (modified in-place)
79
+ mask_vis: Mask visualization image to draw on (modified in-place)
80
+ bboxes: List of bounding boxes [x1, y1, x2, y2]
81
+ """
82
+ for (x1, y1, x2, y2) in bboxes:
83
+ # Green boxes on overlay
84
+ cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 255, 0), 2)
85
+ # Blue boxes on mask visualization
86
+ cv2.rectangle(mask_vis, (x1, y1), (x2, y2), (255, 0, 0), 2)
87
+
88
+
89
+ def create_heatmap_color(heatmap: np.ndarray, model_name: str) -> np.ndarray:
90
+ """
91
+ Create a colored heatmap image suitable for display.
92
+
93
+ Args:
94
+ heatmap: Normalized heatmap (H, W) in range [0, 1]
95
+ model_name: Name of the model for model-specific handling
96
+
97
+ Returns:
98
+ Colored heatmap in RGB format (H, W, 3)
99
+ """
100
+ heatmap_uint8 = (heatmap * 255).astype(np.uint8)
101
+ heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
102
+ heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
103
+
104
+ # For EfficientAD, make padding (zero values) black
105
+ if model_name == "efficientad":
106
+ heatmap_color[heatmap == 0] = [0, 0, 0]
107
+
108
+ return heatmap_color
109
+
110
+
111
+ def create_visuals(
112
+ image: np.ndarray,
113
+ heatmap: np.ndarray,
114
+ mask: np.ndarray,
115
+ bboxes: list,
116
+ model_name: str
117
+ ) -> tuple:
118
+ """
119
+ Create all visualization outputs for a single inference result.
120
+
121
+ Args:
122
+ image: Original input image in RGB format
123
+ heatmap: Normalized heatmap (H, W)
124
+ mask: Binary mask (H, W)
125
+ bboxes: List of bounding boxes
126
+ model_name: Name of the model
127
+
128
+ Returns:
129
+ tuple: (original_resized, heatmap_color, overlay, mask_vis)
130
+ """
131
+ # Resize original image to 256x256
132
+ original_resized = cv2.resize(image, (256, 256))
133
+
134
+ # Create heatmap visualization
135
+ heatmap_color = create_heatmap_color(heatmap, model_name)
136
+
137
+ # Create overlay
138
+ overlay = create_overlay(image, heatmap, model_name)
139
+
140
+ # Create mask visualization
141
+ mask_vis = create_mask_visualization(image, mask)
142
+
143
+ # Draw bounding boxes on overlay and mask visualization
144
+ draw_bounding_boxes(overlay, mask_vis, bboxes)
145
+
146
+ return original_resized, heatmap_color, overlay, mask_vis
models/model_loader.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model loading and caching module.
3
+
4
+ This module provides functions to load anomaly detection models from
5
+ Hugging Face Hub with caching support to avoid reloading the same model multiple times.
6
+ """
7
+
8
+ import os
9
+ import torch
10
+ from collections import OrderedDict
11
+ from huggingface_hub import hf_hub_download
12
+ from anomalib.models import Patchcore, EfficientAd
13
+
14
+ from config import HF_REPO_ID, MODEL_TO_DIR
15
+
16
+ # Maximum number of models to keep in cache (prevents unbounded memory growth)
17
+ MAX_MODEL_CACHE_SIZE = 30
18
+
19
+ # Global model cache with LRU eviction (using OrderedDict)
20
+ _model_cache = OrderedDict()
21
+
22
+
23
+ def get_ckpt_path(model_name: str, category: str) -> str:
24
+ """
25
+ Download or retrieve the checkpoint file for a given model and category.
26
+
27
+ Args:
28
+ model_name: Name of the model ("patchcore" or "efficientad")
29
+ category: MVTec AD category (e.g., "bottle", "cable")
30
+
31
+ Returns:
32
+ Path to the downloaded checkpoint file
33
+ """
34
+ dirname = MODEL_TO_DIR[model_name]
35
+ hf_path = f"{dirname}/MVTecAD/{category}/latest/weights/lightning/model.ckpt"
36
+
37
+ return hf_hub_download(
38
+ repo_id=HF_REPO_ID,
39
+ filename=hf_path,
40
+ local_dir="models",
41
+ local_dir_use_symlinks=False,
42
+ )
43
+
44
+
45
+ def load_model(model_name: str, category: str):
46
+ """
47
+ Load an anomaly detection model with caching and LRU eviction.
48
+
49
+ Args:
50
+ model_name: Name of the model ("patchcore" or "efficientad")
51
+ category: MVTec AD category
52
+
53
+ Returns:
54
+ Loaded model on the appropriate device (CUDA if available)
55
+
56
+ Raises:
57
+ ValueError: If an unknown model name is provided
58
+ """
59
+ key = f"{model_name}_{category}"
60
+
61
+ # Return cached model if available (move to end to mark as recently used)
62
+ if key in _model_cache:
63
+ _model_cache.move_to_end(key)
64
+ return _model_cache[key]
65
+
66
+ # Evict least recently used model if cache is full
67
+ if len(_model_cache) >= MAX_MODEL_CACHE_SIZE:
68
+ _model_cache.popitem(last=False) # Remove first (oldest) item
69
+
70
+ # Download checkpoint
71
+ ckpt = get_ckpt_path(model_name, category)
72
+
73
+ # Load the appropriate model type
74
+ if model_name == "patchcore":
75
+ model = Patchcore.load_from_checkpoint(ckpt)
76
+ elif model_name == "efficientad":
77
+ model = EfficientAd.load_from_checkpoint(ckpt)
78
+ else:
79
+ raise ValueError(f"Unknown model: {model_name}")
80
+
81
+ # Set evaluation mode and move to device
82
+ model.eval()
83
+ device = "cuda" if torch.cuda.is_available() else "cpu"
84
+ model.to(device)
85
+
86
+ # Cache the model (add to end)
87
+ _model_cache[key] = model
88
+
89
+ return model
90
+
91
+
92
+ def clear_model_cache():
93
+ """Clear the model cache to free memory."""
94
+ global _model_cache
95
+ _model_cache.clear()
requirements.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Anomaly Detection Dependencies
2
+
3
+ # Core ML Framework
4
+ torch>=2.0.0
5
+ torchvision
6
+
7
+ # Anomalib for pre-trained models
8
+ anomalib==2.2.0
9
+
10
+ # Gradio for UI
11
+ gradio>=4.0.0
12
+
13
+ # OpenCV for image processing
14
+ opencv-python-headless>=4.8.0
15
+
16
+ # Image processing
17
+ Pillow>=10.0.0
18
+ matplotlib>=3.7.0
19
+
20
+ # Hugging Face Hub for model downloads
21
+ huggingface-hub>=0.19.0
22
+ huggingface-hub[cli]
23
+
24
+ # Google Gemini for explainable AI
25
+ google-genai>=0.3.0
26
+
27
+ # Utilities
28
+ numpy>=1.24.0
29
+ tqdm>=4.65.0
30
+ python-dotenv>=1.0.0
31
+
32
+ # Training support
33
+ tensorboard>=2.16.0
34
+
35
+ # Optional: ONNX support (if needed)
36
+ # onnx>=1.14.0
37
+ # openvino>=2023.3.0