File size: 4,900 Bytes
9c74dfe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | from __future__ import annotations
import torch
from models.action_decoder import ChunkDecoderConfig
from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig
from models.rvt_backbone import RVTVisualEncoder
from models.multiview_fusion import MultiViewFusionConfig
from models.observation_memory import ObservationMemoryConfig
from models.planner import PlannerConfig
from models.policy import PolicyConfig
from models.reveal_head import RevealHeadConfig
from models.world_model import RevealWMConfig
from train.trainer import TrainerConfig, build_policy
def _camera_batch() -> tuple[torch.Tensor, torch.Tensor]:
intrinsics = torch.eye(3).view(1, 1, 3, 3).expand(1, 3, 3, 3).clone()
intrinsics[:, :, 0, 0] = 30.0
intrinsics[:, :, 1, 1] = 30.0
intrinsics[:, :, 0, 2] = 16.0
intrinsics[:, :, 1, 2] = 16.0
extrinsics = torch.eye(4).view(1, 1, 4, 4).expand(1, 3, 4, 4).clone()
extrinsics[:, 1, 0, 3] = -0.1
extrinsics[:, 2, 0, 3] = 0.1
return intrinsics, extrinsics
def test_rvt_backbone_emits_five_view_tokens() -> None:
backbone = FrozenVLBackbone(
FrozenVLBackboneConfig(
backbone_type="rvt",
hidden_dim=512,
max_text_tokens=77,
freeze_backbone=True,
gradient_checkpointing=False,
rvt_point_stride=4,
rvt_max_points_per_view=128,
)
)
texts = ["move the box together"]
language_tokens = backbone.tokenize_text(texts, device=torch.device("cpu"))
text_features = backbone.encode_text(language_tokens["input_ids"], language_tokens["attention_mask"])
intrinsics, extrinsics = _camera_batch()
tokens = backbone.encode_images(
images=torch.rand(1, 3, 3, 32, 32),
proprio=torch.rand(1, 32),
language_tokens=text_features,
depths=torch.rand(1, 3, 1, 32, 32) * 0.3 + 0.2,
camera_intrinsics=intrinsics,
camera_extrinsics=extrinsics,
)
assert tuple(tokens.shape) == (1, 5, 400, 512)
def test_rvt_backbone_uses_fixed_scene_bounds_for_normalization() -> None:
encoder = RVTVisualEncoder(
checkpoint_path="/workspace/models/rvt_official/rvt/model_14.pth",
mvt_cfg_path="/workspace/models/rvt_official/rvt/mvt_cfg.yaml",
output_dim=512,
input_proprio_dim=32,
renderer_device="cpu",
point_stride=4,
max_points_per_view=128,
)
points = torch.tensor(
[
[-0.3, -0.5, 0.6],
[0.7, 0.5, 1.6],
[0.2, 0.0, 1.1],
],
dtype=torch.float32,
)
normalized = encoder._normalize_world_points(points)
assert torch.allclose(normalized[0], torch.tensor([-1.0, -1.0, -1.0]))
assert torch.allclose(normalized[1], torch.tensor([1.0, 1.0, 1.0]))
assert torch.allclose(normalized[2], torch.tensor([0.0, 0.0, 0.0]), atol=1e-6)
def test_backbone_only_policy_accepts_rvt_backbone() -> None:
intrinsics, extrinsics = _camera_batch()
policy = build_policy(
PolicyConfig(
backbone=FrozenVLBackboneConfig(
backbone_type="rvt",
hidden_dim=512,
max_text_tokens=77,
freeze_backbone=True,
gradient_checkpointing=False,
rvt_point_stride=4,
rvt_max_points_per_view=128,
),
fusion=MultiViewFusionConfig(
hidden_dim=512,
num_cameras=5,
num_layers=1,
num_heads=8,
ff_dim=1024,
dropout=0.0,
proprio_dim=32,
),
memory=ObservationMemoryConfig(
hidden_dim=512,
history_steps=1,
num_layers=1,
dropout=0.0,
),
decoder=ChunkDecoderConfig(
hidden_dim=512,
num_heads=8,
num_layers=1,
ff_dim=1024,
dropout=0.0,
chunk_size=2,
action_dim=14,
num_candidates=2,
),
reveal_head=RevealHeadConfig(hidden_dim=512),
world_model=RevealWMConfig(hidden_dim=512),
planner=PlannerConfig(hidden_dim=512, num_candidates=2),
),
TrainerConfig(
policy_type="backbone_only",
use_bf16=False,
freeze_backbone=True,
gradient_checkpointing=False,
),
)
outputs = policy(
images=torch.rand(1, 3, 3, 32, 32),
depths=torch.rand(1, 3, 1, 32, 32) * 0.3 + 0.2,
camera_intrinsics=intrinsics,
camera_extrinsics=extrinsics,
proprio=torch.rand(1, 32),
texts=["move the box together"],
)
assert tuple(outputs["scene_tokens"].shape) == (1, 2007, 512)
assert tuple(outputs["action_mean"].shape) == (1, 2, 14)
|