File size: 15,445 Bytes
1b1b1e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
#!/usr/bin/env python3
"""Offline strict analysis from raw_errors_*.json files.

This computes EXACT metrics that need sample-level data (joint constraints,
percentile distributions, failure mode breakdown, Pareto frontier, etc.) that
the on-line eval script cannot easily aggregate.
"""
import json
import glob
import os
import math
from collections import OrderedDict, defaultdict
import numpy as np


OUT_A = "/mnt/sfs_turbo_new/R11181/project_vlm/exp_v5/output/job_exp4_settingA_20260430_083003"
OUT_B = "/mnt/sfs_turbo_new/R11181/project_vlm/exp_v5/output/job_exp4_settingB_20260430_083037"

DIMS = ["dx", "dy", "dz", "dpitch", "dyaw", "droll"]
SEEN_BY_B = {"Town01_Opt", "Town02_Opt", "Town03_Opt", "Town04_Opt",
             "Town05_Opt", "Town06_Opt", "Town07_Opt"}
UNSEEN_BY_B = {"Town10HD"}


def load_raw(out_dir):
    """Returns {map_name: list of sample dicts (each sample has dim->list[per_wp])}."""
    res = {}
    for d in sorted(glob.glob(f"{out_dir}/eval_strict_*")):
        if not os.path.isdir(d):
            continue
        map_name = os.path.basename(d).replace("eval_strict_", "")
        files = glob.glob(f"{d}/raw_errors_*.json")
        if not files:
            continue
        with open(files[0]) as f:
            payload = json.load(f)
        res[map_name] = payload["errors_per_sample"]
    return res


def per_sample_pos_rot(sample):
    """Convert {dim:[per_wp]} to ([pos_per_wp], [rot_per_wp])."""
    pos = []
    rot = []
    nw = len(sample["dx"])
    for i in range(nw):
        p = math.sqrt(sample["dx"][i]**2 + sample["dy"][i]**2 + sample["dz"][i]**2)
        r = math.sqrt(sample["dpitch"][i]**2 + sample["dyaw"][i]**2 + sample["droll"][i]**2)
        pos.append(p)
        rot.append(r)
    return pos, rot


def aggregate_metrics(samples):
    """Compute EXACT strict metrics from sample-level raw data."""
    if not samples:
        return {}
    n = len(samples)
    pos_rot = [per_sample_pos_rot(s) for s in samples]
    all_pos = [p for poss, _ in pos_rot for p in poss]
    all_rot = [r for _, rots in pos_rot for r in rots]
    fde = [poss[-1] for poss, _ in pos_rot]
    ade = [sum(poss)/len(poss) for poss, _ in pos_rot]
    fde_rot = [rots[-1] for _, rots in pos_rot]
    ade_rot = [sum(rots)/len(rots) for _, rots in pos_rot]

    m = OrderedDict()

    # ---- Sample-level rates (any wp under threshold) ----
    POS_THRS = [0.1, 0.2, 0.3, 0.5, 1.0, 2.0]
    ROT_THRS = [0.5, 1.0, 2.0, 5.0, 10.0]
    for thr in POS_THRS:
        m[f"SR@{thr}m"] = sum(1 for p in all_pos if p < thr) / len(all_pos)
    for thr in ROT_THRS:
        m[f"RotAcc@{thr}deg"] = sum(1 for r in all_rot if r < thr) / len(all_rot)

    # ---- Trajectory-level (ALL wps under threshold) ----
    TRAJ_POS = [0.3, 0.5, 1.0, 2.0]
    TRAJ_ROT = [1.0, 2.0, 5.0, 10.0]
    for thr in TRAJ_POS:
        m[f"TrajSR@{thr}m"] = sum(1 for poss, _ in pos_rot if all(p < thr for p in poss)) / n
    for thr in TRAJ_ROT:
        m[f"TrajRotSR@{thr}deg"] = sum(1 for _, rots in pos_rot if all(r < thr for r in rots)) / n

    # ---- TRUE Joint constraint rates (any wp satisfies BOTH pos AND rot) ----
    JOINT = [(0.5, 1.0), (0.5, 5.0), (0.5, 2.0),
             (0.3, 1.0), (1.0, 1.0), (1.0, 5.0)]
    for pt, rt in JOINT:
        hit = 0
        for poss, rots in pos_rot:
            if any(p < pt and r < rt for p, r in zip(poss, rots)):
                hit += 1
        m[f"JointSR@({pt}m,{rt}deg)"] = hit / n

    # ---- Trajectory-level TRUE Joint (ALL wps satisfy BOTH) ----
    for pt, rt in JOINT:
        hit = 0
        for poss, rots in pos_rot:
            if all(p < pt and r < rt for p, r in zip(poss, rots)):
                hit += 1
        m[f"TrajJointSR@({pt}m,{rt}deg)"] = hit / n

    # ---- Percentile / tail metrics ----
    fde_arr = np.array(fde); ade_arr = np.array(ade)
    rot_arr = np.array(all_rot); pos_arr = np.array(all_pos)
    for p in [50, 75, 90, 95, 99]:
        m[f"FDE_p{p}"] = float(np.percentile(fde_arr, p))
        m[f"ADE_p{p}"] = float(np.percentile(ade_arr, p))
        m[f"rot_err_p{p}"] = float(np.percentile(rot_arr, p))
        m[f"pos_err_p{p}"] = float(np.percentile(pos_arr, p))
    m["FDE_max"] = float(fde_arr.max())
    m["ADE_max"] = float(ade_arr.max())
    m["rot_err_max"] = float(rot_arr.max())

    # ---- Hard failure rates ----
    for thr in [1.0, 2.0, 5.0, 10.0]:
        m[f"HardFailRate_FDE_gt_{thr}m"] = sum(1 for f in fde if f > thr) / n
    for thr in [10.0, 30.0, 60.0]:
        per_sample_max_rot = [max(rots) if rots else 0 for _, rots in pos_rot]
        m[f"HardFailRate_max_rot_gt_{thr}deg"] = sum(1 for r in per_sample_max_rot if r > thr) / n

    # ---- Standard summary ----
    m["FDE_mean"] = float(fde_arr.mean())
    m["ADE_mean"] = float(ade_arr.mean())
    m["FDE_rot_mean"] = float(np.array(fde_rot).mean())
    m["pos_mae"] = float(pos_arr.mean())
    m["rot_mae"] = float(rot_arr.mean())
    m["pos_rmse"] = float(np.sqrt((pos_arr ** 2).mean()))
    m["rot_rmse"] = float(np.sqrt((rot_arr ** 2).mean()))
    m["n_samples"] = n
    return m


def fmt_pct(v): return f"{v*100:6.2f}%"
def fmt_num(v, d=4): return f"{v:7.{d}f}"


def main():
    print("Loading raw error data ...")
    A = load_raw(OUT_A)
    B = load_raw(OUT_B)
    maps = sorted(set(A.keys()) & set(B.keys()))
    if not maps:
        print("[ERROR] no maps with raw_errors_*.json found.")
        print("Did you run eval_exp4_strict_parallel.sh first?")
        return

    print(f"Maps with raw data: {maps}\n")

    # Compute exact metrics per map
    metrics_A = {m: aggregate_metrics(A[m]) for m in maps}
    metrics_B = {m: aggregate_metrics(B[m]) for m in maps}

    # MEAN across maps (exclude all)
    eval_maps = [m for m in maps if m != "all"]
    mean_A = OrderedDict()
    mean_B = OrderedDict()
    for k in metrics_A[eval_maps[0]].keys():
        if k == "n_samples":
            continue
        mean_A[k] = sum(metrics_A[m][k] for m in eval_maps) / len(eval_maps)
        mean_B[k] = sum(metrics_B[m][k] for m in eval_maps) / len(eval_maps)

    # ========================================================================
    # SECTION 1: Layered metrics (loose -> extreme strict)
    # ========================================================================
    print("=" * 100)
    print(" SECTION 1 — Layered precision (sample-level rates, EXACT)")
    print("=" * 100)
    LAYERED = OrderedDict([
        ("L1 LOOSE (saturated)", [
            ("SR@1.0m", "higher", "%"),
            ("SR@2.0m", "higher", "%"),
            ("RotAcc@10.0deg", "higher", "%"),
        ]),
        ("L2 STANDARD", [
            ("SR@0.5m", "higher", "%"),
            ("RotAcc@5.0deg", "higher", "%"),
            ("TrajSR@1.0m", "higher", "%"),
        ]),
        ("L3 STRICT", [
            ("SR@0.3m", "higher", "%"),
            ("RotAcc@2.0deg", "higher", "%"),
            ("RotAcc@1.0deg", "higher", "%"),
            ("TrajSR@0.5m", "higher", "%"),
            ("TrajRotSR@5.0deg", "higher", "%"),
        ]),
        ("L4 EXTREME", [
            ("SR@0.2m", "higher", "%"),
            ("SR@0.1m", "higher", "%"),
            ("RotAcc@0.5deg", "higher", "%"),
            ("TrajSR@0.3m", "higher", "%"),
            ("TrajRotSR@1.0deg", "higher", "%"),
        ]),
    ])

    for layer, entries in LAYERED.items():
        print(f"\n>>> {layer}")
        print(f"  {'Metric':25s}{'A mean':>12s}{'B mean':>12s}{'B - A':>12s}{'Win%':>10s}")
        print("  " + "-" * 75)
        for key, direction, _ in entries:
            a, b = mean_A.get(key), mean_B.get(key)
            if a is None or b is None:
                continue
            # win rate across maps
            wins = sum(1 for m in eval_maps
                       if (metrics_B[m][key] > metrics_A[m][key] if direction == "higher"
                           else metrics_B[m][key] < metrics_A[m][key]))
            ties = sum(1 for m in eval_maps if metrics_A[m][key] == metrics_B[m][key])
            diff = (b - a) * 100
            print(f"  {key:25s}{fmt_pct(a):>12s}{fmt_pct(b):>12s}"
                  f"{diff:+11.2f}pp{wins:>3d}/{len(eval_maps)}+{ties}t")

    # ========================================================================
    # SECTION 2: TRUE JOINT constraints (exact)
    # ========================================================================
    print("\n" + "=" * 100)
    print(" SECTION 2 — TRUE JOINT constraints (sample-level AND, exact)")
    print("=" * 100)
    print("  This is the GOLD STANDARD: each sample must satisfy BOTH pos+rot.")
    print()
    print(f"  {'Metric':40s}{'A mean':>12s}{'B mean':>12s}{'B - A':>12s}{'Win%':>10s}")
    print("  " + "-" * 90)
    JOINT_ROWS = [
        ("JointSR@(0.5m,1.0deg)",     "any wp pos<0.5 AND rot<1°"),
        ("JointSR@(0.5m,5.0deg)",     "any wp pos<0.5 AND rot<5°"),
        ("JointSR@(0.3m,1.0deg)",     "any wp pos<0.3 AND rot<1°"),
        ("JointSR@(1.0m,1.0deg)",     "any wp pos<1.0 AND rot<1°"),
        ("TrajJointSR@(0.5m,1.0deg)", "ALL wps pos<0.5 AND rot<1°"),
        ("TrajJointSR@(0.5m,5.0deg)", "ALL wps pos<0.5 AND rot<5°"),
        ("TrajJointSR@(1.0m,5.0deg)", "ALL wps pos<1.0 AND rot<5°"),
    ]
    for key, _desc in JOINT_ROWS:
        a, b = mean_A.get(key), mean_B.get(key)
        if a is None or b is None:
            continue
        wins = sum(1 for m in eval_maps if metrics_B[m][key] > metrics_A[m][key])
        ties = sum(1 for m in eval_maps if metrics_A[m][key] == metrics_B[m][key])
        diff = (b - a) * 100
        print(f"  {key:40s}{fmt_pct(a):>12s}{fmt_pct(b):>12s}"
              f"{diff:+11.2f}pp{wins:>3d}/{len(eval_maps)}+{ties}t")

    # ========================================================================
    # SECTION 3: Percentile distributions (tail risk)
    # ========================================================================
    print("\n" + "=" * 100)
    print(" SECTION 3 — Percentile distributions (tail risk, exact)")
    print("=" * 100)
    print("  Lower is better for all (these are error percentiles).")
    print()
    print(f"  {'Metric':25s}{'A':>10s}{'B':>10s}{'B improves':>14s}{'Win%':>10s}")
    print("  " + "-" * 75)
    PCT_ROWS = ["FDE_p50", "FDE_p75", "FDE_p90", "FDE_p95", "FDE_p99",
                "ADE_p50", "ADE_p75", "ADE_p90", "ADE_p95", "ADE_p99",
                "rot_err_p50", "rot_err_p75", "rot_err_p90", "rot_err_p95", "rot_err_p99",
                "pos_err_p50", "pos_err_p75", "pos_err_p90", "pos_err_p95", "pos_err_p99",
                "FDE_max", "ADE_max", "rot_err_max"]
    for key in PCT_ROWS:
        a, b = mean_A.get(key), mean_B.get(key)
        if a is None or b is None:
            continue
        wins = sum(1 for m in eval_maps if metrics_B[m][key] < metrics_A[m][key])
        rel = (a - b) / max(abs(a), 1e-9) * 100
        print(f"  {key:25s}{fmt_num(a):>10s}{fmt_num(b):>10s}"
              f"{rel:+12.2f}%  {wins:>3d}/{len(eval_maps)}")

    # ========================================================================
    # SECTION 4: Hard failure rates (catastrophic predictions)
    # ========================================================================
    print("\n" + "=" * 100)
    print(" SECTION 4 — HARD failure rates (catastrophic predictions)")
    print("=" * 100)
    print("  Lower is better. These are samples where the model went seriously wrong.")
    print()
    print(f"  {'Metric':40s}{'A':>10s}{'B':>10s}{'B improves':>14s}{'Win%':>10s}")
    print("  " + "-" * 90)
    HARD_ROWS = ["HardFailRate_FDE_gt_1.0m", "HardFailRate_FDE_gt_2.0m",
                 "HardFailRate_FDE_gt_5.0m", "HardFailRate_FDE_gt_10.0m",
                 "HardFailRate_max_rot_gt_10.0deg",
                 "HardFailRate_max_rot_gt_30.0deg",
                 "HardFailRate_max_rot_gt_60.0deg"]
    for key in HARD_ROWS:
        a, b = mean_A.get(key), mean_B.get(key)
        if a is None or b is None:
            continue
        wins = sum(1 for m in eval_maps if metrics_B[m][key] < metrics_A[m][key])
        rel = (a - b) / max(abs(a), 1e-9) * 100 if a > 0 else 0
        print(f"  {key:40s}{fmt_pct(a):>10s}{fmt_pct(b):>10s}"
              f"{rel:+12.2f}%  {wins:>3d}/{len(eval_maps)}")

    # ========================================================================
    # SECTION 5: OOD analysis (Town10HD vs seen maps)
    # ========================================================================
    print("\n" + "=" * 100)
    print(" SECTION 5 — OOD generalization (Town10HD = TRUE hold-out)")
    print("=" * 100)
    seen_maps = sorted(set(eval_maps) & SEEN_BY_B)
    unseen_maps = sorted(set(eval_maps) & UNSEEN_BY_B)
    if not unseen_maps:
        print("  No OOD maps in eval set, skipping.")
    else:
        print(f"  Seen by B (near-domain): {seen_maps}")
        print(f"  TRUE OOD: {unseen_maps}")
        print()
        OOD_KEYS = ["JointSR@(0.5m,1.0deg)", "TrajJointSR@(0.5m,5.0deg)",
                    "RotAcc@1.0deg", "FDE_p95", "HardFailRate_FDE_gt_2.0m"]
        for k in OOD_KEYS:
            a_seen = sum(metrics_A[m][k] for m in seen_maps) / len(seen_maps)
            b_seen = sum(metrics_B[m][k] for m in seen_maps) / len(seen_maps)
            a_uns = sum(metrics_A[m][k] for m in unseen_maps) / len(unseen_maps)
            b_uns = sum(metrics_B[m][k] for m in unseen_maps) / len(unseen_maps)
            is_pct = "SR" in k or "Acc" in k or "Rate" in k
            f = fmt_pct if is_pct else fmt_num
            print(f"  {k:40s}")
            print(f"     A seen: {f(a_seen)}  B seen: {f(b_seen)}  "
                  f"A unseen: {f(a_uns)}  B unseen: {f(b_uns)}")
            if is_pct:
                gap_seen = (b_seen - a_seen) * 100
                gap_uns = (b_uns - a_uns) * 100
                print(f"     B-A on seen: {gap_seen:+.2f}pp,  "
                      f"B-A on OOD: {gap_uns:+.2f}pp,  "
                      f"OOD-loss-A: {(a_seen-a_uns)*100:.2f}pp,  "
                      f"OOD-loss-B: {(b_seen-b_uns)*100:.2f}pp")

    # ========================================================================
    # SECTION 6: Composite verdict
    # ========================================================================
    print("\n" + "=" * 100)
    print(" SECTION 6 — Verdict")
    print("=" * 100)
    # Win rate over a curated set
    KEY_VERDICT_METRICS = [
        ("SR@0.5m", "higher"),
        ("RotAcc@1.0deg", "higher"),
        ("JointSR@(0.5m,1.0deg)", "higher"),
        ("TrajJointSR@(0.5m,5.0deg)", "higher"),
        ("FDE_p95", "lower"),
        ("HardFailRate_FDE_gt_2.0m", "lower"),
        ("rot_err_p95", "lower"),
    ]
    print(f"  {'Verdict metric':40s}{'A':>11s}{'B':>11s}{'B advantage':>15s}")
    print("  " + "-" * 80)
    a_wins = 0; b_wins = 0
    for key, direction in KEY_VERDICT_METRICS:
        a, b = mean_A.get(key), mean_B.get(key)
        if a is None or b is None:
            continue
        is_pct = "SR" in key or "Acc" in key or "Rate" in key
        f = fmt_pct if is_pct else fmt_num
        if direction == "higher":
            adv = f"{(b-a)*100:+.2f}pp"
            if b > a: b_wins += 1
            else: a_wins += 1
        else:
            adv = f"{(a-b)/max(abs(a),1e-9)*100:+.2f}%"
            if b < a: b_wins += 1
            else: a_wins += 1
        print(f"  {key:40s}{f(a):>11s}{f(b):>11s}{adv:>15s}")
    print()
    print(f"  Overall: B wins {b_wins}/{a_wins+b_wins} verdict metrics.")


if __name__ == "__main__":
    main()