#!/usr/bin/env python3 """Train + evaluate binary "is_grasping" recognition (T5 v3 / TGSR). Predicts a binary class label over the future T_fut window from past T_obs of input modalities. Ground truth = annotation-based grasp-verb mask. Comparison: input includes pressure (treatment) vs not (control), under the same cross-modal kinematic baseline. Lift = macro_F1(with) − macro_F1(without). """ from __future__ import annotations import argparse import json import random import sys import time from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader THIS = Path(__file__).resolve() sys.path.insert(0, str(THIS.parent)) sys.path.insert(0, str(THIS.parents[1])) try: from experiments.dataset_grasp_state import ( GraspStateDataset, collate_grasp_state, build_grasp_train_test, EVENT_NAMES, CLASS_NAMES_BINARY, CLASS_NAMES_THREE, VERB_LIST, OBJECT_TOP_LIST, ) except ModuleNotFoundError: from dataset_grasp_state import ( GraspStateDataset, collate_grasp_state, build_grasp_train_test, EVENT_NAMES, CLASS_NAMES_BINARY, CLASS_NAMES_THREE, VERB_LIST, OBJECT_TOP_LIST, ) from nets.models_forecast import build_forecast_model # type: ignore class GraspStateClassifier(nn.Module): """Wrap the existing forecasting backbone for binary classification. Reuses build_forecast_model with output dim = num_classes, then mean-pools over the T_fut output axis to produce (B, num_classes) logits. """ def __init__(self, base_name, modality_dims, t_obs, t_fut, d_model, dropout, num_classes=2): super().__init__() self.base = build_forecast_model( base_name, modality_dims, num_classes=num_classes, t_obs=t_obs, t_fut=t_fut, d_model=d_model, dropout=dropout, ) def forward(self, x): out = self.base(x) # (B, T_fut, num_classes) return out.mean(dim=1) # (B, num_classes) ← logits def set_seed(seed: int): random.seed(seed); np.random.seed(seed) torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) def train_epoch(model, loader, optimizer, device, class_weight=None): model.train() total, n = 0.0, 0 for x, y, _et, _ in loader: x = {m: v.to(device) for m, v in x.items()} y = y.to(device) optimizer.zero_grad() logits = model(x) loss = F.cross_entropy(logits, y, weight=class_weight) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() total += loss.item() * y.numel() n += y.numel() return total / max(n, 1) @torch.no_grad() def evaluate(model, loader, device, num_classes=2, class_names=None): if class_names is None: if num_classes == 2: _CN = CLASS_NAMES_BINARY elif num_classes == 3: _CN = CLASS_NAMES_THREE elif num_classes == len(VERB_LIST): _CN = {i: v for i, v in enumerate(VERB_LIST)} else: _CN = {i: v for i, v in enumerate(OBJECT_TOP_LIST)} else: _CN = class_names """Return overall + per-event-stratified F1, accuracy, confusion.""" model.eval() # 5 strata = 4 events + overall cm = np.zeros((5, num_classes, num_classes), dtype=np.int64) for x, y, et, _ in loader: x = {m: v.to(device) for m, v in x.items()} logits = model(x) pred = logits.argmax(dim=-1).cpu().numpy() y_np = y.numpy(); et_np = et.numpy() for k in range(len(y_np)): e = int(et_np[k]) cm[e][int(y_np[k])][int(pred[k])] += 1 cm[4][int(y_np[k])][int(pred[k])] += 1 out = {} for e in range(5): m = cm[e] n = int(m.sum()) # per-class F1 f1s = [] for c in range(num_classes): tp = m[c][c] fp = m[:, c].sum() - tp fn = m[c, :].sum() - tp prec = tp / max(tp + fp, 1) rec = tp / max(tp + fn, 1) f1 = 2 * prec * rec / max(prec + rec, 1e-9) f1s.append(float(f1)) macro_f1 = float(np.mean(f1s)) acc = float(np.trace(m)) / max(n, 1) name = EVENT_NAMES.get(e, "overall") if e < 4 else "overall" out[name] = { "n": n, "accuracy": acc, "macro_f1": macro_f1, "f1_per_class": {_CN[c]: f1s[c] for c in range(num_classes)}, "confusion": m.tolist(), } return out def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", required=True, choices=["daf", "futr", "deepconvlstm"]) ap.add_argument("--input_modalities", required=True, help="comma-separated, e.g. 'emg,imu,mocap' or 'emg,imu,mocap,pressure'") ap.add_argument("--t_obs", type=float, default=1.0) ap.add_argument("--t_fut", type=float, default=0.5) ap.add_argument("--anchor_stride", type=float, default=0.25) ap.add_argument("--per_class_max", type=int, default=15000, help="Cap each class to this many anchors in train (for balance).") ap.add_argument("--epochs", type=int, default=30) ap.add_argument("--batch_size", type=int, default=64) ap.add_argument("--lr", type=float, default=3e-4) ap.add_argument("--weight_decay", type=float, default=1e-4) ap.add_argument("--d_model", type=int, default=128) ap.add_argument("--dropout", type=float, default=0.1) ap.add_argument("--num_workers", type=int, default=2) ap.add_argument("--seed", type=int, default=42) ap.add_argument("--patience", type=int, default=6) ap.add_argument("--no_class_weight", action="store_true", help="Skip class-weighted CE; rely on per_class_max balancing.") ap.add_argument("--label_mode", default="binary", choices=["binary", "three_class", "verb", "object"]) ap.add_argument("--sustained_threshold_sec", type=float, default=0.3, help="(3-class only) min contiguous contact run for SustainedGrasp class.") ap.add_argument("--require_lift_for_sustained", action="store_true", help="(3-class only) Class 2 also requires verb ∈ LIFT_VERBS or hand_type=both.") ap.add_argument("--train_vols", default=None, help="comma-separated volunteer IDs to override the default TRAIN split (for CV).") ap.add_argument("--test_vols", default=None, help="comma-separated volunteer IDs to override the default TEST split (for CV).") ap.add_argument("--output_dir", required=True) args = ap.parse_args() set_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") inputs = args.input_modalities.split(",") print(f"device={device} seed={args.seed} model={args.model} " f"inputs={inputs} t_obs={args.t_obs} t_fut={args.t_fut}", flush=True) tr_v = args.train_vols.split(',') if args.train_vols else None te_v = args.test_vols.split(',') if args.test_vols else None train_ds, test_ds = build_grasp_train_test( input_modalities=inputs, t_obs_sec=args.t_obs, t_fut_sec=args.t_fut, anchor_stride_sec=args.anchor_stride, per_class_max=args.per_class_max, label_mode=args.label_mode, sustained_threshold_sec=args.sustained_threshold_sec, require_lift_for_sustained=args.require_lift_for_sustained, rng_seed=args.seed, train_vols=tr_v, test_vols=te_v, ) num_classes = train_ds.num_classes print(f"train={len(train_ds)} test={len(test_ds)} num_classes={num_classes}", flush=True) tr_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_grasp_state, drop_last=False) te_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_grasp_state) model = GraspStateClassifier( args.model, train_ds.modality_dims, t_obs=train_ds.T_obs, t_fut=train_ds.T_fut, d_model=args.d_model, dropout=args.dropout, num_classes=num_classes, ).to(device) n_params = sum(p.numel() for p in model.parameters()) print(f"params={n_params:,}", flush=True) # Class weight = inverse class frequency in train if args.no_class_weight: cw = None else: ny = np.zeros(num_classes, dtype=np.int64) for it in train_ds._items: ny[it["label"]] += 1 cw = torch.tensor(ny.sum() / (num_classes * np.maximum(ny, 1)), dtype=torch.float32).to(device) print(f"class_weight={cw.tolist()}", flush=True) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.lr * 0.05) out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True) best_f1 = -1.0 best_epoch, best_eval = 0, None patience_counter = 0 for ep in range(1, args.epochs + 1): t0 = time.time() tr_loss = train_epoch(model, tr_loader, optimizer, device, class_weight=cw) ev = evaluate(model, te_loader, device, num_classes=num_classes) sched.step() f1 = ev["overall"]["macro_f1"] print(f" E{ep:2d} | tr_ce {tr_loss:.4f} | overall_f1 {f1:.4f} acc {ev['overall']['accuracy']:.4f} " f"| pre_f1 {ev['pre-contact']['macro_f1']:.3f} " f"steady {ev['steady-grip']['macro_f1']:.3f} " f"release {ev['release']['macro_f1']:.3f} " f"non {ev['non-contact']['macro_f1']:.3f} | {time.time()-t0:.1f}s", flush=True) if f1 > best_f1: best_f1 = f1 best_epoch = ep best_eval = ev torch.save({k: v.cpu() for k, v in model.state_dict().items()}, out_dir / "model_best.pt") patience_counter = 0 else: patience_counter += 1 if patience_counter >= args.patience: print(f" early stop at epoch {ep} (best {best_epoch})", flush=True) break out = { "method": args.model, "input_modalities": inputs, "seed": args.seed, "n_params": n_params, "T_obs": train_ds.T_obs, "T_fut": train_ds.T_fut, "best_epoch": int(best_epoch), "best_macro_f1": float(best_f1), "eval": best_eval, "args": vars(args), } with open(out_dir / "results.json", "w") as f: json.dump(out, f, indent=2) print(f"\n[done] best macro_F1={best_f1:.4f} at epoch {best_epoch}", flush=True) if __name__ == "__main__": main()