| import sys |
| from pathlib import Path |
|
|
| import pytest |
| import torch |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| sys.path.insert(0, str(REPO_ROOT / "code" / "reveal_vla_bimanual")) |
|
|
| from models.action_decoder import ChunkDecoderConfig |
| from models.backbones import FrozenVLBackboneConfig |
| from models.multiview_fusion import MultiViewFusionConfig |
| from models.observation_memory import ObservationMemoryConfig |
| from models.planner import PlannerConfig |
| from models.policy import PolicyConfig |
| from models.reveal_head import RevealHeadConfig |
| from models.world_model import RevealWMConfig |
| from train.trainer import TrainerConfig |
|
|
|
|
| @pytest.fixture |
| def tiny_policy_config(): |
| def _factory( |
| hidden_dim: int = 16, |
| chunk_size: int = 2, |
| num_candidates: int = 4, |
| top_k: int = 2, |
| field_size: int = 4, |
| belief_map_size: int = 8, |
| ) -> PolicyConfig: |
| return PolicyConfig( |
| backbone=FrozenVLBackboneConfig( |
| hidden_dim=hidden_dim, |
| freeze_backbone=True, |
| gradient_checkpointing=False, |
| use_dummy_backbone=True, |
| depth_patch_size=8, |
| ), |
| fusion=MultiViewFusionConfig( |
| hidden_dim=hidden_dim, |
| num_layers=1, |
| num_heads=4, |
| ff_dim=hidden_dim * 4, |
| dropout=0.0, |
| ), |
| memory=ObservationMemoryConfig( |
| hidden_dim=hidden_dim, |
| num_heads=4, |
| dropout=0.0, |
| history_steps=2, |
| scene_history_steps=2, |
| belief_history_steps=3, |
| max_history_steps=4, |
| scene_bank_size=2, |
| belief_bank_size=2, |
| ), |
| decoder=ChunkDecoderConfig( |
| hidden_dim=hidden_dim, |
| num_heads=4, |
| num_layers=1, |
| ff_dim=hidden_dim * 4, |
| dropout=0.0, |
| chunk_size=chunk_size, |
| num_candidates=num_candidates, |
| num_proposal_modes=7, |
| planner_top_k=top_k, |
| ), |
| reveal_head=RevealHeadConfig( |
| hidden_dim=hidden_dim, |
| num_heads=4, |
| field_size=field_size, |
| belief_map_size=belief_map_size, |
| predict_belief_map=True, |
| ), |
| world_model=RevealWMConfig( |
| hidden_dim=hidden_dim, |
| num_heads=4, |
| field_size=field_size, |
| belief_map_size=belief_map_size, |
| scene_bank_size=2, |
| belief_bank_size=2, |
| ), |
| planner=PlannerConfig( |
| hidden_dim=hidden_dim, |
| num_heads=4, |
| num_layers=1, |
| num_candidates=num_candidates, |
| top_k=top_k, |
| ), |
| ) |
|
|
| return _factory |
|
|
|
|
| @pytest.fixture |
| def tiny_trainer_config(): |
| def _factory(policy_type: str = "elastic_reveal") -> TrainerConfig: |
| return TrainerConfig( |
| policy_type=policy_type, |
| use_bf16=False, |
| gradient_checkpointing=False, |
| freeze_backbone=True, |
| plan_during_train=True, |
| plan_during_eval=True, |
| ) |
|
|
| return _factory |
|
|
|
|
| @pytest.fixture |
| def tiny_batch(): |
| def _factory( |
| batch_size: int = 2, |
| history_steps: int = 2, |
| resolution: int = 16, |
| chunk_size: int = 2, |
| ) -> dict[str, torch.Tensor | list[str]]: |
| images = torch.rand(batch_size, 3, 3, resolution, resolution) |
| depths = torch.rand(batch_size, 3, 1, resolution, resolution) |
| batch = { |
| "images": images, |
| "depths": depths, |
| "depth_valid": torch.ones_like(depths), |
| "camera_intrinsics": torch.eye(3).view(1, 1, 3, 3).expand(batch_size, 3, 3, 3).clone(), |
| "camera_extrinsics": torch.eye(4).view(1, 1, 4, 4).expand(batch_size, 3, 4, 4).clone(), |
| "proprio": torch.rand(batch_size, 32), |
| "texts": ["test task"] * batch_size, |
| "history_images": torch.rand(batch_size, history_steps, 3, 3, resolution, resolution), |
| "history_depths": torch.rand(batch_size, history_steps, 3, 1, resolution, resolution), |
| "history_depth_valid": torch.ones(batch_size, history_steps, 3, 1, resolution, resolution), |
| "history_proprio": torch.rand(batch_size, history_steps, 32), |
| "history_actions": torch.rand(batch_size, history_steps, 14), |
| "action_chunk": torch.rand(batch_size, chunk_size, 14), |
| } |
| return batch |
|
|
| return _factory |
|
|
|
|
| @pytest.fixture |
| def tiny_state(): |
| def _factory(batch_size: int = 2, field_size: int = 4) -> dict[str, torch.Tensor]: |
| return { |
| "target_belief_field": torch.rand(batch_size, 1, field_size, field_size), |
| "visibility_field": torch.rand(batch_size, 1, field_size, field_size), |
| "clearance_field": torch.rand(batch_size, 2, field_size, field_size), |
| "occluder_contact_field": torch.rand(batch_size, 1, field_size, field_size), |
| "grasp_affordance_field": torch.rand(batch_size, 1, field_size, field_size), |
| "support_stability_field": torch.rand(batch_size, 1, field_size, field_size), |
| "persistence_field": torch.rand(batch_size, 1, field_size, field_size), |
| "reocclusion_field": torch.rand(batch_size, 1, field_size, field_size), |
| "disturbance_field": torch.rand(batch_size, 1, field_size, field_size), |
| "risk_field": torch.rand(batch_size, 1, field_size, field_size), |
| "uncertainty_field": torch.rand(batch_size, 1, field_size, field_size), |
| "access_field": torch.rand(batch_size, 3, field_size, field_size), |
| "support_mode_logits": torch.rand(batch_size, 3), |
| "phase_logits": torch.rand(batch_size, 5), |
| "arm_role_logits": torch.rand(batch_size, 2, 4), |
| "interaction_tokens": torch.rand(batch_size, 8, 16), |
| "field_tokens": torch.rand(batch_size, field_size * field_size, 16), |
| "latent_summary": torch.rand(batch_size, 16), |
| "corridor_logits": torch.rand(batch_size, 3, 32), |
| "persistence_horizon": torch.rand(batch_size, 3), |
| "disturbance_cost": torch.rand(batch_size), |
| "reocclusion_logit": torch.rand(batch_size, 3), |
| "belief_map": torch.rand(batch_size, 1, 8, 8), |
| "compact_state": torch.rand(batch_size, 30), |
| } |
|
|
| return _factory |
|
|