import torch from pytorch3d.transforms import euler_angles_to_matrix from train.trainer import build_policy def _rotated_extrinsics(extrinsics: torch.Tensor, angle_radians: float) -> torch.Tensor: rotated = extrinsics.clone() rotation = euler_angles_to_matrix( torch.tensor([[0.0, angle_radians, 0.0]], dtype=rotated.dtype), "XYZ", )[0] rotated[:, 1, :3, :3] = rotation return rotated def test_geometry_tokens_propagate(tiny_policy_config, tiny_trainer_config, tiny_batch): config = tiny_policy_config() batch = tiny_batch(chunk_size=config.decoder.chunk_size) policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal")) policy.eval() rotated_extrinsics = _rotated_extrinsics(batch["camera_extrinsics"], angle_radians=0.7) encoded_a = policy.backbone.encode_images( batch["images"], depths=batch["depths"], depth_valid=batch["depth_valid"], camera_intrinsics=batch["camera_intrinsics"], camera_extrinsics=batch["camera_extrinsics"], return_aux=True, ) encoded_b = policy.backbone.encode_images( batch["images"], depths=batch["depths"], depth_valid=batch["depth_valid"], camera_intrinsics=batch["camera_intrinsics"], camera_extrinsics=rotated_extrinsics, return_aux=True, ) assert not torch.allclose(encoded_a["geometry_tokens"], encoded_b["geometry_tokens"]) scene_a = policy._encode_scene_with_optional_depth( images=batch["images"], proprio=batch["proprio"], texts=batch["texts"], depths=batch["depths"], depth_valid=batch["depth_valid"], camera_intrinsics=batch["camera_intrinsics"], camera_extrinsics=batch["camera_extrinsics"], use_depth=True, use_geometry_tokens=True, use_camera_pose_tokens=True, )["scene_tokens"] scene_b = policy._encode_scene_with_optional_depth( images=batch["images"], proprio=batch["proprio"], texts=batch["texts"], depths=batch["depths"], depth_valid=batch["depth_valid"], camera_intrinsics=batch["camera_intrinsics"], camera_extrinsics=rotated_extrinsics, use_depth=True, use_geometry_tokens=True, use_camera_pose_tokens=True, )["scene_tokens"] assert not torch.allclose(scene_a, scene_b) scene_disabled = policy._encode_scene_with_optional_depth( images=batch["images"], proprio=batch["proprio"], texts=batch["texts"], depths=batch["depths"], depth_valid=batch["depth_valid"], camera_intrinsics=batch["camera_intrinsics"], camera_extrinsics=batch["camera_extrinsics"], use_depth=True, use_geometry_tokens=False, use_camera_pose_tokens=False, )["scene_tokens"] scene_disabled_rotated = policy._encode_scene_with_optional_depth( images=batch["images"], proprio=batch["proprio"], texts=batch["texts"], depths=batch["depths"], depth_valid=batch["depth_valid"], camera_intrinsics=batch["camera_intrinsics"], camera_extrinsics=rotated_extrinsics, use_depth=True, use_geometry_tokens=False, use_camera_pose_tokens=False, )["scene_tokens"] assert torch.allclose(scene_disabled, scene_disabled_rotated, atol=1e-6, rtol=1e-5)