#!/usr/bin/env python3 """Train + evaluate frame-level future verb_fine forecasting. Outputs per-horizon top-1 frame accuracy on the test set, saved to results.json under . """ from __future__ import annotations import argparse import json import os import random import sys import time from pathlib import Path import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader THIS = Path(__file__).resolve() sys.path.insert(0, str(THIS.parent)) sys.path.insert(0, str(THIS.parents[1])) try: from experiments.dataset_forecast import ( ForecastDataset, collate_forecast, build_train_test, IDLE_LABEL, NUM_FORECAST_CLASSES, ) from experiments.models_forecast import build_forecast_model except ModuleNotFoundError: from dataset_forecast import ( ForecastDataset, collate_forecast, build_train_test, IDLE_LABEL, NUM_FORECAST_CLASSES, ) from models_forecast import build_forecast_model def set_seed(seed: int): random.seed(seed); np.random.seed(seed) torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) def train_epoch(model, loader, optimizer, criterion, device): model.train() total, n_frames, correct = 0.0, 0, 0 for x, y, _ in loader: x = {m: v.to(device) for m, v in x.items()} y = y.to(device) # (B, T_fut) optimizer.zero_grad() logits = model(x) # (B, T_fut, C) loss = criterion(logits.reshape(-1, logits.size(-1)), y.reshape(-1)) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() total += loss.item() * y.numel() n_frames += y.numel() correct += (logits.argmax(-1) == y).sum().item() return total / max(n_frames, 1), correct / max(n_frames, 1) @torch.no_grad() def evaluate(model, loader, device, t_fut: int): model.eval() # Per-horizon counts (overall, ignore-idle) per_h_correct = np.zeros(t_fut, dtype=np.int64) per_h_total = np.zeros(t_fut, dtype=np.int64) per_h_correct_action = np.zeros(t_fut, dtype=np.int64) per_h_total_action = np.zeros(t_fut, dtype=np.int64) for x, y, _ in loader: x = {m: v.to(device) for m, v in x.items()} y = y.to(device) # (B, T_fut) logits = model(x) # (B, T_fut, C) pred = logits.argmax(-1) # (B, T_fut) for h in range(t_fut): yh = y[:, h]; ph = pred[:, h] per_h_correct[h] += (ph == yh).sum().item() per_h_total[h] += yh.numel() mask = (yh != IDLE_LABEL) per_h_correct_action[h] += ((ph == yh) & mask).sum().item() per_h_total_action[h] += mask.sum().item() return { "per_h_acc": (per_h_correct / np.maximum(per_h_total, 1)).tolist(), "per_h_acc_action": (per_h_correct_action / np.maximum(per_h_total_action, 1)).tolist(), "frame_acc": float(per_h_correct.sum() / max(per_h_total.sum(), 1)), "frame_acc_action": float(per_h_correct_action.sum() / max(per_h_total_action.sum(), 1)), } def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", type=str, required=True, choices=["daf", "futr", "deepconvlstm", "rulstm", "avt"]) ap.add_argument("--modalities", type=str, default="imu,emg,eyetrack,mocap,pressure", help="Comma-separated modality list") ap.add_argument("--t_obs", type=float, default=1.5) ap.add_argument("--t_fut", type=float, default=0.5) ap.add_argument("--anchor_stride", type=float, default=0.25) ap.add_argument("--contact_only", action="store_true", help="Only keep anchors whose past+future window has any " "frame with pressure-sum > threshold (Plan B).") ap.add_argument("--contact_threshold_g", type=float, default=5.0) ap.add_argument("--epochs", type=int, default=15) ap.add_argument("--batch_size", type=int, default=64) ap.add_argument("--lr", type=float, default=3e-4) ap.add_argument("--weight_decay", type=float, default=1e-4) ap.add_argument("--d_model", type=int, default=128) ap.add_argument("--dropout", type=float, default=0.1) ap.add_argument("--label_smoothing", type=float, default=0.05) ap.add_argument("--num_workers", type=int, default=2) ap.add_argument("--seed", type=int, default=42) ap.add_argument("--patience", type=int, default=5) ap.add_argument("--output_dir", type=str, required=True) args = ap.parse_args() set_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"device={device} | seed={args.seed} | model={args.model} " f"modalities={args.modalities}") mods = args.modalities.split(",") train_ds, test_ds = build_train_test( modalities=mods, t_obs_sec=args.t_obs, t_fut_sec=args.t_fut, anchor_stride_sec=args.anchor_stride, contact_only=args.contact_only, contact_threshold_g=args.contact_threshold_g, ) print(f"train={len(train_ds)} test={len(test_ds)} " f"T_obs={train_ds.T_obs} T_fut={train_ds.T_fut} " f"mod_dims={train_ds.modality_dims}") tr_loader = DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_forecast, drop_last=False, ) te_loader = DataLoader( test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_forecast, ) model = build_forecast_model( args.model, train_ds.modality_dims, num_classes=NUM_FORECAST_CLASSES, t_obs=train_ds.T_obs, t_fut=train_ds.T_fut, d_model=args.d_model, dropout=args.dropout, ).to(device) n_params = sum(p.numel() for p in model.parameters()) print(f"params={n_params:,}") optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) sched = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs, eta_min=args.lr * 0.05 ) criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True) best = {"frame_acc_action": -1.0, "epoch": 0, "state_dict": None} for ep in range(1, args.epochs + 1): t0 = time.time() tr_loss, tr_acc = train_epoch(model, tr_loader, optimizer, criterion, device) ev = evaluate(model, te_loader, device, t_fut=train_ds.T_fut) sched.step() print(f" E{ep:2d} | tr {tr_loss:.4f}/{tr_acc:.3f} " f"| te frame_acc {ev['frame_acc']:.3f} action {ev['frame_acc_action']:.3f} " f"| {time.time()-t0:.1f}s") if ev["frame_acc_action"] > best["frame_acc_action"]: best = {**ev, "epoch": ep, "state_dict": {k: v.cpu() for k, v in model.state_dict().items()}} torch.save(best["state_dict"], out_dir / "model_best.pt") # Final reporting from best epoch final = {k: v for k, v in best.items() if k != "state_dict"} out = { "method": args.model, "modalities": mods, "seed": args.seed, "n_params": n_params, "T_obs": train_ds.T_obs, "T_fut": train_ds.T_fut, "best_epoch": int(best["epoch"]), "frame_acc": float(best["frame_acc"]), "frame_acc_action": float(best["frame_acc_action"]), "per_h_acc": list(map(float, best["per_h_acc"])), "per_h_acc_action": list(map(float, best["per_h_acc_action"])), "args": vars(args), } with open(out_dir / "results.json", "w") as f: json.dump(out, f, indent=2) print(f"\n[done] best frame_acc_action {best['frame_acc_action']:.4f} (epoch {best['epoch']})") print(f"per_h_acc_action: {[f'{a:.3f}' for a in best['per_h_acc_action']]}") print(f"saved to {out_dir}/results.json") if __name__ == "__main__": main()