File size: 4,900 Bytes
9c74dfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from __future__ import annotations

import torch

from models.action_decoder import ChunkDecoderConfig
from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig
from models.rvt_backbone import RVTVisualEncoder
from models.multiview_fusion import MultiViewFusionConfig
from models.observation_memory import ObservationMemoryConfig
from models.planner import PlannerConfig
from models.policy import PolicyConfig
from models.reveal_head import RevealHeadConfig
from models.world_model import RevealWMConfig
from train.trainer import TrainerConfig, build_policy


def _camera_batch() -> tuple[torch.Tensor, torch.Tensor]:
    intrinsics = torch.eye(3).view(1, 1, 3, 3).expand(1, 3, 3, 3).clone()
    intrinsics[:, :, 0, 0] = 30.0
    intrinsics[:, :, 1, 1] = 30.0
    intrinsics[:, :, 0, 2] = 16.0
    intrinsics[:, :, 1, 2] = 16.0
    extrinsics = torch.eye(4).view(1, 1, 4, 4).expand(1, 3, 4, 4).clone()
    extrinsics[:, 1, 0, 3] = -0.1
    extrinsics[:, 2, 0, 3] = 0.1
    return intrinsics, extrinsics


def test_rvt_backbone_emits_five_view_tokens() -> None:
    backbone = FrozenVLBackbone(
        FrozenVLBackboneConfig(
            backbone_type="rvt",
            hidden_dim=512,
            max_text_tokens=77,
            freeze_backbone=True,
            gradient_checkpointing=False,
            rvt_point_stride=4,
            rvt_max_points_per_view=128,
        )
    )
    texts = ["move the box together"]
    language_tokens = backbone.tokenize_text(texts, device=torch.device("cpu"))
    text_features = backbone.encode_text(language_tokens["input_ids"], language_tokens["attention_mask"])
    intrinsics, extrinsics = _camera_batch()
    tokens = backbone.encode_images(
        images=torch.rand(1, 3, 3, 32, 32),
        proprio=torch.rand(1, 32),
        language_tokens=text_features,
        depths=torch.rand(1, 3, 1, 32, 32) * 0.3 + 0.2,
        camera_intrinsics=intrinsics,
        camera_extrinsics=extrinsics,
    )

    assert tuple(tokens.shape) == (1, 5, 400, 512)


def test_rvt_backbone_uses_fixed_scene_bounds_for_normalization() -> None:
    encoder = RVTVisualEncoder(
        checkpoint_path="/workspace/models/rvt_official/rvt/model_14.pth",
        mvt_cfg_path="/workspace/models/rvt_official/rvt/mvt_cfg.yaml",
        output_dim=512,
        input_proprio_dim=32,
        renderer_device="cpu",
        point_stride=4,
        max_points_per_view=128,
    )
    points = torch.tensor(
        [
            [-0.3, -0.5, 0.6],
            [0.7, 0.5, 1.6],
            [0.2, 0.0, 1.1],
        ],
        dtype=torch.float32,
    )
    normalized = encoder._normalize_world_points(points)

    assert torch.allclose(normalized[0], torch.tensor([-1.0, -1.0, -1.0]))
    assert torch.allclose(normalized[1], torch.tensor([1.0, 1.0, 1.0]))
    assert torch.allclose(normalized[2], torch.tensor([0.0, 0.0, 0.0]), atol=1e-6)


def test_backbone_only_policy_accepts_rvt_backbone() -> None:
    intrinsics, extrinsics = _camera_batch()
    policy = build_policy(
        PolicyConfig(
            backbone=FrozenVLBackboneConfig(
                backbone_type="rvt",
                hidden_dim=512,
                max_text_tokens=77,
                freeze_backbone=True,
                gradient_checkpointing=False,
                rvt_point_stride=4,
                rvt_max_points_per_view=128,
            ),
            fusion=MultiViewFusionConfig(
                hidden_dim=512,
                num_cameras=5,
                num_layers=1,
                num_heads=8,
                ff_dim=1024,
                dropout=0.0,
                proprio_dim=32,
            ),
            memory=ObservationMemoryConfig(
                hidden_dim=512,
                history_steps=1,
                num_layers=1,
                dropout=0.0,
            ),
            decoder=ChunkDecoderConfig(
                hidden_dim=512,
                num_heads=8,
                num_layers=1,
                ff_dim=1024,
                dropout=0.0,
                chunk_size=2,
                action_dim=14,
                num_candidates=2,
            ),
            reveal_head=RevealHeadConfig(hidden_dim=512),
            world_model=RevealWMConfig(hidden_dim=512),
            planner=PlannerConfig(hidden_dim=512, num_candidates=2),
        ),
        TrainerConfig(
            policy_type="backbone_only",
            use_bf16=False,
            freeze_backbone=True,
            gradient_checkpointing=False,
        ),
    )
    outputs = policy(
        images=torch.rand(1, 3, 3, 32, 32),
        depths=torch.rand(1, 3, 1, 32, 32) * 0.3 + 0.2,
        camera_intrinsics=intrinsics,
        camera_extrinsics=extrinsics,
        proprio=torch.rand(1, 32),
        texts=["move the box together"],
    )

    assert tuple(outputs["scene_tokens"].shape) == (1, 2007, 512)
    assert tuple(outputs["action_mean"].shape) == (1, 2, 14)