"""Anchor-based binary "is_grasping" classification dataset (T5 v3 / TGSR). At each sampled anchor t in a recording: past = sensor frames over [t - T_obs, t] ← input label = majority vote of grasp-annotation mask over (t, t+T_fut] ← binary class Ground-truth source: annotations_v3 verb segments. A frame is marked "is_grasp" if it falls inside a segment whose action_name belongs to GRASP_VERBS (set below). The label is annotation-derived, completely independent of pressure — so adding/removing pressure as input does NOT leak the label. This is the cleanest test of "does pressure improve recognition of object-interaction state when human-annotated grasp segments are GT?" """ from __future__ import annotations import json import sys from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple import numpy as np import torch from torch.utils.data import Dataset THIS = Path(__file__).resolve() sys.path.insert(0, str(THIS.parent)) sys.path.insert(0, str(THIS.parents[1])) try: from experiments.dataset_seqpred import ( SAMPLING_RATE_HZ, _load_recording_sensors, TRAIN_VOLS_V3, TEST_VOLS_V3, DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR, ) except ModuleNotFoundError: from dataset_seqpred import ( SAMPLING_RATE_HZ, _load_recording_sensors, TRAIN_VOLS_V3, TEST_VOLS_V3, DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR, ) GRASP_VERBS = { "grasp", "hold", "pick_up", "move", "place", "put_down", "pull", "rotate", "insert", "remove", } # User-specified subset of action verbs that mean "the object has been lifted # off its resting surface and held in hand" (used as Class 2 stricter definition). LIFT_VERBS = {"grasp", "open", "move", "pick_up", "hold"} # Multi-class verb taxonomy (annotations_v3 verb_fine universe). # Verb 0 = background (anchor outside any segment). VERB_LIST = [ "background", "grasp", "move", "place", "adjust", "pick_up", "close", "put_down", "pull", "hold", "open", "rotate", "release", "push", "insert", "remove", "align", "stabilize", ] VERB_TO_IDX = {v: i for i, v in enumerate(VERB_LIST)} # Top-15 most common object categories with non-zero coverage in the # pressure-bearing test set (annotations_v3 survey of TRAIN+TEST_VOLS_V3). # Index 0 = "_other": anchor outside any segment OR object not in top-15. # Note: "coat" excluded because it appears only in v14, which has no # pressure-aligned sessions and is silently dropped by the loader. OBJECT_TOP_LIST = [ "_other", "sealed jar", "towel", "tablecloth", "box", "pot", "rice bowl", "tape", "pants", "spoon", "plate", "marker", "cloth", "laptop", "toothbrush case", "tea canister", ] OBJECT_TO_IDX = {o: i for i, o in enumerate(OBJECT_TOP_LIST)} EVENT_NAMES = {0: "non-contact", 1: "pre-contact", 2: "steady-grip", 3: "release"} CLASS_NAMES_BINARY = {0: "non-grasp", 1: "grasp"} CLASS_NAMES_THREE = {0: "no-grasp", 1: "attempted", 2: "sustained"} # Back-compat default (used by binary code paths) CLASS_NAMES = CLASS_NAMES_BINARY def _parse_one(x: str, fmt_mode: str) -> float: p = x.split(":") if len(p) == 2: return int(p[0]) * 60 + int(p[1]) if fmt_mode == "hhmmss": return int(p[0]) * 3600 + int(p[1]) * 60 + int(p[2]) return int(p[0]) * 60 + int(p[1]) + int(p[2]) / 30.0 # mmssff @ 30fps def _detect_fmt(segments, rec_sec: float) -> str: for s in segments: b = s["timestamp"].split("-")[1] p = b.split(":") if len(p) == 3: hh = int(p[0]) * 3600 + int(p[1]) * 60 + int(p[2]) if hh > rec_sec * 1.05: return "mmssff" return "hhmmss" def build_object_label(annot_path: Path, n_frames: int, sr: int = SAMPLING_RATE_HZ) -> np.ndarray: """Per-frame object index (top-15 + '_other' fallback as class 0).""" label = np.zeros(n_frames, dtype=np.int8) if not annot_path.exists(): return label try: ann = json.load(open(annot_path)) except Exception: return label segments = ann.get("segments", []) if not segments: return label rec_sec = n_frames / sr fmt = _detect_fmt(segments, rec_sec) for s in segments: obj = s.get("action_annotation", {}).get("object_name") idx = OBJECT_TO_IDX.get(obj, 0) if idx == 0: continue # leave as 0 ("_other"/background) try: a, b = s["timestamp"].split("-") t0 = _parse_one(a, fmt); t1 = _parse_one(b, fmt) except Exception: continue if t1 <= t0 or t1 > rec_sec * 1.10: continue i0 = max(0, int(round(t0 * sr))) i1 = min(n_frames, int(round(t1 * sr))) label[i0:i1] = idx return label def build_lift_eligible_mask(annot_path: Path, n_frames: int, sr: int = SAMPLING_RATE_HZ) -> np.ndarray: """Per-frame bool: True if frame is inside a segment that meets the lifted-grasp criterion: verb ∈ LIFT_VERBS OR hand_type == 'both'. Used by 3-class label_mode when require_lift_for_sustained=True.""" mask = np.zeros(n_frames, dtype=bool) if not annot_path.exists(): return mask try: ann = json.load(open(annot_path)) except Exception: return mask segments = ann.get("segments", []) if not segments: return mask rec_sec = n_frames / sr fmt = _detect_fmt(segments, rec_sec) for s in segments: a = s.get("action_annotation", {}) verb = a.get("action_name") hand = a.get("hand_type", "") is_lift = (verb in LIFT_VERBS) or (hand == "both") if not is_lift: continue try: ts0, ts1 = s["timestamp"].split("-") t0 = _parse_one(ts0, fmt); t1 = _parse_one(ts1, fmt) except Exception: continue if t1 <= t0 or t1 > rec_sec * 1.10: continue i0 = max(0, int(round(t0 * sr))) i1 = min(n_frames, int(round(t1 * sr))) mask[i0:i1] = True return mask def build_verb_label(annot_path: Path, n_frames: int, sr: int = SAMPLING_RATE_HZ) -> np.ndarray: """Per-frame verb index (int8). Default (no segment) = 0 (background).""" label = np.zeros(n_frames, dtype=np.int8) if not annot_path.exists(): return label try: ann = json.load(open(annot_path)) except Exception: return label segments = ann.get("segments", []) if not segments: return label rec_sec = n_frames / sr fmt = _detect_fmt(segments, rec_sec) for s in segments: verb = s.get("action_annotation", {}).get("action_name") v_idx = VERB_TO_IDX.get(verb, 0) # unknown verb → background if v_idx == 0: continue try: a, b = s["timestamp"].split("-") t0 = _parse_one(a, fmt); t1 = _parse_one(b, fmt) except Exception: continue if t1 <= t0 or t1 > rec_sec * 1.10: continue i0 = max(0, int(round(t0 * sr))) i1 = min(n_frames, int(round(t1 * sr))) label[i0:i1] = v_idx return label def build_grasp_mask(annot_path: Path, n_frames: int, sr: int = SAMPLING_RATE_HZ) -> np.ndarray: """Return bool array of shape (n_frames,).""" mask = np.zeros(n_frames, dtype=bool) if not annot_path.exists(): return mask try: ann = json.load(open(annot_path)) except Exception: return mask segments = ann.get("segments", []) if not segments: return mask rec_sec = n_frames / sr fmt = _detect_fmt(segments, rec_sec) for s in segments: verb = s.get("action_annotation", {}).get("action_name") if verb not in GRASP_VERBS: continue try: a, b = s["timestamp"].split("-") t0 = _parse_one(a, fmt); t1 = _parse_one(b, fmt) except Exception: continue if t1 <= t0 or t1 > rec_sec * 1.10: continue i0 = max(0, int(round(t0 * sr))) i1 = min(n_frames, int(round(t1 * sr))) mask[i0:i1] = True return mask class GraspStateDataset(Dataset): """Predict binary 'is_grasping' label over future window from past sensor signals.""" def __init__( self, volunteers: Sequence[str], input_modalities: Sequence[str], t_obs_sec: float = 1.0, t_fut_sec: float = 0.5, anchor_stride_sec: float = 0.25, downsample: int = 5, dataset_dir: Path = DEFAULT_DATASET_DIR, annot_dir: Path = DEFAULT_ANNOT_DIR, contact_threshold_g: float = 5.0, # legacy sum-threshold (kept for back-compat, unused if use_per_cell_contact=True) per_cell_threshold_g: float = 10.0, # per-cell threshold to declare a sensor cell "active" min_active_cells: int = 3, # need ≥ this many active cells to declare contact use_per_cell_contact: bool = True, # NEW default: use per-cell active-count for event_type label_mode: str = "binary", # "binary", "three_class", or "verb" sustained_threshold_sec: float = 0.3, # (3-class only) min contiguous contact for "Sustained" require_lift_for_sustained: bool = False, # (3-class only) Class 2 also requires verb ∈ LIFT_VERBS per_class_max: Optional[int] = None, input_stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None, expected_input_dims: Optional[Dict[str, int]] = None, majority_threshold: float = 0.5, rng_seed: int = 0, log: bool = True, ): super().__init__() self.input_modalities = list(input_modalities) self.t_obs_sec = float(t_obs_sec) self.t_fut_sec = float(t_fut_sec) self.anchor_stride_sec = float(anchor_stride_sec) self.downsample = int(downsample) self.sr = SAMPLING_RATE_HZ // self.downsample self.dataset_dir = Path(dataset_dir) self.annot_dir = Path(annot_dir) self.contact_threshold_g = float(contact_threshold_g) self.per_cell_threshold_g = float(per_cell_threshold_g) self.min_active_cells = int(min_active_cells) self.use_per_cell_contact = bool(use_per_cell_contact) self.label_mode = str(label_mode) if self.label_mode not in ("binary", "three_class", "verb", "object"): raise ValueError(f"label_mode must be binary|three_class|verb|object, got {label_mode}") if self.label_mode == "binary": self.num_classes = 2 elif self.label_mode == "three_class": self.num_classes = 3 elif self.label_mode == "verb": self.num_classes = len(VERB_LIST) else: # object self.num_classes = len(OBJECT_TOP_LIST) self.sustained_threshold_sec = float(sustained_threshold_sec) self.require_lift_for_sustained = bool(require_lift_for_sustained) self.per_class_max = per_class_max self.majority_threshold = float(majority_threshold) self.T_obs = int(round(self.t_obs_sec * self.sr)) self.T_fut = int(round(self.t_fut_sec * self.sr)) self._items: List[dict] = [] self._modality_dims: Dict[str, int] = dict(expected_input_dims) if expected_input_dims else {} rng = np.random.default_rng(rng_seed) # Load pressure even if not in inputs, for event_type stratification. load_mods = list(dict.fromkeys(list(self.input_modalities) + ["pressure"])) # Per-class anchor pool pools: Dict[int, List[dict]] = {c: [] for c in range(self.num_classes)} sustained_thresh_frames = int(round(self.sustained_threshold_sec * self.sr)) for vol in volunteers: vol_dir = self.dataset_dir / vol if not vol_dir.is_dir(): continue for scenario_dir in sorted(vol_dir.glob("s*")): if not scenario_dir.is_dir(): continue scene = scenario_dir.name annot_path = self.annot_dir / vol / f"{scene}.json" if not annot_path.exists(): continue try: sensors_all = _load_recording_sensors( scenario_dir, vol, scene, load_mods ) except Exception: continue if sensors_all is None or any(a is None for a in sensors_all.values()): continue pressure_full = sensors_all["pressure"] # (T, 50) input_arrs = {m: sensors_all[m] for m in self.input_modalities} for m, arr in input_arrs.items(): self._enforce_dim(input_arrs, m, arr, self._modality_dims) T_avail = min(a.shape[0] for a in input_arrs.values()) T_avail = min(T_avail, pressure_full.shape[0]) if T_avail < (self.T_obs + self.T_fut) * self.downsample: continue # Build grasp mask at 100 Hz, then downsample. mask_full = build_grasp_mask(annot_path, T_avail, sr=SAMPLING_RATE_HZ) if self.label_mode == "verb": verb_full = build_verb_label(annot_path, T_avail, sr=SAMPLING_RATE_HZ) verb_ds = verb_full[:T_avail:self.downsample] else: verb_ds = None if self.label_mode == "object": obj_full = build_object_label(annot_path, T_avail, sr=SAMPLING_RATE_HZ) obj_ds = obj_full[:T_avail:self.downsample] else: obj_ds = None if self.label_mode == "three_class" and self.require_lift_for_sustained: lift_full = build_lift_eligible_mask(annot_path, T_avail, sr=SAMPLING_RATE_HZ) lift_eligible_ds = lift_full[:T_avail:self.downsample] else: lift_eligible_ds = None input_ds = {m: arr[:T_avail:self.downsample] for m, arr in input_arrs.items()} pressure_ds = pressure_full[:T_avail:self.downsample] mask_ds = mask_full[:T_avail:self.downsample] T_ds = mask_ds.shape[0] if self.use_per_cell_contact: # n_active per frame: count cells with value > per_cell_threshold_g n_active = (pressure_ds > self.per_cell_threshold_g).sum(axis=1) contact_frame = n_active >= self.min_active_cells else: pressure_sum = pressure_ds.sum(axis=1) contact_frame = pressure_sum > self.contact_threshold_g stride = max(1, int(round(self.anchor_stride_sec * self.sr))) first_anchor = self.T_obs last_anchor = T_ds - self.T_fut if last_anchor <= first_anchor: continue for anchor in range(first_anchor, last_anchor + 1, stride): fut_mask = mask_ds[anchor:anchor + self.T_fut] if fut_mask.shape[0] != self.T_fut: continue annotation_is_grasp = fut_mask.mean() >= self.majority_threshold if self.label_mode == "binary": label = int(annotation_is_grasp) elif self.label_mode == "three_class": if not annotation_is_grasp: label = 0 # NoGrasp else: # longest contiguous run of contact in future window fut_contact = contact_frame[anchor:anchor + self.T_fut] longest = 0; cur = 0 for v in fut_contact: if v: cur += 1; longest = max(longest, cur) else: cur = 0 is_sustained = longest >= sustained_thresh_frames if is_sustained and self.require_lift_for_sustained: # Demote to Class 1 unless majority of future window is in # a "lift-eligible" segment (verb ∈ LIFT_VERBS or hand=both). fut_lift = lift_eligible_ds[anchor:anchor + self.T_fut] if fut_lift.mean() < 0.5: is_sustained = False label = 2 if is_sustained else 1 elif self.label_mode == "verb": fut_v = verb_ds[anchor:anchor + self.T_fut] counts = np.bincount(fut_v, minlength=self.num_classes) label = int(np.argmax(counts)) else: # object — majority object in future window fut_o = obj_ds[anchor:anchor + self.T_fut] counts = np.bincount(fut_o, minlength=self.num_classes) label = int(np.argmax(counts)) # event_type for stratification (4-class transition taxonomy) past_high = contact_frame[anchor - self.T_obs:anchor].mean() > 0.5 fut_high = contact_frame[anchor:anchor + self.T_fut].mean() > 0.5 if not past_high and not fut_high: et = 0 elif not past_high and fut_high: et = 1 elif past_high and fut_high: et = 2 else: et = 3 past_slice = {m: arr[anchor - self.T_obs:anchor] for m, arr in input_ds.items()} if any(w.shape[0] != self.T_obs for w in past_slice.values()): continue item = { "x": past_slice, "label": label, "event_type": et, "meta": {"vol": vol, "scene": scene, "anchor_idx": int(anchor)}, } pools[label].append(item) # Balance classes if requested (cap larger pool to per_class_max) if self.per_class_max is not None: for c, pool in pools.items(): if len(pool) > self.per_class_max: idx = rng.choice(len(pool), size=self.per_class_max, replace=False) pools[c] = [pool[i] for i in sorted(idx)] self._items = [it for c in range(self.num_classes) for it in pools[c]] if not self._items: raise RuntimeError("GraspStateDataset: collected 0 anchors.") # Z-score inputs if input_stats is None: input_stats = self._compute_input_stats() self._input_stats = input_stats self._apply_input_stats(input_stats) if log: if self.label_mode == "binary": class_names = CLASS_NAMES_BINARY elif self.label_mode == "three_class": class_names = CLASS_NAMES_THREE elif self.label_mode == "verb": class_names = {i: v for i, v in enumerate(VERB_LIST)} else: # object class_names = {i: v for i, v in enumerate(OBJECT_TOP_LIST)} counts_class = {class_names[c]: sum(1 for it in self._items if it["label"] == c) for c in range(self.num_classes)} counts_event = {EVENT_NAMES[k]: sum(1 for it in self._items if it["event_type"] == k) for k in (0, 1, 2, 3)} print(f"[GraspStateDataset] vols={len(volunteers)} " f"inputs={self.input_modalities} " f"anchors={len(self._items)} class={counts_class} " f"event={counts_event} " f"T_obs={self.T_obs} T_fut={self.T_fut} sr={self.sr}Hz " f"input_dims={self._modality_dims}", flush=True) @staticmethod def _enforce_dim(arrs, m, arr, dim_dict): if m in dim_dict: tgt = dim_dict[m] if arr.shape[1] != tgt: if arr.shape[1] < tgt: pad = np.zeros((arr.shape[0], tgt - arr.shape[1]), dtype=np.float32) arrs[m] = np.concatenate([arr, pad], axis=1) else: arrs[m] = arr[:, :tgt] else: dim_dict[m] = arr.shape[1] def _compute_input_stats(self): accs = {m: [] for m in self._modality_dims} for it in self._items: for m, w in it["x"].items(): accs[m].append(w) out = {} for m, ws in accs.items(): cat = np.concatenate(ws, axis=0) mu = cat.mean(axis=0).astype(np.float32) sd = cat.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd) out[m] = (mu, sd.astype(np.float32)) return out def _apply_input_stats(self, stats): for it in self._items: for m, w in it["x"].items(): if m in stats: mu, sd = stats[m] it["x"][m] = ((w - mu) / sd).astype(np.float32) def __len__(self): return len(self._items) def __getitem__(self, idx): it = self._items[idx] x = {m: torch.from_numpy(np.ascontiguousarray(w)) for m, w in it["x"].items()} label = int(it["label"]) et = int(it["event_type"]) return x, label, et, it["meta"] @property def modality_dims(self): return dict(self._modality_dims) def collate_grasp_state(batch): xs, labels, ets, metas = zip(*batch) mods = list(xs[0].keys()) x_out = {m: torch.stack([x[m] for x in xs], dim=0) for m in mods} y_out = torch.tensor(labels, dtype=torch.long) et_out = torch.tensor(ets, dtype=torch.long) return x_out, y_out, et_out, list(metas) def build_grasp_train_test( input_modalities, t_obs_sec=1.0, t_fut_sec=0.5, anchor_stride_sec=0.25, downsample=5, dataset_dir=DEFAULT_DATASET_DIR, annot_dir=DEFAULT_ANNOT_DIR, contact_threshold_g=5.0, per_class_max=None, label_mode="binary", sustained_threshold_sec=0.3, require_lift_for_sustained=False, rng_seed=0, train_vols=None, test_vols=None, ): if train_vols is None: train_vols = TRAIN_VOLS_V3 if test_vols is None: test_vols = TEST_VOLS_V3 train = GraspStateDataset( train_vols, input_modalities=input_modalities, t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec, anchor_stride_sec=anchor_stride_sec, downsample=downsample, dataset_dir=dataset_dir, annot_dir=annot_dir, contact_threshold_g=contact_threshold_g, per_class_max=per_class_max, label_mode=label_mode, sustained_threshold_sec=sustained_threshold_sec, require_lift_for_sustained=require_lift_for_sustained, rng_seed=rng_seed, log=True, ) test = GraspStateDataset( test_vols, input_modalities=input_modalities, t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec, anchor_stride_sec=anchor_stride_sec, downsample=downsample, dataset_dir=dataset_dir, annot_dir=annot_dir, contact_threshold_g=contact_threshold_g, per_class_max=None, # don't cap test label_mode=label_mode, sustained_threshold_sec=sustained_threshold_sec, require_lift_for_sustained=require_lift_for_sustained, input_stats=train._input_stats, expected_input_dims=train._modality_dims, rng_seed=rng_seed + 1, log=True, ) return train, test if __name__ == "__main__": import argparse ap = argparse.ArgumentParser() ap.add_argument("--input_modalities", default="emg,imu,mocap") ap.add_argument("--t_obs", type=float, default=1.0) ap.add_argument("--t_fut", type=float, default=0.5) args = ap.parse_args() tr, te = build_grasp_train_test( input_modalities=args.input_modalities.split(","), t_obs_sec=args.t_obs, t_fut_sec=args.t_fut, ) x, y, et, meta = tr[0] print(f"sample: x={ {m: tuple(v.shape) for m,v in x.items()} } y={y} et={et}")