Spatial-BEATs / scripts /calibrate_activity.py
dieKarotte's picture
Add files using upload-large-folder tool
4fdc640 verified
Raw
History Blame Contribute Delete
17.3 kB
#!/usr/bin/env python3
"""v11d — activity threshold calibration + Top-K̂ decode for frame-track CSVs.
Re-decodes the **already-dumped** ``*__pred.csv`` files under different
gating strategies and reports which split prefers what.
Decode modes
------------
``threshold``
Keep predictions with ``activity_prob >= threshold``. Sweep multiple
thresholds in one run.
``topk_hat``
Per (sample, frame), keep the top-K̂ tracks by ``activity_prob`` where
K̂ comes from the ``num_active_pred`` column dumped by v10's
``num_active_head``. Requires that column to be present.
``topk_hat_min``
Top-K̂ AND ``activity_prob >= threshold`` (intersection). Recovers the
K̂ gate's source-count discipline while suppressing low-confidence
overprediction in single-source frames.
Metrics mirror ``analyze_csv_dump.py``: per-split ``hit_cls_and_angle``,
``class_right_angle_wrong``, ``no_same_class_pred_but_other_preds_exist``
and matched-TP / FP rates. Use those fields to pick the best decode for
each real_ov{1,2,3} split.
Usage
-----
python scripts/calibrate_activity.py \
--dump-dir checkpoints/.../valid_csv_dump \
--thresholds 0.3 0.4 0.5 0.6 \
--modes threshold topk_hat topk_hat_min \
--json-out calibration.json
"""
from __future__ import annotations
import argparse
import csv
import json
import math
import statistics
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any, Optional
def angular_distance_deg(azi1: float, ele1: float, azi2: float, ele2: float) -> float:
a1, e1 = math.radians(azi1), math.radians(ele1)
a2, e2 = math.radians(azi2), math.radians(ele2)
x1 = math.cos(e1) * math.cos(a1)
y1 = math.cos(e1) * math.sin(a1)
z1 = math.sin(e1)
x2 = math.cos(e2) * math.cos(a2)
y2 = math.cos(e2) * math.sin(a2)
z2 = math.sin(e2)
dot = max(-1.0, min(1.0, x1 * x2 + y1 * y2 + z1 * z2))
return math.degrees(math.acos(dot))
def infer_split(name: str) -> Optional[str]:
prefixes = (
("valid__ov1_real_", "real_ov1"),
("valid__ov2_real_", "real_ov2"),
("valid__ov3_real_", "real_ov3"),
("valid__ov1_", "ov1"),
("valid__ov2_", "ov2"),
("valid__ov3_", "ov3"),
("valid__hm3d__", "ov1"),
)
for prefix, split in prefixes:
if name.startswith(prefix):
return split
return None
def _row_dict(row: dict[str, str]) -> dict[str, Any]:
out = {
"frame_idx": int(row["frame_idx"]),
"class_idx": int(row["class_idx"]),
"class_name": row.get("class_name", ""),
"azi": float(row["azimuth_deg"]),
"ele": float(row["elevation_deg"]),
"activity_prob": float(row["activity_prob"]),
"track_or_src": int(row["src_or_track_idx"]),
}
if "num_active_pred" in row and row["num_active_pred"] != "":
out["num_active_pred"] = int(row["num_active_pred"])
return out
def load_pred_rows(csv_path: Path) -> dict[int, list[dict[str, Any]]]:
rows_by_frame: dict[int, list[dict[str, Any]]] = defaultdict(list)
with csv_path.open() as f:
for row in csv.DictReader(f):
d = _row_dict(row)
rows_by_frame[d["frame_idx"]].append(d)
return rows_by_frame
def load_gt_rows(csv_path: Path) -> dict[int, list[dict[str, Any]]]:
rows_by_frame: dict[int, list[dict[str, Any]]] = defaultdict(list)
with csv_path.open() as f:
for row in csv.DictReader(f):
d = _row_dict(row)
rows_by_frame[d["frame_idx"]].append(d)
return rows_by_frame
def filter_preds_threshold(
pred_by_frame: dict[int, list[dict[str, Any]]],
threshold: float,
) -> dict[int, list[dict[str, Any]]]:
out: dict[int, list[dict[str, Any]]] = defaultdict(list)
for t, preds in pred_by_frame.items():
out[t] = [p for p in preds if p["activity_prob"] >= threshold]
return out
def filter_preds_topk_hat(
pred_by_frame: dict[int, list[dict[str, Any]]],
threshold: Optional[float],
) -> dict[int, list[dict[str, Any]]]:
out: dict[int, list[dict[str, Any]]] = defaultdict(list)
for t, preds in pred_by_frame.items():
k_hat = None
for p in preds:
if "num_active_pred" in p:
k_hat = int(p["num_active_pred"])
break
if k_hat is None or k_hat <= 0:
out[t] = []
continue
sorted_preds = sorted(preds, key=lambda p: p["activity_prob"], reverse=True)
kept = sorted_preds[:k_hat]
if threshold is not None:
kept = [p for p in kept if p["activity_prob"] >= threshold]
out[t] = kept
return out
def analyze_one_pair(
pred_by_frame_filtered: dict[int, list[dict[str, Any]]],
gt_by_frame: dict[int, list[dict[str, Any]]],
pred_path_name: str,
) -> dict[str, Any]:
all_frames = sorted(set(pred_by_frame_filtered) | set(gt_by_frame))
gt_outcomes = Counter()
pred_outcomes = Counter()
frame_relation = Counter()
same_class_best_angles: list[float] = []
for frame_idx in all_frames:
preds = pred_by_frame_filtered.get(frame_idx, [])
gts = gt_by_frame.get(frame_idx, [])
num_gt = len(gts)
num_pred = len(preds)
if num_pred < num_gt:
frame_relation["under"] += 1
elif num_pred == num_gt:
frame_relation["equal"] += 1
else:
frame_relation["over"] += 1
if num_gt > 0 and num_pred == 0:
frame_relation["gt_no_pred"] += 1
for gt in gts:
same_class_preds = [p for p in preds if p["class_idx"] == gt["class_idx"]]
if same_class_preds:
best_angle = min(
angular_distance_deg(gt["azi"], gt["ele"], p["azi"], p["ele"])
for p in same_class_preds
)
same_class_best_angles.append(best_angle)
if best_angle <= 20.0:
gt_outcomes["hit_cls_and_angle"] += 1
else:
gt_outcomes["class_right_angle_wrong"] += 1
else:
if preds:
gt_outcomes["no_same_class_pred_but_other_preds_exist"] += 1
else:
gt_outcomes["no_pred_in_frame"] += 1
used_pred = [False] * len(preds)
used_gt = [False] * len(gts)
candidates: list[tuple[float, int, int]] = []
for pi, p in enumerate(preds):
for gi, gt in enumerate(gts):
if p["class_idx"] != gt["class_idx"]:
continue
angle = angular_distance_deg(gt["azi"], gt["ele"], p["azi"], p["ele"])
if angle <= 20.0:
candidates.append((angle, pi, gi))
candidates.sort()
for _, pi, gi in candidates:
if used_pred[pi] or used_gt[gi]:
continue
used_pred[pi] = True
used_gt[gi] = True
pred_outcomes["matched_tp"] += 1
for pi, p in enumerate(preds):
if used_pred[pi]:
continue
same_class_gt = [gt for gt in gts if gt["class_idx"] == p["class_idx"]]
if same_class_gt:
pred_outcomes["same_class_angle_wrong_fp"] += 1
else:
pred_outcomes["wrong_class_or_spurious_fp"] += 1
return {
"file": pred_path_name,
"frames": len(all_frames),
"avg_gt_per_frame": (
sum(len(gt_by_frame.get(t, [])) for t in all_frames) / len(all_frames)
if all_frames else 0.0
),
"avg_pred_per_frame": (
sum(len(pred_by_frame_filtered.get(t, [])) for t in all_frames)
/ len(all_frames) if all_frames else 0.0
),
"frame_relation": frame_relation,
"gt_outcomes": gt_outcomes,
"pred_outcomes": pred_outcomes,
"mean_same_class_best_angle": (
statistics.mean(same_class_best_angles) if same_class_best_angles else None
),
}
def aggregate_rows(rows: list[dict[str, Any]]) -> dict[str, Any]:
agg_gt = Counter()
agg_pred = Counter()
agg_frame = Counter()
total_frames = sum(r["frames"] for r in rows)
avg_gt = (
sum(r["avg_gt_per_frame"] * r["frames"] for r in rows) / total_frames
if total_frames else 0.0
)
avg_pred = (
sum(r["avg_pred_per_frame"] * r["frames"] for r in rows) / total_frames
if total_frames else 0.0
)
same_class_means = []
for r in rows:
agg_gt.update(r["gt_outcomes"])
agg_pred.update(r["pred_outcomes"])
agg_frame.update(r["frame_relation"])
if r["mean_same_class_best_angle"] is not None:
same_class_means.append(r["mean_same_class_best_angle"])
total_gt = sum(agg_gt.values())
total_pred = sum(agg_pred.values())
same_class_total = agg_gt["hit_cls_and_angle"] + agg_gt["class_right_angle_wrong"]
return {
"samples": len(rows),
"frames": total_frames,
"avg_gt_per_frame": avg_gt,
"avg_pred_per_frame": avg_pred,
"frame_relation": dict(agg_frame),
"gt_outcomes": dict(agg_gt),
"pred_outcomes": dict(agg_pred),
"gt_total": total_gt,
"pred_total": total_pred,
"hit_share": (
agg_gt["hit_cls_and_angle"] / total_gt if total_gt else None
),
"same_class_angle_le_20_share": (
agg_gt["hit_cls_and_angle"] / same_class_total if same_class_total else None
),
"matched_tp_precision": (
agg_pred["matched_tp"] / total_pred if total_pred else None
),
"matched_tp_recall": (
agg_pred["matched_tp"] / total_gt if total_gt else None
),
"mean_best_angle_when_same_class_exists": (
statistics.mean(same_class_means) if same_class_means else None
),
}
def evaluate_decode(
raw_pred_by_split: dict[str, list[tuple[str, dict[int, list[dict[str, Any]]]]]],
gt_by_split: dict[str, list[tuple[str, dict[int, list[dict[str, Any]]]]]],
decode_mode: str,
threshold: Optional[float],
) -> dict[str, Any]:
summary: dict[str, Any] = {}
for split, pred_entries in sorted(raw_pred_by_split.items()):
gt_entries = dict(gt_by_split[split])
rows: list[dict[str, Any]] = []
for pred_name, pred_rows in pred_entries:
if pred_name not in gt_entries:
continue
if decode_mode == "threshold":
filtered = filter_preds_threshold(pred_rows, threshold or 0.0)
elif decode_mode == "topk_hat":
filtered = filter_preds_topk_hat(pred_rows, None)
elif decode_mode == "topk_hat_min":
filtered = filter_preds_topk_hat(pred_rows, threshold)
else:
raise ValueError(f"unknown decode mode: {decode_mode}")
rows.append(analyze_one_pair(filtered, gt_entries[pred_name], pred_name))
summary[split] = aggregate_rows(rows) if rows else {"samples": 0}
return summary
def fmt_pct(num: int, den: int) -> str:
if den <= 0:
return " - "
return f"{100.0 * num / den:5.1f}%"
def print_summary(label: str, summary: dict[str, dict[str, Any]]) -> None:
print(f"=== {label} ===")
header = (
f"{'split':<10} {'samples':>7} {'avg_pred':>9} {'avg_gt':>7} "
f"{'hit':>6} {'cls_ok_ang_bad':>16} {'no_same_cls':>13} {'no_pred':>9} "
f"{'tp_prec':>9} {'tp_rec':>9}"
)
print(header)
for split, stats in summary.items():
if not stats or stats.get("samples", 0) == 0:
continue
gt = stats["gt_outcomes"]
gt_total = stats["gt_total"]
pred_total = stats["pred_total"]
print(
f"{split:<10} "
f"{stats['samples']:>7d} "
f"{stats['avg_pred_per_frame']:>9.2f} "
f"{stats['avg_gt_per_frame']:>7.2f} "
f"{fmt_pct(gt.get('hit_cls_and_angle', 0), gt_total):>6} "
f"{fmt_pct(gt.get('class_right_angle_wrong', 0), gt_total):>16} "
f"{fmt_pct(gt.get('no_same_class_pred_but_other_preds_exist', 0), gt_total):>13} "
f"{fmt_pct(gt.get('no_pred_in_frame', 0), gt_total):>9} "
f"{fmt_pct(stats['pred_outcomes'].get('matched_tp', 0), pred_total):>9} "
f"{fmt_pct(stats['pred_outcomes'].get('matched_tp', 0), gt_total):>9}"
)
print()
def pick_best(
results: dict[str, dict[str, Any]],
metric: str = "matched_tp_recall",
) -> dict[str, dict[str, Any]]:
"""For each split, find the run config that maximizes ``metric``."""
best: dict[str, dict[str, Any]] = {}
for label, summary in results.items():
for split, stats in summary.items():
if not stats or stats.get("samples", 0) == 0:
continue
value = stats.get(metric)
if value is None:
continue
cur = best.get(split)
if cur is None or value > cur["value"]:
best[split] = {
"config": label,
"metric": metric,
"value": value,
"stats": {
k: stats.get(k)
for k in (
"hit_share",
"matched_tp_precision",
"matched_tp_recall",
"avg_pred_per_frame",
"avg_gt_per_frame",
)
},
}
return best
def main() -> None:
parser = argparse.ArgumentParser(
description="Calibrate activity threshold + Top-K̂ decode (v11d)."
)
parser.add_argument(
"--dump-dir", type=Path, required=True,
help="Directory containing paired *__pred.csv and *__gt.csv files.",
)
parser.add_argument(
"--thresholds", type=float, nargs="*", default=[0.3, 0.4, 0.5, 0.6],
help="Activity thresholds to sweep (used by threshold and topk_hat_min).",
)
parser.add_argument(
"--modes", nargs="*", default=["threshold", "topk_hat", "topk_hat_min"],
choices=["threshold", "topk_hat", "topk_hat_min"],
help="Decode modes to evaluate.",
)
parser.add_argument(
"--best-metric", default="matched_tp_recall",
choices=[
"matched_tp_recall", "matched_tp_precision",
"hit_share", "same_class_angle_le_20_share",
],
help="Metric used by --print-best to rank configs per split.",
)
parser.add_argument(
"--json-out", type=Path, default=None,
help="Optional path to write the full result as JSON.",
)
args = parser.parse_args()
raw_pred_by_split: dict[str, list[tuple[str, dict[int, list[dict[str, Any]]]]]] = (
defaultdict(list)
)
gt_by_split: dict[str, list[tuple[str, dict[int, list[dict[str, Any]]]]]] = (
defaultdict(list)
)
for pred_path in sorted(args.dump_dir.glob("*__pred.csv")):
split = infer_split(pred_path.name)
if split is None:
continue
gt_path = Path(str(pred_path).replace("__pred.csv", "__gt.csv"))
if not gt_path.exists():
continue
raw_pred_by_split[split].append(
(pred_path.name, load_pred_rows(pred_path))
)
gt_by_split[split].append(
(pred_path.name, load_gt_rows(gt_path))
)
results: dict[str, dict[str, Any]] = {}
for mode in args.modes:
if mode == "topk_hat":
label = "topk_hat (no min)"
results[label] = evaluate_decode(
raw_pred_by_split, gt_by_split, mode, threshold=None,
)
print_summary(label, results[label])
else:
for thr in args.thresholds:
label = f"{mode} thr={thr:g}"
results[label] = evaluate_decode(
raw_pred_by_split, gt_by_split, mode, threshold=thr,
)
print_summary(label, results[label])
best = pick_best(results, metric=args.best_metric)
if best:
print(f"=== Best config per split (by {args.best_metric}) ===")
for split, info in sorted(best.items()):
stats = info["stats"]
print(
f" {split}: {info['config']} "
f"{args.best_metric}={info['value']:.3f} "
f"hit={stats.get('hit_share') or 0:.3f} "
f"prec={stats.get('matched_tp_precision') or 0:.3f} "
f"rec={stats.get('matched_tp_recall') or 0:.3f} "
f"avg_pred={stats.get('avg_pred_per_frame') or 0:.2f} "
f"avg_gt={stats.get('avg_gt_per_frame') or 0:.2f}"
)
print()
if args.json_out is not None:
payload = {
"dump_dir": str(args.dump_dir),
"best_metric": args.best_metric,
"best_per_split": best,
"results": results,
}
args.json_out.write_text(
json.dumps(payload, indent=2, ensure_ascii=False, default=str)
)
print(f"[Saved] {args.json_out}")
if __name__ == "__main__":
main()