import torch from pytorch3d.transforms import euler_angles_to_matrix from train.trainer import build_policy def test_camera_rotation_geometry_changes_policy_representation( tiny_policy_config, tiny_trainer_config, tiny_batch, ): config = tiny_policy_config() batch = tiny_batch(chunk_size=config.decoder.chunk_size) extrinsics_a = batch["camera_extrinsics"].clone() extrinsics_b = batch["camera_extrinsics"].clone() shared_translation = torch.tensor([0.12, -0.05, 0.08], dtype=extrinsics_a.dtype) extrinsics_a[:, 0, :3, 3] = shared_translation extrinsics_b[:, 0, :3, 3] = shared_translation extrinsics_b[:, 0, :3, :3] = euler_angles_to_matrix( torch.tensor([[0.0, 0.0, 0.9]], dtype=extrinsics_b.dtype), "XYZ", )[0] policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal")) policy.eval() output_a = policy( images=batch["images"], depths=batch["depths"], depth_valid=batch["depth_valid"], camera_intrinsics=batch["camera_intrinsics"], camera_extrinsics=extrinsics_a, proprio=batch["proprio"], texts=batch["texts"], history_images=batch["history_images"], history_depths=batch["history_depths"], history_depth_valid=batch["history_depth_valid"], history_proprio=batch["history_proprio"], history_actions=batch["history_actions"], plan=False, use_geometry_tokens=True, use_camera_pose_tokens=True, ) output_b = policy( images=batch["images"], depths=batch["depths"], depth_valid=batch["depth_valid"], camera_intrinsics=batch["camera_intrinsics"], camera_extrinsics=extrinsics_b, proprio=batch["proprio"], texts=batch["texts"], history_images=batch["history_images"], history_depths=batch["history_depths"], history_depth_valid=batch["history_depth_valid"], history_proprio=batch["history_proprio"], history_actions=batch["history_actions"], plan=False, use_geometry_tokens=True, use_camera_pose_tokens=True, ) output_disabled = policy( images=batch["images"], depths=batch["depths"], depth_valid=batch["depth_valid"], camera_intrinsics=batch["camera_intrinsics"], camera_extrinsics=extrinsics_a, proprio=batch["proprio"], texts=batch["texts"], history_images=batch["history_images"], history_depths=batch["history_depths"], history_depth_valid=batch["history_depth_valid"], history_proprio=batch["history_proprio"], history_actions=batch["history_actions"], plan=False, use_geometry_tokens=False, use_camera_pose_tokens=False, ) output_disabled_rotated = policy( images=batch["images"], depths=batch["depths"], depth_valid=batch["depth_valid"], camera_intrinsics=batch["camera_intrinsics"], camera_extrinsics=extrinsics_b, proprio=batch["proprio"], texts=batch["texts"], history_images=batch["history_images"], history_depths=batch["history_depths"], history_depth_valid=batch["history_depth_valid"], history_proprio=batch["history_proprio"], history_actions=batch["history_actions"], plan=False, use_geometry_tokens=False, use_camera_pose_tokens=False, ) assert not torch.allclose(output_a["scene_tokens"], output_b["scene_tokens"]) assert torch.allclose( output_disabled["scene_tokens"], output_disabled_rotated["scene_tokens"], atol=1e-6, rtol=1e-5, )