PULSE-code / scripts /eval_subset.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
#!/usr/bin/env python3
"""Per-subset evaluator.
Given a (modalities, t_obs, t_fut) triple, evaluate ALL trained seed dirs
across all 27 rows whose results.json matches that triple. Builds the test
dataset exactly once for the given triple, then iterates over matching
seeds, loads each model_best.pt, runs inference, and writes
<seed_dir>/eval_macrof1.json.
Used by dispatch_eval.sh to run 16 of these in parallel on the cluster.
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from pathlib import Path
import pandas as pd # noqa: F401 (must come before torch on this cluster)
import numpy as np
import torch
from sklearn.metrics import f1_score, accuracy_score
from torch.utils.data import DataLoader
REPO = Path("${PULSE_ROOT}")
sys.path.insert(0, str(REPO / "experiments"))
from dataset_seqpred import ( # noqa: E402
build_train_test, collate_triplet,
)
from models_seqpred import build_model # noqa: E402
def find_matching_seeds(mods_canon: str, t_obs: float, t_fut: float):
out = []
for tt in [
"table1_main_comparison",
"table3_horizon_curve",
"table4_modality_ablation",
"table5_component_ablation",
"table7_missing_modality",
]:
td = REPO / tt
for row_dir in sorted(td.glob("row*")):
seed42 = row_dir / "seeds" / "seed42" / "results.json"
if not seed42.exists():
continue
with open(seed42) as f:
d = json.load(f)
a = d["args"]
row_mods_canon = ",".join(sorted(a["modalities"].split(",")))
if (row_mods_canon == mods_canon
and abs(float(a["t_obs"]) - t_obs) < 1e-6
and abs(float(a["t_fut"]) - t_fut) < 1e-6):
for sd in sorted((row_dir / "seeds").glob("seed*")):
if (sd / "model_best.pt").exists() and (sd / "results.json").exists():
out.append(sd)
return out
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--modalities", required=True,
help="Sorted comma-separated list, e.g. 'emg,eyetrack,imu,mocap,pressure'")
ap.add_argument("--t_obs", type=float, required=True)
ap.add_argument("--t_fut", type=float, required=True)
args = ap.parse_args()
seed_dirs = find_matching_seeds(args.modalities, args.t_obs, args.t_fut)
print(f"Subset key=({args.modalities!r}, t_obs={args.t_obs}, t_fut={args.t_fut})", flush=True)
print(f"Matched {len(seed_dirs)} seed dirs", flush=True)
for sd in seed_dirs:
print(f" {sd.relative_to(REPO)}", flush=True)
if not seed_dirs:
return
# Each seed dir's args.modalities preserves the original (possibly unsorted)
# order, which determines the model's branch ordering. We use the first
# matching seed's order to build the test loader, then for any seed dir
# whose original order differs we rebuild — but in practice all seeds in
# a row share the same order, and rows with same canonical-set but different
# original order appear together in the dispatcher's same job (since the
# canonical key matches), so we have to handle order divergence.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device={device}", flush=True)
# Group seed_dirs by the original (un-sorted) modality list each used,
# because different orders → different branch indices in the model.
orders = {}
for sd in seed_dirs:
with open(sd / "results.json") as f:
d = json.load(f)
orig_mods = d["args"]["modalities"] # original order
orders.setdefault(orig_mods, []).append((sd, d))
print(f"Distinct original modality orderings under this canonical key: {len(orders)}",
flush=True)
n_ok, n_fail = 0, 0
t0 = time.time()
for orig_mods, group in orders.items():
mods_list = orig_mods.split(",")
print(f"\n=== Building test loader for original order: {mods_list} ===",
flush=True)
tb0 = time.time()
train_ds, test_ds = build_train_test(
modalities=mods_list,
t_obs_sec=args.t_obs, t_fut_sec=args.t_fut,
)
del train_ds # only need test stats which test_ds carries
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False,
collate_fn=collate_triplet, num_workers=0)
modality_dims = test_ds.modality_dims
print(f" build took {time.time()-tb0:.1f}s; test n={len(test_ds)}",
flush=True)
for sd, results in group:
args_d = results["args"]
try:
model = build_model(args_d["model"], modality_dims).to(device)
state = torch.load(sd / "model_best.pt", map_location=device,
weights_only=False)
model.load_state_dict(state["state_dict"])
model.eval()
all_logits = {k: [] for k in
("verb_fine", "verb_composite", "noun", "hand")}
all_y = {k: [] for k in
("verb_fine", "verb_composite", "noun", "hand")}
with torch.no_grad():
for x, mask, lens, y, meta in test_loader:
x = {m: t.to(device) for m, t in x.items()}
mask = mask.to(device)
logits = model(x, mask)
for k in all_logits:
all_logits[k].append(logits[k].cpu())
all_y[k].append(y[k])
logits_cat = {k: torch.cat(v, dim=0) for k, v in all_logits.items()}
y_cat = {k: torch.cat(v, dim=0).numpy() for k, v in all_y.items()}
pred_cat = {k: logits_cat[k].argmax(dim=1).numpy() for k in logits_cat}
out = {}
for k in ("verb_fine", "verb_composite", "noun", "hand"):
out[f"{k}_acc"] = float(accuracy_score(y_cat[k], pred_cat[k]))
out[f"{k}_macro_f1"] = float(f1_score(y_cat[k], pred_cat[k],
average="macro", zero_division=0))
out[f"{k}_weighted_f1"] = float(f1_score(y_cat[k], pred_cat[k],
average="weighted", zero_division=0))
correct = ((pred_cat["verb_fine"] == y_cat["verb_fine"]) &
(pred_cat["noun"] == y_cat["noun"]) &
(pred_cat["hand"] == y_cat["hand"]))
out["action_acc"] = float(correct.mean())
out["n_params"] = sum(p.numel() for p in model.parameters())
with open(sd / "eval_macrof1.json", "w") as f:
json.dump(out, f, indent=2)
print(f" OK {sd.relative_to(REPO)} action_acc={out['action_acc']:.4f}",
flush=True)
n_ok += 1
# free model
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
print(f" FAIL {sd.relative_to(REPO)}: {e}", flush=True)
n_fail += 1
print(f"\nSubset done. ok={n_ok} fail={n_fail} elapsed={time.time()-t0:.1f}s",
flush=True)
if __name__ == "__main__":
main()