CosFly-Track / scripts /exp4_strict_offline_analysis.py
Ys404's picture
Add scripts and checkpoints (CosFly-Track release)
34bc2eb verified
raw
history blame
15.4 kB
#!/usr/bin/env python3
"""Offline strict analysis from raw_errors_*.json files.
This computes EXACT metrics that need sample-level data (joint constraints,
percentile distributions, failure mode breakdown, Pareto frontier, etc.) that
the on-line eval script cannot easily aggregate.
"""
import json
import glob
import os
import math
from collections import OrderedDict, defaultdict
import numpy as np
OUT_A = "/mnt/sfs_turbo_new/R11181/project_vlm/exp_v5/output/job_exp4_settingA_20260430_083003"
OUT_B = "/mnt/sfs_turbo_new/R11181/project_vlm/exp_v5/output/job_exp4_settingB_20260430_083037"
DIMS = ["dx", "dy", "dz", "dpitch", "dyaw", "droll"]
SEEN_BY_B = {"Town01_Opt", "Town02_Opt", "Town03_Opt", "Town04_Opt",
"Town05_Opt", "Town06_Opt", "Town07_Opt"}
UNSEEN_BY_B = {"Town10HD"}
def load_raw(out_dir):
"""Returns {map_name: list of sample dicts (each sample has dim->list[per_wp])}."""
res = {}
for d in sorted(glob.glob(f"{out_dir}/eval_strict_*")):
if not os.path.isdir(d):
continue
map_name = os.path.basename(d).replace("eval_strict_", "")
files = glob.glob(f"{d}/raw_errors_*.json")
if not files:
continue
with open(files[0]) as f:
payload = json.load(f)
res[map_name] = payload["errors_per_sample"]
return res
def per_sample_pos_rot(sample):
"""Convert {dim:[per_wp]} to ([pos_per_wp], [rot_per_wp])."""
pos = []
rot = []
nw = len(sample["dx"])
for i in range(nw):
p = math.sqrt(sample["dx"][i]**2 + sample["dy"][i]**2 + sample["dz"][i]**2)
r = math.sqrt(sample["dpitch"][i]**2 + sample["dyaw"][i]**2 + sample["droll"][i]**2)
pos.append(p)
rot.append(r)
return pos, rot
def aggregate_metrics(samples):
"""Compute EXACT strict metrics from sample-level raw data."""
if not samples:
return {}
n = len(samples)
pos_rot = [per_sample_pos_rot(s) for s in samples]
all_pos = [p for poss, _ in pos_rot for p in poss]
all_rot = [r for _, rots in pos_rot for r in rots]
fde = [poss[-1] for poss, _ in pos_rot]
ade = [sum(poss)/len(poss) for poss, _ in pos_rot]
fde_rot = [rots[-1] for _, rots in pos_rot]
ade_rot = [sum(rots)/len(rots) for _, rots in pos_rot]
m = OrderedDict()
# ---- Sample-level rates (any wp under threshold) ----
POS_THRS = [0.1, 0.2, 0.3, 0.5, 1.0, 2.0]
ROT_THRS = [0.5, 1.0, 2.0, 5.0, 10.0]
for thr in POS_THRS:
m[f"SR@{thr}m"] = sum(1 for p in all_pos if p < thr) / len(all_pos)
for thr in ROT_THRS:
m[f"RotAcc@{thr}deg"] = sum(1 for r in all_rot if r < thr) / len(all_rot)
# ---- Trajectory-level (ALL wps under threshold) ----
TRAJ_POS = [0.3, 0.5, 1.0, 2.0]
TRAJ_ROT = [1.0, 2.0, 5.0, 10.0]
for thr in TRAJ_POS:
m[f"TrajSR@{thr}m"] = sum(1 for poss, _ in pos_rot if all(p < thr for p in poss)) / n
for thr in TRAJ_ROT:
m[f"TrajRotSR@{thr}deg"] = sum(1 for _, rots in pos_rot if all(r < thr for r in rots)) / n
# ---- TRUE Joint constraint rates (any wp satisfies BOTH pos AND rot) ----
JOINT = [(0.5, 1.0), (0.5, 5.0), (0.5, 2.0),
(0.3, 1.0), (1.0, 1.0), (1.0, 5.0)]
for pt, rt in JOINT:
hit = 0
for poss, rots in pos_rot:
if any(p < pt and r < rt for p, r in zip(poss, rots)):
hit += 1
m[f"JointSR@({pt}m,{rt}deg)"] = hit / n
# ---- Trajectory-level TRUE Joint (ALL wps satisfy BOTH) ----
for pt, rt in JOINT:
hit = 0
for poss, rots in pos_rot:
if all(p < pt and r < rt for p, r in zip(poss, rots)):
hit += 1
m[f"TrajJointSR@({pt}m,{rt}deg)"] = hit / n
# ---- Percentile / tail metrics ----
fde_arr = np.array(fde); ade_arr = np.array(ade)
rot_arr = np.array(all_rot); pos_arr = np.array(all_pos)
for p in [50, 75, 90, 95, 99]:
m[f"FDE_p{p}"] = float(np.percentile(fde_arr, p))
m[f"ADE_p{p}"] = float(np.percentile(ade_arr, p))
m[f"rot_err_p{p}"] = float(np.percentile(rot_arr, p))
m[f"pos_err_p{p}"] = float(np.percentile(pos_arr, p))
m["FDE_max"] = float(fde_arr.max())
m["ADE_max"] = float(ade_arr.max())
m["rot_err_max"] = float(rot_arr.max())
# ---- Hard failure rates ----
for thr in [1.0, 2.0, 5.0, 10.0]:
m[f"HardFailRate_FDE_gt_{thr}m"] = sum(1 for f in fde if f > thr) / n
for thr in [10.0, 30.0, 60.0]:
per_sample_max_rot = [max(rots) if rots else 0 for _, rots in pos_rot]
m[f"HardFailRate_max_rot_gt_{thr}deg"] = sum(1 for r in per_sample_max_rot if r > thr) / n
# ---- Standard summary ----
m["FDE_mean"] = float(fde_arr.mean())
m["ADE_mean"] = float(ade_arr.mean())
m["FDE_rot_mean"] = float(np.array(fde_rot).mean())
m["pos_mae"] = float(pos_arr.mean())
m["rot_mae"] = float(rot_arr.mean())
m["pos_rmse"] = float(np.sqrt((pos_arr ** 2).mean()))
m["rot_rmse"] = float(np.sqrt((rot_arr ** 2).mean()))
m["n_samples"] = n
return m
def fmt_pct(v): return f"{v*100:6.2f}%"
def fmt_num(v, d=4): return f"{v:7.{d}f}"
def main():
print("Loading raw error data ...")
A = load_raw(OUT_A)
B = load_raw(OUT_B)
maps = sorted(set(A.keys()) & set(B.keys()))
if not maps:
print("[ERROR] no maps with raw_errors_*.json found.")
print("Did you run eval_exp4_strict_parallel.sh first?")
return
print(f"Maps with raw data: {maps}\n")
# Compute exact metrics per map
metrics_A = {m: aggregate_metrics(A[m]) for m in maps}
metrics_B = {m: aggregate_metrics(B[m]) for m in maps}
# MEAN across maps (exclude all)
eval_maps = [m for m in maps if m != "all"]
mean_A = OrderedDict()
mean_B = OrderedDict()
for k in metrics_A[eval_maps[0]].keys():
if k == "n_samples":
continue
mean_A[k] = sum(metrics_A[m][k] for m in eval_maps) / len(eval_maps)
mean_B[k] = sum(metrics_B[m][k] for m in eval_maps) / len(eval_maps)
# ========================================================================
# SECTION 1: Layered metrics (loose -> extreme strict)
# ========================================================================
print("=" * 100)
print(" SECTION 1 — Layered precision (sample-level rates, EXACT)")
print("=" * 100)
LAYERED = OrderedDict([
("L1 LOOSE (saturated)", [
("SR@1.0m", "higher", "%"),
("SR@2.0m", "higher", "%"),
("RotAcc@10.0deg", "higher", "%"),
]),
("L2 STANDARD", [
("SR@0.5m", "higher", "%"),
("RotAcc@5.0deg", "higher", "%"),
("TrajSR@1.0m", "higher", "%"),
]),
("L3 STRICT", [
("SR@0.3m", "higher", "%"),
("RotAcc@2.0deg", "higher", "%"),
("RotAcc@1.0deg", "higher", "%"),
("TrajSR@0.5m", "higher", "%"),
("TrajRotSR@5.0deg", "higher", "%"),
]),
("L4 EXTREME", [
("SR@0.2m", "higher", "%"),
("SR@0.1m", "higher", "%"),
("RotAcc@0.5deg", "higher", "%"),
("TrajSR@0.3m", "higher", "%"),
("TrajRotSR@1.0deg", "higher", "%"),
]),
])
for layer, entries in LAYERED.items():
print(f"\n>>> {layer}")
print(f" {'Metric':25s}{'A mean':>12s}{'B mean':>12s}{'B - A':>12s}{'Win%':>10s}")
print(" " + "-" * 75)
for key, direction, _ in entries:
a, b = mean_A.get(key), mean_B.get(key)
if a is None or b is None:
continue
# win rate across maps
wins = sum(1 for m in eval_maps
if (metrics_B[m][key] > metrics_A[m][key] if direction == "higher"
else metrics_B[m][key] < metrics_A[m][key]))
ties = sum(1 for m in eval_maps if metrics_A[m][key] == metrics_B[m][key])
diff = (b - a) * 100
print(f" {key:25s}{fmt_pct(a):>12s}{fmt_pct(b):>12s}"
f"{diff:+11.2f}pp{wins:>3d}/{len(eval_maps)}+{ties}t")
# ========================================================================
# SECTION 2: TRUE JOINT constraints (exact)
# ========================================================================
print("\n" + "=" * 100)
print(" SECTION 2 — TRUE JOINT constraints (sample-level AND, exact)")
print("=" * 100)
print(" This is the GOLD STANDARD: each sample must satisfy BOTH pos+rot.")
print()
print(f" {'Metric':40s}{'A mean':>12s}{'B mean':>12s}{'B - A':>12s}{'Win%':>10s}")
print(" " + "-" * 90)
JOINT_ROWS = [
("JointSR@(0.5m,1.0deg)", "any wp pos<0.5 AND rot<1°"),
("JointSR@(0.5m,5.0deg)", "any wp pos<0.5 AND rot<5°"),
("JointSR@(0.3m,1.0deg)", "any wp pos<0.3 AND rot<1°"),
("JointSR@(1.0m,1.0deg)", "any wp pos<1.0 AND rot<1°"),
("TrajJointSR@(0.5m,1.0deg)", "ALL wps pos<0.5 AND rot<1°"),
("TrajJointSR@(0.5m,5.0deg)", "ALL wps pos<0.5 AND rot<5°"),
("TrajJointSR@(1.0m,5.0deg)", "ALL wps pos<1.0 AND rot<5°"),
]
for key, _desc in JOINT_ROWS:
a, b = mean_A.get(key), mean_B.get(key)
if a is None or b is None:
continue
wins = sum(1 for m in eval_maps if metrics_B[m][key] > metrics_A[m][key])
ties = sum(1 for m in eval_maps if metrics_A[m][key] == metrics_B[m][key])
diff = (b - a) * 100
print(f" {key:40s}{fmt_pct(a):>12s}{fmt_pct(b):>12s}"
f"{diff:+11.2f}pp{wins:>3d}/{len(eval_maps)}+{ties}t")
# ========================================================================
# SECTION 3: Percentile distributions (tail risk)
# ========================================================================
print("\n" + "=" * 100)
print(" SECTION 3 — Percentile distributions (tail risk, exact)")
print("=" * 100)
print(" Lower is better for all (these are error percentiles).")
print()
print(f" {'Metric':25s}{'A':>10s}{'B':>10s}{'B improves':>14s}{'Win%':>10s}")
print(" " + "-" * 75)
PCT_ROWS = ["FDE_p50", "FDE_p75", "FDE_p90", "FDE_p95", "FDE_p99",
"ADE_p50", "ADE_p75", "ADE_p90", "ADE_p95", "ADE_p99",
"rot_err_p50", "rot_err_p75", "rot_err_p90", "rot_err_p95", "rot_err_p99",
"pos_err_p50", "pos_err_p75", "pos_err_p90", "pos_err_p95", "pos_err_p99",
"FDE_max", "ADE_max", "rot_err_max"]
for key in PCT_ROWS:
a, b = mean_A.get(key), mean_B.get(key)
if a is None or b is None:
continue
wins = sum(1 for m in eval_maps if metrics_B[m][key] < metrics_A[m][key])
rel = (a - b) / max(abs(a), 1e-9) * 100
print(f" {key:25s}{fmt_num(a):>10s}{fmt_num(b):>10s}"
f"{rel:+12.2f}% {wins:>3d}/{len(eval_maps)}")
# ========================================================================
# SECTION 4: Hard failure rates (catastrophic predictions)
# ========================================================================
print("\n" + "=" * 100)
print(" SECTION 4 — HARD failure rates (catastrophic predictions)")
print("=" * 100)
print(" Lower is better. These are samples where the model went seriously wrong.")
print()
print(f" {'Metric':40s}{'A':>10s}{'B':>10s}{'B improves':>14s}{'Win%':>10s}")
print(" " + "-" * 90)
HARD_ROWS = ["HardFailRate_FDE_gt_1.0m", "HardFailRate_FDE_gt_2.0m",
"HardFailRate_FDE_gt_5.0m", "HardFailRate_FDE_gt_10.0m",
"HardFailRate_max_rot_gt_10.0deg",
"HardFailRate_max_rot_gt_30.0deg",
"HardFailRate_max_rot_gt_60.0deg"]
for key in HARD_ROWS:
a, b = mean_A.get(key), mean_B.get(key)
if a is None or b is None:
continue
wins = sum(1 for m in eval_maps if metrics_B[m][key] < metrics_A[m][key])
rel = (a - b) / max(abs(a), 1e-9) * 100 if a > 0 else 0
print(f" {key:40s}{fmt_pct(a):>10s}{fmt_pct(b):>10s}"
f"{rel:+12.2f}% {wins:>3d}/{len(eval_maps)}")
# ========================================================================
# SECTION 5: OOD analysis (Town10HD vs seen maps)
# ========================================================================
print("\n" + "=" * 100)
print(" SECTION 5 — OOD generalization (Town10HD = TRUE hold-out)")
print("=" * 100)
seen_maps = sorted(set(eval_maps) & SEEN_BY_B)
unseen_maps = sorted(set(eval_maps) & UNSEEN_BY_B)
if not unseen_maps:
print(" No OOD maps in eval set, skipping.")
else:
print(f" Seen by B (near-domain): {seen_maps}")
print(f" TRUE OOD: {unseen_maps}")
print()
OOD_KEYS = ["JointSR@(0.5m,1.0deg)", "TrajJointSR@(0.5m,5.0deg)",
"RotAcc@1.0deg", "FDE_p95", "HardFailRate_FDE_gt_2.0m"]
for k in OOD_KEYS:
a_seen = sum(metrics_A[m][k] for m in seen_maps) / len(seen_maps)
b_seen = sum(metrics_B[m][k] for m in seen_maps) / len(seen_maps)
a_uns = sum(metrics_A[m][k] for m in unseen_maps) / len(unseen_maps)
b_uns = sum(metrics_B[m][k] for m in unseen_maps) / len(unseen_maps)
is_pct = "SR" in k or "Acc" in k or "Rate" in k
f = fmt_pct if is_pct else fmt_num
print(f" {k:40s}")
print(f" A seen: {f(a_seen)} B seen: {f(b_seen)} "
f"A unseen: {f(a_uns)} B unseen: {f(b_uns)}")
if is_pct:
gap_seen = (b_seen - a_seen) * 100
gap_uns = (b_uns - a_uns) * 100
print(f" B-A on seen: {gap_seen:+.2f}pp, "
f"B-A on OOD: {gap_uns:+.2f}pp, "
f"OOD-loss-A: {(a_seen-a_uns)*100:.2f}pp, "
f"OOD-loss-B: {(b_seen-b_uns)*100:.2f}pp")
# ========================================================================
# SECTION 6: Composite verdict
# ========================================================================
print("\n" + "=" * 100)
print(" SECTION 6 — Verdict")
print("=" * 100)
# Win rate over a curated set
KEY_VERDICT_METRICS = [
("SR@0.5m", "higher"),
("RotAcc@1.0deg", "higher"),
("JointSR@(0.5m,1.0deg)", "higher"),
("TrajJointSR@(0.5m,5.0deg)", "higher"),
("FDE_p95", "lower"),
("HardFailRate_FDE_gt_2.0m", "lower"),
("rot_err_p95", "lower"),
]
print(f" {'Verdict metric':40s}{'A':>11s}{'B':>11s}{'B advantage':>15s}")
print(" " + "-" * 80)
a_wins = 0; b_wins = 0
for key, direction in KEY_VERDICT_METRICS:
a, b = mean_A.get(key), mean_B.get(key)
if a is None or b is None:
continue
is_pct = "SR" in key or "Acc" in key or "Rate" in key
f = fmt_pct if is_pct else fmt_num
if direction == "higher":
adv = f"{(b-a)*100:+.2f}pp"
if b > a: b_wins += 1
else: a_wins += 1
else:
adv = f"{(a-b)/max(abs(a),1e-9)*100:+.2f}%"
if b < a: b_wins += 1
else: a_wins += 1
print(f" {key:40s}{f(a):>11s}{f(b):>11s}{adv:>15s}")
print()
print(f" Overall: B wins {b_wins}/{a_wins+b_wins} verdict metrics.")
if __name__ == "__main__":
main()