PULSE-code / experiments /analysis /analysis_figures.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
#!/usr/bin/env python3
"""Generate three showcase figures for the main paper:
1. Eye-Hand-Contact coordination (gaze fixation + hand velocity + pressure)
2. Pressure fingerprints per action category
3. 3D hand trajectory colored by pressure
"""
import os, glob, json, re
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
DATASET = "${PULSE_ROOT}/dataset"
OUT_DIR = "${PULSE_ROOT}/paper/figures"
os.makedirs(OUT_DIR, exist_ok=True)
PRESSURE_THRESHOLD = 5.0
FPS = 100
# ============================================================
# Shared data-loading helpers
# ============================================================
def load_pressure(scenario_dir):
"""Return (T, 2) array of (right_total, left_total) pressure."""
f = os.path.join(scenario_dir, "aligned_pressure_100hz.csv")
if not os.path.exists(f):
return None
df = pd.read_csv(f, low_memory=False)
r_cols = [c for c in df.columns if c.startswith('R') and c.endswith('(g)')]
l_cols = [c for c in df.columns if c.startswith('L') and c.endswith('(g)')]
if len(r_cols) < 20 or len(l_cols) < 20:
return None
r = df[r_cols].apply(pd.to_numeric, errors='coerce').fillna(0).values
l = df[l_cols].apply(pd.to_numeric, errors='coerce').fillna(0).values
return r, l # (T, 25) each
def load_emg(scenario_dir):
f = os.path.join(scenario_dir, "aligned_emg_100hz.csv")
if not os.path.exists(f):
return None
df = pd.read_csv(f, low_memory=False)
numeric = [c for c in df.select_dtypes(include=[np.number]).columns
if c not in ('time', 'UTC', 'Frame')]
if len(numeric) < 4:
return None
return np.nan_to_num(df[numeric].values.astype(np.float32))
def load_gaze(scenario_dir):
f = os.path.join(scenario_dir, "aligned_eyetrack_100hz.csv")
if not os.path.exists(f):
return None
df = pd.read_csv(f, low_memory=False)
gx_col = [c for c in df.columns if 'Gaze X' in c and 'Scene Cam' in c]
gy_col = [c for c in df.columns if 'Gaze Y' in c and 'Scene Cam' in c]
if gx_col and gy_col:
gx = pd.to_numeric(df[gx_col[0]], errors='coerce').fillna(0).values
gy = pd.to_numeric(df[gy_col[0]], errors='coerce').fillna(0).values
return np.stack([gx, gy], axis=1)
return None
def load_mocap_hand(scenario_dir, vol, scenario):
"""Return wrist 3D position (T,3) and tip position summary."""
f = os.path.join(scenario_dir, f"aligned_{vol}{scenario}_s_Q.tsv")
if not os.path.exists(f):
return None, None
df = pd.read_csv(f, sep='\t', low_memory=False)
# Right hand wrist (try several naming patterns)
candidates = [
['RightHand_X','RightHand_Y','RightHand_Z'],
['R_Hand_X','R_Hand_Y','R_Hand_Z'],
['Q_RWristIn_X','Q_RWristIn_Y','Q_RWristIn_Z'],
]
r_wrist = None
for cs in candidates:
if all(c in df.columns for c in cs):
r_wrist = df[cs].apply(pd.to_numeric, errors='coerce').fillna(0).values
break
l_wrist = None
for cs_l in [['LeftHand_X','LeftHand_Y','LeftHand_Z'],
['L_Hand_X','L_Hand_Y','L_Hand_Z'],
['Q_LWristIn_X','Q_LWristIn_Y','Q_LWristIn_Z']]:
if all(c in df.columns for c in cs_l):
l_wrist = df[cs_l].apply(pd.to_numeric, errors='coerce').fillna(0).values
break
return r_wrist, l_wrist
def compute_velocity(position, window=5):
"""Magnitude of velocity (after smoothing)."""
vel = np.zeros_like(position)
vel[1:] = position[1:] - position[:-1]
mag = np.linalg.norm(vel, axis=1)
try:
mag = savgol_filter(mag, window_length=min(window*2+1, len(mag)-1 if len(mag)%2==0 else len(mag)), polyorder=2)
except:
pass
return mag
def detect_grasp_events(hand_pressure, threshold=PRESSURE_THRESHOLD, min_gap=50):
"""Detect pressure onset events."""
total = hand_pressure.sum(axis=1) if hand_pressure.ndim == 2 else hand_pressure
above = total > threshold
onsets = []
last_state = False
for i, a in enumerate(above):
if a and not last_state:
if i + 10 < len(above) and np.mean(above[i:i+10]) > 0.7:
if not onsets or i - onsets[-1] > min_gap:
onsets.append(i)
last_state = True
elif not a and last_state:
if i + 5 < len(above) and np.mean(above[i:i+5]) < 0.3:
last_state = False
return onsets
def emg_envelope(emg, window=20):
rect = np.abs(emg - np.mean(emg, axis=0))
kernel = np.ones(window) / window
env = np.stack([np.convolve(rect[:, c], kernel, mode='same') for c in range(rect.shape[1])], axis=1)
return env.sum(axis=1)
def gaze_velocity(gaze_xy, window=5):
"""Magnitude of gaze velocity — high = saccade, low = fixation."""
v = np.zeros_like(gaze_xy)
v[1:] = gaze_xy[1:] - gaze_xy[:-1]
mag = np.linalg.norm(v, axis=1)
try:
mag = savgol_filter(mag, window_length=min(window*2+1, 15), polyorder=2)
except:
pass
return mag
# ============================================================
# FIGURE 1: Eye-Hand-Contact coordination
# ============================================================
def make_eye_hand_contact_figure():
print("=== Figure 1: Eye-Hand-Contact coordination ===")
context = 200 # 2s before + 0.5s after
after = 50
events = [] # list of dicts: gaze_vel, hand_vel, pressure, all shape (context+after,)
for vol_dir in sorted(glob.glob(f"{DATASET}/v*")):
vol = os.path.basename(vol_dir)
for sd in sorted(glob.glob(f"{vol_dir}/s*")):
scenario = os.path.basename(sd)
meta_path = os.path.join(sd, "alignment_metadata.json")
if not os.path.exists(meta_path):
continue
meta = json.load(open(meta_path))
if not {'pressure', 'eyetrack', 'mocap'}.issubset(set(meta['modalities'])):
continue
p = load_pressure(sd)
g = load_gaze(sd)
r_wrist, _ = load_mocap_hand(sd, vol, scenario)
if p is None or g is None or r_wrist is None:
continue
r_p, _ = p
min_len = min(len(r_p), len(g), len(r_wrist))
r_p, g, r_wrist = r_p[:min_len], g[:min_len], r_wrist[:min_len]
hand_vel = compute_velocity(r_wrist)
gvel = gaze_velocity(g)
total_p = r_p.sum(axis=1)
onsets = detect_grasp_events(r_p)
for o in onsets:
if o < context or o + after >= min_len:
continue
# Require quiescent pre-grasp
rest_window = gvel[o-150:o-100]
vel_rest = hand_vel[o-150:o-100]
if np.mean(vel_rest) > hand_vel[o-50:o].mean() * 0.5:
continue
gv_seg = gvel[o-context:o+after]
hv_seg = hand_vel[o-context:o+after]
pr_seg = total_p[o-context:o+after]
if len(gv_seg) != context+after or np.isnan(gv_seg).any():
continue
events.append({'gv': gv_seg, 'hv': hv_seg, 'p': pr_seg})
if len(events) > 400:
break
if len(events) > 400:
break
print(f" Collected {len(events)} events")
if len(events) < 50:
print(" Not enough events, skipping")
return
# Gaze: fixation = low gaze velocity, so use "1 - normalized gaze velocity"
# This represents "gaze fixation stability"
def norm01(arr):
arr = np.array(arr)
arr = arr - arr.min(axis=1, keepdims=True)
mx = arr.max(axis=1, keepdims=True)
return arr / (mx + 1e-8)
gv_stack = norm01([e['gv'] for e in events])
hv_stack = norm01([e['hv'] for e in events])
p_stack = norm01([e['p'] for e in events])
# Smooth gaze to show fixation trend
# Gaze fixation = low velocity. Plot (1 - gaze_velocity) -> rises as gaze fixates
gaze_fix = 1 - gv_stack # high = fixating
# Normalize each event's fix to [0,1] for display
gaze_fix_plot = norm01(gaze_fix)
time_axis = np.arange(-context, after) * 10 # ms
fig, ax = plt.subplots(figsize=(9, 4.5))
for stack, color, label in [
(gaze_fix_plot, '#8E44AD', 'Gaze fixation'),
(hv_stack, '#3498DB', 'Hand velocity'),
(p_stack, '#27AE60', 'Pressure (contact)'),
]:
mean = stack.mean(axis=0)
std = stack.std(axis=0)
ax.plot(time_axis, mean, color=color, linewidth=2.5, label=label)
ax.fill_between(time_axis, mean - std*0.4, mean + std*0.4, color=color, alpha=0.15)
ax.axvline(0, color='black', linestyle='--', linewidth=1.2, alpha=0.7)
ax.set_xlabel('Time relative to contact onset (ms)', fontsize=12)
ax.set_ylabel('Normalized amplitude', fontsize=12)
ax.set_title(f'Gaze → Hand → Contact coordination ({len(events)} events)',
fontsize=13, fontweight='bold')
ax.set_xlim(-2000, 500)
ax.legend(loc='upper left', fontsize=10, frameon=True)
ax.grid(True, alpha=0.3)
ax.set_ylim(-0.05, 1.1)
plt.tight_layout()
out_path = os.path.join(OUT_DIR, 'eye_hand_contact.pdf')
plt.savefig(out_path, dpi=150, bbox_inches='tight')
plt.savefig(out_path.replace('.pdf', '.png'), dpi=150, bbox_inches='tight')
plt.close()
print(f" Saved {out_path}")
# ============================================================
# FIGURE 2: Pressure fingerprints per action category
# ============================================================
def make_pressure_fingerprints():
print("\n=== Figure 2: Pressure fingerprints ===")
import sys
sys.path.insert(0, '${PULSE_ROOT}')
from experiments.train_exp2 import load_annotations
# For each action class, accumulate mean pressure profile (50 channels)
action_r_sum = {} # action -> (sum 25 channels, count)
action_l_sum = {}
for vol_dir in sorted(glob.glob(f"{DATASET}/v*")):
vol = os.path.basename(vol_dir)
for sd in sorted(glob.glob(f"{vol_dir}/s*")):
scenario = os.path.basename(sd)
meta_path = os.path.join(sd, "alignment_metadata.json")
if not os.path.exists(meta_path):
continue
meta = json.load(open(meta_path))
if 'pressure' not in set(meta['modalities']):
continue
p = load_pressure(sd)
if p is None:
continue
r_p, l_p = p
labels = load_annotations(vol, scenario, len(r_p), sampling_rate=100, use_coarse=False)
if labels is None:
continue
labels = labels[:len(r_p)]
from experiments.train_exp2 import ACTION_NAMES
for a_id, a_name in ACTION_NAMES.items():
if a_name == 'Idle':
continue
mask = labels == a_id
if mask.sum() < 10:
continue
r_mean = r_p[mask].mean(axis=0)
l_mean = l_p[mask].mean(axis=0)
if a_name not in action_r_sum:
action_r_sum[a_name] = [np.zeros(25), 0]
action_l_sum[a_name] = [np.zeros(25), 0]
action_r_sum[a_name][0] += r_mean * mask.sum()
action_r_sum[a_name][1] += mask.sum()
action_l_sum[a_name][0] += l_mean * mask.sum()
action_l_sum[a_name][1] += mask.sum()
# Compute mean for each action
results = {}
for a_name in action_r_sum:
r_cnt = action_r_sum[a_name][1]
l_cnt = action_l_sum[a_name][1]
if r_cnt == 0 or l_cnt == 0:
continue
results[a_name] = {
'r': action_r_sum[a_name][0] / r_cnt,
'l': action_l_sum[a_name][0] / l_cnt,
}
print(f" Action categories: {list(results.keys())}")
if not results:
print(" No data")
return
# Pick top 6 by frequency (they have most data)
# Sort by right-hand count
sorted_actions = sorted(results.keys(),
key=lambda a: action_r_sum[a][1], reverse=True)[:6]
# Plot as 2-row grid: top row = right hand, bottom row = left hand (or combine as single image)
# Use 25 points arranged as a 5x5 grid (stylized hand layout)
# Actual finger layout is complex; for visualization use simple grid
# Layout (rough hand analogy): arrange as fingertips at top, palm base at bottom
# Index mapping — 25 points, organized heuristically:
# row 0 (fingertips): 1-5
# row 1-2: finger segments
# row 3-4: palm area
def point_to_xy(idx):
"""Map channel index (0-24) to 2D hand position (stylized)."""
# Simple 5x5 grid
row = idx // 5
col = idx % 5
return col, 4 - row # flip y so fingertips at top
n = len(sorted_actions)
fig, axes = plt.subplots(2, n, figsize=(2.0 * n, 4.8), squeeze=False)
vmax = max(max(results[a]['r'].max(), results[a]['l'].max()) for a in sorted_actions)
for i, a in enumerate(sorted_actions):
for row, (hand, title) in enumerate([('r', 'Right'), ('l', 'Left')]):
ax = axes[row][i]
data = results[a][hand]
grid = np.zeros((5, 5))
for idx, v in enumerate(data):
x, y = point_to_xy(idx)
grid[4-y, x] = v
im = ax.imshow(grid, cmap='hot', vmin=0, vmax=vmax, aspect='equal')
ax.set_xticks([]); ax.set_yticks([])
if row == 0:
ax.set_title(a, fontsize=11, fontweight='bold')
if i == 0:
ax.set_ylabel(title, fontsize=10)
fig.suptitle('Per-action fingertip pressure signatures (mean across events)',
fontsize=12, fontweight='bold', y=0.98)
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.7, pad=0.02)
cbar.set_label('Pressure (g)', fontsize=10)
plt.savefig(os.path.join(OUT_DIR, 'pressure_fingerprints.pdf'), bbox_inches='tight')
plt.savefig(os.path.join(OUT_DIR, 'pressure_fingerprints.png'), dpi=150, bbox_inches='tight')
plt.close()
print(f" Saved pressure_fingerprints.pdf")
# ============================================================
# FIGURE 3: 3D hand trajectory colored by pressure
# ============================================================
def make_3d_trajectory():
print("\n=== Figure 3: 3D hand trajectory + pressure coloring ===")
from mpl_toolkits.mplot3d import Axes3D
# Pick a few illustrative recordings with rich grasping — use v1 s3 (kitchen) or similar
candidates = [('v1', 's3'), ('v2', 's4'), ('v1', 's5'), ('v1', 's7')]
picked = []
for vol, scn in candidates:
sd = f"{DATASET}/{vol}/{scn}"
if not os.path.isdir(sd):
continue
p = load_pressure(sd)
r_wrist, _ = load_mocap_hand(sd, vol, scn)
if p is None or r_wrist is None:
continue
r_p, _ = p
min_len = min(len(r_p), len(r_wrist))
total_p = r_p[:min_len].sum(axis=1)
r_wrist = r_wrist[:min_len]
# Take a window that contains a grasp
onsets = detect_grasp_events(r_p[:min_len])
if not onsets:
continue
# Take ~3s centred on first onset
o = onsets[0]
start = max(0, o - 150)
end = min(min_len, o + 150)
traj = r_wrist[start:end]
pressure = total_p[start:end]
picked.append((vol, scn, traj, pressure))
if len(picked) >= 3:
break
if not picked:
print(" No valid recordings found")
return
fig = plt.figure(figsize=(3.5 * len(picked), 4))
for i, (vol, scn, traj, pr) in enumerate(picked):
ax = fig.add_subplot(1, len(picked), i+1, projection='3d')
# Normalize pressure for coloring
pr_norm = pr / (pr.max() + 1e-6)
# Plot as colored line segments
for j in range(len(traj) - 1):
x = traj[j:j+2, 0]
y = traj[j:j+2, 1]
z = traj[j:j+2, 2]
c = plt.cm.coolwarm(pr_norm[j])
ax.plot(x, y, z, color=c, linewidth=2.5, alpha=0.85)
# Mark contact point
contact_idx = np.argmax(pr)
ax.scatter(traj[contact_idx, 0], traj[contact_idx, 1], traj[contact_idx, 2],
color='red', s=50, marker='*', zorder=5, label='Peak contact')
ax.set_title(f'{vol}/{scn}', fontsize=10)
ax.set_xlabel('X', fontsize=8); ax.set_ylabel('Y', fontsize=8); ax.set_zlabel('Z', fontsize=8)
ax.tick_params(labelsize=7)
# Colorbar
sm = plt.cm.ScalarMappable(cmap='coolwarm', norm=matplotlib.colors.Normalize(vmin=0, vmax=1))
sm.set_array([])
cbar = fig.colorbar(sm, ax=fig.axes, shrink=0.6, pad=0.02)
cbar.set_label('Normalised pressure', fontsize=10)
fig.suptitle('Right-hand wrist 3D trajectory coloured by fingertip pressure',
fontsize=12, fontweight='bold', y=1.02)
plt.savefig(os.path.join(OUT_DIR, 'hand_trajectory_3d.pdf'), bbox_inches='tight')
plt.savefig(os.path.join(OUT_DIR, 'hand_trajectory_3d.png'), dpi=150, bbox_inches='tight')
plt.close()
print(f" Saved hand_trajectory_3d.pdf")
if __name__ == '__main__':
make_eye_hand_contact_figure()
make_pressure_fingerprints()
make_3d_trajectory()
print("\nAll figures generated in", OUT_DIR)