classifucation
Browse files
app.py
CHANGED
|
@@ -272,6 +272,7 @@ logger = logging.getLogger(__name__)
|
|
| 272 |
|
| 273 |
class FeatureExtractor:
|
| 274 |
def __init__(self):
|
|
|
|
| 275 |
backbone = models.resnet50(weights="IMAGENET1K_V1")
|
| 276 |
self.model = nn.Sequential(*list(backbone.children())[:-1])
|
| 277 |
self.model.eval()
|
|
@@ -291,15 +292,24 @@ class FeatureExtractor:
|
|
| 291 |
rgb = np.array(rgb.convert("RGB"))
|
| 292 |
if rgb.dtype != np.uint8:
|
| 293 |
rgb = rgb.astype(np.uint8)
|
|
|
|
| 294 |
if len(rgb.shape) == 2:
|
| 295 |
rgb = cv2.cvtColor(rgb, cv2.COLOR_GRAY2RGB)
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
input_tensor = self.transform(Image.fromarray(rgb)).unsqueeze(0)
|
| 298 |
-
|
|
|
|
| 299 |
with torch.no_grad():
|
|
|
|
|
|
|
| 300 |
backbone = models.resnet50(weights="IMAGENET1K_V1")
|
| 301 |
backbone.eval()
|
| 302 |
-
|
| 303 |
x = backbone.conv1(input_tensor)
|
| 304 |
x = backbone.bn1(x)
|
| 305 |
x = backbone.relu(x)
|
|
@@ -307,23 +317,27 @@ class FeatureExtractor:
|
|
| 307 |
x = backbone.layer1(x)
|
| 308 |
x = backbone.layer2(x)
|
| 309 |
x = backbone.layer3(x)
|
| 310 |
-
features_spatial = backbone.layer4(x)
|
| 311 |
-
|
|
|
|
| 312 |
feat = torch.mean(features_spatial, dim=[2, 3]).squeeze().cpu().numpy()
|
| 313 |
-
|
|
|
|
| 314 |
amap = torch.sum(features_spatial, dim=1).squeeze().cpu().numpy()
|
| 315 |
amap = np.maximum(amap, 0)
|
| 316 |
amap /= (np.max(amap) + 1e-8)
|
| 317 |
amap = cv2.resize(amap, (rgb.shape[1], rgb.shape[0]))
|
| 318 |
amap = np.uint8(255 * amap)
|
| 319 |
heatmap = cv2.applyColorMap(amap, cv2.COLORMAP_JET)
|
|
|
|
|
|
|
|
|
|
| 320 |
heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
| 321 |
overlay = cv2.addWeighted(rgb, 0.6, heatmap_rgb, 0.4, 0)
|
| 322 |
|
| 323 |
norm = np.linalg.norm(feat)
|
| 324 |
return (feat / norm if norm > 1e-8 else feat), overlay
|
| 325 |
|
| 326 |
-
|
| 327 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 328 |
# MASTER ORCHESTRATOR β EnginePartDetector
|
| 329 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -333,87 +347,118 @@ class EnginePartDetector:
|
|
| 333 |
|
| 334 |
def __init__(self):
|
| 335 |
self.feature_extractor = FeatureExtractor()
|
| 336 |
-
|
| 337 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 338 |
|
| 339 |
# ββ Persistence βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 340 |
|
| 341 |
-
def
|
| 342 |
if os.path.exists(self.TEMPLATE_FILE):
|
| 343 |
try:
|
| 344 |
with open(self.TEMPLATE_FILE, "rb") as f:
|
| 345 |
-
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
except Exception as e:
|
| 348 |
-
logger.error(f"
|
| 349 |
-
self.
|
| 350 |
|
| 351 |
-
def
|
| 352 |
try:
|
| 353 |
with open(self.TEMPLATE_FILE, "wb") as f:
|
| 354 |
-
pickle.dump(self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
except Exception as e:
|
| 356 |
-
logger.error(f"
|
| 357 |
|
| 358 |
# ββ Layer 1: ROI Detection & Extraction βββββββββββββββββββββββββββββββββββ
|
| 359 |
|
| 360 |
@staticmethod
|
| 361 |
-
def detect_connect_and_crop(
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
|
|
|
|
|
|
|
|
|
| 365 |
img_rgb = image_source
|
| 366 |
img_h, img_w = img_rgb.shape[:2]
|
| 367 |
gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
|
| 368 |
gray = cv2.GaussianBlur(gray, (7, 7), 0)
|
| 369 |
-
|
|
|
|
| 370 |
circles = cv2.HoughCircles(
|
| 371 |
gray, cv2.HOUGH_GRADIENT, dp=1.2, minDist=60,
|
| 372 |
param1=100, param2=35, minRadius=12, maxRadius=45
|
| 373 |
)
|
| 374 |
-
|
| 375 |
if circles is None:
|
| 376 |
-
return img_rgb, img_rgb, "β No bolt holes detected."
|
| 377 |
-
|
| 378 |
circles = np.round(circles[0]).astype(int)
|
| 379 |
-
|
|
|
|
| 380 |
ys = sorted([c[1] for c in circles])
|
| 381 |
y_median = np.median(ys)
|
| 382 |
-
|
| 383 |
top_row = sorted([c for c in circles if c[1] < y_median], key=lambda x: x[0])
|
| 384 |
bot_row = sorted([c for c in circles if c[1] >= y_median], key=lambda x: x[0])
|
| 385 |
-
|
| 386 |
if len(top_row) < 2 or len(bot_row) < 2:
|
| 387 |
-
return img_rgb, img_rgb, "β οΈ Insufficient hole rows for localization."
|
| 388 |
|
|
|
|
| 389 |
y_top = int(np.mean([c[1] for c in top_row]))
|
| 390 |
y_bot = int(np.mean([c[1] for c in bot_row]))
|
| 391 |
-
|
|
|
|
| 392 |
xs = [c[0] for c in circles]
|
| 393 |
x_min, x_max = min(xs), max(xs)
|
| 394 |
padding_h = 60
|
| 395 |
padding_v = 20
|
| 396 |
-
|
| 397 |
x_start = max(0, x_min - padding_h)
|
| 398 |
x_end = min(img_w, x_max + padding_h)
|
| 399 |
y_start = max(0, min(y_top, y_bot) - padding_v)
|
| 400 |
y_end = min(img_h, max(y_top, y_bot) + padding_v)
|
| 401 |
|
|
|
|
| 402 |
vis_img = img_rgb.copy()
|
| 403 |
LINE_COLOR = (0, 255, 0)
|
| 404 |
HOLE_COLOR = (255, 0, 0)
|
| 405 |
-
|
|
|
|
| 406 |
cv2.line(vis_img, (0, y_top), (img_w, y_top), LINE_COLOR, 3)
|
| 407 |
cv2.line(vis_img, (0, y_bot), (img_w, y_bot), LINE_COLOR, 3)
|
| 408 |
-
|
| 409 |
for (x, y, r) in circles:
|
| 410 |
cv2.circle(vis_img, (x, y), r, HOLE_COLOR, 3)
|
| 411 |
cv2.circle(vis_img, (x, y), 2, (255, 255, 255), -1)
|
| 412 |
|
|
|
|
| 413 |
cropped_img = img_rgb[y_start:y_end, x_start:x_end]
|
| 414 |
-
|
| 415 |
if cropped_img.size == 0:
|
| 416 |
-
return vis_img, img_rgb, "β οΈ ROI selection failed."
|
| 417 |
|
| 418 |
stats_text = (
|
| 419 |
f"β
**Full Saddle Band Extracted**\n"
|
|
@@ -423,72 +468,22 @@ class EnginePartDetector:
|
|
| 423 |
f"β’ ROI Size: {cropped_img.shape[1]}x{cropped_img.shape[0]} px"
|
| 424 |
)
|
| 425 |
|
| 426 |
-
return vis_img, cropped_img, stats_text
|
| 427 |
-
|
| 428 |
-
# βοΏ½οΏ½ Vertical-line detection on structural edge map βββββββββββββββββββββββ
|
| 429 |
-
|
| 430 |
-
@staticmethod
|
| 431 |
-
def detect_vertical_lines_on_edge_map(
|
| 432 |
-
roi_enhanced: np.ndarray,
|
| 433 |
-
angle_tolerance_deg: float = 12.0,
|
| 434 |
-
min_line_length_ratio: float = 0.15,
|
| 435 |
-
) -> tuple[bool, np.ndarray, str]:
|
| 436 |
-
|
| 437 |
-
gray = cv2.cvtColor(roi_enhanced, cv2.COLOR_RGB2GRAY)
|
| 438 |
-
clahe = cv2.createCLAHE(clipLimit=2.8, tileGridSize=(8, 8))
|
| 439 |
-
gray = clahe.apply(gray)
|
| 440 |
-
edges = cv2.Canny(gray, 50, 150)
|
| 441 |
-
|
| 442 |
-
h, w = edges.shape
|
| 443 |
-
min_len = max(20, int(h * min_line_length_ratio))
|
| 444 |
-
|
| 445 |
-
lines = cv2.HoughLinesP(
|
| 446 |
-
edges, rho=1, theta=np.pi / 180,
|
| 447 |
-
threshold=20, minLineLength=min_len, maxLineGap=10,
|
| 448 |
-
)
|
| 449 |
-
|
| 450 |
-
# RGB canvas from edge map
|
| 451 |
-
canvas = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
|
| 452 |
-
|
| 453 |
-
vertical_lines = []
|
| 454 |
-
if lines is not None:
|
| 455 |
-
for seg in lines:
|
| 456 |
-
x1, y1, x2, y2 = seg[0]
|
| 457 |
-
dx = abs(x2 - x1)
|
| 458 |
-
dy = abs(y2 - y1)
|
| 459 |
-
angle = np.degrees(np.arctan2(dx, dy + 1e-6))
|
| 460 |
-
if angle <= angle_tolerance_deg:
|
| 461 |
-
vertical_lines.append((x1, y1, x2, y2, dy))
|
| 462 |
-
|
| 463 |
-
# Sort by length β longest first
|
| 464 |
-
vertical_lines.sort(key=lambda v: v[4], reverse=True)
|
| 465 |
-
has_vertical = len(vertical_lines) > 0
|
| 466 |
-
|
| 467 |
-
if has_vertical:
|
| 468 |
-
for (x1, y1, x2, y2, _) in vertical_lines:
|
| 469 |
-
cv2.line(canvas, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 470 |
-
cv2.rectangle(canvas, (0, 0), (240, 46), (0, 150, 0), -1)
|
| 471 |
-
cv2.putText(canvas, f"PRESENT ({len(vertical_lines)})",
|
| 472 |
-
(6, 34), cv2.FONT_HERSHEY_DUPLEX, 0.85, (255, 255, 255), 2)
|
| 473 |
-
status = (f"β
**Vertical lines PRESENT** β "
|
| 474 |
-
f"{len(vertical_lines)} near-vertical line(s) detected.")
|
| 475 |
-
else:
|
| 476 |
-
cv2.rectangle(canvas, (0, 0), (190, 46), (180, 0, 0), -1)
|
| 477 |
-
cv2.putText(canvas, "ABSENT",
|
| 478 |
-
(6, 34), cv2.FONT_HERSHEY_DUPLEX, 1.1, (255, 255, 255), 2)
|
| 479 |
-
status = "β **Vertical lines ABSENT** β No near-vertical lines on edge map."
|
| 480 |
-
|
| 481 |
-
return has_vertical, canvas, status
|
| 482 |
|
| 483 |
@staticmethod
|
| 484 |
def enhance_roi(roi: np.ndarray) -> np.ndarray:
|
| 485 |
"""Apply high-contrast CLAHE to highlight blurred lines/features."""
|
| 486 |
if roi is None or roi.size == 0:
|
| 487 |
return roi
|
|
|
|
|
|
|
| 488 |
lab = cv2.cvtColor(roi, cv2.COLOR_RGB2LAB)
|
| 489 |
l, a, b = cv2.split(lab)
|
| 490 |
-
|
|
|
|
|
|
|
| 491 |
cl = clahe.apply(l)
|
|
|
|
| 492 |
merged = cv2.merge((cl, a, b))
|
| 493 |
enhanced = cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
|
| 494 |
return enhanced
|
|
@@ -504,104 +499,112 @@ class EnginePartDetector:
|
|
| 504 |
|
| 505 |
# ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 506 |
|
| 507 |
-
def
|
| 508 |
if image is None:
|
| 509 |
return "β No image supplied.", None
|
| 510 |
-
if not
|
| 511 |
-
return "β
|
| 512 |
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
|
|
|
| 516 |
if "β" in log or "β οΈ" in log:
|
| 517 |
return log, None
|
| 518 |
|
|
|
|
| 519 |
roi_enhanced = self.enhance_roi(roi)
|
|
|
|
|
|
|
| 520 |
features, _ = self.feature_extractor.extract(roi_enhanced)
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 526 |
|
| 527 |
-
return f"β
|
| 528 |
|
| 529 |
def match_part(
|
| 530 |
self,
|
| 531 |
image: np.ndarray,
|
| 532 |
threshold: float = 0.70,
|
| 533 |
-
) -> tuple[str, dict | None, np.ndarray | None, np.ndarray | None
|
| 534 |
-
"""
|
| 535 |
-
Returns:
|
| 536 |
-
report_text, label_dict, field_vis, attention_map, annotated_edge_map
|
| 537 |
-
"""
|
| 538 |
if image is None:
|
| 539 |
-
return "β No image supplied.", None, None, None
|
| 540 |
-
if not self.
|
| 541 |
-
return "β οΈ No
|
| 542 |
|
| 543 |
-
#
|
| 544 |
-
vis, roi, log
|
| 545 |
if "β" in log or "β οΈ" in log:
|
| 546 |
-
return log, None, vis, None
|
| 547 |
|
| 548 |
-
#
|
| 549 |
roi_enhanced = self.enhance_roi(roi)
|
|
|
|
|
|
|
| 550 |
query_feat, attention_map = self.feature_extractor.extract(roi_enhanced)
|
| 551 |
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
f"
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
f"### πΈ Field Detection",
|
| 581 |
log,
|
| 582 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
return "\n".join(
|
| 592 |
|
| 593 |
-
def get_template_roi(self,
|
| 594 |
-
|
| 595 |
-
return self.templates[part_name].get("roi")
|
| 596 |
-
return None
|
| 597 |
|
| 598 |
def list_templates(self) -> str:
|
| 599 |
-
if not self.
|
| 600 |
-
return "No
|
| 601 |
-
header = f"Total: {len(self.
|
| 602 |
-
body
|
| 603 |
-
|
| 604 |
-
|
|
|
|
| 605 |
|
| 606 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 607 |
# Gradio Application
|
|
@@ -609,54 +612,44 @@ class EnginePartDetector:
|
|
| 609 |
|
| 610 |
detector = EnginePartDetector()
|
| 611 |
|
| 612 |
-
|
| 613 |
def detect_part(image, threshold):
|
| 614 |
return detector.match_part(image, threshold)
|
| 615 |
|
|
|
|
|
|
|
| 616 |
|
| 617 |
-
def
|
| 618 |
-
return detector.save_template(image, part_name)
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
def list_templates():
|
| 622 |
return detector.list_templates()
|
| 623 |
|
| 624 |
-
|
| 625 |
custom_css = """
|
| 626 |
-
.container { max-width:
|
| 627 |
-
.header
|
| 628 |
-
.footer
|
| 629 |
"""
|
| 630 |
|
| 631 |
with gr.Blocks(title="Engine Part CV System", theme=gr.themes.Soft(), css=custom_css) as demo:
|
| 632 |
gr.Markdown("""
|
| 633 |
<div class="header">
|
| 634 |
-
<h1>π§ Engine Part
|
| 635 |
-
<p>
|
| 636 |
-
<strong>Layer 1:</strong> Hough Bolt-Hole Detection & Crop |
|
| 637 |
-
<strong>Layer 2:</strong> ResNet50 Feature Matching |
|
| 638 |
-
<strong>Edge Map:</strong> Vertical-Line Detection
|
| 639 |
-
</p>
|
| 640 |
</div>
|
| 641 |
""")
|
| 642 |
|
| 643 |
with gr.Tab("π Match Inspection"):
|
| 644 |
with gr.Row():
|
| 645 |
with gr.Column(scale=1):
|
| 646 |
-
detect_input
|
| 647 |
threshold_slider = gr.Slider(0.5, 0.99, value=0.75, step=0.01, label="Matching Threshold")
|
| 648 |
-
detect_btn
|
| 649 |
|
| 650 |
with gr.Column(scale=1):
|
| 651 |
detect_output = gr.Markdown(label="Match Report")
|
| 652 |
-
match_label
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
with gr.Row():
|
| 659 |
-
edge_output = gr.Image(label="Structural Edge Map (green lines = vertical PRESENT | red banner = ABSENT)")
|
| 660 |
|
| 661 |
detect_btn.click(
|
| 662 |
fn=detect_part,
|
|
@@ -665,39 +658,51 @@ with gr.Blocks(title="Engine Part CV System", theme=gr.themes.Soft(), css=custom
|
|
| 665 |
api_name="detect_part",
|
| 666 |
)
|
| 667 |
|
| 668 |
-
with gr.Tab("πΎ
|
| 669 |
with gr.Row():
|
| 670 |
with gr.Column(scale=1):
|
| 671 |
-
template_input
|
| 672 |
-
|
| 673 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 674 |
with gr.Column(scale=1):
|
| 675 |
-
add_status
|
| 676 |
-
add_roi_view = gr.Image(label="
|
| 677 |
|
| 678 |
add_btn.click(
|
| 679 |
-
fn=
|
| 680 |
-
inputs=[template_input,
|
| 681 |
outputs=[add_status, add_roi_view],
|
| 682 |
-
api_name="
|
| 683 |
)
|
| 684 |
|
| 685 |
-
with gr.Tab("π Library"):
|
| 686 |
with gr.Row():
|
| 687 |
with gr.Column(scale=1):
|
| 688 |
-
template_list = gr.Textbox(label="Current
|
| 689 |
-
refresh_btn
|
| 690 |
with gr.Column(scale=1):
|
| 691 |
-
library_roi_view = gr.Image(label="
|
| 692 |
-
|
| 693 |
def update_library_preview():
|
| 694 |
-
if detector.
|
| 695 |
-
first_name = sorted(detector.
|
| 696 |
return detector.list_templates(), detector.get_template_roi(first_name)
|
| 697 |
-
return "No
|
| 698 |
|
| 699 |
refresh_btn.click(fn=update_library_preview, outputs=[template_list, library_roi_view])
|
| 700 |
demo.load(fn=update_library_preview, outputs=[template_list, library_roi_view])
|
| 701 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 702 |
if __name__ == "__main__":
|
| 703 |
-
demo.launch(share=False, show_error=True)
|
|
|
|
| 272 |
|
| 273 |
class FeatureExtractor:
|
| 274 |
def __init__(self):
|
| 275 |
+
# Using ResNet50 for 2048-D feature vectors
|
| 276 |
backbone = models.resnet50(weights="IMAGENET1K_V1")
|
| 277 |
self.model = nn.Sequential(*list(backbone.children())[:-1])
|
| 278 |
self.model.eval()
|
|
|
|
| 292 |
rgb = np.array(rgb.convert("RGB"))
|
| 293 |
if rgb.dtype != np.uint8:
|
| 294 |
rgb = rgb.astype(np.uint8)
|
| 295 |
+
|
| 296 |
if len(rgb.shape) == 2:
|
| 297 |
rgb = cv2.cvtColor(rgb, cv2.COLOR_GRAY2RGB)
|
| 298 |
+
|
| 299 |
+
# We want the layer BEFORE the global pooling to get spatial info
|
| 300 |
+
# resnet.layer4 is the last block
|
| 301 |
+
# self.model is nn.Sequential(*list(backbone.children())[:-1])
|
| 302 |
+
# children()[:-1] = [conv1, bn1, relu, maxpool, layer1, layer2, layer3, layer4]
|
| 303 |
+
|
| 304 |
input_tensor = self.transform(Image.fromarray(rgb)).unsqueeze(0)
|
| 305 |
+
|
| 306 |
+
# Get activations from the last conv layer (Layer 4)
|
| 307 |
with torch.no_grad():
|
| 308 |
+
# Run through the layers up to global pooling
|
| 309 |
+
# Using the original backbone for Easier Access to sub-layers
|
| 310 |
backbone = models.resnet50(weights="IMAGENET1K_V1")
|
| 311 |
backbone.eval()
|
| 312 |
+
|
| 313 |
x = backbone.conv1(input_tensor)
|
| 314 |
x = backbone.bn1(x)
|
| 315 |
x = backbone.relu(x)
|
|
|
|
| 317 |
x = backbone.layer1(x)
|
| 318 |
x = backbone.layer2(x)
|
| 319 |
x = backbone.layer3(x)
|
| 320 |
+
features_spatial = backbone.layer4(x) # [1, 2048, 7, 7]
|
| 321 |
+
|
| 322 |
+
# Global Average Pooling to get the vector
|
| 323 |
feat = torch.mean(features_spatial, dim=[2, 3]).squeeze().cpu().numpy()
|
| 324 |
+
|
| 325 |
+
# Create Heatmap: sum across channels to see "hot" regions
|
| 326 |
amap = torch.sum(features_spatial, dim=1).squeeze().cpu().numpy()
|
| 327 |
amap = np.maximum(amap, 0)
|
| 328 |
amap /= (np.max(amap) + 1e-8)
|
| 329 |
amap = cv2.resize(amap, (rgb.shape[1], rgb.shape[0]))
|
| 330 |
amap = np.uint8(255 * amap)
|
| 331 |
heatmap = cv2.applyColorMap(amap, cv2.COLORMAP_JET)
|
| 332 |
+
|
| 333 |
+
# Overlay heatmap on original image
|
| 334 |
+
# Convert BGR heatmap to RGB
|
| 335 |
heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
| 336 |
overlay = cv2.addWeighted(rgb, 0.6, heatmap_rgb, 0.4, 0)
|
| 337 |
|
| 338 |
norm = np.linalg.norm(feat)
|
| 339 |
return (feat / norm if norm > 1e-8 else feat), overlay
|
| 340 |
|
|
|
|
| 341 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 342 |
# MASTER ORCHESTRATOR β EnginePartDetector
|
| 343 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 347 |
|
| 348 |
def __init__(self):
|
| 349 |
self.feature_extractor = FeatureExtractor()
|
| 350 |
+
# Changed from simple templates to class-based feature lists
|
| 351 |
+
self.classes: dict[str, list[np.ndarray]] = {}
|
| 352 |
+
# We also store an example ROI for each class for visualization
|
| 353 |
+
self.class_rois: dict[str, np.ndarray] = {}
|
| 354 |
+
self._load_data()
|
| 355 |
|
| 356 |
# ββ Persistence βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 357 |
|
| 358 |
+
def _load_data(self) -> None:
|
| 359 |
if os.path.exists(self.TEMPLATE_FILE):
|
| 360 |
try:
|
| 361 |
with open(self.TEMPLATE_FILE, "rb") as f:
|
| 362 |
+
data = pickle.load(f)
|
| 363 |
+
# Support legacy format if needed, but here we assume the new format
|
| 364 |
+
if isinstance(data, dict):
|
| 365 |
+
# If old format was {name: {"features": feat, "roi": roi}}
|
| 366 |
+
# we convert it to {name: [feat]}
|
| 367 |
+
self.classes = {}
|
| 368 |
+
self.class_rois = {}
|
| 369 |
+
for k, v in data.items():
|
| 370 |
+
if isinstance(v, dict) and "features" in v:
|
| 371 |
+
self.classes[k] = [v["features"]]
|
| 372 |
+
self.class_rois[k] = v.get("roi")
|
| 373 |
+
else:
|
| 374 |
+
self.classes[k] = v
|
| 375 |
+
else:
|
| 376 |
+
self.classes = {}
|
| 377 |
+
logger.info(f"Loaded {len(self.classes)} class(es).")
|
| 378 |
except Exception as e:
|
| 379 |
+
logger.error(f"Data load failed: {e}")
|
| 380 |
+
self.classes = {}
|
| 381 |
|
| 382 |
+
def _persist_data(self) -> None:
|
| 383 |
try:
|
| 384 |
with open(self.TEMPLATE_FILE, "wb") as f:
|
| 385 |
+
pickle.dump(self.classes, f)
|
| 386 |
+
# Separately save ROIs if needed, but for now we just persist classes
|
| 387 |
+
# In a real app we'd save ROIs too. Let's include them in a combined dict.
|
| 388 |
+
with open("class_data.pkl", "wb") as f:
|
| 389 |
+
pickle.dump({"classes": self.classes, "rois": self.class_rois}, f)
|
| 390 |
except Exception as e:
|
| 391 |
+
logger.error(f"Data save failed: {e}")
|
| 392 |
|
| 393 |
# ββ Layer 1: ROI Detection & Extraction βββββββββββββββββββββββββββββββββββ
|
| 394 |
|
| 395 |
@staticmethod
|
| 396 |
+
def detect_connect_and_crop(image_source: np.ndarray) -> tuple[np.ndarray, np.ndarray, str]:
|
| 397 |
+
"""
|
| 398 |
+
1. Detects bolt holes.
|
| 399 |
+
2. Separates into Top and Bottom rows.
|
| 400 |
+
3. Fits horizontal reference lines.
|
| 401 |
+
4. Crops the FULL horizontal band between rows (includes regions between saddles).
|
| 402 |
+
"""
|
| 403 |
img_rgb = image_source
|
| 404 |
img_h, img_w = img_rgb.shape[:2]
|
| 405 |
gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
|
| 406 |
gray = cv2.GaussianBlur(gray, (7, 7), 0)
|
| 407 |
+
|
| 408 |
+
# ββ Step 1: Detect Circles ββββββββββββββββββββββββββββββββββββββββββββ
|
| 409 |
circles = cv2.HoughCircles(
|
| 410 |
gray, cv2.HOUGH_GRADIENT, dp=1.2, minDist=60,
|
| 411 |
param1=100, param2=35, minRadius=12, maxRadius=45
|
| 412 |
)
|
| 413 |
+
|
| 414 |
if circles is None:
|
| 415 |
+
return img_rgb, img_rgb, "β No bolt holes detected."
|
| 416 |
+
|
| 417 |
circles = np.round(circles[0]).astype(int)
|
| 418 |
+
|
| 419 |
+
# ββ Step 2: Row Separation ββββββββββββββββββββββββββββββββββββββββββββ
|
| 420 |
ys = sorted([c[1] for c in circles])
|
| 421 |
y_median = np.median(ys)
|
| 422 |
+
|
| 423 |
top_row = sorted([c for c in circles if c[1] < y_median], key=lambda x: x[0])
|
| 424 |
bot_row = sorted([c for c in circles if c[1] >= y_median], key=lambda x: x[0])
|
| 425 |
+
|
| 426 |
if len(top_row) < 2 or len(bot_row) < 2:
|
| 427 |
+
return img_rgb, img_rgb, "β οΈ Insufficient hole rows for localization."
|
| 428 |
|
| 429 |
+
# ββ Step 3: Reference Lines βββββββββββββββββββββββββββββββββββββββββββ
|
| 430 |
y_top = int(np.mean([c[1] for c in top_row]))
|
| 431 |
y_bot = int(np.mean([c[1] for c in bot_row]))
|
| 432 |
+
|
| 433 |
+
# Horizontal bounds (First hole to Last hole)
|
| 434 |
xs = [c[0] for c in circles]
|
| 435 |
x_min, x_max = min(xs), max(xs)
|
| 436 |
padding_h = 60
|
| 437 |
padding_v = 20
|
| 438 |
+
|
| 439 |
x_start = max(0, x_min - padding_h)
|
| 440 |
x_end = min(img_w, x_max + padding_h)
|
| 441 |
y_start = max(0, min(y_top, y_bot) - padding_v)
|
| 442 |
y_end = min(img_h, max(y_top, y_bot) + padding_v)
|
| 443 |
|
| 444 |
+
# ββ Step 4: Visualization βββββββββββββββββββββββββββββββββββββββββββββ
|
| 445 |
vis_img = img_rgb.copy()
|
| 446 |
LINE_COLOR = (0, 255, 0)
|
| 447 |
HOLE_COLOR = (255, 0, 0)
|
| 448 |
+
|
| 449 |
+
# Draw lines and detected holes
|
| 450 |
cv2.line(vis_img, (0, y_top), (img_w, y_top), LINE_COLOR, 3)
|
| 451 |
cv2.line(vis_img, (0, y_bot), (img_w, y_bot), LINE_COLOR, 3)
|
| 452 |
+
|
| 453 |
for (x, y, r) in circles:
|
| 454 |
cv2.circle(vis_img, (x, y), r, HOLE_COLOR, 3)
|
| 455 |
cv2.circle(vis_img, (x, y), 2, (255, 255, 255), -1)
|
| 456 |
|
| 457 |
+
# ββ Step 5: Full Band Crop ββββββββββββββββββββββββββββββββββββββββββββ
|
| 458 |
cropped_img = img_rgb[y_start:y_end, x_start:x_end]
|
| 459 |
+
|
| 460 |
if cropped_img.size == 0:
|
| 461 |
+
return vis_img, img_rgb, "β οΈ ROI selection failed."
|
| 462 |
|
| 463 |
stats_text = (
|
| 464 |
f"β
**Full Saddle Band Extracted**\n"
|
|
|
|
| 468 |
f"β’ ROI Size: {cropped_img.shape[1]}x{cropped_img.shape[0]} px"
|
| 469 |
)
|
| 470 |
|
| 471 |
+
return vis_img, cropped_img, stats_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
|
| 473 |
@staticmethod
|
| 474 |
def enhance_roi(roi: np.ndarray) -> np.ndarray:
|
| 475 |
"""Apply high-contrast CLAHE to highlight blurred lines/features."""
|
| 476 |
if roi is None or roi.size == 0:
|
| 477 |
return roi
|
| 478 |
+
|
| 479 |
+
# Convert to LAB space to apply CLAHE on L (luminance) channel
|
| 480 |
lab = cv2.cvtColor(roi, cv2.COLOR_RGB2LAB)
|
| 481 |
l, a, b = cv2.split(lab)
|
| 482 |
+
|
| 483 |
+
# ClipLimit 10.0 provides very high contrast as requested
|
| 484 |
+
clahe = cv2.createCLAHE(clipLimit=10.0, tileGridSize=(8, 8))
|
| 485 |
cl = clahe.apply(l)
|
| 486 |
+
|
| 487 |
merged = cv2.merge((cl, a, b))
|
| 488 |
enhanced = cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
|
| 489 |
return enhanced
|
|
|
|
| 499 |
|
| 500 |
# ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 501 |
|
| 502 |
+
def add_to_class(self, image: np.ndarray, class_name: str) -> tuple[str, np.ndarray | None]:
|
| 503 |
if image is None:
|
| 504 |
return "β No image supplied.", None
|
| 505 |
+
if not class_name or not class_name.strip():
|
| 506 |
+
return "β Class name is empty.", None
|
| 507 |
|
| 508 |
+
class_name = class_name.strip()
|
| 509 |
+
|
| 510 |
+
# Layer 1: Localization
|
| 511 |
+
vis, roi, log = self.detect_connect_and_crop(image)
|
| 512 |
if "β" in log or "β οΈ" in log:
|
| 513 |
return log, None
|
| 514 |
|
| 515 |
+
# Enhance ROI
|
| 516 |
roi_enhanced = self.enhance_roi(roi)
|
| 517 |
+
|
| 518 |
+
# Layer 2: Feature Extraction
|
| 519 |
features, _ = self.feature_extractor.extract(roi_enhanced)
|
| 520 |
+
|
| 521 |
+
if class_name not in self.classes:
|
| 522 |
+
self.classes[class_name] = []
|
| 523 |
+
|
| 524 |
+
self.classes[class_name].append(features)
|
| 525 |
+
self.class_rois[class_name] = roi_enhanced # Keep the latest ROI as reference
|
| 526 |
+
|
| 527 |
+
self._persist_data()
|
| 528 |
|
| 529 |
+
return f"β
Image added to class '{class_name}'! (Now has {len(self.classes[class_name])} samples)\n\n{log}", roi
|
| 530 |
|
| 531 |
def match_part(
|
| 532 |
self,
|
| 533 |
image: np.ndarray,
|
| 534 |
threshold: float = 0.70,
|
| 535 |
+
) -> tuple[str, dict | None, np.ndarray | None, np.ndarray | None]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
if image is None:
|
| 537 |
+
return "β No image supplied.", None, None, None
|
| 538 |
+
if not self.classes:
|
| 539 |
+
return "β οΈ No trained classes yet. Add samples to at least one class (e.g. 'Perfect').", None, None, None
|
| 540 |
|
| 541 |
+
# Layer 1: Localization
|
| 542 |
+
vis, roi, log = self.detect_connect_and_crop(image)
|
| 543 |
if "β" in log or "β οΈ" in log:
|
| 544 |
+
return log, None, vis, None
|
| 545 |
|
| 546 |
+
# Enhance ROI
|
| 547 |
roi_enhanced = self.enhance_roi(roi)
|
| 548 |
+
|
| 549 |
+
# Layer 2: Feature Extraction
|
| 550 |
query_feat, attention_map = self.feature_extractor.extract(roi_enhanced)
|
| 551 |
|
| 552 |
+
# Layer 3: Latent Space Matching (Cosine Similarity to centroids)
|
| 553 |
+
class_scores = []
|
| 554 |
+
for name, vectors in self.classes.items():
|
| 555 |
+
# Calculate centroid (neighborhood center)
|
| 556 |
+
centroid = np.mean(vectors, axis=0)
|
| 557 |
+
sim = self._cosine(query_feat, centroid)
|
| 558 |
+
class_scores.append((name, sim))
|
| 559 |
+
|
| 560 |
+
class_scores.sort(key=lambda x: x[1], reverse=True)
|
| 561 |
+
|
| 562 |
+
best_class, best_score = class_scores[0]
|
| 563 |
+
matched = best_score >= threshold
|
| 564 |
+
status = f"β
CLASSIFIED AS: {best_class}" if matched else "β UNCERTAIN (below threshold)"
|
| 565 |
+
|
| 566 |
+
lines = [
|
| 567 |
+
f"{'β
' if matched else 'β'} **Top Prediction**: `{best_class}`",
|
| 568 |
+
f"π **Cosine Similarity**: {best_score:.2%}",
|
| 569 |
+
f"π― **Status**: {status}",
|
| 570 |
+
"",
|
| 571 |
+
"### π Multi-Stage Architecture Analysis",
|
| 572 |
+
"1. **Localization**: Bolt holes detected, horizontal band cropped.",
|
| 573 |
+
"2. **Feature Extraction**: ResNet50 extracted unique mathematical fingerprint.",
|
| 574 |
+
"3. **Matching**: Nearest cluster identified in latent space via Cosine Similarity.",
|
| 575 |
+
"",
|
| 576 |
+
"The heatmap on the right shows exactly where the AI is focusing.",
|
| 577 |
+
"- **Red Regions**: Areas defining the class (e.g., surface quality, edges).",
|
| 578 |
+
"",
|
| 579 |
+
"---",
|
|
|
|
| 580 |
log,
|
| 581 |
]
|
| 582 |
+
|
| 583 |
+
if len(class_scores) > 1:
|
| 584 |
+
lines.append("\n**Class Probabilities (Latent Distance):**")
|
| 585 |
+
for name, sim in class_scores:
|
| 586 |
+
lines.append(f" β’ `{name}`: {sim:.3f}")
|
| 587 |
|
| 588 |
+
label_dict = {name: float(sim) for name, sim in class_scores}
|
| 589 |
+
|
| 590 |
+
# Edge Map for structural analysis
|
| 591 |
+
gray_enhanced = cv2.cvtColor(roi_enhanced, cv2.COLOR_RGB2GRAY)
|
| 592 |
+
edges = cv2.Canny(gray_enhanced, 50, 150)
|
| 593 |
+
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
|
| 594 |
+
|
| 595 |
+
return "\n".join(lines), label_dict, vis, attention_map, edges_rgb
|
| 596 |
|
| 597 |
+
def get_template_roi(self, class_name: str) -> np.ndarray | None:
|
| 598 |
+
return self.class_rois.get(class_name)
|
|
|
|
|
|
|
| 599 |
|
| 600 |
def list_templates(self) -> str:
|
| 601 |
+
if not self.classes:
|
| 602 |
+
return "No classes trained yet."
|
| 603 |
+
header = f"Total: {len(self.classes)} class(es)\n" + "β" * 30
|
| 604 |
+
body = []
|
| 605 |
+
for name, vectors in sorted(self.classes.items()):
|
| 606 |
+
body.append(f" β’ {name}: {len(vectors)} samples")
|
| 607 |
+
return f"{header}\n" + "\n".join(body)
|
| 608 |
|
| 609 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 610 |
# Gradio Application
|
|
|
|
| 612 |
|
| 613 |
detector = EnginePartDetector()
|
| 614 |
|
|
|
|
| 615 |
def detect_part(image, threshold):
|
| 616 |
return detector.match_part(image, threshold)
|
| 617 |
|
| 618 |
+
def add_sample(image, class_name):
|
| 619 |
+
return detector.add_to_class(image, class_name)
|
| 620 |
|
| 621 |
+
def list_classes():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 622 |
return detector.list_templates()
|
| 623 |
|
| 624 |
+
# Custom CSS for premium look
|
| 625 |
custom_css = """
|
| 626 |
+
.container { max-width: 1200px; margin: auto; }
|
| 627 |
+
.header { text-align: center; margin-bottom: 2rem; }
|
| 628 |
+
.footer { text-align: center; margin-top: 2rem; color: #666; }
|
| 629 |
"""
|
| 630 |
|
| 631 |
with gr.Blocks(title="Engine Part CV System", theme=gr.themes.Soft(), css=custom_css) as demo:
|
| 632 |
gr.Markdown("""
|
| 633 |
<div class="header">
|
| 634 |
+
<h1>π§ Engine Part CV System</h1>
|
| 635 |
+
<p><strong>Multi-Stage Architecture:</strong> Localization β Feature Fingerprint (ResNet) β Latent Space Matching</p>
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
</div>
|
| 637 |
""")
|
| 638 |
|
| 639 |
with gr.Tab("π Match Inspection"):
|
| 640 |
with gr.Row():
|
| 641 |
with gr.Column(scale=1):
|
| 642 |
+
detect_input = gr.Image(sources=["upload", "webcam"], type="numpy", label="Input Image")
|
| 643 |
threshold_slider = gr.Slider(0.5, 0.99, value=0.75, step=0.01, label="Matching Threshold")
|
| 644 |
+
detect_btn = gr.Button("π Run Inspection", variant="primary")
|
| 645 |
|
| 646 |
with gr.Column(scale=1):
|
| 647 |
detect_output = gr.Markdown(label="Match Report")
|
| 648 |
+
match_label = gr.Label(label="Top Scores", num_top_classes=5)
|
| 649 |
+
with gr.Row():
|
| 650 |
+
vis_output = gr.Image(label="Field Visualization")
|
| 651 |
+
attn_output = gr.Image(label="AI Attention Heatmap")
|
| 652 |
+
edge_output = gr.Image(label="Structural Edge Map (Line Detection)")
|
|
|
|
|
|
|
|
|
|
| 653 |
|
| 654 |
detect_btn.click(
|
| 655 |
fn=detect_part,
|
|
|
|
| 658 |
api_name="detect_part",
|
| 659 |
)
|
| 660 |
|
| 661 |
+
with gr.Tab("πΎ Train Latent Space"):
|
| 662 |
with gr.Row():
|
| 663 |
with gr.Column(scale=1):
|
| 664 |
+
template_input = gr.Image(sources=["upload"], type="numpy", label="Training Image")
|
| 665 |
+
class_name_input = gr.Dropdown(
|
| 666 |
+
choices=["Perfect", "Defected", "Unknown"],
|
| 667 |
+
label="Class Label",
|
| 668 |
+
value="Perfect",
|
| 669 |
+
allow_custom_value=True
|
| 670 |
+
)
|
| 671 |
+
add_btn = gr.Button("πΎ Add to Cluster", variant="primary")
|
| 672 |
with gr.Column(scale=1):
|
| 673 |
+
add_status = gr.Textbox(label="Training Status", lines=5)
|
| 674 |
+
add_roi_view = gr.Image(label="Processed Training ROI", interactive=False)
|
| 675 |
|
| 676 |
add_btn.click(
|
| 677 |
+
fn=add_sample,
|
| 678 |
+
inputs=[template_input, class_name_input],
|
| 679 |
outputs=[add_status, add_roi_view],
|
| 680 |
+
api_name="add_sample",
|
| 681 |
)
|
| 682 |
|
| 683 |
+
with gr.Tab("π Class Library"):
|
| 684 |
with gr.Row():
|
| 685 |
with gr.Column(scale=1):
|
| 686 |
+
template_list = gr.Textbox(label="Current Trained Classes", lines=12)
|
| 687 |
+
refresh_btn = gr.Button("π Refresh Clusters")
|
| 688 |
with gr.Column(scale=1):
|
| 689 |
+
library_roi_view = gr.Image(label="Last Reference ROI", interactive=False)
|
| 690 |
+
|
| 691 |
def update_library_preview():
|
| 692 |
+
if detector.classes:
|
| 693 |
+
first_name = sorted(detector.classes.keys())[0]
|
| 694 |
return detector.list_templates(), detector.get_template_roi(first_name)
|
| 695 |
+
return "No classes trained yet.", None
|
| 696 |
|
| 697 |
refresh_btn.click(fn=update_library_preview, outputs=[template_list, library_roi_view])
|
| 698 |
demo.load(fn=update_library_preview, outputs=[template_list, library_roi_view])
|
| 699 |
|
| 700 |
+
gr.Markdown("""
|
| 701 |
+
---
|
| 702 |
+
<div class="footer">
|
| 703 |
+
<p>Engine Part CV System β’ Powered by PyTorch & OpenCV</p>
|
| 704 |
+
</div>
|
| 705 |
+
""")
|
| 706 |
+
|
| 707 |
if __name__ == "__main__":
|
| 708 |
+
demo.launch(share=False, show_error=True)
|