VLAarchtests / tests /test_geometry_tokens_propagate.py
lsnu's picture
2026-03-25 runpod handoff update
e7d8e79 verified
import torch
from pytorch3d.transforms import euler_angles_to_matrix
from train.trainer import build_policy
def _rotated_extrinsics(extrinsics: torch.Tensor, angle_radians: float) -> torch.Tensor:
rotated = extrinsics.clone()
rotation = euler_angles_to_matrix(
torch.tensor([[0.0, angle_radians, 0.0]], dtype=rotated.dtype),
"XYZ",
)[0]
rotated[:, 1, :3, :3] = rotation
return rotated
def test_geometry_tokens_propagate(tiny_policy_config, tiny_trainer_config, tiny_batch):
config = tiny_policy_config()
batch = tiny_batch(chunk_size=config.decoder.chunk_size)
policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
policy.eval()
rotated_extrinsics = _rotated_extrinsics(batch["camera_extrinsics"], angle_radians=0.7)
encoded_a = policy.backbone.encode_images(
batch["images"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=batch["camera_extrinsics"],
return_aux=True,
)
encoded_b = policy.backbone.encode_images(
batch["images"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=rotated_extrinsics,
return_aux=True,
)
assert not torch.allclose(encoded_a["geometry_tokens"], encoded_b["geometry_tokens"])
scene_a = policy._encode_scene_with_optional_depth(
images=batch["images"],
proprio=batch["proprio"],
texts=batch["texts"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=batch["camera_extrinsics"],
use_depth=True,
use_geometry_tokens=True,
use_camera_pose_tokens=True,
)["scene_tokens"]
scene_b = policy._encode_scene_with_optional_depth(
images=batch["images"],
proprio=batch["proprio"],
texts=batch["texts"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=rotated_extrinsics,
use_depth=True,
use_geometry_tokens=True,
use_camera_pose_tokens=True,
)["scene_tokens"]
assert not torch.allclose(scene_a, scene_b)
scene_disabled = policy._encode_scene_with_optional_depth(
images=batch["images"],
proprio=batch["proprio"],
texts=batch["texts"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=batch["camera_extrinsics"],
use_depth=True,
use_geometry_tokens=False,
use_camera_pose_tokens=False,
)["scene_tokens"]
scene_disabled_rotated = policy._encode_scene_with_optional_depth(
images=batch["images"],
proprio=batch["proprio"],
texts=batch["texts"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=rotated_extrinsics,
use_depth=True,
use_geometry_tokens=False,
use_camera_pose_tokens=False,
)["scene_tokens"]
assert torch.allclose(scene_disabled, scene_disabled_rotated, atol=1e-6, rtol=1e-5)