File size: 2,717 Bytes
9c74dfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from __future__ import annotations

import pickle
from pathlib import Path
from types import SimpleNamespace

import numpy as np
from PIL import Image
import torch
from torch.utils.data import DataLoader

from sim_rlbench.dataset import RLBenchOfflineChunkDataset, sample_weight_from_action_norms


def _arm(pose_xyz: tuple[float, float, float], gripper_open: float) -> SimpleNamespace:
    return SimpleNamespace(
        gripper_pose=np.asarray([*pose_xyz, 0.0, 0.0, 0.0, 1.0], dtype=np.float32),
        joint_positions=np.linspace(0.0, 0.6, 7, dtype=np.float32),
        gripper_open=float(gripper_open),
    )


def _obs(step: int) -> SimpleNamespace:
    offset = 0.01 * float(step)
    return SimpleNamespace(
        right=_arm((0.10 + offset, 0.20, 0.30), 1.0),
        left=_arm((-0.10 - offset, 0.15, 0.25), 0.0),
    )


def _write_rgb_frame(directory: Path, step: int, value: int) -> None:
    directory.mkdir(parents=True, exist_ok=True)
    image = np.full((8, 8, 3), fill_value=value, dtype=np.uint8)
    Image.fromarray(image).save(directory / f"rgb_{step:04d}.png")


def test_rlbench_dataset_emits_task_metadata(tmp_path: Path) -> None:
    task_name = "bimanual_push_box"
    episode_dir = tmp_path / task_name / "all_variations" / "episodes" / "episode0"
    episode_dir.mkdir(parents=True, exist_ok=True)

    with (episode_dir / "variation_descriptions.pkl").open("wb") as handle:
        pickle.dump(["push the box together"], handle)
    with (episode_dir / "low_dim_obs.pkl").open("wb") as handle:
        pickle.dump([_obs(0), _obs(1)], handle)

    for camera_name, pixel_value in (("front", 32), ("wrist_left", 96), ("wrist_right", 160)):
        camera_dir = episode_dir / f"{camera_name}_rgb"
        _write_rgb_frame(camera_dir, 0, pixel_value)
        _write_rgb_frame(camera_dir, 1, pixel_value + 1)

    dataset = RLBenchOfflineChunkDataset(
        dataset_root=tmp_path,
        tasks=[task_name],
        episode_indices=[0],
        resolution=8,
        chunk_size=2,
        history_steps=1,
    )

    item = dataset[0]
    assert item["task"] == task_name
    assert item["task_name"] == task_name
    assert int(item["task_id"]) == -1
    assert item["texts"] == "push the box together"

    batch = next(iter(DataLoader(dataset, batch_size=1, shuffle=False)))
    assert batch["task_name"] == [task_name]
    assert batch["task_id"].tolist() == [-1]


def test_sample_weight_from_action_norms_biases_large_actions() -> None:
    weights = sample_weight_from_action_norms(
        torch.tensor([0.01, 0.05, 0.20], dtype=torch.float32),
        min_norm=0.05,
        power=2.0,
    )

    assert np.allclose(weights.numpy(), np.asarray([1.0, 1.0, 16.0], dtype=np.float32))