| import torch |
|
|
| from pytorch3d.transforms import euler_angles_to_matrix |
| from train.trainer import build_policy |
|
|
|
|
| def _rotated_extrinsics(extrinsics: torch.Tensor, angle_radians: float) -> torch.Tensor: |
| rotated = extrinsics.clone() |
| rotation = euler_angles_to_matrix( |
| torch.tensor([[0.0, angle_radians, 0.0]], dtype=rotated.dtype), |
| "XYZ", |
| )[0] |
| rotated[:, 1, :3, :3] = rotation |
| return rotated |
|
|
|
|
| def test_geometry_tokens_propagate(tiny_policy_config, tiny_trainer_config, tiny_batch): |
| config = tiny_policy_config() |
| batch = tiny_batch(chunk_size=config.decoder.chunk_size) |
| policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal")) |
| policy.eval() |
|
|
| rotated_extrinsics = _rotated_extrinsics(batch["camera_extrinsics"], angle_radians=0.7) |
|
|
| encoded_a = policy.backbone.encode_images( |
| batch["images"], |
| depths=batch["depths"], |
| depth_valid=batch["depth_valid"], |
| camera_intrinsics=batch["camera_intrinsics"], |
| camera_extrinsics=batch["camera_extrinsics"], |
| return_aux=True, |
| ) |
| encoded_b = policy.backbone.encode_images( |
| batch["images"], |
| depths=batch["depths"], |
| depth_valid=batch["depth_valid"], |
| camera_intrinsics=batch["camera_intrinsics"], |
| camera_extrinsics=rotated_extrinsics, |
| return_aux=True, |
| ) |
| assert not torch.allclose(encoded_a["geometry_tokens"], encoded_b["geometry_tokens"]) |
|
|
| scene_a = policy._encode_scene_with_optional_depth( |
| images=batch["images"], |
| proprio=batch["proprio"], |
| texts=batch["texts"], |
| depths=batch["depths"], |
| depth_valid=batch["depth_valid"], |
| camera_intrinsics=batch["camera_intrinsics"], |
| camera_extrinsics=batch["camera_extrinsics"], |
| use_depth=True, |
| use_geometry_tokens=True, |
| use_camera_pose_tokens=True, |
| )["scene_tokens"] |
| scene_b = policy._encode_scene_with_optional_depth( |
| images=batch["images"], |
| proprio=batch["proprio"], |
| texts=batch["texts"], |
| depths=batch["depths"], |
| depth_valid=batch["depth_valid"], |
| camera_intrinsics=batch["camera_intrinsics"], |
| camera_extrinsics=rotated_extrinsics, |
| use_depth=True, |
| use_geometry_tokens=True, |
| use_camera_pose_tokens=True, |
| )["scene_tokens"] |
| assert not torch.allclose(scene_a, scene_b) |
|
|
| scene_disabled = policy._encode_scene_with_optional_depth( |
| images=batch["images"], |
| proprio=batch["proprio"], |
| texts=batch["texts"], |
| depths=batch["depths"], |
| depth_valid=batch["depth_valid"], |
| camera_intrinsics=batch["camera_intrinsics"], |
| camera_extrinsics=batch["camera_extrinsics"], |
| use_depth=True, |
| use_geometry_tokens=False, |
| use_camera_pose_tokens=False, |
| )["scene_tokens"] |
| scene_disabled_rotated = policy._encode_scene_with_optional_depth( |
| images=batch["images"], |
| proprio=batch["proprio"], |
| texts=batch["texts"], |
| depths=batch["depths"], |
| depth_valid=batch["depth_valid"], |
| camera_intrinsics=batch["camera_intrinsics"], |
| camera_extrinsics=rotated_extrinsics, |
| use_depth=True, |
| use_geometry_tokens=False, |
| use_camera_pose_tokens=False, |
| )["scene_tokens"] |
| assert torch.allclose(scene_disabled, scene_disabled_rotated, atol=1e-6, rtol=1e-5) |
|
|