from __future__ import annotations import pickle from pathlib import Path from types import SimpleNamespace import numpy as np from PIL import Image import torch from torch.utils.data import DataLoader from sim_rlbench.dataset import RLBenchOfflineChunkDataset, sample_weight_from_action_norms def _arm(pose_xyz: tuple[float, float, float], gripper_open: float) -> SimpleNamespace: return SimpleNamespace( gripper_pose=np.asarray([*pose_xyz, 0.0, 0.0, 0.0, 1.0], dtype=np.float32), joint_positions=np.linspace(0.0, 0.6, 7, dtype=np.float32), gripper_open=float(gripper_open), ) def _obs(step: int) -> SimpleNamespace: offset = 0.01 * float(step) return SimpleNamespace( right=_arm((0.10 + offset, 0.20, 0.30), 1.0), left=_arm((-0.10 - offset, 0.15, 0.25), 0.0), ) def _write_rgb_frame(directory: Path, step: int, value: int) -> None: directory.mkdir(parents=True, exist_ok=True) image = np.full((8, 8, 3), fill_value=value, dtype=np.uint8) Image.fromarray(image).save(directory / f"rgb_{step:04d}.png") def test_rlbench_dataset_emits_task_metadata(tmp_path: Path) -> None: task_name = "bimanual_push_box" episode_dir = tmp_path / task_name / "all_variations" / "episodes" / "episode0" episode_dir.mkdir(parents=True, exist_ok=True) with (episode_dir / "variation_descriptions.pkl").open("wb") as handle: pickle.dump(["push the box together"], handle) with (episode_dir / "low_dim_obs.pkl").open("wb") as handle: pickle.dump([_obs(0), _obs(1)], handle) for camera_name, pixel_value in (("front", 32), ("wrist_left", 96), ("wrist_right", 160)): camera_dir = episode_dir / f"{camera_name}_rgb" _write_rgb_frame(camera_dir, 0, pixel_value) _write_rgb_frame(camera_dir, 1, pixel_value + 1) dataset = RLBenchOfflineChunkDataset( dataset_root=tmp_path, tasks=[task_name], episode_indices=[0], resolution=8, chunk_size=2, history_steps=1, ) item = dataset[0] assert item["task"] == task_name assert item["task_name"] == task_name assert int(item["task_id"]) == -1 assert item["texts"] == "push the box together" batch = next(iter(DataLoader(dataset, batch_size=1, shuffle=False))) assert batch["task_name"] == [task_name] assert batch["task_id"].tolist() == [-1] def test_sample_weight_from_action_norms_biases_large_actions() -> None: weights = sample_weight_from_action_norms( torch.tensor([0.01, 0.05, 0.20], dtype=torch.float32), min_norm=0.05, power=2.0, ) assert np.allclose(weights.numpy(), np.asarray([1.0, 1.0, 16.0], dtype=np.float32))