File size: 5,284 Bytes
b4b2877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python3
"""Re-evaluate v3 saved models to compute action_vn@3 and action_vn@5.

Loads model_best.pt from each seed dir, runs test set, computes:
  - action_vn_top1 / top3 / top5 (verb_fine top-K AND noun top-K)
  - verb_fine_top1 / top3 / top5
  - noun_top1 / top3 / top5

Writes results into <seed_dir>/eval_topk.json so the aggregator can pick them up.
"""

from __future__ import annotations
import json, sys, time
from pathlib import Path

import pandas as pd  # noqa
import torch
from torch.utils.data import DataLoader

REPO = Path("${PULSE_ROOT}")
sys.path.insert(0, str(REPO / "experiments"))

from dataset_seqpred import build_train_test, collate_triplet  # noqa
from models_seqpred import build_model  # noqa


def topk_correct(logits, y, k):
    if k > logits.shape[1]:
        k = logits.shape[1]
    _, topk = logits.topk(k, dim=1)
    return (topk == y.unsqueeze(1)).any(dim=1)


def find_v3_seed_dirs():
    """Walk table1_main_comparison/row*/seeds_v3{,_bidir,_sf}/seed*/model_best.pt"""
    out = []
    base = REPO / "table1_main_comparison"
    for row_dir in sorted(base.glob("row*")):
        for sub in ("seeds_v3", "seeds_v3_bidir", "seeds_v3_sf"):
            for sd in sorted((row_dir / sub).glob("seed*")):
                if (sd / "model_best.pt").exists() and (sd / "results.json").exists():
                    out.append(sd)
    return out


_loader_cache = {}


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"device={device}", flush=True)
    seed_dirs = find_v3_seed_dirs()
    print(f"Found {len(seed_dirs)} v3 seed dirs", flush=True)

    t0 = time.time()
    n_ok, n_fail = 0, 0
    for i, sd in enumerate(seed_dirs, 1):
        try:
            with open(sd / "results.json") as f:
                results = json.load(f)
            args = results["args"]
            mods_list = args["modalities"].split(",")
            mods_key = tuple(mods_list)
            mode = args.get("mode", "anticipation")

            if (mods_key, mode) not in _loader_cache:
                print(f"  [build loader] mode={mode} modalities={mods_list}", flush=True)
                train_ds, test_ds = build_train_test(modalities=mods_list, mode=mode)
                del train_ds
                test_loader = DataLoader(test_ds, batch_size=64, shuffle=False,
                                         collate_fn=collate_triplet, num_workers=0)
                _loader_cache[(mods_key, mode)] = (test_loader, test_ds.modality_dims)
            test_loader, modality_dims = _loader_cache[(mods_key, mode)]

            extra = {}
            if args["model"] in ("dailyactformer", "ours", "daf"):
                extra["causal"] = (mode == "anticipation")
            model = build_model(args["model"], modality_dims, **extra).to(device)
            state = torch.load(sd / "model_best.pt", map_location=device, weights_only=False)
            model.load_state_dict(state["state_dict"])
            model.eval()

            all_logits = {k: [] for k in ("verb_fine", "verb_composite", "noun", "hand")}
            all_y = {k: [] for k in ("verb_fine", "verb_composite", "noun", "hand")}
            with torch.no_grad():
                for x, mask, lens, y, meta in test_loader:
                    x = {m: t.to(device) for m, t in x.items()}
                    mask = mask.to(device)
                    logits = model(x, mask)
                    for k in all_logits:
                        all_logits[k].append(logits[k].cpu())
                        all_y[k].append(y[k])

            logits_cat = {k: torch.cat(v, dim=0) for k, v in all_logits.items()}
            y_cat = {k: torch.cat(v, dim=0) for k, v in all_y.items()}

            out = {}
            for k in ("verb_fine", "verb_composite", "noun", "hand"):
                preds_top1 = logits_cat[k].argmax(dim=1)
                out[f"{k}_top1"] = float((preds_top1 == y_cat[k]).float().mean())
                out[f"{k}_top3"] = float(topk_correct(logits_cat[k], y_cat[k], 3).float().mean())
                out[f"{k}_top5"] = float(topk_correct(logits_cat[k], y_cat[k], 5).float().mean())

            # Joint action_vn (verb_fine ∧ noun) at top-1, top-3, top-5
            for K, lbl in [(1, "top1"), (3, "top3"), (5, "top5")]:
                vf_ok = topk_correct(logits_cat["verb_fine"], y_cat["verb_fine"], K)
                n_ok2 = topk_correct(logits_cat["noun"], y_cat["noun"], K)
                out[f"action_vn_{lbl}"] = float((vf_ok & n_ok2).float().mean())

            with open(sd / "eval_topk.json", "w") as f:
                json.dump(out, f, indent=2)
            n_ok += 1
            if i % 5 == 0 or i <= 3:
                rel = sd.relative_to(REPO)
                print(f"  [{i:>3}/{len(seed_dirs)}] {rel}  vn@1={out['action_vn_top1']:.4f}  vn@3={out['action_vn_top3']:.4f}  vn@5={out['action_vn_top5']:.4f}", flush=True)
            del model
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        except Exception as e:
            n_fail += 1
            print(f"  [{i:>3}/{len(seed_dirs)}] FAIL {sd.relative_to(REPO)}: {e}", flush=True)

    print(f"Done. ok={n_ok} fail={n_fail} elapsed={time.time()-t0:.1f}s", flush=True)


if __name__ == "__main__":
    main()