import sys from pathlib import Path import pytest import torch REPO_ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(REPO_ROOT / "code" / "reveal_vla_bimanual")) from models.action_decoder import ChunkDecoderConfig from models.backbones import FrozenVLBackboneConfig from models.multiview_fusion import MultiViewFusionConfig from models.observation_memory import ObservationMemoryConfig from models.planner import PlannerConfig from models.policy import PolicyConfig from models.reveal_head import RevealHeadConfig from models.world_model import RevealWMConfig from train.trainer import TrainerConfig @pytest.fixture def tiny_policy_config(): def _factory( hidden_dim: int = 16, chunk_size: int = 2, num_candidates: int = 4, top_k: int = 2, field_size: int = 4, belief_map_size: int = 8, ) -> PolicyConfig: return PolicyConfig( backbone=FrozenVLBackboneConfig( hidden_dim=hidden_dim, freeze_backbone=True, gradient_checkpointing=False, use_dummy_backbone=True, depth_patch_size=8, ), fusion=MultiViewFusionConfig( hidden_dim=hidden_dim, num_layers=1, num_heads=4, ff_dim=hidden_dim * 4, dropout=0.0, ), memory=ObservationMemoryConfig( hidden_dim=hidden_dim, num_heads=4, dropout=0.0, history_steps=2, scene_history_steps=2, belief_history_steps=3, max_history_steps=4, scene_bank_size=2, belief_bank_size=2, ), decoder=ChunkDecoderConfig( hidden_dim=hidden_dim, num_heads=4, num_layers=1, ff_dim=hidden_dim * 4, dropout=0.0, chunk_size=chunk_size, num_candidates=num_candidates, num_proposal_modes=7, planner_top_k=top_k, ), reveal_head=RevealHeadConfig( hidden_dim=hidden_dim, num_heads=4, field_size=field_size, belief_map_size=belief_map_size, predict_belief_map=True, ), world_model=RevealWMConfig( hidden_dim=hidden_dim, num_heads=4, field_size=field_size, belief_map_size=belief_map_size, scene_bank_size=2, belief_bank_size=2, ), planner=PlannerConfig( hidden_dim=hidden_dim, num_heads=4, num_layers=1, num_candidates=num_candidates, top_k=top_k, ), ) return _factory @pytest.fixture def tiny_trainer_config(): def _factory(policy_type: str = "elastic_reveal") -> TrainerConfig: return TrainerConfig( policy_type=policy_type, use_bf16=False, gradient_checkpointing=False, freeze_backbone=True, plan_during_train=True, plan_during_eval=True, ) return _factory @pytest.fixture def tiny_batch(): def _factory( batch_size: int = 2, history_steps: int = 2, resolution: int = 16, chunk_size: int = 2, ) -> dict[str, torch.Tensor | list[str]]: images = torch.rand(batch_size, 3, 3, resolution, resolution) depths = torch.rand(batch_size, 3, 1, resolution, resolution) batch = { "images": images, "depths": depths, "depth_valid": torch.ones_like(depths), "camera_intrinsics": torch.eye(3).view(1, 1, 3, 3).expand(batch_size, 3, 3, 3).clone(), "camera_extrinsics": torch.eye(4).view(1, 1, 4, 4).expand(batch_size, 3, 4, 4).clone(), "proprio": torch.rand(batch_size, 32), "texts": ["test task"] * batch_size, "history_images": torch.rand(batch_size, history_steps, 3, 3, resolution, resolution), "history_depths": torch.rand(batch_size, history_steps, 3, 1, resolution, resolution), "history_depth_valid": torch.ones(batch_size, history_steps, 3, 1, resolution, resolution), "history_proprio": torch.rand(batch_size, history_steps, 32), "history_actions": torch.rand(batch_size, history_steps, 14), "action_chunk": torch.rand(batch_size, chunk_size, 14), } return batch return _factory @pytest.fixture def tiny_state(): def _factory(batch_size: int = 2, field_size: int = 4) -> dict[str, torch.Tensor]: return { "target_belief_field": torch.rand(batch_size, 1, field_size, field_size), "visibility_field": torch.rand(batch_size, 1, field_size, field_size), "clearance_field": torch.rand(batch_size, 2, field_size, field_size), "occluder_contact_field": torch.rand(batch_size, 1, field_size, field_size), "grasp_affordance_field": torch.rand(batch_size, 1, field_size, field_size), "support_stability_field": torch.rand(batch_size, 1, field_size, field_size), "persistence_field": torch.rand(batch_size, 1, field_size, field_size), "reocclusion_field": torch.rand(batch_size, 1, field_size, field_size), "disturbance_field": torch.rand(batch_size, 1, field_size, field_size), "risk_field": torch.rand(batch_size, 1, field_size, field_size), "uncertainty_field": torch.rand(batch_size, 1, field_size, field_size), "access_field": torch.rand(batch_size, 3, field_size, field_size), "support_mode_logits": torch.rand(batch_size, 3), "phase_logits": torch.rand(batch_size, 5), "arm_role_logits": torch.rand(batch_size, 2, 4), "interaction_tokens": torch.rand(batch_size, 8, 16), "field_tokens": torch.rand(batch_size, field_size * field_size, 16), "latent_summary": torch.rand(batch_size, 16), "corridor_logits": torch.rand(batch_size, 3, 32), "persistence_horizon": torch.rand(batch_size, 3), "disturbance_cost": torch.rand(batch_size), "reocclusion_logit": torch.rand(batch_size, 3), "belief_map": torch.rand(batch_size, 1, 8, 8), "compact_state": torch.rand(batch_size, 30), } return _factory