File size: 7,511 Bytes
b4b2877 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | #!/usr/bin/env python3
"""Per-subset evaluator.
Given a (modalities, t_obs, t_fut) triple, evaluate ALL trained seed dirs
across all 27 rows whose results.json matches that triple. Builds the test
dataset exactly once for the given triple, then iterates over matching
seeds, loads each model_best.pt, runs inference, and writes
<seed_dir>/eval_macrof1.json.
Used by dispatch_eval.sh to run 16 of these in parallel on the cluster.
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from pathlib import Path
import pandas as pd # noqa: F401 (must come before torch on this cluster)
import numpy as np
import torch
from sklearn.metrics import f1_score, accuracy_score
from torch.utils.data import DataLoader
REPO = Path("${PULSE_ROOT}")
sys.path.insert(0, str(REPO / "experiments"))
from dataset_seqpred import ( # noqa: E402
build_train_test, collate_triplet,
)
from models_seqpred import build_model # noqa: E402
def find_matching_seeds(mods_canon: str, t_obs: float, t_fut: float):
out = []
for tt in [
"table1_main_comparison",
"table3_horizon_curve",
"table4_modality_ablation",
"table5_component_ablation",
"table7_missing_modality",
]:
td = REPO / tt
for row_dir in sorted(td.glob("row*")):
seed42 = row_dir / "seeds" / "seed42" / "results.json"
if not seed42.exists():
continue
with open(seed42) as f:
d = json.load(f)
a = d["args"]
row_mods_canon = ",".join(sorted(a["modalities"].split(",")))
if (row_mods_canon == mods_canon
and abs(float(a["t_obs"]) - t_obs) < 1e-6
and abs(float(a["t_fut"]) - t_fut) < 1e-6):
for sd in sorted((row_dir / "seeds").glob("seed*")):
if (sd / "model_best.pt").exists() and (sd / "results.json").exists():
out.append(sd)
return out
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--modalities", required=True,
help="Sorted comma-separated list, e.g. 'emg,eyetrack,imu,mocap,pressure'")
ap.add_argument("--t_obs", type=float, required=True)
ap.add_argument("--t_fut", type=float, required=True)
args = ap.parse_args()
seed_dirs = find_matching_seeds(args.modalities, args.t_obs, args.t_fut)
print(f"Subset key=({args.modalities!r}, t_obs={args.t_obs}, t_fut={args.t_fut})", flush=True)
print(f"Matched {len(seed_dirs)} seed dirs", flush=True)
for sd in seed_dirs:
print(f" {sd.relative_to(REPO)}", flush=True)
if not seed_dirs:
return
# Each seed dir's args.modalities preserves the original (possibly unsorted)
# order, which determines the model's branch ordering. We use the first
# matching seed's order to build the test loader, then for any seed dir
# whose original order differs we rebuild — but in practice all seeds in
# a row share the same order, and rows with same canonical-set but different
# original order appear together in the dispatcher's same job (since the
# canonical key matches), so we have to handle order divergence.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device={device}", flush=True)
# Group seed_dirs by the original (un-sorted) modality list each used,
# because different orders → different branch indices in the model.
orders = {}
for sd in seed_dirs:
with open(sd / "results.json") as f:
d = json.load(f)
orig_mods = d["args"]["modalities"] # original order
orders.setdefault(orig_mods, []).append((sd, d))
print(f"Distinct original modality orderings under this canonical key: {len(orders)}",
flush=True)
n_ok, n_fail = 0, 0
t0 = time.time()
for orig_mods, group in orders.items():
mods_list = orig_mods.split(",")
print(f"\n=== Building test loader for original order: {mods_list} ===",
flush=True)
tb0 = time.time()
train_ds, test_ds = build_train_test(
modalities=mods_list,
t_obs_sec=args.t_obs, t_fut_sec=args.t_fut,
)
del train_ds # only need test stats which test_ds carries
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False,
collate_fn=collate_triplet, num_workers=0)
modality_dims = test_ds.modality_dims
print(f" build took {time.time()-tb0:.1f}s; test n={len(test_ds)}",
flush=True)
for sd, results in group:
args_d = results["args"]
try:
model = build_model(args_d["model"], modality_dims).to(device)
state = torch.load(sd / "model_best.pt", map_location=device,
weights_only=False)
model.load_state_dict(state["state_dict"])
model.eval()
all_logits = {k: [] for k in
("verb_fine", "verb_composite", "noun", "hand")}
all_y = {k: [] for k in
("verb_fine", "verb_composite", "noun", "hand")}
with torch.no_grad():
for x, mask, lens, y, meta in test_loader:
x = {m: t.to(device) for m, t in x.items()}
mask = mask.to(device)
logits = model(x, mask)
for k in all_logits:
all_logits[k].append(logits[k].cpu())
all_y[k].append(y[k])
logits_cat = {k: torch.cat(v, dim=0) for k, v in all_logits.items()}
y_cat = {k: torch.cat(v, dim=0).numpy() for k, v in all_y.items()}
pred_cat = {k: logits_cat[k].argmax(dim=1).numpy() for k in logits_cat}
out = {}
for k in ("verb_fine", "verb_composite", "noun", "hand"):
out[f"{k}_acc"] = float(accuracy_score(y_cat[k], pred_cat[k]))
out[f"{k}_macro_f1"] = float(f1_score(y_cat[k], pred_cat[k],
average="macro", zero_division=0))
out[f"{k}_weighted_f1"] = float(f1_score(y_cat[k], pred_cat[k],
average="weighted", zero_division=0))
correct = ((pred_cat["verb_fine"] == y_cat["verb_fine"]) &
(pred_cat["noun"] == y_cat["noun"]) &
(pred_cat["hand"] == y_cat["hand"]))
out["action_acc"] = float(correct.mean())
out["n_params"] = sum(p.numel() for p in model.parameters())
with open(sd / "eval_macrof1.json", "w") as f:
json.dump(out, f, indent=2)
print(f" OK {sd.relative_to(REPO)} action_acc={out['action_acc']:.4f}",
flush=True)
n_ok += 1
# free model
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
print(f" FAIL {sd.relative_to(REPO)}: {e}", flush=True)
n_fail += 1
print(f"\nSubset done. ok={n_ok} fail={n_fail} elapsed={time.time()-t0:.1f}s",
flush=True)
if __name__ == "__main__":
main()
|