| from __future__ import annotations |
|
|
| import torch |
|
|
| from models.action_decoder import ChunkDecoderConfig |
| from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig |
| from models.rvt_backbone import RVTVisualEncoder |
| from models.multiview_fusion import MultiViewFusionConfig |
| from models.observation_memory import ObservationMemoryConfig |
| from models.planner import PlannerConfig |
| from models.policy import PolicyConfig |
| from models.reveal_head import RevealHeadConfig |
| from models.world_model import RevealWMConfig |
| from train.trainer import TrainerConfig, build_policy |
|
|
|
|
| def _camera_batch() -> tuple[torch.Tensor, torch.Tensor]: |
| intrinsics = torch.eye(3).view(1, 1, 3, 3).expand(1, 3, 3, 3).clone() |
| intrinsics[:, :, 0, 0] = 30.0 |
| intrinsics[:, :, 1, 1] = 30.0 |
| intrinsics[:, :, 0, 2] = 16.0 |
| intrinsics[:, :, 1, 2] = 16.0 |
| extrinsics = torch.eye(4).view(1, 1, 4, 4).expand(1, 3, 4, 4).clone() |
| extrinsics[:, 1, 0, 3] = -0.1 |
| extrinsics[:, 2, 0, 3] = 0.1 |
| return intrinsics, extrinsics |
|
|
|
|
| def test_rvt_backbone_emits_five_view_tokens() -> None: |
| backbone = FrozenVLBackbone( |
| FrozenVLBackboneConfig( |
| backbone_type="rvt", |
| hidden_dim=512, |
| max_text_tokens=77, |
| freeze_backbone=True, |
| gradient_checkpointing=False, |
| rvt_point_stride=4, |
| rvt_max_points_per_view=128, |
| ) |
| ) |
| texts = ["move the box together"] |
| language_tokens = backbone.tokenize_text(texts, device=torch.device("cpu")) |
| text_features = backbone.encode_text(language_tokens["input_ids"], language_tokens["attention_mask"]) |
| intrinsics, extrinsics = _camera_batch() |
| tokens = backbone.encode_images( |
| images=torch.rand(1, 3, 3, 32, 32), |
| proprio=torch.rand(1, 32), |
| language_tokens=text_features, |
| depths=torch.rand(1, 3, 1, 32, 32) * 0.3 + 0.2, |
| camera_intrinsics=intrinsics, |
| camera_extrinsics=extrinsics, |
| ) |
|
|
| assert tuple(tokens.shape) == (1, 5, 400, 512) |
|
|
|
|
| def test_rvt_backbone_uses_fixed_scene_bounds_for_normalization() -> None: |
| encoder = RVTVisualEncoder( |
| checkpoint_path="/workspace/models/rvt_official/rvt/model_14.pth", |
| mvt_cfg_path="/workspace/models/rvt_official/rvt/mvt_cfg.yaml", |
| output_dim=512, |
| input_proprio_dim=32, |
| renderer_device="cpu", |
| point_stride=4, |
| max_points_per_view=128, |
| ) |
| points = torch.tensor( |
| [ |
| [-0.3, -0.5, 0.6], |
| [0.7, 0.5, 1.6], |
| [0.2, 0.0, 1.1], |
| ], |
| dtype=torch.float32, |
| ) |
| normalized = encoder._normalize_world_points(points) |
|
|
| assert torch.allclose(normalized[0], torch.tensor([-1.0, -1.0, -1.0])) |
| assert torch.allclose(normalized[1], torch.tensor([1.0, 1.0, 1.0])) |
| assert torch.allclose(normalized[2], torch.tensor([0.0, 0.0, 0.0]), atol=1e-6) |
|
|
|
|
| def test_backbone_only_policy_accepts_rvt_backbone() -> None: |
| intrinsics, extrinsics = _camera_batch() |
| policy = build_policy( |
| PolicyConfig( |
| backbone=FrozenVLBackboneConfig( |
| backbone_type="rvt", |
| hidden_dim=512, |
| max_text_tokens=77, |
| freeze_backbone=True, |
| gradient_checkpointing=False, |
| rvt_point_stride=4, |
| rvt_max_points_per_view=128, |
| ), |
| fusion=MultiViewFusionConfig( |
| hidden_dim=512, |
| num_cameras=5, |
| num_layers=1, |
| num_heads=8, |
| ff_dim=1024, |
| dropout=0.0, |
| proprio_dim=32, |
| ), |
| memory=ObservationMemoryConfig( |
| hidden_dim=512, |
| history_steps=1, |
| num_layers=1, |
| dropout=0.0, |
| ), |
| decoder=ChunkDecoderConfig( |
| hidden_dim=512, |
| num_heads=8, |
| num_layers=1, |
| ff_dim=1024, |
| dropout=0.0, |
| chunk_size=2, |
| action_dim=14, |
| num_candidates=2, |
| ), |
| reveal_head=RevealHeadConfig(hidden_dim=512), |
| world_model=RevealWMConfig(hidden_dim=512), |
| planner=PlannerConfig(hidden_dim=512, num_candidates=2), |
| ), |
| TrainerConfig( |
| policy_type="backbone_only", |
| use_bf16=False, |
| freeze_backbone=True, |
| gradient_checkpointing=False, |
| ), |
| ) |
| outputs = policy( |
| images=torch.rand(1, 3, 3, 32, 32), |
| depths=torch.rand(1, 3, 1, 32, 32) * 0.3 + 0.2, |
| camera_intrinsics=intrinsics, |
| camera_extrinsics=extrinsics, |
| proprio=torch.rand(1, 32), |
| texts=["move the box together"], |
| ) |
|
|
| assert tuple(outputs["scene_tokens"].shape) == (1, 2007, 512) |
| assert tuple(outputs["action_mean"].shape) == (1, 2, 14) |
|
|