"""Frame-level future motor-primitive forecasting dataset. Task definition --------------- At a sampled anchor time t in a recording: past = sensor frames over [t - T_obs, t] ← input future = per-frame verb_fine labels over (t, t + T_fut] ← target We use NUM_VERB_FINE (= 17) as a sentinel "idle / no segment" class for frames not covered by any annotated segment, so every future frame has a valid label (output cardinality = NUM_VERB_FINE + 1 = 18). Anchors are sampled at fixed stride within each recording so the model sees both intra-segment future (mostly stationary) and across-boundary future (where the next-action label changes — the interesting cases). """ from __future__ import annotations import os import sys from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple import numpy as np import torch from torch.utils.data import Dataset THIS = Path(__file__).resolve() sys.path.insert(0, str(THIS.parent)) sys.path.insert(0, str(THIS.parents[1])) try: from experiments.dataset_seqpred import ( SAMPLING_RATE_HZ, _load_recording_sensors, _load_annotations, parse_ts_range, TRAIN_VOLS_V3, TEST_VOLS_V3, DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR, ) from experiments.taxonomy import ( classify_segment, NUM_VERB_FINE, ) except ModuleNotFoundError: from dataset_seqpred import ( SAMPLING_RATE_HZ, _load_recording_sensors, _load_annotations, parse_ts_range, TRAIN_VOLS_V3, TEST_VOLS_V3, DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR, ) from taxonomy import classify_segment, NUM_VERB_FINE IDLE_LABEL = NUM_VERB_FINE # = 17, sentinel for "no segment covers this frame" NUM_FORECAST_CLASSES = NUM_VERB_FINE + 1 # = 18 class ForecastDataset(Dataset): """Forecast next T_fut seconds of per-frame verb_fine given past T_obs.""" def __init__( self, volunteers: Sequence[str], modalities: Sequence[str], t_obs_sec: float = 1.5, t_fut_sec: float = 0.5, anchor_stride_sec: float = 0.25, downsample: int = 5, dataset_dir: Path = DEFAULT_DATASET_DIR, annot_dir: Path = DEFAULT_ANNOT_DIR, stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None, expected_dims: Optional[Dict[str, int]] = None, contact_only: bool = False, contact_threshold_g: float = 5.0, log: bool = True, ): super().__init__() self.modalities = list(modalities) self.t_obs_sec = float(t_obs_sec) self.t_fut_sec = float(t_fut_sec) self.anchor_stride_sec = float(anchor_stride_sec) self.downsample = int(downsample) self.sr = SAMPLING_RATE_HZ // self.downsample self.dataset_dir = Path(dataset_dir) self.annot_dir = Path(annot_dir) self.contact_only = bool(contact_only) self.contact_threshold_g = float(contact_threshold_g) # Output time-step counts (after downsample) self.T_obs = int(round(self.t_obs_sec * self.sr)) self.T_fut = int(round(self.t_fut_sec * self.sr)) self._items: List[dict] = [] # Pre-seed modality dims if caller (e.g. test set) provides them self._modality_dims: Dict[str, int] = dict(expected_dims) if expected_dims else {} for vol in volunteers: vol_dir = self.dataset_dir / vol if not vol_dir.is_dir(): continue for scenario_dir in sorted(vol_dir.glob("s*")): if not scenario_dir.is_dir(): continue scene = scenario_dir.name annot_path = self.annot_dir / vol / f"{scene}.json" if not annot_path.exists(): continue # Always include pressure for the filter, even if model # doesn't see it as input. We separate "filter sensors" # (load_mods) from "model input sensors" (self.modalities). load_mods = list(dict.fromkeys(list(self.modalities) + ["pressure"])) try: sensors_all = _load_recording_sensors( scenario_dir, vol, scene, load_mods ) except Exception: continue if sensors_all is None or any(a is None for a in sensors_all.values()): continue pressure_full = sensors_all.get("pressure") # (T, 50) # Subset to model-input modalities for everything downstream sensors = {m: sensors_all[m] for m in self.modalities} # Track modality dim consistency for m, arr in sensors.items(): if m in self._modality_dims: target = self._modality_dims[m] if arr.shape[1] != target: if arr.shape[1] < target: pad = np.zeros((arr.shape[0], target - arr.shape[1]), dtype=np.float32) sensors[m] = np.concatenate([arr, pad], axis=1) else: sensors[m] = arr[:, :target] else: self._modality_dims[m] = arr.shape[1] T_avail = min(a.shape[0] for a in sensors.values()) if T_avail < (self.T_obs + self.T_fut) * self.downsample: continue # Build per-frame verb_fine timeline at full 100 Hz timeline = np.full(T_avail, IDLE_LABEL, dtype=np.int64) segs = _load_annotations(annot_path) for seg in segs: a = seg.get("action_annotation", {}) labels = classify_segment(a) if labels is None: continue start_sec, end_sec = parse_ts_range(seg.get("timestamp", "")) s = int(round(start_sec * SAMPLING_RATE_HZ)) e = int(round(end_sec * SAMPLING_RATE_HZ)) s = max(0, s); e = min(T_avail, e) if e > s: timeline[s:e] = labels["verb_fine"] # Downsample timeline to 20 Hz timeline_ds = timeline[::self.downsample] T_ds = len(timeline_ds) # Downsample sensors to 20 Hz (kept as full record; # we'll slice windows below) sensors_ds = {m: arr[::self.downsample] for m, arr in sensors.items()} # Build contact mask at 20 Hz (per-frame): is pressure-sum > thr? # Pressure is 50 channels; we follow the T2 contact convention # (sum across all fingertips and threshold at 5 g). if pressure_full is not None: pressure_ds = pressure_full[::self.downsample] contact_ds = pressure_ds.sum(axis=1) > self.contact_threshold_g else: contact_ds = np.zeros(T_ds, dtype=bool) # Sample anchors at fixed stride (in 20 Hz frames) stride = max(1, int(round(self.anchor_stride_sec * self.sr))) first_anchor = self.T_obs last_anchor = T_ds - self.T_fut if last_anchor <= first_anchor: continue for anchor in range(first_anchor, last_anchor + 1, stride): # contact-rich filter: any contact frame in past or future window? if self.contact_only: win = contact_ds[max(0, anchor - self.T_obs): min(T_ds, anchor + self.T_fut)] if not win.any(): continue past_slice = {m: arr[anchor - self.T_obs:anchor] for m, arr in sensors_ds.items()} fut_labels = timeline_ds[anchor:anchor + self.T_fut].copy() # length sanity if any(w.shape[0] != self.T_obs for w in past_slice.values()): continue if fut_labels.shape[0] != self.T_fut: continue self._items.append({ "x": past_slice, # dict[mod] -> (T_obs, F_mod) "y_seq": fut_labels, # (T_fut,) int in [0..17] "meta": {"vol": vol, "scene": scene, "anchor_idx": int(anchor)}, }) if not self._items: raise RuntimeError("ForecastDataset: collected 0 anchors. Check annot_dir / modalities.") # Per-modality z-score using training stats if stats is None: stats = self._compute_stats() self._stats = stats self._apply_stats(stats) if log: print(f"[ForecastDataset] vols={len(volunteers)} " f"anchors={len(self._items)} " f"T_obs={self.T_obs} T_fut={self.T_fut} " f"contact_only={self.contact_only} " f"modality_dims={self._modality_dims} " f"sr={self.sr}Hz", flush=True) # ----- Stats / normalization ----- def _compute_stats(self) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: accs = {m: [] for m in self._modality_dims} for it in self._items: for m, w in it["x"].items(): accs[m].append(w) out = {} for m, ws in accs.items(): cat = np.concatenate(ws, axis=0) mu = cat.mean(axis=0) sd = cat.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd) out[m] = (mu.astype(np.float32), sd.astype(np.float32)) return out def _apply_stats(self, stats): for it in self._items: for m, w in it["x"].items(): if m in stats: mu, sd = stats[m] it["x"][m] = ((w - mu) / sd).astype(np.float32) # ----- Dataset protocol ----- def __len__(self): return len(self._items) def __getitem__(self, idx): it = self._items[idx] x = {m: torch.from_numpy(np.ascontiguousarray(w)) for m, w in it["x"].items()} y_seq = torch.from_numpy(np.ascontiguousarray(it["y_seq"])) # (T_fut,) return x, y_seq, it["meta"] @property def modality_dims(self): return dict(self._modality_dims) def class_freq(self) -> np.ndarray: c = np.zeros(NUM_FORECAST_CLASSES, dtype=np.int64) for it in self._items: for v in it["y_seq"]: c[int(v)] += 1 return c def collate_forecast(batch): """Stack (x_dict, y_seq, meta) -> batched tensors. All samples share T_obs/T_fut.""" xs, ys, metas = zip(*batch) B = len(batch) mods = list(xs[0].keys()) x_out: Dict[str, torch.Tensor] = {} for m in mods: x_out[m] = torch.stack([x[m] for x in xs], dim=0) # (B, T_obs, F_mod) y_out = torch.stack(ys, dim=0) # (B, T_fut) return x_out, y_out, list(metas) def build_train_test( modalities: Sequence[str], t_obs_sec: float = 1.5, t_fut_sec: float = 0.5, anchor_stride_sec: float = 0.25, downsample: int = 5, dataset_dir: Path = DEFAULT_DATASET_DIR, annot_dir: Path = DEFAULT_ANNOT_DIR, contact_only: bool = False, contact_threshold_g: float = 5.0, ): train = ForecastDataset( TRAIN_VOLS_V3, modalities=modalities, t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec, anchor_stride_sec=anchor_stride_sec, downsample=downsample, dataset_dir=dataset_dir, annot_dir=annot_dir, contact_only=contact_only, contact_threshold_g=contact_threshold_g, stats=None, log=True, ) test = ForecastDataset( TEST_VOLS_V3, modalities=modalities, t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec, anchor_stride_sec=anchor_stride_sec, downsample=downsample, dataset_dir=dataset_dir, annot_dir=annot_dir, contact_only=contact_only, contact_threshold_g=contact_threshold_g, stats=train._stats, expected_dims=train._modality_dims, log=True, ) return train, test if __name__ == "__main__": import argparse ap = argparse.ArgumentParser() ap.add_argument("--modalities", type=str, default="imu,emg,eyetrack,mocap,pressure") ap.add_argument("--t_obs", type=float, default=1.5) ap.add_argument("--t_fut", type=float, default=0.5) ap.add_argument("--stride", type=float, default=0.25) args = ap.parse_args() mods = args.modalities.split(",") tr, te = build_train_test( modalities=mods, t_obs_sec=args.t_obs, t_fut_sec=args.t_fut, anchor_stride_sec=args.stride, ) print(f"\nTrain={len(tr)} Test={len(te)} T_obs={tr.T_obs} T_fut={tr.T_fut}") print(f"Train class freq:\n{tr.class_freq()}") print(f"Test class freq:\n{te.class_freq()}") x, y, meta = tr[0] print(f"Sample: x={ {m: tuple(v.shape) for m,v in x.items()} } y_seq={tuple(y.shape)}")