#!/usr/bin/env python3 """Per-subset evaluator. Given a (modalities, t_obs, t_fut) triple, evaluate ALL trained seed dirs across all 27 rows whose results.json matches that triple. Builds the test dataset exactly once for the given triple, then iterates over matching seeds, loads each model_best.pt, runs inference, and writes /eval_macrof1.json. Used by dispatch_eval.sh to run 16 of these in parallel on the cluster. """ from __future__ import annotations import argparse import json import sys import time from pathlib import Path import pandas as pd # noqa: F401 (must come before torch on this cluster) import numpy as np import torch from sklearn.metrics import f1_score, accuracy_score from torch.utils.data import DataLoader REPO = Path("${PULSE_ROOT}") sys.path.insert(0, str(REPO / "experiments")) from dataset_seqpred import ( # noqa: E402 build_train_test, collate_triplet, ) from models_seqpred import build_model # noqa: E402 def find_matching_seeds(mods_canon: str, t_obs: float, t_fut: float): out = [] for tt in [ "table1_main_comparison", "table3_horizon_curve", "table4_modality_ablation", "table5_component_ablation", "table7_missing_modality", ]: td = REPO / tt for row_dir in sorted(td.glob("row*")): seed42 = row_dir / "seeds" / "seed42" / "results.json" if not seed42.exists(): continue with open(seed42) as f: d = json.load(f) a = d["args"] row_mods_canon = ",".join(sorted(a["modalities"].split(","))) if (row_mods_canon == mods_canon and abs(float(a["t_obs"]) - t_obs) < 1e-6 and abs(float(a["t_fut"]) - t_fut) < 1e-6): for sd in sorted((row_dir / "seeds").glob("seed*")): if (sd / "model_best.pt").exists() and (sd / "results.json").exists(): out.append(sd) return out def main(): ap = argparse.ArgumentParser() ap.add_argument("--modalities", required=True, help="Sorted comma-separated list, e.g. 'emg,eyetrack,imu,mocap,pressure'") ap.add_argument("--t_obs", type=float, required=True) ap.add_argument("--t_fut", type=float, required=True) args = ap.parse_args() seed_dirs = find_matching_seeds(args.modalities, args.t_obs, args.t_fut) print(f"Subset key=({args.modalities!r}, t_obs={args.t_obs}, t_fut={args.t_fut})", flush=True) print(f"Matched {len(seed_dirs)} seed dirs", flush=True) for sd in seed_dirs: print(f" {sd.relative_to(REPO)}", flush=True) if not seed_dirs: return # Each seed dir's args.modalities preserves the original (possibly unsorted) # order, which determines the model's branch ordering. We use the first # matching seed's order to build the test loader, then for any seed dir # whose original order differs we rebuild — but in practice all seeds in # a row share the same order, and rows with same canonical-set but different # original order appear together in the dispatcher's same job (since the # canonical key matches), so we have to handle order divergence. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"device={device}", flush=True) # Group seed_dirs by the original (un-sorted) modality list each used, # because different orders → different branch indices in the model. orders = {} for sd in seed_dirs: with open(sd / "results.json") as f: d = json.load(f) orig_mods = d["args"]["modalities"] # original order orders.setdefault(orig_mods, []).append((sd, d)) print(f"Distinct original modality orderings under this canonical key: {len(orders)}", flush=True) n_ok, n_fail = 0, 0 t0 = time.time() for orig_mods, group in orders.items(): mods_list = orig_mods.split(",") print(f"\n=== Building test loader for original order: {mods_list} ===", flush=True) tb0 = time.time() train_ds, test_ds = build_train_test( modalities=mods_list, t_obs_sec=args.t_obs, t_fut_sec=args.t_fut, ) del train_ds # only need test stats which test_ds carries test_loader = DataLoader(test_ds, batch_size=64, shuffle=False, collate_fn=collate_triplet, num_workers=0) modality_dims = test_ds.modality_dims print(f" build took {time.time()-tb0:.1f}s; test n={len(test_ds)}", flush=True) for sd, results in group: args_d = results["args"] try: model = build_model(args_d["model"], modality_dims).to(device) state = torch.load(sd / "model_best.pt", map_location=device, weights_only=False) model.load_state_dict(state["state_dict"]) model.eval() all_logits = {k: [] for k in ("verb_fine", "verb_composite", "noun", "hand")} all_y = {k: [] for k in ("verb_fine", "verb_composite", "noun", "hand")} with torch.no_grad(): for x, mask, lens, y, meta in test_loader: x = {m: t.to(device) for m, t in x.items()} mask = mask.to(device) logits = model(x, mask) for k in all_logits: all_logits[k].append(logits[k].cpu()) all_y[k].append(y[k]) logits_cat = {k: torch.cat(v, dim=0) for k, v in all_logits.items()} y_cat = {k: torch.cat(v, dim=0).numpy() for k, v in all_y.items()} pred_cat = {k: logits_cat[k].argmax(dim=1).numpy() for k in logits_cat} out = {} for k in ("verb_fine", "verb_composite", "noun", "hand"): out[f"{k}_acc"] = float(accuracy_score(y_cat[k], pred_cat[k])) out[f"{k}_macro_f1"] = float(f1_score(y_cat[k], pred_cat[k], average="macro", zero_division=0)) out[f"{k}_weighted_f1"] = float(f1_score(y_cat[k], pred_cat[k], average="weighted", zero_division=0)) correct = ((pred_cat["verb_fine"] == y_cat["verb_fine"]) & (pred_cat["noun"] == y_cat["noun"]) & (pred_cat["hand"] == y_cat["hand"])) out["action_acc"] = float(correct.mean()) out["n_params"] = sum(p.numel() for p in model.parameters()) with open(sd / "eval_macrof1.json", "w") as f: json.dump(out, f, indent=2) print(f" OK {sd.relative_to(REPO)} action_acc={out['action_acc']:.4f}", flush=True) n_ok += 1 # free model del model if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: print(f" FAIL {sd.relative_to(REPO)}: {e}", flush=True) n_fail += 1 print(f"\nSubset done. ok={n_ok} fail={n_fail} elapsed={time.time()-t0:.1f}s", flush=True) if __name__ == "__main__": main()