Spaces:
Sleeping
Sleeping
| """ | |
| FADNet Gradio GUI | |
| ================= | |
| Thermal Hotspot & Crack Detection β Interactive Inference Dashboard | |
| Supports: Standard, Multi-Resolution WBF, and SAHI inference modes. | |
| Run: | |
| pip install gradio ultralytics ensemble-boxes opencv-python-headless | |
| python app.py | |
| """ | |
| import os, sys, math, cv2, pathlib, warnings, textwrap | |
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| warnings.filterwarnings("ignore") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 0. Constants & Paths (edit these to match your environment) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE_DIR = pathlib.Path(__file__).parent | |
| CKPT_DIR = BASE_DIR | |
| CHECKPOINTS = { | |
| "FADNet Finetune (Best)": str(CKPT_DIR / "fadnet_finetune_best.pt"), | |
| "FADNet YOLO Backbone": str(CKPT_DIR / "fadnet_yolo_best.pt"), | |
| } | |
| CLASS_NAMES = ["Hotspot", "Crack"] | |
| N_CLASSES = 2 | |
| # F1-optimal defaults (from notebook Cell 19/20) | |
| DEFAULT_CONF_HOTSPOT = 0.20 | |
| DEFAULT_CONF_CRACK = 0.20 | |
| # Colour palette (BGR β used by cv2, converted to RGB for Gradio) | |
| COLORS = { | |
| "Hotspot": (255, 80, 60), # bright red-orange | |
| "Crack": ( 60, 140, 255), # cornflower blue | |
| "GT": ( 0, 220, 0), # green | |
| "TP": ( 0, 200, 200), # cyan | |
| "FP": ( 0, 0, 220), # red | |
| "FN": ( 0, 200, 220), # yellow-ish | |
| } | |
| GALLERY_IMAGES = sorted((BASE_DIR / "working").glob("*.png")) if (BASE_DIR / "working").exists() else [] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1. CoordAtt Patch (required before loading any FADNet checkpoint) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class h_sigmoid(nn.Module): | |
| def forward(self, x): return nn.functional.relu6(x + 3) / 6 | |
| class h_swish(nn.Module): | |
| def forward(self, x): return x * h_sigmoid()(x) | |
| class CoordAtt(nn.Module): | |
| def __init__(self, inp, oup=None, reduction=32): | |
| super().__init__() | |
| oup = oup or inp | |
| mip = max(8, inp // reduction) | |
| self.conv1 = nn.Conv2d(inp, mip, 1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(mip) | |
| self.act = h_swish() | |
| self.conv_h = nn.Conv2d(mip, oup, 1, bias=False) | |
| self.conv_w = nn.Conv2d(mip, oup, 1, bias=False) | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| xh = x.mean(dim=3, keepdim=True) | |
| xw = x.mean(dim=2, keepdim=True).permute(0, 1, 3, 2) | |
| y = torch.cat([xh, xw], dim=2) | |
| y = self.act(self.bn1(self.conv1(y))) | |
| xh, xw = torch.split(y, [H, W], dim=2) | |
| xw = xw.permute(0, 1, 3, 2) | |
| return x * torch.sigmoid(self.conv_h(xh)) * torch.sigmoid(self.conv_w(xw)) | |
| def patch_ultralytics(): | |
| """Inject CoordAtt into Ultralytics so FADNet checkpoints load cleanly.""" | |
| try: | |
| import ultralytics.nn.modules as M | |
| import ultralytics.nn.tasks as T | |
| import shutil | |
| M.CoordAtt = CoordAtt | |
| T.CoordAtt = CoordAtt | |
| fake_mod = type(sys)("ultralytics.nn.modules.coord_att") | |
| fake_mod.CoordAtt = CoordAtt | |
| fake_mod.h_swish = h_swish | |
| fake_mod.h_sigmoid = h_sigmoid | |
| sys.modules["ultralytics.nn.modules.coord_att"] = fake_mod | |
| M.coord_att = fake_mod | |
| d = pathlib.Path(M.__file__).parent | |
| coord_att_src = textwrap.dedent("""\ | |
| import torch, torch.nn as nn | |
| class h_sigmoid(nn.Module): | |
| def forward(self, x): return nn.functional.relu6(x + 3) / 6 | |
| class h_swish(nn.Module): | |
| def forward(self, x): return x * h_sigmoid()(x) | |
| class CoordAtt(nn.Module): | |
| def __init__(self, inp, oup=None, reduction=32): | |
| super().__init__() | |
| oup = oup or inp; mip = max(8, inp // reduction) | |
| self.conv1 = nn.Conv2d(inp, mip, 1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(mip) | |
| self.act = h_swish() | |
| self.conv_h = nn.Conv2d(mip, oup, 1, bias=False) | |
| self.conv_w = nn.Conv2d(mip, oup, 1, bias=False) | |
| def forward(self, x): | |
| B,C,H,W = x.shape | |
| xh = x.mean(3, keepdim=True) | |
| xw = x.mean(2, keepdim=True).permute(0,1,3,2) | |
| y = self.act(self.bn1(self.conv1(torch.cat([xh,xw],2)))) | |
| xh, xw = torch.split(y, [H, W], 2) | |
| return x*torch.sigmoid(self.conv_h(xh))*torch.sigmoid(self.conv_w(xw.permute(0,1,3,2))) | |
| """) | |
| (d / "coord_att.py").write_text(coord_att_src) | |
| tp = pathlib.Path(T.__file__).with_suffix(".py") | |
| txt = tp.read_text() | |
| if "coord_att" not in txt: | |
| tp.write_text("from ultralytics.nn.modules.coord_att import CoordAtt\n" + txt) | |
| shutil.rmtree(tp.parent / "__pycache__", ignore_errors=True) | |
| shutil.rmtree(d / "__pycache__", ignore_errors=True) | |
| return True, "CoordAtt patch applied β" | |
| except Exception as e: | |
| return False, f"Patch failed: {e}" | |
| # Apply patch at startup | |
| _patch_ok, _patch_msg = patch_ultralytics() | |
| print(_patch_msg) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2. Model Cache | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _model_cache: dict[str, object] = {} | |
| def load_model(ckpt_name: str): | |
| """Load (and cache) a YOLO checkpoint by friendly name.""" | |
| from ultralytics import YOLO | |
| ckpt_path = CHECKPOINTS.get(ckpt_name) | |
| if not ckpt_path: | |
| raise ValueError(f"Unknown checkpoint: {ckpt_name}") | |
| if not os.path.exists(ckpt_path): | |
| raise FileNotFoundError( | |
| f"Checkpoint not found at:\n {ckpt_path}\n\n" | |
| "Copy the .pt files into the checkpoints/ folder next to app.py." | |
| ) | |
| if ckpt_name not in _model_cache: | |
| _model_cache[ckpt_name] = YOLO(ckpt_path) | |
| return _model_cache[ckpt_name] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3. Drawing helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _draw_box(img, x1, y1, x2, y2, color_bgr, label, font_scale=0.48, thickness=2): | |
| cv2.rectangle(img, (x1, y1), (x2, y2), color_bgr, thickness) | |
| (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 1) | |
| by = max(y1 - 4, th + 4) | |
| cv2.rectangle(img, (x1, by - th - 4), (x1 + tw + 6, by), color_bgr, -1) | |
| cv2.putText(img, label, (x1 + 3, by - 2), | |
| cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), 1, cv2.LINE_AA) | |
| def annotate_image(img_bgr, boxes_norm, scores, labels, | |
| conf_thrs=(0.20, 0.20), draw_conf=True): | |
| """ | |
| Draw predicted bounding boxes on a BGR image copy. | |
| Returns an RGB numpy array. | |
| boxes_norm : list of [x1,y1,x2,y2] in [0,1] | |
| """ | |
| vis = img_bgr.copy() | |
| H, W = vis.shape[:2] | |
| order = sorted(range(len(scores)), key=lambda i: -scores[i]) | |
| for i in order: | |
| lbl = labels[i] | |
| score = scores[i] | |
| if score < conf_thrs[lbl]: | |
| continue | |
| box = boxes_norm[i] | |
| x1, y1 = int(box[0] * W), int(box[1] * H) | |
| x2, y2 = int(box[2] * W), int(box[3] * H) | |
| col = COLORS[CLASS_NAMES[lbl]] | |
| text = f"{CLASS_NAMES[lbl]} {score:.2f}" if draw_conf else CLASS_NAMES[lbl] | |
| _draw_box(vis, x1, y1, x2, y2, col, text) | |
| return cv2.cvtColor(vis, cv2.COLOR_BGR2RGB) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4. Inference Modes | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _yolo_predict(model, img_path_or_arr, imgsz, conf_raw, iou_raw, device): | |
| """Run YOLO.predict and return (boxes_norm, scores, labels).""" | |
| is_arr = isinstance(img_path_or_arr, np.ndarray) | |
| src = img_path_or_arr | |
| # Get image dims for normalisation | |
| if is_arr: | |
| H, W = src.shape[:2] | |
| else: | |
| tmp = cv2.imread(str(img_path_or_arr)) | |
| H, W = tmp.shape[:2] | |
| res = model.predict( | |
| src, imgsz=imgsz, conf=conf_raw, iou=iou_raw, | |
| verbose=False, save=False, device=device, | |
| ) | |
| r = res[0] | |
| boxes, scores, labels = [], [], [] | |
| if len(r.boxes): | |
| for box in r.boxes: | |
| x1, y1, x2, y2 = box.xyxy[0].cpu().tolist() | |
| boxes.append([ | |
| max(0.0, x1 / W), max(0.0, y1 / H), | |
| min(1.0, x2 / W), min(1.0, y2 / H), | |
| ]) | |
| scores.append(float(box.conf[0])) | |
| # Label flip: model cls 0βdataset 1 and vice-versa | |
| labels.append(1 - int(box.cls[0])) | |
| return boxes, scores, labels | |
| def infer_standard(model, img_bgr, conf_hotspot, conf_crack, nms_iou, imgsz, device): | |
| """Single-resolution inference.""" | |
| boxes, scores, labels = _yolo_predict( | |
| model, img_bgr, imgsz, conf_raw=0.01, iou_raw=nms_iou, device=device | |
| ) | |
| # Apply per-class threshold | |
| thrs = [conf_hotspot, conf_crack] | |
| keep = [(b, s, l) for b, s, l in zip(boxes, scores, labels) if s >= thrs[l]] | |
| if keep: | |
| b, s, l = zip(*keep) | |
| return list(b), list(s), list(l) | |
| return [], [], [] | |
| def infer_multires_wbf(model, img_bgr, conf_hotspot, conf_crack, | |
| nms_iou, imgsz_list, wbf_iou, wbf_skip, device): | |
| """Multi-resolution Weighted Box Fusion (Lever 3 from notebook).""" | |
| try: | |
| from ensemble_boxes import weighted_boxes_fusion | |
| except ImportError: | |
| raise ImportError("Install ensemble-boxes: pip install ensemble-boxes") | |
| all_boxes, all_scores, all_labels = [], [], [] | |
| for imgsz in imgsz_list: | |
| b, s, l = _yolo_predict(model, img_bgr, imgsz, 0.01, 0.99, device) | |
| all_boxes.append(b); all_scores.append(s); all_labels.append(l) | |
| final_boxes, final_scores, final_labels = [], [], [] | |
| for cls_id in range(N_CLASSES): | |
| cb = [[bx for bx, lb in zip(mb, ml) if lb == cls_id] | |
| for mb, ml in zip(all_boxes, all_labels)] | |
| cs = [[sc for sc, lb in zip(ms, ml) if lb == cls_id] | |
| for ms, ml in zip(all_scores, all_labels)] | |
| if all(len(b) == 0 for b in cb): | |
| continue | |
| b_f, s_f, l_f = weighted_boxes_fusion( | |
| cb, cs, [[cls_id] * len(s) for s in cs], | |
| weights=[1.0] * len(imgsz_list), | |
| iou_thr=wbf_iou, skip_box_thr=wbf_skip, | |
| ) | |
| final_boxes.extend(b_f.tolist()) | |
| final_scores.extend(s_f.tolist()) | |
| final_labels.extend([int(x) for x in l_f]) | |
| thrs = [conf_hotspot, conf_crack] | |
| keep = [(b, s, l) for b, s, l in zip(final_boxes, final_scores, final_labels) if s >= thrs[l]] | |
| if keep: | |
| b, s, l = zip(*keep) | |
| return list(b), list(s), list(l) | |
| return [], [], [] | |
| def _generate_tiles(H, W, tile_size, overlap_ratio): | |
| stride = int(tile_size * (1 - overlap_ratio)) | |
| tiles = [] | |
| y = 0 | |
| while y < H: | |
| x = 0 | |
| while x < W: | |
| x2 = min(x + tile_size, W); y2 = min(y + tile_size, H) | |
| x1 = max(0, x2 - tile_size); y1 = max(0, y2 - tile_size) | |
| tiles.append((x1, y1, x2, y2)) | |
| if x2 == W: break | |
| x += stride | |
| if y2 == H: break | |
| y += stride | |
| return tiles | |
| def infer_sahi(model, img_bgr, conf_hotspot, conf_crack, | |
| tile_size, overlap, model_imgsz, wbf_iou, wbf_skip, | |
| full_weight, tile_weight, device): | |
| """SAHI Sliced Inference (Lever 4 from notebook).""" | |
| try: | |
| from ensemble_boxes import weighted_boxes_fusion | |
| except ImportError: | |
| raise ImportError("Install ensemble-boxes: pip install ensemble-boxes") | |
| H, W = img_bgr.shape[:2] | |
| tiles = _generate_tiles(H, W, tile_size, overlap) | |
| all_boxes, all_scores, all_labels, all_weights = [], [], [], [] | |
| # Full image | |
| fb, fs, fl = _yolo_predict(model, img_bgr, model_imgsz, 0.01, 0.99, device) | |
| all_boxes.append(fb); all_scores.append(fs); all_labels.append(fl) | |
| all_weights.append(full_weight) | |
| # Tiles | |
| for (tx1, ty1, tx2, ty2) in tiles: | |
| tile = img_bgr[ty1:ty2, tx1:tx2] | |
| tH, tW = tile.shape[:2] | |
| if tH < 8 or tW < 8: | |
| continue | |
| tb, ts, tl = _yolo_predict(model, tile, model_imgsz, 0.01, 0.99, device) | |
| # remap tile-relative coords β full image normalised | |
| mapped_boxes = [] | |
| for bx in tb: | |
| ax1 = (bx[0] * tW + tx1) / W; ay1 = (bx[1] * tH + ty1) / H | |
| ax2 = (bx[2] * tW + tx1) / W; ay2 = (bx[3] * tH + ty1) / H | |
| mapped_boxes.append([ | |
| max(0.0, ax1), max(0.0, ay1), | |
| min(1.0, ax2), min(1.0, ay2), | |
| ]) | |
| all_boxes.append(mapped_boxes); all_scores.append(ts); all_labels.append(tl) | |
| all_weights.append(tile_weight) | |
| # WBF fusion | |
| final_boxes, final_scores, final_labels = [], [], [] | |
| for cls_id in range(N_CLASSES): | |
| cb = [[bx for bx, lb in zip(mb, ml) if lb == cls_id] | |
| for mb, ml in zip(all_boxes, all_labels)] | |
| cs = [[sc for sc, lb in zip(ms, ml) if lb == cls_id] | |
| for ms, ml in zip(all_scores, all_labels)] | |
| if all(len(b) == 0 for b in cb): | |
| continue | |
| b_f, s_f, l_f = weighted_boxes_fusion( | |
| cb, cs, [[cls_id] * len(s) for s in cs], | |
| weights=all_weights, | |
| iou_thr=wbf_iou, skip_box_thr=wbf_skip, | |
| ) | |
| final_boxes.extend(b_f.tolist()); final_scores.extend(s_f.tolist()) | |
| final_labels.extend([int(x) for x in l_f]) | |
| thrs = [conf_hotspot, conf_crack] | |
| keep = [(b, s, l) for b, s, l in zip(final_boxes, final_scores, final_labels) if s >= thrs[l]] | |
| if keep: | |
| b, s, l = zip(*keep) | |
| return list(b), list(s), list(l) | |
| return [], [], [] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 5. Main inference callback (called by Gradio) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_inference( | |
| image_np, | |
| ckpt_name, | |
| infer_mode, | |
| conf_hotspot, | |
| conf_crack, | |
| nms_iou, | |
| imgsz, | |
| # Multi-res options | |
| use_736, | |
| wbf_iou, | |
| wbf_skip, | |
| # SAHI options | |
| sahi_tile, | |
| sahi_overlap, | |
| sahi_full_weight, | |
| ): | |
| if image_np is None: | |
| return None, "β οΈ Please upload an image first.", [] | |
| # ββ Resolve device ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| device = 0 if torch.cuda.is_available() else "cpu" | |
| # ββ Load model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| model = load_model(ckpt_name) | |
| except (FileNotFoundError, ValueError) as e: | |
| return None, f"β {e}", [] | |
| # ββ Convert image ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| img_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
| try: | |
| if infer_mode == "Standard": | |
| boxes, scores, labels = infer_standard( | |
| model, img_bgr, conf_hotspot, conf_crack, nms_iou, int(imgsz), device | |
| ) | |
| elif infer_mode == "Multi-Res WBF": | |
| res_list = [640, 736] if use_736 else [640] | |
| boxes, scores, labels = infer_multires_wbf( | |
| model, img_bgr, conf_hotspot, conf_crack, | |
| nms_iou, res_list, wbf_iou, wbf_skip, device | |
| ) | |
| elif infer_mode == "SAHI": | |
| boxes, scores, labels = infer_sahi( | |
| model, img_bgr, conf_hotspot, conf_crack, | |
| int(sahi_tile), sahi_overlap, int(imgsz), | |
| wbf_iou, wbf_skip, sahi_full_weight, 1.0, device | |
| ) | |
| else: | |
| return None, "Unknown inference mode.", [] | |
| except Exception as e: | |
| import traceback | |
| return None, f"β Inference error:\n{traceback.format_exc()}", [] | |
| # ββ Annotate βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| thrs = [conf_hotspot, conf_crack] | |
| vis = annotate_image(img_bgr, boxes, scores, labels, conf_thrs=thrs) | |
| # ββ Build detection table βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| rows = [] | |
| for b, s, l in sorted( | |
| zip(boxes, scores, labels), key=lambda x: -x[1] | |
| ): | |
| if s < thrs[l]: | |
| continue | |
| rows.append([ | |
| CLASS_NAMES[l], | |
| f"{s:.3f}", | |
| f"[{b[0]:.3f}, {b[1]:.3f}, {b[2]:.3f}, {b[3]:.3f}]", | |
| ]) | |
| # ββ Summary text ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| n_hotspot = sum(1 for l, s in zip(labels, scores) if l == 0 and s >= thrs[l]) | |
| n_crack = sum(1 for l, s in zip(labels, scores) if l == 1 and s >= thrs[l]) | |
| device_str = f"GPU (cuda:{device})" if device != "cpu" else "CPU" | |
| summary = ( | |
| f"β **{n_hotspot + n_crack} detection(s)** β " | |
| f"{n_hotspot} Hotspot Β· {n_crack} Crack\n\n" | |
| f"Mode: `{infer_mode}` Β· Checkpoint: `{ckpt_name}` Β· Device: `{device_str}`" | |
| ) | |
| return vis, summary, rows | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 6. Gradio UI | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| THEME = gr.themes.Base( | |
| primary_hue=gr.themes.colors.orange, | |
| secondary_hue=gr.themes.colors.slate, | |
| neutral_hue=gr.themes.colors.slate, | |
| font=[gr.themes.GoogleFont("Inter"), "sans-serif"], | |
| ).set( | |
| body_background_fill="#0f1117", | |
| body_background_fill_dark="#0f1117", | |
| block_background_fill="#1a1e2e", | |
| block_background_fill_dark="#1a1e2e", | |
| block_border_color="#2d3148", | |
| block_border_color_dark="#2d3148", | |
| block_label_text_color="#c9d1e0", | |
| block_label_text_color_dark="#c9d1e0", | |
| input_background_fill="#22273a", | |
| input_background_fill_dark="#22273a", | |
| slider_color="#f97316", | |
| slider_color_dark="#f97316", | |
| button_primary_background_fill="#f97316", | |
| button_primary_background_fill_hover="#ea6a0b", | |
| button_primary_text_color="#ffffff", | |
| body_text_color="#e2e8f0", | |
| body_text_color_dark="#e2e8f0", | |
| ) | |
| CSS = """ | |
| #title-banner { | |
| background: linear-gradient(135deg, #1e2235 0%, #252b42 50%, #1a1e2e 100%); | |
| border: 1px solid #f97316; | |
| border-radius: 12px; | |
| padding: 24px 32px; | |
| margin-bottom: 8px; | |
| } | |
| #title-banner h1 { color: #f97316 !important; margin: 0 0 4px 0; font-size: 2rem; } | |
| #title-banner p { color: #94a3b8 !important; margin: 0; } | |
| .detect-table thead th { background: #252b42 !important; color: #f97316 !important; } | |
| .detect-table tbody tr:nth-child(even) { background: #1f2333 !important; } | |
| .mode-card { border-left: 3px solid #f97316; padding-left: 10px; } | |
| footer { display: none !important; } | |
| """ | |
| def build_ui(): | |
| with gr.Blocks(theme=THEME, css=CSS, title="FADNet β Thermal Defect Detector") as demo: | |
| # ββ Header ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| gr.HTML(""" | |
| <div id="title-banner"> | |
| <h1>π₯ FADNet β Thermal Defect Detector</h1> | |
| <p>Hotspot & Crack detection in thermal images Β· YOLOv8 + CoordAtt Β· | |
| mAP@0.5 = 91.51% (Multi-Res WBF)</p> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TAB 1 β Inference | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π― Inference", id="infer"): | |
| with gr.Row(equal_height=False): | |
| # ββ LEFT COLUMN β Settings βββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("### βοΈ Checkpoint") | |
| ckpt_radio = gr.Radio( | |
| choices=list(CHECKPOINTS.keys()), | |
| value=list(CHECKPOINTS.keys())[0], | |
| label="Model checkpoint", | |
| show_label=False, | |
| ) | |
| gr.Markdown("### π§ Inference Mode") | |
| mode_radio = gr.Radio( | |
| choices=["Standard", "Multi-Res WBF", "SAHI"], | |
| value="Standard", | |
| label="Inference mode", | |
| show_label=False, | |
| ) | |
| mode_desc = gr.Markdown( | |
| "<div class='mode-card'>Single-scale inference. Fast & accurate.</div>", | |
| elem_classes=["mode-card"], | |
| ) | |
| gr.Markdown("### π§ Per-Class Thresholds") | |
| conf_hot = gr.Slider( | |
| 0.01, 0.99, value=DEFAULT_CONF_HOTSPOT, step=0.01, | |
| label="Hotspot confidence threshold", | |
| ) | |
| conf_crk = gr.Slider( | |
| 0.01, 0.99, value=DEFAULT_CONF_CRACK, step=0.01, | |
| label="Crack confidence threshold", | |
| ) | |
| nms_iou = gr.Slider( | |
| 0.10, 0.90, value=0.45, step=0.05, | |
| label="NMS / WBF IoU threshold", | |
| ) | |
| imgsz = gr.Slider( | |
| 320, 1280, value=640, step=32, | |
| label="Model input resolution (px)", | |
| ) | |
| # Multi-Res options | |
| with gr.Group(visible=False) as multires_group: | |
| gr.Markdown("#### Multi-Res WBF Options") | |
| use_736 = gr.Checkbox(value=True, label="Also run at 736 px") | |
| wbf_iou = gr.Slider(0.10, 0.80, value=0.45, step=0.05, label="WBF IoU threshold") | |
| wbf_skip = gr.Slider(0.001, 0.10, value=0.001, step=0.001, label="WBF skip box threshold") | |
| # SAHI options | |
| with gr.Group(visible=False) as sahi_group: | |
| gr.Markdown("#### SAHI Options") | |
| sahi_tile = gr.Slider(192, 512, value=320, step=32, label="Tile size (px)") | |
| sahi_overlap = gr.Slider(0.10, 0.60, value=0.40, step=0.05, label="Tile overlap ratio") | |
| sahi_full_w = gr.Slider(0.5, 3.0, value=1.5, step=0.1, label="Full-image weight (vs tile=1.0)") | |
| run_btn = gr.Button("βΆ Run Detection", variant="primary", size="lg") | |
| clear_btn = gr.Button("π Clear", variant="secondary") | |
| # ββ RIGHT COLUMN β I/O ββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| input_img = gr.Image( | |
| type="numpy", label="Input Image", | |
| height=400, | |
| ) | |
| output_img = gr.Image( | |
| type="numpy", label="Detection Result", | |
| height=400, | |
| ) | |
| summary_md = gr.Markdown("*Upload an image and click **Run Detection**.*") | |
| detect_table = gr.Dataframe( | |
| headers=["Class", "Confidence", "Box [x1, y1, x2, y2]"], | |
| datatype=["str", "str", "str"], | |
| label="Detections", | |
| wrap=True, | |
| elem_classes=["detect-table"], | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TAB 2 β Analytics | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π Analytics"): | |
| gr.Markdown("### Pre-computed Metrics from Training Run") | |
| CHART_META = [ | |
| ("fadnet_metrics_dashboard.png", "π Full Metrics Dashboard"), | |
| ("fadnet_advanced_push.png", "π Technique Comparison"), | |
| ("perclass_thresh_heatmap.png", "π‘οΈ Per-Class Threshold Heatmap"), | |
| ("f1_optimal_curves.png", "π F1-Optimal Threshold Curves"), | |
| ("fadnet_result_grid.png", "πΌοΈ Result Image Grid (GT vs Pred)"), | |
| ("fadnet_live_inference.png", "π΄ Live Inference Samples"), | |
| ("fadnet_bbox_quality.png", "π Bounding Box Quality Inspector"), | |
| ] | |
| working_dir = BASE_DIR / "working" | |
| for fname, label in CHART_META: | |
| fpath = working_dir / fname | |
| if fpath.exists(): | |
| gr.Markdown(f"#### {label}") | |
| gr.Image(value=str(fpath), label=label, show_label=False) | |
| else: | |
| gr.Markdown( | |
| f"*`{fname}` not found β run the notebook to generate it.*" | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TAB 3 β Model Info | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("βΉοΈ Model Info"): | |
| gr.Markdown(""" | |
| ## FADNet β Architecture & Results | |
| ### ποΈ Architecture | |
| FADNet is a **YOLOv8-based thermal defect detector** enhanced with **CoordAttention (CoordAtt)** | |
| β a coordinate-aware channel attention mechanism that captures long-range spatial dependencies | |
| in both horizontal and vertical directions simultaneously. | |
| | Component | Detail | | |
| |-------------------|---------------------------------------------| | |
| | Base architecture | YOLOv8 | | |
| | Attention module | CoordAtt (Hou et al., 2021) | | |
| | Classes | Hotspot (thermal) Β· Crack (structural) | | |
| | Input resolution | 640 Γ 640 px (default) | | |
| | Dataset | Thermal-H&C (Roboflow) | | |
| --- | |
| ### π Checkpoints | |
| | File | Role | | |
| |----------------------------|------------------------------| | |
| | `fadnet_finetune_best.pt` | **Primary** β fine-tuned FADNet (**recommended**) | | |
| | `fadnet_yolo_best.pt` | YOLO backbone variant | | |
| | `fadnet_unet_best.pth` | U-Net segmentation head | | |
| --- | |
| ### π Benchmark Results (test set) | |
| | Technique | mAP@0.5 | Hotspot AP | Crack AP | Ξ vs Baseline | | |
| |-----------------------|---------|------------|----------|---------------| | |
| | Baseline WBF | 90.92% | β | β | β | | |
| | Per-class threshold | 90.40% | β | β | β0.52% | | |
| | + Soft-NMS (Ο=0.3) | 90.60% | β | β | β0.32% | | |
| | **Multi-res WBF** π | **91.51%** | **94.15%** | **88.86%** | **+0.59%** | | |
| | SAHI (tile=384) | 82.92% | β | β | β8.00% | | |
| --- | |
| ### π¬ Inference Modes | |
| **Standard** β Single-scale YOLO inference with per-class thresholds. | |
| Fast, minimal overhead. Use for quick evaluation. | |
| **Multi-Res WBF** β Runs inference at 640 px and 736 px, then fuses predictions | |
| with Weighted Box Fusion. Achieves the best mAP@0.5 (91.51%). | |
| **SAHI** β Sliced Adaptive Inference (Akyon et al., 2022). Divides the image into | |
| overlapping tiles, runs the model on each, then merges with WBF. Best for detecting | |
| very small hotspots in high-resolution images. | |
| --- | |
| ### ποΈ F1-Optimal Thresholds (paper settings) | |
| ``` | |
| crack_conf = 0.20 | |
| hotspot_conf = 0.20 | |
| mAP@0.5 = 0.9151 | |
| mean F1 = ~0.88 | |
| ``` | |
| """) | |
| # ββ Event Wiring ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODE_DESCS = { | |
| "Standard": "<div class='mode-card'>Single-scale inference at your chosen resolution. Fast & accurate.</div>", | |
| "Multi-Res WBF":"<div class='mode-card'>Runs at 640 & 736 px, fuses with WBF β <strong>best mAP@0.5 (91.51%)</strong>.</div>", | |
| "SAHI": "<div class='mode-card'>Slices image into overlapping tiles. Best for small hotspots in high-res images.</div>", | |
| } | |
| def on_mode_change(mode): | |
| return ( | |
| MODE_DESCS[mode], | |
| gr.update(visible=(mode == "Multi-Res WBF")), | |
| gr.update(visible=(mode == "SAHI")), | |
| ) | |
| mode_radio.change( | |
| on_mode_change, | |
| inputs=mode_radio, | |
| outputs=[mode_desc, multires_group, sahi_group], | |
| ) | |
| run_btn.click( | |
| run_inference, | |
| inputs=[ | |
| input_img, ckpt_radio, mode_radio, | |
| conf_hot, conf_crk, nms_iou, imgsz, | |
| use_736, wbf_iou, wbf_skip, | |
| sahi_tile, sahi_overlap, sahi_full_w, | |
| ], | |
| outputs=[output_img, summary_md, detect_table], | |
| ) | |
| clear_btn.click( | |
| lambda: (None, None, "*Upload an image and click **Run Detection**.*", []), | |
| outputs=[input_img, output_img, summary_md, detect_table], | |
| ) | |
| return demo | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 7. Entry point | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| demo = build_ui() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| favicon_path=None, | |
| ) | |