#!/usr/bin/env python3 """ Training loop for T10 Triplet Next-Action Prediction. Usage example: python3 experiments/train_seqpred.py \ --model dailyactformer \ --modalities imu,emg,eyetrack,mocap,pressure \ --t_obs 8 --t_fut 2 \ --epochs 40 --batch_size 32 --lr 3e-4 \ --output_dir results/seqpred/ours_all5_tfut2_seed42 \ --seed 42 """ from __future__ import annotations # pandas must be imported BEFORE torch/numpy to avoid a GLIBCXX load-order bug # on this cluster (libstdc++ from Anaconda vs system). import pandas # noqa: F401 import argparse import json import os import random import sys import time from pathlib import Path from typing import Dict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader # Make sibling modules importable from either (a) the neurips26 root (running # as `python experiments/train_seqpred.py`) or (b) the frozen row/code/ folder # (running via the per-row run.sh after setup_row.sh snapshots the code). THIS = Path(__file__).resolve() sys.path.insert(0, str(THIS.parent)) # row/code/ sys.path.insert(0, str(THIS.parents[1])) # neurips26/ try: from experiments.dataset_seqpred import ( TripletSeqPredDataset, build_train_test, collate_triplet, TRAIN_VOLS_V3, TEST_VOLS_V3, ) from experiments.models_seqpred import build_model from experiments.taxonomy import ( NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN, NUM_HAND, ) except ModuleNotFoundError: from dataset_seqpred import ( TripletSeqPredDataset, build_train_test, collate_triplet, TRAIN_VOLS_V3, TEST_VOLS_V3, ) from models_seqpred import build_model from taxonomy import ( NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN, NUM_HAND, ) # --------------------------------------------------------------------------- # Utilities # --------------------------------------------------------------------------- def set_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def top_k_correct(logits: torch.Tensor, target: torch.Tensor, k: int) -> torch.Tensor: """Return a bool tensor (B,) indicating whether `target` is in top-k of logits.""" k = min(k, logits.size(1)) _, top = logits.topk(k, dim=1) return (top == target.unsqueeze(1)).any(dim=1) def mean_class_recall(logits: torch.Tensor, target: torch.Tensor, num_classes: int) -> float: pred = logits.argmax(dim=1) recall_per_cls = [] for c in range(num_classes): sel = (target == c) n = int(sel.sum().item()) if n == 0: continue r = float((pred[sel] == c).float().mean().item()) recall_per_cls.append(r) return float(np.mean(recall_per_cls)) if recall_per_cls else 0.0 def build_class_weights(counts: np.ndarray) -> torch.Tensor: """Inverse-frequency weights, normalized so mean weight = 1.""" counts = counts.astype(np.float32).clip(min=1.0) w = 1.0 / counts w = w / w.mean() return torch.from_numpy(w) # --------------------------------------------------------------------------- # Core loss # --------------------------------------------------------------------------- def triplet_loss( logits: Dict[str, torch.Tensor], y: Dict[str, torch.Tensor], weights: Dict[str, torch.Tensor], lambda_cfg: Dict[str, float], label_smoothing: float = 0.05, ) -> Dict[str, torch.Tensor]: losses = {} for head in ("verb_fine", "verb_composite", "noun", "hand"): w = weights.get(head, None) if w is not None: w = w.to(logits[head].device) l = F.cross_entropy( logits[head], y[head], weight=w, label_smoothing=label_smoothing, ) losses[head] = l total = sum(lambda_cfg.get(k, 1.0) * losses[k] for k in losses) losses["total"] = total return losses # --------------------------------------------------------------------------- # Eval # --------------------------------------------------------------------------- @torch.no_grad() def evaluate(model, loader, device) -> Dict[str, float]: model.eval() all_logits: Dict[str, list] = {k: [] for k in ("verb_fine", "verb_composite", "noun", "hand")} all_y: Dict[str, list] = {k: [] for k in ("verb_fine", "verb_composite", "noun", "hand")} for batch in loader: # Backward-compatible unpack: collate returns 5 or 6 elements. if len(batch) == 6: x, mask, lens, y, meta, prev = batch else: x, mask, lens, y, meta = batch prev = None x = {m: t.to(device) for m, t in x.items()} mask = mask.to(device) kwargs = {} if prev is not None and getattr(model, "use_prev_action", False): kwargs["prev_v_comp"] = prev["verb_composite"].to(device) kwargs["prev_noun"] = prev["noun"].to(device) logits = model(x, mask, **kwargs) for k in all_logits: all_logits[k].append(logits[k].cpu()) all_y[k].append(y[k]) logits_cat = {k: torch.cat(v, dim=0) for k, v in all_logits.items()} y_cat = {k: torch.cat(v, dim=0) for k, v in all_y.items()} m = {} for k, K in [("verb_fine", NUM_VERB_FINE), ("verb_composite", NUM_VERB_COMPOSITE), ("noun", NUM_NOUN), ("hand", NUM_HAND)]: preds = logits_cat[k].argmax(dim=1) acc1 = float((preds == y_cat[k]).float().mean().item()) m[f"{k}_top1"] = acc1 if K > 5: acc5 = float(top_k_correct(logits_cat[k], y_cat[k], 5).float().mean().item()) m[f"{k}_top5"] = acc5 m[f"{k}_mcr"] = mean_class_recall(logits_cat[k], y_cat[k], K) # Per-head argmax predictions vf_pred = logits_cat["verb_fine"].argmax(dim=1) n_pred = logits_cat["noun"].argmax(dim=1) h_pred = logits_cat["hand"].argmax(dim=1) # Headline (current default): action_vn = (verb_fine, noun) joint top-1. # Hand is dropped from the joint metric because the hand label is dominated # by a single majority class (~48% train, ~42% test) so a constant predictor # already saturates it; including hand in the joint compresses the signal # from the verb / noun heads where models actually learn. Hand is still # reported separately as `hand_top1`. vn_correct = (vf_pred == y_cat["verb_fine"]) & (n_pred == y_cat["noun"]) m["action_vn_top1"] = float(vn_correct.float().mean().item()) # Top-5 action over (verb_fine, noun) vf_top5 = top_k_correct(logits_cat["verb_fine"], y_cat["verb_fine"], 5) n_top5 = top_k_correct(logits_cat["noun"], y_cat["noun"], 5) m["action_vn_top5"] = float((vf_top5 & n_top5).float().mean().item()) # Legacy: include hand in the joint, kept for backward compatibility with # earlier reports. Will be deprecated. vfn_h_correct = vn_correct & (h_pred == y_cat["hand"]) m["action_top1"] = float(vfn_h_correct.float().mean().item()) h_top1 = (h_pred == y_cat["hand"]) m["action_top5"] = float((vf_top5 & n_top5 & h_top1).float().mean().item()) return m # --------------------------------------------------------------------------- # Modality dropout (train-time only) # --------------------------------------------------------------------------- def apply_modality_dropout(x: Dict[str, torch.Tensor], p: float) -> Dict[str, torch.Tensor]: """Per-sample per-modality dropout: zero out each (sample, modality) cell independently with probability p, but force-keep at least one modality per sample so the model never receives an all-zero input.""" if p <= 0.0: return x mods = list(x.keys()) if len(mods) <= 1: return x any_t = next(iter(x.values())) B = any_t.shape[0] device = any_t.device keep = (torch.rand(B, len(mods), device=device) >= p) forced = torch.randint(len(mods), (B,), device=device) keep[torch.arange(B, device=device), forced] = True out = {} for i, m in enumerate(mods): km = keep[:, i].to(x[m].dtype).view(B, *([1] * (x[m].ndim - 1))) out[m] = x[m] * km return out # --------------------------------------------------------------------------- # Main training # --------------------------------------------------------------------------- def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", type=str, default="deepconvlstm", choices=["deepconvlstm", "dailyactformer", "rulstm", "futr", "afft", "handformer", "actionllm"]) ap.add_argument("--modalities", type=str, default="imu,emg,eyetrack,mocap,pressure") ap.add_argument("--t_obs", type=float, default=8.0, help="Anticipation mode only: observation window length (s).") ap.add_argument("--t_fut", type=float, default=2.0, help="Anticipation mode only: prediction horizon (s).") ap.add_argument("--mode", type=str, default="recognition", choices=["recognition", "anticipation"], help="recognition = classify segment from its own [start,end] sensor " "window (default). anticipation = legacy T10 setup, predict from " "[start-t_fut-t_obs, start-t_fut].") ap.add_argument("--downsample", type=int, default=5) ap.add_argument("--epochs", type=int, default=40) ap.add_argument("--batch_size", type=int, default=32) ap.add_argument("--lr", type=float, default=3e-4) ap.add_argument("--weight_decay", type=float, default=1e-4) ap.add_argument("--grad_clip", type=float, default=1.0) ap.add_argument("--label_smoothing", type=float, default=0.05) ap.add_argument("--dropout", type=float, default=0.1, help="Dropout used inside DAF stems / transformer / pool.") ap.add_argument("--use_prev_action", action="store_true", help="Condition DAF on previous-segment (verb_composite, noun) " "labels via embedding concat to pooled features. Only DAF " "uses this; baselines ignore it.") ap.add_argument("--modality_dropout", type=float, default=0.0, help="Train-time per-sample per-modality dropout prob " "(0.0=off). At least one modality is always kept.") ap.add_argument("--use_class_weights", action="store_true", help="Weight CE by inverse class frequency (better for tail).") ap.add_argument("--lambda_verb_fine", type=float, default=1.0) ap.add_argument("--lambda_verb_composite", type=float, default=0.5) ap.add_argument("--lambda_noun", type=float, default=1.0) ap.add_argument("--lambda_hand", type=float, default=0.5) ap.add_argument("--patience", type=int, default=12) ap.add_argument("--warmup_epochs", type=int, default=0, help="Linear LR warmup over the first N epochs (0=off).") ap.add_argument("--seed", type=int, default=42) ap.add_argument("--output_dir", type=str, required=True) ap.add_argument("--num_workers", type=int, default=0) ap.add_argument("--tag", type=str, default="") args = ap.parse_args() set_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.mode == "anticipation": print(f"[cfg] model={args.model} modalities={args.modalities} " f"mode={args.mode} T_obs={args.t_obs}s T_fut={args.t_fut}s seed={args.seed}") else: print(f"[cfg] model={args.model} modalities={args.modalities} " f"mode={args.mode} (segment-aligned window) seed={args.seed}") print(f"[cfg] device={device} epochs={args.epochs} lr={args.lr} " f"batch_size={args.batch_size}") mods = tuple(args.modalities.split(",")) train_ds, test_ds = build_train_test( modalities=mods, t_obs_sec=args.t_obs, t_fut_sec=args.t_fut, downsample=args.downsample, mode=args.mode, ) print(f"[data] train={len(train_ds)} test={len(test_ds)} " f"modality_dims={train_ds.modality_dims}") # Class counts for weighting (train only) counts = train_ds.class_counts() weights: Dict[str, torch.Tensor] = {} if args.use_class_weights: for k in ("verb_fine", "verb_composite", "noun", "hand"): weights[k] = build_class_weights(counts[k]) train_loader = DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_triplet, num_workers=args.num_workers, drop_last=True, ) test_loader = DataLoader( test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_triplet, num_workers=args.num_workers, ) # For DailyActFormer: causal mask only when doing anticipation; bidirectional # attention for recognition (the default). Other models ignore unknown kwargs. extra_kwargs = {} if args.model in ("dailyactformer", "ours", "daf"): extra_kwargs["causal"] = (args.mode == "anticipation") extra_kwargs["dropout"] = args.dropout # Every model class now accepts use_prev_action; pass it uniformly. extra_kwargs["use_prev_action"] = args.use_prev_action model = build_model(args.model, train_ds.modality_dims, **extra_kwargs).to(device) n_params = sum(p.numel() for p in model.parameters()) print(f"[model] {args.model} params={n_params:,}") opt = torch.optim.AdamW( model.parameters(), lr=args.lr, weight_decay=args.weight_decay, ) if args.warmup_epochs > 0: warmup = torch.optim.lr_scheduler.LinearLR( opt, start_factor=1.0 / max(1, args.warmup_epochs), end_factor=1.0, total_iters=args.warmup_epochs, ) cosine = torch.optim.lr_scheduler.CosineAnnealingLR( opt, T_max=max(1, args.epochs - args.warmup_epochs), eta_min=args.lr * 0.05, ) sched = torch.optim.lr_scheduler.SequentialLR( opt, schedulers=[warmup, cosine], milestones=[args.warmup_epochs], ) else: sched = torch.optim.lr_scheduler.CosineAnnealingLR( opt, T_max=args.epochs, eta_min=args.lr * 0.05, ) lambda_cfg = { "verb_fine": args.lambda_verb_fine, "verb_composite": args.lambda_verb_composite, "noun": args.lambda_noun, "hand": args.lambda_hand, } # Output directory out_dir = Path(args.output_dir) if args.tag: out_dir = out_dir.parent / f"{out_dir.name}_{args.tag}" out_dir.mkdir(parents=True, exist_ok=True) with open(out_dir / "config.json", "w") as f: json.dump(vars(args) | {"n_params": n_params}, f, indent=2) best = {"action_vn_top1": -1.0, "action_top1": -1.0} best_epoch = 0 best_path = out_dir / "model_best.pt" patience = 0 history = [] for epoch in range(1, args.epochs + 1): t0 = time.time() model.train() losses_epoch = {k: 0.0 for k in ("verb_fine", "verb_composite", "noun", "hand", "total")} n_batches = 0 for batch in train_loader: if len(batch) == 6: x, mask, lens, y, meta, prev = batch else: x, mask, lens, y, meta = batch prev = None x = {m: t.to(device) for m, t in x.items()} mask = mask.to(device) y = {k: v.to(device) for k, v in y.items()} if args.modality_dropout > 0.0: x = apply_modality_dropout(x, args.modality_dropout) kwargs = {} if prev is not None and getattr(model, "use_prev_action", False): kwargs["prev_v_comp"] = prev["verb_composite"].to(device) kwargs["prev_noun"] = prev["noun"].to(device) opt.zero_grad() logits = model(x, mask, **kwargs) l = triplet_loss(logits, y, weights, lambda_cfg, label_smoothing=args.label_smoothing) l["total"].backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) opt.step() for k in losses_epoch: losses_epoch[k] += float(l[k].detach().item()) n_batches += 1 for k in losses_epoch: losses_epoch[k] /= max(1, n_batches) sched.step() metrics = evaluate(model, test_loader, device) dur = time.time() - t0 print( f" E{epoch:3d} loss={losses_epoch['total']:.3f} " f"(vf={losses_epoch['verb_fine']:.2f} " f"n={losses_epoch['noun']:.2f} " f"h={losses_epoch['hand']:.2f}) | " f"act_vn@1={metrics['action_vn_top1']:.3f} " f"vf@1={metrics['verb_fine_top1']:.3f} " f"n@1={metrics['noun_top1']:.3f} " f"h@1={metrics['hand_top1']:.3f} | " f"{dur:.1f}s", flush=True, ) history.append({"epoch": epoch, **losses_epoch, **metrics}) if metrics["action_vn_top1"] > best["action_vn_top1"]: best = dict(metrics) best_epoch = epoch patience = 0 torch.save( {"state_dict": {k: v.cpu().clone() for k, v in model.state_dict().items()}, "epoch": epoch, "metrics": metrics}, best_path, ) else: patience += 1 if patience >= args.patience: print(f" early stop at epoch {epoch} (best epoch {best_epoch})") break # Write results results = { "best_epoch": best_epoch, "best_test_metrics": best, "history": history, "n_params": n_params, "train_size": len(train_ds), "test_size": len(test_ds), "train_class_counts": {k: v.tolist() for k, v in counts.items()}, "modality_dims": train_ds.modality_dims, "args": vars(args), } with open(out_dir / "results.json", "w") as f: json.dump(results, f, indent=2) print(f"\n[done] best action_vn@1 = {best['action_vn_top1']:.4f} " f"(legacy action@1 = {best['action_top1']:.4f}, epoch {best_epoch}) " f"saved to {out_dir}") if __name__ == "__main__": main()