PULSE-code / experiments /data /dataset_grasp_state.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
"""Anchor-based binary "is_grasping" classification dataset (T5 v3 / TGSR).
At each sampled anchor t in a recording:
past = sensor frames over [t - T_obs, t] ← input
label = majority vote of grasp-annotation mask over (t, t+T_fut] ← binary class
Ground-truth source: annotations_v3 verb segments. A frame is marked
"is_grasp" if it falls inside a segment whose action_name belongs to
GRASP_VERBS (set below). The label is annotation-derived, completely
independent of pressure — so adding/removing pressure as input does
NOT leak the label.
This is the cleanest test of "does pressure improve recognition of
object-interaction state when human-annotated grasp segments are GT?"
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset
THIS = Path(__file__).resolve()
sys.path.insert(0, str(THIS.parent))
sys.path.insert(0, str(THIS.parents[1]))
try:
from experiments.dataset_seqpred import (
SAMPLING_RATE_HZ, _load_recording_sensors,
TRAIN_VOLS_V3, TEST_VOLS_V3,
DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
)
except ModuleNotFoundError:
from dataset_seqpred import (
SAMPLING_RATE_HZ, _load_recording_sensors,
TRAIN_VOLS_V3, TEST_VOLS_V3,
DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
)
GRASP_VERBS = {
"grasp", "hold", "pick_up", "move", "place", "put_down",
"pull", "rotate", "insert", "remove",
}
# User-specified subset of action verbs that mean "the object has been lifted
# off its resting surface and held in hand" (used as Class 2 stricter definition).
LIFT_VERBS = {"grasp", "open", "move", "pick_up", "hold"}
# Multi-class verb taxonomy (annotations_v3 verb_fine universe).
# Verb 0 = background (anchor outside any segment).
VERB_LIST = [
"background",
"grasp", "move", "place", "adjust", "pick_up",
"close", "put_down", "pull", "hold", "open",
"rotate", "release", "push", "insert", "remove",
"align", "stabilize",
]
VERB_TO_IDX = {v: i for i, v in enumerate(VERB_LIST)}
# Top-15 most common object categories with non-zero coverage in the
# pressure-bearing test set (annotations_v3 survey of TRAIN+TEST_VOLS_V3).
# Index 0 = "_other": anchor outside any segment OR object not in top-15.
# Note: "coat" excluded because it appears only in v14, which has no
# pressure-aligned sessions and is silently dropped by the loader.
OBJECT_TOP_LIST = [
"_other",
"sealed jar", "towel", "tablecloth", "box", "pot",
"rice bowl", "tape", "pants", "spoon", "plate",
"marker", "cloth", "laptop", "toothbrush case", "tea canister",
]
OBJECT_TO_IDX = {o: i for i, o in enumerate(OBJECT_TOP_LIST)}
EVENT_NAMES = {0: "non-contact", 1: "pre-contact", 2: "steady-grip", 3: "release"}
CLASS_NAMES_BINARY = {0: "non-grasp", 1: "grasp"}
CLASS_NAMES_THREE = {0: "no-grasp", 1: "attempted", 2: "sustained"}
# Back-compat default (used by binary code paths)
CLASS_NAMES = CLASS_NAMES_BINARY
def _parse_one(x: str, fmt_mode: str) -> float:
p = x.split(":")
if len(p) == 2:
return int(p[0]) * 60 + int(p[1])
if fmt_mode == "hhmmss":
return int(p[0]) * 3600 + int(p[1]) * 60 + int(p[2])
return int(p[0]) * 60 + int(p[1]) + int(p[2]) / 30.0 # mmssff @ 30fps
def _detect_fmt(segments, rec_sec: float) -> str:
for s in segments:
b = s["timestamp"].split("-")[1]
p = b.split(":")
if len(p) == 3:
hh = int(p[0]) * 3600 + int(p[1]) * 60 + int(p[2])
if hh > rec_sec * 1.05:
return "mmssff"
return "hhmmss"
def build_object_label(annot_path: Path, n_frames: int,
sr: int = SAMPLING_RATE_HZ) -> np.ndarray:
"""Per-frame object index (top-15 + '_other' fallback as class 0)."""
label = np.zeros(n_frames, dtype=np.int8)
if not annot_path.exists():
return label
try:
ann = json.load(open(annot_path))
except Exception:
return label
segments = ann.get("segments", [])
if not segments:
return label
rec_sec = n_frames / sr
fmt = _detect_fmt(segments, rec_sec)
for s in segments:
obj = s.get("action_annotation", {}).get("object_name")
idx = OBJECT_TO_IDX.get(obj, 0)
if idx == 0:
continue # leave as 0 ("_other"/background)
try:
a, b = s["timestamp"].split("-")
t0 = _parse_one(a, fmt); t1 = _parse_one(b, fmt)
except Exception:
continue
if t1 <= t0 or t1 > rec_sec * 1.10:
continue
i0 = max(0, int(round(t0 * sr)))
i1 = min(n_frames, int(round(t1 * sr)))
label[i0:i1] = idx
return label
def build_lift_eligible_mask(annot_path: Path, n_frames: int,
sr: int = SAMPLING_RATE_HZ) -> np.ndarray:
"""Per-frame bool: True if frame is inside a segment that meets the
lifted-grasp criterion: verb ∈ LIFT_VERBS OR hand_type == 'both'.
Used by 3-class label_mode when require_lift_for_sustained=True."""
mask = np.zeros(n_frames, dtype=bool)
if not annot_path.exists():
return mask
try:
ann = json.load(open(annot_path))
except Exception:
return mask
segments = ann.get("segments", [])
if not segments:
return mask
rec_sec = n_frames / sr
fmt = _detect_fmt(segments, rec_sec)
for s in segments:
a = s.get("action_annotation", {})
verb = a.get("action_name")
hand = a.get("hand_type", "")
is_lift = (verb in LIFT_VERBS) or (hand == "both")
if not is_lift:
continue
try:
ts0, ts1 = s["timestamp"].split("-")
t0 = _parse_one(ts0, fmt); t1 = _parse_one(ts1, fmt)
except Exception:
continue
if t1 <= t0 or t1 > rec_sec * 1.10:
continue
i0 = max(0, int(round(t0 * sr)))
i1 = min(n_frames, int(round(t1 * sr)))
mask[i0:i1] = True
return mask
def build_verb_label(annot_path: Path, n_frames: int,
sr: int = SAMPLING_RATE_HZ) -> np.ndarray:
"""Per-frame verb index (int8). Default (no segment) = 0 (background)."""
label = np.zeros(n_frames, dtype=np.int8)
if not annot_path.exists():
return label
try:
ann = json.load(open(annot_path))
except Exception:
return label
segments = ann.get("segments", [])
if not segments:
return label
rec_sec = n_frames / sr
fmt = _detect_fmt(segments, rec_sec)
for s in segments:
verb = s.get("action_annotation", {}).get("action_name")
v_idx = VERB_TO_IDX.get(verb, 0) # unknown verb → background
if v_idx == 0:
continue
try:
a, b = s["timestamp"].split("-")
t0 = _parse_one(a, fmt); t1 = _parse_one(b, fmt)
except Exception:
continue
if t1 <= t0 or t1 > rec_sec * 1.10:
continue
i0 = max(0, int(round(t0 * sr)))
i1 = min(n_frames, int(round(t1 * sr)))
label[i0:i1] = v_idx
return label
def build_grasp_mask(annot_path: Path, n_frames: int,
sr: int = SAMPLING_RATE_HZ) -> np.ndarray:
"""Return bool array of shape (n_frames,)."""
mask = np.zeros(n_frames, dtype=bool)
if not annot_path.exists():
return mask
try:
ann = json.load(open(annot_path))
except Exception:
return mask
segments = ann.get("segments", [])
if not segments:
return mask
rec_sec = n_frames / sr
fmt = _detect_fmt(segments, rec_sec)
for s in segments:
verb = s.get("action_annotation", {}).get("action_name")
if verb not in GRASP_VERBS:
continue
try:
a, b = s["timestamp"].split("-")
t0 = _parse_one(a, fmt); t1 = _parse_one(b, fmt)
except Exception:
continue
if t1 <= t0 or t1 > rec_sec * 1.10:
continue
i0 = max(0, int(round(t0 * sr)))
i1 = min(n_frames, int(round(t1 * sr)))
mask[i0:i1] = True
return mask
class GraspStateDataset(Dataset):
"""Predict binary 'is_grasping' label over future window from past sensor signals."""
def __init__(
self,
volunteers: Sequence[str],
input_modalities: Sequence[str],
t_obs_sec: float = 1.0,
t_fut_sec: float = 0.5,
anchor_stride_sec: float = 0.25,
downsample: int = 5,
dataset_dir: Path = DEFAULT_DATASET_DIR,
annot_dir: Path = DEFAULT_ANNOT_DIR,
contact_threshold_g: float = 5.0, # legacy sum-threshold (kept for back-compat, unused if use_per_cell_contact=True)
per_cell_threshold_g: float = 10.0, # per-cell threshold to declare a sensor cell "active"
min_active_cells: int = 3, # need ≥ this many active cells to declare contact
use_per_cell_contact: bool = True, # NEW default: use per-cell active-count for event_type
label_mode: str = "binary", # "binary", "three_class", or "verb"
sustained_threshold_sec: float = 0.3, # (3-class only) min contiguous contact for "Sustained"
require_lift_for_sustained: bool = False, # (3-class only) Class 2 also requires verb ∈ LIFT_VERBS
per_class_max: Optional[int] = None,
input_stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None,
expected_input_dims: Optional[Dict[str, int]] = None,
majority_threshold: float = 0.5,
rng_seed: int = 0,
log: bool = True,
):
super().__init__()
self.input_modalities = list(input_modalities)
self.t_obs_sec = float(t_obs_sec)
self.t_fut_sec = float(t_fut_sec)
self.anchor_stride_sec = float(anchor_stride_sec)
self.downsample = int(downsample)
self.sr = SAMPLING_RATE_HZ // self.downsample
self.dataset_dir = Path(dataset_dir)
self.annot_dir = Path(annot_dir)
self.contact_threshold_g = float(contact_threshold_g)
self.per_cell_threshold_g = float(per_cell_threshold_g)
self.min_active_cells = int(min_active_cells)
self.use_per_cell_contact = bool(use_per_cell_contact)
self.label_mode = str(label_mode)
if self.label_mode not in ("binary", "three_class", "verb", "object"):
raise ValueError(f"label_mode must be binary|three_class|verb|object, got {label_mode}")
if self.label_mode == "binary":
self.num_classes = 2
elif self.label_mode == "three_class":
self.num_classes = 3
elif self.label_mode == "verb":
self.num_classes = len(VERB_LIST)
else: # object
self.num_classes = len(OBJECT_TOP_LIST)
self.sustained_threshold_sec = float(sustained_threshold_sec)
self.require_lift_for_sustained = bool(require_lift_for_sustained)
self.per_class_max = per_class_max
self.majority_threshold = float(majority_threshold)
self.T_obs = int(round(self.t_obs_sec * self.sr))
self.T_fut = int(round(self.t_fut_sec * self.sr))
self._items: List[dict] = []
self._modality_dims: Dict[str, int] = dict(expected_input_dims) if expected_input_dims else {}
rng = np.random.default_rng(rng_seed)
# Load pressure even if not in inputs, for event_type stratification.
load_mods = list(dict.fromkeys(list(self.input_modalities) + ["pressure"]))
# Per-class anchor pool
pools: Dict[int, List[dict]] = {c: [] for c in range(self.num_classes)}
sustained_thresh_frames = int(round(self.sustained_threshold_sec * self.sr))
for vol in volunteers:
vol_dir = self.dataset_dir / vol
if not vol_dir.is_dir():
continue
for scenario_dir in sorted(vol_dir.glob("s*")):
if not scenario_dir.is_dir():
continue
scene = scenario_dir.name
annot_path = self.annot_dir / vol / f"{scene}.json"
if not annot_path.exists():
continue
try:
sensors_all = _load_recording_sensors(
scenario_dir, vol, scene, load_mods
)
except Exception:
continue
if sensors_all is None or any(a is None for a in sensors_all.values()):
continue
pressure_full = sensors_all["pressure"] # (T, 50)
input_arrs = {m: sensors_all[m] for m in self.input_modalities}
for m, arr in input_arrs.items():
self._enforce_dim(input_arrs, m, arr, self._modality_dims)
T_avail = min(a.shape[0] for a in input_arrs.values())
T_avail = min(T_avail, pressure_full.shape[0])
if T_avail < (self.T_obs + self.T_fut) * self.downsample:
continue
# Build grasp mask at 100 Hz, then downsample.
mask_full = build_grasp_mask(annot_path, T_avail,
sr=SAMPLING_RATE_HZ)
if self.label_mode == "verb":
verb_full = build_verb_label(annot_path, T_avail, sr=SAMPLING_RATE_HZ)
verb_ds = verb_full[:T_avail:self.downsample]
else:
verb_ds = None
if self.label_mode == "object":
obj_full = build_object_label(annot_path, T_avail, sr=SAMPLING_RATE_HZ)
obj_ds = obj_full[:T_avail:self.downsample]
else:
obj_ds = None
if self.label_mode == "three_class" and self.require_lift_for_sustained:
lift_full = build_lift_eligible_mask(annot_path, T_avail, sr=SAMPLING_RATE_HZ)
lift_eligible_ds = lift_full[:T_avail:self.downsample]
else:
lift_eligible_ds = None
input_ds = {m: arr[:T_avail:self.downsample] for m, arr in input_arrs.items()}
pressure_ds = pressure_full[:T_avail:self.downsample]
mask_ds = mask_full[:T_avail:self.downsample]
T_ds = mask_ds.shape[0]
if self.use_per_cell_contact:
# n_active per frame: count cells with value > per_cell_threshold_g
n_active = (pressure_ds > self.per_cell_threshold_g).sum(axis=1)
contact_frame = n_active >= self.min_active_cells
else:
pressure_sum = pressure_ds.sum(axis=1)
contact_frame = pressure_sum > self.contact_threshold_g
stride = max(1, int(round(self.anchor_stride_sec * self.sr)))
first_anchor = self.T_obs
last_anchor = T_ds - self.T_fut
if last_anchor <= first_anchor:
continue
for anchor in range(first_anchor, last_anchor + 1, stride):
fut_mask = mask_ds[anchor:anchor + self.T_fut]
if fut_mask.shape[0] != self.T_fut:
continue
annotation_is_grasp = fut_mask.mean() >= self.majority_threshold
if self.label_mode == "binary":
label = int(annotation_is_grasp)
elif self.label_mode == "three_class":
if not annotation_is_grasp:
label = 0 # NoGrasp
else:
# longest contiguous run of contact in future window
fut_contact = contact_frame[anchor:anchor + self.T_fut]
longest = 0; cur = 0
for v in fut_contact:
if v: cur += 1; longest = max(longest, cur)
else: cur = 0
is_sustained = longest >= sustained_thresh_frames
if is_sustained and self.require_lift_for_sustained:
# Demote to Class 1 unless majority of future window is in
# a "lift-eligible" segment (verb ∈ LIFT_VERBS or hand=both).
fut_lift = lift_eligible_ds[anchor:anchor + self.T_fut]
if fut_lift.mean() < 0.5:
is_sustained = False
label = 2 if is_sustained else 1
elif self.label_mode == "verb":
fut_v = verb_ds[anchor:anchor + self.T_fut]
counts = np.bincount(fut_v, minlength=self.num_classes)
label = int(np.argmax(counts))
else: # object — majority object in future window
fut_o = obj_ds[anchor:anchor + self.T_fut]
counts = np.bincount(fut_o, minlength=self.num_classes)
label = int(np.argmax(counts))
# event_type for stratification (4-class transition taxonomy)
past_high = contact_frame[anchor - self.T_obs:anchor].mean() > 0.5
fut_high = contact_frame[anchor:anchor + self.T_fut].mean() > 0.5
if not past_high and not fut_high: et = 0
elif not past_high and fut_high: et = 1
elif past_high and fut_high: et = 2
else: et = 3
past_slice = {m: arr[anchor - self.T_obs:anchor]
for m, arr in input_ds.items()}
if any(w.shape[0] != self.T_obs for w in past_slice.values()):
continue
item = {
"x": past_slice,
"label": label,
"event_type": et,
"meta": {"vol": vol, "scene": scene, "anchor_idx": int(anchor)},
}
pools[label].append(item)
# Balance classes if requested (cap larger pool to per_class_max)
if self.per_class_max is not None:
for c, pool in pools.items():
if len(pool) > self.per_class_max:
idx = rng.choice(len(pool), size=self.per_class_max, replace=False)
pools[c] = [pool[i] for i in sorted(idx)]
self._items = [it for c in range(self.num_classes) for it in pools[c]]
if not self._items:
raise RuntimeError("GraspStateDataset: collected 0 anchors.")
# Z-score inputs
if input_stats is None:
input_stats = self._compute_input_stats()
self._input_stats = input_stats
self._apply_input_stats(input_stats)
if log:
if self.label_mode == "binary":
class_names = CLASS_NAMES_BINARY
elif self.label_mode == "three_class":
class_names = CLASS_NAMES_THREE
elif self.label_mode == "verb":
class_names = {i: v for i, v in enumerate(VERB_LIST)}
else: # object
class_names = {i: v for i, v in enumerate(OBJECT_TOP_LIST)}
counts_class = {class_names[c]: sum(1 for it in self._items if it["label"] == c)
for c in range(self.num_classes)}
counts_event = {EVENT_NAMES[k]: sum(1 for it in self._items if it["event_type"] == k)
for k in (0, 1, 2, 3)}
print(f"[GraspStateDataset] vols={len(volunteers)} "
f"inputs={self.input_modalities} "
f"anchors={len(self._items)} class={counts_class} "
f"event={counts_event} "
f"T_obs={self.T_obs} T_fut={self.T_fut} sr={self.sr}Hz "
f"input_dims={self._modality_dims}", flush=True)
@staticmethod
def _enforce_dim(arrs, m, arr, dim_dict):
if m in dim_dict:
tgt = dim_dict[m]
if arr.shape[1] != tgt:
if arr.shape[1] < tgt:
pad = np.zeros((arr.shape[0], tgt - arr.shape[1]), dtype=np.float32)
arrs[m] = np.concatenate([arr, pad], axis=1)
else:
arrs[m] = arr[:, :tgt]
else:
dim_dict[m] = arr.shape[1]
def _compute_input_stats(self):
accs = {m: [] for m in self._modality_dims}
for it in self._items:
for m, w in it["x"].items():
accs[m].append(w)
out = {}
for m, ws in accs.items():
cat = np.concatenate(ws, axis=0)
mu = cat.mean(axis=0).astype(np.float32)
sd = cat.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
out[m] = (mu, sd.astype(np.float32))
return out
def _apply_input_stats(self, stats):
for it in self._items:
for m, w in it["x"].items():
if m in stats:
mu, sd = stats[m]
it["x"][m] = ((w - mu) / sd).astype(np.float32)
def __len__(self): return len(self._items)
def __getitem__(self, idx):
it = self._items[idx]
x = {m: torch.from_numpy(np.ascontiguousarray(w)) for m, w in it["x"].items()}
label = int(it["label"])
et = int(it["event_type"])
return x, label, et, it["meta"]
@property
def modality_dims(self): return dict(self._modality_dims)
def collate_grasp_state(batch):
xs, labels, ets, metas = zip(*batch)
mods = list(xs[0].keys())
x_out = {m: torch.stack([x[m] for x in xs], dim=0) for m in mods}
y_out = torch.tensor(labels, dtype=torch.long)
et_out = torch.tensor(ets, dtype=torch.long)
return x_out, y_out, et_out, list(metas)
def build_grasp_train_test(
input_modalities,
t_obs_sec=1.0, t_fut_sec=0.5, anchor_stride_sec=0.25,
downsample=5,
dataset_dir=DEFAULT_DATASET_DIR, annot_dir=DEFAULT_ANNOT_DIR,
contact_threshold_g=5.0, per_class_max=None,
label_mode="binary", sustained_threshold_sec=0.3,
require_lift_for_sustained=False,
rng_seed=0,
train_vols=None, test_vols=None,
):
if train_vols is None: train_vols = TRAIN_VOLS_V3
if test_vols is None: test_vols = TEST_VOLS_V3
train = GraspStateDataset(
train_vols, input_modalities=input_modalities,
t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
anchor_stride_sec=anchor_stride_sec, downsample=downsample,
dataset_dir=dataset_dir, annot_dir=annot_dir,
contact_threshold_g=contact_threshold_g, per_class_max=per_class_max,
label_mode=label_mode, sustained_threshold_sec=sustained_threshold_sec,
require_lift_for_sustained=require_lift_for_sustained,
rng_seed=rng_seed, log=True,
)
test = GraspStateDataset(
test_vols, input_modalities=input_modalities,
t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
anchor_stride_sec=anchor_stride_sec, downsample=downsample,
dataset_dir=dataset_dir, annot_dir=annot_dir,
contact_threshold_g=contact_threshold_g, per_class_max=None, # don't cap test
label_mode=label_mode, sustained_threshold_sec=sustained_threshold_sec,
require_lift_for_sustained=require_lift_for_sustained,
input_stats=train._input_stats,
expected_input_dims=train._modality_dims,
rng_seed=rng_seed + 1, log=True,
)
return train, test
if __name__ == "__main__":
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("--input_modalities", default="emg,imu,mocap")
ap.add_argument("--t_obs", type=float, default=1.0)
ap.add_argument("--t_fut", type=float, default=0.5)
args = ap.parse_args()
tr, te = build_grasp_train_test(
input_modalities=args.input_modalities.split(","),
t_obs_sec=args.t_obs, t_fut_sec=args.t_fut,
)
x, y, et, meta = tr[0]
print(f"sample: x={ {m: tuple(v.shape) for m,v in x.items()} } y={y} et={et}")