Spaces:
Sleeping
Sleeping
| """Generalisation Evaluation β Stage 6 of the Generalisation pipeline.""" | |
| import streamlit as st | |
| import cv2 | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import plotly.figure_factory as ff | |
| from src.models import BACKBONES | |
| def _iou(a, b): | |
| xi1 = max(a[0], b[0]); yi1 = max(a[1], b[1]) | |
| xi2 = min(a[2], b[2]); yi2 = min(a[3], b[3]) | |
| inter = max(0, xi2 - xi1) * max(0, yi2 - yi1) | |
| aa = (a[2] - a[0]) * (a[3] - a[1]) | |
| ab = (b[2] - b[0]) * (b[3] - b[1]) | |
| return inter / (aa + ab - inter + 1e-6) | |
| def match_detections(dets, gt_list, iou_thr): | |
| dets_sorted = sorted(dets, key=lambda d: d[5], reverse=True) | |
| matched_gt = set() | |
| results = [] | |
| for det in dets_sorted: | |
| det_box = det[:4] | |
| best_iou, best_gt_idx, best_gt_label = 0.0, -1, None | |
| for gi, (gt_box, gt_label) in enumerate(gt_list): | |
| if gi in matched_gt: | |
| continue | |
| iou_val = _iou(det_box, gt_box) | |
| if iou_val > best_iou: | |
| best_iou, best_gt_idx, best_gt_label = iou_val, gi, gt_label | |
| if best_iou >= iou_thr and best_gt_idx >= 0: | |
| matched_gt.add(best_gt_idx) | |
| results.append((det, best_gt_label, best_iou)) | |
| else: | |
| results.append((det, None, best_iou)) | |
| return results, len(gt_list) - len(matched_gt), matched_gt | |
| def compute_pr_curve(dets, gt_list, iou_thr, steps=50): | |
| if not dets: | |
| return [], [], [], [] | |
| thresholds = np.linspace(0.0, 1.0, steps) | |
| precisions, recalls, f1s = [], [], [] | |
| for thr in thresholds: | |
| filtered = [d for d in dets if d[5] >= thr] | |
| if not filtered: | |
| precisions.append(1.0); recalls.append(0.0); f1s.append(0.0) | |
| continue | |
| matched, n_missed, _ = match_detections(filtered, gt_list, iou_thr) | |
| tp = sum(1 for _, gt_lbl, _ in matched if gt_lbl is not None) | |
| fp = sum(1 for _, gt_lbl, _ in matched if gt_lbl is None) | |
| fn = n_missed | |
| prec = tp / (tp + fp) if (tp + fp) > 0 else 1.0 | |
| rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0 | |
| precisions.append(prec); recalls.append(rec); f1s.append(f1) | |
| return thresholds.tolist(), precisions, recalls, f1s | |
| def build_confusion_matrix(dets, gt_list, iou_thr): | |
| gt_labels = sorted(set(lbl for _, lbl in gt_list)) | |
| all_labels = gt_labels + ["background"] | |
| n = len(all_labels) | |
| matrix = np.zeros((n, n), dtype=int) | |
| label_to_idx = {lbl: i for i, lbl in enumerate(all_labels)} | |
| matched, n_missed, matched_gt_indices = match_detections(dets, gt_list, iou_thr) | |
| for det, gt_lbl, _ in matched: | |
| pred_lbl = det[4] | |
| if gt_lbl is not None: | |
| pi = label_to_idx.get(pred_lbl, label_to_idx["background"]) | |
| gi = label_to_idx[gt_lbl] | |
| matrix[pi][gi] += 1 | |
| else: | |
| pi = label_to_idx.get(pred_lbl, label_to_idx["background"]) | |
| matrix[pi][label_to_idx["background"]] += 1 | |
| for gi, (_, gt_lbl) in enumerate(gt_list): | |
| if gi not in matched_gt_indices: | |
| matrix[label_to_idx["background"]][label_to_idx[gt_lbl]] += 1 | |
| return matrix, all_labels | |
| def render(): | |
| st.title("π Evaluation: Confusion Matrix & PR Curves") | |
| pipe = st.session_state.get("gen_pipeline") | |
| if not pipe: | |
| st.error("Complete the **Data Lab** first.") | |
| st.stop() | |
| crop = pipe.get("crop") | |
| crop_aug = pipe.get("crop_aug", crop) | |
| bbox = pipe.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0])) if crop is not None else None | |
| rois = pipe.get("rois", [{"label": "object", "bbox": bbox, | |
| "crop": crop, "crop_aug": crop_aug}]) | |
| rce_dets = pipe.get("rce_dets") | |
| cnn_dets = pipe.get("cnn_dets") | |
| orb_dets = pipe.get("orb_dets") | |
| if rce_dets is None and cnn_dets is None and orb_dets is None: | |
| st.warning("Run detection first in **Real-Time Detection**.") | |
| st.stop() | |
| gt_boxes = [(roi["bbox"], roi["label"]) for roi in rois] | |
| st.sidebar.subheader("Evaluation Settings") | |
| iou_thresh = st.sidebar.slider("IoU Threshold", 0.1, 0.9, 0.5, 0.05, | |
| help="Minimum IoU to count as TP", | |
| key="gen_eval_iou") | |
| st.subheader("Ground Truth (from Data Lab ROIs)") | |
| st.caption(f"{len(gt_boxes)} ground-truth ROIs defined") | |
| gt_vis = pipe["test_image"].copy() | |
| for (bx0, by0, bx1, by1), lbl in gt_boxes: | |
| cv2.rectangle(gt_vis, (bx0, by0), (bx1, by1), (0, 255, 255), 2) | |
| cv2.putText(gt_vis, lbl, (bx0, by0 - 6), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1) | |
| st.image(cv2.cvtColor(gt_vis, cv2.COLOR_BGR2RGB), | |
| caption="Ground Truth Annotations", use_container_width=True) | |
| st.divider() | |
| methods = {} | |
| if rce_dets is not None: | |
| methods["RCE"] = rce_dets | |
| if cnn_dets is not None: | |
| methods["CNN"] = cnn_dets | |
| if orb_dets is not None: | |
| methods["ORB"] = orb_dets | |
| # Confusion Matrices | |
| st.subheader("π² Confusion Matrices") | |
| cm_cols = st.columns(len(methods)) | |
| for col, (name, dets) in zip(cm_cols, methods.items()): | |
| with col: | |
| st.markdown(f"**{name}**") | |
| matrix, labels = build_confusion_matrix(dets, gt_boxes, iou_thresh) | |
| fig_cm = ff.create_annotated_heatmap( | |
| z=matrix.tolist(), x=labels, y=labels, | |
| colorscale="Blues", showscale=True) | |
| fig_cm.update_layout(title=f"{name} Confusion Matrix", | |
| xaxis_title="Actual", yaxis_title="Predicted", | |
| template="plotly_dark", height=350) | |
| fig_cm.update_yaxes(autorange="reversed") | |
| st.plotly_chart(fig_cm, use_container_width=True) | |
| matched, n_missed, _ = match_detections(dets, gt_boxes, iou_thresh) | |
| tp = sum(1 for _, g, _ in matched if g is not None) | |
| fp = sum(1 for _, g, _ in matched if g is None) | |
| fn = n_missed | |
| prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | |
| rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0 | |
| m1, m2, m3 = st.columns(3) | |
| m1.metric("Precision", f"{prec:.1%}") | |
| m2.metric("Recall", f"{rec:.1%}") | |
| m3.metric("F1 Score", f"{f1:.1%}") | |
| # PR Curves | |
| st.divider() | |
| st.subheader("π Precision-Recall Curves") | |
| method_colors = {"RCE": "#00ff88", "CNN": "#4488ff", "ORB": "#ff8800"} | |
| fig_pr = go.Figure() | |
| fig_f1 = go.Figure() | |
| summary_rows = [] | |
| for name, dets in methods.items(): | |
| thrs, precs, recs, f1s = compute_pr_curve(dets, gt_boxes, iou_thresh) | |
| clr = method_colors.get(name, "#ffffff") | |
| fig_pr.add_trace(go.Scatter( | |
| x=recs, y=precs, mode="lines+markers", | |
| name=name, line=dict(color=clr, width=2), marker=dict(size=4))) | |
| fig_f1.add_trace(go.Scatter( | |
| x=thrs, y=f1s, mode="lines", | |
| name=name, line=dict(color=clr, width=2))) | |
| # 11-point interpolated AP (VOC standard) | |
| if recs and precs: | |
| ap = 0.0 | |
| for t in np.arange(0.0, 1.1, 0.1): | |
| p_at_r = [p for p, r in zip(precs, recs) if r >= t] | |
| ap += max(p_at_r) if p_at_r else 0.0 | |
| ap /= 11.0 | |
| else: | |
| ap = 0.0 | |
| best_f1_idx = int(np.argmax(f1s)) if f1s else 0 | |
| summary_rows.append({ | |
| "Method": name, | |
| "AP": f"{ap:.3f}", | |
| "Best F1": f"{f1s[best_f1_idx]:.3f}" if f1s else "N/A", | |
| "@ Threshold": f"{thrs[best_f1_idx]:.2f}" if thrs else "N/A", | |
| "Detections": len(dets), | |
| }) | |
| fig_pr.update_layout(title="Precision vs Recall", | |
| xaxis_title="Recall", yaxis_title="Precision", | |
| template="plotly_dark", height=400, | |
| xaxis=dict(range=[0, 1.05]), yaxis=dict(range=[0, 1.05])) | |
| fig_f1.update_layout(title="F1 Score vs Confidence Threshold", | |
| xaxis_title="Confidence Threshold", yaxis_title="F1 Score", | |
| template="plotly_dark", height=400, | |
| xaxis=dict(range=[0, 1.05]), yaxis=dict(range=[0, 1.05])) | |
| pc1, pc2 = st.columns(2) | |
| pc1.plotly_chart(fig_pr, use_container_width=True) | |
| pc2.plotly_chart(fig_f1, use_container_width=True) | |
| # Summary Table | |
| st.divider() | |
| st.subheader("π Summary") | |
| import pandas as pd | |
| st.dataframe(pd.DataFrame(summary_rows), use_container_width=True, hide_index=True) | |
| st.caption(f"All metrics computed at IoU threshold = **{iou_thresh:.2f}**.") | |