ais-api / tiling_test.py
csmith715's picture
Adding Tiling functionality
8021aca
import math
from typing import Dict, Optional, Tuple
import numpy as np
import pandas as pd
import cv2
# --- Parse YOLO txt (normalized) -> pixel xyxy ---
def load_yolo_labels_xyxy(txt_path: str, img_w: int, img_h: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Returns:
cls_ids: (N,) int
boxes_xyxy: (N,4) float32 in pixel coords
"""
cls_ids, boxes = [], []
with open(txt_path, "r") as f:
for line in f:
parts = line.strip().split()
if len(parts) != 5:
continue
c, xc, yc, w, h = parts
c = int(float(c))
xc, yc, w, h = map(float, (xc, yc, w, h))
# convert normalized -> pixel xyxy
px = xc * img_w
py = yc * img_h
pw = w * img_w
ph = h * img_h
x1 = px - pw / 2.0
y1 = py - ph / 2.0
x2 = px + pw / 2.0
y2 = py + ph / 2.0
boxes.append([x1, y1, x2, y2])
cls_ids.append(c)
if not boxes:
return np.zeros((0,), dtype=np.int32), np.zeros((0,4), dtype=np.float32)
return np.array(cls_ids, dtype=np.int32), np.array(boxes, dtype=np.float32)
# --- IoU & matching ---
def iou_matrix(a_xyxy: np.ndarray, b_xyxy: np.ndarray) -> np.ndarray:
"""Pairwise IoU: (Na,4) vs (Nb,4) -> (Na,Nb)."""
if a_xyxy.size == 0 or b_xyxy.size == 0:
return np.zeros((a_xyxy.shape[0], b_xyxy.shape[0]), dtype=np.float32)
ax1, ay1, ax2, ay2 = a_xyxy[:,0:1], a_xyxy[:,1:2], a_xyxy[:,2:3], a_xyxy[:,3:4]
bx1, by1, bx2, by2 = b_xyxy[:,0], b_xyxy[:,1], b_xyxy[:,2], b_xyxy[:,3]
xx1 = np.maximum(ax1, bx1)
yy1 = np.maximum(ay1, by1)
xx2 = np.minimum(ax2, bx2)
yy2 = np.minimum(ay2, by2)
inter = np.maximum(0, xx2 - xx1) * np.maximum(0, yy2 - yy1)
area_a = (ax2 - ax1) * (ay2 - ay1)
area_b = (bx2 - bx1) * (by2 - by1)
union = np.maximum(1e-9, area_a + area_b - inter)
return (inter / union).astype(np.float32)
def greedy_match_per_class(
pred_boxes: np.ndarray, pred_scores: np.ndarray, pred_cls: np.ndarray,
gt_boxes: np.ndarray, gt_cls: np.ndarray,
iou_thr: float
):
"""
Greedy IoU matching per class. Returns:
matches: list of (pred_idx, gt_idx)
pred_unmatched: np.ndarray of unmatched pred indices
gt_unmatched: np.ndarray of unmatched gt indices
"""
matches = []
pred_unmatched = np.ones(len(pred_boxes), dtype=bool)
gt_unmatched = np.ones(len(gt_boxes), dtype=bool)
classes = np.union1d(pred_cls, gt_cls)
for c in classes:
p_idx = np.where(pred_cls == c)[0]
g_idx = np.where(gt_cls == c)[0]
if len(p_idx) == 0 or len(g_idx) == 0:
continue
IoU = iou_matrix(pred_boxes[p_idx], gt_boxes[g_idx])
# Greedy: repeatedly pick the best remaining pair
used_p = set(); used_g = set()
while True:
if IoU.size == 0:
break
m = np.max(IoU)
if m < iou_thr:
break
i, j = np.unravel_index(np.argmax(IoU), IoU.shape)
pi, gi = p_idx[i], g_idx[j]
if (i in used_p) or (j in used_g):
IoU[i, j] = -1.0
continue
matches.append((pi, gi))
used_p.add(i); used_g.add(j)
IoU[i, :] = -1.0
IoU[:, j] = -1.0
# mark matched as not unmatched
for i in used_p:
pred_unmatched[p_idx[i]] = False
for j in used_g:
gt_unmatched[g_idx[j]] = False
return matches, np.where(pred_unmatched)[0], np.where(gt_unmatched)[0]
# --- Count metrics (optional but handy) ---
def count_metrics(actual_counts: Dict[int, int], pred_counts: Dict[int, int]) -> Tuple[pd.DataFrame, Dict]:
labels = sorted(set(actual_counts)|set(pred_counts))
rows = []
tp_sum = fp_sum = fn_sum = 0
abs_sum = 0
denom_sum = 0
for c in labels:
a = int(actual_counts.get(c, 0))
p = int(pred_counts.get(c, 0))
tp = min(a, p); fp = max(p-a, 0); fn = max(a-p, 0)
abs_err = abs(p-a)
denom = (abs(a)+abs(p))/2 if (a+p)>0 else 1.0
smape = abs_err/denom
prec = tp/(tp+fp) if (tp+fp)>0 else float('nan')
rec = tp/(tp+fn) if (tp+fn)>0 else float('nan')
f1 = 2*prec*rec/(prec+rec) if (not math.isnan(prec) and not math.isnan(rec) and (prec+rec)>0) else float('nan')
rows.append({"class_id": c, "actual": a, "pred": p, "abs_err": abs_err, "sMAPE": smape, "P": prec, "R": rec, "F1": f1})
tp_sum += tp; fp_sum += fp; fn_sum += fn; abs_sum += abs_err; denom_sum += denom
micro_p = tp_sum/(tp_sum+fp_sum) if (tp_sum+fp_sum)>0 else float('nan')
micro_r = tp_sum/(tp_sum+fn_sum) if (tp_sum+fn_sum)>0 else float('nan')
micro_f1 = 2*micro_p*micro_r/(micro_p+micro_r) if (not math.isnan(micro_p) and not math.isnan(micro_r) and (micro_p+micro_r)>0) else float('nan')
overall = {"sum_abs_count_error": abs_sum, "micro_precision": micro_p, "micro_recall": micro_r, "micro_f1": micro_f1, "micro_sMAPE": abs_sum/(denom_sum or 1.0)}
return pd.DataFrame(rows), overall
# --- Pretty eval for ONE image ---
def evaluate_one_image(
out: Dict, # from detect_tiled_softnms(...)
label_txt_path: str,
img_w: int, img_h: int,
iou_thr: float = 0.50,
conf_thr: float = 0.25,
return_vis: bool = False,
image_rgb: Optional[np.ndarray] = None
):
"""
Returns:
per_class_df (precision/recall/F1, counts),
overall (micro P/R/F1, totals),
(optional) annotated RGB image
"""
# Predictions (filter by conf)
p_boxes = out["xyxy"].astype(np.float32)
p_scores = out["conf"].astype(np.float32)
p_cls = out["cls"].astype(np.int32)
keep = p_scores >= float(conf_thr)
p_boxes, p_scores, p_cls = p_boxes[keep], p_scores[keep], p_cls[keep]
names: Dict[int,str] = out.get("names", {})
# Ground truth
g_cls, g_boxes = load_yolo_labels_xyxy(label_txt_path, img_w, img_h)
# Per-class counts (sanity)
actual_counts = {int(c): int((g_cls == c).sum()) for c in np.unique(g_cls)} if len(g_cls) else {}
pred_counts = {int(c): int((p_cls == c).sum()) for c in np.unique(p_cls)} if len(p_cls) else {}
count_df, count_overall = count_metrics(actual_counts, pred_counts)
# Matching
matches, p_unmatched_idx, g_unmatched_idx = greedy_match_per_class(
p_boxes, p_scores, p_cls, g_boxes, g_cls, iou_thr=iou_thr
)
matched_p = np.array([m[0] for m in matches], dtype=int) if matches else np.array([], dtype=int)
matched_g = np.array([m[1] for m in matches], dtype=int) if matches else np.array([], dtype=int)
# Compute per-class detection metrics
classes = sorted(set(list(actual_counts.keys()) + list(pred_counts.keys())))
rows = []
for c in classes:
tp = int(np.sum(p_cls[matched_p] == c)) # matched pairs already class-consistent
fp = int(np.sum((p_cls == c))) - tp
fn = int(np.sum((g_cls == c))) - tp
prec = tp/(tp+fp) if (tp+fp)>0 else float('nan')
rec = tp/(tp+fn) if (tp+fn)>0 else float('nan')
f1 = 2*prec*rec/(prec+rec) if (not math.isnan(prec) and not math.isnan(rec) and (prec+rec)>0) else float('nan')
rows.append({
"class_id": c,
"class_name": names.get(c, str(c)),
"gt": int(np.sum(g_cls==c)),
"pred": int(np.sum(p_cls==c)),
"TP": tp, "FP": fp, "FN": fn,
"precision": prec, "recall": rec, "F1": f1
})
det_df = pd.DataFrame(rows).sort_values("class_id").reset_index(drop=True)
# Overall detection micro-averages
TP = int(len(matches))
FP = int(len(p_boxes) - TP)
FN = int(len(g_boxes) - TP)
micro_p = TP/(TP+FP) if (TP+FP)>0 else float('nan')
micro_r = TP/(TP+FN) if (TP+FN)>0 else float('nan')
micro_f1 = 2*micro_p*micro_r/(micro_p+micro_r) if (not math.isnan(micro_p) and not math.isnan(micro_r) and (micro_p+micro_r)>0) else float('nan')
overall = {
"gt_instances": int(len(g_boxes)),
"pred_instances": int(len(p_boxes)),
"TP": TP, "FP": FP, "FN": FN,
"micro_precision": micro_p,
"micro_recall": micro_r,
"micro_F1": micro_f1,
"iou_thr": iou_thr,
"conf_thr": conf_thr
}
if not return_vis or image_rgb is None:
return det_df, overall, count_df, count_overall
# Annotated visualization
vis = image_rgb.copy()
# Draw GT (yellow)
for i in range(len(g_boxes)):
color = (240, 230, 70)
x1,y1,x2,y2 = g_boxes[i].astype(int)
cv2.rectangle(vis, (x1,y1), (x2,y2), color, 2)
# Draw matched predictions (green)
for pi in matched_p:
x1,y1,x2,y2 = p_boxes[pi].astype(int)
c = int(p_cls[pi]); sc = float(p_scores[pi])
label = f"{names.get(c,str(c))} {sc:.2f}"
cv2.rectangle(vis, (x1,y1), (x2,y2), (60, 220, 60), 2)
cv2.putText(vis, label, (x1+2, max(0,y1-5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (60,220,60), 2, cv2.LINE_AA)
# Draw unmatched predictions (red)
for pi in p_unmatched_idx:
x1,y1,x2,y2 = p_boxes[pi].astype(int)
c = int(p_cls[pi]); sc = float(p_scores[pi])
label = f"{names.get(c,str(c))} {sc:.2f}"
cv2.rectangle(vis, (x1,y1), (x2,y2), (10, 60, 240), 2)
cv2.putText(vis, label, (x1+2, max(0,y1-5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (10,60,240), 2, cv2.LINE_AA)
return det_df, overall, count_df, count_overall, vis