#!/usr/bin/env python3 """Re-evaluate v3 saved models to compute action_vn@3 and action_vn@5. Loads model_best.pt from each seed dir, runs test set, computes: - action_vn_top1 / top3 / top5 (verb_fine top-K AND noun top-K) - verb_fine_top1 / top3 / top5 - noun_top1 / top3 / top5 Writes results into /eval_topk.json so the aggregator can pick them up. """ from __future__ import annotations import json, sys, time from pathlib import Path import pandas as pd # noqa import torch from torch.utils.data import DataLoader REPO = Path("${PULSE_ROOT}") sys.path.insert(0, str(REPO / "experiments")) from dataset_seqpred import build_train_test, collate_triplet # noqa from models_seqpred import build_model # noqa def topk_correct(logits, y, k): if k > logits.shape[1]: k = logits.shape[1] _, topk = logits.topk(k, dim=1) return (topk == y.unsqueeze(1)).any(dim=1) def find_v3_seed_dirs(): """Walk table1_main_comparison/row*/seeds_v3{,_bidir,_sf}/seed*/model_best.pt""" out = [] base = REPO / "table1_main_comparison" for row_dir in sorted(base.glob("row*")): for sub in ("seeds_v3", "seeds_v3_bidir", "seeds_v3_sf"): for sd in sorted((row_dir / sub).glob("seed*")): if (sd / "model_best.pt").exists() and (sd / "results.json").exists(): out.append(sd) return out _loader_cache = {} def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"device={device}", flush=True) seed_dirs = find_v3_seed_dirs() print(f"Found {len(seed_dirs)} v3 seed dirs", flush=True) t0 = time.time() n_ok, n_fail = 0, 0 for i, sd in enumerate(seed_dirs, 1): try: with open(sd / "results.json") as f: results = json.load(f) args = results["args"] mods_list = args["modalities"].split(",") mods_key = tuple(mods_list) mode = args.get("mode", "anticipation") if (mods_key, mode) not in _loader_cache: print(f" [build loader] mode={mode} modalities={mods_list}", flush=True) train_ds, test_ds = build_train_test(modalities=mods_list, mode=mode) del train_ds test_loader = DataLoader(test_ds, batch_size=64, shuffle=False, collate_fn=collate_triplet, num_workers=0) _loader_cache[(mods_key, mode)] = (test_loader, test_ds.modality_dims) test_loader, modality_dims = _loader_cache[(mods_key, mode)] extra = {} if args["model"] in ("dailyactformer", "ours", "daf"): extra["causal"] = (mode == "anticipation") model = build_model(args["model"], modality_dims, **extra).to(device) state = torch.load(sd / "model_best.pt", map_location=device, weights_only=False) model.load_state_dict(state["state_dict"]) model.eval() all_logits = {k: [] for k in ("verb_fine", "verb_composite", "noun", "hand")} all_y = {k: [] for k in ("verb_fine", "verb_composite", "noun", "hand")} with torch.no_grad(): for x, mask, lens, y, meta in test_loader: x = {m: t.to(device) for m, t in x.items()} mask = mask.to(device) logits = model(x, mask) for k in all_logits: all_logits[k].append(logits[k].cpu()) all_y[k].append(y[k]) logits_cat = {k: torch.cat(v, dim=0) for k, v in all_logits.items()} y_cat = {k: torch.cat(v, dim=0) for k, v in all_y.items()} out = {} for k in ("verb_fine", "verb_composite", "noun", "hand"): preds_top1 = logits_cat[k].argmax(dim=1) out[f"{k}_top1"] = float((preds_top1 == y_cat[k]).float().mean()) out[f"{k}_top3"] = float(topk_correct(logits_cat[k], y_cat[k], 3).float().mean()) out[f"{k}_top5"] = float(topk_correct(logits_cat[k], y_cat[k], 5).float().mean()) # Joint action_vn (verb_fine ∧ noun) at top-1, top-3, top-5 for K, lbl in [(1, "top1"), (3, "top3"), (5, "top5")]: vf_ok = topk_correct(logits_cat["verb_fine"], y_cat["verb_fine"], K) n_ok2 = topk_correct(logits_cat["noun"], y_cat["noun"], K) out[f"action_vn_{lbl}"] = float((vf_ok & n_ok2).float().mean()) with open(sd / "eval_topk.json", "w") as f: json.dump(out, f, indent=2) n_ok += 1 if i % 5 == 0 or i <= 3: rel = sd.relative_to(REPO) print(f" [{i:>3}/{len(seed_dirs)}] {rel} vn@1={out['action_vn_top1']:.4f} vn@3={out['action_vn_top3']:.4f} vn@5={out['action_vn_top5']:.4f}", flush=True) del model if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: n_fail += 1 print(f" [{i:>3}/{len(seed_dirs)}] FAIL {sd.relative_to(REPO)}: {e}", flush=True) print(f"Done. ok={n_ok} fail={n_fail} elapsed={time.time()-t0:.1f}s", flush=True) if __name__ == "__main__": main()