csmith715 commited on
Commit
8021aca
·
1 Parent(s): 621ba56

Adding Tiling functionality

Browse files
Files changed (3) hide show
  1. app.py +49 -57
  2. tiling.py +238 -0
  3. tiling_test.py +231 -0
app.py CHANGED
@@ -8,6 +8,7 @@ import PIL.Image as Image
8
  from fastapi import FastAPI, UploadFile, File, HTTPException, Request
9
  from pydantic import BaseModel
10
  from ultralytics import YOLO
 
11
 
12
  MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", 8 * 1024 * 1024)) # 8 MB default
13
  MAX_SIDE = int(os.getenv("MAX_SIDE", 2000)) # downscale largest side to this
@@ -20,14 +21,17 @@ HIGH_CLASS_NAMES = [
20
 
21
  LOW_CLASS_NAMES = ["shop_bw", "shop_sw", "field_bw", "Insulation"]
22
 
 
 
23
  # -----------------------------
24
  # App setup
25
  # -----------------------------
26
 
27
  app = FastAPI(title="YOLO Weld Type Detector API", version="1.0.0")
28
 
29
- model = YOLO("top_reduced_best.pt")
30
- low_model = YOLO("best_low_072725.pt")
 
31
 
32
 
33
  # -----------------------------
@@ -58,29 +62,45 @@ def downscale_if_needed(img_rgb: np.ndarray) -> np.ndarray:
58
  new_w, new_h = int(w * scale), int(h * scale)
59
  return cv2.resize(img_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA)
60
 
61
- def detect_weld_types(image_bgr: np.ndarray, model_type: str) -> dict:
62
- if model_type == "top":
63
- results = model(image_bgr)
64
- class_names = HIGH_CLASS_NAMES
65
- else:
66
- results = low_model(image_bgr)
67
- class_names = LOW_CLASS_NAMES
68
-
69
- boxes = results[0].boxes
70
- class_ids = boxes.cls.cpu().numpy().astype(int) if boxes and boxes.cls is not None else []
71
-
72
- counts = {}
73
- for cid in class_ids:
74
- if 0 <= cid < len(class_names):
75
- name = class_names[cid]
76
- counts[name] = counts.get(name, 0) + 1
77
  return counts
78
-
79
- # def merge_counts(a: dict, b: dict) -> dict:
80
- # out = dict(a)
81
- # for k, v in b.items():
82
- # out[k] = out.get(k, 0) + v
83
- # return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  # -----------------------------
86
  # Endpoints
@@ -110,11 +130,11 @@ async def predict_multipart(file: UploadFile = File(default=None)):
110
 
111
  img_rgb = downscale_if_needed(pil_to_numpy_rgb(img))
112
  img_bgr = numpy_rgb_to_bgr(img_rgb)
113
-
114
- high = detect_weld_types(img_bgr, "top")
115
- low = detect_weld_types(img_bgr, "low")
116
- merged = high | low
117
- return PredictResponse(detections=merged)
118
 
119
  @app.post("/ping")
120
  async def ping():
@@ -126,34 +146,6 @@ async def echo(req: Request):
126
  ct = req.headers.get("content-type", "")
127
  return {"ok": True, "content_type": ct}
128
 
129
- # @app.post("/predict_base64", response_model=PredictResponse)
130
- # def predict_base64(payload: PredictQuery = Body(...)):
131
- # b64 = payload.image_base64
132
- # # Size guard for base64 (approx raw size)
133
- # try:
134
- # raw = base64.b64decode(b64, validate=True)
135
- # except Exception:
136
- # raise HTTPException(status_code=400, detail="Invalid base64.")
137
- #
138
- # if len(raw) > MAX_UPLOAD_BYTES:
139
- # raise HTTPException(
140
- # status_code=413,
141
- # detail=f"Image too large after base64 decode ({len(raw)/1024/1024:.2f} MB). "
142
- # f"Use multipart /predict or reduce image size."
143
- # )
144
- #
145
- # try:
146
- # img = Image.open(io.BytesIO(raw))
147
- # except Exception:
148
- # raise HTTPException(status_code=400, detail="Invalid image.")
149
- #
150
- # img_rgb = downscale_if_needed(pil_to_numpy_rgb(img))
151
- # img_bgr = numpy_rgb_to_bgr(img_rgb)
152
- #
153
- # high = detect_weld_types(img_bgr, "top")
154
- # low = detect_weld_types(img_bgr, "low")
155
- # return PredictResponse(detections=merge_counts(low, high))
156
-
157
 
158
  if __name__ == "__main__":
159
  uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
8
  from fastapi import FastAPI, UploadFile, File, HTTPException, Request
9
  from pydantic import BaseModel
10
  from ultralytics import YOLO
11
+ from tiling import detect_tiled_softnms
12
 
13
  MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", 8 * 1024 * 1024)) # 8 MB default
14
  MAX_SIDE = int(os.getenv("MAX_SIDE", 2000)) # downscale largest side to this
 
21
 
22
  LOW_CLASS_NAMES = ["shop_bw", "shop_sw", "field_bw", "Insulation"]
23
 
24
+ ALL_CLASS_NAMES = HIGH_CLASS_NAMES + LOW_CLASS_NAMES
25
+
26
  # -----------------------------
27
  # App setup
28
  # -----------------------------
29
 
30
  app = FastAPI(title="YOLO Weld Type Detector API", version="1.0.0")
31
 
32
+ model = YOLO("best_7-15-25.pt")
33
+ # model = YOLO("top_reduced_best.pt")
34
+ # low_model = YOLO("best_low_072725.pt")
35
 
36
 
37
  # -----------------------------
 
62
  new_w, new_h = int(w * scale), int(h * scale)
63
  return cv2.resize(img_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA)
64
 
65
+ def normalize_prediction(output):
66
+ weld_counts = {}
67
+ for cls_pred in output['cls']:
68
+ weld_key = output['names'][cls_pred]
69
+ weld_counts[weld_key] = weld_counts.get(weld_key, 0) + 1
70
+ return weld_counts
71
+
72
+ def detect_weld_types(image_bgr: np.ndarray, model) -> dict:
73
+ out = detect_tiled_softnms(
74
+ model, image_bgr,
75
+ tile_size=1024, overlap=0.23,
76
+ per_tile_conf=0.2, per_tile_iou=0.7,
77
+ softnms_iou=0.6, softnms_method="hard", softnms_sigma=0.5,
78
+ final_conf=0.38, device=None, imgsz=1280
79
+ )
80
+ counts = normalize_prediction(out)
81
  return counts
82
+ # {'file': 50.724137931034484,
83
+ # 'soft_iou': 0.5982183908045983,
84
+ # 'final_conf': 0.37854022988505753,
85
+ # 'olap': 0.22752873563218376}
86
+
87
+ # def detect_weld_types(image_bgr: np.ndarray, model_type: str) -> dict:
88
+ # if model_type == "top":
89
+ # results = model.predict(image_bgr)
90
+ # class_names = HIGH_CLASS_NAMES
91
+ # else:
92
+ # results = low_model.predict(image_bgr, conf=0.10, iou=0.55, max_det=300, imgsz=1920, augment=True)
93
+ # class_names = LOW_CLASS_NAMES
94
+ #
95
+ # boxes = results[0].boxes
96
+ # class_ids = boxes.cls.cpu().numpy().astype(int) if boxes and boxes.cls is not None else []
97
+ #
98
+ # counts = {}
99
+ # for cid in class_ids:
100
+ # if 0 <= cid < len(class_names):
101
+ # name = class_names[cid]
102
+ # counts[name] = counts.get(name, 0) + 1
103
+ # return counts
104
 
105
  # -----------------------------
106
  # Endpoints
 
130
 
131
  img_rgb = downscale_if_needed(pil_to_numpy_rgb(img))
132
  img_bgr = numpy_rgb_to_bgr(img_rgb)
133
+ welds = detect_weld_types(img_bgr, model)
134
+ # high = detect_weld_types(img_bgr, "top")
135
+ # low = detect_weld_types(img_bgr, "low")
136
+ # merged = high | low
137
+ return PredictResponse(detections=welds)
138
 
139
  @app.post("/ping")
140
  async def ping():
 
146
  ct = req.headers.get("content-type", "")
147
  return {"ok": True, "content_type": ct}
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  if __name__ == "__main__":
151
  uvicorn.run("app:app", host="0.0.0.0", port=7860)
tiling.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tiled_yolo_softnms.py
3
+ Tiled inference + class-wise Soft-NMS for YOLO (Ultralytics).
4
+ - Runs YOLO on overlapping tiles to boost recall on small symbols.
5
+ - Maps all tile detections back to full-image coords.
6
+ - Fuses duplicates with Soft-NMS per class.
7
+
8
+ Usage
9
+ -----
10
+ from ultralytics import YOLO
11
+ import cv2
12
+
13
+ model = YOLO("best.pt") # your YOLO v12/v11/v8 checkpoint
14
+ img = cv2.imread("example.jpg")[:, :, ::-1] # BGR->RGB (optional; YOLO accepts BGR too)
15
+
16
+ out = detect_tiled_softnms(
17
+ model, img,
18
+ tile_size=1024, overlap=0.25,
19
+ per_tile_conf=0.2, per_tile_iou=0.7,
20
+ softnms_iou=0.55, softnms_method="linear", softnms_sigma=0.5,
21
+ final_conf=0.25, device=None, imgsz=None
22
+ )
23
+
24
+ # Access results
25
+ xyxy = out["xyxy"]
26
+ conf = out["conf"]
27
+ cls = out["cls"]
28
+ annot = draw_detections(img.copy(), xyxy, conf, cls, out["names"])
29
+ cv2.imwrite("annotated.jpg", annot[:, :, ::-1]) # RGB->BGR for writing
30
+ """
31
+
32
+ from typing import List, Tuple, Dict, Optional
33
+ import numpy as np
34
+ import cv2
35
+
36
+ # ---------------------------
37
+ # Utilities
38
+ # ---------------------------
39
+
40
+ def make_overlapping_tiles(H: int, W: int, tile: int, overlap: float) -> List[Tuple[int, int, int, int]]:
41
+ """Return list of (x0, y0, x1, y1) tile boxes covering the image with given overlap."""
42
+ assert 0.0 <= overlap < 1.0
43
+ stride = max(1, int(tile * (1.0 - overlap)))
44
+ xs = list(range(0, max(W - tile, 0) + 1, stride))
45
+ ys = list(range(0, max(H - tile, 0) + 1, stride))
46
+ if xs[-1] + tile < W:
47
+ xs.append(W - tile)
48
+ if ys[-1] + tile < H:
49
+ ys.append(H - tile)
50
+ tiles = []
51
+ for y in ys:
52
+ for x in xs:
53
+ x0, y0 = max(0, x), max(0, y)
54
+ x1, y1 = min(W, x0 + tile), min(H, y0 + tile)
55
+ tiles.append((x0, y0, x1, y1))
56
+ return tiles
57
+
58
+ def iou_xyxy(a: np.ndarray, b: np.ndarray) -> np.ndarray:
59
+ """IoU between one box a (4,x) and many boxes b (N,4)."""
60
+ xx1 = np.maximum(a[0], b[:, 0])
61
+ yy1 = np.maximum(a[1], b[:, 1])
62
+ xx2 = np.minimum(a[2], b[:, 2])
63
+ yy2 = np.minimum(a[3], b[:, 3])
64
+ inter = np.maximum(0, xx2 - xx1) * np.maximum(0, yy2 - yy1)
65
+ area_a = (a[2]-a[0]) * (a[3]-a[1])
66
+ area_b = (b[:, 2]-b[:, 0]) * (b[:, 3]-b[:, 1])
67
+ union = np.maximum(1e-9, area_a + area_b - inter)
68
+ return inter / union
69
+
70
+ def soft_nms_classwise(
71
+ boxes: np.ndarray, scores: np.ndarray, classes: np.ndarray,
72
+ iou_thr: float = 0.55, method: str = "linear", sigma: float = 0.5,
73
+ score_thresh: float = 1e-3, max_det: Optional[int] = None
74
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
75
+ """
76
+ Soft-NMS per class.
77
+ boxes: (N,4), scores:(N, x), classes:(N, x)
78
+ Returns filtered (boxes, scores, classes).
79
+ """
80
+ keep_boxes, keep_scores, keep_classes = [], [], []
81
+ for c in np.unique(classes):
82
+ m = classes == c
83
+ b = boxes[m].astype(np.float32).copy()
84
+ s = scores[m].astype(np.float32).copy()
85
+ idxs = np.arange(b.shape[0])
86
+
87
+ kept = []
88
+ while len(idxs):
89
+ i = idxs[np.argmax(s[idxs])]
90
+ M = b[i].copy()
91
+ Ms = s[i].copy()
92
+ kept.append(i)
93
+
94
+ idxs = idxs[idxs != i]
95
+ if len(idxs) == 0:
96
+ break
97
+ ious = iou_xyxy(M, b[idxs])
98
+ if method == "linear":
99
+ decay = np.where(ious > iou_thr, 1.0 - ious, 1.0)
100
+ s[idxs] *= decay
101
+ elif method == "gaussian":
102
+ s[idxs] *= np.exp(-(ious ** 2) / sigma)
103
+ elif method == "hard":
104
+ # standard NMS behaviour
105
+ idxs = idxs[ious <= iou_thr]
106
+ else:
107
+ raise ValueError("method must be 'linear', 'gaussian', or 'hard'")
108
+
109
+ # prune very low scores
110
+ idxs = idxs[s[idxs] >= score_thresh]
111
+
112
+ if kept:
113
+ kb, ks = b[kept], s[kept]
114
+ order = np.argsort(-ks)
115
+ kb, ks = kb[order], ks[order]
116
+ kc = np.full(len(ks), c, dtype=classes.dtype)
117
+ keep_boxes.append(kb)
118
+ keep_scores.append(ks)
119
+ keep_classes.append(kc)
120
+
121
+ if not keep_boxes:
122
+ return (np.zeros((0, 4), dtype=np.float32),
123
+ np.zeros((0,), dtype=np.float32),
124
+ np.zeros((0,), dtype=classes.dtype))
125
+
126
+ B = np.concatenate(keep_boxes, axis=0)
127
+ S = np.concatenate(keep_scores, axis=0)
128
+ C = np.concatenate(keep_classes, axis=0)
129
+
130
+ order = np.argsort(-S)
131
+ if max_det is not None:
132
+ order = order[:max_det]
133
+ return B[order], S[order], C[order]
134
+
135
+ def draw_detections(img: np.ndarray, boxes: np.ndarray, scores: np.ndarray, classes: np.ndarray, names: Dict[int, str]) -> np.ndarray:
136
+ """Simple visualizer (RGB in, RGB out)."""
137
+ for (x1, y1, x2, y2), sc, cl in zip(boxes.astype(int), scores, classes.astype(int)):
138
+ label = f"{names.get(cl, str(cl))} {sc:.2f}"
139
+ cv2.rectangle(img, (x1, y1), (x2, y2), (0, 180, 255), 2)
140
+ (tw, th), bl = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
141
+ cv2.rectangle(img, (x1, y1 - th - 6), (x1 + tw + 4, y1), (0, 180, 255), -1)
142
+ cv2.putText(img, label, (x1 + 2, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2, cv2.LINE_AA)
143
+ return img
144
+
145
+ # ---------------------------
146
+ # Main tiled inference
147
+ # ---------------------------
148
+
149
+ def detect_tiled_softnms(
150
+ model, image: np.ndarray,
151
+ tile_size: int = 1024, overlap: float = 0.25,
152
+ per_tile_conf: float = 0.25, per_tile_iou: float = 0.7,
153
+ softnms_iou: float = 0.55, softnms_method: str = "linear", softnms_sigma: float = 0.5,
154
+ final_conf: float = 0.25, max_det: int = 3000,
155
+ device: Optional[str] = None, imgsz: Optional[int] = None,
156
+ class_agnostic_nms: bool = False
157
+ ) -> Dict[str, np.ndarray]:
158
+ """
159
+ Run YOLO on overlapping tiles, then fuse globally with class-wise Soft-NMS.
160
+ Returns dict: {"xyxy","conf","cls","names"}.
161
+ """
162
+ assert image.ndim == 3, "image must be HxWx3"
163
+ H, W = image.shape[:2]
164
+ names = getattr(model, "names", {i: str(i) for i in range(1000)})
165
+
166
+ tiles = make_overlapping_tiles(H, W, tile=tile_size, overlap=overlap)
167
+
168
+ all_boxes, all_scores, all_classes = [], [], []
169
+
170
+ for (x0, y0, x1, y1) in tiles:
171
+ tile = image[y0:y1, x0:x1]
172
+ # Ultralytics returns boxes in original tile coords (pre-letterbox)
173
+ results = model.predict(
174
+ source=tile,
175
+ conf=per_tile_conf,
176
+ iou=per_tile_iou,
177
+ imgsz=imgsz, # None -> model default
178
+ device=device,
179
+ verbose=False
180
+ )
181
+
182
+ if not results:
183
+ continue
184
+
185
+ r = results[0]
186
+ if r.boxes is None or r.boxes.shape[0] == 0:
187
+ continue
188
+
189
+ b = r.boxes.xyxy.cpu().numpy()
190
+ s = r.boxes.conf.cpu().numpy()
191
+ c = r.boxes.cls.cpu().numpy().astype(int)
192
+
193
+ # Map to full-image coordinates
194
+ b[:, [0, 2]] += x0
195
+ b[:, [1, 3]] += y0
196
+
197
+ # Clip
198
+ b[:, 0] = np.clip(b[:, 0], 0, W - 1)
199
+ b[:, 1] = np.clip(b[:, 1], 0, H - 1)
200
+ b[:, 2] = np.clip(b[:, 2], 0, W - 1)
201
+ b[:, 3] = np.clip(b[:, 3], 0, H - 1)
202
+
203
+ # Filter degenerate boxes
204
+ valid = (b[:, 2] > b[:, 0]) & (b[:, 3] > b[:, 1])
205
+ if not np.any(valid):
206
+ continue
207
+ all_boxes.append(b[valid])
208
+ all_scores.append(s[valid])
209
+ all_classes.append(c[valid])
210
+
211
+ if not all_boxes:
212
+ return {"xyxy": np.zeros((0, 4), dtype=np.float32),
213
+ "conf": np.zeros((0,), dtype=np.float32),
214
+ "cls": np.zeros((0,), dtype=np.int32),
215
+ "names": names}
216
+
217
+ boxes = np.concatenate(all_boxes, axis=0).astype(np.float32)
218
+ scores = np.concatenate(all_scores, axis=0).astype(np.float32)
219
+ classes = np.concatenate(all_classes, axis=0).astype(np.int32)
220
+
221
+ # Global fusion: class-wise Soft-NMS or class-agnostic if chosen
222
+ if class_agnostic_nms:
223
+ classes = np.zeros_like(classes)
224
+
225
+ boxes, scores, classes = soft_nms_classwise(
226
+ boxes, scores, classes,
227
+ iou_thr=softnms_iou,
228
+ method=softnms_method,
229
+ sigma=softnms_sigma,
230
+ score_thresh=1e-3,
231
+ max_det=max_det
232
+ )
233
+
234
+ # Final confidence gate
235
+ keep = scores >= final_conf
236
+ boxes, scores, classes = boxes[keep], scores[keep], classes[keep]
237
+
238
+ return {"xyxy": boxes, "conf": scores, "cls": classes, "names": names}
tiling_test.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Dict, Optional, Tuple
3
+ import numpy as np
4
+ import pandas as pd
5
+ import cv2
6
+
7
+ # --- Parse YOLO txt (normalized) -> pixel xyxy ---
8
+ def load_yolo_labels_xyxy(txt_path: str, img_w: int, img_h: int) -> Tuple[np.ndarray, np.ndarray]:
9
+ """
10
+ Returns:
11
+ cls_ids: (N,) int
12
+ boxes_xyxy: (N,4) float32 in pixel coords
13
+ """
14
+ cls_ids, boxes = [], []
15
+ with open(txt_path, "r") as f:
16
+ for line in f:
17
+ parts = line.strip().split()
18
+ if len(parts) != 5:
19
+ continue
20
+ c, xc, yc, w, h = parts
21
+ c = int(float(c))
22
+ xc, yc, w, h = map(float, (xc, yc, w, h))
23
+ # convert normalized -> pixel xyxy
24
+ px = xc * img_w
25
+ py = yc * img_h
26
+ pw = w * img_w
27
+ ph = h * img_h
28
+ x1 = px - pw / 2.0
29
+ y1 = py - ph / 2.0
30
+ x2 = px + pw / 2.0
31
+ y2 = py + ph / 2.0
32
+ boxes.append([x1, y1, x2, y2])
33
+ cls_ids.append(c)
34
+ if not boxes:
35
+ return np.zeros((0,), dtype=np.int32), np.zeros((0,4), dtype=np.float32)
36
+ return np.array(cls_ids, dtype=np.int32), np.array(boxes, dtype=np.float32)
37
+
38
+ # --- IoU & matching ---
39
+ def iou_matrix(a_xyxy: np.ndarray, b_xyxy: np.ndarray) -> np.ndarray:
40
+ """Pairwise IoU: (Na,4) vs (Nb,4) -> (Na,Nb)."""
41
+ if a_xyxy.size == 0 or b_xyxy.size == 0:
42
+ return np.zeros((a_xyxy.shape[0], b_xyxy.shape[0]), dtype=np.float32)
43
+ ax1, ay1, ax2, ay2 = a_xyxy[:,0:1], a_xyxy[:,1:2], a_xyxy[:,2:3], a_xyxy[:,3:4]
44
+ bx1, by1, bx2, by2 = b_xyxy[:,0], b_xyxy[:,1], b_xyxy[:,2], b_xyxy[:,3]
45
+ xx1 = np.maximum(ax1, bx1)
46
+ yy1 = np.maximum(ay1, by1)
47
+ xx2 = np.minimum(ax2, bx2)
48
+ yy2 = np.minimum(ay2, by2)
49
+ inter = np.maximum(0, xx2 - xx1) * np.maximum(0, yy2 - yy1)
50
+ area_a = (ax2 - ax1) * (ay2 - ay1)
51
+ area_b = (bx2 - bx1) * (by2 - by1)
52
+ union = np.maximum(1e-9, area_a + area_b - inter)
53
+ return (inter / union).astype(np.float32)
54
+
55
+ def greedy_match_per_class(
56
+ pred_boxes: np.ndarray, pred_scores: np.ndarray, pred_cls: np.ndarray,
57
+ gt_boxes: np.ndarray, gt_cls: np.ndarray,
58
+ iou_thr: float
59
+ ):
60
+ """
61
+ Greedy IoU matching per class. Returns:
62
+ matches: list of (pred_idx, gt_idx)
63
+ pred_unmatched: np.ndarray of unmatched pred indices
64
+ gt_unmatched: np.ndarray of unmatched gt indices
65
+ """
66
+ matches = []
67
+ pred_unmatched = np.ones(len(pred_boxes), dtype=bool)
68
+ gt_unmatched = np.ones(len(gt_boxes), dtype=bool)
69
+
70
+ classes = np.union1d(pred_cls, gt_cls)
71
+ for c in classes:
72
+ p_idx = np.where(pred_cls == c)[0]
73
+ g_idx = np.where(gt_cls == c)[0]
74
+ if len(p_idx) == 0 or len(g_idx) == 0:
75
+ continue
76
+
77
+ IoU = iou_matrix(pred_boxes[p_idx], gt_boxes[g_idx])
78
+ # Greedy: repeatedly pick the best remaining pair
79
+ used_p = set(); used_g = set()
80
+ while True:
81
+ if IoU.size == 0:
82
+ break
83
+ m = np.max(IoU)
84
+ if m < iou_thr:
85
+ break
86
+ i, j = np.unravel_index(np.argmax(IoU), IoU.shape)
87
+ pi, gi = p_idx[i], g_idx[j]
88
+ if (i in used_p) or (j in used_g):
89
+ IoU[i, j] = -1.0
90
+ continue
91
+ matches.append((pi, gi))
92
+ used_p.add(i); used_g.add(j)
93
+ IoU[i, :] = -1.0
94
+ IoU[:, j] = -1.0
95
+
96
+ # mark matched as not unmatched
97
+ for i in used_p:
98
+ pred_unmatched[p_idx[i]] = False
99
+ for j in used_g:
100
+ gt_unmatched[g_idx[j]] = False
101
+
102
+ return matches, np.where(pred_unmatched)[0], np.where(gt_unmatched)[0]
103
+
104
+ # --- Count metrics (optional but handy) ---
105
+ def count_metrics(actual_counts: Dict[int, int], pred_counts: Dict[int, int]) -> Tuple[pd.DataFrame, Dict]:
106
+ labels = sorted(set(actual_counts)|set(pred_counts))
107
+ rows = []
108
+ tp_sum = fp_sum = fn_sum = 0
109
+ abs_sum = 0
110
+ denom_sum = 0
111
+ for c in labels:
112
+ a = int(actual_counts.get(c, 0))
113
+ p = int(pred_counts.get(c, 0))
114
+ tp = min(a, p); fp = max(p-a, 0); fn = max(a-p, 0)
115
+ abs_err = abs(p-a)
116
+ denom = (abs(a)+abs(p))/2 if (a+p)>0 else 1.0
117
+ smape = abs_err/denom
118
+ prec = tp/(tp+fp) if (tp+fp)>0 else float('nan')
119
+ rec = tp/(tp+fn) if (tp+fn)>0 else float('nan')
120
+ f1 = 2*prec*rec/(prec+rec) if (not math.isnan(prec) and not math.isnan(rec) and (prec+rec)>0) else float('nan')
121
+ rows.append({"class_id": c, "actual": a, "pred": p, "abs_err": abs_err, "sMAPE": smape, "P": prec, "R": rec, "F1": f1})
122
+ tp_sum += tp; fp_sum += fp; fn_sum += fn; abs_sum += abs_err; denom_sum += denom
123
+ micro_p = tp_sum/(tp_sum+fp_sum) if (tp_sum+fp_sum)>0 else float('nan')
124
+ micro_r = tp_sum/(tp_sum+fn_sum) if (tp_sum+fn_sum)>0 else float('nan')
125
+ micro_f1 = 2*micro_p*micro_r/(micro_p+micro_r) if (not math.isnan(micro_p) and not math.isnan(micro_r) and (micro_p+micro_r)>0) else float('nan')
126
+ overall = {"sum_abs_count_error": abs_sum, "micro_precision": micro_p, "micro_recall": micro_r, "micro_f1": micro_f1, "micro_sMAPE": abs_sum/(denom_sum or 1.0)}
127
+ return pd.DataFrame(rows), overall
128
+
129
+ # --- Pretty eval for ONE image ---
130
+ def evaluate_one_image(
131
+ out: Dict, # from detect_tiled_softnms(...)
132
+ label_txt_path: str,
133
+ img_w: int, img_h: int,
134
+ iou_thr: float = 0.50,
135
+ conf_thr: float = 0.25,
136
+ return_vis: bool = False,
137
+ image_rgb: Optional[np.ndarray] = None
138
+ ):
139
+ """
140
+ Returns:
141
+ per_class_df (precision/recall/F1, counts),
142
+ overall (micro P/R/F1, totals),
143
+ (optional) annotated RGB image
144
+ """
145
+ # Predictions (filter by conf)
146
+ p_boxes = out["xyxy"].astype(np.float32)
147
+ p_scores = out["conf"].astype(np.float32)
148
+ p_cls = out["cls"].astype(np.int32)
149
+ keep = p_scores >= float(conf_thr)
150
+ p_boxes, p_scores, p_cls = p_boxes[keep], p_scores[keep], p_cls[keep]
151
+ names: Dict[int,str] = out.get("names", {})
152
+
153
+ # Ground truth
154
+ g_cls, g_boxes = load_yolo_labels_xyxy(label_txt_path, img_w, img_h)
155
+
156
+ # Per-class counts (sanity)
157
+ actual_counts = {int(c): int((g_cls == c).sum()) for c in np.unique(g_cls)} if len(g_cls) else {}
158
+ pred_counts = {int(c): int((p_cls == c).sum()) for c in np.unique(p_cls)} if len(p_cls) else {}
159
+ count_df, count_overall = count_metrics(actual_counts, pred_counts)
160
+
161
+ # Matching
162
+ matches, p_unmatched_idx, g_unmatched_idx = greedy_match_per_class(
163
+ p_boxes, p_scores, p_cls, g_boxes, g_cls, iou_thr=iou_thr
164
+ )
165
+ matched_p = np.array([m[0] for m in matches], dtype=int) if matches else np.array([], dtype=int)
166
+ matched_g = np.array([m[1] for m in matches], dtype=int) if matches else np.array([], dtype=int)
167
+
168
+ # Compute per-class detection metrics
169
+ classes = sorted(set(list(actual_counts.keys()) + list(pred_counts.keys())))
170
+ rows = []
171
+ for c in classes:
172
+ tp = int(np.sum(p_cls[matched_p] == c)) # matched pairs already class-consistent
173
+ fp = int(np.sum((p_cls == c))) - tp
174
+ fn = int(np.sum((g_cls == c))) - tp
175
+ prec = tp/(tp+fp) if (tp+fp)>0 else float('nan')
176
+ rec = tp/(tp+fn) if (tp+fn)>0 else float('nan')
177
+ f1 = 2*prec*rec/(prec+rec) if (not math.isnan(prec) and not math.isnan(rec) and (prec+rec)>0) else float('nan')
178
+ rows.append({
179
+ "class_id": c,
180
+ "class_name": names.get(c, str(c)),
181
+ "gt": int(np.sum(g_cls==c)),
182
+ "pred": int(np.sum(p_cls==c)),
183
+ "TP": tp, "FP": fp, "FN": fn,
184
+ "precision": prec, "recall": rec, "F1": f1
185
+ })
186
+ det_df = pd.DataFrame(rows).sort_values("class_id").reset_index(drop=True)
187
+
188
+ # Overall detection micro-averages
189
+ TP = int(len(matches))
190
+ FP = int(len(p_boxes) - TP)
191
+ FN = int(len(g_boxes) - TP)
192
+ micro_p = TP/(TP+FP) if (TP+FP)>0 else float('nan')
193
+ micro_r = TP/(TP+FN) if (TP+FN)>0 else float('nan')
194
+ micro_f1 = 2*micro_p*micro_r/(micro_p+micro_r) if (not math.isnan(micro_p) and not math.isnan(micro_r) and (micro_p+micro_r)>0) else float('nan')
195
+
196
+ overall = {
197
+ "gt_instances": int(len(g_boxes)),
198
+ "pred_instances": int(len(p_boxes)),
199
+ "TP": TP, "FP": FP, "FN": FN,
200
+ "micro_precision": micro_p,
201
+ "micro_recall": micro_r,
202
+ "micro_F1": micro_f1,
203
+ "iou_thr": iou_thr,
204
+ "conf_thr": conf_thr
205
+ }
206
+
207
+ if not return_vis or image_rgb is None:
208
+ return det_df, overall, count_df, count_overall
209
+
210
+ # Annotated visualization
211
+ vis = image_rgb.copy()
212
+ # Draw GT (yellow)
213
+ for i in range(len(g_boxes)):
214
+ color = (240, 230, 70)
215
+ x1,y1,x2,y2 = g_boxes[i].astype(int)
216
+ cv2.rectangle(vis, (x1,y1), (x2,y2), color, 2)
217
+ # Draw matched predictions (green)
218
+ for pi in matched_p:
219
+ x1,y1,x2,y2 = p_boxes[pi].astype(int)
220
+ c = int(p_cls[pi]); sc = float(p_scores[pi])
221
+ label = f"{names.get(c,str(c))} {sc:.2f}"
222
+ cv2.rectangle(vis, (x1,y1), (x2,y2), (60, 220, 60), 2)
223
+ cv2.putText(vis, label, (x1+2, max(0,y1-5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (60,220,60), 2, cv2.LINE_AA)
224
+ # Draw unmatched predictions (red)
225
+ for pi in p_unmatched_idx:
226
+ x1,y1,x2,y2 = p_boxes[pi].astype(int)
227
+ c = int(p_cls[pi]); sc = float(p_scores[pi])
228
+ label = f"{names.get(c,str(c))} {sc:.2f}"
229
+ cv2.rectangle(vis, (x1,y1), (x2,y2), (10, 60, 240), 2)
230
+ cv2.putText(vis, label, (x1+2, max(0,y1-5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (10,60,240), 2, cv2.LINE_AA)
231
+ return det_df, overall, count_df, count_overall, vis