DariusGiannoli
fix: batch of bug fixes and missing features
2d788b3
"""Generalisation Evaluation β€” Stage 6 of the Generalisation pipeline."""
import streamlit as st
import cv2
import numpy as np
import plotly.graph_objects as go
import plotly.figure_factory as ff
from src.models import BACKBONES
def _iou(a, b):
xi1 = max(a[0], b[0]); yi1 = max(a[1], b[1])
xi2 = min(a[2], b[2]); yi2 = min(a[3], b[3])
inter = max(0, xi2 - xi1) * max(0, yi2 - yi1)
aa = (a[2] - a[0]) * (a[3] - a[1])
ab = (b[2] - b[0]) * (b[3] - b[1])
return inter / (aa + ab - inter + 1e-6)
def match_detections(dets, gt_list, iou_thr):
dets_sorted = sorted(dets, key=lambda d: d[5], reverse=True)
matched_gt = set()
results = []
for det in dets_sorted:
det_box = det[:4]
best_iou, best_gt_idx, best_gt_label = 0.0, -1, None
for gi, (gt_box, gt_label) in enumerate(gt_list):
if gi in matched_gt:
continue
iou_val = _iou(det_box, gt_box)
if iou_val > best_iou:
best_iou, best_gt_idx, best_gt_label = iou_val, gi, gt_label
if best_iou >= iou_thr and best_gt_idx >= 0:
matched_gt.add(best_gt_idx)
results.append((det, best_gt_label, best_iou))
else:
results.append((det, None, best_iou))
return results, len(gt_list) - len(matched_gt), matched_gt
def compute_pr_curve(dets, gt_list, iou_thr, steps=50):
if not dets:
return [], [], [], []
thresholds = np.linspace(0.0, 1.0, steps)
precisions, recalls, f1s = [], [], []
for thr in thresholds:
filtered = [d for d in dets if d[5] >= thr]
if not filtered:
precisions.append(1.0); recalls.append(0.0); f1s.append(0.0)
continue
matched, n_missed, _ = match_detections(filtered, gt_list, iou_thr)
tp = sum(1 for _, gt_lbl, _ in matched if gt_lbl is not None)
fp = sum(1 for _, gt_lbl, _ in matched if gt_lbl is None)
fn = n_missed
prec = tp / (tp + fp) if (tp + fp) > 0 else 1.0
rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
precisions.append(prec); recalls.append(rec); f1s.append(f1)
return thresholds.tolist(), precisions, recalls, f1s
def build_confusion_matrix(dets, gt_list, iou_thr):
gt_labels = sorted(set(lbl for _, lbl in gt_list))
all_labels = gt_labels + ["background"]
n = len(all_labels)
matrix = np.zeros((n, n), dtype=int)
label_to_idx = {lbl: i for i, lbl in enumerate(all_labels)}
matched, n_missed, matched_gt_indices = match_detections(dets, gt_list, iou_thr)
for det, gt_lbl, _ in matched:
pred_lbl = det[4]
if gt_lbl is not None:
pi = label_to_idx.get(pred_lbl, label_to_idx["background"])
gi = label_to_idx[gt_lbl]
matrix[pi][gi] += 1
else:
pi = label_to_idx.get(pred_lbl, label_to_idx["background"])
matrix[pi][label_to_idx["background"]] += 1
for gi, (_, gt_lbl) in enumerate(gt_list):
if gi not in matched_gt_indices:
matrix[label_to_idx["background"]][label_to_idx[gt_lbl]] += 1
return matrix, all_labels
def render():
st.title("πŸ“ˆ Evaluation: Confusion Matrix & PR Curves")
pipe = st.session_state.get("gen_pipeline")
if not pipe:
st.error("Complete the **Data Lab** first.")
st.stop()
crop = pipe.get("crop")
crop_aug = pipe.get("crop_aug", crop)
bbox = pipe.get("crop_bbox", (0, 0, crop.shape[1], crop.shape[0])) if crop is not None else None
rois = pipe.get("rois", [{"label": "object", "bbox": bbox,
"crop": crop, "crop_aug": crop_aug}])
rce_dets = pipe.get("rce_dets")
cnn_dets = pipe.get("cnn_dets")
orb_dets = pipe.get("orb_dets")
if rce_dets is None and cnn_dets is None and orb_dets is None:
st.warning("Run detection first in **Real-Time Detection**.")
st.stop()
gt_boxes = [(roi["bbox"], roi["label"]) for roi in rois]
st.sidebar.subheader("Evaluation Settings")
iou_thresh = st.sidebar.slider("IoU Threshold", 0.1, 0.9, 0.5, 0.05,
help="Minimum IoU to count as TP",
key="gen_eval_iou")
st.subheader("Ground Truth (from Data Lab ROIs)")
st.caption(f"{len(gt_boxes)} ground-truth ROIs defined")
gt_vis = pipe["test_image"].copy()
for (bx0, by0, bx1, by1), lbl in gt_boxes:
cv2.rectangle(gt_vis, (bx0, by0), (bx1, by1), (0, 255, 255), 2)
cv2.putText(gt_vis, lbl, (bx0, by0 - 6),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1)
st.image(cv2.cvtColor(gt_vis, cv2.COLOR_BGR2RGB),
caption="Ground Truth Annotations", use_container_width=True)
st.divider()
methods = {}
if rce_dets is not None:
methods["RCE"] = rce_dets
if cnn_dets is not None:
methods["CNN"] = cnn_dets
if orb_dets is not None:
methods["ORB"] = orb_dets
# Confusion Matrices
st.subheader("πŸ”² Confusion Matrices")
cm_cols = st.columns(len(methods))
for col, (name, dets) in zip(cm_cols, methods.items()):
with col:
st.markdown(f"**{name}**")
matrix, labels = build_confusion_matrix(dets, gt_boxes, iou_thresh)
fig_cm = ff.create_annotated_heatmap(
z=matrix.tolist(), x=labels, y=labels,
colorscale="Blues", showscale=True)
fig_cm.update_layout(title=f"{name} Confusion Matrix",
xaxis_title="Actual", yaxis_title="Predicted",
template="plotly_dark", height=350)
fig_cm.update_yaxes(autorange="reversed")
st.plotly_chart(fig_cm, use_container_width=True)
matched, n_missed, _ = match_detections(dets, gt_boxes, iou_thresh)
tp = sum(1 for _, g, _ in matched if g is not None)
fp = sum(1 for _, g, _ in matched if g is None)
fn = n_missed
prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
m1, m2, m3 = st.columns(3)
m1.metric("Precision", f"{prec:.1%}")
m2.metric("Recall", f"{rec:.1%}")
m3.metric("F1 Score", f"{f1:.1%}")
# PR Curves
st.divider()
st.subheader("πŸ“‰ Precision-Recall Curves")
method_colors = {"RCE": "#00ff88", "CNN": "#4488ff", "ORB": "#ff8800"}
fig_pr = go.Figure()
fig_f1 = go.Figure()
summary_rows = []
for name, dets in methods.items():
thrs, precs, recs, f1s = compute_pr_curve(dets, gt_boxes, iou_thresh)
clr = method_colors.get(name, "#ffffff")
fig_pr.add_trace(go.Scatter(
x=recs, y=precs, mode="lines+markers",
name=name, line=dict(color=clr, width=2), marker=dict(size=4)))
fig_f1.add_trace(go.Scatter(
x=thrs, y=f1s, mode="lines",
name=name, line=dict(color=clr, width=2)))
# 11-point interpolated AP (VOC standard)
if recs and precs:
ap = 0.0
for t in np.arange(0.0, 1.1, 0.1):
p_at_r = [p for p, r in zip(precs, recs) if r >= t]
ap += max(p_at_r) if p_at_r else 0.0
ap /= 11.0
else:
ap = 0.0
best_f1_idx = int(np.argmax(f1s)) if f1s else 0
summary_rows.append({
"Method": name,
"AP": f"{ap:.3f}",
"Best F1": f"{f1s[best_f1_idx]:.3f}" if f1s else "N/A",
"@ Threshold": f"{thrs[best_f1_idx]:.2f}" if thrs else "N/A",
"Detections": len(dets),
})
fig_pr.update_layout(title="Precision vs Recall",
xaxis_title="Recall", yaxis_title="Precision",
template="plotly_dark", height=400,
xaxis=dict(range=[0, 1.05]), yaxis=dict(range=[0, 1.05]))
fig_f1.update_layout(title="F1 Score vs Confidence Threshold",
xaxis_title="Confidence Threshold", yaxis_title="F1 Score",
template="plotly_dark", height=400,
xaxis=dict(range=[0, 1.05]), yaxis=dict(range=[0, 1.05]))
pc1, pc2 = st.columns(2)
pc1.plotly_chart(fig_pr, use_container_width=True)
pc2.plotly_chart(fig_f1, use_container_width=True)
# Summary Table
st.divider()
st.subheader("πŸ“Š Summary")
import pandas as pd
st.dataframe(pd.DataFrame(summary_rows), use_container_width=True, hide_index=True)
st.caption(f"All metrics computed at IoU threshold = **{iou_thresh:.2f}**.")