Adding Tiling functionality
Browse files- app.py +49 -57
- tiling.py +238 -0
- tiling_test.py +231 -0
app.py
CHANGED
|
@@ -8,6 +8,7 @@ import PIL.Image as Image
|
|
| 8 |
from fastapi import FastAPI, UploadFile, File, HTTPException, Request
|
| 9 |
from pydantic import BaseModel
|
| 10 |
from ultralytics import YOLO
|
|
|
|
| 11 |
|
| 12 |
MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", 8 * 1024 * 1024)) # 8 MB default
|
| 13 |
MAX_SIDE = int(os.getenv("MAX_SIDE", 2000)) # downscale largest side to this
|
|
@@ -20,14 +21,17 @@ HIGH_CLASS_NAMES = [
|
|
| 20 |
|
| 21 |
LOW_CLASS_NAMES = ["shop_bw", "shop_sw", "field_bw", "Insulation"]
|
| 22 |
|
|
|
|
|
|
|
| 23 |
# -----------------------------
|
| 24 |
# App setup
|
| 25 |
# -----------------------------
|
| 26 |
|
| 27 |
app = FastAPI(title="YOLO Weld Type Detector API", version="1.0.0")
|
| 28 |
|
| 29 |
-
model = YOLO("
|
| 30 |
-
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
# -----------------------------
|
|
@@ -58,29 +62,45 @@ def downscale_if_needed(img_rgb: np.ndarray) -> np.ndarray:
|
|
| 58 |
new_w, new_h = int(w * scale), int(h * scale)
|
| 59 |
return cv2.resize(img_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
| 60 |
|
| 61 |
-
def
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
return counts
|
| 78 |
-
|
| 79 |
-
#
|
| 80 |
-
#
|
| 81 |
-
#
|
| 82 |
-
|
| 83 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
# -----------------------------
|
| 86 |
# Endpoints
|
|
@@ -110,11 +130,11 @@ async def predict_multipart(file: UploadFile = File(default=None)):
|
|
| 110 |
|
| 111 |
img_rgb = downscale_if_needed(pil_to_numpy_rgb(img))
|
| 112 |
img_bgr = numpy_rgb_to_bgr(img_rgb)
|
| 113 |
-
|
| 114 |
-
high = detect_weld_types(img_bgr, "top")
|
| 115 |
-
low = detect_weld_types(img_bgr, "low")
|
| 116 |
-
merged = high | low
|
| 117 |
-
return PredictResponse(detections=
|
| 118 |
|
| 119 |
@app.post("/ping")
|
| 120 |
async def ping():
|
|
@@ -126,34 +146,6 @@ async def echo(req: Request):
|
|
| 126 |
ct = req.headers.get("content-type", "")
|
| 127 |
return {"ok": True, "content_type": ct}
|
| 128 |
|
| 129 |
-
# @app.post("/predict_base64", response_model=PredictResponse)
|
| 130 |
-
# def predict_base64(payload: PredictQuery = Body(...)):
|
| 131 |
-
# b64 = payload.image_base64
|
| 132 |
-
# # Size guard for base64 (approx raw size)
|
| 133 |
-
# try:
|
| 134 |
-
# raw = base64.b64decode(b64, validate=True)
|
| 135 |
-
# except Exception:
|
| 136 |
-
# raise HTTPException(status_code=400, detail="Invalid base64.")
|
| 137 |
-
#
|
| 138 |
-
# if len(raw) > MAX_UPLOAD_BYTES:
|
| 139 |
-
# raise HTTPException(
|
| 140 |
-
# status_code=413,
|
| 141 |
-
# detail=f"Image too large after base64 decode ({len(raw)/1024/1024:.2f} MB). "
|
| 142 |
-
# f"Use multipart /predict or reduce image size."
|
| 143 |
-
# )
|
| 144 |
-
#
|
| 145 |
-
# try:
|
| 146 |
-
# img = Image.open(io.BytesIO(raw))
|
| 147 |
-
# except Exception:
|
| 148 |
-
# raise HTTPException(status_code=400, detail="Invalid image.")
|
| 149 |
-
#
|
| 150 |
-
# img_rgb = downscale_if_needed(pil_to_numpy_rgb(img))
|
| 151 |
-
# img_bgr = numpy_rgb_to_bgr(img_rgb)
|
| 152 |
-
#
|
| 153 |
-
# high = detect_weld_types(img_bgr, "top")
|
| 154 |
-
# low = detect_weld_types(img_bgr, "low")
|
| 155 |
-
# return PredictResponse(detections=merge_counts(low, high))
|
| 156 |
-
|
| 157 |
|
| 158 |
if __name__ == "__main__":
|
| 159 |
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|
|
|
|
| 8 |
from fastapi import FastAPI, UploadFile, File, HTTPException, Request
|
| 9 |
from pydantic import BaseModel
|
| 10 |
from ultralytics import YOLO
|
| 11 |
+
from tiling import detect_tiled_softnms
|
| 12 |
|
| 13 |
MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", 8 * 1024 * 1024)) # 8 MB default
|
| 14 |
MAX_SIDE = int(os.getenv("MAX_SIDE", 2000)) # downscale largest side to this
|
|
|
|
| 21 |
|
| 22 |
LOW_CLASS_NAMES = ["shop_bw", "shop_sw", "field_bw", "Insulation"]
|
| 23 |
|
| 24 |
+
ALL_CLASS_NAMES = HIGH_CLASS_NAMES + LOW_CLASS_NAMES
|
| 25 |
+
|
| 26 |
# -----------------------------
|
| 27 |
# App setup
|
| 28 |
# -----------------------------
|
| 29 |
|
| 30 |
app = FastAPI(title="YOLO Weld Type Detector API", version="1.0.0")
|
| 31 |
|
| 32 |
+
model = YOLO("best_7-15-25.pt")
|
| 33 |
+
# model = YOLO("top_reduced_best.pt")
|
| 34 |
+
# low_model = YOLO("best_low_072725.pt")
|
| 35 |
|
| 36 |
|
| 37 |
# -----------------------------
|
|
|
|
| 62 |
new_w, new_h = int(w * scale), int(h * scale)
|
| 63 |
return cv2.resize(img_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
| 64 |
|
| 65 |
+
def normalize_prediction(output):
|
| 66 |
+
weld_counts = {}
|
| 67 |
+
for cls_pred in output['cls']:
|
| 68 |
+
weld_key = output['names'][cls_pred]
|
| 69 |
+
weld_counts[weld_key] = weld_counts.get(weld_key, 0) + 1
|
| 70 |
+
return weld_counts
|
| 71 |
+
|
| 72 |
+
def detect_weld_types(image_bgr: np.ndarray, model) -> dict:
|
| 73 |
+
out = detect_tiled_softnms(
|
| 74 |
+
model, image_bgr,
|
| 75 |
+
tile_size=1024, overlap=0.23,
|
| 76 |
+
per_tile_conf=0.2, per_tile_iou=0.7,
|
| 77 |
+
softnms_iou=0.6, softnms_method="hard", softnms_sigma=0.5,
|
| 78 |
+
final_conf=0.38, device=None, imgsz=1280
|
| 79 |
+
)
|
| 80 |
+
counts = normalize_prediction(out)
|
| 81 |
return counts
|
| 82 |
+
# {'file': 50.724137931034484,
|
| 83 |
+
# 'soft_iou': 0.5982183908045983,
|
| 84 |
+
# 'final_conf': 0.37854022988505753,
|
| 85 |
+
# 'olap': 0.22752873563218376}
|
| 86 |
+
|
| 87 |
+
# def detect_weld_types(image_bgr: np.ndarray, model_type: str) -> dict:
|
| 88 |
+
# if model_type == "top":
|
| 89 |
+
# results = model.predict(image_bgr)
|
| 90 |
+
# class_names = HIGH_CLASS_NAMES
|
| 91 |
+
# else:
|
| 92 |
+
# results = low_model.predict(image_bgr, conf=0.10, iou=0.55, max_det=300, imgsz=1920, augment=True)
|
| 93 |
+
# class_names = LOW_CLASS_NAMES
|
| 94 |
+
#
|
| 95 |
+
# boxes = results[0].boxes
|
| 96 |
+
# class_ids = boxes.cls.cpu().numpy().astype(int) if boxes and boxes.cls is not None else []
|
| 97 |
+
#
|
| 98 |
+
# counts = {}
|
| 99 |
+
# for cid in class_ids:
|
| 100 |
+
# if 0 <= cid < len(class_names):
|
| 101 |
+
# name = class_names[cid]
|
| 102 |
+
# counts[name] = counts.get(name, 0) + 1
|
| 103 |
+
# return counts
|
| 104 |
|
| 105 |
# -----------------------------
|
| 106 |
# Endpoints
|
|
|
|
| 130 |
|
| 131 |
img_rgb = downscale_if_needed(pil_to_numpy_rgb(img))
|
| 132 |
img_bgr = numpy_rgb_to_bgr(img_rgb)
|
| 133 |
+
welds = detect_weld_types(img_bgr, model)
|
| 134 |
+
# high = detect_weld_types(img_bgr, "top")
|
| 135 |
+
# low = detect_weld_types(img_bgr, "low")
|
| 136 |
+
# merged = high | low
|
| 137 |
+
return PredictResponse(detections=welds)
|
| 138 |
|
| 139 |
@app.post("/ping")
|
| 140 |
async def ping():
|
|
|
|
| 146 |
ct = req.headers.get("content-type", "")
|
| 147 |
return {"ok": True, "content_type": ct}
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
if __name__ == "__main__":
|
| 151 |
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|
tiling.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
tiled_yolo_softnms.py
|
| 3 |
+
Tiled inference + class-wise Soft-NMS for YOLO (Ultralytics).
|
| 4 |
+
- Runs YOLO on overlapping tiles to boost recall on small symbols.
|
| 5 |
+
- Maps all tile detections back to full-image coords.
|
| 6 |
+
- Fuses duplicates with Soft-NMS per class.
|
| 7 |
+
|
| 8 |
+
Usage
|
| 9 |
+
-----
|
| 10 |
+
from ultralytics import YOLO
|
| 11 |
+
import cv2
|
| 12 |
+
|
| 13 |
+
model = YOLO("best.pt") # your YOLO v12/v11/v8 checkpoint
|
| 14 |
+
img = cv2.imread("example.jpg")[:, :, ::-1] # BGR->RGB (optional; YOLO accepts BGR too)
|
| 15 |
+
|
| 16 |
+
out = detect_tiled_softnms(
|
| 17 |
+
model, img,
|
| 18 |
+
tile_size=1024, overlap=0.25,
|
| 19 |
+
per_tile_conf=0.2, per_tile_iou=0.7,
|
| 20 |
+
softnms_iou=0.55, softnms_method="linear", softnms_sigma=0.5,
|
| 21 |
+
final_conf=0.25, device=None, imgsz=None
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Access results
|
| 25 |
+
xyxy = out["xyxy"]
|
| 26 |
+
conf = out["conf"]
|
| 27 |
+
cls = out["cls"]
|
| 28 |
+
annot = draw_detections(img.copy(), xyxy, conf, cls, out["names"])
|
| 29 |
+
cv2.imwrite("annotated.jpg", annot[:, :, ::-1]) # RGB->BGR for writing
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
from typing import List, Tuple, Dict, Optional
|
| 33 |
+
import numpy as np
|
| 34 |
+
import cv2
|
| 35 |
+
|
| 36 |
+
# ---------------------------
|
| 37 |
+
# Utilities
|
| 38 |
+
# ---------------------------
|
| 39 |
+
|
| 40 |
+
def make_overlapping_tiles(H: int, W: int, tile: int, overlap: float) -> List[Tuple[int, int, int, int]]:
|
| 41 |
+
"""Return list of (x0, y0, x1, y1) tile boxes covering the image with given overlap."""
|
| 42 |
+
assert 0.0 <= overlap < 1.0
|
| 43 |
+
stride = max(1, int(tile * (1.0 - overlap)))
|
| 44 |
+
xs = list(range(0, max(W - tile, 0) + 1, stride))
|
| 45 |
+
ys = list(range(0, max(H - tile, 0) + 1, stride))
|
| 46 |
+
if xs[-1] + tile < W:
|
| 47 |
+
xs.append(W - tile)
|
| 48 |
+
if ys[-1] + tile < H:
|
| 49 |
+
ys.append(H - tile)
|
| 50 |
+
tiles = []
|
| 51 |
+
for y in ys:
|
| 52 |
+
for x in xs:
|
| 53 |
+
x0, y0 = max(0, x), max(0, y)
|
| 54 |
+
x1, y1 = min(W, x0 + tile), min(H, y0 + tile)
|
| 55 |
+
tiles.append((x0, y0, x1, y1))
|
| 56 |
+
return tiles
|
| 57 |
+
|
| 58 |
+
def iou_xyxy(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
| 59 |
+
"""IoU between one box a (4,x) and many boxes b (N,4)."""
|
| 60 |
+
xx1 = np.maximum(a[0], b[:, 0])
|
| 61 |
+
yy1 = np.maximum(a[1], b[:, 1])
|
| 62 |
+
xx2 = np.minimum(a[2], b[:, 2])
|
| 63 |
+
yy2 = np.minimum(a[3], b[:, 3])
|
| 64 |
+
inter = np.maximum(0, xx2 - xx1) * np.maximum(0, yy2 - yy1)
|
| 65 |
+
area_a = (a[2]-a[0]) * (a[3]-a[1])
|
| 66 |
+
area_b = (b[:, 2]-b[:, 0]) * (b[:, 3]-b[:, 1])
|
| 67 |
+
union = np.maximum(1e-9, area_a + area_b - inter)
|
| 68 |
+
return inter / union
|
| 69 |
+
|
| 70 |
+
def soft_nms_classwise(
|
| 71 |
+
boxes: np.ndarray, scores: np.ndarray, classes: np.ndarray,
|
| 72 |
+
iou_thr: float = 0.55, method: str = "linear", sigma: float = 0.5,
|
| 73 |
+
score_thresh: float = 1e-3, max_det: Optional[int] = None
|
| 74 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 75 |
+
"""
|
| 76 |
+
Soft-NMS per class.
|
| 77 |
+
boxes: (N,4), scores:(N, x), classes:(N, x)
|
| 78 |
+
Returns filtered (boxes, scores, classes).
|
| 79 |
+
"""
|
| 80 |
+
keep_boxes, keep_scores, keep_classes = [], [], []
|
| 81 |
+
for c in np.unique(classes):
|
| 82 |
+
m = classes == c
|
| 83 |
+
b = boxes[m].astype(np.float32).copy()
|
| 84 |
+
s = scores[m].astype(np.float32).copy()
|
| 85 |
+
idxs = np.arange(b.shape[0])
|
| 86 |
+
|
| 87 |
+
kept = []
|
| 88 |
+
while len(idxs):
|
| 89 |
+
i = idxs[np.argmax(s[idxs])]
|
| 90 |
+
M = b[i].copy()
|
| 91 |
+
Ms = s[i].copy()
|
| 92 |
+
kept.append(i)
|
| 93 |
+
|
| 94 |
+
idxs = idxs[idxs != i]
|
| 95 |
+
if len(idxs) == 0:
|
| 96 |
+
break
|
| 97 |
+
ious = iou_xyxy(M, b[idxs])
|
| 98 |
+
if method == "linear":
|
| 99 |
+
decay = np.where(ious > iou_thr, 1.0 - ious, 1.0)
|
| 100 |
+
s[idxs] *= decay
|
| 101 |
+
elif method == "gaussian":
|
| 102 |
+
s[idxs] *= np.exp(-(ious ** 2) / sigma)
|
| 103 |
+
elif method == "hard":
|
| 104 |
+
# standard NMS behaviour
|
| 105 |
+
idxs = idxs[ious <= iou_thr]
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError("method must be 'linear', 'gaussian', or 'hard'")
|
| 108 |
+
|
| 109 |
+
# prune very low scores
|
| 110 |
+
idxs = idxs[s[idxs] >= score_thresh]
|
| 111 |
+
|
| 112 |
+
if kept:
|
| 113 |
+
kb, ks = b[kept], s[kept]
|
| 114 |
+
order = np.argsort(-ks)
|
| 115 |
+
kb, ks = kb[order], ks[order]
|
| 116 |
+
kc = np.full(len(ks), c, dtype=classes.dtype)
|
| 117 |
+
keep_boxes.append(kb)
|
| 118 |
+
keep_scores.append(ks)
|
| 119 |
+
keep_classes.append(kc)
|
| 120 |
+
|
| 121 |
+
if not keep_boxes:
|
| 122 |
+
return (np.zeros((0, 4), dtype=np.float32),
|
| 123 |
+
np.zeros((0,), dtype=np.float32),
|
| 124 |
+
np.zeros((0,), dtype=classes.dtype))
|
| 125 |
+
|
| 126 |
+
B = np.concatenate(keep_boxes, axis=0)
|
| 127 |
+
S = np.concatenate(keep_scores, axis=0)
|
| 128 |
+
C = np.concatenate(keep_classes, axis=0)
|
| 129 |
+
|
| 130 |
+
order = np.argsort(-S)
|
| 131 |
+
if max_det is not None:
|
| 132 |
+
order = order[:max_det]
|
| 133 |
+
return B[order], S[order], C[order]
|
| 134 |
+
|
| 135 |
+
def draw_detections(img: np.ndarray, boxes: np.ndarray, scores: np.ndarray, classes: np.ndarray, names: Dict[int, str]) -> np.ndarray:
|
| 136 |
+
"""Simple visualizer (RGB in, RGB out)."""
|
| 137 |
+
for (x1, y1, x2, y2), sc, cl in zip(boxes.astype(int), scores, classes.astype(int)):
|
| 138 |
+
label = f"{names.get(cl, str(cl))} {sc:.2f}"
|
| 139 |
+
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 180, 255), 2)
|
| 140 |
+
(tw, th), bl = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
|
| 141 |
+
cv2.rectangle(img, (x1, y1 - th - 6), (x1 + tw + 4, y1), (0, 180, 255), -1)
|
| 142 |
+
cv2.putText(img, label, (x1 + 2, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2, cv2.LINE_AA)
|
| 143 |
+
return img
|
| 144 |
+
|
| 145 |
+
# ---------------------------
|
| 146 |
+
# Main tiled inference
|
| 147 |
+
# ---------------------------
|
| 148 |
+
|
| 149 |
+
def detect_tiled_softnms(
|
| 150 |
+
model, image: np.ndarray,
|
| 151 |
+
tile_size: int = 1024, overlap: float = 0.25,
|
| 152 |
+
per_tile_conf: float = 0.25, per_tile_iou: float = 0.7,
|
| 153 |
+
softnms_iou: float = 0.55, softnms_method: str = "linear", softnms_sigma: float = 0.5,
|
| 154 |
+
final_conf: float = 0.25, max_det: int = 3000,
|
| 155 |
+
device: Optional[str] = None, imgsz: Optional[int] = None,
|
| 156 |
+
class_agnostic_nms: bool = False
|
| 157 |
+
) -> Dict[str, np.ndarray]:
|
| 158 |
+
"""
|
| 159 |
+
Run YOLO on overlapping tiles, then fuse globally with class-wise Soft-NMS.
|
| 160 |
+
Returns dict: {"xyxy","conf","cls","names"}.
|
| 161 |
+
"""
|
| 162 |
+
assert image.ndim == 3, "image must be HxWx3"
|
| 163 |
+
H, W = image.shape[:2]
|
| 164 |
+
names = getattr(model, "names", {i: str(i) for i in range(1000)})
|
| 165 |
+
|
| 166 |
+
tiles = make_overlapping_tiles(H, W, tile=tile_size, overlap=overlap)
|
| 167 |
+
|
| 168 |
+
all_boxes, all_scores, all_classes = [], [], []
|
| 169 |
+
|
| 170 |
+
for (x0, y0, x1, y1) in tiles:
|
| 171 |
+
tile = image[y0:y1, x0:x1]
|
| 172 |
+
# Ultralytics returns boxes in original tile coords (pre-letterbox)
|
| 173 |
+
results = model.predict(
|
| 174 |
+
source=tile,
|
| 175 |
+
conf=per_tile_conf,
|
| 176 |
+
iou=per_tile_iou,
|
| 177 |
+
imgsz=imgsz, # None -> model default
|
| 178 |
+
device=device,
|
| 179 |
+
verbose=False
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
if not results:
|
| 183 |
+
continue
|
| 184 |
+
|
| 185 |
+
r = results[0]
|
| 186 |
+
if r.boxes is None or r.boxes.shape[0] == 0:
|
| 187 |
+
continue
|
| 188 |
+
|
| 189 |
+
b = r.boxes.xyxy.cpu().numpy()
|
| 190 |
+
s = r.boxes.conf.cpu().numpy()
|
| 191 |
+
c = r.boxes.cls.cpu().numpy().astype(int)
|
| 192 |
+
|
| 193 |
+
# Map to full-image coordinates
|
| 194 |
+
b[:, [0, 2]] += x0
|
| 195 |
+
b[:, [1, 3]] += y0
|
| 196 |
+
|
| 197 |
+
# Clip
|
| 198 |
+
b[:, 0] = np.clip(b[:, 0], 0, W - 1)
|
| 199 |
+
b[:, 1] = np.clip(b[:, 1], 0, H - 1)
|
| 200 |
+
b[:, 2] = np.clip(b[:, 2], 0, W - 1)
|
| 201 |
+
b[:, 3] = np.clip(b[:, 3], 0, H - 1)
|
| 202 |
+
|
| 203 |
+
# Filter degenerate boxes
|
| 204 |
+
valid = (b[:, 2] > b[:, 0]) & (b[:, 3] > b[:, 1])
|
| 205 |
+
if not np.any(valid):
|
| 206 |
+
continue
|
| 207 |
+
all_boxes.append(b[valid])
|
| 208 |
+
all_scores.append(s[valid])
|
| 209 |
+
all_classes.append(c[valid])
|
| 210 |
+
|
| 211 |
+
if not all_boxes:
|
| 212 |
+
return {"xyxy": np.zeros((0, 4), dtype=np.float32),
|
| 213 |
+
"conf": np.zeros((0,), dtype=np.float32),
|
| 214 |
+
"cls": np.zeros((0,), dtype=np.int32),
|
| 215 |
+
"names": names}
|
| 216 |
+
|
| 217 |
+
boxes = np.concatenate(all_boxes, axis=0).astype(np.float32)
|
| 218 |
+
scores = np.concatenate(all_scores, axis=0).astype(np.float32)
|
| 219 |
+
classes = np.concatenate(all_classes, axis=0).astype(np.int32)
|
| 220 |
+
|
| 221 |
+
# Global fusion: class-wise Soft-NMS or class-agnostic if chosen
|
| 222 |
+
if class_agnostic_nms:
|
| 223 |
+
classes = np.zeros_like(classes)
|
| 224 |
+
|
| 225 |
+
boxes, scores, classes = soft_nms_classwise(
|
| 226 |
+
boxes, scores, classes,
|
| 227 |
+
iou_thr=softnms_iou,
|
| 228 |
+
method=softnms_method,
|
| 229 |
+
sigma=softnms_sigma,
|
| 230 |
+
score_thresh=1e-3,
|
| 231 |
+
max_det=max_det
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Final confidence gate
|
| 235 |
+
keep = scores >= final_conf
|
| 236 |
+
boxes, scores, classes = boxes[keep], scores[keep], classes[keep]
|
| 237 |
+
|
| 238 |
+
return {"xyxy": boxes, "conf": scores, "cls": classes, "names": names}
|
tiling_test.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Dict, Optional, Tuple
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
# --- Parse YOLO txt (normalized) -> pixel xyxy ---
|
| 8 |
+
def load_yolo_labels_xyxy(txt_path: str, img_w: int, img_h: int) -> Tuple[np.ndarray, np.ndarray]:
|
| 9 |
+
"""
|
| 10 |
+
Returns:
|
| 11 |
+
cls_ids: (N,) int
|
| 12 |
+
boxes_xyxy: (N,4) float32 in pixel coords
|
| 13 |
+
"""
|
| 14 |
+
cls_ids, boxes = [], []
|
| 15 |
+
with open(txt_path, "r") as f:
|
| 16 |
+
for line in f:
|
| 17 |
+
parts = line.strip().split()
|
| 18 |
+
if len(parts) != 5:
|
| 19 |
+
continue
|
| 20 |
+
c, xc, yc, w, h = parts
|
| 21 |
+
c = int(float(c))
|
| 22 |
+
xc, yc, w, h = map(float, (xc, yc, w, h))
|
| 23 |
+
# convert normalized -> pixel xyxy
|
| 24 |
+
px = xc * img_w
|
| 25 |
+
py = yc * img_h
|
| 26 |
+
pw = w * img_w
|
| 27 |
+
ph = h * img_h
|
| 28 |
+
x1 = px - pw / 2.0
|
| 29 |
+
y1 = py - ph / 2.0
|
| 30 |
+
x2 = px + pw / 2.0
|
| 31 |
+
y2 = py + ph / 2.0
|
| 32 |
+
boxes.append([x1, y1, x2, y2])
|
| 33 |
+
cls_ids.append(c)
|
| 34 |
+
if not boxes:
|
| 35 |
+
return np.zeros((0,), dtype=np.int32), np.zeros((0,4), dtype=np.float32)
|
| 36 |
+
return np.array(cls_ids, dtype=np.int32), np.array(boxes, dtype=np.float32)
|
| 37 |
+
|
| 38 |
+
# --- IoU & matching ---
|
| 39 |
+
def iou_matrix(a_xyxy: np.ndarray, b_xyxy: np.ndarray) -> np.ndarray:
|
| 40 |
+
"""Pairwise IoU: (Na,4) vs (Nb,4) -> (Na,Nb)."""
|
| 41 |
+
if a_xyxy.size == 0 or b_xyxy.size == 0:
|
| 42 |
+
return np.zeros((a_xyxy.shape[0], b_xyxy.shape[0]), dtype=np.float32)
|
| 43 |
+
ax1, ay1, ax2, ay2 = a_xyxy[:,0:1], a_xyxy[:,1:2], a_xyxy[:,2:3], a_xyxy[:,3:4]
|
| 44 |
+
bx1, by1, bx2, by2 = b_xyxy[:,0], b_xyxy[:,1], b_xyxy[:,2], b_xyxy[:,3]
|
| 45 |
+
xx1 = np.maximum(ax1, bx1)
|
| 46 |
+
yy1 = np.maximum(ay1, by1)
|
| 47 |
+
xx2 = np.minimum(ax2, bx2)
|
| 48 |
+
yy2 = np.minimum(ay2, by2)
|
| 49 |
+
inter = np.maximum(0, xx2 - xx1) * np.maximum(0, yy2 - yy1)
|
| 50 |
+
area_a = (ax2 - ax1) * (ay2 - ay1)
|
| 51 |
+
area_b = (bx2 - bx1) * (by2 - by1)
|
| 52 |
+
union = np.maximum(1e-9, area_a + area_b - inter)
|
| 53 |
+
return (inter / union).astype(np.float32)
|
| 54 |
+
|
| 55 |
+
def greedy_match_per_class(
|
| 56 |
+
pred_boxes: np.ndarray, pred_scores: np.ndarray, pred_cls: np.ndarray,
|
| 57 |
+
gt_boxes: np.ndarray, gt_cls: np.ndarray,
|
| 58 |
+
iou_thr: float
|
| 59 |
+
):
|
| 60 |
+
"""
|
| 61 |
+
Greedy IoU matching per class. Returns:
|
| 62 |
+
matches: list of (pred_idx, gt_idx)
|
| 63 |
+
pred_unmatched: np.ndarray of unmatched pred indices
|
| 64 |
+
gt_unmatched: np.ndarray of unmatched gt indices
|
| 65 |
+
"""
|
| 66 |
+
matches = []
|
| 67 |
+
pred_unmatched = np.ones(len(pred_boxes), dtype=bool)
|
| 68 |
+
gt_unmatched = np.ones(len(gt_boxes), dtype=bool)
|
| 69 |
+
|
| 70 |
+
classes = np.union1d(pred_cls, gt_cls)
|
| 71 |
+
for c in classes:
|
| 72 |
+
p_idx = np.where(pred_cls == c)[0]
|
| 73 |
+
g_idx = np.where(gt_cls == c)[0]
|
| 74 |
+
if len(p_idx) == 0 or len(g_idx) == 0:
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
IoU = iou_matrix(pred_boxes[p_idx], gt_boxes[g_idx])
|
| 78 |
+
# Greedy: repeatedly pick the best remaining pair
|
| 79 |
+
used_p = set(); used_g = set()
|
| 80 |
+
while True:
|
| 81 |
+
if IoU.size == 0:
|
| 82 |
+
break
|
| 83 |
+
m = np.max(IoU)
|
| 84 |
+
if m < iou_thr:
|
| 85 |
+
break
|
| 86 |
+
i, j = np.unravel_index(np.argmax(IoU), IoU.shape)
|
| 87 |
+
pi, gi = p_idx[i], g_idx[j]
|
| 88 |
+
if (i in used_p) or (j in used_g):
|
| 89 |
+
IoU[i, j] = -1.0
|
| 90 |
+
continue
|
| 91 |
+
matches.append((pi, gi))
|
| 92 |
+
used_p.add(i); used_g.add(j)
|
| 93 |
+
IoU[i, :] = -1.0
|
| 94 |
+
IoU[:, j] = -1.0
|
| 95 |
+
|
| 96 |
+
# mark matched as not unmatched
|
| 97 |
+
for i in used_p:
|
| 98 |
+
pred_unmatched[p_idx[i]] = False
|
| 99 |
+
for j in used_g:
|
| 100 |
+
gt_unmatched[g_idx[j]] = False
|
| 101 |
+
|
| 102 |
+
return matches, np.where(pred_unmatched)[0], np.where(gt_unmatched)[0]
|
| 103 |
+
|
| 104 |
+
# --- Count metrics (optional but handy) ---
|
| 105 |
+
def count_metrics(actual_counts: Dict[int, int], pred_counts: Dict[int, int]) -> Tuple[pd.DataFrame, Dict]:
|
| 106 |
+
labels = sorted(set(actual_counts)|set(pred_counts))
|
| 107 |
+
rows = []
|
| 108 |
+
tp_sum = fp_sum = fn_sum = 0
|
| 109 |
+
abs_sum = 0
|
| 110 |
+
denom_sum = 0
|
| 111 |
+
for c in labels:
|
| 112 |
+
a = int(actual_counts.get(c, 0))
|
| 113 |
+
p = int(pred_counts.get(c, 0))
|
| 114 |
+
tp = min(a, p); fp = max(p-a, 0); fn = max(a-p, 0)
|
| 115 |
+
abs_err = abs(p-a)
|
| 116 |
+
denom = (abs(a)+abs(p))/2 if (a+p)>0 else 1.0
|
| 117 |
+
smape = abs_err/denom
|
| 118 |
+
prec = tp/(tp+fp) if (tp+fp)>0 else float('nan')
|
| 119 |
+
rec = tp/(tp+fn) if (tp+fn)>0 else float('nan')
|
| 120 |
+
f1 = 2*prec*rec/(prec+rec) if (not math.isnan(prec) and not math.isnan(rec) and (prec+rec)>0) else float('nan')
|
| 121 |
+
rows.append({"class_id": c, "actual": a, "pred": p, "abs_err": abs_err, "sMAPE": smape, "P": prec, "R": rec, "F1": f1})
|
| 122 |
+
tp_sum += tp; fp_sum += fp; fn_sum += fn; abs_sum += abs_err; denom_sum += denom
|
| 123 |
+
micro_p = tp_sum/(tp_sum+fp_sum) if (tp_sum+fp_sum)>0 else float('nan')
|
| 124 |
+
micro_r = tp_sum/(tp_sum+fn_sum) if (tp_sum+fn_sum)>0 else float('nan')
|
| 125 |
+
micro_f1 = 2*micro_p*micro_r/(micro_p+micro_r) if (not math.isnan(micro_p) and not math.isnan(micro_r) and (micro_p+micro_r)>0) else float('nan')
|
| 126 |
+
overall = {"sum_abs_count_error": abs_sum, "micro_precision": micro_p, "micro_recall": micro_r, "micro_f1": micro_f1, "micro_sMAPE": abs_sum/(denom_sum or 1.0)}
|
| 127 |
+
return pd.DataFrame(rows), overall
|
| 128 |
+
|
| 129 |
+
# --- Pretty eval for ONE image ---
|
| 130 |
+
def evaluate_one_image(
|
| 131 |
+
out: Dict, # from detect_tiled_softnms(...)
|
| 132 |
+
label_txt_path: str,
|
| 133 |
+
img_w: int, img_h: int,
|
| 134 |
+
iou_thr: float = 0.50,
|
| 135 |
+
conf_thr: float = 0.25,
|
| 136 |
+
return_vis: bool = False,
|
| 137 |
+
image_rgb: Optional[np.ndarray] = None
|
| 138 |
+
):
|
| 139 |
+
"""
|
| 140 |
+
Returns:
|
| 141 |
+
per_class_df (precision/recall/F1, counts),
|
| 142 |
+
overall (micro P/R/F1, totals),
|
| 143 |
+
(optional) annotated RGB image
|
| 144 |
+
"""
|
| 145 |
+
# Predictions (filter by conf)
|
| 146 |
+
p_boxes = out["xyxy"].astype(np.float32)
|
| 147 |
+
p_scores = out["conf"].astype(np.float32)
|
| 148 |
+
p_cls = out["cls"].astype(np.int32)
|
| 149 |
+
keep = p_scores >= float(conf_thr)
|
| 150 |
+
p_boxes, p_scores, p_cls = p_boxes[keep], p_scores[keep], p_cls[keep]
|
| 151 |
+
names: Dict[int,str] = out.get("names", {})
|
| 152 |
+
|
| 153 |
+
# Ground truth
|
| 154 |
+
g_cls, g_boxes = load_yolo_labels_xyxy(label_txt_path, img_w, img_h)
|
| 155 |
+
|
| 156 |
+
# Per-class counts (sanity)
|
| 157 |
+
actual_counts = {int(c): int((g_cls == c).sum()) for c in np.unique(g_cls)} if len(g_cls) else {}
|
| 158 |
+
pred_counts = {int(c): int((p_cls == c).sum()) for c in np.unique(p_cls)} if len(p_cls) else {}
|
| 159 |
+
count_df, count_overall = count_metrics(actual_counts, pred_counts)
|
| 160 |
+
|
| 161 |
+
# Matching
|
| 162 |
+
matches, p_unmatched_idx, g_unmatched_idx = greedy_match_per_class(
|
| 163 |
+
p_boxes, p_scores, p_cls, g_boxes, g_cls, iou_thr=iou_thr
|
| 164 |
+
)
|
| 165 |
+
matched_p = np.array([m[0] for m in matches], dtype=int) if matches else np.array([], dtype=int)
|
| 166 |
+
matched_g = np.array([m[1] for m in matches], dtype=int) if matches else np.array([], dtype=int)
|
| 167 |
+
|
| 168 |
+
# Compute per-class detection metrics
|
| 169 |
+
classes = sorted(set(list(actual_counts.keys()) + list(pred_counts.keys())))
|
| 170 |
+
rows = []
|
| 171 |
+
for c in classes:
|
| 172 |
+
tp = int(np.sum(p_cls[matched_p] == c)) # matched pairs already class-consistent
|
| 173 |
+
fp = int(np.sum((p_cls == c))) - tp
|
| 174 |
+
fn = int(np.sum((g_cls == c))) - tp
|
| 175 |
+
prec = tp/(tp+fp) if (tp+fp)>0 else float('nan')
|
| 176 |
+
rec = tp/(tp+fn) if (tp+fn)>0 else float('nan')
|
| 177 |
+
f1 = 2*prec*rec/(prec+rec) if (not math.isnan(prec) and not math.isnan(rec) and (prec+rec)>0) else float('nan')
|
| 178 |
+
rows.append({
|
| 179 |
+
"class_id": c,
|
| 180 |
+
"class_name": names.get(c, str(c)),
|
| 181 |
+
"gt": int(np.sum(g_cls==c)),
|
| 182 |
+
"pred": int(np.sum(p_cls==c)),
|
| 183 |
+
"TP": tp, "FP": fp, "FN": fn,
|
| 184 |
+
"precision": prec, "recall": rec, "F1": f1
|
| 185 |
+
})
|
| 186 |
+
det_df = pd.DataFrame(rows).sort_values("class_id").reset_index(drop=True)
|
| 187 |
+
|
| 188 |
+
# Overall detection micro-averages
|
| 189 |
+
TP = int(len(matches))
|
| 190 |
+
FP = int(len(p_boxes) - TP)
|
| 191 |
+
FN = int(len(g_boxes) - TP)
|
| 192 |
+
micro_p = TP/(TP+FP) if (TP+FP)>0 else float('nan')
|
| 193 |
+
micro_r = TP/(TP+FN) if (TP+FN)>0 else float('nan')
|
| 194 |
+
micro_f1 = 2*micro_p*micro_r/(micro_p+micro_r) if (not math.isnan(micro_p) and not math.isnan(micro_r) and (micro_p+micro_r)>0) else float('nan')
|
| 195 |
+
|
| 196 |
+
overall = {
|
| 197 |
+
"gt_instances": int(len(g_boxes)),
|
| 198 |
+
"pred_instances": int(len(p_boxes)),
|
| 199 |
+
"TP": TP, "FP": FP, "FN": FN,
|
| 200 |
+
"micro_precision": micro_p,
|
| 201 |
+
"micro_recall": micro_r,
|
| 202 |
+
"micro_F1": micro_f1,
|
| 203 |
+
"iou_thr": iou_thr,
|
| 204 |
+
"conf_thr": conf_thr
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
if not return_vis or image_rgb is None:
|
| 208 |
+
return det_df, overall, count_df, count_overall
|
| 209 |
+
|
| 210 |
+
# Annotated visualization
|
| 211 |
+
vis = image_rgb.copy()
|
| 212 |
+
# Draw GT (yellow)
|
| 213 |
+
for i in range(len(g_boxes)):
|
| 214 |
+
color = (240, 230, 70)
|
| 215 |
+
x1,y1,x2,y2 = g_boxes[i].astype(int)
|
| 216 |
+
cv2.rectangle(vis, (x1,y1), (x2,y2), color, 2)
|
| 217 |
+
# Draw matched predictions (green)
|
| 218 |
+
for pi in matched_p:
|
| 219 |
+
x1,y1,x2,y2 = p_boxes[pi].astype(int)
|
| 220 |
+
c = int(p_cls[pi]); sc = float(p_scores[pi])
|
| 221 |
+
label = f"{names.get(c,str(c))} {sc:.2f}"
|
| 222 |
+
cv2.rectangle(vis, (x1,y1), (x2,y2), (60, 220, 60), 2)
|
| 223 |
+
cv2.putText(vis, label, (x1+2, max(0,y1-5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (60,220,60), 2, cv2.LINE_AA)
|
| 224 |
+
# Draw unmatched predictions (red)
|
| 225 |
+
for pi in p_unmatched_idx:
|
| 226 |
+
x1,y1,x2,y2 = p_boxes[pi].astype(int)
|
| 227 |
+
c = int(p_cls[pi]); sc = float(p_scores[pi])
|
| 228 |
+
label = f"{names.get(c,str(c))} {sc:.2f}"
|
| 229 |
+
cv2.rectangle(vis, (x1,y1), (x2,y2), (10, 60, 240), 2)
|
| 230 |
+
cv2.putText(vis, label, (x1+2, max(0,y1-5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (10,60,240), 2, cv2.LINE_AA)
|
| 231 |
+
return det_df, overall, count_df, count_overall, vis
|