| """Frame-level future motor-primitive forecasting dataset. |
| |
| Task definition |
| --------------- |
| At a sampled anchor time t in a recording: |
| past = sensor frames over [t - T_obs, t] ← input |
| future = per-frame verb_fine labels over (t, t + T_fut] ← target |
| |
| We use NUM_VERB_FINE (= 17) as a sentinel "idle / no segment" class for |
| frames not covered by any annotated segment, so every future frame has a |
| valid label (output cardinality = NUM_VERB_FINE + 1 = 18). |
| |
| Anchors are sampled at fixed stride within each recording so the model |
| sees both intra-segment future (mostly stationary) and across-boundary |
| future (where the next-action label changes — the interesting cases). |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import sys |
| from pathlib import Path |
| from typing import Dict, List, Optional, Sequence, Tuple |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
|
|
| THIS = Path(__file__).resolve() |
| sys.path.insert(0, str(THIS.parent)) |
| sys.path.insert(0, str(THIS.parents[1])) |
|
|
| try: |
| from experiments.dataset_seqpred import ( |
| SAMPLING_RATE_HZ, _load_recording_sensors, _load_annotations, |
| parse_ts_range, TRAIN_VOLS_V3, TEST_VOLS_V3, |
| DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR, |
| ) |
| from experiments.taxonomy import ( |
| classify_segment, NUM_VERB_FINE, |
| ) |
| except ModuleNotFoundError: |
| from dataset_seqpred import ( |
| SAMPLING_RATE_HZ, _load_recording_sensors, _load_annotations, |
| parse_ts_range, TRAIN_VOLS_V3, TEST_VOLS_V3, |
| DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR, |
| ) |
| from taxonomy import classify_segment, NUM_VERB_FINE |
|
|
|
|
| IDLE_LABEL = NUM_VERB_FINE |
| NUM_FORECAST_CLASSES = NUM_VERB_FINE + 1 |
|
|
|
|
| class ForecastDataset(Dataset): |
| """Forecast next T_fut seconds of per-frame verb_fine given past T_obs.""" |
|
|
| def __init__( |
| self, |
| volunteers: Sequence[str], |
| modalities: Sequence[str], |
| t_obs_sec: float = 1.5, |
| t_fut_sec: float = 0.5, |
| anchor_stride_sec: float = 0.25, |
| downsample: int = 5, |
| dataset_dir: Path = DEFAULT_DATASET_DIR, |
| annot_dir: Path = DEFAULT_ANNOT_DIR, |
| stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None, |
| expected_dims: Optional[Dict[str, int]] = None, |
| contact_only: bool = False, |
| contact_threshold_g: float = 5.0, |
| log: bool = True, |
| ): |
| super().__init__() |
| self.modalities = list(modalities) |
| self.t_obs_sec = float(t_obs_sec) |
| self.t_fut_sec = float(t_fut_sec) |
| self.anchor_stride_sec = float(anchor_stride_sec) |
| self.downsample = int(downsample) |
| self.sr = SAMPLING_RATE_HZ // self.downsample |
| self.dataset_dir = Path(dataset_dir) |
| self.annot_dir = Path(annot_dir) |
| self.contact_only = bool(contact_only) |
| self.contact_threshold_g = float(contact_threshold_g) |
|
|
| |
| self.T_obs = int(round(self.t_obs_sec * self.sr)) |
| self.T_fut = int(round(self.t_fut_sec * self.sr)) |
|
|
| self._items: List[dict] = [] |
| |
| self._modality_dims: Dict[str, int] = dict(expected_dims) if expected_dims else {} |
|
|
| for vol in volunteers: |
| vol_dir = self.dataset_dir / vol |
| if not vol_dir.is_dir(): |
| continue |
| for scenario_dir in sorted(vol_dir.glob("s*")): |
| if not scenario_dir.is_dir(): |
| continue |
| scene = scenario_dir.name |
| annot_path = self.annot_dir / vol / f"{scene}.json" |
| if not annot_path.exists(): |
| continue |
|
|
| |
| |
| |
| load_mods = list(dict.fromkeys(list(self.modalities) + ["pressure"])) |
| try: |
| sensors_all = _load_recording_sensors( |
| scenario_dir, vol, scene, load_mods |
| ) |
| except Exception: |
| continue |
| if sensors_all is None or any(a is None for a in sensors_all.values()): |
| continue |
| pressure_full = sensors_all.get("pressure") |
| |
| sensors = {m: sensors_all[m] for m in self.modalities} |
|
|
| |
| for m, arr in sensors.items(): |
| if m in self._modality_dims: |
| target = self._modality_dims[m] |
| if arr.shape[1] != target: |
| if arr.shape[1] < target: |
| pad = np.zeros((arr.shape[0], target - arr.shape[1]), |
| dtype=np.float32) |
| sensors[m] = np.concatenate([arr, pad], axis=1) |
| else: |
| sensors[m] = arr[:, :target] |
| else: |
| self._modality_dims[m] = arr.shape[1] |
|
|
| T_avail = min(a.shape[0] for a in sensors.values()) |
| if T_avail < (self.T_obs + self.T_fut) * self.downsample: |
| continue |
|
|
| |
| timeline = np.full(T_avail, IDLE_LABEL, dtype=np.int64) |
| segs = _load_annotations(annot_path) |
| for seg in segs: |
| a = seg.get("action_annotation", {}) |
| labels = classify_segment(a) |
| if labels is None: |
| continue |
| start_sec, end_sec = parse_ts_range(seg.get("timestamp", "")) |
| s = int(round(start_sec * SAMPLING_RATE_HZ)) |
| e = int(round(end_sec * SAMPLING_RATE_HZ)) |
| s = max(0, s); e = min(T_avail, e) |
| if e > s: |
| timeline[s:e] = labels["verb_fine"] |
|
|
| |
| timeline_ds = timeline[::self.downsample] |
| T_ds = len(timeline_ds) |
|
|
| |
| |
| sensors_ds = {m: arr[::self.downsample] for m, arr in sensors.items()} |
|
|
| |
| |
| |
| if pressure_full is not None: |
| pressure_ds = pressure_full[::self.downsample] |
| contact_ds = pressure_ds.sum(axis=1) > self.contact_threshold_g |
| else: |
| contact_ds = np.zeros(T_ds, dtype=bool) |
|
|
| |
| stride = max(1, int(round(self.anchor_stride_sec * self.sr))) |
| first_anchor = self.T_obs |
| last_anchor = T_ds - self.T_fut |
| if last_anchor <= first_anchor: |
| continue |
|
|
| for anchor in range(first_anchor, last_anchor + 1, stride): |
| |
| if self.contact_only: |
| win = contact_ds[max(0, anchor - self.T_obs): |
| min(T_ds, anchor + self.T_fut)] |
| if not win.any(): |
| continue |
| past_slice = {m: arr[anchor - self.T_obs:anchor] |
| for m, arr in sensors_ds.items()} |
| fut_labels = timeline_ds[anchor:anchor + self.T_fut].copy() |
| |
| if any(w.shape[0] != self.T_obs for w in past_slice.values()): |
| continue |
| if fut_labels.shape[0] != self.T_fut: |
| continue |
| self._items.append({ |
| "x": past_slice, |
| "y_seq": fut_labels, |
| "meta": {"vol": vol, "scene": scene, "anchor_idx": int(anchor)}, |
| }) |
|
|
| if not self._items: |
| raise RuntimeError("ForecastDataset: collected 0 anchors. Check annot_dir / modalities.") |
|
|
| |
| if stats is None: |
| stats = self._compute_stats() |
| self._stats = stats |
| self._apply_stats(stats) |
|
|
| if log: |
| print(f"[ForecastDataset] vols={len(volunteers)} " |
| f"anchors={len(self._items)} " |
| f"T_obs={self.T_obs} T_fut={self.T_fut} " |
| f"contact_only={self.contact_only} " |
| f"modality_dims={self._modality_dims} " |
| f"sr={self.sr}Hz", flush=True) |
|
|
| |
| def _compute_stats(self) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: |
| accs = {m: [] for m in self._modality_dims} |
| for it in self._items: |
| for m, w in it["x"].items(): |
| accs[m].append(w) |
| out = {} |
| for m, ws in accs.items(): |
| cat = np.concatenate(ws, axis=0) |
| mu = cat.mean(axis=0) |
| sd = cat.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd) |
| out[m] = (mu.astype(np.float32), sd.astype(np.float32)) |
| return out |
|
|
| def _apply_stats(self, stats): |
| for it in self._items: |
| for m, w in it["x"].items(): |
| if m in stats: |
| mu, sd = stats[m] |
| it["x"][m] = ((w - mu) / sd).astype(np.float32) |
|
|
| |
| def __len__(self): |
| return len(self._items) |
|
|
| def __getitem__(self, idx): |
| it = self._items[idx] |
| x = {m: torch.from_numpy(np.ascontiguousarray(w)) for m, w in it["x"].items()} |
| y_seq = torch.from_numpy(np.ascontiguousarray(it["y_seq"])) |
| return x, y_seq, it["meta"] |
|
|
| @property |
| def modality_dims(self): |
| return dict(self._modality_dims) |
|
|
| def class_freq(self) -> np.ndarray: |
| c = np.zeros(NUM_FORECAST_CLASSES, dtype=np.int64) |
| for it in self._items: |
| for v in it["y_seq"]: |
| c[int(v)] += 1 |
| return c |
|
|
|
|
| def collate_forecast(batch): |
| """Stack (x_dict, y_seq, meta) -> batched tensors. All samples share T_obs/T_fut.""" |
| xs, ys, metas = zip(*batch) |
| B = len(batch) |
| mods = list(xs[0].keys()) |
| x_out: Dict[str, torch.Tensor] = {} |
| for m in mods: |
| x_out[m] = torch.stack([x[m] for x in xs], dim=0) |
| y_out = torch.stack(ys, dim=0) |
| return x_out, y_out, list(metas) |
|
|
|
|
| def build_train_test( |
| modalities: Sequence[str], |
| t_obs_sec: float = 1.5, |
| t_fut_sec: float = 0.5, |
| anchor_stride_sec: float = 0.25, |
| downsample: int = 5, |
| dataset_dir: Path = DEFAULT_DATASET_DIR, |
| annot_dir: Path = DEFAULT_ANNOT_DIR, |
| contact_only: bool = False, |
| contact_threshold_g: float = 5.0, |
| ): |
| train = ForecastDataset( |
| TRAIN_VOLS_V3, modalities=modalities, |
| t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec, |
| anchor_stride_sec=anchor_stride_sec, downsample=downsample, |
| dataset_dir=dataset_dir, annot_dir=annot_dir, |
| contact_only=contact_only, contact_threshold_g=contact_threshold_g, |
| stats=None, log=True, |
| ) |
| test = ForecastDataset( |
| TEST_VOLS_V3, modalities=modalities, |
| t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec, |
| anchor_stride_sec=anchor_stride_sec, downsample=downsample, |
| dataset_dir=dataset_dir, annot_dir=annot_dir, |
| contact_only=contact_only, contact_threshold_g=contact_threshold_g, |
| stats=train._stats, expected_dims=train._modality_dims, log=True, |
| ) |
| return train, test |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--modalities", type=str, default="imu,emg,eyetrack,mocap,pressure") |
| ap.add_argument("--t_obs", type=float, default=1.5) |
| ap.add_argument("--t_fut", type=float, default=0.5) |
| ap.add_argument("--stride", type=float, default=0.25) |
| args = ap.parse_args() |
| mods = args.modalities.split(",") |
| tr, te = build_train_test( |
| modalities=mods, |
| t_obs_sec=args.t_obs, t_fut_sec=args.t_fut, |
| anchor_stride_sec=args.stride, |
| ) |
| print(f"\nTrain={len(tr)} Test={len(te)} T_obs={tr.T_obs} T_fut={tr.T_fut}") |
| print(f"Train class freq:\n{tr.class_freq()}") |
| print(f"Test class freq:\n{te.class_freq()}") |
| x, y, meta = tr[0] |
| print(f"Sample: x={ {m: tuple(v.shape) for m,v in x.items()} } y_seq={tuple(y.shape)}") |
|
|