|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pytest |
|
|
import torch |
|
|
|
|
|
from lerobot.configs.types import ( |
|
|
FeatureType, |
|
|
PipelineFeatureType, |
|
|
PolicyFeature, |
|
|
) |
|
|
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep |
|
|
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR |
|
|
|
|
|
|
|
|
BATCH_SIZE = 2 |
|
|
STATE_DIM = 16 |
|
|
IMG_HEIGHT = 64 |
|
|
IMG_WIDTH = 64 |
|
|
|
|
|
|
|
|
TEST_STATE_KEY = "test_state_obs" |
|
|
TEST_CAMERA_KEY = "test_rgb_cam" |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def processor(): |
|
|
"""Default processor with test keys.""" |
|
|
return IsaaclabArenaProcessorStep( |
|
|
state_keys=(TEST_STATE_KEY,), |
|
|
camera_keys=(TEST_CAMERA_KEY,), |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def sample_observation(): |
|
|
"""Sample IsaacLab Arena observation with state and camera data.""" |
|
|
return { |
|
|
f"{OBS_STR}.policy": { |
|
|
TEST_STATE_KEY: torch.randn(BATCH_SIZE, STATE_DIM), |
|
|
}, |
|
|
f"{OBS_STR}.camera_obs": { |
|
|
TEST_CAMERA_KEY: torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8), |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_state_extraction(processor, sample_observation): |
|
|
"""Test that state is extracted and converted to float32.""" |
|
|
processed = processor.observation(sample_observation) |
|
|
|
|
|
assert OBS_STATE in processed |
|
|
assert processed[OBS_STATE].shape == (BATCH_SIZE, STATE_DIM) |
|
|
assert processed[OBS_STATE].dtype == torch.float32 |
|
|
|
|
|
|
|
|
def test_state_concatenation_multiple_keys(): |
|
|
"""Test that multiple state keys are concatenated in order.""" |
|
|
dim1, dim2 = 10, 6 |
|
|
processor = IsaaclabArenaProcessorStep( |
|
|
state_keys=("state_alpha", "state_beta"), |
|
|
camera_keys=(), |
|
|
) |
|
|
|
|
|
obs = { |
|
|
f"{OBS_STR}.policy": { |
|
|
"state_alpha": torch.ones(BATCH_SIZE, dim1), |
|
|
"state_beta": torch.ones(BATCH_SIZE, dim2) * 2, |
|
|
}, |
|
|
} |
|
|
|
|
|
processed = processor.observation(obs) |
|
|
|
|
|
state = processed[OBS_STATE] |
|
|
assert state.shape == (BATCH_SIZE, dim1 + dim2) |
|
|
|
|
|
assert torch.all(state[:, :dim1] == 1.0) |
|
|
assert torch.all(state[:, dim1:] == 2.0) |
|
|
|
|
|
|
|
|
def test_state_flattening_higher_dims(): |
|
|
"""Test that state with dim > 2 is flattened to (B, -1).""" |
|
|
processor = IsaaclabArenaProcessorStep( |
|
|
state_keys=("multidim_state",), |
|
|
camera_keys=(), |
|
|
) |
|
|
|
|
|
|
|
|
obs = { |
|
|
f"{OBS_STR}.policy": { |
|
|
"multidim_state": torch.randn(BATCH_SIZE, 4, 4), |
|
|
}, |
|
|
} |
|
|
|
|
|
processed = processor.observation(obs) |
|
|
|
|
|
assert processed[OBS_STATE].shape == (BATCH_SIZE, 16) |
|
|
|
|
|
|
|
|
def test_state_filters_to_configured_keys(): |
|
|
"""Test that only configured state_keys are extracted.""" |
|
|
processor = IsaaclabArenaProcessorStep( |
|
|
state_keys=("included_key",), |
|
|
camera_keys=(), |
|
|
) |
|
|
|
|
|
obs = { |
|
|
f"{OBS_STR}.policy": { |
|
|
"included_key": torch.randn(BATCH_SIZE, 10), |
|
|
"excluded_key": torch.randn(BATCH_SIZE, 6), |
|
|
}, |
|
|
} |
|
|
|
|
|
processed = processor.observation(obs) |
|
|
|
|
|
|
|
|
assert processed[OBS_STATE].shape == (BATCH_SIZE, 10) |
|
|
|
|
|
|
|
|
def test_missing_state_key_skipped(): |
|
|
"""Test that missing state keys in observation are skipped.""" |
|
|
processor = IsaaclabArenaProcessorStep( |
|
|
state_keys=("present_key", "missing_key"), |
|
|
camera_keys=(), |
|
|
) |
|
|
|
|
|
obs = { |
|
|
f"{OBS_STR}.policy": { |
|
|
"present_key": torch.randn(BATCH_SIZE, 10), |
|
|
|
|
|
}, |
|
|
} |
|
|
|
|
|
processed = processor.observation(obs) |
|
|
|
|
|
|
|
|
assert processed[OBS_STATE].shape == (BATCH_SIZE, 10) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_camera_permutation_bhwc_to_bchw(processor, sample_observation): |
|
|
"""Test images are permuted from (B, H, W, C) to (B, C, H, W).""" |
|
|
processed = processor.observation(sample_observation) |
|
|
|
|
|
img_key = f"{OBS_IMAGES}.{TEST_CAMERA_KEY}" |
|
|
assert img_key in processed |
|
|
img = processed[img_key] |
|
|
assert img.shape == (BATCH_SIZE, 3, IMG_HEIGHT, IMG_WIDTH) |
|
|
|
|
|
|
|
|
def test_camera_uint8_to_normalized_float32(processor): |
|
|
"""Test that uint8 images are normalized to float32 [0, 1].""" |
|
|
obs = { |
|
|
f"{OBS_STR}.camera_obs": { |
|
|
TEST_CAMERA_KEY: torch.full((BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), 255, dtype=torch.uint8), |
|
|
}, |
|
|
} |
|
|
|
|
|
processed = processor.observation(obs) |
|
|
|
|
|
img = processed[f"{OBS_IMAGES}.{TEST_CAMERA_KEY}"] |
|
|
assert img.dtype == torch.float32 |
|
|
assert torch.allclose(img, torch.ones_like(img)) |
|
|
|
|
|
|
|
|
def test_camera_float32_passthrough(processor): |
|
|
"""Test that float32 images are kept as float32.""" |
|
|
original_img = torch.rand(BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3, dtype=torch.float32) |
|
|
obs = { |
|
|
f"{OBS_STR}.camera_obs": { |
|
|
TEST_CAMERA_KEY: original_img.clone(), |
|
|
}, |
|
|
} |
|
|
|
|
|
processed = processor.observation(obs) |
|
|
|
|
|
img = processed[f"{OBS_IMAGES}.{TEST_CAMERA_KEY}"] |
|
|
assert img.dtype == torch.float32 |
|
|
|
|
|
expected = original_img.permute(0, 3, 1, 2) |
|
|
assert torch.allclose(img, expected) |
|
|
|
|
|
|
|
|
def test_camera_other_dtype_converted_to_float(processor): |
|
|
"""Test that non-uint8, non-float32 dtypes are converted to float.""" |
|
|
obs = { |
|
|
f"{OBS_STR}.camera_obs": { |
|
|
TEST_CAMERA_KEY: torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.int32), |
|
|
}, |
|
|
} |
|
|
|
|
|
processed = processor.observation(obs) |
|
|
|
|
|
img = processed[f"{OBS_IMAGES}.{TEST_CAMERA_KEY}"] |
|
|
assert img.dtype == torch.float32 |
|
|
|
|
|
|
|
|
def test_camera_filters_to_configured_keys(): |
|
|
"""Test that only configured camera_keys are extracted.""" |
|
|
processor = IsaaclabArenaProcessorStep( |
|
|
state_keys=(), |
|
|
camera_keys=("included_cam",), |
|
|
) |
|
|
|
|
|
obs = { |
|
|
f"{OBS_STR}.camera_obs": { |
|
|
"included_cam": torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8), |
|
|
"excluded_cam": torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8), |
|
|
}, |
|
|
} |
|
|
|
|
|
processed = processor.observation(obs) |
|
|
|
|
|
assert f"{OBS_IMAGES}.included_cam" in processed |
|
|
assert f"{OBS_IMAGES}.excluded_cam" not in processed |
|
|
|
|
|
|
|
|
def test_camera_key_preserved_exactly(): |
|
|
"""Test that camera key name is used exactly (no suffix stripping).""" |
|
|
processor = IsaaclabArenaProcessorStep( |
|
|
state_keys=(), |
|
|
camera_keys=("my_cam_rgb",), |
|
|
) |
|
|
|
|
|
obs = { |
|
|
f"{OBS_STR}.camera_obs": { |
|
|
"my_cam_rgb": torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8), |
|
|
}, |
|
|
} |
|
|
|
|
|
processed = processor.observation(obs) |
|
|
|
|
|
|
|
|
assert f"{OBS_IMAGES}.my_cam_rgb" in processed |
|
|
assert f"{OBS_IMAGES}.my_cam" not in processed |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_missing_camera_obs_section(processor): |
|
|
"""Test processor handles observation without camera_obs section.""" |
|
|
obs = { |
|
|
f"{OBS_STR}.policy": { |
|
|
TEST_STATE_KEY: torch.randn(BATCH_SIZE, STATE_DIM), |
|
|
}, |
|
|
} |
|
|
|
|
|
processed = processor.observation(obs) |
|
|
|
|
|
assert OBS_STATE in processed |
|
|
assert not any(k.startswith(OBS_IMAGES) for k in processed) |
|
|
|
|
|
|
|
|
def test_missing_policy_obs_section(processor): |
|
|
"""Test processor handles observation without policy section.""" |
|
|
obs = { |
|
|
f"{OBS_STR}.camera_obs": { |
|
|
TEST_CAMERA_KEY: torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8), |
|
|
}, |
|
|
} |
|
|
|
|
|
processed = processor.observation(obs) |
|
|
|
|
|
assert f"{OBS_IMAGES}.{TEST_CAMERA_KEY}" in processed |
|
|
assert OBS_STATE not in processed |
|
|
|
|
|
|
|
|
def test_empty_observation(processor): |
|
|
"""Test processor handles empty observation dict.""" |
|
|
processed = processor.observation({}) |
|
|
|
|
|
assert OBS_STATE not in processed |
|
|
assert not any(k.startswith(OBS_IMAGES) for k in processed) |
|
|
|
|
|
|
|
|
def test_no_matching_state_keys(): |
|
|
"""Test processor when no state keys match observation.""" |
|
|
processor = IsaaclabArenaProcessorStep( |
|
|
state_keys=("nonexistent_key",), |
|
|
camera_keys=(), |
|
|
) |
|
|
|
|
|
obs = { |
|
|
f"{OBS_STR}.policy": { |
|
|
"some_other_key": torch.randn(BATCH_SIZE, STATE_DIM), |
|
|
}, |
|
|
} |
|
|
|
|
|
processed = processor.observation(obs) |
|
|
|
|
|
|
|
|
assert OBS_STATE not in processed |
|
|
|
|
|
|
|
|
def test_no_matching_camera_keys(): |
|
|
"""Test processor when no camera keys match observation.""" |
|
|
processor = IsaaclabArenaProcessorStep( |
|
|
state_keys=(), |
|
|
camera_keys=("nonexistent_cam",), |
|
|
) |
|
|
|
|
|
obs = { |
|
|
f"{OBS_STR}.camera_obs": { |
|
|
"some_other_cam": torch.randint( |
|
|
0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8 |
|
|
), |
|
|
}, |
|
|
} |
|
|
|
|
|
processed = processor.observation(obs) |
|
|
|
|
|
assert not any(k.startswith(OBS_IMAGES) for k in processed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_default_keys(): |
|
|
"""Test default state_keys and camera_keys values.""" |
|
|
processor = IsaaclabArenaProcessorStep() |
|
|
|
|
|
assert processor.state_keys == ("robot_joint_pos",) |
|
|
assert processor.camera_keys == ("robot_pov_cam_rgb",) |
|
|
|
|
|
|
|
|
def test_custom_keys_configuration(): |
|
|
"""Test processor with custom state and camera keys.""" |
|
|
processor = IsaaclabArenaProcessorStep( |
|
|
state_keys=("pos_xyz", "quat_wxyz", "grip_val"), |
|
|
camera_keys=("front_view", "wrist_view"), |
|
|
) |
|
|
|
|
|
obs = { |
|
|
f"{OBS_STR}.policy": { |
|
|
"pos_xyz": torch.randn(BATCH_SIZE, 3), |
|
|
"quat_wxyz": torch.randn(BATCH_SIZE, 4), |
|
|
"grip_val": torch.randn(BATCH_SIZE, 1), |
|
|
}, |
|
|
f"{OBS_STR}.camera_obs": { |
|
|
"front_view": torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8), |
|
|
"wrist_view": torch.randint(0, 255, (BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3), dtype=torch.uint8), |
|
|
}, |
|
|
} |
|
|
|
|
|
processed = processor.observation(obs) |
|
|
|
|
|
|
|
|
assert processed[OBS_STATE].shape == (BATCH_SIZE, 8) |
|
|
|
|
|
assert f"{OBS_IMAGES}.front_view" in processed |
|
|
assert f"{OBS_IMAGES}.wrist_view" in processed |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_transform_features_passthrough(processor): |
|
|
"""Test that transform_features returns features unchanged.""" |
|
|
input_features = { |
|
|
PipelineFeatureType.OBSERVATION: { |
|
|
"observation.state": PolicyFeature( |
|
|
type=FeatureType.STATE, |
|
|
shape=(16,), |
|
|
), |
|
|
"observation.images.cam": PolicyFeature( |
|
|
type=FeatureType.VISUAL, |
|
|
shape=(3, 64, 64), |
|
|
), |
|
|
}, |
|
|
PipelineFeatureType.ACTION: { |
|
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), |
|
|
}, |
|
|
} |
|
|
|
|
|
output_features = processor.transform_features(input_features) |
|
|
|
|
|
|
|
|
assert output_features == input_features |
|
|
|