Raminnit commited on
Commit
f7828db
Β·
verified Β·
1 Parent(s): 8771e52

Upload 4 files

Browse files
Files changed (4) hide show
  1. app (1).py +751 -0
  2. fadnet_finetune_best.pt +3 -0
  3. fadnet_yolo_best.pt +3 -0
  4. requirements.txt +5 -0
app (1).py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FADNet Gradio GUI
3
+ =================
4
+ Thermal Hotspot & Crack Detection β€” Interactive Inference Dashboard
5
+ Supports: Standard, Multi-Resolution WBF, and SAHI inference modes.
6
+
7
+ Run:
8
+ pip install gradio ultralytics ensemble-boxes opencv-python-headless
9
+ python app.py
10
+ """
11
+
12
+ import os, sys, math, cv2, pathlib, warnings, textwrap
13
+ import numpy as np
14
+ import gradio as gr
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ warnings.filterwarnings("ignore")
19
+
20
+ # ─────────────────────────────────────────────────────────────────────────────
21
+ # 0. Constants & Paths (edit these to match your environment)
22
+ # ─────────────────────────────────────────────────────────────────────────────
23
+ BASE_DIR = pathlib.Path(__file__).parent
24
+ CKPT_DIR = BASE_DIR / "checkpoints"
25
+
26
+ CHECKPOINTS = {
27
+ "FADNet Finetune (Best)": str(CKPT_DIR / "fadnet_finetune_best.pt"),
28
+ "FADNet YOLO Backbone": str(CKPT_DIR / "fadnet_yolo_best.pt"),
29
+ }
30
+
31
+ CLASS_NAMES = ["Hotspot", "Crack"]
32
+ N_CLASSES = 2
33
+
34
+ # F1-optimal defaults (from notebook Cell 19/20)
35
+ DEFAULT_CONF_HOTSPOT = 0.20
36
+ DEFAULT_CONF_CRACK = 0.20
37
+
38
+ # Colour palette (BGR β†’ used by cv2, converted to RGB for Gradio)
39
+ COLORS = {
40
+ "Hotspot": (255, 80, 60), # bright red-orange
41
+ "Crack": ( 60, 140, 255), # cornflower blue
42
+ "GT": ( 0, 220, 0), # green
43
+ "TP": ( 0, 200, 200), # cyan
44
+ "FP": ( 0, 0, 220), # red
45
+ "FN": ( 0, 200, 220), # yellow-ish
46
+ }
47
+
48
+ GALLERY_IMAGES = sorted((BASE_DIR / "working").glob("*.png")) if (BASE_DIR / "working").exists() else []
49
+
50
+ # ─────────────────────────────────────────────────────────────────────────────
51
+ # 1. CoordAtt Patch (required before loading any FADNet checkpoint)
52
+ # ─────────────────────────────────────────────────────────────────────────────
53
+
54
+ class h_sigmoid(nn.Module):
55
+ def forward(self, x): return nn.functional.relu6(x + 3) / 6
56
+
57
+ class h_swish(nn.Module):
58
+ def forward(self, x): return x * h_sigmoid()(x)
59
+
60
+ class CoordAtt(nn.Module):
61
+ def __init__(self, inp, oup=None, reduction=32):
62
+ super().__init__()
63
+ oup = oup or inp
64
+ mip = max(8, inp // reduction)
65
+ self.conv1 = nn.Conv2d(inp, mip, 1, bias=False)
66
+ self.bn1 = nn.BatchNorm2d(mip)
67
+ self.act = h_swish()
68
+ self.conv_h = nn.Conv2d(mip, oup, 1, bias=False)
69
+ self.conv_w = nn.Conv2d(mip, oup, 1, bias=False)
70
+
71
+ def forward(self, x):
72
+ B, C, H, W = x.shape
73
+ xh = x.mean(dim=3, keepdim=True)
74
+ xw = x.mean(dim=2, keepdim=True).permute(0, 1, 3, 2)
75
+ y = torch.cat([xh, xw], dim=2)
76
+ y = self.act(self.bn1(self.conv1(y)))
77
+ xh, xw = torch.split(y, [H, W], dim=2)
78
+ xw = xw.permute(0, 1, 3, 2)
79
+ return x * torch.sigmoid(self.conv_h(xh)) * torch.sigmoid(self.conv_w(xw))
80
+
81
+
82
+ def patch_ultralytics():
83
+ """Inject CoordAtt into Ultralytics so FADNet checkpoints load cleanly."""
84
+ try:
85
+ import ultralytics.nn.modules as M
86
+ import ultralytics.nn.tasks as T
87
+ import shutil
88
+
89
+ M.CoordAtt = CoordAtt
90
+ T.CoordAtt = CoordAtt
91
+
92
+ fake_mod = type(sys)("ultralytics.nn.modules.coord_att")
93
+ fake_mod.CoordAtt = CoordAtt
94
+ fake_mod.h_swish = h_swish
95
+ fake_mod.h_sigmoid = h_sigmoid
96
+ sys.modules["ultralytics.nn.modules.coord_att"] = fake_mod
97
+ M.coord_att = fake_mod
98
+
99
+ d = pathlib.Path(M.__file__).parent
100
+ coord_att_src = textwrap.dedent("""\
101
+ import torch, torch.nn as nn
102
+ class h_sigmoid(nn.Module):
103
+ def forward(self, x): return nn.functional.relu6(x + 3) / 6
104
+ class h_swish(nn.Module):
105
+ def forward(self, x): return x * h_sigmoid()(x)
106
+ class CoordAtt(nn.Module):
107
+ def __init__(self, inp, oup=None, reduction=32):
108
+ super().__init__()
109
+ oup = oup or inp; mip = max(8, inp // reduction)
110
+ self.conv1 = nn.Conv2d(inp, mip, 1, bias=False)
111
+ self.bn1 = nn.BatchNorm2d(mip)
112
+ self.act = h_swish()
113
+ self.conv_h = nn.Conv2d(mip, oup, 1, bias=False)
114
+ self.conv_w = nn.Conv2d(mip, oup, 1, bias=False)
115
+ def forward(self, x):
116
+ B,C,H,W = x.shape
117
+ xh = x.mean(3, keepdim=True)
118
+ xw = x.mean(2, keepdim=True).permute(0,1,3,2)
119
+ y = self.act(self.bn1(self.conv1(torch.cat([xh,xw],2))))
120
+ xh, xw = torch.split(y, [H, W], 2)
121
+ return x*torch.sigmoid(self.conv_h(xh))*torch.sigmoid(self.conv_w(xw.permute(0,1,3,2)))
122
+ """)
123
+ (d / "coord_att.py").write_text(coord_att_src)
124
+
125
+ tp = pathlib.Path(T.__file__).with_suffix(".py")
126
+ txt = tp.read_text()
127
+ if "coord_att" not in txt:
128
+ tp.write_text("from ultralytics.nn.modules.coord_att import CoordAtt\n" + txt)
129
+
130
+ shutil.rmtree(tp.parent / "__pycache__", ignore_errors=True)
131
+ shutil.rmtree(d / "__pycache__", ignore_errors=True)
132
+ return True, "CoordAtt patch applied βœ“"
133
+ except Exception as e:
134
+ return False, f"Patch failed: {e}"
135
+
136
+
137
+ # Apply patch at startup
138
+ _patch_ok, _patch_msg = patch_ultralytics()
139
+ print(_patch_msg)
140
+
141
+
142
+ # ─────────────────────────────────────────────────────────────────────────────
143
+ # 2. Model Cache
144
+ # ─────────────────────────────────────────────────────────────────────────────
145
+ _model_cache: dict[str, object] = {}
146
+
147
+ def load_model(ckpt_name: str):
148
+ """Load (and cache) a YOLO checkpoint by friendly name."""
149
+ from ultralytics import YOLO
150
+
151
+ ckpt_path = CHECKPOINTS.get(ckpt_name)
152
+ if not ckpt_path:
153
+ raise ValueError(f"Unknown checkpoint: {ckpt_name}")
154
+ if not os.path.exists(ckpt_path):
155
+ raise FileNotFoundError(
156
+ f"Checkpoint not found at:\n {ckpt_path}\n\n"
157
+ "Copy the .pt files into the checkpoints/ folder next to app.py."
158
+ )
159
+ if ckpt_name not in _model_cache:
160
+ _model_cache[ckpt_name] = YOLO(ckpt_path)
161
+ return _model_cache[ckpt_name]
162
+
163
+
164
+ # ─────────────────────────────────────────────────────────────────────────────
165
+ # 3. Drawing helpers
166
+ # ─────────────────────────────────────────────────────────────────────────────
167
+
168
+ def _draw_box(img, x1, y1, x2, y2, color_bgr, label, font_scale=0.48, thickness=2):
169
+ cv2.rectangle(img, (x1, y1), (x2, y2), color_bgr, thickness)
170
+ (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 1)
171
+ by = max(y1 - 4, th + 4)
172
+ cv2.rectangle(img, (x1, by - th - 4), (x1 + tw + 6, by), color_bgr, -1)
173
+ cv2.putText(img, label, (x1 + 3, by - 2),
174
+ cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), 1, cv2.LINE_AA)
175
+
176
+
177
+ def annotate_image(img_bgr, boxes_norm, scores, labels,
178
+ conf_thrs=(0.20, 0.20), draw_conf=True):
179
+ """
180
+ Draw predicted bounding boxes on a BGR image copy.
181
+ Returns an RGB numpy array.
182
+ boxes_norm : list of [x1,y1,x2,y2] in [0,1]
183
+ """
184
+ vis = img_bgr.copy()
185
+ H, W = vis.shape[:2]
186
+ order = sorted(range(len(scores)), key=lambda i: -scores[i])
187
+ for i in order:
188
+ lbl = labels[i]
189
+ score = scores[i]
190
+ if score < conf_thrs[lbl]:
191
+ continue
192
+ box = boxes_norm[i]
193
+ x1, y1 = int(box[0] * W), int(box[1] * H)
194
+ x2, y2 = int(box[2] * W), int(box[3] * H)
195
+ col = COLORS[CLASS_NAMES[lbl]]
196
+ text = f"{CLASS_NAMES[lbl]} {score:.2f}" if draw_conf else CLASS_NAMES[lbl]
197
+ _draw_box(vis, x1, y1, x2, y2, col, text)
198
+
199
+ return cv2.cvtColor(vis, cv2.COLOR_BGR2RGB)
200
+
201
+
202
+ # ─────────────────────────────────────────────────────────────────────────────
203
+ # 4. Inference Modes
204
+ # ─────────────────────────────────────────────────────────────────────────────
205
+
206
+ def _yolo_predict(model, img_path_or_arr, imgsz, conf_raw, iou_raw, device):
207
+ """Run YOLO.predict and return (boxes_norm, scores, labels)."""
208
+ is_arr = isinstance(img_path_or_arr, np.ndarray)
209
+ src = img_path_or_arr
210
+
211
+ # Get image dims for normalisation
212
+ if is_arr:
213
+ H, W = src.shape[:2]
214
+ else:
215
+ tmp = cv2.imread(str(img_path_or_arr))
216
+ H, W = tmp.shape[:2]
217
+
218
+ res = model.predict(
219
+ src, imgsz=imgsz, conf=conf_raw, iou=iou_raw,
220
+ verbose=False, save=False, device=device,
221
+ )
222
+ r = res[0]
223
+ boxes, scores, labels = [], [], []
224
+ if len(r.boxes):
225
+ for box in r.boxes:
226
+ x1, y1, x2, y2 = box.xyxy[0].cpu().tolist()
227
+ boxes.append([
228
+ max(0.0, x1 / W), max(0.0, y1 / H),
229
+ min(1.0, x2 / W), min(1.0, y2 / H),
230
+ ])
231
+ scores.append(float(box.conf[0]))
232
+ # Label flip: model cls 0β†’dataset 1 and vice-versa
233
+ labels.append(1 - int(box.cls[0]))
234
+ return boxes, scores, labels
235
+
236
+
237
+ def infer_standard(model, img_bgr, conf_hotspot, conf_crack, nms_iou, imgsz, device):
238
+ """Single-resolution inference."""
239
+ boxes, scores, labels = _yolo_predict(
240
+ model, img_bgr, imgsz, conf_raw=0.01, iou_raw=nms_iou, device=device
241
+ )
242
+ # Apply per-class threshold
243
+ thrs = [conf_hotspot, conf_crack]
244
+ keep = [(b, s, l) for b, s, l in zip(boxes, scores, labels) if s >= thrs[l]]
245
+ if keep:
246
+ b, s, l = zip(*keep)
247
+ return list(b), list(s), list(l)
248
+ return [], [], []
249
+
250
+
251
+ def infer_multires_wbf(model, img_bgr, conf_hotspot, conf_crack,
252
+ nms_iou, imgsz_list, wbf_iou, wbf_skip, device):
253
+ """Multi-resolution Weighted Box Fusion (Lever 3 from notebook)."""
254
+ try:
255
+ from ensemble_boxes import weighted_boxes_fusion
256
+ except ImportError:
257
+ raise ImportError("Install ensemble-boxes: pip install ensemble-boxes")
258
+
259
+ all_boxes, all_scores, all_labels = [], [], []
260
+ for imgsz in imgsz_list:
261
+ b, s, l = _yolo_predict(model, img_bgr, imgsz, 0.01, 0.99, device)
262
+ all_boxes.append(b); all_scores.append(s); all_labels.append(l)
263
+
264
+ final_boxes, final_scores, final_labels = [], [], []
265
+ for cls_id in range(N_CLASSES):
266
+ cb = [[bx for bx, lb in zip(mb, ml) if lb == cls_id]
267
+ for mb, ml in zip(all_boxes, all_labels)]
268
+ cs = [[sc for sc, lb in zip(ms, ml) if lb == cls_id]
269
+ for ms, ml in zip(all_scores, all_labels)]
270
+ if all(len(b) == 0 for b in cb):
271
+ continue
272
+ b_f, s_f, l_f = weighted_boxes_fusion(
273
+ cb, cs, [[cls_id] * len(s) for s in cs],
274
+ weights=[1.0] * len(imgsz_list),
275
+ iou_thr=wbf_iou, skip_box_thr=wbf_skip,
276
+ )
277
+ final_boxes.extend(b_f.tolist())
278
+ final_scores.extend(s_f.tolist())
279
+ final_labels.extend([int(x) for x in l_f])
280
+
281
+ thrs = [conf_hotspot, conf_crack]
282
+ keep = [(b, s, l) for b, s, l in zip(final_boxes, final_scores, final_labels) if s >= thrs[l]]
283
+ if keep:
284
+ b, s, l = zip(*keep)
285
+ return list(b), list(s), list(l)
286
+ return [], [], []
287
+
288
+
289
+ def _generate_tiles(H, W, tile_size, overlap_ratio):
290
+ stride = int(tile_size * (1 - overlap_ratio))
291
+ tiles = []
292
+ y = 0
293
+ while y < H:
294
+ x = 0
295
+ while x < W:
296
+ x2 = min(x + tile_size, W); y2 = min(y + tile_size, H)
297
+ x1 = max(0, x2 - tile_size); y1 = max(0, y2 - tile_size)
298
+ tiles.append((x1, y1, x2, y2))
299
+ if x2 == W: break
300
+ x += stride
301
+ if y2 == H: break
302
+ y += stride
303
+ return tiles
304
+
305
+
306
+ def infer_sahi(model, img_bgr, conf_hotspot, conf_crack,
307
+ tile_size, overlap, model_imgsz, wbf_iou, wbf_skip,
308
+ full_weight, tile_weight, device):
309
+ """SAHI Sliced Inference (Lever 4 from notebook)."""
310
+ try:
311
+ from ensemble_boxes import weighted_boxes_fusion
312
+ except ImportError:
313
+ raise ImportError("Install ensemble-boxes: pip install ensemble-boxes")
314
+
315
+ H, W = img_bgr.shape[:2]
316
+ tiles = _generate_tiles(H, W, tile_size, overlap)
317
+
318
+ all_boxes, all_scores, all_labels, all_weights = [], [], [], []
319
+
320
+ # Full image
321
+ fb, fs, fl = _yolo_predict(model, img_bgr, model_imgsz, 0.01, 0.99, device)
322
+ all_boxes.append(fb); all_scores.append(fs); all_labels.append(fl)
323
+ all_weights.append(full_weight)
324
+
325
+ # Tiles
326
+ for (tx1, ty1, tx2, ty2) in tiles:
327
+ tile = img_bgr[ty1:ty2, tx1:tx2]
328
+ tH, tW = tile.shape[:2]
329
+ if tH < 8 or tW < 8:
330
+ continue
331
+ tb, ts, tl = _yolo_predict(model, tile, model_imgsz, 0.01, 0.99, device)
332
+ # remap tile-relative coords β†’ full image normalised
333
+ mapped_boxes = []
334
+ for bx in tb:
335
+ ax1 = (bx[0] * tW + tx1) / W; ay1 = (bx[1] * tH + ty1) / H
336
+ ax2 = (bx[2] * tW + tx1) / W; ay2 = (bx[3] * tH + ty1) / H
337
+ mapped_boxes.append([
338
+ max(0.0, ax1), max(0.0, ay1),
339
+ min(1.0, ax2), min(1.0, ay2),
340
+ ])
341
+ all_boxes.append(mapped_boxes); all_scores.append(ts); all_labels.append(tl)
342
+ all_weights.append(tile_weight)
343
+
344
+ # WBF fusion
345
+ final_boxes, final_scores, final_labels = [], [], []
346
+ for cls_id in range(N_CLASSES):
347
+ cb = [[bx for bx, lb in zip(mb, ml) if lb == cls_id]
348
+ for mb, ml in zip(all_boxes, all_labels)]
349
+ cs = [[sc for sc, lb in zip(ms, ml) if lb == cls_id]
350
+ for ms, ml in zip(all_scores, all_labels)]
351
+ if all(len(b) == 0 for b in cb):
352
+ continue
353
+ b_f, s_f, l_f = weighted_boxes_fusion(
354
+ cb, cs, [[cls_id] * len(s) for s in cs],
355
+ weights=all_weights,
356
+ iou_thr=wbf_iou, skip_box_thr=wbf_skip,
357
+ )
358
+ final_boxes.extend(b_f.tolist()); final_scores.extend(s_f.tolist())
359
+ final_labels.extend([int(x) for x in l_f])
360
+
361
+ thrs = [conf_hotspot, conf_crack]
362
+ keep = [(b, s, l) for b, s, l in zip(final_boxes, final_scores, final_labels) if s >= thrs[l]]
363
+ if keep:
364
+ b, s, l = zip(*keep)
365
+ return list(b), list(s), list(l)
366
+ return [], [], []
367
+
368
+
369
+ # ─────────────────────────────────────────────────────────────────────────────
370
+ # 5. Main inference callback (called by Gradio)
371
+ # ─────────────────────────────────────────────────────────────────────────────
372
+
373
+ def run_inference(
374
+ image_np,
375
+ ckpt_name,
376
+ infer_mode,
377
+ conf_hotspot,
378
+ conf_crack,
379
+ nms_iou,
380
+ imgsz,
381
+ # Multi-res options
382
+ use_736,
383
+ wbf_iou,
384
+ wbf_skip,
385
+ # SAHI options
386
+ sahi_tile,
387
+ sahi_overlap,
388
+ sahi_full_weight,
389
+ ):
390
+ if image_np is None:
391
+ return None, "⚠️ Please upload an image first.", []
392
+
393
+ # ── Resolve device ──────────────────────────────────────────────────────
394
+ device = 0 if torch.cuda.is_available() else "cpu"
395
+
396
+ # ── Load model ──────────────────────────────────────────────────────────
397
+ try:
398
+ model = load_model(ckpt_name)
399
+ except (FileNotFoundError, ValueError) as e:
400
+ return None, f"❌ {e}", []
401
+
402
+ # ── Convert image ────────────────────────────────────────────────────────
403
+ img_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
404
+
405
+ try:
406
+ if infer_mode == "Standard":
407
+ boxes, scores, labels = infer_standard(
408
+ model, img_bgr, conf_hotspot, conf_crack, nms_iou, int(imgsz), device
409
+ )
410
+ elif infer_mode == "Multi-Res WBF":
411
+ res_list = [640, 736] if use_736 else [640]
412
+ boxes, scores, labels = infer_multires_wbf(
413
+ model, img_bgr, conf_hotspot, conf_crack,
414
+ nms_iou, res_list, wbf_iou, wbf_skip, device
415
+ )
416
+ elif infer_mode == "SAHI":
417
+ boxes, scores, labels = infer_sahi(
418
+ model, img_bgr, conf_hotspot, conf_crack,
419
+ int(sahi_tile), sahi_overlap, int(imgsz),
420
+ wbf_iou, wbf_skip, sahi_full_weight, 1.0, device
421
+ )
422
+ else:
423
+ return None, "Unknown inference mode.", []
424
+ except Exception as e:
425
+ import traceback
426
+ return None, f"❌ Inference error:\n{traceback.format_exc()}", []
427
+
428
+ # ── Annotate ─────────────────────────────────────────────────────────────
429
+ thrs = [conf_hotspot, conf_crack]
430
+ vis = annotate_image(img_bgr, boxes, scores, labels, conf_thrs=thrs)
431
+
432
+ # ── Build detection table ─────────────────────────────────────────────────
433
+ rows = []
434
+ for b, s, l in sorted(
435
+ zip(boxes, scores, labels), key=lambda x: -x[1]
436
+ ):
437
+ if s < thrs[l]:
438
+ continue
439
+ rows.append([
440
+ CLASS_NAMES[l],
441
+ f"{s:.3f}",
442
+ f"[{b[0]:.3f}, {b[1]:.3f}, {b[2]:.3f}, {b[3]:.3f}]",
443
+ ])
444
+
445
+ # ── Summary text ──────────────────────────────────────────────────────────
446
+ n_hotspot = sum(1 for l, s in zip(labels, scores) if l == 0 and s >= thrs[l])
447
+ n_crack = sum(1 for l, s in zip(labels, scores) if l == 1 and s >= thrs[l])
448
+ device_str = f"GPU (cuda:{device})" if device != "cpu" else "CPU"
449
+ summary = (
450
+ f"βœ… **{n_hotspot + n_crack} detection(s)** β€” "
451
+ f"{n_hotspot} Hotspot Β· {n_crack} Crack\n\n"
452
+ f"Mode: `{infer_mode}` Β· Checkpoint: `{ckpt_name}` Β· Device: `{device_str}`"
453
+ )
454
+
455
+ return vis, summary, rows
456
+
457
+
458
+ # ─────────────────────────────────────────────────────────────────────────────
459
+ # 6. Gradio UI
460
+ # ────────────────────────────────────────────────────────��────────────────────
461
+
462
+ THEME = gr.themes.Base(
463
+ primary_hue=gr.themes.colors.orange,
464
+ secondary_hue=gr.themes.colors.slate,
465
+ neutral_hue=gr.themes.colors.slate,
466
+ font=[gr.themes.GoogleFont("Inter"), "sans-serif"],
467
+ ).set(
468
+ body_background_fill="#0f1117",
469
+ body_background_fill_dark="#0f1117",
470
+ block_background_fill="#1a1e2e",
471
+ block_background_fill_dark="#1a1e2e",
472
+ block_border_color="#2d3148",
473
+ block_border_color_dark="#2d3148",
474
+ block_label_text_color="#c9d1e0",
475
+ block_label_text_color_dark="#c9d1e0",
476
+ input_background_fill="#22273a",
477
+ input_background_fill_dark="#22273a",
478
+ slider_color="#f97316",
479
+ slider_color_dark="#f97316",
480
+ button_primary_background_fill="#f97316",
481
+ button_primary_background_fill_hover="#ea6a0b",
482
+ button_primary_text_color="#ffffff",
483
+ body_text_color="#e2e8f0",
484
+ body_text_color_dark="#e2e8f0",
485
+ )
486
+
487
+ CSS = """
488
+ #title-banner {
489
+ background: linear-gradient(135deg, #1e2235 0%, #252b42 50%, #1a1e2e 100%);
490
+ border: 1px solid #f97316;
491
+ border-radius: 12px;
492
+ padding: 24px 32px;
493
+ margin-bottom: 8px;
494
+ }
495
+ #title-banner h1 { color: #f97316 !important; margin: 0 0 4px 0; font-size: 2rem; }
496
+ #title-banner p { color: #94a3b8 !important; margin: 0; }
497
+
498
+ .detect-table thead th { background: #252b42 !important; color: #f97316 !important; }
499
+ .detect-table tbody tr:nth-child(even) { background: #1f2333 !important; }
500
+
501
+ .mode-card { border-left: 3px solid #f97316; padding-left: 10px; }
502
+
503
+ footer { display: none !important; }
504
+ """
505
+
506
+ def build_ui():
507
+ with gr.Blocks(theme=THEME, css=CSS, title="FADNet β€” Thermal Defect Detector") as demo:
508
+
509
+ # ── Header ──────────────────────────────────────────────────────────
510
+ gr.HTML("""
511
+ <div id="title-banner">
512
+ <h1>πŸ”₯ FADNet β€” Thermal Defect Detector</h1>
513
+ <p>Hotspot &amp; Crack detection in thermal images Β· YOLOv8 + CoordAtt Β·
514
+ mAP@0.5 = 91.51% (Multi-Res WBF)</p>
515
+ </div>
516
+ """)
517
+
518
+ with gr.Tabs():
519
+
520
+ # ══════════════════════════════════════════════════════════════════
521
+ # TAB 1 β€” Inference
522
+ # ══════════════════════════════════════════════════════════════════
523
+ with gr.Tab("🎯 Inference", id="infer"):
524
+ with gr.Row(equal_height=False):
525
+
526
+ # ── LEFT COLUMN β€” Settings ─────────────────────────────
527
+ with gr.Column(scale=1, min_width=300):
528
+ gr.Markdown("### βš™οΈ Checkpoint")
529
+ ckpt_radio = gr.Radio(
530
+ choices=list(CHECKPOINTS.keys()),
531
+ value=list(CHECKPOINTS.keys())[0],
532
+ label="Model checkpoint",
533
+ show_label=False,
534
+ )
535
+
536
+ gr.Markdown("### 🧠 Inference Mode")
537
+ mode_radio = gr.Radio(
538
+ choices=["Standard", "Multi-Res WBF", "SAHI"],
539
+ value="Standard",
540
+ label="Inference mode",
541
+ show_label=False,
542
+ )
543
+ mode_desc = gr.Markdown(
544
+ "<div class='mode-card'>Single-scale inference. Fast & accurate.</div>",
545
+ elem_classes=["mode-card"],
546
+ )
547
+
548
+ gr.Markdown("### πŸ”§ Per-Class Thresholds")
549
+ conf_hot = gr.Slider(
550
+ 0.01, 0.99, value=DEFAULT_CONF_HOTSPOT, step=0.01,
551
+ label="Hotspot confidence threshold",
552
+ )
553
+ conf_crk = gr.Slider(
554
+ 0.01, 0.99, value=DEFAULT_CONF_CRACK, step=0.01,
555
+ label="Crack confidence threshold",
556
+ )
557
+ nms_iou = gr.Slider(
558
+ 0.10, 0.90, value=0.45, step=0.05,
559
+ label="NMS / WBF IoU threshold",
560
+ )
561
+ imgsz = gr.Slider(
562
+ 320, 1280, value=640, step=32,
563
+ label="Model input resolution (px)",
564
+ )
565
+
566
+ # Multi-Res options
567
+ with gr.Group(visible=False) as multires_group:
568
+ gr.Markdown("#### Multi-Res WBF Options")
569
+ use_736 = gr.Checkbox(value=True, label="Also run at 736 px")
570
+ wbf_iou = gr.Slider(0.10, 0.80, value=0.45, step=0.05, label="WBF IoU threshold")
571
+ wbf_skip = gr.Slider(0.001, 0.10, value=0.001, step=0.001, label="WBF skip box threshold")
572
+
573
+ # SAHI options
574
+ with gr.Group(visible=False) as sahi_group:
575
+ gr.Markdown("#### SAHI Options")
576
+ sahi_tile = gr.Slider(192, 512, value=320, step=32, label="Tile size (px)")
577
+ sahi_overlap = gr.Slider(0.10, 0.60, value=0.40, step=0.05, label="Tile overlap ratio")
578
+ sahi_full_w = gr.Slider(0.5, 3.0, value=1.5, step=0.1, label="Full-image weight (vs tile=1.0)")
579
+
580
+ run_btn = gr.Button("β–Ά Run Detection", variant="primary", size="lg")
581
+ clear_btn = gr.Button("πŸ—‘ Clear", variant="secondary")
582
+
583
+ # ── RIGHT COLUMN β€” I/O ────────────────────────────────
584
+ with gr.Column(scale=2):
585
+ with gr.Row():
586
+ input_img = gr.Image(
587
+ type="numpy", label="Input Image",
588
+ height=400,
589
+ )
590
+ output_img = gr.Image(
591
+ type="numpy", label="Detection Result",
592
+ height=400,
593
+ )
594
+
595
+ summary_md = gr.Markdown("*Upload an image and click **Run Detection**.*")
596
+
597
+ detect_table = gr.Dataframe(
598
+ headers=["Class", "Confidence", "Box [x1, y1, x2, y2]"],
599
+ datatype=["str", "str", "str"],
600
+ label="Detections",
601
+ wrap=True,
602
+ elem_classes=["detect-table"],
603
+ )
604
+
605
+ # ══════════════════════════════════════════════════════════════════
606
+ # TAB 2 β€” Analytics
607
+ # ══════════════════════════════════════════════════════════════════
608
+ with gr.Tab("πŸ“Š Analytics"):
609
+ gr.Markdown("### Pre-computed Metrics from Training Run")
610
+
611
+ CHART_META = [
612
+ ("fadnet_metrics_dashboard.png", "πŸ“ˆ Full Metrics Dashboard"),
613
+ ("fadnet_advanced_push.png", "πŸš€ Technique Comparison"),
614
+ ("perclass_thresh_heatmap.png", "🌑️ Per-Class Threshold Heatmap"),
615
+ ("f1_optimal_curves.png", "πŸ“‰ F1-Optimal Threshold Curves"),
616
+ ("fadnet_result_grid.png", "πŸ–ΌοΈ Result Image Grid (GT vs Pred)"),
617
+ ("fadnet_live_inference.png", "πŸ”΄ Live Inference Samples"),
618
+ ("fadnet_bbox_quality.png", "πŸ” Bounding Box Quality Inspector"),
619
+ ]
620
+
621
+ working_dir = BASE_DIR / "working"
622
+ for fname, label in CHART_META:
623
+ fpath = working_dir / fname
624
+ if fpath.exists():
625
+ gr.Markdown(f"#### {label}")
626
+ gr.Image(value=str(fpath), label=label, show_label=False)
627
+ else:
628
+ gr.Markdown(
629
+ f"*`{fname}` not found β€” run the notebook to generate it.*"
630
+ )
631
+
632
+ # ══════════════════════════════════════════════════════════════════
633
+ # TAB 3 β€” Model Info
634
+ # ══════════════════════════════════════════════════════════════════
635
+ with gr.Tab("ℹ️ Model Info"):
636
+ gr.Markdown("""
637
+ ## FADNet β€” Architecture & Results
638
+
639
+ ### πŸ—οΈ Architecture
640
+ FADNet is a **YOLOv8-based thermal defect detector** enhanced with **CoordAttention (CoordAtt)**
641
+ β€” a coordinate-aware channel attention mechanism that captures long-range spatial dependencies
642
+ in both horizontal and vertical directions simultaneously.
643
+
644
+ | Component | Detail |
645
+ |-------------------|---------------------------------------------|
646
+ | Base architecture | YOLOv8 |
647
+ | Attention module | CoordAtt (Hou et al., 2021) |
648
+ | Classes | Hotspot (thermal) Β· Crack (structural) |
649
+ | Input resolution | 640 Γ— 640 px (default) |
650
+ | Dataset | Thermal-H&C (Roboflow) |
651
+
652
+ ---
653
+
654
+ ### πŸ“‹ Checkpoints
655
+
656
+ | File | Role |
657
+ |----------------------------|------------------------------|
658
+ | `fadnet_finetune_best.pt` | **Primary** β€” fine-tuned FADNet (**recommended**) |
659
+ | `fadnet_yolo_best.pt` | YOLO backbone variant |
660
+ | `fadnet_unet_best.pth` | U-Net segmentation head |
661
+
662
+ ---
663
+
664
+ ### πŸ“ˆ Benchmark Results (test set)
665
+
666
+ | Technique | mAP@0.5 | Hotspot AP | Crack AP | Ξ” vs Baseline |
667
+ |-----------------------|---------|------------|----------|---------------|
668
+ | Baseline WBF | 90.92% | β€” | β€” | β€” |
669
+ | Per-class threshold | 90.40% | β€” | β€” | βˆ’0.52% |
670
+ | + Soft-NMS (Οƒ=0.3) | 90.60% | β€” | β€” | βˆ’0.32% |
671
+ | **Multi-res WBF** πŸ† | **91.51%** | **94.15%** | **88.86%** | **+0.59%** |
672
+ | SAHI (tile=384) | 82.92% | β€” | β€” | βˆ’8.00% |
673
+
674
+ ---
675
+
676
+ ### πŸ”¬ Inference Modes
677
+
678
+ **Standard** β€” Single-scale YOLO inference with per-class thresholds.
679
+ Fast, minimal overhead. Use for quick evaluation.
680
+
681
+ **Multi-Res WBF** β€” Runs inference at 640 px and 736 px, then fuses predictions
682
+ with Weighted Box Fusion. Achieves the best mAP@0.5 (91.51%).
683
+
684
+ **SAHI** β€” Sliced Adaptive Inference (Akyon et al., 2022). Divides the image into
685
+ overlapping tiles, runs the model on each, then merges with WBF. Best for detecting
686
+ very small hotspots in high-resolution images.
687
+
688
+ ---
689
+
690
+ ### πŸŽ›οΈ F1-Optimal Thresholds (paper settings)
691
+ ```
692
+ crack_conf = 0.20
693
+ hotspot_conf = 0.20
694
+ mAP@0.5 = 0.9151
695
+ mean F1 = ~0.88
696
+ ```
697
+ """)
698
+
699
+ # ── Event Wiring ────────────────────────────────────────────────────
700
+
701
+ MODE_DESCS = {
702
+ "Standard": "<div class='mode-card'>Single-scale inference at your chosen resolution. Fast &amp; accurate.</div>",
703
+ "Multi-Res WBF":"<div class='mode-card'>Runs at 640 &amp; 736 px, fuses with WBF β€” <strong>best mAP@0.5 (91.51%)</strong>.</div>",
704
+ "SAHI": "<div class='mode-card'>Slices image into overlapping tiles. Best for small hotspots in high-res images.</div>",
705
+ }
706
+
707
+ def on_mode_change(mode):
708
+ return (
709
+ MODE_DESCS[mode],
710
+ gr.update(visible=(mode == "Multi-Res WBF")),
711
+ gr.update(visible=(mode == "SAHI")),
712
+ )
713
+
714
+ mode_radio.change(
715
+ on_mode_change,
716
+ inputs=mode_radio,
717
+ outputs=[mode_desc, multires_group, sahi_group],
718
+ )
719
+
720
+ run_btn.click(
721
+ run_inference,
722
+ inputs=[
723
+ input_img, ckpt_radio, mode_radio,
724
+ conf_hot, conf_crk, nms_iou, imgsz,
725
+ use_736, wbf_iou, wbf_skip,
726
+ sahi_tile, sahi_overlap, sahi_full_w,
727
+ ],
728
+ outputs=[output_img, summary_md, detect_table],
729
+ )
730
+
731
+ clear_btn.click(
732
+ lambda: (None, None, "*Upload an image and click **Run Detection**.*", []),
733
+ outputs=[input_img, output_img, summary_md, detect_table],
734
+ )
735
+
736
+ return demo
737
+
738
+
739
+ # ─────────────────────────────────────────────────────────────────────────────
740
+ # 7. Entry point
741
+ # ─────────────────────────────────────────────────────────────────────────────
742
+
743
+ if __name__ == "__main__":
744
+ demo = build_ui()
745
+ demo.launch(
746
+ server_name="0.0.0.0",
747
+ server_port=7860,
748
+ share=False,
749
+ show_error=True,
750
+ favicon_path=None,
751
+ )
fadnet_finetune_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:636630314d68463c16ec53d1d94310a8f417dd68636e38f960b111ac015e5a06
3
+ size 29437836
fadnet_yolo_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbb5a164a345db498e36e66d3c8ea72f35def1dd58fd742dba8dbdfeff4495a0
3
+ size 29437900
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ultralytics
2
+ ensemble-boxes
3
+ opencv-python-headless
4
+ torch
5
+ torchvision