| |
| """v11d — activity threshold calibration + Top-K̂ decode for frame-track CSVs. |
| |
| Re-decodes the **already-dumped** ``*__pred.csv`` files under different |
| gating strategies and reports which split prefers what. |
| |
| Decode modes |
| ------------ |
| ``threshold`` |
| Keep predictions with ``activity_prob >= threshold``. Sweep multiple |
| thresholds in one run. |
| |
| ``topk_hat`` |
| Per (sample, frame), keep the top-K̂ tracks by ``activity_prob`` where |
| K̂ comes from the ``num_active_pred`` column dumped by v10's |
| ``num_active_head``. Requires that column to be present. |
| |
| ``topk_hat_min`` |
| Top-K̂ AND ``activity_prob >= threshold`` (intersection). Recovers the |
| K̂ gate's source-count discipline while suppressing low-confidence |
| overprediction in single-source frames. |
| |
| Metrics mirror ``analyze_csv_dump.py``: per-split ``hit_cls_and_angle``, |
| ``class_right_angle_wrong``, ``no_same_class_pred_but_other_preds_exist`` |
| and matched-TP / FP rates. Use those fields to pick the best decode for |
| each real_ov{1,2,3} split. |
| |
| Usage |
| ----- |
| |
| python scripts/calibrate_activity.py \ |
| --dump-dir checkpoints/.../valid_csv_dump \ |
| --thresholds 0.3 0.4 0.5 0.6 \ |
| --modes threshold topk_hat topk_hat_min \ |
| --json-out calibration.json |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import csv |
| import json |
| import math |
| import statistics |
| from collections import Counter, defaultdict |
| from pathlib import Path |
| from typing import Any, Optional |
|
|
|
|
| def angular_distance_deg(azi1: float, ele1: float, azi2: float, ele2: float) -> float: |
| a1, e1 = math.radians(azi1), math.radians(ele1) |
| a2, e2 = math.radians(azi2), math.radians(ele2) |
| x1 = math.cos(e1) * math.cos(a1) |
| y1 = math.cos(e1) * math.sin(a1) |
| z1 = math.sin(e1) |
| x2 = math.cos(e2) * math.cos(a2) |
| y2 = math.cos(e2) * math.sin(a2) |
| z2 = math.sin(e2) |
| dot = max(-1.0, min(1.0, x1 * x2 + y1 * y2 + z1 * z2)) |
| return math.degrees(math.acos(dot)) |
|
|
|
|
| def infer_split(name: str) -> Optional[str]: |
| prefixes = ( |
| ("valid__ov1_real_", "real_ov1"), |
| ("valid__ov2_real_", "real_ov2"), |
| ("valid__ov3_real_", "real_ov3"), |
| ("valid__ov1_", "ov1"), |
| ("valid__ov2_", "ov2"), |
| ("valid__ov3_", "ov3"), |
| ("valid__hm3d__", "ov1"), |
| ) |
| for prefix, split in prefixes: |
| if name.startswith(prefix): |
| return split |
| return None |
|
|
|
|
| def _row_dict(row: dict[str, str]) -> dict[str, Any]: |
| out = { |
| "frame_idx": int(row["frame_idx"]), |
| "class_idx": int(row["class_idx"]), |
| "class_name": row.get("class_name", ""), |
| "azi": float(row["azimuth_deg"]), |
| "ele": float(row["elevation_deg"]), |
| "activity_prob": float(row["activity_prob"]), |
| "track_or_src": int(row["src_or_track_idx"]), |
| } |
| if "num_active_pred" in row and row["num_active_pred"] != "": |
| out["num_active_pred"] = int(row["num_active_pred"]) |
| return out |
|
|
|
|
| def load_pred_rows(csv_path: Path) -> dict[int, list[dict[str, Any]]]: |
| rows_by_frame: dict[int, list[dict[str, Any]]] = defaultdict(list) |
| with csv_path.open() as f: |
| for row in csv.DictReader(f): |
| d = _row_dict(row) |
| rows_by_frame[d["frame_idx"]].append(d) |
| return rows_by_frame |
|
|
|
|
| def load_gt_rows(csv_path: Path) -> dict[int, list[dict[str, Any]]]: |
| rows_by_frame: dict[int, list[dict[str, Any]]] = defaultdict(list) |
| with csv_path.open() as f: |
| for row in csv.DictReader(f): |
| d = _row_dict(row) |
| rows_by_frame[d["frame_idx"]].append(d) |
| return rows_by_frame |
|
|
|
|
| def filter_preds_threshold( |
| pred_by_frame: dict[int, list[dict[str, Any]]], |
| threshold: float, |
| ) -> dict[int, list[dict[str, Any]]]: |
| out: dict[int, list[dict[str, Any]]] = defaultdict(list) |
| for t, preds in pred_by_frame.items(): |
| out[t] = [p for p in preds if p["activity_prob"] >= threshold] |
| return out |
|
|
|
|
| def filter_preds_topk_hat( |
| pred_by_frame: dict[int, list[dict[str, Any]]], |
| threshold: Optional[float], |
| ) -> dict[int, list[dict[str, Any]]]: |
| out: dict[int, list[dict[str, Any]]] = defaultdict(list) |
| for t, preds in pred_by_frame.items(): |
| k_hat = None |
| for p in preds: |
| if "num_active_pred" in p: |
| k_hat = int(p["num_active_pred"]) |
| break |
| if k_hat is None or k_hat <= 0: |
| out[t] = [] |
| continue |
| sorted_preds = sorted(preds, key=lambda p: p["activity_prob"], reverse=True) |
| kept = sorted_preds[:k_hat] |
| if threshold is not None: |
| kept = [p for p in kept if p["activity_prob"] >= threshold] |
| out[t] = kept |
| return out |
|
|
|
|
| def analyze_one_pair( |
| pred_by_frame_filtered: dict[int, list[dict[str, Any]]], |
| gt_by_frame: dict[int, list[dict[str, Any]]], |
| pred_path_name: str, |
| ) -> dict[str, Any]: |
| all_frames = sorted(set(pred_by_frame_filtered) | set(gt_by_frame)) |
| gt_outcomes = Counter() |
| pred_outcomes = Counter() |
| frame_relation = Counter() |
| same_class_best_angles: list[float] = [] |
|
|
| for frame_idx in all_frames: |
| preds = pred_by_frame_filtered.get(frame_idx, []) |
| gts = gt_by_frame.get(frame_idx, []) |
| num_gt = len(gts) |
| num_pred = len(preds) |
|
|
| if num_pred < num_gt: |
| frame_relation["under"] += 1 |
| elif num_pred == num_gt: |
| frame_relation["equal"] += 1 |
| else: |
| frame_relation["over"] += 1 |
| if num_gt > 0 and num_pred == 0: |
| frame_relation["gt_no_pred"] += 1 |
|
|
| for gt in gts: |
| same_class_preds = [p for p in preds if p["class_idx"] == gt["class_idx"]] |
| if same_class_preds: |
| best_angle = min( |
| angular_distance_deg(gt["azi"], gt["ele"], p["azi"], p["ele"]) |
| for p in same_class_preds |
| ) |
| same_class_best_angles.append(best_angle) |
| if best_angle <= 20.0: |
| gt_outcomes["hit_cls_and_angle"] += 1 |
| else: |
| gt_outcomes["class_right_angle_wrong"] += 1 |
| else: |
| if preds: |
| gt_outcomes["no_same_class_pred_but_other_preds_exist"] += 1 |
| else: |
| gt_outcomes["no_pred_in_frame"] += 1 |
|
|
| used_pred = [False] * len(preds) |
| used_gt = [False] * len(gts) |
| candidates: list[tuple[float, int, int]] = [] |
| for pi, p in enumerate(preds): |
| for gi, gt in enumerate(gts): |
| if p["class_idx"] != gt["class_idx"]: |
| continue |
| angle = angular_distance_deg(gt["azi"], gt["ele"], p["azi"], p["ele"]) |
| if angle <= 20.0: |
| candidates.append((angle, pi, gi)) |
| candidates.sort() |
| for _, pi, gi in candidates: |
| if used_pred[pi] or used_gt[gi]: |
| continue |
| used_pred[pi] = True |
| used_gt[gi] = True |
| pred_outcomes["matched_tp"] += 1 |
| for pi, p in enumerate(preds): |
| if used_pred[pi]: |
| continue |
| same_class_gt = [gt for gt in gts if gt["class_idx"] == p["class_idx"]] |
| if same_class_gt: |
| pred_outcomes["same_class_angle_wrong_fp"] += 1 |
| else: |
| pred_outcomes["wrong_class_or_spurious_fp"] += 1 |
|
|
| return { |
| "file": pred_path_name, |
| "frames": len(all_frames), |
| "avg_gt_per_frame": ( |
| sum(len(gt_by_frame.get(t, [])) for t in all_frames) / len(all_frames) |
| if all_frames else 0.0 |
| ), |
| "avg_pred_per_frame": ( |
| sum(len(pred_by_frame_filtered.get(t, [])) for t in all_frames) |
| / len(all_frames) if all_frames else 0.0 |
| ), |
| "frame_relation": frame_relation, |
| "gt_outcomes": gt_outcomes, |
| "pred_outcomes": pred_outcomes, |
| "mean_same_class_best_angle": ( |
| statistics.mean(same_class_best_angles) if same_class_best_angles else None |
| ), |
| } |
|
|
|
|
| def aggregate_rows(rows: list[dict[str, Any]]) -> dict[str, Any]: |
| agg_gt = Counter() |
| agg_pred = Counter() |
| agg_frame = Counter() |
| total_frames = sum(r["frames"] for r in rows) |
| avg_gt = ( |
| sum(r["avg_gt_per_frame"] * r["frames"] for r in rows) / total_frames |
| if total_frames else 0.0 |
| ) |
| avg_pred = ( |
| sum(r["avg_pred_per_frame"] * r["frames"] for r in rows) / total_frames |
| if total_frames else 0.0 |
| ) |
| same_class_means = [] |
| for r in rows: |
| agg_gt.update(r["gt_outcomes"]) |
| agg_pred.update(r["pred_outcomes"]) |
| agg_frame.update(r["frame_relation"]) |
| if r["mean_same_class_best_angle"] is not None: |
| same_class_means.append(r["mean_same_class_best_angle"]) |
| total_gt = sum(agg_gt.values()) |
| total_pred = sum(agg_pred.values()) |
| same_class_total = agg_gt["hit_cls_and_angle"] + agg_gt["class_right_angle_wrong"] |
| return { |
| "samples": len(rows), |
| "frames": total_frames, |
| "avg_gt_per_frame": avg_gt, |
| "avg_pred_per_frame": avg_pred, |
| "frame_relation": dict(agg_frame), |
| "gt_outcomes": dict(agg_gt), |
| "pred_outcomes": dict(agg_pred), |
| "gt_total": total_gt, |
| "pred_total": total_pred, |
| "hit_share": ( |
| agg_gt["hit_cls_and_angle"] / total_gt if total_gt else None |
| ), |
| "same_class_angle_le_20_share": ( |
| agg_gt["hit_cls_and_angle"] / same_class_total if same_class_total else None |
| ), |
| "matched_tp_precision": ( |
| agg_pred["matched_tp"] / total_pred if total_pred else None |
| ), |
| "matched_tp_recall": ( |
| agg_pred["matched_tp"] / total_gt if total_gt else None |
| ), |
| "mean_best_angle_when_same_class_exists": ( |
| statistics.mean(same_class_means) if same_class_means else None |
| ), |
| } |
|
|
|
|
| def evaluate_decode( |
| raw_pred_by_split: dict[str, list[tuple[str, dict[int, list[dict[str, Any]]]]]], |
| gt_by_split: dict[str, list[tuple[str, dict[int, list[dict[str, Any]]]]]], |
| decode_mode: str, |
| threshold: Optional[float], |
| ) -> dict[str, Any]: |
| summary: dict[str, Any] = {} |
| for split, pred_entries in sorted(raw_pred_by_split.items()): |
| gt_entries = dict(gt_by_split[split]) |
| rows: list[dict[str, Any]] = [] |
| for pred_name, pred_rows in pred_entries: |
| if pred_name not in gt_entries: |
| continue |
| if decode_mode == "threshold": |
| filtered = filter_preds_threshold(pred_rows, threshold or 0.0) |
| elif decode_mode == "topk_hat": |
| filtered = filter_preds_topk_hat(pred_rows, None) |
| elif decode_mode == "topk_hat_min": |
| filtered = filter_preds_topk_hat(pred_rows, threshold) |
| else: |
| raise ValueError(f"unknown decode mode: {decode_mode}") |
| rows.append(analyze_one_pair(filtered, gt_entries[pred_name], pred_name)) |
| summary[split] = aggregate_rows(rows) if rows else {"samples": 0} |
| return summary |
|
|
|
|
| def fmt_pct(num: int, den: int) -> str: |
| if den <= 0: |
| return " - " |
| return f"{100.0 * num / den:5.1f}%" |
|
|
|
|
| def print_summary(label: str, summary: dict[str, dict[str, Any]]) -> None: |
| print(f"=== {label} ===") |
| header = ( |
| f"{'split':<10} {'samples':>7} {'avg_pred':>9} {'avg_gt':>7} " |
| f"{'hit':>6} {'cls_ok_ang_bad':>16} {'no_same_cls':>13} {'no_pred':>9} " |
| f"{'tp_prec':>9} {'tp_rec':>9}" |
| ) |
| print(header) |
| for split, stats in summary.items(): |
| if not stats or stats.get("samples", 0) == 0: |
| continue |
| gt = stats["gt_outcomes"] |
| gt_total = stats["gt_total"] |
| pred_total = stats["pred_total"] |
| print( |
| f"{split:<10} " |
| f"{stats['samples']:>7d} " |
| f"{stats['avg_pred_per_frame']:>9.2f} " |
| f"{stats['avg_gt_per_frame']:>7.2f} " |
| f"{fmt_pct(gt.get('hit_cls_and_angle', 0), gt_total):>6} " |
| f"{fmt_pct(gt.get('class_right_angle_wrong', 0), gt_total):>16} " |
| f"{fmt_pct(gt.get('no_same_class_pred_but_other_preds_exist', 0), gt_total):>13} " |
| f"{fmt_pct(gt.get('no_pred_in_frame', 0), gt_total):>9} " |
| f"{fmt_pct(stats['pred_outcomes'].get('matched_tp', 0), pred_total):>9} " |
| f"{fmt_pct(stats['pred_outcomes'].get('matched_tp', 0), gt_total):>9}" |
| ) |
| print() |
|
|
|
|
| def pick_best( |
| results: dict[str, dict[str, Any]], |
| metric: str = "matched_tp_recall", |
| ) -> dict[str, dict[str, Any]]: |
| """For each split, find the run config that maximizes ``metric``.""" |
| best: dict[str, dict[str, Any]] = {} |
| for label, summary in results.items(): |
| for split, stats in summary.items(): |
| if not stats or stats.get("samples", 0) == 0: |
| continue |
| value = stats.get(metric) |
| if value is None: |
| continue |
| cur = best.get(split) |
| if cur is None or value > cur["value"]: |
| best[split] = { |
| "config": label, |
| "metric": metric, |
| "value": value, |
| "stats": { |
| k: stats.get(k) |
| for k in ( |
| "hit_share", |
| "matched_tp_precision", |
| "matched_tp_recall", |
| "avg_pred_per_frame", |
| "avg_gt_per_frame", |
| ) |
| }, |
| } |
| return best |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description="Calibrate activity threshold + Top-K̂ decode (v11d)." |
| ) |
| parser.add_argument( |
| "--dump-dir", type=Path, required=True, |
| help="Directory containing paired *__pred.csv and *__gt.csv files.", |
| ) |
| parser.add_argument( |
| "--thresholds", type=float, nargs="*", default=[0.3, 0.4, 0.5, 0.6], |
| help="Activity thresholds to sweep (used by threshold and topk_hat_min).", |
| ) |
| parser.add_argument( |
| "--modes", nargs="*", default=["threshold", "topk_hat", "topk_hat_min"], |
| choices=["threshold", "topk_hat", "topk_hat_min"], |
| help="Decode modes to evaluate.", |
| ) |
| parser.add_argument( |
| "--best-metric", default="matched_tp_recall", |
| choices=[ |
| "matched_tp_recall", "matched_tp_precision", |
| "hit_share", "same_class_angle_le_20_share", |
| ], |
| help="Metric used by --print-best to rank configs per split.", |
| ) |
| parser.add_argument( |
| "--json-out", type=Path, default=None, |
| help="Optional path to write the full result as JSON.", |
| ) |
| args = parser.parse_args() |
|
|
| raw_pred_by_split: dict[str, list[tuple[str, dict[int, list[dict[str, Any]]]]]] = ( |
| defaultdict(list) |
| ) |
| gt_by_split: dict[str, list[tuple[str, dict[int, list[dict[str, Any]]]]]] = ( |
| defaultdict(list) |
| ) |
| for pred_path in sorted(args.dump_dir.glob("*__pred.csv")): |
| split = infer_split(pred_path.name) |
| if split is None: |
| continue |
| gt_path = Path(str(pred_path).replace("__pred.csv", "__gt.csv")) |
| if not gt_path.exists(): |
| continue |
| raw_pred_by_split[split].append( |
| (pred_path.name, load_pred_rows(pred_path)) |
| ) |
| gt_by_split[split].append( |
| (pred_path.name, load_gt_rows(gt_path)) |
| ) |
|
|
| results: dict[str, dict[str, Any]] = {} |
| for mode in args.modes: |
| if mode == "topk_hat": |
| label = "topk_hat (no min)" |
| results[label] = evaluate_decode( |
| raw_pred_by_split, gt_by_split, mode, threshold=None, |
| ) |
| print_summary(label, results[label]) |
| else: |
| for thr in args.thresholds: |
| label = f"{mode} thr={thr:g}" |
| results[label] = evaluate_decode( |
| raw_pred_by_split, gt_by_split, mode, threshold=thr, |
| ) |
| print_summary(label, results[label]) |
|
|
| best = pick_best(results, metric=args.best_metric) |
| if best: |
| print(f"=== Best config per split (by {args.best_metric}) ===") |
| for split, info in sorted(best.items()): |
| stats = info["stats"] |
| print( |
| f" {split}: {info['config']} " |
| f"{args.best_metric}={info['value']:.3f} " |
| f"hit={stats.get('hit_share') or 0:.3f} " |
| f"prec={stats.get('matched_tp_precision') or 0:.3f} " |
| f"rec={stats.get('matched_tp_recall') or 0:.3f} " |
| f"avg_pred={stats.get('avg_pred_per_frame') or 0:.2f} " |
| f"avg_gt={stats.get('avg_gt_per_frame') or 0:.2f}" |
| ) |
| print() |
|
|
| if args.json_out is not None: |
| payload = { |
| "dump_dir": str(args.dump_dir), |
| "best_metric": args.best_metric, |
| "best_per_split": best, |
| "results": results, |
| } |
| args.json_out.write_text( |
| json.dumps(payload, indent=2, ensure_ascii=False, default=str) |
| ) |
| print(f"[Saved] {args.json_out}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|