File size: 3,709 Bytes
e7d8e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch

from pytorch3d.transforms import euler_angles_to_matrix
from train.trainer import build_policy


def test_camera_rotation_geometry_changes_policy_representation(
    tiny_policy_config,
    tiny_trainer_config,
    tiny_batch,
):
    config = tiny_policy_config()
    batch = tiny_batch(chunk_size=config.decoder.chunk_size)
    extrinsics_a = batch["camera_extrinsics"].clone()
    extrinsics_b = batch["camera_extrinsics"].clone()
    shared_translation = torch.tensor([0.12, -0.05, 0.08], dtype=extrinsics_a.dtype)
    extrinsics_a[:, 0, :3, 3] = shared_translation
    extrinsics_b[:, 0, :3, 3] = shared_translation
    extrinsics_b[:, 0, :3, :3] = euler_angles_to_matrix(
        torch.tensor([[0.0, 0.0, 0.9]], dtype=extrinsics_b.dtype),
        "XYZ",
    )[0]

    policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
    policy.eval()

    output_a = policy(
        images=batch["images"],
        depths=batch["depths"],
        depth_valid=batch["depth_valid"],
        camera_intrinsics=batch["camera_intrinsics"],
        camera_extrinsics=extrinsics_a,
        proprio=batch["proprio"],
        texts=batch["texts"],
        history_images=batch["history_images"],
        history_depths=batch["history_depths"],
        history_depth_valid=batch["history_depth_valid"],
        history_proprio=batch["history_proprio"],
        history_actions=batch["history_actions"],
        plan=False,
        use_geometry_tokens=True,
        use_camera_pose_tokens=True,
    )
    output_b = policy(
        images=batch["images"],
        depths=batch["depths"],
        depth_valid=batch["depth_valid"],
        camera_intrinsics=batch["camera_intrinsics"],
        camera_extrinsics=extrinsics_b,
        proprio=batch["proprio"],
        texts=batch["texts"],
        history_images=batch["history_images"],
        history_depths=batch["history_depths"],
        history_depth_valid=batch["history_depth_valid"],
        history_proprio=batch["history_proprio"],
        history_actions=batch["history_actions"],
        plan=False,
        use_geometry_tokens=True,
        use_camera_pose_tokens=True,
    )
    output_disabled = policy(
        images=batch["images"],
        depths=batch["depths"],
        depth_valid=batch["depth_valid"],
        camera_intrinsics=batch["camera_intrinsics"],
        camera_extrinsics=extrinsics_a,
        proprio=batch["proprio"],
        texts=batch["texts"],
        history_images=batch["history_images"],
        history_depths=batch["history_depths"],
        history_depth_valid=batch["history_depth_valid"],
        history_proprio=batch["history_proprio"],
        history_actions=batch["history_actions"],
        plan=False,
        use_geometry_tokens=False,
        use_camera_pose_tokens=False,
    )
    output_disabled_rotated = policy(
        images=batch["images"],
        depths=batch["depths"],
        depth_valid=batch["depth_valid"],
        camera_intrinsics=batch["camera_intrinsics"],
        camera_extrinsics=extrinsics_b,
        proprio=batch["proprio"],
        texts=batch["texts"],
        history_images=batch["history_images"],
        history_depths=batch["history_depths"],
        history_depth_valid=batch["history_depth_valid"],
        history_proprio=batch["history_proprio"],
        history_actions=batch["history_actions"],
        plan=False,
        use_geometry_tokens=False,
        use_camera_pose_tokens=False,
    )

    assert not torch.allclose(output_a["scene_tokens"], output_b["scene_tokens"])
    assert torch.allclose(
        output_disabled["scene_tokens"],
        output_disabled_rotated["scene_tokens"],
        atol=1e-6,
        rtol=1e-5,
    )