| """Anchor-based binary "is_grasping" classification dataset (T5 v3 / TGSR). |
| |
| At each sampled anchor t in a recording: |
| past = sensor frames over [t - T_obs, t] ← input |
| label = majority vote of grasp-annotation mask over (t, t+T_fut] ← binary class |
| |
| Ground-truth source: annotations_v3 verb segments. A frame is marked |
| "is_grasp" if it falls inside a segment whose action_name belongs to |
| GRASP_VERBS (set below). The label is annotation-derived, completely |
| independent of pressure — so adding/removing pressure as input does |
| NOT leak the label. |
| |
| This is the cleanest test of "does pressure improve recognition of |
| object-interaction state when human-annotated grasp segments are GT?" |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import sys |
| from pathlib import Path |
| from typing import Dict, List, Optional, Sequence, Tuple |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
|
|
| THIS = Path(__file__).resolve() |
| sys.path.insert(0, str(THIS.parent)) |
| sys.path.insert(0, str(THIS.parents[1])) |
|
|
| try: |
| from experiments.dataset_seqpred import ( |
| SAMPLING_RATE_HZ, _load_recording_sensors, |
| TRAIN_VOLS_V3, TEST_VOLS_V3, |
| DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR, |
| ) |
| except ModuleNotFoundError: |
| from dataset_seqpred import ( |
| SAMPLING_RATE_HZ, _load_recording_sensors, |
| TRAIN_VOLS_V3, TEST_VOLS_V3, |
| DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR, |
| ) |
|
|
|
|
| GRASP_VERBS = { |
| "grasp", "hold", "pick_up", "move", "place", "put_down", |
| "pull", "rotate", "insert", "remove", |
| } |
| |
| |
| LIFT_VERBS = {"grasp", "open", "move", "pick_up", "hold"} |
|
|
| |
| |
| VERB_LIST = [ |
| "background", |
| "grasp", "move", "place", "adjust", "pick_up", |
| "close", "put_down", "pull", "hold", "open", |
| "rotate", "release", "push", "insert", "remove", |
| "align", "stabilize", |
| ] |
| VERB_TO_IDX = {v: i for i, v in enumerate(VERB_LIST)} |
|
|
| |
| |
| |
| |
| |
| OBJECT_TOP_LIST = [ |
| "_other", |
| "sealed jar", "towel", "tablecloth", "box", "pot", |
| "rice bowl", "tape", "pants", "spoon", "plate", |
| "marker", "cloth", "laptop", "toothbrush case", "tea canister", |
| ] |
| OBJECT_TO_IDX = {o: i for i, o in enumerate(OBJECT_TOP_LIST)} |
| EVENT_NAMES = {0: "non-contact", 1: "pre-contact", 2: "steady-grip", 3: "release"} |
| CLASS_NAMES_BINARY = {0: "non-grasp", 1: "grasp"} |
| CLASS_NAMES_THREE = {0: "no-grasp", 1: "attempted", 2: "sustained"} |
| |
| CLASS_NAMES = CLASS_NAMES_BINARY |
|
|
|
|
| def _parse_one(x: str, fmt_mode: str) -> float: |
| p = x.split(":") |
| if len(p) == 2: |
| return int(p[0]) * 60 + int(p[1]) |
| if fmt_mode == "hhmmss": |
| return int(p[0]) * 3600 + int(p[1]) * 60 + int(p[2]) |
| return int(p[0]) * 60 + int(p[1]) + int(p[2]) / 30.0 |
|
|
|
|
| def _detect_fmt(segments, rec_sec: float) -> str: |
| for s in segments: |
| b = s["timestamp"].split("-")[1] |
| p = b.split(":") |
| if len(p) == 3: |
| hh = int(p[0]) * 3600 + int(p[1]) * 60 + int(p[2]) |
| if hh > rec_sec * 1.05: |
| return "mmssff" |
| return "hhmmss" |
|
|
|
|
| def build_object_label(annot_path: Path, n_frames: int, |
| sr: int = SAMPLING_RATE_HZ) -> np.ndarray: |
| """Per-frame object index (top-15 + '_other' fallback as class 0).""" |
| label = np.zeros(n_frames, dtype=np.int8) |
| if not annot_path.exists(): |
| return label |
| try: |
| ann = json.load(open(annot_path)) |
| except Exception: |
| return label |
| segments = ann.get("segments", []) |
| if not segments: |
| return label |
| rec_sec = n_frames / sr |
| fmt = _detect_fmt(segments, rec_sec) |
| for s in segments: |
| obj = s.get("action_annotation", {}).get("object_name") |
| idx = OBJECT_TO_IDX.get(obj, 0) |
| if idx == 0: |
| continue |
| try: |
| a, b = s["timestamp"].split("-") |
| t0 = _parse_one(a, fmt); t1 = _parse_one(b, fmt) |
| except Exception: |
| continue |
| if t1 <= t0 or t1 > rec_sec * 1.10: |
| continue |
| i0 = max(0, int(round(t0 * sr))) |
| i1 = min(n_frames, int(round(t1 * sr))) |
| label[i0:i1] = idx |
| return label |
|
|
|
|
| def build_lift_eligible_mask(annot_path: Path, n_frames: int, |
| sr: int = SAMPLING_RATE_HZ) -> np.ndarray: |
| """Per-frame bool: True if frame is inside a segment that meets the |
| lifted-grasp criterion: verb ∈ LIFT_VERBS OR hand_type == 'both'. |
| Used by 3-class label_mode when require_lift_for_sustained=True.""" |
| mask = np.zeros(n_frames, dtype=bool) |
| if not annot_path.exists(): |
| return mask |
| try: |
| ann = json.load(open(annot_path)) |
| except Exception: |
| return mask |
| segments = ann.get("segments", []) |
| if not segments: |
| return mask |
| rec_sec = n_frames / sr |
| fmt = _detect_fmt(segments, rec_sec) |
| for s in segments: |
| a = s.get("action_annotation", {}) |
| verb = a.get("action_name") |
| hand = a.get("hand_type", "") |
| is_lift = (verb in LIFT_VERBS) or (hand == "both") |
| if not is_lift: |
| continue |
| try: |
| ts0, ts1 = s["timestamp"].split("-") |
| t0 = _parse_one(ts0, fmt); t1 = _parse_one(ts1, fmt) |
| except Exception: |
| continue |
| if t1 <= t0 or t1 > rec_sec * 1.10: |
| continue |
| i0 = max(0, int(round(t0 * sr))) |
| i1 = min(n_frames, int(round(t1 * sr))) |
| mask[i0:i1] = True |
| return mask |
|
|
|
|
| def build_verb_label(annot_path: Path, n_frames: int, |
| sr: int = SAMPLING_RATE_HZ) -> np.ndarray: |
| """Per-frame verb index (int8). Default (no segment) = 0 (background).""" |
| label = np.zeros(n_frames, dtype=np.int8) |
| if not annot_path.exists(): |
| return label |
| try: |
| ann = json.load(open(annot_path)) |
| except Exception: |
| return label |
| segments = ann.get("segments", []) |
| if not segments: |
| return label |
| rec_sec = n_frames / sr |
| fmt = _detect_fmt(segments, rec_sec) |
| for s in segments: |
| verb = s.get("action_annotation", {}).get("action_name") |
| v_idx = VERB_TO_IDX.get(verb, 0) |
| if v_idx == 0: |
| continue |
| try: |
| a, b = s["timestamp"].split("-") |
| t0 = _parse_one(a, fmt); t1 = _parse_one(b, fmt) |
| except Exception: |
| continue |
| if t1 <= t0 or t1 > rec_sec * 1.10: |
| continue |
| i0 = max(0, int(round(t0 * sr))) |
| i1 = min(n_frames, int(round(t1 * sr))) |
| label[i0:i1] = v_idx |
| return label |
|
|
|
|
| def build_grasp_mask(annot_path: Path, n_frames: int, |
| sr: int = SAMPLING_RATE_HZ) -> np.ndarray: |
| """Return bool array of shape (n_frames,).""" |
| mask = np.zeros(n_frames, dtype=bool) |
| if not annot_path.exists(): |
| return mask |
| try: |
| ann = json.load(open(annot_path)) |
| except Exception: |
| return mask |
| segments = ann.get("segments", []) |
| if not segments: |
| return mask |
| rec_sec = n_frames / sr |
| fmt = _detect_fmt(segments, rec_sec) |
| for s in segments: |
| verb = s.get("action_annotation", {}).get("action_name") |
| if verb not in GRASP_VERBS: |
| continue |
| try: |
| a, b = s["timestamp"].split("-") |
| t0 = _parse_one(a, fmt); t1 = _parse_one(b, fmt) |
| except Exception: |
| continue |
| if t1 <= t0 or t1 > rec_sec * 1.10: |
| continue |
| i0 = max(0, int(round(t0 * sr))) |
| i1 = min(n_frames, int(round(t1 * sr))) |
| mask[i0:i1] = True |
| return mask |
|
|
|
|
| class GraspStateDataset(Dataset): |
| """Predict binary 'is_grasping' label over future window from past sensor signals.""" |
|
|
| def __init__( |
| self, |
| volunteers: Sequence[str], |
| input_modalities: Sequence[str], |
| t_obs_sec: float = 1.0, |
| t_fut_sec: float = 0.5, |
| anchor_stride_sec: float = 0.25, |
| downsample: int = 5, |
| dataset_dir: Path = DEFAULT_DATASET_DIR, |
| annot_dir: Path = DEFAULT_ANNOT_DIR, |
| contact_threshold_g: float = 5.0, |
| per_cell_threshold_g: float = 10.0, |
| min_active_cells: int = 3, |
| use_per_cell_contact: bool = True, |
| label_mode: str = "binary", |
| sustained_threshold_sec: float = 0.3, |
| require_lift_for_sustained: bool = False, |
| per_class_max: Optional[int] = None, |
| input_stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None, |
| expected_input_dims: Optional[Dict[str, int]] = None, |
| majority_threshold: float = 0.5, |
| rng_seed: int = 0, |
| log: bool = True, |
| ): |
| super().__init__() |
| self.input_modalities = list(input_modalities) |
| self.t_obs_sec = float(t_obs_sec) |
| self.t_fut_sec = float(t_fut_sec) |
| self.anchor_stride_sec = float(anchor_stride_sec) |
| self.downsample = int(downsample) |
| self.sr = SAMPLING_RATE_HZ // self.downsample |
| self.dataset_dir = Path(dataset_dir) |
| self.annot_dir = Path(annot_dir) |
| self.contact_threshold_g = float(contact_threshold_g) |
| self.per_cell_threshold_g = float(per_cell_threshold_g) |
| self.min_active_cells = int(min_active_cells) |
| self.use_per_cell_contact = bool(use_per_cell_contact) |
| self.label_mode = str(label_mode) |
| if self.label_mode not in ("binary", "three_class", "verb", "object"): |
| raise ValueError(f"label_mode must be binary|three_class|verb|object, got {label_mode}") |
| if self.label_mode == "binary": |
| self.num_classes = 2 |
| elif self.label_mode == "three_class": |
| self.num_classes = 3 |
| elif self.label_mode == "verb": |
| self.num_classes = len(VERB_LIST) |
| else: |
| self.num_classes = len(OBJECT_TOP_LIST) |
| self.sustained_threshold_sec = float(sustained_threshold_sec) |
| self.require_lift_for_sustained = bool(require_lift_for_sustained) |
| self.per_class_max = per_class_max |
| self.majority_threshold = float(majority_threshold) |
| self.T_obs = int(round(self.t_obs_sec * self.sr)) |
| self.T_fut = int(round(self.t_fut_sec * self.sr)) |
|
|
| self._items: List[dict] = [] |
| self._modality_dims: Dict[str, int] = dict(expected_input_dims) if expected_input_dims else {} |
| rng = np.random.default_rng(rng_seed) |
|
|
| |
| load_mods = list(dict.fromkeys(list(self.input_modalities) + ["pressure"])) |
|
|
| |
| pools: Dict[int, List[dict]] = {c: [] for c in range(self.num_classes)} |
| sustained_thresh_frames = int(round(self.sustained_threshold_sec * self.sr)) |
|
|
| for vol in volunteers: |
| vol_dir = self.dataset_dir / vol |
| if not vol_dir.is_dir(): |
| continue |
| for scenario_dir in sorted(vol_dir.glob("s*")): |
| if not scenario_dir.is_dir(): |
| continue |
| scene = scenario_dir.name |
| annot_path = self.annot_dir / vol / f"{scene}.json" |
| if not annot_path.exists(): |
| continue |
| try: |
| sensors_all = _load_recording_sensors( |
| scenario_dir, vol, scene, load_mods |
| ) |
| except Exception: |
| continue |
| if sensors_all is None or any(a is None for a in sensors_all.values()): |
| continue |
|
|
| pressure_full = sensors_all["pressure"] |
| input_arrs = {m: sensors_all[m] for m in self.input_modalities} |
| for m, arr in input_arrs.items(): |
| self._enforce_dim(input_arrs, m, arr, self._modality_dims) |
|
|
| T_avail = min(a.shape[0] for a in input_arrs.values()) |
| T_avail = min(T_avail, pressure_full.shape[0]) |
| if T_avail < (self.T_obs + self.T_fut) * self.downsample: |
| continue |
|
|
| |
| mask_full = build_grasp_mask(annot_path, T_avail, |
| sr=SAMPLING_RATE_HZ) |
| if self.label_mode == "verb": |
| verb_full = build_verb_label(annot_path, T_avail, sr=SAMPLING_RATE_HZ) |
| verb_ds = verb_full[:T_avail:self.downsample] |
| else: |
| verb_ds = None |
| if self.label_mode == "object": |
| obj_full = build_object_label(annot_path, T_avail, sr=SAMPLING_RATE_HZ) |
| obj_ds = obj_full[:T_avail:self.downsample] |
| else: |
| obj_ds = None |
| if self.label_mode == "three_class" and self.require_lift_for_sustained: |
| lift_full = build_lift_eligible_mask(annot_path, T_avail, sr=SAMPLING_RATE_HZ) |
| lift_eligible_ds = lift_full[:T_avail:self.downsample] |
| else: |
| lift_eligible_ds = None |
| input_ds = {m: arr[:T_avail:self.downsample] for m, arr in input_arrs.items()} |
| pressure_ds = pressure_full[:T_avail:self.downsample] |
| mask_ds = mask_full[:T_avail:self.downsample] |
| T_ds = mask_ds.shape[0] |
| if self.use_per_cell_contact: |
| |
| n_active = (pressure_ds > self.per_cell_threshold_g).sum(axis=1) |
| contact_frame = n_active >= self.min_active_cells |
| else: |
| pressure_sum = pressure_ds.sum(axis=1) |
| contact_frame = pressure_sum > self.contact_threshold_g |
|
|
| stride = max(1, int(round(self.anchor_stride_sec * self.sr))) |
| first_anchor = self.T_obs |
| last_anchor = T_ds - self.T_fut |
| if last_anchor <= first_anchor: |
| continue |
|
|
| for anchor in range(first_anchor, last_anchor + 1, stride): |
| fut_mask = mask_ds[anchor:anchor + self.T_fut] |
| if fut_mask.shape[0] != self.T_fut: |
| continue |
| annotation_is_grasp = fut_mask.mean() >= self.majority_threshold |
|
|
| if self.label_mode == "binary": |
| label = int(annotation_is_grasp) |
| elif self.label_mode == "three_class": |
| if not annotation_is_grasp: |
| label = 0 |
| else: |
| |
| fut_contact = contact_frame[anchor:anchor + self.T_fut] |
| longest = 0; cur = 0 |
| for v in fut_contact: |
| if v: cur += 1; longest = max(longest, cur) |
| else: cur = 0 |
| is_sustained = longest >= sustained_thresh_frames |
| if is_sustained and self.require_lift_for_sustained: |
| |
| |
| fut_lift = lift_eligible_ds[anchor:anchor + self.T_fut] |
| if fut_lift.mean() < 0.5: |
| is_sustained = False |
| label = 2 if is_sustained else 1 |
| elif self.label_mode == "verb": |
| fut_v = verb_ds[anchor:anchor + self.T_fut] |
| counts = np.bincount(fut_v, minlength=self.num_classes) |
| label = int(np.argmax(counts)) |
| else: |
| fut_o = obj_ds[anchor:anchor + self.T_fut] |
| counts = np.bincount(fut_o, minlength=self.num_classes) |
| label = int(np.argmax(counts)) |
|
|
| |
| past_high = contact_frame[anchor - self.T_obs:anchor].mean() > 0.5 |
| fut_high = contact_frame[anchor:anchor + self.T_fut].mean() > 0.5 |
| if not past_high and not fut_high: et = 0 |
| elif not past_high and fut_high: et = 1 |
| elif past_high and fut_high: et = 2 |
| else: et = 3 |
|
|
| past_slice = {m: arr[anchor - self.T_obs:anchor] |
| for m, arr in input_ds.items()} |
| if any(w.shape[0] != self.T_obs for w in past_slice.values()): |
| continue |
|
|
| item = { |
| "x": past_slice, |
| "label": label, |
| "event_type": et, |
| "meta": {"vol": vol, "scene": scene, "anchor_idx": int(anchor)}, |
| } |
| pools[label].append(item) |
|
|
| |
| if self.per_class_max is not None: |
| for c, pool in pools.items(): |
| if len(pool) > self.per_class_max: |
| idx = rng.choice(len(pool), size=self.per_class_max, replace=False) |
| pools[c] = [pool[i] for i in sorted(idx)] |
| self._items = [it for c in range(self.num_classes) for it in pools[c]] |
|
|
| if not self._items: |
| raise RuntimeError("GraspStateDataset: collected 0 anchors.") |
|
|
| |
| if input_stats is None: |
| input_stats = self._compute_input_stats() |
| self._input_stats = input_stats |
| self._apply_input_stats(input_stats) |
|
|
| if log: |
| if self.label_mode == "binary": |
| class_names = CLASS_NAMES_BINARY |
| elif self.label_mode == "three_class": |
| class_names = CLASS_NAMES_THREE |
| elif self.label_mode == "verb": |
| class_names = {i: v for i, v in enumerate(VERB_LIST)} |
| else: |
| class_names = {i: v for i, v in enumerate(OBJECT_TOP_LIST)} |
| counts_class = {class_names[c]: sum(1 for it in self._items if it["label"] == c) |
| for c in range(self.num_classes)} |
| counts_event = {EVENT_NAMES[k]: sum(1 for it in self._items if it["event_type"] == k) |
| for k in (0, 1, 2, 3)} |
| print(f"[GraspStateDataset] vols={len(volunteers)} " |
| f"inputs={self.input_modalities} " |
| f"anchors={len(self._items)} class={counts_class} " |
| f"event={counts_event} " |
| f"T_obs={self.T_obs} T_fut={self.T_fut} sr={self.sr}Hz " |
| f"input_dims={self._modality_dims}", flush=True) |
|
|
| @staticmethod |
| def _enforce_dim(arrs, m, arr, dim_dict): |
| if m in dim_dict: |
| tgt = dim_dict[m] |
| if arr.shape[1] != tgt: |
| if arr.shape[1] < tgt: |
| pad = np.zeros((arr.shape[0], tgt - arr.shape[1]), dtype=np.float32) |
| arrs[m] = np.concatenate([arr, pad], axis=1) |
| else: |
| arrs[m] = arr[:, :tgt] |
| else: |
| dim_dict[m] = arr.shape[1] |
|
|
| def _compute_input_stats(self): |
| accs = {m: [] for m in self._modality_dims} |
| for it in self._items: |
| for m, w in it["x"].items(): |
| accs[m].append(w) |
| out = {} |
| for m, ws in accs.items(): |
| cat = np.concatenate(ws, axis=0) |
| mu = cat.mean(axis=0).astype(np.float32) |
| sd = cat.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd) |
| out[m] = (mu, sd.astype(np.float32)) |
| return out |
|
|
| def _apply_input_stats(self, stats): |
| for it in self._items: |
| for m, w in it["x"].items(): |
| if m in stats: |
| mu, sd = stats[m] |
| it["x"][m] = ((w - mu) / sd).astype(np.float32) |
|
|
| def __len__(self): return len(self._items) |
|
|
| def __getitem__(self, idx): |
| it = self._items[idx] |
| x = {m: torch.from_numpy(np.ascontiguousarray(w)) for m, w in it["x"].items()} |
| label = int(it["label"]) |
| et = int(it["event_type"]) |
| return x, label, et, it["meta"] |
|
|
| @property |
| def modality_dims(self): return dict(self._modality_dims) |
|
|
|
|
| def collate_grasp_state(batch): |
| xs, labels, ets, metas = zip(*batch) |
| mods = list(xs[0].keys()) |
| x_out = {m: torch.stack([x[m] for x in xs], dim=0) for m in mods} |
| y_out = torch.tensor(labels, dtype=torch.long) |
| et_out = torch.tensor(ets, dtype=torch.long) |
| return x_out, y_out, et_out, list(metas) |
|
|
|
|
| def build_grasp_train_test( |
| input_modalities, |
| t_obs_sec=1.0, t_fut_sec=0.5, anchor_stride_sec=0.25, |
| downsample=5, |
| dataset_dir=DEFAULT_DATASET_DIR, annot_dir=DEFAULT_ANNOT_DIR, |
| contact_threshold_g=5.0, per_class_max=None, |
| label_mode="binary", sustained_threshold_sec=0.3, |
| require_lift_for_sustained=False, |
| rng_seed=0, |
| train_vols=None, test_vols=None, |
| ): |
| if train_vols is None: train_vols = TRAIN_VOLS_V3 |
| if test_vols is None: test_vols = TEST_VOLS_V3 |
| train = GraspStateDataset( |
| train_vols, input_modalities=input_modalities, |
| t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec, |
| anchor_stride_sec=anchor_stride_sec, downsample=downsample, |
| dataset_dir=dataset_dir, annot_dir=annot_dir, |
| contact_threshold_g=contact_threshold_g, per_class_max=per_class_max, |
| label_mode=label_mode, sustained_threshold_sec=sustained_threshold_sec, |
| require_lift_for_sustained=require_lift_for_sustained, |
| rng_seed=rng_seed, log=True, |
| ) |
| test = GraspStateDataset( |
| test_vols, input_modalities=input_modalities, |
| t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec, |
| anchor_stride_sec=anchor_stride_sec, downsample=downsample, |
| dataset_dir=dataset_dir, annot_dir=annot_dir, |
| contact_threshold_g=contact_threshold_g, per_class_max=None, |
| label_mode=label_mode, sustained_threshold_sec=sustained_threshold_sec, |
| require_lift_for_sustained=require_lift_for_sustained, |
| input_stats=train._input_stats, |
| expected_input_dims=train._modality_dims, |
| rng_seed=rng_seed + 1, log=True, |
| ) |
| return train, test |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--input_modalities", default="emg,imu,mocap") |
| ap.add_argument("--t_obs", type=float, default=1.0) |
| ap.add_argument("--t_fut", type=float, default=0.5) |
| args = ap.parse_args() |
| tr, te = build_grasp_train_test( |
| input_modalities=args.input_modalities.split(","), |
| t_obs_sec=args.t_obs, t_fut_sec=args.t_fut, |
| ) |
| x, y, et, meta = tr[0] |
| print(f"sample: x={ {m: tuple(v.shape) for m,v in x.items()} } y={y} et={et}") |
|
|