PULSE-code / experiments /tasks /train_grasp_state.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
#!/usr/bin/env python3
"""Train + evaluate binary "is_grasping" recognition (T5 v3 / TGSR).
Predicts a binary class label over the future T_fut window from past T_obs of
input modalities. Ground truth = annotation-based grasp-verb mask.
Comparison: input includes pressure (treatment) vs not (control), under the
same cross-modal kinematic baseline. Lift = macro_F1(with) − macro_F1(without).
"""
from __future__ import annotations
import argparse
import json
import random
import sys
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
THIS = Path(__file__).resolve()
sys.path.insert(0, str(THIS.parent))
sys.path.insert(0, str(THIS.parents[1]))
try:
from experiments.dataset_grasp_state import (
GraspStateDataset, collate_grasp_state,
build_grasp_train_test, EVENT_NAMES,
CLASS_NAMES_BINARY, CLASS_NAMES_THREE, VERB_LIST, OBJECT_TOP_LIST,
)
except ModuleNotFoundError:
from dataset_grasp_state import (
GraspStateDataset, collate_grasp_state,
build_grasp_train_test, EVENT_NAMES,
CLASS_NAMES_BINARY, CLASS_NAMES_THREE, VERB_LIST, OBJECT_TOP_LIST,
)
from nets.models_forecast import build_forecast_model # type: ignore
class GraspStateClassifier(nn.Module):
"""Wrap the existing forecasting backbone for binary classification.
Reuses build_forecast_model with output dim = num_classes, then mean-pools
over the T_fut output axis to produce (B, num_classes) logits.
"""
def __init__(self, base_name, modality_dims, t_obs, t_fut,
d_model, dropout, num_classes=2):
super().__init__()
self.base = build_forecast_model(
base_name, modality_dims,
num_classes=num_classes,
t_obs=t_obs, t_fut=t_fut,
d_model=d_model, dropout=dropout,
)
def forward(self, x):
out = self.base(x) # (B, T_fut, num_classes)
return out.mean(dim=1) # (B, num_classes) ← logits
def set_seed(seed: int):
random.seed(seed); np.random.seed(seed)
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
def train_epoch(model, loader, optimizer, device, class_weight=None):
model.train()
total, n = 0.0, 0
for x, y, _et, _ in loader:
x = {m: v.to(device) for m, v in x.items()}
y = y.to(device)
optimizer.zero_grad()
logits = model(x)
loss = F.cross_entropy(logits, y, weight=class_weight)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total += loss.item() * y.numel()
n += y.numel()
return total / max(n, 1)
@torch.no_grad()
def evaluate(model, loader, device, num_classes=2, class_names=None):
if class_names is None:
if num_classes == 2:
_CN = CLASS_NAMES_BINARY
elif num_classes == 3:
_CN = CLASS_NAMES_THREE
elif num_classes == len(VERB_LIST):
_CN = {i: v for i, v in enumerate(VERB_LIST)}
else:
_CN = {i: v for i, v in enumerate(OBJECT_TOP_LIST)}
else:
_CN = class_names
"""Return overall + per-event-stratified F1, accuracy, confusion."""
model.eval()
# 5 strata = 4 events + overall
cm = np.zeros((5, num_classes, num_classes), dtype=np.int64)
for x, y, et, _ in loader:
x = {m: v.to(device) for m, v in x.items()}
logits = model(x)
pred = logits.argmax(dim=-1).cpu().numpy()
y_np = y.numpy(); et_np = et.numpy()
for k in range(len(y_np)):
e = int(et_np[k])
cm[e][int(y_np[k])][int(pred[k])] += 1
cm[4][int(y_np[k])][int(pred[k])] += 1
out = {}
for e in range(5):
m = cm[e]
n = int(m.sum())
# per-class F1
f1s = []
for c in range(num_classes):
tp = m[c][c]
fp = m[:, c].sum() - tp
fn = m[c, :].sum() - tp
prec = tp / max(tp + fp, 1)
rec = tp / max(tp + fn, 1)
f1 = 2 * prec * rec / max(prec + rec, 1e-9)
f1s.append(float(f1))
macro_f1 = float(np.mean(f1s))
acc = float(np.trace(m)) / max(n, 1)
name = EVENT_NAMES.get(e, "overall") if e < 4 else "overall"
out[name] = {
"n": n, "accuracy": acc,
"macro_f1": macro_f1,
"f1_per_class": {_CN[c]: f1s[c] for c in range(num_classes)},
"confusion": m.tolist(),
}
return out
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", required=True, choices=["daf", "futr", "deepconvlstm"])
ap.add_argument("--input_modalities", required=True,
help="comma-separated, e.g. 'emg,imu,mocap' or 'emg,imu,mocap,pressure'")
ap.add_argument("--t_obs", type=float, default=1.0)
ap.add_argument("--t_fut", type=float, default=0.5)
ap.add_argument("--anchor_stride", type=float, default=0.25)
ap.add_argument("--per_class_max", type=int, default=15000,
help="Cap each class to this many anchors in train (for balance).")
ap.add_argument("--epochs", type=int, default=30)
ap.add_argument("--batch_size", type=int, default=64)
ap.add_argument("--lr", type=float, default=3e-4)
ap.add_argument("--weight_decay", type=float, default=1e-4)
ap.add_argument("--d_model", type=int, default=128)
ap.add_argument("--dropout", type=float, default=0.1)
ap.add_argument("--num_workers", type=int, default=2)
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--patience", type=int, default=6)
ap.add_argument("--no_class_weight", action="store_true",
help="Skip class-weighted CE; rely on per_class_max balancing.")
ap.add_argument("--label_mode", default="binary", choices=["binary", "three_class", "verb", "object"])
ap.add_argument("--sustained_threshold_sec", type=float, default=0.3,
help="(3-class only) min contiguous contact run for SustainedGrasp class.")
ap.add_argument("--require_lift_for_sustained", action="store_true",
help="(3-class only) Class 2 also requires verb ∈ LIFT_VERBS or hand_type=both.")
ap.add_argument("--train_vols", default=None,
help="comma-separated volunteer IDs to override the default TRAIN split (for CV).")
ap.add_argument("--test_vols", default=None,
help="comma-separated volunteer IDs to override the default TEST split (for CV).")
ap.add_argument("--output_dir", required=True)
args = ap.parse_args()
set_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = args.input_modalities.split(",")
print(f"device={device} seed={args.seed} model={args.model} "
f"inputs={inputs} t_obs={args.t_obs} t_fut={args.t_fut}", flush=True)
tr_v = args.train_vols.split(',') if args.train_vols else None
te_v = args.test_vols.split(',') if args.test_vols else None
train_ds, test_ds = build_grasp_train_test(
input_modalities=inputs,
t_obs_sec=args.t_obs, t_fut_sec=args.t_fut,
anchor_stride_sec=args.anchor_stride,
per_class_max=args.per_class_max,
label_mode=args.label_mode,
sustained_threshold_sec=args.sustained_threshold_sec,
require_lift_for_sustained=args.require_lift_for_sustained,
rng_seed=args.seed,
train_vols=tr_v, test_vols=te_v,
)
num_classes = train_ds.num_classes
print(f"train={len(train_ds)} test={len(test_ds)} num_classes={num_classes}", flush=True)
tr_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, collate_fn=collate_grasp_state,
drop_last=False)
te_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, collate_fn=collate_grasp_state)
model = GraspStateClassifier(
args.model, train_ds.modality_dims,
t_obs=train_ds.T_obs, t_fut=train_ds.T_fut,
d_model=args.d_model, dropout=args.dropout,
num_classes=num_classes,
).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"params={n_params:,}", flush=True)
# Class weight = inverse class frequency in train
if args.no_class_weight:
cw = None
else:
ny = np.zeros(num_classes, dtype=np.int64)
for it in train_ds._items: ny[it["label"]] += 1
cw = torch.tensor(ny.sum() / (num_classes * np.maximum(ny, 1)),
dtype=torch.float32).to(device)
print(f"class_weight={cw.tolist()}", flush=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.lr * 0.05)
out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True)
best_f1 = -1.0
best_epoch, best_eval = 0, None
patience_counter = 0
for ep in range(1, args.epochs + 1):
t0 = time.time()
tr_loss = train_epoch(model, tr_loader, optimizer, device, class_weight=cw)
ev = evaluate(model, te_loader, device, num_classes=num_classes)
sched.step()
f1 = ev["overall"]["macro_f1"]
print(f" E{ep:2d} | tr_ce {tr_loss:.4f} | overall_f1 {f1:.4f} acc {ev['overall']['accuracy']:.4f} "
f"| pre_f1 {ev['pre-contact']['macro_f1']:.3f} "
f"steady {ev['steady-grip']['macro_f1']:.3f} "
f"release {ev['release']['macro_f1']:.3f} "
f"non {ev['non-contact']['macro_f1']:.3f} | {time.time()-t0:.1f}s", flush=True)
if f1 > best_f1:
best_f1 = f1
best_epoch = ep
best_eval = ev
torch.save({k: v.cpu() for k, v in model.state_dict().items()},
out_dir / "model_best.pt")
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= args.patience:
print(f" early stop at epoch {ep} (best {best_epoch})", flush=True)
break
out = {
"method": args.model,
"input_modalities": inputs,
"seed": args.seed, "n_params": n_params,
"T_obs": train_ds.T_obs, "T_fut": train_ds.T_fut,
"best_epoch": int(best_epoch),
"best_macro_f1": float(best_f1),
"eval": best_eval,
"args": vars(args),
}
with open(out_dir / "results.json", "w") as f:
json.dump(out, f, indent=2)
print(f"\n[done] best macro_F1={best_f1:.4f} at epoch {best_epoch}", flush=True)
if __name__ == "__main__":
main()