VLAarchtests / tests /test_camera_rotation_geometry.py
lsnu's picture
2026-03-25 runpod handoff update
e7d8e79 verified
import torch
from pytorch3d.transforms import euler_angles_to_matrix
from train.trainer import build_policy
def test_camera_rotation_geometry_changes_policy_representation(
tiny_policy_config,
tiny_trainer_config,
tiny_batch,
):
config = tiny_policy_config()
batch = tiny_batch(chunk_size=config.decoder.chunk_size)
extrinsics_a = batch["camera_extrinsics"].clone()
extrinsics_b = batch["camera_extrinsics"].clone()
shared_translation = torch.tensor([0.12, -0.05, 0.08], dtype=extrinsics_a.dtype)
extrinsics_a[:, 0, :3, 3] = shared_translation
extrinsics_b[:, 0, :3, 3] = shared_translation
extrinsics_b[:, 0, :3, :3] = euler_angles_to_matrix(
torch.tensor([[0.0, 0.0, 0.9]], dtype=extrinsics_b.dtype),
"XYZ",
)[0]
policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
policy.eval()
output_a = policy(
images=batch["images"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=extrinsics_a,
proprio=batch["proprio"],
texts=batch["texts"],
history_images=batch["history_images"],
history_depths=batch["history_depths"],
history_depth_valid=batch["history_depth_valid"],
history_proprio=batch["history_proprio"],
history_actions=batch["history_actions"],
plan=False,
use_geometry_tokens=True,
use_camera_pose_tokens=True,
)
output_b = policy(
images=batch["images"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=extrinsics_b,
proprio=batch["proprio"],
texts=batch["texts"],
history_images=batch["history_images"],
history_depths=batch["history_depths"],
history_depth_valid=batch["history_depth_valid"],
history_proprio=batch["history_proprio"],
history_actions=batch["history_actions"],
plan=False,
use_geometry_tokens=True,
use_camera_pose_tokens=True,
)
output_disabled = policy(
images=batch["images"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=extrinsics_a,
proprio=batch["proprio"],
texts=batch["texts"],
history_images=batch["history_images"],
history_depths=batch["history_depths"],
history_depth_valid=batch["history_depth_valid"],
history_proprio=batch["history_proprio"],
history_actions=batch["history_actions"],
plan=False,
use_geometry_tokens=False,
use_camera_pose_tokens=False,
)
output_disabled_rotated = policy(
images=batch["images"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=extrinsics_b,
proprio=batch["proprio"],
texts=batch["texts"],
history_images=batch["history_images"],
history_depths=batch["history_depths"],
history_depth_valid=batch["history_depth_valid"],
history_proprio=batch["history_proprio"],
history_actions=batch["history_actions"],
plan=False,
use_geometry_tokens=False,
use_camera_pose_tokens=False,
)
assert not torch.allclose(output_a["scene_tokens"], output_b["scene_tokens"])
assert torch.allclose(
output_disabled["scene_tokens"],
output_disabled_rotated["scene_tokens"],
atol=1e-6,
rtol=1e-5,
)