| |
| """Train + evaluate binary "is_grasping" recognition (T5 v3 / TGSR). |
| |
| Predicts a binary class label over the future T_fut window from past T_obs of |
| input modalities. Ground truth = annotation-based grasp-verb mask. |
| |
| Comparison: input includes pressure (treatment) vs not (control), under the |
| same cross-modal kinematic baseline. Lift = macro_F1(with) − macro_F1(without). |
| """ |
| from __future__ import annotations |
| import argparse |
| import json |
| import random |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
|
|
| THIS = Path(__file__).resolve() |
| sys.path.insert(0, str(THIS.parent)) |
| sys.path.insert(0, str(THIS.parents[1])) |
|
|
| try: |
| from experiments.dataset_grasp_state import ( |
| GraspStateDataset, collate_grasp_state, |
| build_grasp_train_test, EVENT_NAMES, |
| CLASS_NAMES_BINARY, CLASS_NAMES_THREE, VERB_LIST, OBJECT_TOP_LIST, |
| ) |
| except ModuleNotFoundError: |
| from dataset_grasp_state import ( |
| GraspStateDataset, collate_grasp_state, |
| build_grasp_train_test, EVENT_NAMES, |
| CLASS_NAMES_BINARY, CLASS_NAMES_THREE, VERB_LIST, OBJECT_TOP_LIST, |
| ) |
| from nets.models_forecast import build_forecast_model |
|
|
|
|
| class GraspStateClassifier(nn.Module): |
| """Wrap the existing forecasting backbone for binary classification. |
| |
| Reuses build_forecast_model with output dim = num_classes, then mean-pools |
| over the T_fut output axis to produce (B, num_classes) logits. |
| """ |
| def __init__(self, base_name, modality_dims, t_obs, t_fut, |
| d_model, dropout, num_classes=2): |
| super().__init__() |
| self.base = build_forecast_model( |
| base_name, modality_dims, |
| num_classes=num_classes, |
| t_obs=t_obs, t_fut=t_fut, |
| d_model=d_model, dropout=dropout, |
| ) |
|
|
| def forward(self, x): |
| out = self.base(x) |
| return out.mean(dim=1) |
|
|
|
|
| def set_seed(seed: int): |
| random.seed(seed); np.random.seed(seed) |
| torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def train_epoch(model, loader, optimizer, device, class_weight=None): |
| model.train() |
| total, n = 0.0, 0 |
| for x, y, _et, _ in loader: |
| x = {m: v.to(device) for m, v in x.items()} |
| y = y.to(device) |
| optimizer.zero_grad() |
| logits = model(x) |
| loss = F.cross_entropy(logits, y, weight=class_weight) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| total += loss.item() * y.numel() |
| n += y.numel() |
| return total / max(n, 1) |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, loader, device, num_classes=2, class_names=None): |
| if class_names is None: |
| if num_classes == 2: |
| _CN = CLASS_NAMES_BINARY |
| elif num_classes == 3: |
| _CN = CLASS_NAMES_THREE |
| elif num_classes == len(VERB_LIST): |
| _CN = {i: v for i, v in enumerate(VERB_LIST)} |
| else: |
| _CN = {i: v for i, v in enumerate(OBJECT_TOP_LIST)} |
| else: |
| _CN = class_names |
| """Return overall + per-event-stratified F1, accuracy, confusion.""" |
| model.eval() |
| |
| cm = np.zeros((5, num_classes, num_classes), dtype=np.int64) |
| for x, y, et, _ in loader: |
| x = {m: v.to(device) for m, v in x.items()} |
| logits = model(x) |
| pred = logits.argmax(dim=-1).cpu().numpy() |
| y_np = y.numpy(); et_np = et.numpy() |
| for k in range(len(y_np)): |
| e = int(et_np[k]) |
| cm[e][int(y_np[k])][int(pred[k])] += 1 |
| cm[4][int(y_np[k])][int(pred[k])] += 1 |
|
|
| out = {} |
| for e in range(5): |
| m = cm[e] |
| n = int(m.sum()) |
| |
| f1s = [] |
| for c in range(num_classes): |
| tp = m[c][c] |
| fp = m[:, c].sum() - tp |
| fn = m[c, :].sum() - tp |
| prec = tp / max(tp + fp, 1) |
| rec = tp / max(tp + fn, 1) |
| f1 = 2 * prec * rec / max(prec + rec, 1e-9) |
| f1s.append(float(f1)) |
| macro_f1 = float(np.mean(f1s)) |
| acc = float(np.trace(m)) / max(n, 1) |
| name = EVENT_NAMES.get(e, "overall") if e < 4 else "overall" |
| out[name] = { |
| "n": n, "accuracy": acc, |
| "macro_f1": macro_f1, |
| "f1_per_class": {_CN[c]: f1s[c] for c in range(num_classes)}, |
| "confusion": m.tolist(), |
| } |
| return out |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--model", required=True, choices=["daf", "futr", "deepconvlstm"]) |
| ap.add_argument("--input_modalities", required=True, |
| help="comma-separated, e.g. 'emg,imu,mocap' or 'emg,imu,mocap,pressure'") |
| ap.add_argument("--t_obs", type=float, default=1.0) |
| ap.add_argument("--t_fut", type=float, default=0.5) |
| ap.add_argument("--anchor_stride", type=float, default=0.25) |
| ap.add_argument("--per_class_max", type=int, default=15000, |
| help="Cap each class to this many anchors in train (for balance).") |
| ap.add_argument("--epochs", type=int, default=30) |
| ap.add_argument("--batch_size", type=int, default=64) |
| ap.add_argument("--lr", type=float, default=3e-4) |
| ap.add_argument("--weight_decay", type=float, default=1e-4) |
| ap.add_argument("--d_model", type=int, default=128) |
| ap.add_argument("--dropout", type=float, default=0.1) |
| ap.add_argument("--num_workers", type=int, default=2) |
| ap.add_argument("--seed", type=int, default=42) |
| ap.add_argument("--patience", type=int, default=6) |
| ap.add_argument("--no_class_weight", action="store_true", |
| help="Skip class-weighted CE; rely on per_class_max balancing.") |
| ap.add_argument("--label_mode", default="binary", choices=["binary", "three_class", "verb", "object"]) |
| ap.add_argument("--sustained_threshold_sec", type=float, default=0.3, |
| help="(3-class only) min contiguous contact run for SustainedGrasp class.") |
| ap.add_argument("--require_lift_for_sustained", action="store_true", |
| help="(3-class only) Class 2 also requires verb ∈ LIFT_VERBS or hand_type=both.") |
| ap.add_argument("--train_vols", default=None, |
| help="comma-separated volunteer IDs to override the default TRAIN split (for CV).") |
| ap.add_argument("--test_vols", default=None, |
| help="comma-separated volunteer IDs to override the default TEST split (for CV).") |
| ap.add_argument("--output_dir", required=True) |
| args = ap.parse_args() |
|
|
| set_seed(args.seed) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| inputs = args.input_modalities.split(",") |
| print(f"device={device} seed={args.seed} model={args.model} " |
| f"inputs={inputs} t_obs={args.t_obs} t_fut={args.t_fut}", flush=True) |
|
|
| tr_v = args.train_vols.split(',') if args.train_vols else None |
| te_v = args.test_vols.split(',') if args.test_vols else None |
| train_ds, test_ds = build_grasp_train_test( |
| input_modalities=inputs, |
| t_obs_sec=args.t_obs, t_fut_sec=args.t_fut, |
| anchor_stride_sec=args.anchor_stride, |
| per_class_max=args.per_class_max, |
| label_mode=args.label_mode, |
| sustained_threshold_sec=args.sustained_threshold_sec, |
| require_lift_for_sustained=args.require_lift_for_sustained, |
| rng_seed=args.seed, |
| train_vols=tr_v, test_vols=te_v, |
| ) |
| num_classes = train_ds.num_classes |
| print(f"train={len(train_ds)} test={len(test_ds)} num_classes={num_classes}", flush=True) |
|
|
| tr_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, |
| num_workers=args.num_workers, collate_fn=collate_grasp_state, |
| drop_last=False) |
| te_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, |
| num_workers=args.num_workers, collate_fn=collate_grasp_state) |
|
|
| model = GraspStateClassifier( |
| args.model, train_ds.modality_dims, |
| t_obs=train_ds.T_obs, t_fut=train_ds.T_fut, |
| d_model=args.d_model, dropout=args.dropout, |
| num_classes=num_classes, |
| ).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"params={n_params:,}", flush=True) |
|
|
| |
| if args.no_class_weight: |
| cw = None |
| else: |
| ny = np.zeros(num_classes, dtype=np.int64) |
| for it in train_ds._items: ny[it["label"]] += 1 |
| cw = torch.tensor(ny.sum() / (num_classes * np.maximum(ny, 1)), |
| dtype=torch.float32).to(device) |
| print(f"class_weight={cw.tolist()}", flush=True) |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.lr * 0.05) |
|
|
| out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True) |
| best_f1 = -1.0 |
| best_epoch, best_eval = 0, None |
| patience_counter = 0 |
| for ep in range(1, args.epochs + 1): |
| t0 = time.time() |
| tr_loss = train_epoch(model, tr_loader, optimizer, device, class_weight=cw) |
| ev = evaluate(model, te_loader, device, num_classes=num_classes) |
| sched.step() |
| f1 = ev["overall"]["macro_f1"] |
| print(f" E{ep:2d} | tr_ce {tr_loss:.4f} | overall_f1 {f1:.4f} acc {ev['overall']['accuracy']:.4f} " |
| f"| pre_f1 {ev['pre-contact']['macro_f1']:.3f} " |
| f"steady {ev['steady-grip']['macro_f1']:.3f} " |
| f"release {ev['release']['macro_f1']:.3f} " |
| f"non {ev['non-contact']['macro_f1']:.3f} | {time.time()-t0:.1f}s", flush=True) |
| if f1 > best_f1: |
| best_f1 = f1 |
| best_epoch = ep |
| best_eval = ev |
| torch.save({k: v.cpu() for k, v in model.state_dict().items()}, |
| out_dir / "model_best.pt") |
| patience_counter = 0 |
| else: |
| patience_counter += 1 |
| if patience_counter >= args.patience: |
| print(f" early stop at epoch {ep} (best {best_epoch})", flush=True) |
| break |
|
|
| out = { |
| "method": args.model, |
| "input_modalities": inputs, |
| "seed": args.seed, "n_params": n_params, |
| "T_obs": train_ds.T_obs, "T_fut": train_ds.T_fut, |
| "best_epoch": int(best_epoch), |
| "best_macro_f1": float(best_f1), |
| "eval": best_eval, |
| "args": vars(args), |
| } |
| with open(out_dir / "results.json", "w") as f: |
| json.dump(out, f, indent=2) |
| print(f"\n[done] best macro_F1={best_f1:.4f} at epoch {best_epoch}", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|