File size: 14,558 Bytes
b4b2877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
"""
Multimodal scene dataset for Experiment 1: Activity Recognition.
Loads aligned 100Hz multi-modal data, supports modality selection,
subject-independent splits, and variable-length sequence handling.
"""

import os
import json
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

DATASET_DIR = "${PULSE_ROOT}/dataset"

MODALITY_FILES = {
    'mocap': None,  # Special: uses aligned_{vol}{scene}_s_Q.tsv (skeleton data)
    'emg': 'aligned_emg_100hz.csv',
    'eyetrack': 'aligned_eyetrack_100hz.csv',
    'imu': 'aligned_imu_100hz.csv',
    'pressure': 'aligned_pressure_100hz.csv',
    'video': 'video_features_100hz.npy',  # ViT-B/16 (ImageNet)
    'videomae': 'video_features_videomae_100hz.npy',  # VideoMAE (Kinetics-400)
}


def get_modality_filepath(scenario_dir, modality, vol=None, scenario=None):
    """Return the file path for a given modality.

    Mocap uses a special naming pattern: aligned_{vol}{scene}_s_Q.tsv
    All other modalities use MODALITY_FILES directly.
    """
    if modality == 'mocap':
        if vol is None or scenario is None:
            raise ValueError("vol and scenario required for mocap modality")
        return os.path.join(scenario_dir, f"aligned_{vol}{scenario}_s_Q.tsv")
    return os.path.join(scenario_dir, MODALITY_FILES[modality])

SKIP_COLS = {'Frame', 'Time', 'time', 'UTC'}
SKIP_COL_SUFFIXES = (' Type',)

# Eyetrack exports sometimes include volunteer-specific marker/ICA columns.
# Benchmark inputs use the fixed 24 core gaze columns below; recordings missing
# any core column are skipped instead of truncating the full dataset.
EYETRACK_SKIP_PATTERNS = ('Index Of Cognitive Activity', 'Marker Coordinates', 'Markers_')
EYETRACK_CORE_COLS = [
    'Dikablis Glasses 3_Eye Data_Original_Pupil X',
    'Dikablis Glasses 3_Eye Data_Original_Pupil Y',
    'Dikablis Glasses 3_Eye Data_Original_Left Eye_Pupil X',
    'Dikablis Glasses 3_Eye Data_Original_Left Eye_Pupil Y',
    'Dikablis Glasses 3_Eye Data_Original_Left Eye_Pupil Area',
    'Dikablis Glasses 3_Eye Data_Original_Left Eye_Pupil Height',
    'Dikablis Glasses 3_Eye Data_Original_Left Eye_Pupil Width',
    'Dikablis Glasses 3_Eye Data_Original_Left Eye_Fixations_Fixations',
    'Dikablis Glasses 3_Eye Data_Original_Left Eye_Fixations_Fixations Duration',
    'Dikablis Glasses 3_Eye Data_Original_Left Eye_Saccades_Saccades',
    'Dikablis Glasses 3_Eye Data_Original_Left Eye_Saccades_Saccades Duration',
    'Dikablis Glasses 3_Eye Data_Original_Left Eye_Saccades_Saccades Angle',
    'Dikablis Glasses 3_Eye Data_Original_Right Eye_Pupil X',
    'Dikablis Glasses 3_Eye Data_Original_Right Eye_Pupil Y',
    'Dikablis Glasses 3_Eye Data_Original_Right Eye_Pupil Area',
    'Dikablis Glasses 3_Eye Data_Original_Right Eye_Pupil Height',
    'Dikablis Glasses 3_Eye Data_Original_Right Eye_Pupil Width',
    'Dikablis Glasses 3_Eye Data_Original_Right Eye_Fixations_Fixations',
    'Dikablis Glasses 3_Eye Data_Original_Right Eye_Fixations_Fixations Duration',
    'Dikablis Glasses 3_Eye Data_Original_Right Eye_Saccades_Saccades',
    'Dikablis Glasses 3_Eye Data_Original_Right Eye_Saccades_Saccades Duration',
    'Dikablis Glasses 3_Eye Data_Original_Right Eye_Saccades_Saccades Angle',
    'Dikablis Glasses 3_Field Data_Scene Cam_Original_Gaze_Gaze X',
    'Dikablis Glasses 3_Field Data_Scene Cam_Original_Gaze_Gaze Y',
]
EYETRACK_EXCLUDED_RECORDINGS = {('v1', 's1'), ('v14', 's8')}

SCENE_LABELS = {f's{i}': i - 1 for i in range(1, 9)}
NUM_CLASSES = 8

TRAIN_VOLS = ['v1', 'v2', 'v11', 'v12', 'v13', 'v15', 'v16', 'v17', 'v19', 'v20', 'v21', 'v22', 'v23', 'v24']
VAL_VOLS = []  # No separate val set; use train for early stopping or cross-val
TEST_VOLS = ['v25', 'v26', 'v27', 'v3']


def _preprocess_mocap_skeleton(arr, feat_cols):
    """Convert absolute skeleton coords to hip-relative positions + velocity.

    Input:  (T, F) with absolute XYZ + quaternions
    Output: (T, F + N_pos) where N_pos = number of XYZ position features
            [hip-relative features, XYZ velocity]
    """
    col_to_idx = {c: i for i, c in enumerate(feat_cols)}

    # Find hip position for subtraction
    hip_x_idx = col_to_idx.get('Hips_X')
    hip_y_idx = col_to_idx.get('Hips_Y')
    hip_z_idx = col_to_idx.get('Hips_Z')
    if hip_x_idx is None:
        return arr  # No hip joint found, skip preprocessing

    # Identify all position columns (_X, _Y, _Z)
    x_indices = [i for i, c in enumerate(feat_cols) if c.endswith('_X')]
    y_indices = [i for i, c in enumerate(feat_cols) if c.endswith('_Y')]
    z_indices = [i for i, c in enumerate(feat_cols) if c.endswith('_Z')]
    all_pos_indices = sorted(x_indices + y_indices + z_indices)

    # 1. Make XYZ positions hip-relative
    arr_rel = arr.copy()
    hip_xyz = arr[:, [hip_x_idx, hip_y_idx, hip_z_idx]]  # (T, 3)
    for idx in x_indices:
        arr_rel[:, idx] -= hip_xyz[:, 0]
    for idx in y_indices:
        arr_rel[:, idx] -= hip_xyz[:, 1]
    for idx in z_indices:
        arr_rel[:, idx] -= hip_xyz[:, 2]

    # 2. Compute velocity of position features only
    pos_data = arr_rel[:, all_pos_indices]  # (T, N_pos)
    velocity = np.zeros_like(pos_data)
    velocity[1:] = pos_data[1:] - pos_data[:-1]

    # 3. Concatenate: [hip-relative features (pos+quat), position velocity]
    return np.concatenate([arr_rel, velocity], axis=1)


def load_modality_array(filepath, modality):
    """Load a modality CSV/TSV/NPY and return numpy_array.
    Returns None if data is corrupted (extreme values or mostly zeros)."""
    # Video features stored as .npy
    if filepath.endswith('.npy'):
        if not os.path.exists(filepath):
            return None
        arr = np.load(filepath).astype(np.float32)
        arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
        return arr
    # Mocap uses TSV with tab separator
    sep = '\t' if filepath.endswith('.tsv') else ','
    df = pd.read_csv(filepath, sep=sep, low_memory=False)
    df.columns = [str(c).strip() for c in df.columns]
    if modality == 'eyetrack':
        parts = os.path.normpath(filepath).split(os.sep)
        if len(parts) >= 3 and (parts[-3], parts[-2]) in EYETRACK_EXCLUDED_RECORDINGS:
            return None
    feat_cols = [c for c in df.columns
                 if c not in SKIP_COLS
                 and not any(c.endswith(s) for s in SKIP_COL_SUFFIXES)]
    if modality == 'eyetrack':
        feat_cols = [c for c in EYETRACK_CORE_COLS if c in feat_cols]
        if len(feat_cols) != len(EYETRACK_CORE_COLS):
            return None
    sub = df[feat_cols]
    # Coerce non-numeric columns
    obj_cols = sub.select_dtypes(include=['object']).columns
    if len(obj_cols) > 0:
        sub = sub.copy()
        sub[obj_cols] = sub[obj_cols].apply(pd.to_numeric, errors='coerce')
    arr = sub.values.astype(np.float64)
    arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
    # Quality check: reject samples with extreme values (corrupted data)
    max_abs = np.max(np.abs(arr))
    if max_abs > 1e6:
        return None  # Corrupted
    # Quality check: reject samples that are mostly zeros (sensor dropout).
    # Pressure and EMG are legitimately zero for long periods (rest, no grip)
    # so we only apply the strict near-total-loss check to the modalities
    # where a flat-zero stream is a clear dropout signal.
    if modality not in ("pressure", "emg"):
        zero_ratio = np.mean(arr == 0.0)
        if zero_ratio > 0.9:
            return None  # Near-total data loss
    # Mocap skeleton: convert to hip-relative + velocity
    if modality == 'mocap' and filepath.endswith('.tsv'):
        arr = _preprocess_mocap_skeleton(arr, feat_cols)
    arr = arr.astype(np.float32)
    return arr


class MultimodalSceneDataset(Dataset):
    """Dataset for scene-level classification from multimodal time series."""

    def __init__(self, volunteers, modalities, downsample=5, stats=None):
        self.modalities = modalities
        self.downsample = downsample
        self.data = []
        self.labels = []
        self.sample_info = []
        self._modality_dims = {}

        for vol in volunteers:
            vol_dir = os.path.join(DATASET_DIR, vol)
            if not os.path.isdir(vol_dir):
                continue
            for scenario in sorted(os.listdir(vol_dir)):
                scenario_dir = os.path.join(vol_dir, scenario)
                if not os.path.isdir(scenario_dir) or scenario not in SCENE_LABELS:
                    continue
                meta_path = os.path.join(scenario_dir, 'alignment_metadata.json')
                if not os.path.exists(meta_path):
                    continue
                with open(meta_path) as f:
                    meta = json.load(f)
                available = set(meta['modalities'])
                if not set(modalities).issubset(available):
                    continue

                parts = []
                skip = False
                for mod in modalities:
                    if mod == 'mocap':
                        # Skeleton data: aligned_{vol}{scene}_s_Q.tsv
                        tsv_name = f"aligned_{vol}{scenario}_s_Q.tsv"
                        filepath = os.path.join(scenario_dir, tsv_name)
                    else:
                        filepath = os.path.join(scenario_dir, MODALITY_FILES[mod])
                    if not os.path.exists(filepath):
                        skip = True
                        break
                    arr = load_modality_array(filepath, mod)
                    if arr is None:
                        print(f"  SKIP {vol}/{scenario} {mod}: corrupted data", flush=True)
                        skip = True
                        break
                    # Validate dimension consistency
                    if mod in self._modality_dims and arr.shape[1] != self._modality_dims[mod]:
                        print(f"  WARNING: {vol}/{scenario} {mod} dim {arr.shape[1]} "
                              f"!= expected {self._modality_dims[mod]}, padding/truncating",
                              flush=True)
                        expected = self._modality_dims[mod]
                        if arr.shape[1] < expected:
                            pad = np.zeros((arr.shape[0], expected - arr.shape[1]), dtype=np.float32)
                            arr = np.concatenate([arr, pad], axis=1)
                        else:
                            arr = arr[:, :expected]
                    if mod not in self._modality_dims:
                        self._modality_dims[mod] = arr.shape[1]
                    parts.append(arr)

                if skip:
                    continue

                min_len = min(p.shape[0] for p in parts)
                parts = [p[:min_len] for p in parts]
                combined = np.concatenate(parts, axis=1)
                combined = combined[::downsample]

                self.data.append(combined)
                self.labels.append(SCENE_LABELS[scenario])
                self.sample_info.append(f"{vol}/{scenario}")

        print(f"  Loaded {len(self.data)} samples, modality dims: {self._modality_dims}, "
              f"total feat dim: {sum(self._modality_dims.values())}", flush=True)

        # Normalization (compute in float64 to avoid overflow)
        if stats is not None:
            self.mean, self.std = stats
        else:
            self._compute_stats()
        for i in range(len(self.data)):
            self.data[i] = ((self.data[i].astype(np.float64) - self.mean) / self.std).astype(np.float32)
            self.data[i] = np.nan_to_num(self.data[i], nan=0.0, posinf=0.0, neginf=0.0)

    def _compute_stats(self):
        # Use float64 for accumulation to prevent overflow
        all_frames = np.concatenate(self.data, axis=0).astype(np.float64)
        self.mean = np.mean(all_frames, axis=0, keepdims=True)
        self.std = np.std(all_frames, axis=0, keepdims=True)
        self.std[self.std < 1e-8] = 1.0

    def get_stats(self):
        return (self.mean, self.std)

    @property
    def feat_dim(self):
        return sum(self._modality_dims.values())

    @property
    def modality_dims(self):
        return dict(self._modality_dims)

    def get_class_weights(self):
        counts = np.bincount(self.labels, minlength=NUM_CLASSES).astype(np.float32)
        counts[counts == 0] = 1.0
        weights = 1.0 / counts
        weights = weights / weights.sum() * NUM_CLASSES
        return torch.FloatTensor(weights)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.from_numpy(self.data[idx]), self.labels[idx]


def collate_fn(batch):
    """Pad variable-length sequences and create masks."""
    sequences, labels = zip(*batch)
    lengths = torch.LongTensor([s.shape[0] for s in sequences])
    padded = pad_sequence(sequences, batch_first=True, padding_value=0.0)
    max_len = padded.shape[1]
    mask = torch.arange(max_len).unsqueeze(0) < lengths.unsqueeze(1)
    labels = torch.LongTensor(labels)
    return padded, labels, mask, lengths


def get_dataloaders(modalities, batch_size=16, downsample=5, num_workers=0):
    """Create train/val/test DataLoaders with proper normalization."""
    print("Loading training data...", flush=True)
    train_ds = MultimodalSceneDataset(TRAIN_VOLS, modalities, downsample)
    stats = train_ds.get_stats()

    print("Loading validation data...", flush=True)
    val_ds = MultimodalSceneDataset(VAL_VOLS, modalities, downsample, stats=stats)

    print("Loading test data...", flush=True)
    test_ds = MultimodalSceneDataset(TEST_VOLS, modalities, downsample, stats=stats)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              collate_fn=collate_fn, num_workers=num_workers,
                              drop_last=False)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                            collate_fn=collate_fn, num_workers=num_workers)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False,
                             collate_fn=collate_fn, num_workers=num_workers)

    info = {
        'feat_dim': train_ds.feat_dim,
        'modality_dims': train_ds.modality_dims,
        'num_classes': NUM_CLASSES,
        'train_size': len(train_ds),
        'val_size': len(val_ds),
        'test_size': len(test_ds),
        'class_weights': train_ds.get_class_weights(),
    }
    return train_loader, val_loader, test_loader, info