from __future__ import annotations import torch from models.action_decoder import ChunkDecoderConfig from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig from models.rvt_backbone import RVTVisualEncoder from models.multiview_fusion import MultiViewFusionConfig from models.observation_memory import ObservationMemoryConfig from models.planner import PlannerConfig from models.policy import PolicyConfig from models.reveal_head import RevealHeadConfig from models.world_model import RevealWMConfig from train.trainer import TrainerConfig, build_policy def _camera_batch() -> tuple[torch.Tensor, torch.Tensor]: intrinsics = torch.eye(3).view(1, 1, 3, 3).expand(1, 3, 3, 3).clone() intrinsics[:, :, 0, 0] = 30.0 intrinsics[:, :, 1, 1] = 30.0 intrinsics[:, :, 0, 2] = 16.0 intrinsics[:, :, 1, 2] = 16.0 extrinsics = torch.eye(4).view(1, 1, 4, 4).expand(1, 3, 4, 4).clone() extrinsics[:, 1, 0, 3] = -0.1 extrinsics[:, 2, 0, 3] = 0.1 return intrinsics, extrinsics def test_rvt_backbone_emits_five_view_tokens() -> None: backbone = FrozenVLBackbone( FrozenVLBackboneConfig( backbone_type="rvt", hidden_dim=512, max_text_tokens=77, freeze_backbone=True, gradient_checkpointing=False, rvt_point_stride=4, rvt_max_points_per_view=128, ) ) texts = ["move the box together"] language_tokens = backbone.tokenize_text(texts, device=torch.device("cpu")) text_features = backbone.encode_text(language_tokens["input_ids"], language_tokens["attention_mask"]) intrinsics, extrinsics = _camera_batch() tokens = backbone.encode_images( images=torch.rand(1, 3, 3, 32, 32), proprio=torch.rand(1, 32), language_tokens=text_features, depths=torch.rand(1, 3, 1, 32, 32) * 0.3 + 0.2, camera_intrinsics=intrinsics, camera_extrinsics=extrinsics, ) assert tuple(tokens.shape) == (1, 5, 400, 512) def test_rvt_backbone_uses_fixed_scene_bounds_for_normalization() -> None: encoder = RVTVisualEncoder( checkpoint_path="/workspace/models/rvt_official/rvt/model_14.pth", mvt_cfg_path="/workspace/models/rvt_official/rvt/mvt_cfg.yaml", output_dim=512, input_proprio_dim=32, renderer_device="cpu", point_stride=4, max_points_per_view=128, ) points = torch.tensor( [ [-0.3, -0.5, 0.6], [0.7, 0.5, 1.6], [0.2, 0.0, 1.1], ], dtype=torch.float32, ) normalized = encoder._normalize_world_points(points) assert torch.allclose(normalized[0], torch.tensor([-1.0, -1.0, -1.0])) assert torch.allclose(normalized[1], torch.tensor([1.0, 1.0, 1.0])) assert torch.allclose(normalized[2], torch.tensor([0.0, 0.0, 0.0]), atol=1e-6) def test_backbone_only_policy_accepts_rvt_backbone() -> None: intrinsics, extrinsics = _camera_batch() policy = build_policy( PolicyConfig( backbone=FrozenVLBackboneConfig( backbone_type="rvt", hidden_dim=512, max_text_tokens=77, freeze_backbone=True, gradient_checkpointing=False, rvt_point_stride=4, rvt_max_points_per_view=128, ), fusion=MultiViewFusionConfig( hidden_dim=512, num_cameras=5, num_layers=1, num_heads=8, ff_dim=1024, dropout=0.0, proprio_dim=32, ), memory=ObservationMemoryConfig( hidden_dim=512, history_steps=1, num_layers=1, dropout=0.0, ), decoder=ChunkDecoderConfig( hidden_dim=512, num_heads=8, num_layers=1, ff_dim=1024, dropout=0.0, chunk_size=2, action_dim=14, num_candidates=2, ), reveal_head=RevealHeadConfig(hidden_dim=512), world_model=RevealWMConfig(hidden_dim=512), planner=PlannerConfig(hidden_dim=512, num_candidates=2), ), TrainerConfig( policy_type="backbone_only", use_bf16=False, freeze_backbone=True, gradient_checkpointing=False, ), ) outputs = policy( images=torch.rand(1, 3, 3, 32, 32), depths=torch.rand(1, 3, 1, 32, 32) * 0.3 + 0.2, camera_intrinsics=intrinsics, camera_extrinsics=extrinsics, proprio=torch.rand(1, 32), texts=["move the box together"], ) assert tuple(outputs["scene_tokens"].shape) == (1, 2007, 512) assert tuple(outputs["action_mean"].shape) == (1, 2, 14)