#!/usr/bin/env python3 """ Experiment D: EMG -> hand pose regression. Predict right-hand finger pose (5 fingertip positions relative to the wrist) from 8-channel surface EMG. 15-dim per-timestep regression target. This directly supports the paper's stated prosthetics use case: "The paired EMG and finger-level hand kinematics support EMG-to-hand-pose decoding for myoelectric prostheses." """ import os import sys import json import time import random import argparse import numpy as np import pandas as pd import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence from scipy.stats import pearsonr sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.dataset import ( DATASET_DIR, MODALITY_FILES, TRAIN_VOLS, TEST_VOLS, load_modality_array, SCENE_LABELS, ) from tasks.train_exp_grip import GripRegressor, set_seed, masked_huber # Right-hand fingertip markers (relative to wrist) WRIST = 'RightHand' FINGERTIPS = ['RightHandThumb3', 'RightHandIndex3', 'RightHandMiddle3', 'RightHandRing3', 'RightHandPinky3'] def load_hand_pose_target(tsv_path): """Load MoCap TSV and return wrist-relative fingertip positions as (T, 15) array: [5 tips × 3 XYZ], in the raw coordinate frame.""" try: df = pd.read_csv(tsv_path, sep='\t') except Exception: return None cols = set(df.columns) needed = [f"{WRIST}_{ax}" for ax in 'XYZ'] for tip in FINGERTIPS: needed.extend([f"{tip}_{ax}" for ax in 'XYZ']) if not all(c in cols for c in needed): return None wrist = df[[f"{WRIST}_{ax}" for ax in 'XYZ']].values.astype(np.float32) tips = [] for tip in FINGERTIPS: t = df[[f"{tip}_{ax}" for ax in 'XYZ']].values.astype(np.float32) tips.append(t - wrist) # wrist-relative pose = np.concatenate(tips, axis=1) # (T, 15) return pose class EMG2PoseDataset(Dataset): """Per-frame regression: EMG -> (5 wrist-relative fingertip XYZ = 15d).""" def __init__(self, volunteers, downsample=5, stats=None, target_stats=None): self.downsample = downsample self.data = [] self.targets = [] self.sample_info = [] for vol in volunteers: vol_dir = os.path.join(DATASET_DIR, vol) if not os.path.isdir(vol_dir): continue for scenario in sorted(os.listdir(vol_dir)): scenario_dir = os.path.join(vol_dir, scenario) if not os.path.isdir(scenario_dir) or scenario not in SCENE_LABELS: continue emg_fp = os.path.join(scenario_dir, MODALITY_FILES['emg']) mocap_fp = os.path.join(scenario_dir, f"aligned_{vol}{scenario}_s_Q.tsv") if not (os.path.exists(emg_fp) and os.path.exists(mocap_fp)): continue emg = load_modality_array(emg_fp, 'emg') if emg is None: continue pose = load_hand_pose_target(mocap_fp) if pose is None: continue T_min = min(emg.shape[0], pose.shape[0]) emg = emg[:T_min:downsample] pose = pose[:T_min:downsample] if emg.shape[0] < 10: continue self.data.append(emg.astype(np.float32)) self.targets.append(pose.astype(np.float32)) self.sample_info.append(f"{vol}/{scenario}") if len(self.data) == 0: raise RuntimeError("No data loaded.") print(f" Loaded {len(self.data)} recordings, avg T " f"{np.mean([d.shape[0] for d in self.data]):.0f}") # Normalize EMG if stats is not None: self.mean, self.std = stats else: all_ = np.concatenate(self.data, axis=0).astype(np.float64) self.mean = all_.mean(axis=0, keepdims=True) self.std = all_.std(axis=0, keepdims=True) self.std[self.std < 1e-8] = 1.0 for i in range(len(self.data)): self.data[i] = ((self.data[i].astype(np.float64) - self.mean) / self.std).astype(np.float32) self.data[i] = np.nan_to_num(self.data[i], nan=0.0, posinf=0.0, neginf=0.0) # Normalize target (mm) if target_stats is not None: self.t_mean, self.t_std = target_stats else: all_t = np.concatenate(self.targets, axis=0).astype(np.float64) self.t_mean = all_t.mean(axis=0, keepdims=True) self.t_std = all_t.std(axis=0, keepdims=True) self.t_std[self.t_std < 1e-8] = 1.0 for i in range(len(self.targets)): self.targets[i] = ((self.targets[i].astype(np.float64) - self.t_mean) / self.t_std).astype(np.float32) self.targets[i] = np.nan_to_num(self.targets[i], nan=0.0, posinf=0.0, neginf=0.0) def get_stats(self): return (self.mean, self.std) def get_target_stats(self): return (self.t_mean, self.t_std) @property def feat_dim(self): return 8 # EMG always 8-channel @property def target_dim(self): return 15 def __len__(self): return len(self.data) def __getitem__(self, idx): return (torch.from_numpy(self.data[idx]), torch.from_numpy(self.targets[idx])) def collate_fn(batch): seqs, targs = zip(*batch) lens = torch.LongTensor([s.shape[0] for s in seqs]) padded = pad_sequence(seqs, batch_first=True, padding_value=0.0) padded_t = pad_sequence(targs, batch_first=True, padding_value=0.0) max_len = padded.shape[1] mask = torch.arange(max_len).unsqueeze(0) < lens.unsqueeze(1) return padded, padded_t, mask, lens @torch.no_grad() def evaluate(model, loader, device, tmean, tstd): model.eval() total_loss = 0.0 n_frames = 0 all_preds, all_trues = [], [] for x, y, mask, _ in loader: x, y, mask = x.to(device), y.to(device), mask.to(device) pred = model(x, mask) loss = masked_huber(pred, y, mask, delta=1.0) nf = mask.sum().item() total_loss += loss.item() * nf n_frames += nf pred_np = pred.cpu().numpy() * tstd + tmean true_np = y.cpu().numpy() * tstd + tmean m_np = mask.cpu().numpy() for b in range(pred_np.shape[0]): valid = m_np[b] all_preds.append(pred_np[b, valid]) all_trues.append(true_np[b, valid]) P = np.concatenate(all_preds, axis=0) # (total_T, 15) T = np.concatenate(all_trues, axis=0) # Per-coord metrics mae = float(np.mean(np.abs(P - T))) rs = [] for d in range(15): if np.std(P[:, d]) < 1e-6 or np.std(T[:, d]) < 1e-6: rs.append(0.0) else: rs.append(float(pearsonr(P[:, d], T[:, d])[0])) r_mean = float(np.mean(rs)) # Per-finger MAE (group by 5 fingertips) finger_mae = [] for i in range(5): finger_mae.append(float(np.mean(np.abs(P[:, 3*i:3*i+3] - T[:, 3*i:3*i+3])))) # Overall 3D Euclidean error per fingertip tip_eucl = [] for i in range(5): d = np.linalg.norm(P[:, 3*i:3*i+3] - T[:, 3*i:3*i+3], axis=1) tip_eucl.append(float(np.mean(d))) return { 'loss': total_loss / max(n_frames, 1), 'mae': mae, 'pearson_r_mean': r_mean, 'pearson_r_per_coord': rs, 'finger_mae': dict(zip(FINGERTIPS, finger_mae)), 'finger_eucl_mm': dict(zip(FINGERTIPS, tip_eucl)), 'avg_eucl_mm': float(np.mean(tip_eucl)), } def run_experiment(args): set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") print(f"Backbone: {args.backbone} | seed: {args.seed}") print("Loading train...") train_ds = EMG2PoseDataset(TRAIN_VOLS, downsample=args.downsample) stats = train_ds.get_stats() tstats = train_ds.get_target_stats() print(f" target mean: {tstats[0].flatten()[:3]} ... std: {tstats[1].flatten()[:3]} ...") print("Loading test...") test_ds = EMG2PoseDataset(TEST_VOLS, downsample=args.downsample, stats=stats, target_stats=tstats) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=0) test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) model = GripRegressor(args.backbone, 8, hidden_dim=args.hidden_dim, output_dim=15, dropout=args.dropout).to(device) n_params = sum(p.numel() for p in model.parameters()) print(f"Params: {n_params:,}") optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.5, patience=7, min_lr=1e-6, ) exp_name = f"pose_{args.backbone}_emg_seed{args.seed}" if args.tag: exp_name += f"_{args.tag}" out_dir = os.path.join(args.output_dir, exp_name) os.makedirs(out_dir, exist_ok=True) best_eucl = float('inf') best_metrics = None best_state = None best_epoch = 0 patience_counter = 0 for epoch in range(1, args.epochs + 1): t0 = time.time() model.train() tr_loss = 0.0 n = 0 for x, y, mask, _ in train_loader: x, y, mask = x.to(device), y.to(device), mask.to(device) optimizer.zero_grad() pred = model(x, mask) loss = masked_huber(pred, y, mask, delta=1.0) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() nf = mask.sum().item() tr_loss += loss.item() * nf n += nf tr_loss /= max(n, 1) m = evaluate(model, test_loader, device, tstats[0], tstats[1]) scheduler.step(m['loss']) print(f" E{epoch:3d} | tr {tr_loss:.4f} | te_loss {m['loss']:.4f} " f"mae {m['mae']:.2f}mm eucl {m['avg_eucl_mm']:.2f}mm " f"r {m['pearson_r_mean']:.3f} | {time.time()-t0:.1f}s") if m['avg_eucl_mm'] < best_eucl: best_eucl = m['avg_eucl_mm'] best_metrics = m best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} best_epoch = epoch patience_counter = 0 else: patience_counter += 1 if patience_counter >= args.patience: print(f" Early stop (best epoch {best_epoch})") break if best_state is not None: torch.save(best_state, os.path.join(out_dir, 'model_best.pt')) results = { 'experiment': exp_name, 'backbone': args.backbone, 'seed': args.seed, 'best_epoch': best_epoch, 'best_test_metrics': best_metrics, 'train_size': len(train_ds), 'test_size': len(test_ds), 'target_mean': tstats[0].flatten().tolist(), 'target_std': tstats[1].flatten().tolist(), 'args': vars(args), } with open(os.path.join(out_dir, 'results.json'), 'w') as f: json.dump(results, f, indent=2) print(f"Saved: {out_dir}/results.json") return results def main(): p = argparse.ArgumentParser() p.add_argument('--backbone', type=str, default='transformer', choices=['transformer', 'lstm', 'cnn']) p.add_argument('--epochs', type=int, default=60) p.add_argument('--batch_size', type=int, default=8) p.add_argument('--lr', type=float, default=1e-3) p.add_argument('--weight_decay', type=float, default=1e-4) p.add_argument('--hidden_dim', type=int, default=128) p.add_argument('--dropout', type=float, default=0.2) p.add_argument('--downsample', type=int, default=5) p.add_argument('--patience', type=int, default=12) p.add_argument('--seed', type=int, default=42) p.add_argument('--output_dir', type=str, required=True) p.add_argument('--tag', type=str, default='') args = p.parse_args() run_experiment(args) if __name__ == '__main__': main()