File size: 3,411 Bytes
e7d8e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch

from pytorch3d.transforms import euler_angles_to_matrix
from train.trainer import build_policy


def _rotated_extrinsics(extrinsics: torch.Tensor, angle_radians: float) -> torch.Tensor:
    rotated = extrinsics.clone()
    rotation = euler_angles_to_matrix(
        torch.tensor([[0.0, angle_radians, 0.0]], dtype=rotated.dtype),
        "XYZ",
    )[0]
    rotated[:, 1, :3, :3] = rotation
    return rotated


def test_geometry_tokens_propagate(tiny_policy_config, tiny_trainer_config, tiny_batch):
    config = tiny_policy_config()
    batch = tiny_batch(chunk_size=config.decoder.chunk_size)
    policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
    policy.eval()

    rotated_extrinsics = _rotated_extrinsics(batch["camera_extrinsics"], angle_radians=0.7)

    encoded_a = policy.backbone.encode_images(
        batch["images"],
        depths=batch["depths"],
        depth_valid=batch["depth_valid"],
        camera_intrinsics=batch["camera_intrinsics"],
        camera_extrinsics=batch["camera_extrinsics"],
        return_aux=True,
    )
    encoded_b = policy.backbone.encode_images(
        batch["images"],
        depths=batch["depths"],
        depth_valid=batch["depth_valid"],
        camera_intrinsics=batch["camera_intrinsics"],
        camera_extrinsics=rotated_extrinsics,
        return_aux=True,
    )
    assert not torch.allclose(encoded_a["geometry_tokens"], encoded_b["geometry_tokens"])

    scene_a = policy._encode_scene_with_optional_depth(
        images=batch["images"],
        proprio=batch["proprio"],
        texts=batch["texts"],
        depths=batch["depths"],
        depth_valid=batch["depth_valid"],
        camera_intrinsics=batch["camera_intrinsics"],
        camera_extrinsics=batch["camera_extrinsics"],
        use_depth=True,
        use_geometry_tokens=True,
        use_camera_pose_tokens=True,
    )["scene_tokens"]
    scene_b = policy._encode_scene_with_optional_depth(
        images=batch["images"],
        proprio=batch["proprio"],
        texts=batch["texts"],
        depths=batch["depths"],
        depth_valid=batch["depth_valid"],
        camera_intrinsics=batch["camera_intrinsics"],
        camera_extrinsics=rotated_extrinsics,
        use_depth=True,
        use_geometry_tokens=True,
        use_camera_pose_tokens=True,
    )["scene_tokens"]
    assert not torch.allclose(scene_a, scene_b)

    scene_disabled = policy._encode_scene_with_optional_depth(
        images=batch["images"],
        proprio=batch["proprio"],
        texts=batch["texts"],
        depths=batch["depths"],
        depth_valid=batch["depth_valid"],
        camera_intrinsics=batch["camera_intrinsics"],
        camera_extrinsics=batch["camera_extrinsics"],
        use_depth=True,
        use_geometry_tokens=False,
        use_camera_pose_tokens=False,
    )["scene_tokens"]
    scene_disabled_rotated = policy._encode_scene_with_optional_depth(
        images=batch["images"],
        proprio=batch["proprio"],
        texts=batch["texts"],
        depths=batch["depths"],
        depth_valid=batch["depth_valid"],
        camera_intrinsics=batch["camera_intrinsics"],
        camera_extrinsics=rotated_extrinsics,
        use_depth=True,
        use_geometry_tokens=False,
        use_camera_pose_tokens=False,
    )["scene_tokens"]
    assert torch.allclose(scene_disabled, scene_disabled_rotated, atol=1e-6, rtol=1e-5)