PULSE-code / experiments /data /dataset_forecast.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
"""Frame-level future motor-primitive forecasting dataset.
Task definition
---------------
At a sampled anchor time t in a recording:
past = sensor frames over [t - T_obs, t] ← input
future = per-frame verb_fine labels over (t, t + T_fut] ← target
We use NUM_VERB_FINE (= 17) as a sentinel "idle / no segment" class for
frames not covered by any annotated segment, so every future frame has a
valid label (output cardinality = NUM_VERB_FINE + 1 = 18).
Anchors are sampled at fixed stride within each recording so the model
sees both intra-segment future (mostly stationary) and across-boundary
future (where the next-action label changes — the interesting cases).
"""
from __future__ import annotations
import os
import sys
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset
THIS = Path(__file__).resolve()
sys.path.insert(0, str(THIS.parent))
sys.path.insert(0, str(THIS.parents[1]))
try:
from experiments.dataset_seqpred import (
SAMPLING_RATE_HZ, _load_recording_sensors, _load_annotations,
parse_ts_range, TRAIN_VOLS_V3, TEST_VOLS_V3,
DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
)
from experiments.taxonomy import (
classify_segment, NUM_VERB_FINE,
)
except ModuleNotFoundError:
from dataset_seqpred import (
SAMPLING_RATE_HZ, _load_recording_sensors, _load_annotations,
parse_ts_range, TRAIN_VOLS_V3, TEST_VOLS_V3,
DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
)
from taxonomy import classify_segment, NUM_VERB_FINE
IDLE_LABEL = NUM_VERB_FINE # = 17, sentinel for "no segment covers this frame"
NUM_FORECAST_CLASSES = NUM_VERB_FINE + 1 # = 18
class ForecastDataset(Dataset):
"""Forecast next T_fut seconds of per-frame verb_fine given past T_obs."""
def __init__(
self,
volunteers: Sequence[str],
modalities: Sequence[str],
t_obs_sec: float = 1.5,
t_fut_sec: float = 0.5,
anchor_stride_sec: float = 0.25,
downsample: int = 5,
dataset_dir: Path = DEFAULT_DATASET_DIR,
annot_dir: Path = DEFAULT_ANNOT_DIR,
stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None,
expected_dims: Optional[Dict[str, int]] = None,
contact_only: bool = False,
contact_threshold_g: float = 5.0,
log: bool = True,
):
super().__init__()
self.modalities = list(modalities)
self.t_obs_sec = float(t_obs_sec)
self.t_fut_sec = float(t_fut_sec)
self.anchor_stride_sec = float(anchor_stride_sec)
self.downsample = int(downsample)
self.sr = SAMPLING_RATE_HZ // self.downsample
self.dataset_dir = Path(dataset_dir)
self.annot_dir = Path(annot_dir)
self.contact_only = bool(contact_only)
self.contact_threshold_g = float(contact_threshold_g)
# Output time-step counts (after downsample)
self.T_obs = int(round(self.t_obs_sec * self.sr))
self.T_fut = int(round(self.t_fut_sec * self.sr))
self._items: List[dict] = []
# Pre-seed modality dims if caller (e.g. test set) provides them
self._modality_dims: Dict[str, int] = dict(expected_dims) if expected_dims else {}
for vol in volunteers:
vol_dir = self.dataset_dir / vol
if not vol_dir.is_dir():
continue
for scenario_dir in sorted(vol_dir.glob("s*")):
if not scenario_dir.is_dir():
continue
scene = scenario_dir.name
annot_path = self.annot_dir / vol / f"{scene}.json"
if not annot_path.exists():
continue
# Always include pressure for the filter, even if model
# doesn't see it as input. We separate "filter sensors"
# (load_mods) from "model input sensors" (self.modalities).
load_mods = list(dict.fromkeys(list(self.modalities) + ["pressure"]))
try:
sensors_all = _load_recording_sensors(
scenario_dir, vol, scene, load_mods
)
except Exception:
continue
if sensors_all is None or any(a is None for a in sensors_all.values()):
continue
pressure_full = sensors_all.get("pressure") # (T, 50)
# Subset to model-input modalities for everything downstream
sensors = {m: sensors_all[m] for m in self.modalities}
# Track modality dim consistency
for m, arr in sensors.items():
if m in self._modality_dims:
target = self._modality_dims[m]
if arr.shape[1] != target:
if arr.shape[1] < target:
pad = np.zeros((arr.shape[0], target - arr.shape[1]),
dtype=np.float32)
sensors[m] = np.concatenate([arr, pad], axis=1)
else:
sensors[m] = arr[:, :target]
else:
self._modality_dims[m] = arr.shape[1]
T_avail = min(a.shape[0] for a in sensors.values())
if T_avail < (self.T_obs + self.T_fut) * self.downsample:
continue
# Build per-frame verb_fine timeline at full 100 Hz
timeline = np.full(T_avail, IDLE_LABEL, dtype=np.int64)
segs = _load_annotations(annot_path)
for seg in segs:
a = seg.get("action_annotation", {})
labels = classify_segment(a)
if labels is None:
continue
start_sec, end_sec = parse_ts_range(seg.get("timestamp", ""))
s = int(round(start_sec * SAMPLING_RATE_HZ))
e = int(round(end_sec * SAMPLING_RATE_HZ))
s = max(0, s); e = min(T_avail, e)
if e > s:
timeline[s:e] = labels["verb_fine"]
# Downsample timeline to 20 Hz
timeline_ds = timeline[::self.downsample]
T_ds = len(timeline_ds)
# Downsample sensors to 20 Hz (kept as full record;
# we'll slice windows below)
sensors_ds = {m: arr[::self.downsample] for m, arr in sensors.items()}
# Build contact mask at 20 Hz (per-frame): is pressure-sum > thr?
# Pressure is 50 channels; we follow the T2 contact convention
# (sum across all fingertips and threshold at 5 g).
if pressure_full is not None:
pressure_ds = pressure_full[::self.downsample]
contact_ds = pressure_ds.sum(axis=1) > self.contact_threshold_g
else:
contact_ds = np.zeros(T_ds, dtype=bool)
# Sample anchors at fixed stride (in 20 Hz frames)
stride = max(1, int(round(self.anchor_stride_sec * self.sr)))
first_anchor = self.T_obs
last_anchor = T_ds - self.T_fut
if last_anchor <= first_anchor:
continue
for anchor in range(first_anchor, last_anchor + 1, stride):
# contact-rich filter: any contact frame in past or future window?
if self.contact_only:
win = contact_ds[max(0, anchor - self.T_obs):
min(T_ds, anchor + self.T_fut)]
if not win.any():
continue
past_slice = {m: arr[anchor - self.T_obs:anchor]
for m, arr in sensors_ds.items()}
fut_labels = timeline_ds[anchor:anchor + self.T_fut].copy()
# length sanity
if any(w.shape[0] != self.T_obs for w in past_slice.values()):
continue
if fut_labels.shape[0] != self.T_fut:
continue
self._items.append({
"x": past_slice, # dict[mod] -> (T_obs, F_mod)
"y_seq": fut_labels, # (T_fut,) int in [0..17]
"meta": {"vol": vol, "scene": scene, "anchor_idx": int(anchor)},
})
if not self._items:
raise RuntimeError("ForecastDataset: collected 0 anchors. Check annot_dir / modalities.")
# Per-modality z-score using training stats
if stats is None:
stats = self._compute_stats()
self._stats = stats
self._apply_stats(stats)
if log:
print(f"[ForecastDataset] vols={len(volunteers)} "
f"anchors={len(self._items)} "
f"T_obs={self.T_obs} T_fut={self.T_fut} "
f"contact_only={self.contact_only} "
f"modality_dims={self._modality_dims} "
f"sr={self.sr}Hz", flush=True)
# ----- Stats / normalization -----
def _compute_stats(self) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
accs = {m: [] for m in self._modality_dims}
for it in self._items:
for m, w in it["x"].items():
accs[m].append(w)
out = {}
for m, ws in accs.items():
cat = np.concatenate(ws, axis=0)
mu = cat.mean(axis=0)
sd = cat.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
out[m] = (mu.astype(np.float32), sd.astype(np.float32))
return out
def _apply_stats(self, stats):
for it in self._items:
for m, w in it["x"].items():
if m in stats:
mu, sd = stats[m]
it["x"][m] = ((w - mu) / sd).astype(np.float32)
# ----- Dataset protocol -----
def __len__(self):
return len(self._items)
def __getitem__(self, idx):
it = self._items[idx]
x = {m: torch.from_numpy(np.ascontiguousarray(w)) for m, w in it["x"].items()}
y_seq = torch.from_numpy(np.ascontiguousarray(it["y_seq"])) # (T_fut,)
return x, y_seq, it["meta"]
@property
def modality_dims(self):
return dict(self._modality_dims)
def class_freq(self) -> np.ndarray:
c = np.zeros(NUM_FORECAST_CLASSES, dtype=np.int64)
for it in self._items:
for v in it["y_seq"]:
c[int(v)] += 1
return c
def collate_forecast(batch):
"""Stack (x_dict, y_seq, meta) -> batched tensors. All samples share T_obs/T_fut."""
xs, ys, metas = zip(*batch)
B = len(batch)
mods = list(xs[0].keys())
x_out: Dict[str, torch.Tensor] = {}
for m in mods:
x_out[m] = torch.stack([x[m] for x in xs], dim=0) # (B, T_obs, F_mod)
y_out = torch.stack(ys, dim=0) # (B, T_fut)
return x_out, y_out, list(metas)
def build_train_test(
modalities: Sequence[str],
t_obs_sec: float = 1.5,
t_fut_sec: float = 0.5,
anchor_stride_sec: float = 0.25,
downsample: int = 5,
dataset_dir: Path = DEFAULT_DATASET_DIR,
annot_dir: Path = DEFAULT_ANNOT_DIR,
contact_only: bool = False,
contact_threshold_g: float = 5.0,
):
train = ForecastDataset(
TRAIN_VOLS_V3, modalities=modalities,
t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
anchor_stride_sec=anchor_stride_sec, downsample=downsample,
dataset_dir=dataset_dir, annot_dir=annot_dir,
contact_only=contact_only, contact_threshold_g=contact_threshold_g,
stats=None, log=True,
)
test = ForecastDataset(
TEST_VOLS_V3, modalities=modalities,
t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
anchor_stride_sec=anchor_stride_sec, downsample=downsample,
dataset_dir=dataset_dir, annot_dir=annot_dir,
contact_only=contact_only, contact_threshold_g=contact_threshold_g,
stats=train._stats, expected_dims=train._modality_dims, log=True,
)
return train, test
if __name__ == "__main__":
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("--modalities", type=str, default="imu,emg,eyetrack,mocap,pressure")
ap.add_argument("--t_obs", type=float, default=1.5)
ap.add_argument("--t_fut", type=float, default=0.5)
ap.add_argument("--stride", type=float, default=0.25)
args = ap.parse_args()
mods = args.modalities.split(",")
tr, te = build_train_test(
modalities=mods,
t_obs_sec=args.t_obs, t_fut_sec=args.t_fut,
anchor_stride_sec=args.stride,
)
print(f"\nTrain={len(tr)} Test={len(te)} T_obs={tr.T_obs} T_fut={tr.T_fut}")
print(f"Train class freq:\n{tr.class_freq()}")
print(f"Test class freq:\n{te.class_freq()}")
x, y, meta = tr[0]
print(f"Sample: x={ {m: tuple(v.shape) for m,v in x.items()} } y_seq={tuple(y.shape)}")