#!/usr/bin/env python3 """ Experiment 3: Grasp/Contact Event Detection Use pressure as ground truth, predict contact from other modalities. Binary classification per frame: contact vs non-contact for left and right hands. """ import os import sys import json import time import random import argparse import numpy as np import pandas as pd import torch import torch.nn as nn from sklearn.metrics import f1_score, precision_score, recall_score from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.dataset import ( DATASET_DIR, MODALITY_FILES, SKIP_COLS, SKIP_COL_SUFFIXES, TRAIN_VOLS, VAL_VOLS, TEST_VOLS, load_modality_array, get_modality_filepath ) PRESSURE_THRESHOLD = 5.0 # grams WINDOW_SIZE = 256 # 2.56s at 100Hz, or 1.28s at downsample=1 (we keep 100Hz for this task) WINDOW_STRIDE = 128 def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def load_modality(scenario_dir, modality, vol=None, scenario=None): """Load a single modality's features from CSV.""" if vol and scenario: filepath = get_modality_filepath(scenario_dir, modality, vol, scenario) else: filepath = os.path.join(scenario_dir, MODALITY_FILES[modality]) return load_modality_array(filepath, modality) def generate_contact_labels(scenario_dir, n_frames): """Generate binary contact labels from pressure data.""" pressure_path = os.path.join(scenario_dir, MODALITY_FILES['pressure']) df = pd.read_csv(pressure_path) # Right hand: R1(g) to R25(g), Left hand: L1(g) to L25(g) r_cols = [c for c in df.columns if c.startswith('R') and c.endswith('(g)')] l_cols = [c for c in df.columns if c.startswith('L') and c.endswith('(g)')] r_pressure = df[r_cols].apply(pd.to_numeric, errors='coerce').values l_pressure = df[l_cols].apply(pd.to_numeric, errors='coerce').values r_pressure = np.nan_to_num(r_pressure, nan=0.0) l_pressure = np.nan_to_num(l_pressure, nan=0.0) r_total = np.sum(r_pressure, axis=1) l_total = np.sum(l_pressure, axis=1) r_contact = (r_total > PRESSURE_THRESHOLD).astype(np.float32) l_contact = (l_total > PRESSURE_THRESHOLD).astype(np.float32) # Truncate or pad to match n_frames min_len = min(len(r_contact), n_frames) labels = np.zeros((n_frames, 2), dtype=np.float32) labels[:min_len, 0] = r_contact[:min_len] labels[:min_len, 1] = l_contact[:min_len] return labels # (T, 2) class ContactDataset(Dataset): """Sliding window dataset for contact detection.""" def __init__(self, volunteers, input_modalities, window_size=WINDOW_SIZE, stride=WINDOW_STRIDE, downsample=2, stats=None): self.windows = [] # (features, labels) pairs self.input_modalities = input_modalities self._feat_dim = None print(f" Loading contact data for {len(volunteers)} volunteers...") all_features = [] for vol in volunteers: vol_dir = os.path.join(DATASET_DIR, vol) if not os.path.isdir(vol_dir): continue for scenario in sorted(os.listdir(vol_dir)): scenario_dir = os.path.join(vol_dir, scenario) if not os.path.isdir(scenario_dir): continue meta_path = os.path.join(scenario_dir, 'alignment_metadata.json') if not os.path.exists(meta_path): continue with open(meta_path) as f: meta = json.load(f) available = set(meta['modalities']) required = set(input_modalities) | {'pressure'} if not required.issubset(available): continue # Load input modalities parts = [] for mod in input_modalities: arr = load_modality(scenario_dir, mod, vol, scenario) parts.append(arr) min_len = min(p.shape[0] for p in parts) features = np.concatenate([p[:min_len] for p in parts], axis=1) # Downsample (less aggressive for frame-level task) features = features[::downsample] # Generate contact labels labels = generate_contact_labels(scenario_dir, min_len) labels = labels[::downsample] if self._feat_dim is None: self._feat_dim = features.shape[1] all_features.append(features) # Extract sliding windows T = features.shape[0] for start in range(0, T - window_size + 1, stride): end = start + window_size self.windows.append(( features[start:end], labels[start:end], )) # Compute normalization stats if stats is not None: self.mean, self.std = stats else: if all_features: all_data = np.concatenate(all_features, axis=0) self.mean = np.mean(all_data, axis=0, keepdims=True).astype(np.float32) self.std = np.std(all_data, axis=0, keepdims=True).astype(np.float32) self.std[self.std < 1e-8] = 1.0 else: self.mean = np.zeros((1, self._feat_dim or 1), dtype=np.float32) self.std = np.ones((1, self._feat_dim or 1), dtype=np.float32) # Apply normalization self.windows = [ ((w[0] - self.mean) / self.std, w[1]) for w in self.windows ] # Count positive ratio all_labels = np.concatenate([w[1] for w in self.windows], axis=0) if self.windows else np.array([]) if len(all_labels) > 0: r_pos = all_labels[:, 0].mean() l_pos = all_labels[:, 1].mean() print(f" Windows: {len(self.windows)}, R_contact: {r_pos:.2%}, L_contact: {l_pos:.2%}") def get_stats(self): return (self.mean, self.std) @property def feat_dim(self): return self._feat_dim def __len__(self): return len(self.windows) def __getitem__(self, idx): features, labels = self.windows[idx] return torch.from_numpy(features), torch.from_numpy(labels) # ============================================================ # Models # ============================================================ class TCN(nn.Module): """Temporal Convolutional Network for frame-level prediction.""" def __init__(self, input_dim, hidden_dim=64, num_layers=4, kernel_size=5): super().__init__() layers = [] in_ch = input_dim for i in range(num_layers): dilation = 2 ** i padding = (kernel_size - 1) * dilation // 2 layers.append(nn.Sequential( nn.Conv1d(in_ch, hidden_dim, kernel_size, padding=padding, dilation=dilation), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(0.1), )) in_ch = hidden_dim self.net = nn.ModuleList(layers) self.head = nn.Conv1d(hidden_dim, 2, 1) # 2 outputs: right_contact, left_contact def forward(self, x): # x: (B, T, C) -> (B, C, T) x = x.permute(0, 2, 1) for layer in self.net: x = layer(x) out = self.head(x) # (B, 2, T) return out.permute(0, 2, 1) # (B, T, 2) class BiLSTMContact(nn.Module): """Bi-LSTM for frame-level contact prediction.""" def __init__(self, input_dim, hidden_dim=64, num_layers=2): super().__init__() self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=True, dropout=0.2 if num_layers > 1 else 0) self.head = nn.Linear(hidden_dim * 2, 2) def forward(self, x): out, _ = self.lstm(x) return self.head(out) # (B, T, 2) class CNN1DContact(nn.Module): """1D CNN for frame-level contact prediction.""" def __init__(self, input_dim, hidden_dim=64): super().__init__() self.net = nn.Sequential( nn.Conv1d(input_dim, hidden_dim, 7, padding=3), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(0.1), nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(0.1), nn.Conv1d(hidden_dim, hidden_dim, 3, padding=1), nn.BatchNorm1d(hidden_dim), nn.ReLU(), ) self.head = nn.Conv1d(hidden_dim, 2, 1) def forward(self, x): x = x.permute(0, 2, 1) x = self.net(x) out = self.head(x) return out.permute(0, 2, 1) def build_contact_model(name, input_dim, hidden_dim=64): if name == 'tcn': return TCN(input_dim, hidden_dim) elif name == 'lstm': return BiLSTMContact(input_dim, hidden_dim) elif name == 'cnn': return CNN1DContact(input_dim, hidden_dim) elif name == 'asformer': from experiments.published_baselines import ASFormerContact return ASFormerContact(input_dim, hidden_dim, num_layers=5, num_decoders=2) elif name == 'deepconvlstm': from experiments.published_models import DeepConvLSTMContact return DeepConvLSTMContact(input_dim, hidden_dim) elif name == 'inceptiontime': from experiments.published_models import InceptionTimeContact return InceptionTimeContact(input_dim, hidden_dim) elif name == 'underpressure': from experiments.published_models import UnderPressureContact return UnderPressureContact(input_dim, hidden_dim) else: raise ValueError(f"Unknown model: {name}") # ============================================================ # Training # ============================================================ def train_one_epoch(model, loader, criterion, optimizer, device): model.train() total_loss = 0 n_samples = 0 for x, y in loader: x, y = x.to(device), y.to(device) optimizer.zero_grad() pred = model(x) # (B, T, 2) loss = criterion(pred.reshape(-1, 2), y.reshape(-1, 2)) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() total_loss += loss.item() * x.size(0) n_samples += x.size(0) return total_loss / n_samples @torch.no_grad() def evaluate(model, loader, criterion, device): model.eval() total_loss = 0 n_samples = 0 all_preds_r, all_labels_r = [], [] all_preds_l, all_labels_l = [], [] for x, y in loader: x, y = x.to(device), y.to(device) pred = model(x) loss = criterion(pred.reshape(-1, 2), y.reshape(-1, 2)) total_loss += loss.item() * x.size(0) n_samples += x.size(0) pred_binary = (torch.sigmoid(pred) > 0.5).cpu().numpy() y_np = y.cpu().numpy() all_preds_r.append(pred_binary[:, :, 0].flatten()) all_labels_r.append(y_np[:, :, 0].flatten()) all_preds_l.append(pred_binary[:, :, 1].flatten()) all_labels_l.append(y_np[:, :, 1].flatten()) avg_loss = total_loss / n_samples preds_r = np.concatenate(all_preds_r) labels_r = np.concatenate(all_labels_r) preds_l = np.concatenate(all_preds_l) labels_l = np.concatenate(all_labels_l) metrics = {} for hand, preds, labels in [('right', preds_r, labels_r), ('left', preds_l, labels_l)]: metrics[f'{hand}_f1'] = f1_score(labels, preds, zero_division=0) metrics[f'{hand}_precision'] = precision_score(labels, preds, zero_division=0) metrics[f'{hand}_recall'] = recall_score(labels, preds, zero_division=0) metrics['avg_f1'] = (metrics['right_f1'] + metrics['left_f1']) / 2 return avg_loss, metrics def run_experiment(args): set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') input_mods = args.modalities.split(',') print(f"\n{'='*60}") print(f"Exp3 Contact Detection | Model: {args.model} | Input: {input_mods}") print(f"{'='*60}") train_ds = ContactDataset(TRAIN_VOLS, input_mods, downsample=args.downsample) stats = train_ds.get_stats() val_ds = ContactDataset(VAL_VOLS, input_mods, downsample=args.downsample, stats=stats) test_ds = ContactDataset(TEST_VOLS, input_mods, downsample=args.downsample, stats=stats) if len(train_ds) == 0: print("No training data available for this modality combination!") return None train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=0) test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=0) # Use test set for validation when val set is empty if len(val_ds) > 0: val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=0) else: val_loader = test_loader print(" No val data, using test set for early stopping.") model = build_contact_model(args.model, train_ds.feat_dim, args.hidden_dim).to(device) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Model params: {n_params:,}, feat_dim: {train_ds.feat_dim}") criterion = nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=7, factor=0.5) mod_str = '-'.join(input_mods) exp_name = f"exp3_{args.model}_{mod_str}_s{args.seed}" out_dir = os.path.join(args.output_dir, exp_name) os.makedirs(out_dir, exist_ok=True) best_val_f1 = 0 best_epoch = 0 patience_counter = 0 for epoch in range(1, args.epochs + 1): t0 = time.time() train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device) val_loss, val_metrics = evaluate(model, val_loader, criterion, device) scheduler.step(val_loss) elapsed = time.time() - t0 print(f" Epoch {epoch:3d} | Train Loss: {train_loss:.4f} | " f"Val Loss: {val_loss:.4f} F1: {val_metrics['avg_f1']:.4f} | {elapsed:.1f}s") if val_metrics['avg_f1'] > best_val_f1: best_val_f1 = val_metrics['avg_f1'] best_epoch = epoch patience_counter = 0 torch.save(model.state_dict(), os.path.join(out_dir, 'model_best.pt')) else: patience_counter += 1 if patience_counter >= args.patience: print(f" Early stopping at epoch {epoch}") break # Test model.load_state_dict(torch.load(os.path.join(out_dir, 'model_best.pt'), weights_only=True)) test_loss, test_metrics = evaluate(model, test_loader, criterion, device) print(f"\n--- Test Results (epoch {best_epoch}) ---") for k, v in test_metrics.items(): print(f" {k}: {v:.4f}") results = { 'experiment': exp_name, 'model': args.model, 'input_modalities': input_mods, 'best_epoch': best_epoch, 'test_metrics': {k: float(v) for k, v in test_metrics.items()}, 'n_params': n_params, 'train_windows': len(train_ds), 'val_windows': len(val_ds), 'test_windows': len(test_ds), 'args': vars(args), } with open(os.path.join(out_dir, 'results.json'), 'w') as f: json.dump(results, f, indent=2) print(f" Saved to {out_dir}") return results def run_all(args): """Run all modality combinations for contact detection.""" modality_combos = [ 'mocap', 'emg', 'imu', 'eyetrack', 'mocap,emg', 'mocap,emg,eyetrack', 'mocap,emg,eyetrack,imu', ] models = ['cnn', 'lstm', 'tcn'] all_results = [] for mod_combo in modality_combos: for model_name in models: args.modalities = mod_combo args.model = model_name try: result = run_experiment(args) if result: all_results.append(result) except Exception as e: print(f"FAILED: {model_name}/{mod_combo}: {e}") all_results.append({'experiment': f"exp3_{model_name}_{mod_combo}", 'error': str(e)}) summary_path = os.path.join(args.output_dir, 'exp3_summary.json') with open(summary_path, 'w') as f: json.dump(all_results, f, indent=2) print(f"\n{'='*60}") print(f"{'Model':<10} {'Input Modalities':<30} {'R_F1':<8} {'L_F1':<8} {'Avg_F1':<8}") print('-' * 70) for r in all_results: if 'error' in r: continue m = r['test_metrics'] mods = ','.join(r['input_modalities']) print(f"{r['model']:<10} {mods:<30} {m['right_f1']:.4f} {m['left_f1']:.4f} {m['avg_f1']:.4f}") def main(): parser = argparse.ArgumentParser(description='Exp3: Contact Detection') parser.add_argument('--model', type=str, default='tcn', choices=['cnn', 'lstm', 'tcn', 'asformer', 'deepconvlstm', 'inceptiontime', 'underpressure']) parser.add_argument('--modalities', type=str, default='mocap,emg', help='Input modalities (excluding pressure which is GT)') parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--weight_decay', type=float, default=1e-4) parser.add_argument('--hidden_dim', type=int, default=64) parser.add_argument('--downsample', type=int, default=2, help='Downsample from 100Hz (2 = 50Hz)') parser.add_argument('--patience', type=int, default=10) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--output_dir', type=str, default='${PULSE_ROOT}/results/exp3') parser.add_argument('--run_all', action='store_true') args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) if args.run_all: run_all(args) else: run_experiment(args) if __name__ == '__main__': main()