#!/usr/bin/env python3 """v11d — activity threshold calibration + Top-K̂ decode for frame-track CSVs. Re-decodes the **already-dumped** ``*__pred.csv`` files under different gating strategies and reports which split prefers what. Decode modes ------------ ``threshold`` Keep predictions with ``activity_prob >= threshold``. Sweep multiple thresholds in one run. ``topk_hat`` Per (sample, frame), keep the top-K̂ tracks by ``activity_prob`` where K̂ comes from the ``num_active_pred`` column dumped by v10's ``num_active_head``. Requires that column to be present. ``topk_hat_min`` Top-K̂ AND ``activity_prob >= threshold`` (intersection). Recovers the K̂ gate's source-count discipline while suppressing low-confidence overprediction in single-source frames. Metrics mirror ``analyze_csv_dump.py``: per-split ``hit_cls_and_angle``, ``class_right_angle_wrong``, ``no_same_class_pred_but_other_preds_exist`` and matched-TP / FP rates. Use those fields to pick the best decode for each real_ov{1,2,3} split. Usage ----- python scripts/calibrate_activity.py \ --dump-dir checkpoints/.../valid_csv_dump \ --thresholds 0.3 0.4 0.5 0.6 \ --modes threshold topk_hat topk_hat_min \ --json-out calibration.json """ from __future__ import annotations import argparse import csv import json import math import statistics from collections import Counter, defaultdict from pathlib import Path from typing import Any, Optional def angular_distance_deg(azi1: float, ele1: float, azi2: float, ele2: float) -> float: a1, e1 = math.radians(azi1), math.radians(ele1) a2, e2 = math.radians(azi2), math.radians(ele2) x1 = math.cos(e1) * math.cos(a1) y1 = math.cos(e1) * math.sin(a1) z1 = math.sin(e1) x2 = math.cos(e2) * math.cos(a2) y2 = math.cos(e2) * math.sin(a2) z2 = math.sin(e2) dot = max(-1.0, min(1.0, x1 * x2 + y1 * y2 + z1 * z2)) return math.degrees(math.acos(dot)) def infer_split(name: str) -> Optional[str]: prefixes = ( ("valid__ov1_real_", "real_ov1"), ("valid__ov2_real_", "real_ov2"), ("valid__ov3_real_", "real_ov3"), ("valid__ov1_", "ov1"), ("valid__ov2_", "ov2"), ("valid__ov3_", "ov3"), ("valid__hm3d__", "ov1"), ) for prefix, split in prefixes: if name.startswith(prefix): return split return None def _row_dict(row: dict[str, str]) -> dict[str, Any]: out = { "frame_idx": int(row["frame_idx"]), "class_idx": int(row["class_idx"]), "class_name": row.get("class_name", ""), "azi": float(row["azimuth_deg"]), "ele": float(row["elevation_deg"]), "activity_prob": float(row["activity_prob"]), "track_or_src": int(row["src_or_track_idx"]), } if "num_active_pred" in row and row["num_active_pred"] != "": out["num_active_pred"] = int(row["num_active_pred"]) return out def load_pred_rows(csv_path: Path) -> dict[int, list[dict[str, Any]]]: rows_by_frame: dict[int, list[dict[str, Any]]] = defaultdict(list) with csv_path.open() as f: for row in csv.DictReader(f): d = _row_dict(row) rows_by_frame[d["frame_idx"]].append(d) return rows_by_frame def load_gt_rows(csv_path: Path) -> dict[int, list[dict[str, Any]]]: rows_by_frame: dict[int, list[dict[str, Any]]] = defaultdict(list) with csv_path.open() as f: for row in csv.DictReader(f): d = _row_dict(row) rows_by_frame[d["frame_idx"]].append(d) return rows_by_frame def filter_preds_threshold( pred_by_frame: dict[int, list[dict[str, Any]]], threshold: float, ) -> dict[int, list[dict[str, Any]]]: out: dict[int, list[dict[str, Any]]] = defaultdict(list) for t, preds in pred_by_frame.items(): out[t] = [p for p in preds if p["activity_prob"] >= threshold] return out def filter_preds_topk_hat( pred_by_frame: dict[int, list[dict[str, Any]]], threshold: Optional[float], ) -> dict[int, list[dict[str, Any]]]: out: dict[int, list[dict[str, Any]]] = defaultdict(list) for t, preds in pred_by_frame.items(): k_hat = None for p in preds: if "num_active_pred" in p: k_hat = int(p["num_active_pred"]) break if k_hat is None or k_hat <= 0: out[t] = [] continue sorted_preds = sorted(preds, key=lambda p: p["activity_prob"], reverse=True) kept = sorted_preds[:k_hat] if threshold is not None: kept = [p for p in kept if p["activity_prob"] >= threshold] out[t] = kept return out def analyze_one_pair( pred_by_frame_filtered: dict[int, list[dict[str, Any]]], gt_by_frame: dict[int, list[dict[str, Any]]], pred_path_name: str, ) -> dict[str, Any]: all_frames = sorted(set(pred_by_frame_filtered) | set(gt_by_frame)) gt_outcomes = Counter() pred_outcomes = Counter() frame_relation = Counter() same_class_best_angles: list[float] = [] for frame_idx in all_frames: preds = pred_by_frame_filtered.get(frame_idx, []) gts = gt_by_frame.get(frame_idx, []) num_gt = len(gts) num_pred = len(preds) if num_pred < num_gt: frame_relation["under"] += 1 elif num_pred == num_gt: frame_relation["equal"] += 1 else: frame_relation["over"] += 1 if num_gt > 0 and num_pred == 0: frame_relation["gt_no_pred"] += 1 for gt in gts: same_class_preds = [p for p in preds if p["class_idx"] == gt["class_idx"]] if same_class_preds: best_angle = min( angular_distance_deg(gt["azi"], gt["ele"], p["azi"], p["ele"]) for p in same_class_preds ) same_class_best_angles.append(best_angle) if best_angle <= 20.0: gt_outcomes["hit_cls_and_angle"] += 1 else: gt_outcomes["class_right_angle_wrong"] += 1 else: if preds: gt_outcomes["no_same_class_pred_but_other_preds_exist"] += 1 else: gt_outcomes["no_pred_in_frame"] += 1 used_pred = [False] * len(preds) used_gt = [False] * len(gts) candidates: list[tuple[float, int, int]] = [] for pi, p in enumerate(preds): for gi, gt in enumerate(gts): if p["class_idx"] != gt["class_idx"]: continue angle = angular_distance_deg(gt["azi"], gt["ele"], p["azi"], p["ele"]) if angle <= 20.0: candidates.append((angle, pi, gi)) candidates.sort() for _, pi, gi in candidates: if used_pred[pi] or used_gt[gi]: continue used_pred[pi] = True used_gt[gi] = True pred_outcomes["matched_tp"] += 1 for pi, p in enumerate(preds): if used_pred[pi]: continue same_class_gt = [gt for gt in gts if gt["class_idx"] == p["class_idx"]] if same_class_gt: pred_outcomes["same_class_angle_wrong_fp"] += 1 else: pred_outcomes["wrong_class_or_spurious_fp"] += 1 return { "file": pred_path_name, "frames": len(all_frames), "avg_gt_per_frame": ( sum(len(gt_by_frame.get(t, [])) for t in all_frames) / len(all_frames) if all_frames else 0.0 ), "avg_pred_per_frame": ( sum(len(pred_by_frame_filtered.get(t, [])) for t in all_frames) / len(all_frames) if all_frames else 0.0 ), "frame_relation": frame_relation, "gt_outcomes": gt_outcomes, "pred_outcomes": pred_outcomes, "mean_same_class_best_angle": ( statistics.mean(same_class_best_angles) if same_class_best_angles else None ), } def aggregate_rows(rows: list[dict[str, Any]]) -> dict[str, Any]: agg_gt = Counter() agg_pred = Counter() agg_frame = Counter() total_frames = sum(r["frames"] for r in rows) avg_gt = ( sum(r["avg_gt_per_frame"] * r["frames"] for r in rows) / total_frames if total_frames else 0.0 ) avg_pred = ( sum(r["avg_pred_per_frame"] * r["frames"] for r in rows) / total_frames if total_frames else 0.0 ) same_class_means = [] for r in rows: agg_gt.update(r["gt_outcomes"]) agg_pred.update(r["pred_outcomes"]) agg_frame.update(r["frame_relation"]) if r["mean_same_class_best_angle"] is not None: same_class_means.append(r["mean_same_class_best_angle"]) total_gt = sum(agg_gt.values()) total_pred = sum(agg_pred.values()) same_class_total = agg_gt["hit_cls_and_angle"] + agg_gt["class_right_angle_wrong"] return { "samples": len(rows), "frames": total_frames, "avg_gt_per_frame": avg_gt, "avg_pred_per_frame": avg_pred, "frame_relation": dict(agg_frame), "gt_outcomes": dict(agg_gt), "pred_outcomes": dict(agg_pred), "gt_total": total_gt, "pred_total": total_pred, "hit_share": ( agg_gt["hit_cls_and_angle"] / total_gt if total_gt else None ), "same_class_angle_le_20_share": ( agg_gt["hit_cls_and_angle"] / same_class_total if same_class_total else None ), "matched_tp_precision": ( agg_pred["matched_tp"] / total_pred if total_pred else None ), "matched_tp_recall": ( agg_pred["matched_tp"] / total_gt if total_gt else None ), "mean_best_angle_when_same_class_exists": ( statistics.mean(same_class_means) if same_class_means else None ), } def evaluate_decode( raw_pred_by_split: dict[str, list[tuple[str, dict[int, list[dict[str, Any]]]]]], gt_by_split: dict[str, list[tuple[str, dict[int, list[dict[str, Any]]]]]], decode_mode: str, threshold: Optional[float], ) -> dict[str, Any]: summary: dict[str, Any] = {} for split, pred_entries in sorted(raw_pred_by_split.items()): gt_entries = dict(gt_by_split[split]) rows: list[dict[str, Any]] = [] for pred_name, pred_rows in pred_entries: if pred_name not in gt_entries: continue if decode_mode == "threshold": filtered = filter_preds_threshold(pred_rows, threshold or 0.0) elif decode_mode == "topk_hat": filtered = filter_preds_topk_hat(pred_rows, None) elif decode_mode == "topk_hat_min": filtered = filter_preds_topk_hat(pred_rows, threshold) else: raise ValueError(f"unknown decode mode: {decode_mode}") rows.append(analyze_one_pair(filtered, gt_entries[pred_name], pred_name)) summary[split] = aggregate_rows(rows) if rows else {"samples": 0} return summary def fmt_pct(num: int, den: int) -> str: if den <= 0: return " - " return f"{100.0 * num / den:5.1f}%" def print_summary(label: str, summary: dict[str, dict[str, Any]]) -> None: print(f"=== {label} ===") header = ( f"{'split':<10} {'samples':>7} {'avg_pred':>9} {'avg_gt':>7} " f"{'hit':>6} {'cls_ok_ang_bad':>16} {'no_same_cls':>13} {'no_pred':>9} " f"{'tp_prec':>9} {'tp_rec':>9}" ) print(header) for split, stats in summary.items(): if not stats or stats.get("samples", 0) == 0: continue gt = stats["gt_outcomes"] gt_total = stats["gt_total"] pred_total = stats["pred_total"] print( f"{split:<10} " f"{stats['samples']:>7d} " f"{stats['avg_pred_per_frame']:>9.2f} " f"{stats['avg_gt_per_frame']:>7.2f} " f"{fmt_pct(gt.get('hit_cls_and_angle', 0), gt_total):>6} " f"{fmt_pct(gt.get('class_right_angle_wrong', 0), gt_total):>16} " f"{fmt_pct(gt.get('no_same_class_pred_but_other_preds_exist', 0), gt_total):>13} " f"{fmt_pct(gt.get('no_pred_in_frame', 0), gt_total):>9} " f"{fmt_pct(stats['pred_outcomes'].get('matched_tp', 0), pred_total):>9} " f"{fmt_pct(stats['pred_outcomes'].get('matched_tp', 0), gt_total):>9}" ) print() def pick_best( results: dict[str, dict[str, Any]], metric: str = "matched_tp_recall", ) -> dict[str, dict[str, Any]]: """For each split, find the run config that maximizes ``metric``.""" best: dict[str, dict[str, Any]] = {} for label, summary in results.items(): for split, stats in summary.items(): if not stats or stats.get("samples", 0) == 0: continue value = stats.get(metric) if value is None: continue cur = best.get(split) if cur is None or value > cur["value"]: best[split] = { "config": label, "metric": metric, "value": value, "stats": { k: stats.get(k) for k in ( "hit_share", "matched_tp_precision", "matched_tp_recall", "avg_pred_per_frame", "avg_gt_per_frame", ) }, } return best def main() -> None: parser = argparse.ArgumentParser( description="Calibrate activity threshold + Top-K̂ decode (v11d)." ) parser.add_argument( "--dump-dir", type=Path, required=True, help="Directory containing paired *__pred.csv and *__gt.csv files.", ) parser.add_argument( "--thresholds", type=float, nargs="*", default=[0.3, 0.4, 0.5, 0.6], help="Activity thresholds to sweep (used by threshold and topk_hat_min).", ) parser.add_argument( "--modes", nargs="*", default=["threshold", "topk_hat", "topk_hat_min"], choices=["threshold", "topk_hat", "topk_hat_min"], help="Decode modes to evaluate.", ) parser.add_argument( "--best-metric", default="matched_tp_recall", choices=[ "matched_tp_recall", "matched_tp_precision", "hit_share", "same_class_angle_le_20_share", ], help="Metric used by --print-best to rank configs per split.", ) parser.add_argument( "--json-out", type=Path, default=None, help="Optional path to write the full result as JSON.", ) args = parser.parse_args() raw_pred_by_split: dict[str, list[tuple[str, dict[int, list[dict[str, Any]]]]]] = ( defaultdict(list) ) gt_by_split: dict[str, list[tuple[str, dict[int, list[dict[str, Any]]]]]] = ( defaultdict(list) ) for pred_path in sorted(args.dump_dir.glob("*__pred.csv")): split = infer_split(pred_path.name) if split is None: continue gt_path = Path(str(pred_path).replace("__pred.csv", "__gt.csv")) if not gt_path.exists(): continue raw_pred_by_split[split].append( (pred_path.name, load_pred_rows(pred_path)) ) gt_by_split[split].append( (pred_path.name, load_gt_rows(gt_path)) ) results: dict[str, dict[str, Any]] = {} for mode in args.modes: if mode == "topk_hat": label = "topk_hat (no min)" results[label] = evaluate_decode( raw_pred_by_split, gt_by_split, mode, threshold=None, ) print_summary(label, results[label]) else: for thr in args.thresholds: label = f"{mode} thr={thr:g}" results[label] = evaluate_decode( raw_pred_by_split, gt_by_split, mode, threshold=thr, ) print_summary(label, results[label]) best = pick_best(results, metric=args.best_metric) if best: print(f"=== Best config per split (by {args.best_metric}) ===") for split, info in sorted(best.items()): stats = info["stats"] print( f" {split}: {info['config']} " f"{args.best_metric}={info['value']:.3f} " f"hit={stats.get('hit_share') or 0:.3f} " f"prec={stats.get('matched_tp_precision') or 0:.3f} " f"rec={stats.get('matched_tp_recall') or 0:.3f} " f"avg_pred={stats.get('avg_pred_per_frame') or 0:.2f} " f"avg_gt={stats.get('avg_gt_per_frame') or 0:.2f}" ) print() if args.json_out is not None: payload = { "dump_dir": str(args.dump_dir), "best_metric": args.best_metric, "best_per_split": best, "results": results, } args.json_out.write_text( json.dumps(payload, indent=2, ensure_ascii=False, default=str) ) print(f"[Saved] {args.json_out}") if __name__ == "__main__": main()