| import torch |
|
|
| from pytorch3d.transforms import euler_angles_to_matrix |
| from train.trainer import build_policy |
|
|
|
|
| def test_camera_rotation_geometry_changes_policy_representation( |
| tiny_policy_config, |
| tiny_trainer_config, |
| tiny_batch, |
| ): |
| config = tiny_policy_config() |
| batch = tiny_batch(chunk_size=config.decoder.chunk_size) |
| extrinsics_a = batch["camera_extrinsics"].clone() |
| extrinsics_b = batch["camera_extrinsics"].clone() |
| shared_translation = torch.tensor([0.12, -0.05, 0.08], dtype=extrinsics_a.dtype) |
| extrinsics_a[:, 0, :3, 3] = shared_translation |
| extrinsics_b[:, 0, :3, 3] = shared_translation |
| extrinsics_b[:, 0, :3, :3] = euler_angles_to_matrix( |
| torch.tensor([[0.0, 0.0, 0.9]], dtype=extrinsics_b.dtype), |
| "XYZ", |
| )[0] |
|
|
| policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal")) |
| policy.eval() |
|
|
| output_a = policy( |
| images=batch["images"], |
| depths=batch["depths"], |
| depth_valid=batch["depth_valid"], |
| camera_intrinsics=batch["camera_intrinsics"], |
| camera_extrinsics=extrinsics_a, |
| proprio=batch["proprio"], |
| texts=batch["texts"], |
| history_images=batch["history_images"], |
| history_depths=batch["history_depths"], |
| history_depth_valid=batch["history_depth_valid"], |
| history_proprio=batch["history_proprio"], |
| history_actions=batch["history_actions"], |
| plan=False, |
| use_geometry_tokens=True, |
| use_camera_pose_tokens=True, |
| ) |
| output_b = policy( |
| images=batch["images"], |
| depths=batch["depths"], |
| depth_valid=batch["depth_valid"], |
| camera_intrinsics=batch["camera_intrinsics"], |
| camera_extrinsics=extrinsics_b, |
| proprio=batch["proprio"], |
| texts=batch["texts"], |
| history_images=batch["history_images"], |
| history_depths=batch["history_depths"], |
| history_depth_valid=batch["history_depth_valid"], |
| history_proprio=batch["history_proprio"], |
| history_actions=batch["history_actions"], |
| plan=False, |
| use_geometry_tokens=True, |
| use_camera_pose_tokens=True, |
| ) |
| output_disabled = policy( |
| images=batch["images"], |
| depths=batch["depths"], |
| depth_valid=batch["depth_valid"], |
| camera_intrinsics=batch["camera_intrinsics"], |
| camera_extrinsics=extrinsics_a, |
| proprio=batch["proprio"], |
| texts=batch["texts"], |
| history_images=batch["history_images"], |
| history_depths=batch["history_depths"], |
| history_depth_valid=batch["history_depth_valid"], |
| history_proprio=batch["history_proprio"], |
| history_actions=batch["history_actions"], |
| plan=False, |
| use_geometry_tokens=False, |
| use_camera_pose_tokens=False, |
| ) |
| output_disabled_rotated = policy( |
| images=batch["images"], |
| depths=batch["depths"], |
| depth_valid=batch["depth_valid"], |
| camera_intrinsics=batch["camera_intrinsics"], |
| camera_extrinsics=extrinsics_b, |
| proprio=batch["proprio"], |
| texts=batch["texts"], |
| history_images=batch["history_images"], |
| history_depths=batch["history_depths"], |
| history_depth_valid=batch["history_depth_valid"], |
| history_proprio=batch["history_proprio"], |
| history_actions=batch["history_actions"], |
| plan=False, |
| use_geometry_tokens=False, |
| use_camera_pose_tokens=False, |
| ) |
|
|
| assert not torch.allclose(output_a["scene_tokens"], output_b["scene_tokens"]) |
| assert torch.allclose( |
| output_disabled["scene_tokens"], |
| output_disabled_rotated["scene_tokens"], |
| atol=1e-6, |
| rtol=1e-5, |
| ) |
|
|