eho69 commited on
Commit
7e58509
Β·
verified Β·
1 Parent(s): 9144a73
Files changed (1) hide show
  1. app.py +846 -883
app.py CHANGED
@@ -1,1003 +1,966 @@
1
- # app.py
2
- import gradio as gr
3
- import torch
4
- import torch.nn as nn
5
- from torchvision import models, transforms
6
- from PIL import Image
7
- import numpy as np
8
- import pickle
9
- import os
10
- import cv2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # class EnginePartDetector:
13
- # def __init__(self):
 
 
 
 
 
14
  # self.model = models.resnet50(weights='IMAGENET1K_V1')
15
  # self.model = nn.Sequential(*list(self.model.children())[:-1])
16
  # self.model.eval()
17
-
 
 
 
 
 
 
 
 
 
 
18
  # self.transform = transforms.Compose([
19
  # transforms.Resize((224, 224)),
20
  # transforms.ToTensor(),
21
  # transforms.Normalize(
22
  # mean=[0.485, 0.456, 0.406],
23
- # std=[0.229, 0.224, 0.225]
24
  # )
25
  # ])
26
-
27
  # self.templates = {}
28
  # self.load_templates()
29
-
30
- # def extract_features(self, image):
31
- # if isinstance(image, np.ndarray):
32
- # image = Image.fromarray(image)
33
-
34
- # img_tensor = self.transform(image).unsqueeze(0)
35
-
36
- # with torch.no_grad():
37
- # features = self.model(img_tensor)
38
- # features = features.squeeze().numpy()
39
-
40
- # return features
41
-
42
- class EnginePartDetector:
43
- def __init__(
44
- self,
45
- clahe_clip_limit: float = 9.9,
46
- clahe_tile_grid: tuple = (8, 8),
47
- ):
48
- # ── ResNet-50 backbone (feature extractor only) ──────────────────
49
- self.model = models.resnet50(weights='IMAGENET1K_V1')
50
- self.model = nn.Sequential(*list(self.model.children())[:-1])
51
- self.model.eval()
52
-
53
- # ── CLAHE (OpenCV) β€” applied BEFORE the torch transform ──────────
54
- # Operates on grayscale to recover shadow-suppressed edges
55
- # (e.g. missing bearing saddle arcs), then merged back to RGB
56
- # so the 3-channel ResNet pipeline is unaffected.
57
- self.clahe = cv2.createCLAHE(
58
- clipLimit=clahe_clip_limit,
59
- tileGridSize=clahe_tile_grid,
60
- )
61
-
62
- # ── ResNet normalisation transform (unchanged) ───────────────────
63
- self.transform = transforms.Compose([
64
- transforms.Resize((224, 224)),
65
- transforms.ToTensor(),
66
- transforms.Normalize(
67
- mean=[0.485, 0.456, 0.406],
68
- std=[0.229, 0.224, 0.225],
69
- )
70
- ])
71
 
72
- self.templates = {}
73
- self.load_templates()
74
 
75
- # ── CLAHE preprocessing ───────────────────────────────────────────────
76
-
77
- def apply_clahe(self, image: np.ndarray) -> np.ndarray:
78
 
79
- # Convert RGB (PIL/numpy) β†’ BGR for OpenCV
80
- bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
81
 
82
- # BGR β†’ LAB
83
- lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB)
84
 
85
- # Split channels; apply CLAHE only to L (luminance)
86
- l_channel, a_channel, b_channel = cv2.split(lab)
87
- l_enhanced = self.clahe.apply(l_channel)
88
 
89
- # Merge enhanced L back with untouched A and B
90
- lab_enhanced = cv2.merge([l_enhanced, a_channel, b_channel])
91
 
92
- # LAB β†’ BGR β†’ RGB
93
- bgr_enhanced = cv2.cvtColor(lab_enhanced, cv2.COLOR_LAB2BGR)
94
- rgb_enhanced = cv2.cvtColor(bgr_enhanced, cv2.COLOR_BGR2RGB)
95
 
96
- return rgb_enhanced # uint8 numpy array, same shape as input
97
 
98
- # ── Feature extraction ────────────────────────────────────────────────
99
 
100
- def extract_features(self, image) -> np.ndarray:
101
 
102
- # 1. Normalise input to numpy uint8 RGB
103
- if isinstance(image, Image.Image):
104
- image = np.array(image.convert("RGB"))
105
- elif isinstance(image, np.ndarray) and image.dtype != np.uint8:
106
- image = image.astype(np.uint8)
107
 
108
- # 2. CLAHE β€” recover shadow-suppressed structural edges
109
- image = self.apply_clahe(image)
110
 
111
- # 3. Mild Gaussian blur β€” reduces high-freq metallic sheen noise
112
- # that CLAHE can amplify; kernel (3,3) is intentionally light
113
- # so real surface-defect texture is preserved
114
- image = cv2.GaussianBlur(image, (3, 3), 0)
115
 
116
- # 4. Convert back to PIL for torchvision transforms
117
- image_pil = Image.fromarray(image)
118
 
119
- # 5. ResNet transform β†’ tensor
120
- img_tensor = self.transform(image_pil).unsqueeze(0)
121
 
122
- # 6. Forward pass (no grad needed β€” inference only)
123
- with torch.no_grad():
124
- features = self.model(img_tensor)
125
- features = features.squeeze().numpy()
126
 
127
- return features
128
 
129
- def cosine_similarity(self, feat1, feat2):
130
- return np.dot(feat1, feat2) / (np.linalg.norm(feat1) * np.linalg.norm(feat2))
131
 
132
- def save_template(self, image, part_name):
133
- if image is None or not part_name:
134
- return "Please provide both image and part name"
135
 
136
- features = self.extract_features(image)
137
- self.templates[part_name] = features
138
 
139
- with open('templates.pkl', 'wb') as f:
140
- pickle.dump(self.templates, f)
141
 
142
- return f"βœ… Template '{part_name}' saved successfully!"
143
 
144
- def load_templates(self):
145
- if os.path.exists('templates.pkl'):
146
- try:
147
- with open('templates.pkl', 'rb') as f:
148
- self.templates = pickle.load(f)
149
- except:
150
- self.templates = {}
151
 
152
- def match_part(self, image, threshold=0.7):
153
- if image is None:
154
- return "Please provide an image", None
155
 
156
- if not self.templates:
157
- return "⚠️ No templates available. Please add templates first.", None
158
 
159
- query_features = self.extract_features(image)
160
 
161
- results = []
162
- for part_name, template_features in self.templates.items():
163
- similarity = self.cosine_similarity(query_features, template_features)
164
- results.append((part_name, similarity))
165
 
166
- results.sort(key=lambda x: x[1], reverse=True)
167
 
168
- best_match = results[0]
169
- output_text = f"πŸ” **Best Match**: {best_match[0]}\n"
170
- output_text += f"πŸ“Š **Confidence**: {best_match[1]:.2%}\n\n"
171
 
172
- if best_match[1] >= threshold:
173
- output_text += "βœ… **Status**: MATCHED\n\n"
174
- else:
175
- output_text += "❌ **Status**: NO MATCH (below threshold)\n\n"
176
 
177
- output_text += "**All Results:**\n"
178
- for part, sim in results:
179
- output_text += f"- {part}: {sim:.2%}\n"
180
 
181
- matched_label = best_match[0] if best_match[1] >= threshold else None
182
- return output_text, matched_label
183
 
184
- detector = EnginePartDetector()
185
 
186
- def add_template(image, part_name):
187
- return detector.save_template(image, part_name)
188
 
189
- def detect_part(image, threshold):
190
- return detector.match_part(image, threshold)
191
 
192
- def list_templates():
193
- if not detector.templates:
194
- return "No templates saved yet"
195
- return "\n".join([f"- {name}" for name in detector.templates.keys()])
196
 
197
- with gr.Blocks(title="Engine Part Detection System") as demo:
198
- gr.Markdown("""
199
- # πŸ”§ Engine Part Detection System
200
- ### Using ResNet50 Feature Extraction & Template Matching
201
 
202
- **How to use:**
203
- 1. **Add Templates**: Upload reference images of engine parts
204
- 2. **Detect Parts**: Upload/capture images to identify parts
205
- """)
206
 
207
- with gr.Tab("πŸ” Detect Part"):
208
- with gr.Row():
209
- with gr.Column():
210
- detect_input = gr.Image(sources=["upload", "webcam"], type="numpy")
211
- threshold_slider = gr.Slider(0.5, 0.95, value=0.7, label="Similarity Threshold")
212
- detect_btn = gr.Button("Detect Part", variant="primary")
213
- with gr.Column():
214
- detect_output = gr.Textbox(label="Detection Results", lines=10)
215
- match_label = gr.Label(label="Matched Part")
216
 
217
- detect_btn.click(
218
- fn=detect_part,
219
- inputs=[detect_input, threshold_slider],
220
- outputs=[detect_output, match_label],
221
- api_name="detect"
222
- )
223
 
224
- with gr.Tab("βž• Add Template"):
225
- with gr.Row():
226
- with gr.Column():
227
- template_input = gr.Image(sources=["upload"], type="numpy")
228
- part_name_input = gr.Textbox(label="Part Name (e.g., 'spark_plug', 'piston')")
229
- add_btn = gr.Button("Save Template", variant="primary")
230
- with gr.Column():
231
- add_output = gr.Textbox(label="Status")
232
 
233
- add_btn.click(
234
- fn=add_template,
235
- inputs=[template_input, part_name_input],
236
- outputs=add_output,
237
- api_name="add_template"
238
- )
239
 
240
- with gr.Tab("πŸ“‹ View Templates"):
241
- template_list = gr.Textbox(label="Saved Templates", lines=10)
242
- refresh_btn = gr.Button("Refresh List")
243
- refresh_btn.click(
244
- fn=list_templates,
245
- outputs=template_list,
246
- api_name="list_templates"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  )
248
- demo.load(fn=list_templates, outputs=template_list)
249
 
250
- if __name__ == "__main__":
251
- demo.launch()
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
- # # app.py
255
- # # ─────────────────────────────────────────────────────────────────────────────
256
- # # Stage 0 : Hough-Circle bolt-hole detection β†’ mid-arc brightness check
257
- # # Two bolt holes (top + bottom) per bridge are detected, connected
258
- # # by a line, and the metallic mid-arc at the midpoint is verified.
259
- # # Stage 1 : Full-image LAB color metric β†’ dark metallic grey β†’ FAIL
260
- # # Stage 2 : ResNet-50 + CLAHE β†’ cosine similarity vs golden template
261
- # # ─────────────────────────────────────────────────────────────────────────────
262
 
263
- # import gradio as gr
264
- # import cv2
265
- # import torch
266
- # import torch.nn as nn
267
- # from torchvision import models, transforms
268
- # from PIL import Image
269
- # import numpy as np
270
- # import pickle
271
- # import os
272
 
 
 
 
273
 
274
- # # ─────────────────────────────────────────────────────────────────────────────
275
- # # EnginePartDetector
276
- # # ─────────────────────────────────────────────────────────────────────────────
277
 
278
- # class EnginePartDetector:
279
- # def __init__(
280
- # self,
281
- # clahe_clip_limit: float = 9.9,
282
- # clahe_tile_grid: tuple = (8, 8),
283
- # ):
284
- # # ── ResNet-50 backbone (feature extractor only) ───────────────────
285
- # self.model = models.resnet50(weights='IMAGENET1K_V1')
286
- # self.model = nn.Sequential(*list(self.model.children())[:-1])
287
- # self.model.eval()
288
 
289
- # # ── CLAHE applied on LAB L-channel only β€” no hue/saturation shift ─
290
- # self.clahe = cv2.createCLAHE(
291
- # clipLimit=clahe_clip_limit,
292
- # tileGridSize=clahe_tile_grid,
293
- # )
 
 
 
294
 
295
- # # ── ResNet ImageNet normalisation ─────────────────────────────────
296
- # self.transform = transforms.Compose([
297
- # transforms.Resize((224, 224)),
298
- # transforms.ToTensor(),
299
- # transforms.Normalize(
300
- # mean=[0.485, 0.456, 0.406],
301
- # std=[0.229, 0.224, 0.225],
302
- # )
303
- # ])
304
 
305
- # self.templates = {}
306
- # self.load_templates()
 
307
 
308
- # # ── CLAHE helper ──────────────────────────────────────────────────────
 
 
 
309
 
310
- # def apply_clahe(self, image: np.ndarray) -> np.ndarray:
311
- # """CLAHE on LAB L-channel β€” shadow recovery, no colour shift."""
312
- # bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
313
- # lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB)
314
- # l, a, b = cv2.split(lab)
315
- # l_enh = self.clahe.apply(l)
316
- # lab_enh = cv2.merge([l_enh, a, b])
317
- # bgr_enh = cv2.cvtColor(lab_enh, cv2.COLOR_LAB2BGR)
318
- # return cv2.cvtColor(bgr_enh, cv2.COLOR_BGR2RGB)
319
-
320
- # # =========================================================================
321
- # # STAGE 0 β€” Hough-Circle bolt-hole detection + mid-arc check
322
- # # =========================================================================
323
- # #
324
- # # Top-down engine block view (simplified):
325
- # #
326
- # # β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
327
- # # β”‚ [saddle1] β”‚ bridge1 β”‚ [saddle2] β”‚ bridge2 β”‚ ... β”‚
328
- # # β”‚ ●(top) β”‚ ●(top) β”‚ β”‚
329
- # # β”‚ β”‚ ← mid-arc scan box β†’ β”‚ β”‚ β”‚
330
- # # β”‚ ●(bot) β”‚ ●(bot) β”‚ β”‚
331
- # # β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
332
- # #
333
- # # β€’ Bolt holes = small dark circles (●) detected by HoughCircles
334
- # # β€’ Each bridge has one TOP hole (y β‰ˆ 200) and one BOTTOM hole (y β‰ˆ 490)
335
- # # β€’ The midpoint of the line joining them sits on the metallic mid-arc
336
- # # β€’ A horizontal scan region at the midpoint should be bright (β‰₯ 100)
337
- # # β€’ If it is dark β†’ arc absent β†’ FAIL immediately
338
- # #
339
- # # =========================================================================
340
-
341
- # def _to_uint8_rgb(self, image) -> np.ndarray:
342
- # """Normalise any input to uint8 RGB numpy array."""
343
- # if isinstance(image, Image.Image):
344
- # return np.array(image.convert("RGB"))
345
- # if image.dtype != np.uint8:
346
- # return np.clip(image, 0, 255).astype(np.uint8)
347
- # return image
348
 
349
- # def _detect_dark_circles(self, gray: np.ndarray):
350
-
351
- # h, w = gray.shape
352
-
353
- # # Bilateral filter preserves the hole edge while reducing metallic noise
354
- # denoised = cv2.bilateralFilter(gray, 9, 75, 75)
355
- # blurred = cv2.GaussianBlur(denoised, (9, 9), 2)
356
-
357
- # # Scale radii to image width (calibrated on 1280 px reference images)
358
- # scale = w / 1280.0
359
- # min_r = max(6, int(8 * scale))
360
- # max_r = max(15, int(30 * scale))
361
- # min_dist = max(15, int(20 * scale))
362
-
363
- # circles = cv2.HoughCircles(
364
- # blurred,
365
- # cv2.HOUGH_GRADIENT,
366
- # dp = 1,
367
- # minDist = min_dist,
368
- # param1 = 60, # Canny high threshold
369
- # param2 = 28, # Accumulator threshold (lower = more circles)
370
- # minRadius = min_r,
371
- # maxRadius = max_r,
372
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
- # if circles is None:
375
- # return []
376
 
377
- # circles = np.round(circles[0]).astype(int)
 
 
 
 
 
378
 
379
- # dark = []
380
- # for (x, y, r) in circles:
381
- # # Sample pixel brightness at the circle centre
382
- # cy = min(y, h - 1)
383
- # cx = min(x, w - 1)
384
- # y1 = max(0, cy - r // 2); y2 = min(h, cy + r // 2)
385
- # x1 = max(0, cx - r // 2); x2 = min(w, cx + r // 2)
386
- # patch = gray[y1:y2, x1:x2]
387
- # if patch.size > 0 and patch.mean() < 80: # dark = bolt hole
388
- # dark.append((int(x), int(y), int(r)))
389
 
390
- # dark.sort(key=lambda c: c[0]) # left β†’ right
391
- # return dark
 
 
 
 
 
 
 
392
 
393
- # def _pair_into_bridges(self, circles, img_h: int):
394
-
395
- # used = [False] * len(circles)
396
- # clusters = []
397
-
398
- # for i, (xi, yi, ri) in enumerate(circles):
399
- # if used[i]:
400
- # continue
401
- # group = [(xi, yi, ri)]
402
- # used[i] = True
403
- # for j, (xj, yj, rj) in enumerate(circles):
404
- # if not used[j] and abs(xi - xj) <= 50:
405
- # group.append((xj, yj, rj))
406
- # used[j] = True
407
- # clusters.append(group)
408
-
409
- # pairs = []
410
- # min_sep = 0.18 * img_h
411
-
412
- # for group in clusters:
413
- # if len(group) < 2:
414
- # continue
415
- # group.sort(key=lambda c: c[1]) # sort by Y
416
- # top = group[0]
417
- # bot = group[-1]
418
- # if (bot[1] - top[1]) >= min_sep:
419
- # pairs.append((top, bot))
420
-
421
- # return pairs # each entry: ( (tx,ty,tr), (bx,by,br) )
422
-
423
- # def _check_mid_arc(
424
- # self,
425
- # gray: np.ndarray,
426
- # top_hole: tuple,
427
- # bot_hole: tuple,
428
- # scan_half_w: int = 50,
429
- # scan_half_h: int = 25,
430
- # bright_thresh: int = 100,
431
- # bright_ratio_min: float = 0.50,
432
- # ):
433
-
434
- # tx, ty, _ = top_hole
435
- # bx, by, _ = bot_hole
436
- # mid_x = int((tx + bx) / 2)
437
- # mid_y = int((ty + by) / 2)
438
 
439
- # h, w = gray.shape
440
- # x1 = max(0, mid_x - scan_half_w); x2 = min(w, mid_x + scan_half_w)
441
- # y1 = max(0, mid_y - scan_half_h); y2 = min(h, mid_y + scan_half_h)
442
 
443
- # region = gray[y1:y2, x1:x2]
444
- # if region.size == 0:
445
- # return False, dict(mean=0, bright_ratio=0, dark_ratio=0,
446
- # mid=(mid_x, mid_y), roi=(x1, y1, x2, y2))
 
 
 
447
 
448
- # mean_b = float(region.mean())
449
- # bright_ratio = float((region >= bright_thresh).sum()) / region.size
450
- # dark_ratio = float((region < 50).sum()) / region.size
451
 
452
- # is_present = (mean_b >= bright_thresh) and (bright_ratio >= bright_ratio_min)
 
 
 
 
 
 
 
453
 
454
- # return is_present, dict(
455
- # mean=mean_b, bright_ratio=bright_ratio, dark_ratio=dark_ratio,
456
- # mid=(mid_x, mid_y), roi=(x1, y1, x2, y2),
457
- # )
458
 
459
- # def stage0_mid_arc(self, image: np.ndarray):
460
-
461
- # image = self._to_uint8_rgb(image)
462
- # h, w = image.shape[:2]
463
- # gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
464
-
465
- # # Enhance contrast before circle detection
466
- # cg = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
467
- # gray_enh = cg.apply(gray)
468
-
469
- # # ── Step 1: detect dark bolt holes ────────────────────────────────
470
- # circles = self._detect_dark_circles(gray_enh)
471
- # debug = image.copy()
472
-
473
- # if len(circles) < 2:
474
- # reason = (f"Only {len(circles)} bolt hole(s) found "
475
- # f"(need β‰₯ 2). Stage 0 skipped.")
476
- # return "SKIP", reason, debug, []
477
-
478
- # # Draw yellow circles for all detected bolt holes
479
- # for (x, y, r) in circles:
480
- # cv2.circle(debug, (x, y), r + 4, (255, 220, 0), 2)
481
- # cv2.circle(debug, (x, y), 3, (255, 220, 0), -1)
482
-
483
- # # ── Step 2: pair holes into bridges ───────────────────────────────
484
- # pairs = self._pair_into_bridges(circles, h)
485
-
486
- # if not pairs:
487
- # reason = (f"{len(circles)} holes detected but no bridge pairs formed. "
488
- # f"Stage 0 skipped.")
489
- # return "SKIP", reason, debug, []
490
-
491
- # # ── Step 3: check mid-arc for each bridge ─────────────────────────
492
- # bridge_results = []
493
- # absent = []
494
-
495
- # for i, (top, bot) in enumerate(pairs, 1):
496
- # present, stats = self._check_mid_arc(gray_enh, top, bot)
497
-
498
- # bridge_results.append(dict(
499
- # bridge=i, top=top, bot=bot,
500
- # present=present, stats=stats,
501
- # ))
502
-
503
- # mx, my = stats['mid']
504
- # x1,y1,x2,y2 = stats['roi']
505
- # col = (0, 230, 0) if present else (0, 0, 255)
506
-
507
- # # Bridge axis line (cyan)
508
- # cv2.line(debug, (top[0], top[1]), (bot[0], bot[1]), (0, 200, 255), 2)
509
- # # Scan-box (green = OK, red = missing arc)
510
- # cv2.rectangle(debug, (x1, y1), (x2, y2), col, 2)
511
- # # Mid-point dot
512
- # cv2.circle(debug, (mx, my), 5, col, -1)
513
- # # Label
514
- # label = f"B{i} {'OK' if present else 'MISS'} {stats['mean']:.0f}"
515
- # cv2.putText(debug, label, (mx - 30, my - 32),
516
- # cv2.FONT_HERSHEY_SIMPLEX, 0.55, col, 2)
517
-
518
- # if not present:
519
- # absent.append(
520
- # f"Bridge {i} "
521
- # f"(brightness={stats['mean']:.0f}, "
522
- # f"bright_ratio={stats['bright_ratio']:.0%})"
523
- # )
524
-
525
- # # ── Step 4: decision ──────────────────────────────────────────────
526
- # if absent:
527
- # status = "FAIL"
528
- # reason = f"Mid-arc ABSENT on: {', '.join(absent)}"
529
- # else:
530
- # n = len(pairs)
531
- # status = "PASS"
532
- # reason = f"Mid-arc PRESENT on all {n} bridge(s) checked."
533
 
534
- # return status, reason, debug, bridge_results
535
 
536
- # # =========================================================================
537
- # # STAGE 1 β€” Full-image LAB color metric
538
- # # =========================================================================
 
 
 
539
 
540
- # def stage1_color(self, image: np.ndarray):
541
-
542
- # image = self._to_uint8_rgb(image)
543
- # image_eh = self.apply_clahe(image)
544
 
545
- # bgr = cv2.cvtColor(image_eh, cv2.COLOR_RGB2BGR)
546
- # lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB)
547
- # l, a, b = cv2.split(lab)
548
-
549
- # # Mask: keep only bright metallic areas (saddles > engine body)
550
- # _, mask = cv2.threshold(l, 80, 255, cv2.THRESH_BINARY)
551
- # k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
552
- # mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, k)
553
- # mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, k)
554
- # if np.sum(mask) < 0.05 * mask.size: # fallback: use whole image
555
- # mask = np.ones_like(l) * 255
556
-
557
- # total = np.sum(mask > 0)
558
- # l_mean = cv2.mean(l, mask=mask)[0]
559
- # a_mean = cv2.mean(a, mask=mask)[0]
560
- # b_mean = cv2.mean(b, mask=mask)[0]
561
- # sat = np.sqrt((a_mean - 128) ** 2 + (b_mean - 128) ** 2)
562
- # dratio = np.sum((l < 105) & (mask > 0)) / total if total else 0
563
- # hist = cv2.calcHist([l], [0], mask, [256], [0, 256]).flatten()
564
- # hist /= hist.sum()
565
- # dmass = float(hist[:105].sum())
566
-
567
- # fails = []
568
- # if l_mean < 105: fails.append(f"L-mean={l_mean:.1f} (<105)")
569
- # if dratio > 0.15: fails.append(f"dark-pixel ratio={dratio:.1%} (>15%)")
570
- # if dmass > 0.20: fails.append(f"dark histogram mass={dmass:.1%} (>20%)")
571
-
572
- # if fails:
573
- # return "FAIL", "Dark metallic grey: " + "; ".join(fails)
574
-
575
- # return "PASS", (
576
- # f"Light grey surface L={l_mean:.1f} "
577
- # f"dark_ratio={dratio:.1%} sat={sat:.1f}"
578
- # )
579
 
580
- # # =========================================================================
581
- # # STAGE 2 β€” ResNet-50 + CLAHE feature extraction & template matching
582
- # # =========================================================================
583
 
584
- # def extract_features(self, image) -> np.ndarray:
585
- # if isinstance(image, Image.Image):
586
- # image = np.array(image.convert("RGB"))
587
- # if isinstance(image, np.ndarray):
588
- # if image.dtype != np.uint8:
589
- # image = np.clip(image, 0, 255).astype(np.uint8)
590
- # if image.ndim == 2:
591
- # image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
 
 
592
 
593
- # image = self.apply_clahe(image)
594
- # image = cv2.GaussianBlur(image, (3, 3), 0)
 
595
 
596
- # tensor = self.transform(Image.fromarray(image)).unsqueeze(0)
597
- # with torch.no_grad():
598
- # features = self.model(tensor).squeeze().numpy()
599
- # return features
600
 
601
- # def cosine_similarity(self, f1, f2):
602
- # return float(
603
- # np.dot(f1, f2) / (np.linalg.norm(f1) * np.linalg.norm(f2) + 1e-8)
604
- # )
 
 
 
605
 
606
- # # ── Template management ───────────────────────────────────────────────
 
 
 
607
 
608
- # def save_template(self, image, part_name: str) -> str:
609
- # if image is None or not part_name:
610
- # return "Please provide both image and part name"
611
- # self.templates[part_name] = self.extract_features(image)
612
- # with open('templates.pkl', 'wb') as f:
613
- # pickle.dump(self.templates, f)
614
- # return f"βœ… Template '{part_name}' saved successfully!"
615
 
616
- # def load_templates(self):
617
- # if os.path.exists('templates.pkl'):
618
- # try:
619
- # with open('templates.pkl', 'rb') as f:
620
- # self.templates = pickle.load(f)
621
- # except Exception:
622
- # self.templates = {}
623
 
624
- # # ── 3-stage detection pipeline ────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
 
626
- # def match_part(self, image, threshold: float = 0.7):
627
-
628
- # if image is None:
629
- # return "Please provide an image", None
630
- # if not self.templates:
631
- # return "⚠️ No templates available. Add templates first.", None
632
-
633
- # lines = []
634
-
635
- # # ── Stage 0: mid-arc ──────────────────────────────────────────────
636
- # s0, r0, _, br0 = self.stage0_mid_arc(image)
637
-
638
- # if s0 == "FAIL":
639
- # lines += [
640
- # "# πŸ”΄ DEFECTED PIECE\n",
641
- # "**Stage 0 β€” Mid-Arc Check** ❌ FAILED",
642
- # f"> {r0}", "",
643
- # "**Bridge details:**",
644
- # ]
645
- # for br in br0:
646
- # s = br['stats']
647
- # ico = "βœ“" if br['present'] else "βœ—"
648
- # lines.append(
649
- # f" {ico} Bridge {br['bridge']}: "
650
- # f"brightness={s['mean']:.0f} "
651
- # f"bright_ratio={s['bright_ratio']:.0%}"
652
- # )
653
- # lines.append("\n_Stage 1 & 2 skipped._")
654
- # return "\n".join(lines), None
655
-
656
- # # Log Stage-0 result (PASS or SKIP)
657
- # if s0 == "SKIP":
658
- # lines.append(f"**Stage 0 β€” Mid-Arc Check** ⚠️ SKIPPED _{r0}_")
659
- # else:
660
- # lines.append("**Stage 0 β€” Mid-Arc Check** βœ… PASSED")
661
- # lines.append(f"> {r0}")
662
- # for br in br0:
663
- # s = br['stats']
664
- # lines.append(
665
- # f" βœ“ Bridge {br['bridge']}: "
666
- # f"brightness={s['mean']:.0f} "
667
- # f"bright_ratio={s['bright_ratio']:.0%}"
668
- # )
669
- # lines.append("")
670
-
671
- # # ── Stage 1: color metric ─────────────────────────────────────────
672
- # s1, r1 = self.stage1_color(image)
673
-
674
- # if s1 == "FAIL":
675
- # lines += [
676
- # "# πŸ”΄ DEFECTED PIECE\n",
677
- # "**Stage 1 β€” Color-Metric Check** ❌ FAILED",
678
- # f"> {r1}",
679
- # "\n_Stage 2 skipped._",
680
- # ]
681
- # return "\n".join(lines), None
682
-
683
- # lines.append("**Stage 1 β€” Color-Metric Check** βœ… PASSED")
684
- # lines.append(f"> {r1}")
685
- # lines.append("")
686
-
687
- # # ── Stage 2: ResNet template match ────────────────────────────────
688
- # qf = self.extract_features(image)
689
- # sims = sorted(
690
- # [(n, self.cosine_similarity(qf, f)) for n, f in self.templates.items()],
691
- # key=lambda x: x[1], reverse=True,
692
- # )
693
- # best_name, best_score = sims[0]
694
-
695
- # lines.append("**Stage 2 β€” Template Matching**")
696
- # lines.append(f" Best match: `{best_name}` β†’ **{best_score:.2%}**")
697
- # lines.append(f" Threshold: {threshold:.0%}")
698
- # lines.append("")
699
- # lines.append("**All Similarities:**")
700
- # for name, score in sims:
701
- # bar = "β–ˆ" * int(score * 20) + "β–‘" * (20 - int(score * 20))
702
- # lines.append(f" `{name}` {bar} {score:.2%}")
703
-
704
- # if best_score >= threshold:
705
- # lines.insert(0, "# βœ… PERFECT PIECE\n")
706
- # lines.append(f"\nβœ… **Final Decision**: MATCHED β€” `{best_name}`")
707
- # return "\n".join(lines), best_name
708
- # else:
709
- # lines.insert(0, "# 🟑 DEFECTED PIECE\n")
710
- # lines.append(
711
- # f"\n❌ **Final Decision**: NO MATCH "
712
- # f"(best {best_score:.2%} < threshold {threshold:.0%})"
713
- # )
714
- # return "\n".join(lines), None
715
 
 
 
 
 
 
716
 
717
- # # ─────────────────────────────────────────────────────────────────────────────
718
- # # Live-camera edge preview (streaming, lightweight)
719
- # # ─────────────────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
720
 
721
- # def live_edge_preview(frame: np.ndarray) -> np.ndarray:
722
- # if frame is None:
723
- # return frame
724
- # if frame.dtype != np.uint8:
725
- # frame = np.clip(frame, 0, 255).astype(np.uint8)
726
- # gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
727
- # blurred = cv2.GaussianBlur(gray, (5, 5), 0)
728
- # clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
729
- # enh = clahe.apply(blurred)
730
- # edges = cv2.Canny(enh, 30, 120)
731
- # dim = (frame * 0.35).astype(np.uint8)
732
- # dim[edges > 0] = [0, 220, 220]
733
- # return dim
734
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
735
 
736
- # # ─────────────────────────────────────────────────────────────────────────────
737
- # # Gradio callbacks
738
- # # ─────────────────────────────────────────────────────────────────────────────
739
 
740
- # detector = EnginePartDetector()
741
 
 
742
 
743
- # def detect_part(image, threshold):
744
- # return detector.match_part(image, threshold)
745
 
 
 
 
746
 
747
- # def debug_arc(image):
748
- # """Return annotated debug image + per-bridge text report for Stage 0."""
749
- # if image is None:
750
- # return None, "Please upload an image."
751
- # if isinstance(image, np.ndarray) and image.dtype != np.uint8:
752
- # image = np.clip(image, 0, 255).astype(np.uint8)
753
-
754
- # status, reason, debug_img, bridges = detector.stage0_mid_arc(image)
755
-
756
- # lines = [
757
- # f"**Stage 0 Status**: `{status}`",
758
- # f"**Reason**: {reason}",
759
- # "",
760
- # "**Legend:** 🟑 yellow = bolt hole | blue line = bridge axis",
761
- # "🟒 green box = arc PRESENT | πŸ”΄ red box = arc ABSENT",
762
- # "",
763
- # "**Bridge Results:**",
764
- # ]
765
- # for br in bridges:
766
- # s = br['stats']
767
- # ico = "βœ“ PRESENT" if br['present'] else "βœ— ABSENT"
768
- # lines.append(
769
- # f" Bridge {br['bridge']}: {ico} | "
770
- # f"brightness={s['mean']:.0f} | "
771
- # f"bright_ratio={s['bright_ratio']:.0%} | "
772
- # f"dark_ratio={s['dark_ratio']:.0%}"
773
- # )
774
 
775
- # return debug_img, "\n".join(lines)
 
 
776
 
 
 
777
 
778
- # def add_template(image, part_name):
779
- # return detector.save_template(image, part_name)
780
 
 
 
 
 
 
781
 
782
- # def list_templates():
783
- # if not detector.templates:
784
- # return "No templates saved yet"
785
- # return "\n".join(f"- {n}" for n in detector.templates)
786
-
787
-
788
- # def cb_capture(frame):
789
- # return frame
790
-
791
-
792
- # # ─────────────────────────────────────────────────────────────────────────────
793
- # # CSS β€” industrial dark theme (kept from reference)
794
- # # ─────────────────────────────────────────────────────────────────────────────
795
-
796
- # CSS = """
797
- # @import url('https://fonts.googleapis.com/css2?family=Share+Tech+Mono&family=Barlow:wght@300;400;600;700&display=swap');
798
- # :root {
799
- # --bg: #0d0f12; --surface: #151820; --card: #1c2030;
800
- # --border: #2a3045; --accent: #00e5ff; --accent2: #ff6b35;
801
- # --text: #d0d8e8; --muted: #5a6480;
802
- # --mono: 'Share Tech Mono', monospace;
803
- # --sans: 'Barlow', sans-serif;
804
- # }
805
- # body, .gradio-container { background: var(--bg) !important; font-family: var(--sans) !important; color: var(--text) !important; }
806
- # .tabs > .tab-nav { background: var(--surface) !important; border-bottom: 2px solid var(--border) !important; }
807
- # .tabs > .tab-nav button { font-family: var(--mono) !important; font-size: 0.75rem !important; letter-spacing: 0.1em !important; text-transform: uppercase !important; color: var(--muted) !important; background: transparent !important; border: none !important; padding: 10px 18px !important; }
808
- # .tabs > .tab-nav button.selected, .tabs > .tab-nav button:hover { color: var(--accent) !important; border-bottom: 2px solid var(--accent) !important; }
809
- # button.primary { background: var(--accent) !important; color: #000 !important; font-family: var(--mono) !important; font-weight: 700 !important; letter-spacing: 0.08em !important; text-transform: uppercase !important; border: none !important; border-radius: 4px !important; }
810
- # button.secondary { background: transparent !important; color: var(--accent2) !important; border: 1px solid var(--accent2) !important; font-family: var(--mono) !important; border-radius: 4px !important; }
811
- # .gr-markdown { font-family: var(--mono) !important; font-size: 0.8rem !important; color: var(--text) !important; background: var(--surface) !important; border: 1px solid var(--border) !important; border-radius: 4px !important; padding: 14px !important; line-height: 1.75 !important; }
812
- # input[type="range"] { accent-color: var(--accent) !important; }
813
- # .live-tag { font-family: var(--mono); font-size: 0.68rem; color: var(--accent); letter-spacing: 0.15em; text-transform: uppercase; margin-bottom: 4px; }
814
- # """
815
-
816
- # # ─────────────────────────────────────────────────────────────────────────────
817
- # # Gradio UI
818
- # # ─────────────────────────────────────────────────────────────────────────────
819
-
820
- # with gr.Blocks(title="ENGINE PART DETECTION") as demo:
821
-
822
- # gr.HTML("""
823
- # <div style="padding:18px 0 6px;border-bottom:1px solid #2a3045;margin-bottom:18px;">
824
- # <span style="font-family:'Share Tech Mono',monospace;font-size:1.5rem;color:#00e5ff;letter-spacing:0.1em;">
825
- # βš™ ENGINE SADDLE DETECTION SYSTEM
826
- # </span><br>
827
- # <span style="font-family:'Share Tech Mono',monospace;font-size:0.72rem;color:#5a6480;letter-spacing:0.06em;">
828
- # STAGE 0: HOUGH MID-ARC Β· STAGE 1: LAB COLOR METRIC Β· STAGE 2: RESNET TEMPLATE MATCH
829
- # </span>
830
- # </div>
831
- # """)
832
 
833
- # with gr.Tabs():
834
-
835
- # # ── Tab 1: Detect ─────────────────────────────────────────────────
836
- # with gr.TabItem("DETECT"):
837
- # gr.HTML("<p class='live-tag' style='margin-bottom:8px;'>Upload or capture Β· 3-stage detection pipeline</p>")
838
- # with gr.Row():
839
- # with gr.Column(scale=1):
840
- # detect_input = gr.Image(
841
- # sources=["upload", "webcam"],
842
- # type="numpy", label="Part Image", height=320,
843
- # )
844
- # threshold_slider = gr.Slider(
845
- # 0.50, 0.95, value=0.70, step=0.01,
846
- # label="Similarity Threshold (Stage 2)",
847
- # )
848
- # detect_btn = gr.Button("β–Ά RUN 3-STAGE DETECTION", variant="primary")
849
-
850
- # with gr.Column(scale=1):
851
- # detect_output = gr.Markdown(label="Detection Report")
852
- # match_label = gr.Label(label="Matched Part", num_top_classes=3)
853
-
854
- # detect_btn.click(
855
- # fn=detect_part,
856
- # inputs=[detect_input, threshold_slider],
857
- # outputs=[detect_output, match_label],
858
- # api_name="detect_part",
859
- # )
860
 
861
- # # ── Tab 2: Debug Arc ──────────────────────────────────────────────
862
- # with gr.TabItem("πŸ”¬ DEBUG ARC"):
863
- # gr.HTML("""
864
- # <p class='live-tag'>Visualise bolt-hole detection and mid-arc scan regions</p>
865
- # <p style='font-family:Share Tech Mono,monospace;font-size:0.72rem;color:#5a6480;'>
866
- # 🟑 Yellow circles = bolt holes &nbsp;|&nbsp;
867
- # Blue line = bridge axis &nbsp;|&nbsp;
868
- # 🟒 Green box = arc PRESENT &nbsp;|&nbsp;
869
- # πŸ”΄ Red box = arc ABSENT
870
- # </p>
871
- # """)
872
- # with gr.Row():
873
- # with gr.Column(scale=1):
874
- # debug_input = gr.Image(
875
- # sources=["upload", "webcam"],
876
- # type="numpy", label="Input Image", height=320,
877
- # )
878
- # debug_btn = gr.Button("πŸ”¬ RUN ARC DEBUG", variant="primary")
879
-
880
- # with gr.Column(scale=1):
881
- # debug_out_img = gr.Image(label="Annotated Debug View", height=320)
882
- # debug_out_txt = gr.Markdown(label="Bridge Results")
883
-
884
- # debug_btn.click(
885
- # fn=debug_arc,
886
- # inputs=[debug_input],
887
- # outputs=[debug_out_img, debug_out_txt],
888
- # )
889
 
890
- # # ── Tab 3: Live Camera ────────────────────────────────────────────
891
- # with gr.TabItem("πŸ“· LIVE CAMERA"):
892
- # gr.HTML("<p class='live-tag'>● Live feed Β· Capture frame Β· Run 3-stage detection</p>")
893
- # with gr.Row():
894
- # with gr.Column(scale=1):
895
- # gr.HTML("<div class='live-tag'>● LIVE FEED</div>")
896
- # live_input = gr.Image(sources=["webcam"], streaming=True,
897
- # type="numpy", height=280)
898
- # gr.HTML("<div class='live-tag' style='margin-top:8px;'>CLAHE EDGE OVERLAY</div>")
899
- # live_output = gr.Image(label="Edge Preview", height=280)
900
-
901
- # with gr.Column(scale=1):
902
- # gr.HTML("<div class='live-tag'>CAPTURED FRAME</div>")
903
- # captured_frame = gr.Image(type="numpy", height=280, interactive=False)
904
- # with gr.Row():
905
- # capture_btn = gr.Button("πŸ“Έ CAPTURE FRAME", variant="primary")
906
- # detect_live_btn = gr.Button("β–Ά DETECT CAPTURED", variant="secondary")
907
- # live_threshold = gr.Slider(0.50, 0.95, value=0.70, step=0.01,
908
- # label="Similarity Threshold")
909
- # live_result = gr.Markdown(label="Detection Result")
910
- # live_label = gr.Label(label="Matched Part", num_top_classes=3)
911
-
912
- # live_input.stream(fn=live_edge_preview,
913
- # inputs=[live_input], outputs=[live_output])
914
- # capture_btn.click(fn=cb_capture,
915
- # inputs=[live_input], outputs=[captured_frame])
916
- # detect_live_btn.click(
917
- # fn=detect_part,
918
- # inputs=[captured_frame, live_threshold],
919
- # outputs=[live_result, live_label],
920
- # )
921
 
922
- # # ── Tab 4: Add Template ───────────────────────────────────────────
923
- # with gr.TabItem("βž• ADD TEMPLATE"):
924
- # gr.HTML("<p class='live-tag'>Register reference images Β· CLAHE + ResNet-50 feature pipeline</p>")
925
- # with gr.Row():
926
- # with gr.Column(scale=1):
927
- # template_input = gr.Image(sources=["upload", "webcam"],
928
- # type="numpy", label="Reference Image", height=300)
929
- # part_name_input = gr.Textbox(
930
- # label="Part Name",
931
- # placeholder="e.g. bearing_saddle_ok / piston_perfect",
932
- # )
933
- # add_btn = gr.Button("πŸ’Ύ SAVE TEMPLATE", variant="primary")
934
-
935
- # with gr.Column(scale=1):
936
- # add_output = gr.Markdown(label="Status")
937
- # gr.HTML("""
938
- # <div style="background:#1c2030;border:1px solid #2a3045;border-radius:4px;
939
- # padding:14px;font-family:'Share Tech Mono',monospace;
940
- # font-size:0.72rem;color:#5a6480;line-height:1.9;">
941
- # Feature extraction pipeline:<br>
942
- # &nbsp; 1 Β· CLAHE on LAB L-channel (shadow recovery)<br>
943
- # &nbsp; 2 Β· Gaussian blur 3Γ—3 (metallic sheen suppression)<br>
944
- # &nbsp; 3 Β· ResNet-50 backbone β†’ 2048-D feature vector<br>
945
- # &nbsp; 4 Β· Persisted to templates.pkl
946
- # </div>
947
- # """)
948
-
949
- # add_btn.click(
950
- # fn=add_template,
951
- # inputs=[template_input, part_name_input],
952
- # outputs=[add_output],
953
- # api_name="add_template",
954
- # )
955
 
956
- # # ── Tab 5: Manage Templates ───────────────────────────────────────
957
- # with gr.TabItem("πŸ“‹ TEMPLATES"):
958
- # gr.HTML("<p class='live-tag'>View and remove saved templates</p>")
959
- # with gr.Row():
960
- # with gr.Column(scale=2):
961
- # template_list_out = gr.Markdown(label="Saved Templates")
962
- # refresh_btn = gr.Button("↻ REFRESH LIST", variant="secondary")
963
-
964
- # with gr.Column(scale=1):
965
- # delete_name = gr.Textbox(label="Part name to delete", placeholder="exact name…")
966
- # delete_btn = gr.Button("πŸ—‘ DELETE TEMPLATE", variant="secondary")
967
- # delete_status = gr.Markdown(label="Status")
968
-
969
- # def cb_delete(name):
970
- # if name in detector.templates:
971
- # del detector.templates[name]
972
- # with open('templates.pkl', 'wb') as f:
973
- # pickle.dump(detector.templates, f)
974
- # msg = f"πŸ—‘οΈ Template '{name}' deleted."
975
- # else:
976
- # msg = f"⚠️ '{name}' not found."
977
- # return msg, list_templates()
978
-
979
- # refresh_btn.click(fn=list_templates, outputs=[template_list_out],
980
- # api_name="list_templates")
981
- # delete_btn.click(fn=cb_delete, inputs=[delete_name],
982
- # outputs=[delete_status, template_list_out])
983
-
984
- # # Hidden machine-readable endpoint (matches reference client)
985
- # raw_list_btn = gr.Button("RAW LIST", visible=False)
986
- # raw_list_btn.click(
987
- # fn=lambda: ",".join(sorted(detector.templates.keys())),
988
- # outputs=[],
989
- # api_name="list_templates_raw",
990
- # )
991
 
992
- # demo.load(fn=list_templates, outputs=[template_list_out])
 
 
 
993
 
 
 
 
994
 
995
- # # ───────────────────────────────────────────────────────────��─────────────────
 
 
 
 
 
996
 
997
- # if __name__ == "__main__":
998
- # demo.launch(
999
- # css=CSS,
1000
- # server_name="0.0.0.0",
1001
- # server_port=7860,
1002
- # ssr_mode=False,
1003
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # app.py
2
+ # import gradio as gr
3
+ # import torch
4
+ # import torch.nn as nn
5
+ # from torchvision import models, transforms
6
+ # from PIL import Image
7
+ # import numpy as np
8
+ # import pickle
9
+ # import os
10
+ # import cv2
11
+
12
+ # # class EnginePartDetector:
13
+ # # def __init__(self):
14
+ # # self.model = models.resnet50(weights='IMAGENET1K_V1')
15
+ # # self.model = nn.Sequential(*list(self.model.children())[:-1])
16
+ # # self.model.eval()
17
+
18
+ # # self.transform = transforms.Compose([
19
+ # # transforms.Resize((224, 224)),
20
+ # # transforms.ToTensor(),
21
+ # # transforms.Normalize(
22
+ # # mean=[0.485, 0.456, 0.406],
23
+ # # std=[0.229, 0.224, 0.225]
24
+ # # )
25
+ # # ])
26
+
27
+ # # self.templates = {}
28
+ # # self.load_templates()
29
+
30
+ # # def extract_features(self, image):
31
+ # # if isinstance(image, np.ndarray):
32
+ # # image = Image.fromarray(image)
33
+
34
+ # # img_tensor = self.transform(image).unsqueeze(0)
35
+
36
+ # # with torch.no_grad():
37
+ # # features = self.model(img_tensor)
38
+ # # features = features.squeeze().numpy()
39
+
40
+ # # return features
41
 
42
  # class EnginePartDetector:
43
+ # def __init__(
44
+ # self,
45
+ # clahe_clip_limit: float = 9.9,
46
+ # clahe_tile_grid: tuple = (8, 8),
47
+ # ):
48
+ # # ── ResNet-50 backbone (feature extractor only) ──────────────────
49
  # self.model = models.resnet50(weights='IMAGENET1K_V1')
50
  # self.model = nn.Sequential(*list(self.model.children())[:-1])
51
  # self.model.eval()
52
+
53
+ # # ── CLAHE (OpenCV) β€” applied BEFORE the torch transform ──────────
54
+ # # Operates on grayscale to recover shadow-suppressed edges
55
+ # # (e.g. missing bearing saddle arcs), then merged back to RGB
56
+ # # so the 3-channel ResNet pipeline is unaffected.
57
+ # self.clahe = cv2.createCLAHE(
58
+ # clipLimit=clahe_clip_limit,
59
+ # tileGridSize=clahe_tile_grid,
60
+ # )
61
+
62
+ # # ── ResNet normalisation transform (unchanged) ───────────────────
63
  # self.transform = transforms.Compose([
64
  # transforms.Resize((224, 224)),
65
  # transforms.ToTensor(),
66
  # transforms.Normalize(
67
  # mean=[0.485, 0.456, 0.406],
68
+ # std=[0.229, 0.224, 0.225],
69
  # )
70
  # ])
71
+
72
  # self.templates = {}
73
  # self.load_templates()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ # # ── CLAHE preprocessing ───────────────────────────────────────────────
 
76
 
77
+ # def apply_clahe(self, image: np.ndarray) -> np.ndarray:
 
 
78
 
79
+ # # Convert RGB (PIL/numpy) β†’ BGR for OpenCV
80
+ # bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
81
 
82
+ # # BGR β†’ LAB
83
+ # lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB)
84
 
85
+ # # Split channels; apply CLAHE only to L (luminance)
86
+ # l_channel, a_channel, b_channel = cv2.split(lab)
87
+ # l_enhanced = self.clahe.apply(l_channel)
88
 
89
+ # # Merge enhanced L back with untouched A and B
90
+ # lab_enhanced = cv2.merge([l_enhanced, a_channel, b_channel])
91
 
92
+ # # LAB β†’ BGR β†’ RGB
93
+ # bgr_enhanced = cv2.cvtColor(lab_enhanced, cv2.COLOR_LAB2BGR)
94
+ # rgb_enhanced = cv2.cvtColor(bgr_enhanced, cv2.COLOR_BGR2RGB)
95
 
96
+ # return rgb_enhanced # uint8 numpy array, same shape as input
97
 
98
+ # # ── Feature extraction ────────────────────────────────────────────────
99
 
100
+ # def extract_features(self, image) -> np.ndarray:
101
 
102
+ # # 1. Normalise input to numpy uint8 RGB
103
+ # if isinstance(image, Image.Image):
104
+ # image = np.array(image.convert("RGB"))
105
+ # elif isinstance(image, np.ndarray) and image.dtype != np.uint8:
106
+ # image = image.astype(np.uint8)
107
 
108
+ # # 2. CLAHE β€” recover shadow-suppressed structural edges
109
+ # image = self.apply_clahe(image)
110
 
111
+ # # 3. Mild Gaussian blur β€” reduces high-freq metallic sheen noise
112
+ # # that CLAHE can amplify; kernel (3,3) is intentionally light
113
+ # # so real surface-defect texture is preserved
114
+ # image = cv2.GaussianBlur(image, (3, 3), 0)
115
 
116
+ # # 4. Convert back to PIL for torchvision transforms
117
+ # image_pil = Image.fromarray(image)
118
 
119
+ # # 5. ResNet transform β†’ tensor
120
+ # img_tensor = self.transform(image_pil).unsqueeze(0)
121
 
122
+ # # 6. Forward pass (no grad needed β€” inference only)
123
+ # with torch.no_grad():
124
+ # features = self.model(img_tensor)
125
+ # features = features.squeeze().numpy()
126
 
127
+ # return features
128
 
129
+ # def cosine_similarity(self, feat1, feat2):
130
+ # return np.dot(feat1, feat2) / (np.linalg.norm(feat1) * np.linalg.norm(feat2))
131
 
132
+ # def save_template(self, image, part_name):
133
+ # if image is None or not part_name:
134
+ # return "Please provide both image and part name"
135
 
136
+ # features = self.extract_features(image)
137
+ # self.templates[part_name] = features
138
 
139
+ # with open('templates.pkl', 'wb') as f:
140
+ # pickle.dump(self.templates, f)
141
 
142
+ # return f"βœ… Template '{part_name}' saved successfully!"
143
 
144
+ # def load_templates(self):
145
+ # if os.path.exists('templates.pkl'):
146
+ # try:
147
+ # with open('templates.pkl', 'rb') as f:
148
+ # self.templates = pickle.load(f)
149
+ # except:
150
+ # self.templates = {}
151
 
152
+ # def match_part(self, image, threshold=0.7):
153
+ # if image is None:
154
+ # return "Please provide an image", None
155
 
156
+ # if not self.templates:
157
+ # return "⚠️ No templates available. Please add templates first.", None
158
 
159
+ # query_features = self.extract_features(image)
160
 
161
+ # results = []
162
+ # for part_name, template_features in self.templates.items():
163
+ # similarity = self.cosine_similarity(query_features, template_features)
164
+ # results.append((part_name, similarity))
165
 
166
+ # results.sort(key=lambda x: x[1], reverse=True)
167
 
168
+ # best_match = results[0]
169
+ # output_text = f"πŸ” **Best Match**: {best_match[0]}\n"
170
+ # output_text += f"πŸ“Š **Confidence**: {best_match[1]:.2%}\n\n"
171
 
172
+ # if best_match[1] >= threshold:
173
+ # output_text += "βœ… **Status**: MATCHED\n\n"
174
+ # else:
175
+ # output_text += "❌ **Status**: NO MATCH (below threshold)\n\n"
176
 
177
+ # output_text += "**All Results:**\n"
178
+ # for part, sim in results:
179
+ # output_text += f"- {part}: {sim:.2%}\n"
180
 
181
+ # matched_label = best_match[0] if best_match[1] >= threshold else None
182
+ # return output_text, matched_label
183
 
184
+ # detector = EnginePartDetector()
185
 
186
+ # def add_template(image, part_name):
187
+ # return detector.save_template(image, part_name)
188
 
189
+ # def detect_part(image, threshold):
190
+ # return detector.match_part(image, threshold)
191
 
192
+ # def list_templates():
193
+ # if not detector.templates:
194
+ # return "No templates saved yet"
195
+ # return "\n".join([f"- {name}" for name in detector.templates.keys()])
196
 
197
+ # with gr.Blocks(title="Engine Part Detection System") as demo:
198
+ # gr.Markdown("""
199
+ # # πŸ”§ Engine Part Detection System
200
+ # ### Using ResNet50 Feature Extraction & Template Matching
201
 
202
+ # **How to use:**
203
+ # 1. **Add Templates**: Upload reference images of engine parts
204
+ # 2. **Detect Parts**: Upload/capture images to identify parts
205
+ # """)
206
 
207
+ # with gr.Tab("πŸ” Detect Part"):
208
+ # with gr.Row():
209
+ # with gr.Column():
210
+ # detect_input = gr.Image(sources=["upload", "webcam"], type="numpy")
211
+ # threshold_slider = gr.Slider(0.5, 0.95, value=0.7, label="Similarity Threshold")
212
+ # detect_btn = gr.Button("Detect Part", variant="primary")
213
+ # with gr.Column():
214
+ # detect_output = gr.Textbox(label="Detection Results", lines=10)
215
+ # match_label = gr.Label(label="Matched Part")
216
 
217
+ # detect_btn.click(
218
+ # fn=detect_part,
219
+ # inputs=[detect_input, threshold_slider],
220
+ # outputs=[detect_output, match_label],
221
+ # api_name="detect"
222
+ # )
223
 
224
+ # with gr.Tab("βž• Add Template"):
225
+ # with gr.Row():
226
+ # with gr.Column():
227
+ # template_input = gr.Image(sources=["upload"], type="numpy")
228
+ # part_name_input = gr.Textbox(label="Part Name (e.g., 'spark_plug', 'piston')")
229
+ # add_btn = gr.Button("Save Template", variant="primary")
230
+ # with gr.Column():
231
+ # add_output = gr.Textbox(label="Status")
232
 
233
+ # add_btn.click(
234
+ # fn=add_template,
235
+ # inputs=[template_input, part_name_input],
236
+ # outputs=add_output,
237
+ # api_name="add_template"
238
+ # )
239
 
240
+ # with gr.Tab("πŸ“‹ View Templates"):
241
+ # template_list = gr.Textbox(label="Saved Templates", lines=10)
242
+ # refresh_btn = gr.Button("Refresh List")
243
+ # refresh_btn.click(
244
+ # fn=list_templates,
245
+ # outputs=template_list,
246
+ # api_name="list_templates"
247
+ # )
248
+ # demo.load(fn=list_templates, outputs=template_list)
249
+
250
+ # if __name__ == "__main__":
251
+ # demo.launch()
252
+
253
+ """
254
+ Production-Grade Engine Part Detection System
255
+ Multi-Layer Architecture:
256
+ Layer 1 β€” Geometric Detection : HoughCircles β†’ bolt-hole centers β†’ line fitting β†’ saddle crop
257
+ Layer 2 β€” Feature Extraction : CLAHE pre-processing β†’ ResNet-50 deep features
258
+ Layer 3 β€” Template Matching : Cosine-similarity with confidence scoring
259
+ Layer 4 β€” UI / Orchestration : Gradio interface tying all layers together
260
+
261
+ Author: Senior CV Engineer
262
+ """
263
+
264
+ # ──────────────────────────────────────────────────────────────────────────────
265
+ # Standard & third-party imports
266
+ # ──────────────────────────────────────────────────────────────────────────────
267
+ import os
268
+ import pickle
269
+ import logging
270
+ from dataclasses import dataclass, field
271
+ from typing import Optional, Tuple, List
272
+
273
+ import cv2
274
+ import gradio as gr
275
+ import numpy as np
276
+ import torch
277
+ import torch.nn as nn
278
+ from PIL import Image
279
+ from scipy import stats
280
+ from torchvision import models, transforms
281
+
282
+ # ──────────────────────────────────────────────────────────────────────────────
283
+ # Logging
284
+ # ──────────────────────────────────────────────────────────────────────────────
285
+ logging.basicConfig(
286
+ level=logging.INFO,
287
+ format="[%(levelname)s] %(name)s β€” %(message)s",
288
+ )
289
+ logger = logging.getLogger("EnginePartDetector")
290
+
291
+
292
+ # ���─────────────────────────────────────────────────────────────────────────────
293
+ # Config dataclass β€” all tunable hyperparameters in one place
294
+ # ──────────────────────────────────────────────────────────────────────────────
295
+ @dataclass
296
+ class DetectorConfig:
297
+ # ── Hough circle params ──────────────────────────────────────────
298
+ hough_dp: float = 1.2 # Inverse resolution ratio
299
+ hough_param1: int = 60 # Canny high threshold
300
+ hough_param2: int = 25 # Accumulator threshold (lower = more detections)
301
+ bolt_min_radius: int = 8 # px β€” smallest bolt-hole radius expected
302
+ bolt_max_radius: int = 40 # px β€” largest bolt-hole radius expected
303
+ min_dist_factor: float = 0.06 # minDist = factor Γ— min(H, W)
304
+
305
+ # ── Row-clustering params ────────────────────────────────────────
306
+ # Y-axis gap between two circle rows must exceed this fraction of image height
307
+ row_separation_min: float = 0.05
308
+
309
+ # ── Crop padding ─────────────────────────────────────────────────
310
+ crop_padding_px: int = 12 # Extra pixels above top-line & below bottom-line
311
+
312
+ # ── CLAHE params ─────────────────────────────────────────────────
313
+ clahe_clip_limit: float = 9.9
314
+ clahe_tile_grid: Tuple[int, int] = field(default_factory=lambda: (8, 8))
315
+
316
+ # ── Template store ───────────────────────────────────────────────
317
+ templates_path: str = "templates.pkl"
318
+
319
+ # ── Detection threshold ──────────────────────────────────────────
320
+ default_threshold: float = 0.70
321
+
322
+
323
+ # ──────────────────────────────────────────────────────────────────────────────
324
+ # LAYER 1 β€” SaddleRegionExtractor
325
+ # ──────────────────────────────────────────────────────────────────────────────
326
+ class SaddleRegionExtractor:
327
+ """
328
+ Geometric detection layer.
329
+
330
+ Pipeline
331
+ --------
332
+ 1. Pre-process (CLAHE + bilateral filter) for robust edge contrast
333
+ 2. HoughCircles to find all circles in the image
334
+ 3. Filter to bolt-hole size range
335
+ 4. K-means (k=2) on Y-coordinate β†’ top row / bottom row
336
+ 5. Robust linear regression (via scipy) through each row's centers
337
+ 6. Per-column masking to crop exactly between the two fitted lines
338
+ """
339
+
340
+ def __init__(self, cfg: DetectorConfig):
341
+ self.cfg = cfg
342
+ self._clahe = cv2.createCLAHE(
343
+ clipLimit=cfg.clahe_clip_limit,
344
+ tileGridSize=cfg.clahe_tile_grid,
345
  )
 
346
 
347
+ # ── Pre-processing ────────────────────────────────────────────────────────
 
348
 
349
+ def _preprocess(self, rgb: np.ndarray) -> np.ndarray:
350
+ """Enhance contrast and reduce noise for circle detection."""
351
+ bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
352
+ lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB)
353
+ l, a, b = cv2.split(lab)
354
+ l = self._clahe.apply(l)
355
+ gray = cv2.cvtColor(cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2BGR), cv2.COLOR_BGR2GRAY)
356
+ # Bilateral filter: reduces metallic sheen noise but preserves hole edges
357
+ gray = cv2.bilateralFilter(gray, d=9, sigmaColor=80, sigmaSpace=80)
358
+ return gray
359
+
360
+ # ── Circle detection ──────────────────────────────────────────────────────
361
+
362
+ def _detect_all_circles(self, rgb: np.ndarray) -> np.ndarray:
363
+ """Run HoughCircles and return Nx3 array [[cx, cy, r], …]."""
364
+ h, w = rgb.shape[:2]
365
+ gray = self._preprocess(rgb)
366
+
367
+ min_dist = max(15, int(min(h, w) * self.cfg.min_dist_factor))
368
+
369
+ raw = cv2.HoughCircles(
370
+ gray,
371
+ cv2.HOUGH_GRADIENT,
372
+ dp=self.cfg.hough_dp,
373
+ minDist=min_dist,
374
+ param1=self.cfg.hough_param1,
375
+ param2=self.cfg.hough_param2,
376
+ minRadius=self.cfg.bolt_min_radius,
377
+ maxRadius=self.cfg.bolt_max_radius,
378
+ )
379
+ if raw is None:
380
+ return np.empty((0, 3), dtype=int)
381
+ return np.round(raw[0]).astype(int)
382
+
383
+ def _filter_bolt_holes(self, circles: np.ndarray, image_shape) -> np.ndarray:
384
+ """
385
+ Keep only circles that are:
386
+ β€’ Radius within configured bolt-hole range
387
+ β€’ Fully inside the image bounds
388
+ β€’ Area < 2 % of total image area (excludes large bearing journals)
389
+ """
390
+ h, w = image_shape[:2]
391
+ img_area = float(h * w)
392
+ kept = []
393
+ for cx, cy, r in circles:
394
+ area_ratio = (np.pi * r * r) / img_area
395
+ in_bounds = (r < cx < w - r) and (r < cy < h - r)
396
+ if area_ratio < 0.02 and in_bounds:
397
+ kept.append([cx, cy, r])
398
+ return np.array(kept, dtype=int) if kept else np.empty((0, 3), dtype=int)
399
+
400
+ # ── Row clustering ────────────────────────────────────────────────────────
401
+
402
+ def _cluster_rows(
403
+ self, circles: np.ndarray, image_height: int
404
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
405
+ """
406
+ Split circles into two rows using a Y-midpoint threshold.
407
+
408
+ Strategy
409
+ --------
410
+ Sort unique Y-values, find the largest gap, split there.
411
+ This is more stable than K-Means for this structured layout because
412
+ bolt-hole rows are clearly separated in Y.
413
+ """
414
+ if len(circles) < 4:
415
+ logger.warning("Too few bolt holes (%d) for row clustering.", len(circles))
416
+ return None, None
417
+
418
+ y_vals = circles[:, 1].astype(float)
419
+
420
+ # Adaptive split: find largest gap in sorted Y values
421
+ sorted_y = np.sort(np.unique(y_vals))
422
+ if len(sorted_y) < 2:
423
+ return None, None
424
+
425
+ gaps = np.diff(sorted_y)
426
+ min_sep = image_height * self.cfg.row_separation_min
427
+
428
+ # Only consider gaps larger than the minimum separation
429
+ valid_gap_idxs = np.where(gaps >= min_sep)[0]
430
+ if len(valid_gap_idxs) == 0:
431
+ logger.warning("Rows too close together; trying mean-split fallback.")
432
+ split_y = float(np.mean(y_vals))
433
+ else:
434
+ best_gap_idx = valid_gap_idxs[np.argmax(gaps[valid_gap_idxs])]
435
+ split_y = (sorted_y[best_gap_idx] + sorted_y[best_gap_idx + 1]) / 2.0
436
 
437
+ top_mask = y_vals < split_y
438
+ bot_mask = ~top_mask
 
 
 
 
 
 
439
 
440
+ top_circles = circles[top_mask]
441
+ bot_circles = circles[bot_mask]
 
 
 
 
 
 
 
442
 
443
+ if len(top_circles) == 0 or len(bot_circles) == 0:
444
+ logger.warning("One row is empty after split.")
445
+ return None, None
446
 
447
+ return top_circles, bot_circles
 
 
448
 
449
+ # ── Line fitting ──────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
450
 
451
+ @staticmethod
452
+ def _fit_line(circles: np.ndarray) -> Tuple[float, float]:
453
+ """
454
+ Fit y = slopeΒ·x + intercept through circle centers.
455
+ Falls back to a horizontal line through the median Y when < 2 points.
456
+ """
457
+ if len(circles) < 2:
458
+ return 0.0, float(np.median(circles[:, 1]))
459
 
460
+ x = circles[:, 0].astype(float)
461
+ y = circles[:, 1].astype(float)
 
 
 
 
 
 
 
462
 
463
+ # scipy linregress for numerically stable OLS
464
+ result = stats.linregress(x, y)
465
+ return float(result.slope), float(result.intercept)
466
 
467
+ @staticmethod
468
+ def _line_endpoints(slope: float, intercept: float, width: int):
469
+ """Return (pt_left, pt_right) spanning the full image width."""
470
+ return (0, int(intercept)), (width - 1, int(slope * (width - 1) + intercept))
471
 
472
+ # ── Saddle crop ───────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
 
474
+ def _crop_between_lines(
475
+ self,
476
+ rgb: np.ndarray,
477
+ top_line: Tuple[float, float],
478
+ bot_line: Tuple[float, float],
479
+ ) -> np.ndarray:
480
+ """
481
+ Per-column mask: for each x, keep rows from y_topβˆ’pad to y_bot+pad.
482
+ Then bounding-box crop the unmasked region.
483
+
484
+ This handles tilted images correctly because the mask boundary follows
485
+ the actual fitted line slope rather than a fixed row cut.
486
+ """
487
+ h, w = rgb.shape[:2]
488
+ pad = self.cfg.crop_padding_px
489
+ st, it = top_line
490
+ sb, ib = bot_line
491
+
492
+ mask = np.zeros((h, w), dtype=np.uint8)
493
+ xs = np.arange(w, dtype=float)
494
+ y_tops = np.clip((st * xs + it - pad).astype(int), 0, h - 1)
495
+ y_bots = np.clip((sb * xs + ib + pad).astype(int), 0, h - 1)
496
+
497
+ for x in range(w):
498
+ yt, yb = y_tops[x], y_bots[x]
499
+ if yb > yt:
500
+ mask[yt:yb, x] = 255
501
+
502
+ # Apply mask (background β†’ black)
503
+ masked = cv2.bitwise_and(rgb, rgb, mask=mask)
504
+
505
+ # Tight bounding box around non-zero pixels
506
+ ys_nz, xs_nz = np.where(mask > 0)
507
+ if len(ys_nz) == 0:
508
+ logger.warning("Empty mask β€” returning full image.")
509
+ return rgb
510
+
511
+ y0, y1 = ys_nz.min(), ys_nz.max()
512
+ x0, x1 = xs_nz.min(), xs_nz.max()
513
+ return masked[y0:y1, x0:x1]
514
+
515
+ # ── Public API ────────────────────────────────────────────────────────────
516
+
517
+ def extract(
518
+ self, rgb: np.ndarray
519
+ ) -> Tuple[np.ndarray, np.ndarray, bool, str]:
520
+ """
521
+ Run the full Layer-1 pipeline.
522
+
523
+ Returns
524
+ -------
525
+ cropped : np.ndarray β€” saddle region (or full image on failure)
526
+ debug_img : np.ndarray β€” annotated image for visual inspection
527
+ success : bool
528
+ message : str
529
+ """
530
+ debug = rgb.copy()
531
+ h, w = rgb.shape[:2]
532
+
533
+ # ── Step 1: detect circles ────────────────────────────────────
534
+ all_circles = self._detect_all_circles(rgb)
535
+ bolt_holes = self._filter_bolt_holes(all_circles, rgb.shape)
536
+
537
+ # Draw ALL detected circles in blue (before filtering)
538
+ for cx, cy, r in all_circles:
539
+ cv2.circle(debug, (cx, cy), r, (100, 100, 255), 1, cv2.LINE_AA)
540
+
541
+ if len(bolt_holes) < 4:
542
+ msg = (
543
+ f"⚠️ Only {len(bolt_holes)} bolt holes detected "
544
+ f"(need β‰₯ 4). Using full image."
545
+ )
546
+ logger.warning(msg)
547
+ return rgb, debug, False, msg
548
+
549
+ # Draw bolt holes in green
550
+ for cx, cy, r in bolt_holes:
551
+ cv2.circle(debug, (cx, cy), r, (0, 230, 0), 2, cv2.LINE_AA)
552
+ cv2.circle(debug, (cx, cy), 4, (0, 230, 0), -1)
553
+
554
+ # ── Step 2: cluster into rows ─────────────────────────────────
555
+ top_circles, bot_circles = self._cluster_rows(bolt_holes, h)
556
+ if top_circles is None:
557
+ msg = "⚠️ Row clustering failed. Using full image."
558
+ logger.warning(msg)
559
+ return rgb, debug, False, msg
560
+
561
+ # Highlight rows: orange=top, red=bottom
562
+ for cx, cy, r in top_circles:
563
+ cv2.circle(debug, (cx, cy), r + 5, (255, 165, 0), 2, cv2.LINE_AA)
564
+ cv2.putText(debug, "T", (cx - 6, cy + 5),
565
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 165, 0), 1, cv2.LINE_AA)
566
+ for cx, cy, r in bot_circles:
567
+ cv2.circle(debug, (cx, cy), r + 5, (0, 50, 220), 2, cv2.LINE_AA)
568
+ cv2.putText(debug, "B", (cx - 6, cy + 5),
569
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 50, 220), 1, cv2.LINE_AA)
570
+
571
+ # ── Step 3: fit lines ─────────────────────────────────────────
572
+ top_line = self._fit_line(top_circles)
573
+ bot_line = self._fit_line(bot_circles)
574
+
575
+ # Guarantee top_line is visually above bot_line at image center
576
+ cx_mid = w // 2
577
+ if (top_line[0] * cx_mid + top_line[1]) > (bot_line[0] * cx_mid + bot_line[1]):
578
+ top_line, bot_line = bot_line, top_line
579
+ top_circles, bot_circles = bot_circles, top_circles
580
+ logger.info("Swapped top/bottom lines to maintain spatial order.")
581
+
582
+ # Draw fitted lines
583
+ pt_t1, pt_t2 = self._line_endpoints(*top_line, w)
584
+ pt_b1, pt_b2 = self._line_endpoints(*bot_line, w)
585
+ cv2.line(debug, pt_t1, pt_t2, (255, 165, 0), 2, cv2.LINE_AA) # orange = top
586
+ cv2.line(debug, pt_b1, pt_b2, (0, 50, 220), 2, cv2.LINE_AA) # blue = bottom
587
+
588
+ # Label lines
589
+ cv2.putText(debug, "TOP LINE", (10, pt_t1[1] - 8),
590
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 165, 0), 2, cv2.LINE_AA)
591
+ cv2.putText(debug, "BOTTOM LINE", (10, pt_b1[1] + 18),
592
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 50, 220), 2, cv2.LINE_AA)
593
+
594
+ # Draw shaded region between lines
595
+ overlay = debug.copy()
596
+ for x in range(w):
597
+ yt = max(0, int(top_line[0] * x + top_line[1]))
598
+ yb = min(h - 1, int(bot_line[0] * x + bot_line[1]))
599
+ if yb > yt:
600
+ overlay[yt:yb, x] = (
601
+ overlay[yt:yb, x] * 0.6 + np.array([255, 255, 0]) * 0.4
602
+ ).astype(np.uint8)
603
+ cv2.addWeighted(overlay, 0.5, debug, 0.5, 0, debug)
604
+
605
+ # ── Step 4: crop saddle ───────────────────────────────────────
606
+ cropped = self._crop_between_lines(rgb, top_line, bot_line)
607
+
608
+ msg = (
609
+ f"βœ… {len(bolt_holes)} bolt holes | "
610
+ f"top-row: {len(top_circles)} | "
611
+ f"bot-row: {len(bot_circles)} | "
612
+ f"Saddle cropped: {cropped.shape[1]}Γ—{cropped.shape[0]} px"
613
+ )
614
+ logger.info(msg)
615
+ return cropped, debug, True, msg
616
 
 
 
617
 
618
+ # ──────────────────────────────────────────────────────────────────────────────
619
+ # LAYER 2 β€” FeatureExtractor (ResNet-50 + CLAHE)
620
+ # ──────────────────────────────────────────────────────────────────────────────
621
+ class FeatureExtractor:
622
+ """
623
+ Deep-feature extraction with luminance-adaptive pre-processing.
624
 
625
+ CLAHE is applied in LAB space on the L-channel only, so colour
626
+ calibration for the ResNet normalisation is unaffected.
627
+ """
 
 
 
 
 
 
 
628
 
629
+ def __init__(self, cfg: DetectorConfig):
630
+ self._clahe = cv2.createCLAHE(
631
+ clipLimit=cfg.clahe_clip_limit,
632
+ tileGridSize=cfg.clahe_tile_grid,
633
+ )
634
+ # ResNet-50 as fixed backbone; remove average-pool + FC β†’ 2048-D vector
635
+ backbone = models.resnet50(weights="IMAGENET1K_V1")
636
+ self.model = nn.Sequential(*list(backbone.children())[:-1])
637
+ self.model.eval()
638
 
639
+ self._transform = transforms.Compose([
640
+ transforms.Resize((224, 224)),
641
+ transforms.ToTensor(),
642
+ transforms.Normalize(
643
+ mean=[0.485, 0.456, 0.406],
644
+ std=[0.229, 0.224, 0.225],
645
+ ),
646
+ ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647
 
648
+ # ── Internal CLAHE ────────────────────────────────────────────────────────
 
 
649
 
650
+ def _apply_clahe(self, rgb: np.ndarray) -> np.ndarray:
651
+ bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
652
+ lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB)
653
+ l, a, b = cv2.split(lab)
654
+ l_enh = self._clahe.apply(l)
655
+ bgr_enh = cv2.cvtColor(cv2.merge([l_enh, a, b]), cv2.COLOR_LAB2BGR)
656
+ return cv2.cvtColor(bgr_enh, cv2.COLOR_BGR2RGB)
657
 
658
+ # ── Public API ────────────────────────────────────────────────────────────
 
 
659
 
660
+ def __call__(self, image) -> np.ndarray:
661
+ """Accept PIL Image or np.ndarray, return 2048-D feature vector."""
662
+ if isinstance(image, Image.Image):
663
+ arr = np.array(image.convert("RGB"))
664
+ elif isinstance(image, np.ndarray):
665
+ arr = image.astype(np.uint8)
666
+ else:
667
+ raise TypeError(f"Unsupported image type: {type(image)}")
668
 
669
+ arr = self._apply_clahe(arr)
670
+ # Light Gaussian blur suppresses metallic sheen amplified by CLAHE
671
+ arr = cv2.GaussianBlur(arr, (3, 3), 0)
 
672
 
673
+ tensor = self._transform(Image.fromarray(arr)).unsqueeze(0)
674
+ with torch.no_grad():
675
+ feat = self.model(tensor).squeeze().numpy()
676
+ return feat # shape (2048,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
 
678
 
679
+ # ──────────────────────────────────────────────────────────────────────────────
680
+ # LAYER 3 β€” TemplateMatcher
681
+ # ──────────────────────────────────────────────────────────────────────────────
682
+ class TemplateMatcher:
683
+ """
684
+ Manages the golden-template library and performs cosine-similarity matching.
685
 
686
+ Templates are stored as {part_name: feature_vector} and persisted to disk
687
+ via pickle so they survive Gradio restarts.
688
+ """
 
689
 
690
+ def __init__(self, cfg: DetectorConfig, extractor: FeatureExtractor):
691
+ self.cfg = cfg
692
+ self.extractor = extractor
693
+ self.templates: dict[str, np.ndarray] = {}
694
+ self._load()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
 
696
+ # ── Persistence ───────────────────────────────────────────────────────────
 
 
697
 
698
+ def _load(self):
699
+ if os.path.exists(self.cfg.templates_path):
700
+ try:
701
+ with open(self.cfg.templates_path, "rb") as f:
702
+ self.templates = pickle.load(f)
703
+ logger.info("Loaded %d templates from %s.",
704
+ len(self.templates), self.cfg.templates_path)
705
+ except Exception as exc:
706
+ logger.error("Failed to load templates: %s", exc)
707
+ self.templates = {}
708
 
709
+ def _save(self):
710
+ with open(self.cfg.templates_path, "wb") as f:
711
+ pickle.dump(self.templates, f)
712
 
713
+ # ── Template management ───────────────────────────────────────────────────
 
 
 
714
 
715
+ def add_template(self, image, part_name: str) -> str:
716
+ if image is None or not part_name.strip():
717
+ return "❌ Provide both an image and a part name."
718
+ feat = self.extractor(image)
719
+ self.templates[part_name.strip()] = feat
720
+ self._save()
721
+ return f"βœ… Template '{part_name.strip()}' saved ({len(self.templates)} total)."
722
 
723
+ def list_templates(self) -> str:
724
+ if not self.templates:
725
+ return "No templates saved yet."
726
+ return "\n".join(f" β€’ {n}" for n in sorted(self.templates))
727
 
728
+ # ── Matching ──────────────────────────────────────────────────────────────
 
 
 
 
 
 
729
 
730
+ @staticmethod
731
+ def _cosine(a: np.ndarray, b: np.ndarray) -> float:
732
+ return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12))
 
 
 
 
733
 
734
+ def match(
735
+ self,
736
+ image,
737
+ threshold: Optional[float] = None,
738
+ ) -> Tuple[str, Optional[str]]:
739
+ """
740
+ Match *image* against all stored templates.
741
+
742
+ Returns
743
+ -------
744
+ report : str β€” human-readable results
745
+ label : str | None β€” best match name if above threshold, else None
746
+ """
747
+ if image is None:
748
+ return "❌ No image provided.", None
749
+ if not self.templates:
750
+ return "⚠️ No templates stored. Add templates first.", None
751
 
752
+ thr = threshold if threshold is not None else self.cfg.default_threshold
753
+ feat = self.extractor(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
 
755
+ scores = sorted(
756
+ ((name, self._cosine(feat, vec)) for name, vec in self.templates.items()),
757
+ key=lambda x: x[1],
758
+ reverse=True,
759
+ )
760
 
761
+ best_name, best_score = scores[0]
762
+ matched = best_score >= thr
763
+
764
+ lines = [
765
+ f"πŸ” **Best Match** : {best_name}",
766
+ f"πŸ“Š **Confidence** : {best_score:.2%}",
767
+ f"{'βœ… MATCHED' if matched else '❌ NO MATCH'} (threshold {thr:.0%})",
768
+ "",
769
+ "**All Scores:**",
770
+ ]
771
+ for name, score in scores:
772
+ bar = "β–ˆ" * int(score * 20)
773
+ lines.append(f" {name:<30} {score:.2%} {bar}")
774
+
775
+ return "\n".join(lines), (best_name if matched else None)
776
+
777
+
778
+ # ──────────────────────────────────────────────────────────────────────────────
779
+ # LAYER 4 β€” MultiLayerPipeline (orchestrator)
780
+ # ──────────────────────────────────────────────────────────────────────────────
781
+ class MultiLayerPipeline:
782
+ """
783
+ Ties Layers 1–3 together and drives the Gradio UI callbacks.
784
+ """
785
+
786
+ def __init__(self, cfg: DetectorConfig):
787
+ self.cfg = cfg
788
+ self.extractor = FeatureExtractor(cfg)
789
+ self.saddle = SaddleRegionExtractor(cfg)
790
+
791
+ # Matcher is initialised *after* extractor so it can use it
792
+ self.matcher = TemplateMatcher(cfg, self.extractor)
793
 
794
+ # ── Detection pipeline ────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
795
 
796
+ def detect(
797
+ self,
798
+ raw_image: np.ndarray,
799
+ threshold: float,
800
+ ) -> Tuple[str, Optional[str], np.ndarray, np.ndarray]:
801
+ """
802
+ Full detection pipeline.
803
+
804
+ 1. Layer-1 β†’ crop saddle region
805
+ 2. Layer-2 β†’ extract features from cropped saddle
806
+ 3. Layer-3 β†’ match against templates
807
+
808
+ Returns
809
+ -------
810
+ report : str
811
+ label : str | None
812
+ debug_image : np.ndarray (annotated with circles + lines)
813
+ cropped_image : np.ndarray (the extracted saddle region sent to matching)
814
+ """
815
+ if raw_image is None:
816
+ return "❌ No image provided.", None, np.zeros((100, 100, 3), np.uint8), np.zeros((100, 100, 3), np.uint8)
817
+
818
+ # Layer 1 β€” geometric extraction
819
+ cropped, debug_img, geo_ok, geo_msg = self.saddle.extract(raw_image)
820
+
821
+ extraction_status = (
822
+ f"**Layer-1 (Geometric):** {geo_msg}\n\n"
823
+ )
824
 
825
+ # Layer 2+3 β€” feature extraction + matching on the cropped saddle
826
+ match_report, label = self.matcher.match(cropped, threshold)
 
827
 
828
+ full_report = extraction_status + match_report
829
 
830
+ return full_report, label, debug_img, cropped
831
 
832
+ # ── Template management pass-throughs ─────────────────────────────────────
 
833
 
834
+ def add_template(self, image: np.ndarray, part_name: str) -> str:
835
+ """Save template using the raw upload (no saddle crop β€” full control image)."""
836
+ return self.matcher.add_template(image, part_name)
837
 
838
+ def add_template_with_crop(self, image: np.ndarray, part_name: str) -> Tuple[str, np.ndarray, np.ndarray]:
839
+ """
840
+ Optional: crop saddle from the template image too, then save.
841
+ Returns status + debug + cropped for visual confirmation.
842
+ """
843
+ if image is None or not part_name.strip():
844
+ blank = np.zeros((100, 100, 3), np.uint8)
845
+ return "❌ Provide both image and name.", blank, blank
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
846
 
847
+ cropped, debug_img, ok, geo_msg = self.saddle.extract(image)
848
+ status = self.matcher.add_template(cropped, part_name)
849
+ return f"{geo_msg}\n{status}", debug_img, cropped
850
 
851
+ def list_templates(self) -> str:
852
+ return self.matcher.list_templates()
853
 
 
 
854
 
855
+ # ──────────────────────────────────────────────────────────────────────────────
856
+ # Initialise pipeline (single global instance β€” Gradio is single-process)
857
+ # ──────────────────────────────────────────────────────────────────────────────
858
+ cfg = DetectorConfig()
859
+ pipeline = MultiLayerPipeline(cfg)
860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
861
 
862
+ # ──────────────────────────────────────────────────────────────────────────────
863
+ # Gradio UI
864
+ # ──────────────────────────────────────────────────────────────────────────────
865
+ CSS = """
866
+ #title { text-align: center; }
867
+ .status-box { font-family: monospace; font-size: 0.85rem; }
868
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
869
 
870
+ with gr.Blocks(title="Engine Part Detection β€” Multi-Layer CV System", css=CSS) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
871
 
872
+ gr.Markdown(
873
+ """
874
+ # πŸ”§ Engine Part Detection System
875
+ ### Multi-Layer Computer Vision Architecture
876
+ | Layer | Role |
877
+ |-------|------|
878
+ | **L1 β€” Geometric** | HoughCircles β†’ bolt-hole centers β†’ line fitting β†’ saddle crop |
879
+ | **L2 β€” Feature** | CLAHE + GaussianBlur β†’ ResNet-50 (2048-D embeddings) |
880
+ | **L3 β€” Matching** | Cosine-similarity against golden templates |
881
+ """,
882
+ elem_id="title",
883
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
884
 
885
+ # ── Tab 1: Detect ──────────────────────────────────────────────────────────
886
+ with gr.Tab("πŸ” Detect Part"):
887
+ gr.Markdown("Upload or capture an engine image. Layer-1 will automatically locate the bearing-saddle region before matching.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
888
 
889
+ with gr.Row():
890
+ with gr.Column(scale=1):
891
+ detect_input = gr.Image(sources=["upload", "webcam"], type="numpy", label="Input Image")
892
+ threshold_slider = gr.Slider(0.50, 0.95, value=cfg.default_threshold,
893
+ step=0.01, label="Similarity Threshold")
894
+ detect_btn = gr.Button("πŸš€ Detect Part", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
895
 
896
+ with gr.Column(scale=1):
897
+ detect_report = gr.Textbox(label="Detection Report",
898
+ lines=14, elem_classes=["status-box"])
899
+ match_label = gr.Label(label="Matched Part")
900
 
901
+ with gr.Row():
902
+ debug_output = gr.Image(label="Layer-1 Debug View (bolt holes + fitted lines)")
903
+ cropped_output = gr.Image(label="Cropped Saddle (sent to matching)")
904
 
905
+ detect_btn.click(
906
+ fn=pipeline.detect,
907
+ inputs=[detect_input, threshold_slider],
908
+ outputs=[detect_report, match_label, debug_output, cropped_output],
909
+ api_name="detect",
910
+ )
911
 
912
+ # ── Tab 2: Add Template ────────────────────────────────────────────────────
913
+ with gr.Tab("βž• Add Template"):
914
+ gr.Markdown(
915
+ "Two modes:\n"
916
+ "- **Raw upload**: save the full image as-is (best for controlled reference shots)\n"
917
+ "- **Auto-crop + save**: Layer-1 extracts the saddle first, then saves that as the template"
918
+ )
919
+
920
+ with gr.Row():
921
+ with gr.Column():
922
+ tpl_input = gr.Image(sources=["upload"], type="numpy", label="Reference Image")
923
+ tpl_name = gr.Textbox(label="Part Name (e.g. bearing_saddle_good)", placeholder="part_name")
924
+
925
+ with gr.Row():
926
+ save_raw_btn = gr.Button("πŸ’Ύ Save Raw", variant="secondary")
927
+ save_crop_btn = gr.Button("βœ‚οΈ Crop Saddle + Save", variant="primary")
928
+
929
+ with gr.Column():
930
+ tpl_status = gr.Textbox(label="Status", lines=4)
931
+ tpl_debug = gr.Image(label="Layer-1 Debug (crop mode only)")
932
+ tpl_cropped = gr.Image(label="Cropped Saddle (crop mode only)")
933
+
934
+ save_raw_btn.click(
935
+ fn=pipeline.add_template,
936
+ inputs=[tpl_input, tpl_name],
937
+ outputs=tpl_status,
938
+ api_name="add_template_raw",
939
+ )
940
+ save_crop_btn.click(
941
+ fn=pipeline.add_template_with_crop,
942
+ inputs=[tpl_input, tpl_name],
943
+ outputs=[tpl_status, tpl_debug, tpl_cropped],
944
+ api_name="add_template_crop",
945
+ )
946
+
947
+ # ── Tab 3: Template Library ────────────────────────────────────────────────
948
+ with gr.Tab("πŸ“‹ Template Library"):
949
+ tpl_list = gr.Textbox(label="Saved Templates", lines=12)
950
+ refresh_btn = gr.Button("πŸ”„ Refresh")
951
+ refresh_btn.click(fn=pipeline.list_templates, outputs=tpl_list)
952
+ demo.load(fn=pipeline.list_templates, outputs=tpl_list)
953
+
954
+ gr.Markdown(
955
+ """
956
+ ---
957
+ **Architecture notes:**
958
+ Layer-1 uses an adaptive Y-gap split (largest inter-row gap > 5 % of image height)
959
+ rather than K-Means, making it robust to images with varying numbers of bolt holes or
960
+ non-uniform spacing. Line fitting uses scipy OLS regression for stability even when
961
+ holes are slightly misaligned. The per-column mask preserves angled saddles correctly.
962
+ """
963
+ )
964
+
965
+ if __name__ == "__main__":
966
+ demo.launch(share=False)