| import math |
| from typing import Dict, Optional, Tuple |
| import numpy as np |
| import pandas as pd |
| import cv2 |
|
|
| |
| def load_yolo_labels_xyxy(txt_path: str, img_w: int, img_h: int) -> Tuple[np.ndarray, np.ndarray]: |
| """ |
| Returns: |
| cls_ids: (N,) int |
| boxes_xyxy: (N,4) float32 in pixel coords |
| """ |
| cls_ids, boxes = [], [] |
| with open(txt_path, "r") as f: |
| for line in f: |
| parts = line.strip().split() |
| if len(parts) != 5: |
| continue |
| c, xc, yc, w, h = parts |
| c = int(float(c)) |
| xc, yc, w, h = map(float, (xc, yc, w, h)) |
| |
| px = xc * img_w |
| py = yc * img_h |
| pw = w * img_w |
| ph = h * img_h |
| x1 = px - pw / 2.0 |
| y1 = py - ph / 2.0 |
| x2 = px + pw / 2.0 |
| y2 = py + ph / 2.0 |
| boxes.append([x1, y1, x2, y2]) |
| cls_ids.append(c) |
| if not boxes: |
| return np.zeros((0,), dtype=np.int32), np.zeros((0,4), dtype=np.float32) |
| return np.array(cls_ids, dtype=np.int32), np.array(boxes, dtype=np.float32) |
|
|
| |
| def iou_matrix(a_xyxy: np.ndarray, b_xyxy: np.ndarray) -> np.ndarray: |
| """Pairwise IoU: (Na,4) vs (Nb,4) -> (Na,Nb).""" |
| if a_xyxy.size == 0 or b_xyxy.size == 0: |
| return np.zeros((a_xyxy.shape[0], b_xyxy.shape[0]), dtype=np.float32) |
| ax1, ay1, ax2, ay2 = a_xyxy[:,0:1], a_xyxy[:,1:2], a_xyxy[:,2:3], a_xyxy[:,3:4] |
| bx1, by1, bx2, by2 = b_xyxy[:,0], b_xyxy[:,1], b_xyxy[:,2], b_xyxy[:,3] |
| xx1 = np.maximum(ax1, bx1) |
| yy1 = np.maximum(ay1, by1) |
| xx2 = np.minimum(ax2, bx2) |
| yy2 = np.minimum(ay2, by2) |
| inter = np.maximum(0, xx2 - xx1) * np.maximum(0, yy2 - yy1) |
| area_a = (ax2 - ax1) * (ay2 - ay1) |
| area_b = (bx2 - bx1) * (by2 - by1) |
| union = np.maximum(1e-9, area_a + area_b - inter) |
| return (inter / union).astype(np.float32) |
|
|
| def greedy_match_per_class( |
| pred_boxes: np.ndarray, pred_scores: np.ndarray, pred_cls: np.ndarray, |
| gt_boxes: np.ndarray, gt_cls: np.ndarray, |
| iou_thr: float |
| ): |
| """ |
| Greedy IoU matching per class. Returns: |
| matches: list of (pred_idx, gt_idx) |
| pred_unmatched: np.ndarray of unmatched pred indices |
| gt_unmatched: np.ndarray of unmatched gt indices |
| """ |
| matches = [] |
| pred_unmatched = np.ones(len(pred_boxes), dtype=bool) |
| gt_unmatched = np.ones(len(gt_boxes), dtype=bool) |
|
|
| classes = np.union1d(pred_cls, gt_cls) |
| for c in classes: |
| p_idx = np.where(pred_cls == c)[0] |
| g_idx = np.where(gt_cls == c)[0] |
| if len(p_idx) == 0 or len(g_idx) == 0: |
| continue |
|
|
| IoU = iou_matrix(pred_boxes[p_idx], gt_boxes[g_idx]) |
| |
| used_p = set(); used_g = set() |
| while True: |
| if IoU.size == 0: |
| break |
| m = np.max(IoU) |
| if m < iou_thr: |
| break |
| i, j = np.unravel_index(np.argmax(IoU), IoU.shape) |
| pi, gi = p_idx[i], g_idx[j] |
| if (i in used_p) or (j in used_g): |
| IoU[i, j] = -1.0 |
| continue |
| matches.append((pi, gi)) |
| used_p.add(i); used_g.add(j) |
| IoU[i, :] = -1.0 |
| IoU[:, j] = -1.0 |
|
|
| |
| for i in used_p: |
| pred_unmatched[p_idx[i]] = False |
| for j in used_g: |
| gt_unmatched[g_idx[j]] = False |
|
|
| return matches, np.where(pred_unmatched)[0], np.where(gt_unmatched)[0] |
|
|
| |
| def count_metrics(actual_counts: Dict[int, int], pred_counts: Dict[int, int]) -> Tuple[pd.DataFrame, Dict]: |
| labels = sorted(set(actual_counts)|set(pred_counts)) |
| rows = [] |
| tp_sum = fp_sum = fn_sum = 0 |
| abs_sum = 0 |
| denom_sum = 0 |
| for c in labels: |
| a = int(actual_counts.get(c, 0)) |
| p = int(pred_counts.get(c, 0)) |
| tp = min(a, p); fp = max(p-a, 0); fn = max(a-p, 0) |
| abs_err = abs(p-a) |
| denom = (abs(a)+abs(p))/2 if (a+p)>0 else 1.0 |
| smape = abs_err/denom |
| prec = tp/(tp+fp) if (tp+fp)>0 else float('nan') |
| rec = tp/(tp+fn) if (tp+fn)>0 else float('nan') |
| f1 = 2*prec*rec/(prec+rec) if (not math.isnan(prec) and not math.isnan(rec) and (prec+rec)>0) else float('nan') |
| rows.append({"class_id": c, "actual": a, "pred": p, "abs_err": abs_err, "sMAPE": smape, "P": prec, "R": rec, "F1": f1}) |
| tp_sum += tp; fp_sum += fp; fn_sum += fn; abs_sum += abs_err; denom_sum += denom |
| micro_p = tp_sum/(tp_sum+fp_sum) if (tp_sum+fp_sum)>0 else float('nan') |
| micro_r = tp_sum/(tp_sum+fn_sum) if (tp_sum+fn_sum)>0 else float('nan') |
| micro_f1 = 2*micro_p*micro_r/(micro_p+micro_r) if (not math.isnan(micro_p) and not math.isnan(micro_r) and (micro_p+micro_r)>0) else float('nan') |
| overall = {"sum_abs_count_error": abs_sum, "micro_precision": micro_p, "micro_recall": micro_r, "micro_f1": micro_f1, "micro_sMAPE": abs_sum/(denom_sum or 1.0)} |
| return pd.DataFrame(rows), overall |
|
|
| |
| def evaluate_one_image( |
| out: Dict, |
| label_txt_path: str, |
| img_w: int, img_h: int, |
| iou_thr: float = 0.50, |
| conf_thr: float = 0.25, |
| return_vis: bool = False, |
| image_rgb: Optional[np.ndarray] = None |
| ): |
| """ |
| Returns: |
| per_class_df (precision/recall/F1, counts), |
| overall (micro P/R/F1, totals), |
| (optional) annotated RGB image |
| """ |
| |
| p_boxes = out["xyxy"].astype(np.float32) |
| p_scores = out["conf"].astype(np.float32) |
| p_cls = out["cls"].astype(np.int32) |
| keep = p_scores >= float(conf_thr) |
| p_boxes, p_scores, p_cls = p_boxes[keep], p_scores[keep], p_cls[keep] |
| names: Dict[int,str] = out.get("names", {}) |
|
|
| |
| g_cls, g_boxes = load_yolo_labels_xyxy(label_txt_path, img_w, img_h) |
|
|
| |
| actual_counts = {int(c): int((g_cls == c).sum()) for c in np.unique(g_cls)} if len(g_cls) else {} |
| pred_counts = {int(c): int((p_cls == c).sum()) for c in np.unique(p_cls)} if len(p_cls) else {} |
| count_df, count_overall = count_metrics(actual_counts, pred_counts) |
|
|
| |
| matches, p_unmatched_idx, g_unmatched_idx = greedy_match_per_class( |
| p_boxes, p_scores, p_cls, g_boxes, g_cls, iou_thr=iou_thr |
| ) |
| matched_p = np.array([m[0] for m in matches], dtype=int) if matches else np.array([], dtype=int) |
| matched_g = np.array([m[1] for m in matches], dtype=int) if matches else np.array([], dtype=int) |
|
|
| |
| classes = sorted(set(list(actual_counts.keys()) + list(pred_counts.keys()))) |
| rows = [] |
| for c in classes: |
| tp = int(np.sum(p_cls[matched_p] == c)) |
| fp = int(np.sum((p_cls == c))) - tp |
| fn = int(np.sum((g_cls == c))) - tp |
| prec = tp/(tp+fp) if (tp+fp)>0 else float('nan') |
| rec = tp/(tp+fn) if (tp+fn)>0 else float('nan') |
| f1 = 2*prec*rec/(prec+rec) if (not math.isnan(prec) and not math.isnan(rec) and (prec+rec)>0) else float('nan') |
| rows.append({ |
| "class_id": c, |
| "class_name": names.get(c, str(c)), |
| "gt": int(np.sum(g_cls==c)), |
| "pred": int(np.sum(p_cls==c)), |
| "TP": tp, "FP": fp, "FN": fn, |
| "precision": prec, "recall": rec, "F1": f1 |
| }) |
| det_df = pd.DataFrame(rows).sort_values("class_id").reset_index(drop=True) |
|
|
| |
| TP = int(len(matches)) |
| FP = int(len(p_boxes) - TP) |
| FN = int(len(g_boxes) - TP) |
| micro_p = TP/(TP+FP) if (TP+FP)>0 else float('nan') |
| micro_r = TP/(TP+FN) if (TP+FN)>0 else float('nan') |
| micro_f1 = 2*micro_p*micro_r/(micro_p+micro_r) if (not math.isnan(micro_p) and not math.isnan(micro_r) and (micro_p+micro_r)>0) else float('nan') |
|
|
| overall = { |
| "gt_instances": int(len(g_boxes)), |
| "pred_instances": int(len(p_boxes)), |
| "TP": TP, "FP": FP, "FN": FN, |
| "micro_precision": micro_p, |
| "micro_recall": micro_r, |
| "micro_F1": micro_f1, |
| "iou_thr": iou_thr, |
| "conf_thr": conf_thr |
| } |
|
|
| if not return_vis or image_rgb is None: |
| return det_df, overall, count_df, count_overall |
|
|
| |
| vis = image_rgb.copy() |
| |
| for i in range(len(g_boxes)): |
| color = (240, 230, 70) |
| x1,y1,x2,y2 = g_boxes[i].astype(int) |
| cv2.rectangle(vis, (x1,y1), (x2,y2), color, 2) |
| |
| for pi in matched_p: |
| x1,y1,x2,y2 = p_boxes[pi].astype(int) |
| c = int(p_cls[pi]); sc = float(p_scores[pi]) |
| label = f"{names.get(c,str(c))} {sc:.2f}" |
| cv2.rectangle(vis, (x1,y1), (x2,y2), (60, 220, 60), 2) |
| cv2.putText(vis, label, (x1+2, max(0,y1-5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (60,220,60), 2, cv2.LINE_AA) |
| |
| for pi in p_unmatched_idx: |
| x1,y1,x2,y2 = p_boxes[pi].astype(int) |
| c = int(p_cls[pi]); sc = float(p_scores[pi]) |
| label = f"{names.get(c,str(c))} {sc:.2f}" |
| cv2.rectangle(vis, (x1,y1), (x2,y2), (10, 60, 240), 2) |
| cv2.putText(vis, label, (x1+2, max(0,y1-5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (10,60,240), 2, cv2.LINE_AA) |
| return det_df, overall, count_df, count_overall, vis |
|
|