#!/usr/bin/env python3 """Re-evaluate all 135 trained seeds with paper-style metrics. For each /seeds/seed*/model_best.pt: - Reload the model with the right modalities - Build the test loader for that modality subset - Run inference, collect predictions - Compute Acc, Macro-F1, Weighted-F1 per head (verb_fine, verb_composite, noun, hand) and for the joint "action" (= verb_fine ∧ noun ∧ hand) - Write /eval_macrof1.json Cache the test_ds per modality subset so we don't rebuild it 135 times. """ from __future__ import annotations import json import os import sys import time from pathlib import Path import pandas as pd # noqa: F401 (dataset_seqpred imports pandas first) import numpy as np import torch from sklearn.metrics import f1_score, accuracy_score from torch.utils.data import DataLoader REPO = Path("${PULSE_ROOT}") sys.path.insert(0, str(REPO / "experiments")) from dataset_seqpred import ( # noqa: E402 TripletSeqPredDataset, build_train_test, collate_triplet, TRAIN_VOLS_V3, TEST_VOLS_V3, ) from models_seqpred import build_model # noqa: E402 def find_seed_dirs(): out = [] for table_name in [ "table1_main_comparison", "table3_horizon_curve", "table4_modality_ablation", "table5_component_ablation", "table7_missing_modality", ]: td = REPO / table_name for row_dir in sorted(td.glob("row*")): for sd in sorted((row_dir / "seeds").glob("seed*")): if (sd / "model_best.pt").exists() and (sd / "results.json").exists(): out.append(sd) return out _test_cache = {} # (modalities_tuple, t_obs, t_fut) -> (test_loader, modality_dims) def get_test_loader(modalities, t_obs, t_fut, downsample, num_workers=0): key = (tuple(modalities), float(t_obs), float(t_fut), int(downsample)) if key in _test_cache: return _test_cache[key] print(f" [build test loader] modalities={modalities} t_obs={t_obs} t_fut={t_fut}", flush=True) train_ds, test_ds = build_train_test( modalities=list(modalities), t_obs_sec=t_obs, t_fut_sec=t_fut, downsample=downsample, ) test_loader = DataLoader(test_ds, batch_size=64, shuffle=False, collate_fn=collate_triplet, num_workers=num_workers) md = test_ds.modality_dims _test_cache[key] = (test_loader, md) return test_loader, md def eval_one(seed_dir: Path, device: torch.device): res_p = seed_dir / "results.json" with open(res_p) as f: results = json.load(f) args = results["args"] model_name = args["model"] modalities = args["modalities"].split(",") t_obs = args["t_obs"] t_fut = args["t_fut"] downsample = args.get("downsample", 5) test_loader, modality_dims = get_test_loader(modalities, t_obs, t_fut, downsample) model = build_model(model_name, modality_dims).to(device) state = torch.load(seed_dir / "model_best.pt", map_location=device, weights_only=False) model.load_state_dict(state["state_dict"]) model.eval() all_logits = {k: [] for k in ("verb_fine", "verb_composite", "noun", "hand")} all_y = {k: [] for k in ("verb_fine", "verb_composite", "noun", "hand")} with torch.no_grad(): for x, mask, lens, y, meta in test_loader: x = {m: t.to(device) for m, t in x.items()} mask = mask.to(device) logits = model(x, mask) for k in all_logits: all_logits[k].append(logits[k].cpu()) all_y[k].append(y[k]) logits_cat = {k: torch.cat(v, dim=0) for k, v in all_logits.items()} y_cat = {k: torch.cat(v, dim=0).numpy() for k, v in all_y.items()} pred_cat = {k: logits_cat[k].argmax(dim=1).numpy() for k in logits_cat} out = {} for k in ("verb_fine", "verb_composite", "noun", "hand"): out[f"{k}_acc"] = float(accuracy_score(y_cat[k], pred_cat[k])) out[f"{k}_macro_f1"] = float(f1_score(y_cat[k], pred_cat[k], average="macro", zero_division=0)) out[f"{k}_weighted_f1"] = float(f1_score(y_cat[k], pred_cat[k], average="weighted", zero_division=0)) # Joint action = verb_fine AND noun AND hand correct correct = ((pred_cat["verb_fine"] == y_cat["verb_fine"]) & (pred_cat["noun"] == y_cat["noun"]) & (pred_cat["hand"] == y_cat["hand"])) out["action_acc"] = float(correct.mean()) # n_params (cheap) out["n_params"] = sum(p.numel() for p in model.parameters()) out_p = seed_dir / "eval_macrof1.json" with open(out_p, "w") as f: json.dump(out, f, indent=2) return out def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"device={device}", flush=True) seed_dirs = find_seed_dirs() print(f"Found {len(seed_dirs)} seed dirs", flush=True) t0 = time.time() n_ok = 0 n_fail = 0 for i, sd in enumerate(seed_dirs, 1): try: res = eval_one(sd, device) n_ok += 1 if i % 10 == 0 or i <= 3: rel = sd.relative_to(REPO) print(f" [{i:>3}/{len(seed_dirs)}] {rel} " f"action_acc={res['action_acc']:.4f} " f"verb_fine_macroF1={res['verb_fine_macro_f1']:.4f} " f"noun_macroF1={res['noun_macro_f1']:.4f}", flush=True) except Exception as e: n_fail += 1 print(f" [{i:>3}/{len(seed_dirs)}] FAIL {sd.relative_to(REPO)}: {e}", flush=True) dur = time.time() - t0 print(f"Done. ok={n_ok} fail={n_fail} elapsed={dur:.1f}s", flush=True) if __name__ == "__main__": main()