| |
| """ |
| Experiment D: EMG -> hand pose regression. |
| |
| Predict right-hand finger pose (5 fingertip positions relative to the wrist) |
| from 8-channel surface EMG. 15-dim per-timestep regression target. |
| |
| This directly supports the paper's stated prosthetics use case: |
| "The paired EMG and finger-level hand kinematics support EMG-to-hand-pose |
| decoding for myoelectric prostheses." |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import random |
| import argparse |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset, DataLoader |
| from torch.nn.utils.rnn import pad_sequence |
| from scipy.stats import pearsonr |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from data.dataset import ( |
| DATASET_DIR, MODALITY_FILES, TRAIN_VOLS, TEST_VOLS, |
| load_modality_array, SCENE_LABELS, |
| ) |
| from tasks.train_exp_grip import GripRegressor, set_seed, masked_huber |
|
|
| |
| WRIST = 'RightHand' |
| FINGERTIPS = ['RightHandThumb3', 'RightHandIndex3', 'RightHandMiddle3', |
| 'RightHandRing3', 'RightHandPinky3'] |
|
|
|
|
| def load_hand_pose_target(tsv_path): |
| """Load MoCap TSV and return wrist-relative fingertip positions |
| as (T, 15) array: [5 tips × 3 XYZ], in the raw coordinate frame.""" |
| try: |
| df = pd.read_csv(tsv_path, sep='\t') |
| except Exception: |
| return None |
| cols = set(df.columns) |
| needed = [f"{WRIST}_{ax}" for ax in 'XYZ'] |
| for tip in FINGERTIPS: |
| needed.extend([f"{tip}_{ax}" for ax in 'XYZ']) |
| if not all(c in cols for c in needed): |
| return None |
| wrist = df[[f"{WRIST}_{ax}" for ax in 'XYZ']].values.astype(np.float32) |
| tips = [] |
| for tip in FINGERTIPS: |
| t = df[[f"{tip}_{ax}" for ax in 'XYZ']].values.astype(np.float32) |
| tips.append(t - wrist) |
| pose = np.concatenate(tips, axis=1) |
| return pose |
|
|
|
|
| class EMG2PoseDataset(Dataset): |
| """Per-frame regression: EMG -> (5 wrist-relative fingertip XYZ = 15d).""" |
|
|
| def __init__(self, volunteers, downsample=5, stats=None, target_stats=None): |
| self.downsample = downsample |
| self.data = [] |
| self.targets = [] |
| self.sample_info = [] |
| for vol in volunteers: |
| vol_dir = os.path.join(DATASET_DIR, vol) |
| if not os.path.isdir(vol_dir): |
| continue |
| for scenario in sorted(os.listdir(vol_dir)): |
| scenario_dir = os.path.join(vol_dir, scenario) |
| if not os.path.isdir(scenario_dir) or scenario not in SCENE_LABELS: |
| continue |
| emg_fp = os.path.join(scenario_dir, MODALITY_FILES['emg']) |
| mocap_fp = os.path.join(scenario_dir, |
| f"aligned_{vol}{scenario}_s_Q.tsv") |
| if not (os.path.exists(emg_fp) and os.path.exists(mocap_fp)): |
| continue |
| emg = load_modality_array(emg_fp, 'emg') |
| if emg is None: |
| continue |
| pose = load_hand_pose_target(mocap_fp) |
| if pose is None: |
| continue |
| T_min = min(emg.shape[0], pose.shape[0]) |
| emg = emg[:T_min:downsample] |
| pose = pose[:T_min:downsample] |
| if emg.shape[0] < 10: |
| continue |
| self.data.append(emg.astype(np.float32)) |
| self.targets.append(pose.astype(np.float32)) |
| self.sample_info.append(f"{vol}/{scenario}") |
|
|
| if len(self.data) == 0: |
| raise RuntimeError("No data loaded.") |
| print(f" Loaded {len(self.data)} recordings, avg T " |
| f"{np.mean([d.shape[0] for d in self.data]):.0f}") |
|
|
| |
| if stats is not None: |
| self.mean, self.std = stats |
| else: |
| all_ = np.concatenate(self.data, axis=0).astype(np.float64) |
| self.mean = all_.mean(axis=0, keepdims=True) |
| self.std = all_.std(axis=0, keepdims=True) |
| self.std[self.std < 1e-8] = 1.0 |
| for i in range(len(self.data)): |
| self.data[i] = ((self.data[i].astype(np.float64) - self.mean) / |
| self.std).astype(np.float32) |
| self.data[i] = np.nan_to_num(self.data[i], nan=0.0, |
| posinf=0.0, neginf=0.0) |
|
|
| |
| if target_stats is not None: |
| self.t_mean, self.t_std = target_stats |
| else: |
| all_t = np.concatenate(self.targets, axis=0).astype(np.float64) |
| self.t_mean = all_t.mean(axis=0, keepdims=True) |
| self.t_std = all_t.std(axis=0, keepdims=True) |
| self.t_std[self.t_std < 1e-8] = 1.0 |
| for i in range(len(self.targets)): |
| self.targets[i] = ((self.targets[i].astype(np.float64) - |
| self.t_mean) / self.t_std).astype(np.float32) |
| self.targets[i] = np.nan_to_num(self.targets[i], nan=0.0, |
| posinf=0.0, neginf=0.0) |
|
|
| def get_stats(self): |
| return (self.mean, self.std) |
|
|
| def get_target_stats(self): |
| return (self.t_mean, self.t_std) |
|
|
| @property |
| def feat_dim(self): |
| return 8 |
|
|
| @property |
| def target_dim(self): |
| return 15 |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| return (torch.from_numpy(self.data[idx]), |
| torch.from_numpy(self.targets[idx])) |
|
|
|
|
| def collate_fn(batch): |
| seqs, targs = zip(*batch) |
| lens = torch.LongTensor([s.shape[0] for s in seqs]) |
| padded = pad_sequence(seqs, batch_first=True, padding_value=0.0) |
| padded_t = pad_sequence(targs, batch_first=True, padding_value=0.0) |
| max_len = padded.shape[1] |
| mask = torch.arange(max_len).unsqueeze(0) < lens.unsqueeze(1) |
| return padded, padded_t, mask, lens |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, loader, device, tmean, tstd): |
| model.eval() |
| total_loss = 0.0 |
| n_frames = 0 |
| all_preds, all_trues = [], [] |
| for x, y, mask, _ in loader: |
| x, y, mask = x.to(device), y.to(device), mask.to(device) |
| pred = model(x, mask) |
| loss = masked_huber(pred, y, mask, delta=1.0) |
| nf = mask.sum().item() |
| total_loss += loss.item() * nf |
| n_frames += nf |
| pred_np = pred.cpu().numpy() * tstd + tmean |
| true_np = y.cpu().numpy() * tstd + tmean |
| m_np = mask.cpu().numpy() |
| for b in range(pred_np.shape[0]): |
| valid = m_np[b] |
| all_preds.append(pred_np[b, valid]) |
| all_trues.append(true_np[b, valid]) |
| P = np.concatenate(all_preds, axis=0) |
| T = np.concatenate(all_trues, axis=0) |
| |
| mae = float(np.mean(np.abs(P - T))) |
| rs = [] |
| for d in range(15): |
| if np.std(P[:, d]) < 1e-6 or np.std(T[:, d]) < 1e-6: |
| rs.append(0.0) |
| else: |
| rs.append(float(pearsonr(P[:, d], T[:, d])[0])) |
| r_mean = float(np.mean(rs)) |
| |
| finger_mae = [] |
| for i in range(5): |
| finger_mae.append(float(np.mean(np.abs(P[:, 3*i:3*i+3] - |
| T[:, 3*i:3*i+3])))) |
| |
| tip_eucl = [] |
| for i in range(5): |
| d = np.linalg.norm(P[:, 3*i:3*i+3] - T[:, 3*i:3*i+3], axis=1) |
| tip_eucl.append(float(np.mean(d))) |
| return { |
| 'loss': total_loss / max(n_frames, 1), |
| 'mae': mae, |
| 'pearson_r_mean': r_mean, |
| 'pearson_r_per_coord': rs, |
| 'finger_mae': dict(zip(FINGERTIPS, finger_mae)), |
| 'finger_eucl_mm': dict(zip(FINGERTIPS, tip_eucl)), |
| 'avg_eucl_mm': float(np.mean(tip_eucl)), |
| } |
|
|
|
|
| def run_experiment(args): |
| set_seed(args.seed) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| print(f"Backbone: {args.backbone} | seed: {args.seed}") |
|
|
| print("Loading train...") |
| train_ds = EMG2PoseDataset(TRAIN_VOLS, downsample=args.downsample) |
| stats = train_ds.get_stats() |
| tstats = train_ds.get_target_stats() |
| print(f" target mean: {tstats[0].flatten()[:3]} ... std: {tstats[1].flatten()[:3]} ...") |
|
|
| print("Loading test...") |
| test_ds = EMG2PoseDataset(TEST_VOLS, downsample=args.downsample, |
| stats=stats, target_stats=tstats) |
|
|
| train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, |
| collate_fn=collate_fn, num_workers=0) |
| test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, |
| collate_fn=collate_fn, num_workers=0) |
|
|
| model = GripRegressor(args.backbone, 8, hidden_dim=args.hidden_dim, |
| output_dim=15, dropout=args.dropout).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"Params: {n_params:,}") |
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, |
| weight_decay=args.weight_decay) |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, mode='min', factor=0.5, patience=7, min_lr=1e-6, |
| ) |
|
|
| exp_name = f"pose_{args.backbone}_emg_seed{args.seed}" |
| if args.tag: |
| exp_name += f"_{args.tag}" |
| out_dir = os.path.join(args.output_dir, exp_name) |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| best_eucl = float('inf') |
| best_metrics = None |
| best_state = None |
| best_epoch = 0 |
| patience_counter = 0 |
|
|
| for epoch in range(1, args.epochs + 1): |
| t0 = time.time() |
| model.train() |
| tr_loss = 0.0 |
| n = 0 |
| for x, y, mask, _ in train_loader: |
| x, y, mask = x.to(device), y.to(device), mask.to(device) |
| optimizer.zero_grad() |
| pred = model(x, mask) |
| loss = masked_huber(pred, y, mask, delta=1.0) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| nf = mask.sum().item() |
| tr_loss += loss.item() * nf |
| n += nf |
| tr_loss /= max(n, 1) |
|
|
| m = evaluate(model, test_loader, device, tstats[0], tstats[1]) |
| scheduler.step(m['loss']) |
| print(f" E{epoch:3d} | tr {tr_loss:.4f} | te_loss {m['loss']:.4f} " |
| f"mae {m['mae']:.2f}mm eucl {m['avg_eucl_mm']:.2f}mm " |
| f"r {m['pearson_r_mean']:.3f} | {time.time()-t0:.1f}s") |
| if m['avg_eucl_mm'] < best_eucl: |
| best_eucl = m['avg_eucl_mm'] |
| best_metrics = m |
| best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} |
| best_epoch = epoch |
| patience_counter = 0 |
| else: |
| patience_counter += 1 |
| if patience_counter >= args.patience: |
| print(f" Early stop (best epoch {best_epoch})") |
| break |
|
|
| if best_state is not None: |
| torch.save(best_state, os.path.join(out_dir, 'model_best.pt')) |
|
|
| results = { |
| 'experiment': exp_name, |
| 'backbone': args.backbone, |
| 'seed': args.seed, |
| 'best_epoch': best_epoch, |
| 'best_test_metrics': best_metrics, |
| 'train_size': len(train_ds), |
| 'test_size': len(test_ds), |
| 'target_mean': tstats[0].flatten().tolist(), |
| 'target_std': tstats[1].flatten().tolist(), |
| 'args': vars(args), |
| } |
| with open(os.path.join(out_dir, 'results.json'), 'w') as f: |
| json.dump(results, f, indent=2) |
| print(f"Saved: {out_dir}/results.json") |
| return results |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument('--backbone', type=str, default='transformer', |
| choices=['transformer', 'lstm', 'cnn']) |
| p.add_argument('--epochs', type=int, default=60) |
| p.add_argument('--batch_size', type=int, default=8) |
| p.add_argument('--lr', type=float, default=1e-3) |
| p.add_argument('--weight_decay', type=float, default=1e-4) |
| p.add_argument('--hidden_dim', type=int, default=128) |
| p.add_argument('--dropout', type=float, default=0.2) |
| p.add_argument('--downsample', type=int, default=5) |
| p.add_argument('--patience', type=int, default=12) |
| p.add_argument('--seed', type=int, default=42) |
| p.add_argument('--output_dir', type=str, required=True) |
| p.add_argument('--tag', type=str, default='') |
| args = p.parse_args() |
| run_experiment(args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|