| |
| """ |
| Experiment 3: Grasp/Contact Event Detection |
| Use pressure as ground truth, predict contact from other modalities. |
| Binary classification per frame: contact vs non-contact for left and right hands. |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import random |
| import argparse |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| from sklearn.metrics import f1_score, precision_score, recall_score |
| from torch.utils.data import Dataset, DataLoader |
| from torch.nn.utils.rnn import pad_sequence |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from data.dataset import ( |
| DATASET_DIR, MODALITY_FILES, SKIP_COLS, SKIP_COL_SUFFIXES, |
| TRAIN_VOLS, VAL_VOLS, TEST_VOLS, load_modality_array, get_modality_filepath |
| ) |
|
|
| PRESSURE_THRESHOLD = 5.0 |
| WINDOW_SIZE = 256 |
| WINDOW_STRIDE = 128 |
|
|
|
|
| def set_seed(seed): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def load_modality(scenario_dir, modality, vol=None, scenario=None): |
| """Load a single modality's features from CSV.""" |
| if vol and scenario: |
| filepath = get_modality_filepath(scenario_dir, modality, vol, scenario) |
| else: |
| filepath = os.path.join(scenario_dir, MODALITY_FILES[modality]) |
| return load_modality_array(filepath, modality) |
|
|
|
|
| def generate_contact_labels(scenario_dir, n_frames): |
| """Generate binary contact labels from pressure data.""" |
| pressure_path = os.path.join(scenario_dir, MODALITY_FILES['pressure']) |
| df = pd.read_csv(pressure_path) |
| |
| r_cols = [c for c in df.columns if c.startswith('R') and c.endswith('(g)')] |
| l_cols = [c for c in df.columns if c.startswith('L') and c.endswith('(g)')] |
|
|
| r_pressure = df[r_cols].apply(pd.to_numeric, errors='coerce').values |
| l_pressure = df[l_cols].apply(pd.to_numeric, errors='coerce').values |
|
|
| r_pressure = np.nan_to_num(r_pressure, nan=0.0) |
| l_pressure = np.nan_to_num(l_pressure, nan=0.0) |
|
|
| r_total = np.sum(r_pressure, axis=1) |
| l_total = np.sum(l_pressure, axis=1) |
|
|
| r_contact = (r_total > PRESSURE_THRESHOLD).astype(np.float32) |
| l_contact = (l_total > PRESSURE_THRESHOLD).astype(np.float32) |
|
|
| |
| min_len = min(len(r_contact), n_frames) |
| labels = np.zeros((n_frames, 2), dtype=np.float32) |
| labels[:min_len, 0] = r_contact[:min_len] |
| labels[:min_len, 1] = l_contact[:min_len] |
|
|
| return labels |
|
|
|
|
| class ContactDataset(Dataset): |
| """Sliding window dataset for contact detection.""" |
|
|
| def __init__(self, volunteers, input_modalities, window_size=WINDOW_SIZE, |
| stride=WINDOW_STRIDE, downsample=2, stats=None): |
| self.windows = [] |
| self.input_modalities = input_modalities |
| self._feat_dim = None |
|
|
| print(f" Loading contact data for {len(volunteers)} volunteers...") |
| all_features = [] |
|
|
| for vol in volunteers: |
| vol_dir = os.path.join(DATASET_DIR, vol) |
| if not os.path.isdir(vol_dir): |
| continue |
| for scenario in sorted(os.listdir(vol_dir)): |
| scenario_dir = os.path.join(vol_dir, scenario) |
| if not os.path.isdir(scenario_dir): |
| continue |
| meta_path = os.path.join(scenario_dir, 'alignment_metadata.json') |
| if not os.path.exists(meta_path): |
| continue |
| with open(meta_path) as f: |
| meta = json.load(f) |
|
|
| available = set(meta['modalities']) |
| required = set(input_modalities) | {'pressure'} |
| if not required.issubset(available): |
| continue |
|
|
| |
| parts = [] |
| for mod in input_modalities: |
| arr = load_modality(scenario_dir, mod, vol, scenario) |
| parts.append(arr) |
|
|
| min_len = min(p.shape[0] for p in parts) |
| features = np.concatenate([p[:min_len] for p in parts], axis=1) |
|
|
| |
| features = features[::downsample] |
|
|
| |
| labels = generate_contact_labels(scenario_dir, min_len) |
| labels = labels[::downsample] |
|
|
| if self._feat_dim is None: |
| self._feat_dim = features.shape[1] |
|
|
| all_features.append(features) |
|
|
| |
| T = features.shape[0] |
| for start in range(0, T - window_size + 1, stride): |
| end = start + window_size |
| self.windows.append(( |
| features[start:end], |
| labels[start:end], |
| )) |
|
|
| |
| if stats is not None: |
| self.mean, self.std = stats |
| else: |
| if all_features: |
| all_data = np.concatenate(all_features, axis=0) |
| self.mean = np.mean(all_data, axis=0, keepdims=True).astype(np.float32) |
| self.std = np.std(all_data, axis=0, keepdims=True).astype(np.float32) |
| self.std[self.std < 1e-8] = 1.0 |
| else: |
| self.mean = np.zeros((1, self._feat_dim or 1), dtype=np.float32) |
| self.std = np.ones((1, self._feat_dim or 1), dtype=np.float32) |
|
|
| |
| self.windows = [ |
| ((w[0] - self.mean) / self.std, w[1]) |
| for w in self.windows |
| ] |
|
|
| |
| all_labels = np.concatenate([w[1] for w in self.windows], axis=0) if self.windows else np.array([]) |
| if len(all_labels) > 0: |
| r_pos = all_labels[:, 0].mean() |
| l_pos = all_labels[:, 1].mean() |
| print(f" Windows: {len(self.windows)}, R_contact: {r_pos:.2%}, L_contact: {l_pos:.2%}") |
|
|
| def get_stats(self): |
| return (self.mean, self.std) |
|
|
| @property |
| def feat_dim(self): |
| return self._feat_dim |
|
|
| def __len__(self): |
| return len(self.windows) |
|
|
| def __getitem__(self, idx): |
| features, labels = self.windows[idx] |
| return torch.from_numpy(features), torch.from_numpy(labels) |
|
|
|
|
| |
| |
| |
|
|
| class TCN(nn.Module): |
| """Temporal Convolutional Network for frame-level prediction.""" |
|
|
| def __init__(self, input_dim, hidden_dim=64, num_layers=4, kernel_size=5): |
| super().__init__() |
| layers = [] |
| in_ch = input_dim |
| for i in range(num_layers): |
| dilation = 2 ** i |
| padding = (kernel_size - 1) * dilation // 2 |
| layers.append(nn.Sequential( |
| nn.Conv1d(in_ch, hidden_dim, kernel_size, padding=padding, dilation=dilation), |
| nn.BatchNorm1d(hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| )) |
| in_ch = hidden_dim |
| self.net = nn.ModuleList(layers) |
| self.head = nn.Conv1d(hidden_dim, 2, 1) |
|
|
| def forward(self, x): |
| |
| x = x.permute(0, 2, 1) |
| for layer in self.net: |
| x = layer(x) |
| out = self.head(x) |
| return out.permute(0, 2, 1) |
|
|
|
|
| class BiLSTMContact(nn.Module): |
| """Bi-LSTM for frame-level contact prediction.""" |
|
|
| def __init__(self, input_dim, hidden_dim=64, num_layers=2): |
| super().__init__() |
| self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, |
| batch_first=True, bidirectional=True, |
| dropout=0.2 if num_layers > 1 else 0) |
| self.head = nn.Linear(hidden_dim * 2, 2) |
|
|
| def forward(self, x): |
| out, _ = self.lstm(x) |
| return self.head(out) |
|
|
|
|
| class CNN1DContact(nn.Module): |
| """1D CNN for frame-level contact prediction.""" |
|
|
| def __init__(self, input_dim, hidden_dim=64): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Conv1d(input_dim, hidden_dim, 7, padding=3), |
| nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(0.1), |
| nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2), |
| nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(0.1), |
| nn.Conv1d(hidden_dim, hidden_dim, 3, padding=1), |
| nn.BatchNorm1d(hidden_dim), nn.ReLU(), |
| ) |
| self.head = nn.Conv1d(hidden_dim, 2, 1) |
|
|
| def forward(self, x): |
| x = x.permute(0, 2, 1) |
| x = self.net(x) |
| out = self.head(x) |
| return out.permute(0, 2, 1) |
|
|
|
|
| def build_contact_model(name, input_dim, hidden_dim=64): |
| if name == 'tcn': |
| return TCN(input_dim, hidden_dim) |
| elif name == 'lstm': |
| return BiLSTMContact(input_dim, hidden_dim) |
| elif name == 'cnn': |
| return CNN1DContact(input_dim, hidden_dim) |
| elif name == 'asformer': |
| from experiments.published_baselines import ASFormerContact |
| return ASFormerContact(input_dim, hidden_dim, |
| num_layers=5, num_decoders=2) |
| elif name == 'deepconvlstm': |
| from experiments.published_models import DeepConvLSTMContact |
| return DeepConvLSTMContact(input_dim, hidden_dim) |
| elif name == 'inceptiontime': |
| from experiments.published_models import InceptionTimeContact |
| return InceptionTimeContact(input_dim, hidden_dim) |
| elif name == 'underpressure': |
| from experiments.published_models import UnderPressureContact |
| return UnderPressureContact(input_dim, hidden_dim) |
| else: |
| raise ValueError(f"Unknown model: {name}") |
|
|
|
|
| |
| |
| |
|
|
| def train_one_epoch(model, loader, criterion, optimizer, device): |
| model.train() |
| total_loss = 0 |
| n_samples = 0 |
| for x, y in loader: |
| x, y = x.to(device), y.to(device) |
| optimizer.zero_grad() |
| pred = model(x) |
| loss = criterion(pred.reshape(-1, 2), y.reshape(-1, 2)) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| total_loss += loss.item() * x.size(0) |
| n_samples += x.size(0) |
| return total_loss / n_samples |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, loader, criterion, device): |
| model.eval() |
| total_loss = 0 |
| n_samples = 0 |
| all_preds_r, all_labels_r = [], [] |
| all_preds_l, all_labels_l = [], [] |
|
|
| for x, y in loader: |
| x, y = x.to(device), y.to(device) |
| pred = model(x) |
| loss = criterion(pred.reshape(-1, 2), y.reshape(-1, 2)) |
| total_loss += loss.item() * x.size(0) |
| n_samples += x.size(0) |
|
|
| pred_binary = (torch.sigmoid(pred) > 0.5).cpu().numpy() |
| y_np = y.cpu().numpy() |
|
|
| all_preds_r.append(pred_binary[:, :, 0].flatten()) |
| all_labels_r.append(y_np[:, :, 0].flatten()) |
| all_preds_l.append(pred_binary[:, :, 1].flatten()) |
| all_labels_l.append(y_np[:, :, 1].flatten()) |
|
|
| avg_loss = total_loss / n_samples |
| preds_r = np.concatenate(all_preds_r) |
| labels_r = np.concatenate(all_labels_r) |
| preds_l = np.concatenate(all_preds_l) |
| labels_l = np.concatenate(all_labels_l) |
|
|
| metrics = {} |
| for hand, preds, labels in [('right', preds_r, labels_r), ('left', preds_l, labels_l)]: |
| metrics[f'{hand}_f1'] = f1_score(labels, preds, zero_division=0) |
| metrics[f'{hand}_precision'] = precision_score(labels, preds, zero_division=0) |
| metrics[f'{hand}_recall'] = recall_score(labels, preds, zero_division=0) |
|
|
| metrics['avg_f1'] = (metrics['right_f1'] + metrics['left_f1']) / 2 |
| return avg_loss, metrics |
|
|
|
|
| def run_experiment(args): |
| set_seed(args.seed) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| input_mods = args.modalities.split(',') |
|
|
| print(f"\n{'='*60}") |
| print(f"Exp3 Contact Detection | Model: {args.model} | Input: {input_mods}") |
| print(f"{'='*60}") |
|
|
| train_ds = ContactDataset(TRAIN_VOLS, input_mods, downsample=args.downsample) |
| stats = train_ds.get_stats() |
| val_ds = ContactDataset(VAL_VOLS, input_mods, downsample=args.downsample, stats=stats) |
| test_ds = ContactDataset(TEST_VOLS, input_mods, downsample=args.downsample, stats=stats) |
|
|
| if len(train_ds) == 0: |
| print("No training data available for this modality combination!") |
| return None |
|
|
| train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=0) |
| test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=0) |
| |
| if len(val_ds) > 0: |
| val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=0) |
| else: |
| val_loader = test_loader |
| print(" No val data, using test set for early stopping.") |
|
|
| model = build_contact_model(args.model, train_ds.feat_dim, args.hidden_dim).to(device) |
| n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Model params: {n_params:,}, feat_dim: {train_ds.feat_dim}") |
|
|
| criterion = nn.BCEWithLogitsLoss() |
| optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=7, factor=0.5) |
|
|
| mod_str = '-'.join(input_mods) |
| exp_name = f"exp3_{args.model}_{mod_str}_s{args.seed}" |
| out_dir = os.path.join(args.output_dir, exp_name) |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| best_val_f1 = 0 |
| best_epoch = 0 |
| patience_counter = 0 |
|
|
| for epoch in range(1, args.epochs + 1): |
| t0 = time.time() |
| train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device) |
| val_loss, val_metrics = evaluate(model, val_loader, criterion, device) |
| scheduler.step(val_loss) |
| elapsed = time.time() - t0 |
|
|
| print(f" Epoch {epoch:3d} | Train Loss: {train_loss:.4f} | " |
| f"Val Loss: {val_loss:.4f} F1: {val_metrics['avg_f1']:.4f} | {elapsed:.1f}s") |
|
|
| if val_metrics['avg_f1'] > best_val_f1: |
| best_val_f1 = val_metrics['avg_f1'] |
| best_epoch = epoch |
| patience_counter = 0 |
| torch.save(model.state_dict(), os.path.join(out_dir, 'model_best.pt')) |
| else: |
| patience_counter += 1 |
|
|
| if patience_counter >= args.patience: |
| print(f" Early stopping at epoch {epoch}") |
| break |
|
|
| |
| model.load_state_dict(torch.load(os.path.join(out_dir, 'model_best.pt'), weights_only=True)) |
| test_loss, test_metrics = evaluate(model, test_loader, criterion, device) |
|
|
| print(f"\n--- Test Results (epoch {best_epoch}) ---") |
| for k, v in test_metrics.items(): |
| print(f" {k}: {v:.4f}") |
|
|
| results = { |
| 'experiment': exp_name, |
| 'model': args.model, |
| 'input_modalities': input_mods, |
| 'best_epoch': best_epoch, |
| 'test_metrics': {k: float(v) for k, v in test_metrics.items()}, |
| 'n_params': n_params, |
| 'train_windows': len(train_ds), |
| 'val_windows': len(val_ds), |
| 'test_windows': len(test_ds), |
| 'args': vars(args), |
| } |
| with open(os.path.join(out_dir, 'results.json'), 'w') as f: |
| json.dump(results, f, indent=2) |
| print(f" Saved to {out_dir}") |
| return results |
|
|
|
|
| def run_all(args): |
| """Run all modality combinations for contact detection.""" |
| modality_combos = [ |
| 'mocap', |
| 'emg', |
| 'imu', |
| 'eyetrack', |
| 'mocap,emg', |
| 'mocap,emg,eyetrack', |
| 'mocap,emg,eyetrack,imu', |
| ] |
| models = ['cnn', 'lstm', 'tcn'] |
| all_results = [] |
|
|
| for mod_combo in modality_combos: |
| for model_name in models: |
| args.modalities = mod_combo |
| args.model = model_name |
| try: |
| result = run_experiment(args) |
| if result: |
| all_results.append(result) |
| except Exception as e: |
| print(f"FAILED: {model_name}/{mod_combo}: {e}") |
| all_results.append({'experiment': f"exp3_{model_name}_{mod_combo}", 'error': str(e)}) |
|
|
| summary_path = os.path.join(args.output_dir, 'exp3_summary.json') |
| with open(summary_path, 'w') as f: |
| json.dump(all_results, f, indent=2) |
|
|
| print(f"\n{'='*60}") |
| print(f"{'Model':<10} {'Input Modalities':<30} {'R_F1':<8} {'L_F1':<8} {'Avg_F1':<8}") |
| print('-' * 70) |
| for r in all_results: |
| if 'error' in r: |
| continue |
| m = r['test_metrics'] |
| mods = ','.join(r['input_modalities']) |
| print(f"{r['model']:<10} {mods:<30} {m['right_f1']:.4f} {m['left_f1']:.4f} {m['avg_f1']:.4f}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Exp3: Contact Detection') |
| parser.add_argument('--model', type=str, default='tcn', |
| choices=['cnn', 'lstm', 'tcn', 'asformer', |
| 'deepconvlstm', 'inceptiontime', 'underpressure']) |
| parser.add_argument('--modalities', type=str, default='mocap,emg', |
| help='Input modalities (excluding pressure which is GT)') |
| parser.add_argument('--epochs', type=int, default=50) |
| parser.add_argument('--batch_size', type=int, default=32) |
| parser.add_argument('--lr', type=float, default=1e-3) |
| parser.add_argument('--weight_decay', type=float, default=1e-4) |
| parser.add_argument('--hidden_dim', type=int, default=64) |
| parser.add_argument('--downsample', type=int, default=2, |
| help='Downsample from 100Hz (2 = 50Hz)') |
| parser.add_argument('--patience', type=int, default=10) |
| parser.add_argument('--seed', type=int, default=42) |
| parser.add_argument('--output_dir', type=str, |
| default='${PULSE_ROOT}/results/exp3') |
| parser.add_argument('--run_all', action='store_true') |
| args = parser.parse_args() |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| if args.run_all: |
| run_all(args) |
| else: |
| run_experiment(args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|