eho69 commited on
Commit
1be4124
Β·
verified Β·
1 Parent(s): 7307be4

classifucation

Browse files
Files changed (1) hide show
  1. app.py +212 -207
app.py CHANGED
@@ -272,6 +272,7 @@ logger = logging.getLogger(__name__)
272
 
273
  class FeatureExtractor:
274
  def __init__(self):
 
275
  backbone = models.resnet50(weights="IMAGENET1K_V1")
276
  self.model = nn.Sequential(*list(backbone.children())[:-1])
277
  self.model.eval()
@@ -291,15 +292,24 @@ class FeatureExtractor:
291
  rgb = np.array(rgb.convert("RGB"))
292
  if rgb.dtype != np.uint8:
293
  rgb = rgb.astype(np.uint8)
 
294
  if len(rgb.shape) == 2:
295
  rgb = cv2.cvtColor(rgb, cv2.COLOR_GRAY2RGB)
296
-
 
 
 
 
 
297
  input_tensor = self.transform(Image.fromarray(rgb)).unsqueeze(0)
298
-
 
299
  with torch.no_grad():
 
 
300
  backbone = models.resnet50(weights="IMAGENET1K_V1")
301
  backbone.eval()
302
-
303
  x = backbone.conv1(input_tensor)
304
  x = backbone.bn1(x)
305
  x = backbone.relu(x)
@@ -307,23 +317,27 @@ class FeatureExtractor:
307
  x = backbone.layer1(x)
308
  x = backbone.layer2(x)
309
  x = backbone.layer3(x)
310
- features_spatial = backbone.layer4(x) # [1, 2048, 7, 7]
311
-
 
312
  feat = torch.mean(features_spatial, dim=[2, 3]).squeeze().cpu().numpy()
313
-
 
314
  amap = torch.sum(features_spatial, dim=1).squeeze().cpu().numpy()
315
  amap = np.maximum(amap, 0)
316
  amap /= (np.max(amap) + 1e-8)
317
  amap = cv2.resize(amap, (rgb.shape[1], rgb.shape[0]))
318
  amap = np.uint8(255 * amap)
319
  heatmap = cv2.applyColorMap(amap, cv2.COLORMAP_JET)
 
 
 
320
  heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
321
  overlay = cv2.addWeighted(rgb, 0.6, heatmap_rgb, 0.4, 0)
322
 
323
  norm = np.linalg.norm(feat)
324
  return (feat / norm if norm > 1e-8 else feat), overlay
325
 
326
-
327
  # ───────────────────────────────────────────────────────────────────────────────
328
  # MASTER ORCHESTRATOR β”‚ EnginePartDetector
329
  # ───────────────────────────────────────────────────────────────────────────────
@@ -333,87 +347,118 @@ class EnginePartDetector:
333
 
334
  def __init__(self):
335
  self.feature_extractor = FeatureExtractor()
336
- self.templates: dict[str, np.ndarray] = {}
337
- self._load_templates()
 
 
 
338
 
339
  # ── Persistence ───────────────────────────────────────────────────────────
340
 
341
- def _load_templates(self) -> None:
342
  if os.path.exists(self.TEMPLATE_FILE):
343
  try:
344
  with open(self.TEMPLATE_FILE, "rb") as f:
345
- self.templates = pickle.load(f)
346
- logger.info(f"Loaded {len(self.templates)} template(s).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  except Exception as e:
348
- logger.error(f"Template load failed: {e}")
349
- self.templates = {}
350
 
351
- def _persist_templates(self) -> None:
352
  try:
353
  with open(self.TEMPLATE_FILE, "wb") as f:
354
- pickle.dump(self.templates, f)
 
 
 
 
355
  except Exception as e:
356
- logger.error(f"Template save failed: {e}")
357
 
358
  # ── Layer 1: ROI Detection & Extraction ───────────────────────────────────
359
 
360
  @staticmethod
361
- def detect_connect_and_crop(
362
- image_source: np.ndarray,
363
- ) -> tuple[np.ndarray, np.ndarray, str, list, list]:
364
-
 
 
 
365
  img_rgb = image_source
366
  img_h, img_w = img_rgb.shape[:2]
367
  gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
368
  gray = cv2.GaussianBlur(gray, (7, 7), 0)
369
-
 
370
  circles = cv2.HoughCircles(
371
  gray, cv2.HOUGH_GRADIENT, dp=1.2, minDist=60,
372
  param1=100, param2=35, minRadius=12, maxRadius=45
373
  )
374
-
375
  if circles is None:
376
- return img_rgb, img_rgb, "❌ No bolt holes detected.", [], []
377
-
378
  circles = np.round(circles[0]).astype(int)
379
-
 
380
  ys = sorted([c[1] for c in circles])
381
  y_median = np.median(ys)
382
-
383
  top_row = sorted([c for c in circles if c[1] < y_median], key=lambda x: x[0])
384
  bot_row = sorted([c for c in circles if c[1] >= y_median], key=lambda x: x[0])
385
-
386
  if len(top_row) < 2 or len(bot_row) < 2:
387
- return img_rgb, img_rgb, "⚠️ Insufficient hole rows for localization.", [], []
388
 
 
389
  y_top = int(np.mean([c[1] for c in top_row]))
390
  y_bot = int(np.mean([c[1] for c in bot_row]))
391
-
 
392
  xs = [c[0] for c in circles]
393
  x_min, x_max = min(xs), max(xs)
394
  padding_h = 60
395
  padding_v = 20
396
-
397
  x_start = max(0, x_min - padding_h)
398
  x_end = min(img_w, x_max + padding_h)
399
  y_start = max(0, min(y_top, y_bot) - padding_v)
400
  y_end = min(img_h, max(y_top, y_bot) + padding_v)
401
 
 
402
  vis_img = img_rgb.copy()
403
  LINE_COLOR = (0, 255, 0)
404
  HOLE_COLOR = (255, 0, 0)
405
-
 
406
  cv2.line(vis_img, (0, y_top), (img_w, y_top), LINE_COLOR, 3)
407
  cv2.line(vis_img, (0, y_bot), (img_w, y_bot), LINE_COLOR, 3)
408
-
409
  for (x, y, r) in circles:
410
  cv2.circle(vis_img, (x, y), r, HOLE_COLOR, 3)
411
  cv2.circle(vis_img, (x, y), 2, (255, 255, 255), -1)
412
 
 
413
  cropped_img = img_rgb[y_start:y_end, x_start:x_end]
414
-
415
  if cropped_img.size == 0:
416
- return vis_img, img_rgb, "⚠️ ROI selection failed.", [], []
417
 
418
  stats_text = (
419
  f"βœ… **Full Saddle Band Extracted**\n"
@@ -423,72 +468,22 @@ class EnginePartDetector:
423
  f"β€’ ROI Size: {cropped_img.shape[1]}x{cropped_img.shape[0]} px"
424
  )
425
 
426
- return vis_img, cropped_img, stats_text, list(top_row), list(bot_row)
427
-
428
- # ─�� Vertical-line detection on structural edge map ───────────────────────
429
-
430
- @staticmethod
431
- def detect_vertical_lines_on_edge_map(
432
- roi_enhanced: np.ndarray,
433
- angle_tolerance_deg: float = 12.0,
434
- min_line_length_ratio: float = 0.15,
435
- ) -> tuple[bool, np.ndarray, str]:
436
-
437
- gray = cv2.cvtColor(roi_enhanced, cv2.COLOR_RGB2GRAY)
438
- clahe = cv2.createCLAHE(clipLimit=2.8, tileGridSize=(8, 8))
439
- gray = clahe.apply(gray)
440
- edges = cv2.Canny(gray, 50, 150)
441
-
442
- h, w = edges.shape
443
- min_len = max(20, int(h * min_line_length_ratio))
444
-
445
- lines = cv2.HoughLinesP(
446
- edges, rho=1, theta=np.pi / 180,
447
- threshold=20, minLineLength=min_len, maxLineGap=10,
448
- )
449
-
450
- # RGB canvas from edge map
451
- canvas = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
452
-
453
- vertical_lines = []
454
- if lines is not None:
455
- for seg in lines:
456
- x1, y1, x2, y2 = seg[0]
457
- dx = abs(x2 - x1)
458
- dy = abs(y2 - y1)
459
- angle = np.degrees(np.arctan2(dx, dy + 1e-6))
460
- if angle <= angle_tolerance_deg:
461
- vertical_lines.append((x1, y1, x2, y2, dy))
462
-
463
- # Sort by length β€” longest first
464
- vertical_lines.sort(key=lambda v: v[4], reverse=True)
465
- has_vertical = len(vertical_lines) > 0
466
-
467
- if has_vertical:
468
- for (x1, y1, x2, y2, _) in vertical_lines:
469
- cv2.line(canvas, (x1, y1), (x2, y2), (0, 255, 0), 2)
470
- cv2.rectangle(canvas, (0, 0), (240, 46), (0, 150, 0), -1)
471
- cv2.putText(canvas, f"PRESENT ({len(vertical_lines)})",
472
- (6, 34), cv2.FONT_HERSHEY_DUPLEX, 0.85, (255, 255, 255), 2)
473
- status = (f"βœ… **Vertical lines PRESENT** β€” "
474
- f"{len(vertical_lines)} near-vertical line(s) detected.")
475
- else:
476
- cv2.rectangle(canvas, (0, 0), (190, 46), (180, 0, 0), -1)
477
- cv2.putText(canvas, "ABSENT",
478
- (6, 34), cv2.FONT_HERSHEY_DUPLEX, 1.1, (255, 255, 255), 2)
479
- status = "❌ **Vertical lines ABSENT** β€” No near-vertical lines on edge map."
480
-
481
- return has_vertical, canvas, status
482
 
483
  @staticmethod
484
  def enhance_roi(roi: np.ndarray) -> np.ndarray:
485
  """Apply high-contrast CLAHE to highlight blurred lines/features."""
486
  if roi is None or roi.size == 0:
487
  return roi
 
 
488
  lab = cv2.cvtColor(roi, cv2.COLOR_RGB2LAB)
489
  l, a, b = cv2.split(lab)
490
- clahe = cv2.createCLAHE(clipLimit=3.9, tileGridSize=(8, 8))
 
 
491
  cl = clahe.apply(l)
 
492
  merged = cv2.merge((cl, a, b))
493
  enhanced = cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
494
  return enhanced
@@ -504,104 +499,112 @@ class EnginePartDetector:
504
 
505
  # ── Public API ────────────────────────────────────────────────────────────
506
 
507
- def save_template(self, image: np.ndarray, part_name: str) -> tuple[str, np.ndarray | None]:
508
  if image is None:
509
  return "❌ No image supplied.", None
510
- if not part_name or not part_name.strip():
511
- return "❌ Part name is empty.", None
512
 
513
- part_name = part_name.strip()
514
-
515
- vis, roi, log, top_row, bot_row = self.detect_connect_and_crop(image)
 
516
  if "❌" in log or "⚠️" in log:
517
  return log, None
518
 
 
519
  roi_enhanced = self.enhance_roi(roi)
 
 
520
  features, _ = self.feature_extractor.extract(roi_enhanced)
521
- self.templates[part_name] = {
522
- "features": features,
523
- "roi": roi_enhanced,
524
- }
525
- self._persist_templates()
 
 
 
526
 
527
- return f"βœ… Template '{part_name}' saved!\n\n{log}", roi
528
 
529
  def match_part(
530
  self,
531
  image: np.ndarray,
532
  threshold: float = 0.70,
533
- ) -> tuple[str, dict | None, np.ndarray | None, np.ndarray | None, np.ndarray | None]:
534
- """
535
- Returns:
536
- report_text, label_dict, field_vis, attention_map, annotated_edge_map
537
- """
538
  if image is None:
539
- return "❌ No image supplied.", None, None, None, None
540
- if not self.templates:
541
- return "⚠️ No templates yet. Add at least one template first.", None, None, None, None
542
 
543
- # ── Layer 1: ROI ──────────────────────────────────────────────────────
544
- vis, roi, log, top_row, bot_row = self.detect_connect_and_crop(image)
545
  if "❌" in log or "⚠️" in log:
546
- return log, None, vis, None, None
547
 
548
- # ── Layer 2: Feature Matching ─────────────────────────────────────────
549
  roi_enhanced = self.enhance_roi(roi)
 
 
550
  query_feat, attention_map = self.feature_extractor.extract(roi_enhanced)
551
 
552
- scores = []
553
- for name, data in self.templates.items():
554
- sim = self._cosine(query_feat, data["features"])
555
- scores.append((name, sim))
556
- scores.sort(key=lambda x: x[1], reverse=True)
557
-
558
- best_name, best_score = scores[0]
559
- feat_matched = best_score >= threshold
560
-
561
- # ── Structural Edge Map + Vertical-Line Detection ─────────────────────
562
- has_vertical, edge_annotated, vline_status = self.detect_vertical_lines_on_edge_map(roi_enhanced)
563
-
564
- # ── Final verdict: feature match only (vertical line is informational) ─
565
- final_pass = feat_matched
566
- final_icon = "βœ… PASS" if final_pass else "❌ FAIL"
567
-
568
- report_lines = [
569
- f"## 🏁 Final Verdict: {final_icon}",
570
- f"",
571
- f"### πŸ” Feature Match",
572
- f"{'βœ…' if feat_matched else '❌'} **Best Match**: `{best_name}`",
573
- f"πŸ“Š **Confidence**: {best_score:.2%}",
574
- f"",
575
- f"### πŸ“ Vertical Line Analysis (Edge Map)",
576
- vline_status,
577
- f"",
578
- f"---",
579
- f"",
580
- f"### πŸ“Έ Field Detection",
581
  log,
582
  ]
 
 
 
 
 
583
 
584
- if len(scores) > 1:
585
- report_lines.append("\n**Other Template Scores:**")
586
- for name, sim in scores[1:5]:
587
- report_lines.append(f" β€’ `{name}`: {sim:.3f}")
588
-
589
- label_dict = {name: float(sim) for name, sim in scores[:5]}
590
-
591
- return "\n".join(report_lines), label_dict, vis, attention_map, edge_annotated
592
 
593
- def get_template_roi(self, part_name: str) -> np.ndarray | None:
594
- if part_name in self.templates:
595
- return self.templates[part_name].get("roi")
596
- return None
597
 
598
  def list_templates(self) -> str:
599
- if not self.templates:
600
- return "No templates saved yet."
601
- header = f"Total: {len(self.templates)} template(s)\n" + "─" * 30
602
- body = "\n".join(f" β€’ {n}" for n in sorted(self.templates.keys()))
603
- return f"{header}\n{body}"
604
-
 
605
 
606
  # ───────────────────────────────────────────────────────────────────────────────
607
  # Gradio Application
@@ -609,54 +612,44 @@ class EnginePartDetector:
609
 
610
  detector = EnginePartDetector()
611
 
612
-
613
  def detect_part(image, threshold):
614
  return detector.match_part(image, threshold)
615
 
 
 
616
 
617
- def add_template(image, part_name):
618
- return detector.save_template(image, part_name)
619
-
620
-
621
- def list_templates():
622
  return detector.list_templates()
623
 
624
-
625
  custom_css = """
626
- .container { max-width: 1400px; margin: auto; }
627
- .header { text-align: center; margin-bottom: 2rem; }
628
- .footer { text-align: center; margin-top: 2rem; color: #666; }
629
  """
630
 
631
  with gr.Blocks(title="Engine Part CV System", theme=gr.themes.Soft(), css=custom_css) as demo:
632
  gr.Markdown("""
633
  <div class="header">
634
- <h1>πŸ”§ Engine Part Detection System</h1>
635
- <p>
636
- <strong>Layer 1:</strong> Hough Bolt-Hole Detection &amp; Crop &nbsp;|&nbsp;
637
- <strong>Layer 2:</strong> ResNet50 Feature Matching &nbsp;|&nbsp;
638
- <strong>Edge Map:</strong> Vertical-Line Detection
639
- </p>
640
  </div>
641
  """)
642
 
643
  with gr.Tab("πŸ” Match Inspection"):
644
  with gr.Row():
645
  with gr.Column(scale=1):
646
- detect_input = gr.Image(sources=["upload", "webcam"], type="numpy", label="Input Image")
647
  threshold_slider = gr.Slider(0.5, 0.99, value=0.75, step=0.01, label="Matching Threshold")
648
- detect_btn = gr.Button("πŸ” Run Inspection", variant="primary")
649
 
650
  with gr.Column(scale=1):
651
  detect_output = gr.Markdown(label="Match Report")
652
- match_label = gr.Label(label="Top Scores", num_top_classes=5)
653
-
654
- with gr.Row():
655
- vis_output = gr.Image(label="Field Visualization (bolt holes)")
656
- attn_output = gr.Image(label="AI Attention Heatmap")
657
-
658
- with gr.Row():
659
- edge_output = gr.Image(label="Structural Edge Map (green lines = vertical PRESENT | red banner = ABSENT)")
660
 
661
  detect_btn.click(
662
  fn=detect_part,
@@ -665,39 +658,51 @@ with gr.Blocks(title="Engine Part CV System", theme=gr.themes.Soft(), css=custom
665
  api_name="detect_part",
666
  )
667
 
668
- with gr.Tab("πŸ’Ύ Add Golden Template"):
669
  with gr.Row():
670
  with gr.Column(scale=1):
671
- template_input = gr.Image(sources=["upload"], type="numpy", label="Reference Image")
672
- part_name_input = gr.Textbox(label="Part Name", placeholder="e.g. bearing_cap_v8_A")
673
- add_btn = gr.Button("πŸ’Ύ Register Template", variant="primary")
 
 
 
 
 
674
  with gr.Column(scale=1):
675
- add_status = gr.Textbox(label="Registration Status", lines=5)
676
- add_roi_view = gr.Image(label="Registered Cropped ROI", interactive=False)
677
 
678
  add_btn.click(
679
- fn=add_template,
680
- inputs=[template_input, part_name_input],
681
  outputs=[add_status, add_roi_view],
682
- api_name="add_template",
683
  )
684
 
685
- with gr.Tab("πŸ“‹ Library"):
686
  with gr.Row():
687
  with gr.Column(scale=1):
688
- template_list = gr.Textbox(label="Current Golden Templates", lines=12)
689
- refresh_btn = gr.Button("πŸ”„ Refresh Library")
690
  with gr.Column(scale=1):
691
- library_roi_view = gr.Image(label="Template ROI Preview", interactive=False)
692
-
693
  def update_library_preview():
694
- if detector.templates:
695
- first_name = sorted(detector.templates.keys())[0]
696
  return detector.list_templates(), detector.get_template_roi(first_name)
697
- return "No templates saved yet.", None
698
 
699
  refresh_btn.click(fn=update_library_preview, outputs=[template_list, library_roi_view])
700
  demo.load(fn=update_library_preview, outputs=[template_list, library_roi_view])
701
 
 
 
 
 
 
 
 
702
  if __name__ == "__main__":
703
- demo.launch(share=False, show_error=True)
 
272
 
273
  class FeatureExtractor:
274
  def __init__(self):
275
+ # Using ResNet50 for 2048-D feature vectors
276
  backbone = models.resnet50(weights="IMAGENET1K_V1")
277
  self.model = nn.Sequential(*list(backbone.children())[:-1])
278
  self.model.eval()
 
292
  rgb = np.array(rgb.convert("RGB"))
293
  if rgb.dtype != np.uint8:
294
  rgb = rgb.astype(np.uint8)
295
+
296
  if len(rgb.shape) == 2:
297
  rgb = cv2.cvtColor(rgb, cv2.COLOR_GRAY2RGB)
298
+
299
+ # We want the layer BEFORE the global pooling to get spatial info
300
+ # resnet.layer4 is the last block
301
+ # self.model is nn.Sequential(*list(backbone.children())[:-1])
302
+ # children()[:-1] = [conv1, bn1, relu, maxpool, layer1, layer2, layer3, layer4]
303
+
304
  input_tensor = self.transform(Image.fromarray(rgb)).unsqueeze(0)
305
+
306
+ # Get activations from the last conv layer (Layer 4)
307
  with torch.no_grad():
308
+ # Run through the layers up to global pooling
309
+ # Using the original backbone for Easier Access to sub-layers
310
  backbone = models.resnet50(weights="IMAGENET1K_V1")
311
  backbone.eval()
312
+
313
  x = backbone.conv1(input_tensor)
314
  x = backbone.bn1(x)
315
  x = backbone.relu(x)
 
317
  x = backbone.layer1(x)
318
  x = backbone.layer2(x)
319
  x = backbone.layer3(x)
320
+ features_spatial = backbone.layer4(x) # [1, 2048, 7, 7]
321
+
322
+ # Global Average Pooling to get the vector
323
  feat = torch.mean(features_spatial, dim=[2, 3]).squeeze().cpu().numpy()
324
+
325
+ # Create Heatmap: sum across channels to see "hot" regions
326
  amap = torch.sum(features_spatial, dim=1).squeeze().cpu().numpy()
327
  amap = np.maximum(amap, 0)
328
  amap /= (np.max(amap) + 1e-8)
329
  amap = cv2.resize(amap, (rgb.shape[1], rgb.shape[0]))
330
  amap = np.uint8(255 * amap)
331
  heatmap = cv2.applyColorMap(amap, cv2.COLORMAP_JET)
332
+
333
+ # Overlay heatmap on original image
334
+ # Convert BGR heatmap to RGB
335
  heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
336
  overlay = cv2.addWeighted(rgb, 0.6, heatmap_rgb, 0.4, 0)
337
 
338
  norm = np.linalg.norm(feat)
339
  return (feat / norm if norm > 1e-8 else feat), overlay
340
 
 
341
  # ───────────────────────────────────────────────────────────────────────────────
342
  # MASTER ORCHESTRATOR β”‚ EnginePartDetector
343
  # ───────────────────────────────────────────────────────────────────────────────
 
347
 
348
  def __init__(self):
349
  self.feature_extractor = FeatureExtractor()
350
+ # Changed from simple templates to class-based feature lists
351
+ self.classes: dict[str, list[np.ndarray]] = {}
352
+ # We also store an example ROI for each class for visualization
353
+ self.class_rois: dict[str, np.ndarray] = {}
354
+ self._load_data()
355
 
356
  # ── Persistence ───────────────────────────────────────────────────────────
357
 
358
+ def _load_data(self) -> None:
359
  if os.path.exists(self.TEMPLATE_FILE):
360
  try:
361
  with open(self.TEMPLATE_FILE, "rb") as f:
362
+ data = pickle.load(f)
363
+ # Support legacy format if needed, but here we assume the new format
364
+ if isinstance(data, dict):
365
+ # If old format was {name: {"features": feat, "roi": roi}}
366
+ # we convert it to {name: [feat]}
367
+ self.classes = {}
368
+ self.class_rois = {}
369
+ for k, v in data.items():
370
+ if isinstance(v, dict) and "features" in v:
371
+ self.classes[k] = [v["features"]]
372
+ self.class_rois[k] = v.get("roi")
373
+ else:
374
+ self.classes[k] = v
375
+ else:
376
+ self.classes = {}
377
+ logger.info(f"Loaded {len(self.classes)} class(es).")
378
  except Exception as e:
379
+ logger.error(f"Data load failed: {e}")
380
+ self.classes = {}
381
 
382
+ def _persist_data(self) -> None:
383
  try:
384
  with open(self.TEMPLATE_FILE, "wb") as f:
385
+ pickle.dump(self.classes, f)
386
+ # Separately save ROIs if needed, but for now we just persist classes
387
+ # In a real app we'd save ROIs too. Let's include them in a combined dict.
388
+ with open("class_data.pkl", "wb") as f:
389
+ pickle.dump({"classes": self.classes, "rois": self.class_rois}, f)
390
  except Exception as e:
391
+ logger.error(f"Data save failed: {e}")
392
 
393
  # ── Layer 1: ROI Detection & Extraction ───────────────────────────────────
394
 
395
  @staticmethod
396
+ def detect_connect_and_crop(image_source: np.ndarray) -> tuple[np.ndarray, np.ndarray, str]:
397
+ """
398
+ 1. Detects bolt holes.
399
+ 2. Separates into Top and Bottom rows.
400
+ 3. Fits horizontal reference lines.
401
+ 4. Crops the FULL horizontal band between rows (includes regions between saddles).
402
+ """
403
  img_rgb = image_source
404
  img_h, img_w = img_rgb.shape[:2]
405
  gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
406
  gray = cv2.GaussianBlur(gray, (7, 7), 0)
407
+
408
+ # ── Step 1: Detect Circles ────────────────────────────────────────────
409
  circles = cv2.HoughCircles(
410
  gray, cv2.HOUGH_GRADIENT, dp=1.2, minDist=60,
411
  param1=100, param2=35, minRadius=12, maxRadius=45
412
  )
413
+
414
  if circles is None:
415
+ return img_rgb, img_rgb, "❌ No bolt holes detected."
416
+
417
  circles = np.round(circles[0]).astype(int)
418
+
419
+ # ── Step 2: Row Separation ────────────────────────────────────────────
420
  ys = sorted([c[1] for c in circles])
421
  y_median = np.median(ys)
422
+
423
  top_row = sorted([c for c in circles if c[1] < y_median], key=lambda x: x[0])
424
  bot_row = sorted([c for c in circles if c[1] >= y_median], key=lambda x: x[0])
425
+
426
  if len(top_row) < 2 or len(bot_row) < 2:
427
+ return img_rgb, img_rgb, "⚠️ Insufficient hole rows for localization."
428
 
429
+ # ── Step 3: Reference Lines ───────────────────────────────────────────
430
  y_top = int(np.mean([c[1] for c in top_row]))
431
  y_bot = int(np.mean([c[1] for c in bot_row]))
432
+
433
+ # Horizontal bounds (First hole to Last hole)
434
  xs = [c[0] for c in circles]
435
  x_min, x_max = min(xs), max(xs)
436
  padding_h = 60
437
  padding_v = 20
438
+
439
  x_start = max(0, x_min - padding_h)
440
  x_end = min(img_w, x_max + padding_h)
441
  y_start = max(0, min(y_top, y_bot) - padding_v)
442
  y_end = min(img_h, max(y_top, y_bot) + padding_v)
443
 
444
+ # ── Step 4: Visualization ─────────────────────────────────────────────
445
  vis_img = img_rgb.copy()
446
  LINE_COLOR = (0, 255, 0)
447
  HOLE_COLOR = (255, 0, 0)
448
+
449
+ # Draw lines and detected holes
450
  cv2.line(vis_img, (0, y_top), (img_w, y_top), LINE_COLOR, 3)
451
  cv2.line(vis_img, (0, y_bot), (img_w, y_bot), LINE_COLOR, 3)
452
+
453
  for (x, y, r) in circles:
454
  cv2.circle(vis_img, (x, y), r, HOLE_COLOR, 3)
455
  cv2.circle(vis_img, (x, y), 2, (255, 255, 255), -1)
456
 
457
+ # ── Step 5: Full Band Crop ────────────────────────────────────────────
458
  cropped_img = img_rgb[y_start:y_end, x_start:x_end]
459
+
460
  if cropped_img.size == 0:
461
+ return vis_img, img_rgb, "⚠️ ROI selection failed."
462
 
463
  stats_text = (
464
  f"βœ… **Full Saddle Band Extracted**\n"
 
468
  f"β€’ ROI Size: {cropped_img.shape[1]}x{cropped_img.shape[0]} px"
469
  )
470
 
471
+ return vis_img, cropped_img, stats_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
 
473
  @staticmethod
474
  def enhance_roi(roi: np.ndarray) -> np.ndarray:
475
  """Apply high-contrast CLAHE to highlight blurred lines/features."""
476
  if roi is None or roi.size == 0:
477
  return roi
478
+
479
+ # Convert to LAB space to apply CLAHE on L (luminance) channel
480
  lab = cv2.cvtColor(roi, cv2.COLOR_RGB2LAB)
481
  l, a, b = cv2.split(lab)
482
+
483
+ # ClipLimit 10.0 provides very high contrast as requested
484
+ clahe = cv2.createCLAHE(clipLimit=10.0, tileGridSize=(8, 8))
485
  cl = clahe.apply(l)
486
+
487
  merged = cv2.merge((cl, a, b))
488
  enhanced = cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
489
  return enhanced
 
499
 
500
  # ── Public API ────────────────────────────────────────────────────────────
501
 
502
+ def add_to_class(self, image: np.ndarray, class_name: str) -> tuple[str, np.ndarray | None]:
503
  if image is None:
504
  return "❌ No image supplied.", None
505
+ if not class_name or not class_name.strip():
506
+ return "❌ Class name is empty.", None
507
 
508
+ class_name = class_name.strip()
509
+
510
+ # Layer 1: Localization
511
+ vis, roi, log = self.detect_connect_and_crop(image)
512
  if "❌" in log or "⚠️" in log:
513
  return log, None
514
 
515
+ # Enhance ROI
516
  roi_enhanced = self.enhance_roi(roi)
517
+
518
+ # Layer 2: Feature Extraction
519
  features, _ = self.feature_extractor.extract(roi_enhanced)
520
+
521
+ if class_name not in self.classes:
522
+ self.classes[class_name] = []
523
+
524
+ self.classes[class_name].append(features)
525
+ self.class_rois[class_name] = roi_enhanced # Keep the latest ROI as reference
526
+
527
+ self._persist_data()
528
 
529
+ return f"βœ… Image added to class '{class_name}'! (Now has {len(self.classes[class_name])} samples)\n\n{log}", roi
530
 
531
  def match_part(
532
  self,
533
  image: np.ndarray,
534
  threshold: float = 0.70,
535
+ ) -> tuple[str, dict | None, np.ndarray | None, np.ndarray | None]:
 
 
 
 
536
  if image is None:
537
+ return "❌ No image supplied.", None, None, None
538
+ if not self.classes:
539
+ return "⚠️ No trained classes yet. Add samples to at least one class (e.g. 'Perfect').", None, None, None
540
 
541
+ # Layer 1: Localization
542
+ vis, roi, log = self.detect_connect_and_crop(image)
543
  if "❌" in log or "⚠️" in log:
544
+ return log, None, vis, None
545
 
546
+ # Enhance ROI
547
  roi_enhanced = self.enhance_roi(roi)
548
+
549
+ # Layer 2: Feature Extraction
550
  query_feat, attention_map = self.feature_extractor.extract(roi_enhanced)
551
 
552
+ # Layer 3: Latent Space Matching (Cosine Similarity to centroids)
553
+ class_scores = []
554
+ for name, vectors in self.classes.items():
555
+ # Calculate centroid (neighborhood center)
556
+ centroid = np.mean(vectors, axis=0)
557
+ sim = self._cosine(query_feat, centroid)
558
+ class_scores.append((name, sim))
559
+
560
+ class_scores.sort(key=lambda x: x[1], reverse=True)
561
+
562
+ best_class, best_score = class_scores[0]
563
+ matched = best_score >= threshold
564
+ status = f"βœ… CLASSIFIED AS: {best_class}" if matched else "❌ UNCERTAIN (below threshold)"
565
+
566
+ lines = [
567
+ f"{'βœ…' if matched else '❌'} **Top Prediction**: `{best_class}`",
568
+ f"πŸ“Š **Cosine Similarity**: {best_score:.2%}",
569
+ f"🎯 **Status**: {status}",
570
+ "",
571
+ "### πŸ” Multi-Stage Architecture Analysis",
572
+ "1. **Localization**: Bolt holes detected, horizontal band cropped.",
573
+ "2. **Feature Extraction**: ResNet50 extracted unique mathematical fingerprint.",
574
+ "3. **Matching**: Nearest cluster identified in latent space via Cosine Similarity.",
575
+ "",
576
+ "The heatmap on the right shows exactly where the AI is focusing.",
577
+ "- **Red Regions**: Areas defining the class (e.g., surface quality, edges).",
578
+ "",
579
+ "---",
 
580
  log,
581
  ]
582
+
583
+ if len(class_scores) > 1:
584
+ lines.append("\n**Class Probabilities (Latent Distance):**")
585
+ for name, sim in class_scores:
586
+ lines.append(f" β€’ `{name}`: {sim:.3f}")
587
 
588
+ label_dict = {name: float(sim) for name, sim in class_scores}
589
+
590
+ # Edge Map for structural analysis
591
+ gray_enhanced = cv2.cvtColor(roi_enhanced, cv2.COLOR_RGB2GRAY)
592
+ edges = cv2.Canny(gray_enhanced, 50, 150)
593
+ edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
594
+
595
+ return "\n".join(lines), label_dict, vis, attention_map, edges_rgb
596
 
597
+ def get_template_roi(self, class_name: str) -> np.ndarray | None:
598
+ return self.class_rois.get(class_name)
 
 
599
 
600
  def list_templates(self) -> str:
601
+ if not self.classes:
602
+ return "No classes trained yet."
603
+ header = f"Total: {len(self.classes)} class(es)\n" + "─" * 30
604
+ body = []
605
+ for name, vectors in sorted(self.classes.items()):
606
+ body.append(f" β€’ {name}: {len(vectors)} samples")
607
+ return f"{header}\n" + "\n".join(body)
608
 
609
  # ───────────────────────────────────────────────────────────────────────────────
610
  # Gradio Application
 
612
 
613
  detector = EnginePartDetector()
614
 
 
615
  def detect_part(image, threshold):
616
  return detector.match_part(image, threshold)
617
 
618
+ def add_sample(image, class_name):
619
+ return detector.add_to_class(image, class_name)
620
 
621
+ def list_classes():
 
 
 
 
622
  return detector.list_templates()
623
 
624
+ # Custom CSS for premium look
625
  custom_css = """
626
+ .container { max-width: 1200px; margin: auto; }
627
+ .header { text-align: center; margin-bottom: 2rem; }
628
+ .footer { text-align: center; margin-top: 2rem; color: #666; }
629
  """
630
 
631
  with gr.Blocks(title="Engine Part CV System", theme=gr.themes.Soft(), css=custom_css) as demo:
632
  gr.Markdown("""
633
  <div class="header">
634
+ <h1>πŸ”§ Engine Part CV System</h1>
635
+ <p><strong>Multi-Stage Architecture:</strong> Localization β†’ Feature Fingerprint (ResNet) β†’ Latent Space Matching</p>
 
 
 
 
636
  </div>
637
  """)
638
 
639
  with gr.Tab("πŸ” Match Inspection"):
640
  with gr.Row():
641
  with gr.Column(scale=1):
642
+ detect_input = gr.Image(sources=["upload", "webcam"], type="numpy", label="Input Image")
643
  threshold_slider = gr.Slider(0.5, 0.99, value=0.75, step=0.01, label="Matching Threshold")
644
+ detect_btn = gr.Button("πŸ” Run Inspection", variant="primary")
645
 
646
  with gr.Column(scale=1):
647
  detect_output = gr.Markdown(label="Match Report")
648
+ match_label = gr.Label(label="Top Scores", num_top_classes=5)
649
+ with gr.Row():
650
+ vis_output = gr.Image(label="Field Visualization")
651
+ attn_output = gr.Image(label="AI Attention Heatmap")
652
+ edge_output = gr.Image(label="Structural Edge Map (Line Detection)")
 
 
 
653
 
654
  detect_btn.click(
655
  fn=detect_part,
 
658
  api_name="detect_part",
659
  )
660
 
661
+ with gr.Tab("πŸ’Ύ Train Latent Space"):
662
  with gr.Row():
663
  with gr.Column(scale=1):
664
+ template_input = gr.Image(sources=["upload"], type="numpy", label="Training Image")
665
+ class_name_input = gr.Dropdown(
666
+ choices=["Perfect", "Defected", "Unknown"],
667
+ label="Class Label",
668
+ value="Perfect",
669
+ allow_custom_value=True
670
+ )
671
+ add_btn = gr.Button("πŸ’Ύ Add to Cluster", variant="primary")
672
  with gr.Column(scale=1):
673
+ add_status = gr.Textbox(label="Training Status", lines=5)
674
+ add_roi_view = gr.Image(label="Processed Training ROI", interactive=False)
675
 
676
  add_btn.click(
677
+ fn=add_sample,
678
+ inputs=[template_input, class_name_input],
679
  outputs=[add_status, add_roi_view],
680
+ api_name="add_sample",
681
  )
682
 
683
+ with gr.Tab("πŸ“‹ Class Library"):
684
  with gr.Row():
685
  with gr.Column(scale=1):
686
+ template_list = gr.Textbox(label="Current Trained Classes", lines=12)
687
+ refresh_btn = gr.Button("πŸ”„ Refresh Clusters")
688
  with gr.Column(scale=1):
689
+ library_roi_view = gr.Image(label="Last Reference ROI", interactive=False)
690
+
691
  def update_library_preview():
692
+ if detector.classes:
693
+ first_name = sorted(detector.classes.keys())[0]
694
  return detector.list_templates(), detector.get_template_roi(first_name)
695
+ return "No classes trained yet.", None
696
 
697
  refresh_btn.click(fn=update_library_preview, outputs=[template_list, library_roi_view])
698
  demo.load(fn=update_library_preview, outputs=[template_list, library_roi_view])
699
 
700
+ gr.Markdown("""
701
+ ---
702
+ <div class="footer">
703
+ <p>Engine Part CV System β€’ Powered by PyTorch & OpenCV</p>
704
+ </div>
705
+ """)
706
+
707
  if __name__ == "__main__":
708
+ demo.launch(share=False, show_error=True)