File size: 3,411 Bytes
e7d8e79 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | import torch
from pytorch3d.transforms import euler_angles_to_matrix
from train.trainer import build_policy
def _rotated_extrinsics(extrinsics: torch.Tensor, angle_radians: float) -> torch.Tensor:
rotated = extrinsics.clone()
rotation = euler_angles_to_matrix(
torch.tensor([[0.0, angle_radians, 0.0]], dtype=rotated.dtype),
"XYZ",
)[0]
rotated[:, 1, :3, :3] = rotation
return rotated
def test_geometry_tokens_propagate(tiny_policy_config, tiny_trainer_config, tiny_batch):
config = tiny_policy_config()
batch = tiny_batch(chunk_size=config.decoder.chunk_size)
policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
policy.eval()
rotated_extrinsics = _rotated_extrinsics(batch["camera_extrinsics"], angle_radians=0.7)
encoded_a = policy.backbone.encode_images(
batch["images"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=batch["camera_extrinsics"],
return_aux=True,
)
encoded_b = policy.backbone.encode_images(
batch["images"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=rotated_extrinsics,
return_aux=True,
)
assert not torch.allclose(encoded_a["geometry_tokens"], encoded_b["geometry_tokens"])
scene_a = policy._encode_scene_with_optional_depth(
images=batch["images"],
proprio=batch["proprio"],
texts=batch["texts"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=batch["camera_extrinsics"],
use_depth=True,
use_geometry_tokens=True,
use_camera_pose_tokens=True,
)["scene_tokens"]
scene_b = policy._encode_scene_with_optional_depth(
images=batch["images"],
proprio=batch["proprio"],
texts=batch["texts"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=rotated_extrinsics,
use_depth=True,
use_geometry_tokens=True,
use_camera_pose_tokens=True,
)["scene_tokens"]
assert not torch.allclose(scene_a, scene_b)
scene_disabled = policy._encode_scene_with_optional_depth(
images=batch["images"],
proprio=batch["proprio"],
texts=batch["texts"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=batch["camera_extrinsics"],
use_depth=True,
use_geometry_tokens=False,
use_camera_pose_tokens=False,
)["scene_tokens"]
scene_disabled_rotated = policy._encode_scene_with_optional_depth(
images=batch["images"],
proprio=batch["proprio"],
texts=batch["texts"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=rotated_extrinsics,
use_depth=True,
use_geometry_tokens=False,
use_camera_pose_tokens=False,
)["scene_tokens"]
assert torch.allclose(scene_disabled, scene_disabled_rotated, atol=1e-6, rtol=1e-5)
|