PULSE-code / experiments /tasks /train_exp4.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
#!/usr/bin/env python3
"""
Experiment 4: Cross-Modal Prediction
Sub-tasks:
4a: MoCap (hand joints) → Pressure (50ch)
4b: EMG (8ch) → Hand Pose (fingertip positions, 30D)
4c: Body skeleton → Gaze (2D gaze point)
"""
import os
import sys
import json
import time
import random
import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from scipy.stats import pearsonr
from torch.utils.data import Dataset, DataLoader
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.dataset import (
DATASET_DIR, MODALITY_FILES, SKIP_COLS, SKIP_COL_SUFFIXES,
TRAIN_VOLS, VAL_VOLS, TEST_VOLS
)
WINDOW_SIZE = 256
WINDOW_STRIDE = 128
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def load_modality_with_cols(scenario_dir, modality, vol=None, scenario=None):
"""Load modality data and return (array, column_names)."""
if modality == 'mocap':
# MoCap uses special naming: aligned_{vol}{scene}_s_Q.tsv
if vol is None or scenario is None:
# Try to infer from scenario_dir path
parts = scenario_dir.rstrip('/').split('/')
scenario = parts[-1]
vol = parts[-2]
filepath = os.path.join(scenario_dir, f"aligned_{vol}{scenario}_s_Q.tsv")
else:
filepath = os.path.join(scenario_dir, MODALITY_FILES[modality])
sep = '\t' if filepath.endswith('.tsv') else ','
df = pd.read_csv(filepath, sep=sep, low_memory=False)
feat_cols = [c for c in df.columns
if c not in SKIP_COLS
and not any(c.endswith(s) for s in SKIP_COL_SUFFIXES)]
sub = df[feat_cols]
obj_cols = sub.select_dtypes(include=['object']).columns
if len(obj_cols) > 0:
sub = sub.copy()
sub[obj_cols] = sub[obj_cols].apply(pd.to_numeric, errors='coerce')
arr = sub.values.astype(np.float64)
arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
# Clip to reasonable sensor range (some MoCap recordings have corrupted values up to 1e304)
arr = np.clip(arr, -1e5, 1e5).astype(np.float32)
return arr, feat_cols
def get_subtask_config(subtask):
"""Return (input_modality, output_modality, input_col_filter, output_col_filter) for each subtask."""
if subtask == '4a':
# MoCap hand joints → Pressure
return 'mocap', 'pressure', lambda cols: [c for c in cols if 'Hand' in c or 'Wrist' in c or 'Thumb' in c or 'Index' in c or 'Middle' in c or 'Ring' in c or 'Pinky' in c], None
elif subtask == '4b':
# EMG → Hand fingertip positions
return 'emg', 'mocap', None, lambda cols: [c for c in cols if 'Tip' in c]
elif subtask == '4c':
# Body skeleton → Gaze point
return 'mocap', 'eyetrack', None, lambda cols: [c for c in cols if 'Pupil X' in c or 'Pupil Y' in c][:2]
else:
raise ValueError(f"Unknown subtask: {subtask}")
class CrossModalDataset(Dataset):
"""Sliding window dataset for cross-modal prediction."""
def __init__(self, volunteers, subtask, window_size=WINDOW_SIZE,
stride=WINDOW_STRIDE, downsample=2, stats=None):
self.windows = []
in_mod, out_mod, in_filter, out_filter = get_subtask_config(subtask)
all_inputs, all_outputs = [], []
self._input_dim = None
self._output_dim = None
for vol in volunteers:
vol_dir = os.path.join(DATASET_DIR, vol)
if not os.path.isdir(vol_dir):
continue
for scenario in sorted(os.listdir(vol_dir)):
scenario_dir = os.path.join(vol_dir, scenario)
if not os.path.isdir(scenario_dir):
continue
meta_path = os.path.join(scenario_dir, 'alignment_metadata.json')
if not os.path.exists(meta_path):
continue
with open(meta_path) as f:
meta = json.load(f)
required = {in_mod, out_mod}
if not required.issubset(set(meta['modalities'])):
continue
in_arr, in_cols = load_modality_with_cols(scenario_dir, in_mod, vol, scenario)
out_arr, out_cols = load_modality_with_cols(scenario_dir, out_mod, vol, scenario)
# Apply column filters
if in_filter:
selected_in = in_filter(in_cols)
if not selected_in:
selected_in = in_cols # fallback to all
in_idx = [in_cols.index(c) for c in selected_in]
in_arr = in_arr[:, in_idx]
if out_filter:
selected_out = out_filter(out_cols)
if not selected_out:
selected_out = out_cols
out_idx = [out_cols.index(c) for c in selected_out]
out_arr = out_arr[:, out_idx]
# Align lengths
min_len = min(in_arr.shape[0], out_arr.shape[0])
in_arr = in_arr[:min_len:downsample]
out_arr = out_arr[:min_len:downsample]
if self._input_dim is None:
self._input_dim = in_arr.shape[1]
self._output_dim = out_arr.shape[1]
all_inputs.append(in_arr)
all_outputs.append(out_arr)
# Extract windows
T = in_arr.shape[0]
for start in range(0, T - window_size + 1, stride):
end = start + window_size
self.windows.append((in_arr[start:end], out_arr[start:end]))
# Compute stats
if stats is not None:
self.in_mean, self.in_std, self.out_mean, self.out_std = stats
else:
if all_inputs:
all_in = np.concatenate(all_inputs, axis=0).astype(np.float64)
all_out = np.concatenate(all_outputs, axis=0).astype(np.float64)
self.in_mean = np.mean(all_in, axis=0, keepdims=True).astype(np.float32)
self.in_std = np.std(all_in, axis=0, keepdims=True).astype(np.float32)
self.in_std[self.in_std < 1e-8] = 1.0
self.out_mean = np.mean(all_out, axis=0, keepdims=True).astype(np.float32)
self.out_std = np.std(all_out, axis=0, keepdims=True).astype(np.float32)
self.out_std[self.out_std < 1e-8] = 1.0
else:
d_in = self._input_dim or 1
d_out = self._output_dim or 1
self.in_mean = np.zeros((1, d_in), dtype=np.float32)
self.in_std = np.ones((1, d_in), dtype=np.float32)
self.out_mean = np.zeros((1, d_out), dtype=np.float32)
self.out_std = np.ones((1, d_out), dtype=np.float32)
# Normalize
self.windows = [
((w[0] - self.in_mean) / self.in_std,
(w[1] - self.out_mean) / self.out_std)
for w in self.windows
]
print(f" Loaded {len(self.windows)} windows, "
f"input_dim={self._input_dim}, output_dim={self._output_dim}")
def get_stats(self):
return (self.in_mean, self.in_std, self.out_mean, self.out_std)
@property
def input_dim(self):
return self._input_dim
@property
def output_dim(self):
return self._output_dim
def __len__(self):
return len(self.windows)
def __getitem__(self, idx):
inp, out = self.windows[idx]
return torch.from_numpy(inp), torch.from_numpy(out)
# ============================================================
# Models for sequence-to-sequence regression
# ============================================================
class MLPSeq(nn.Module):
"""Per-frame MLP (simple baseline)."""
def __init__(self, input_dim, output_dim, hidden_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(), nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(), nn.Dropout(0.1),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, x):
return self.net(x)
class UNet1D(nn.Module):
"""1D U-Net encoder-decoder."""
def __init__(self, input_dim, output_dim, hidden_dim=64):
super().__init__()
# Encoder
self.enc1 = nn.Sequential(
nn.Conv1d(input_dim, hidden_dim, 7, padding=3),
nn.BatchNorm1d(hidden_dim), nn.ReLU(),
)
self.enc2 = nn.Sequential(
nn.Conv1d(hidden_dim, hidden_dim * 2, 5, padding=2, stride=2),
nn.BatchNorm1d(hidden_dim * 2), nn.ReLU(),
)
self.enc3 = nn.Sequential(
nn.Conv1d(hidden_dim * 2, hidden_dim * 4, 3, padding=1, stride=2),
nn.BatchNorm1d(hidden_dim * 4), nn.ReLU(),
)
# Decoder
self.dec3 = nn.Sequential(
nn.ConvTranspose1d(hidden_dim * 4, hidden_dim * 2, 4, stride=2, padding=1),
nn.BatchNorm1d(hidden_dim * 2), nn.ReLU(),
)
self.dec2 = nn.Sequential(
nn.ConvTranspose1d(hidden_dim * 4, hidden_dim, 4, stride=2, padding=1),
nn.BatchNorm1d(hidden_dim), nn.ReLU(),
)
self.dec1 = nn.Conv1d(hidden_dim * 2, output_dim, 1)
def forward(self, x):
# x: (B, T, C) -> (B, C, T)
x = x.permute(0, 2, 1)
e1 = self.enc1(x)
e2 = self.enc2(e1)
e3 = self.enc3(e2)
d3 = self.dec3(e3)
# Handle potential size mismatch from stride
d3 = d3[:, :, :e2.shape[2]]
d2 = self.dec2(torch.cat([d3, e2], dim=1))
d2 = d2[:, :, :e1.shape[2]]
out = self.dec1(torch.cat([d2, e1], dim=1))
return out.permute(0, 2, 1) # (B, T, output_dim)
class Seq2SeqLSTM(nn.Module):
"""Encoder-decoder LSTM with attention."""
def __init__(self, input_dim, output_dim, hidden_dim=128):
super().__init__()
self.encoder = nn.LSTM(input_dim, hidden_dim, num_layers=2,
batch_first=True, bidirectional=True, dropout=0.2)
self.decoder = nn.LSTM(hidden_dim * 2, hidden_dim, num_layers=1,
batch_first=True)
self.head = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
enc_out, (h, c) = self.encoder(x)
dec_out, _ = self.decoder(enc_out)
return self.head(dec_out)
class TransformerRegressor(nn.Module):
"""Transformer for sequence-to-sequence regression."""
def __init__(self, input_dim, output_dim, d_model=128, nhead=4, num_layers=2):
super().__init__()
self.input_proj = nn.Linear(input_dim, d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model, nhead, d_model * 4, dropout=0.1, batch_first=True)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
self.head = nn.Linear(d_model, output_dim)
def forward(self, x):
x = self.input_proj(x)
x = self.encoder(x)
return self.head(x)
def build_model(name, input_dim, output_dim, hidden_dim=128):
if name == 'mlp':
return MLPSeq(input_dim, output_dim, hidden_dim)
elif name == 'unet':
return UNet1D(input_dim, output_dim, hidden_dim // 2)
elif name == 'lstm':
return Seq2SeqLSTM(input_dim, output_dim, hidden_dim)
elif name == 'transformer':
return TransformerRegressor(input_dim, output_dim, hidden_dim)
elif name == 'underpressure':
from experiments.published_models import UnderPressureRegressor
return UnderPressureRegressor(input_dim, output_dim, hidden_dim)
elif name == 'emg2pose':
from experiments.published_models import EMG2Pose
return EMG2Pose(input_dim, output_dim, hidden_dim)
elif name == 'emg2pose_direct':
from experiments.published_models import EMG2Pose
return EMG2Pose(input_dim, output_dim, hidden_dim, use_velocity=False)
else:
raise ValueError(f"Unknown model: {name}")
# ============================================================
# Training
# ============================================================
def compute_metrics(preds, targets, out_std):
"""Compute RMSE, R², and Pearson correlation in original scale."""
# Denormalize
preds_orig = preds * out_std + 0 # mean was already subtracted
targets_orig = targets * out_std + 0
rmse = np.sqrt(np.mean((preds_orig - targets_orig) ** 2))
# R² (coefficient of determination)
ss_res = np.sum((targets_orig - preds_orig) ** 2)
ss_tot = np.sum((targets_orig - np.mean(targets_orig, axis=0)) ** 2)
r2 = 1 - ss_res / (ss_tot + 1e-8)
# Per-channel Pearson correlation
n_channels = preds.shape[1] if preds.ndim > 1 else 1
correlations = []
for ch in range(n_channels):
p = preds_orig[:, ch] if n_channels > 1 else preds_orig
t = targets_orig[:, ch] if n_channels > 1 else targets_orig
if np.std(t) > 1e-8 and np.std(p) > 1e-8:
corr, _ = pearsonr(p, t)
correlations.append(corr)
avg_pearson = np.mean(correlations) if correlations else 0.0
return {'rmse': float(rmse), 'r2': float(r2), 'pearson': float(avg_pearson)}
def train_one_epoch(model, loader, criterion, optimizer, device):
model.train()
total_loss = 0
n = 0
for x, y in loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
pred = model(x)
loss = criterion(pred, y)
if torch.isnan(loss) or torch.isinf(loss):
continue
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item() * x.size(0)
n += x.size(0)
return total_loss / max(n, 1)
@torch.no_grad()
def evaluate(model, loader, criterion, device, out_std):
model.eval()
total_loss = 0
n = 0
all_preds, all_targets = [], []
for x, y in loader:
x, y = x.to(device), y.to(device)
pred = model(x)
loss = criterion(pred, y)
total_loss += loss.item() * x.size(0)
n += x.size(0)
all_preds.append(pred.cpu().numpy().reshape(-1, pred.shape[-1]))
all_targets.append(y.cpu().numpy().reshape(-1, y.shape[-1]))
avg_loss = total_loss / n
preds = np.concatenate(all_preds, axis=0)
targets = np.concatenate(all_targets, axis=0)
metrics = compute_metrics(preds, targets, out_std)
metrics['loss'] = avg_loss
return metrics
def run_experiment(args):
set_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n{'='*60}")
print(f"Exp4 Cross-Modal | Subtask: {args.subtask} | Model: {args.model}")
print(f"{'='*60}")
train_ds = CrossModalDataset(TRAIN_VOLS, args.subtask, downsample=args.downsample)
stats = train_ds.get_stats()
val_ds = CrossModalDataset(VAL_VOLS, args.subtask, downsample=args.downsample, stats=stats)
test_ds = CrossModalDataset(TEST_VOLS, args.subtask, downsample=args.downsample, stats=stats)
if len(train_ds) == 0:
print("No training data!")
return None
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False)
# Use test set for validation when val set is empty
if len(val_ds) == 0:
val_loader = test_loader
print(" No val data, using test set for early stopping.")
model = build_model(args.model, train_ds.input_dim, train_ds.output_dim,
args.hidden_dim).to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Params: {n_params:,}, input_dim: {train_ds.input_dim}, output_dim: {train_ds.output_dim}")
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=7, factor=0.5)
exp_name = f"exp4_{args.subtask}_{args.model}"
out_dir = os.path.join(args.output_dir, exp_name)
os.makedirs(out_dir, exist_ok=True)
out_std = train_ds.out_std.flatten()
best_val_loss = float('inf')
best_epoch = 0
patience_counter = 0
for epoch in range(1, args.epochs + 1):
t0 = time.time()
train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
val_metrics = evaluate(model, val_loader, criterion, device, out_std)
scheduler.step(val_metrics['loss'])
elapsed = time.time() - t0
print(f" Epoch {epoch:3d} | Train: {train_loss:.4f} | "
f"Val: loss={val_metrics['loss']:.4f} rmse={val_metrics['rmse']:.4f} "
f"r2={val_metrics['r2']:.4f} pearson={val_metrics['pearson']:.4f} | {elapsed:.1f}s")
if val_metrics['loss'] < best_val_loss:
best_val_loss = val_metrics['loss']
best_epoch = epoch
patience_counter = 0
torch.save(model.state_dict(), os.path.join(out_dir, 'model_best.pt'))
else:
patience_counter += 1
if patience_counter >= args.patience:
print(f" Early stopping at epoch {epoch}")
break
model_path = os.path.join(out_dir, 'model_best.pt')
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, weights_only=True))
else:
print(" WARNING: No best model saved, using last model")
torch.save(model.state_dict(), model_path)
if len(test_ds) == 0:
print(" No test data!")
return None
test_metrics = evaluate(model, test_loader, criterion, device, out_std)
print(f"\n--- Test Results (epoch {best_epoch}) ---", flush=True)
for k, v in test_metrics.items():
print(f" {k}: {v:.4f}", flush=True)
results = {
'experiment': exp_name,
'subtask': args.subtask,
'model': args.model,
'best_epoch': best_epoch,
'test_metrics': test_metrics,
'n_params': n_params,
'input_dim': train_ds.input_dim,
'output_dim': train_ds.output_dim,
'train_windows': len(train_ds),
'args': vars(args),
}
with open(os.path.join(out_dir, 'results.json'), 'w') as f:
json.dump(results, f, indent=2)
return results
def run_all(args):
"""Run all subtasks × models."""
subtasks = ['4a', '4b', '4c']
models = ['mlp', 'unet', 'lstm', 'transformer']
all_results = []
for subtask in subtasks:
for model_name in models:
args.subtask = subtask
args.model = model_name
try:
result = run_experiment(args)
if result:
all_results.append(result)
except Exception as e:
print(f"FAILED: {subtask}/{model_name}: {e}")
import traceback; traceback.print_exc()
all_results.append({'experiment': f"exp4_{subtask}_{model_name}", 'error': str(e)})
summary_path = os.path.join(args.output_dir, 'exp4_summary.json')
with open(summary_path, 'w') as f:
json.dump(all_results, f, indent=2)
print(f"\n{'='*60}")
print(f"{'Subtask':<10} {'Model':<15} {'RMSE':<10} {'R²':<10} {'Pearson':<10}")
print('-' * 55)
for r in all_results:
if 'error' in r:
continue
m = r['test_metrics']
print(f"{r['subtask']:<10} {r['model']:<15} {m['rmse']:.4f} {m['r2']:.4f} {m['pearson']:.4f}")
def main():
parser = argparse.ArgumentParser(description='Exp4: Cross-Modal Prediction')
parser.add_argument('--subtask', type=str, default='4a',
choices=['4a', '4b', '4c'])
parser.add_argument('--model', type=str, default='unet',
choices=['mlp', 'unet', 'lstm', 'transformer',
'underpressure', 'emg2pose', 'emg2pose_direct'])
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=1e-4)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--downsample', type=int, default=2)
parser.add_argument('--patience', type=int, default=10)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--output_dir', type=str,
default='${PULSE_ROOT}/results/exp4')
parser.add_argument('--run_all', action='store_true')
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
if args.run_all:
run_all(args)
else:
run_experiment(args)
if __name__ == '__main__':
main()