VLAarchtests / tests /conftest.py
lsnu's picture
2026-03-25 runpod handoff update
e7d8e79 verified
import sys
from pathlib import Path
import pytest
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT / "code" / "reveal_vla_bimanual"))
from models.action_decoder import ChunkDecoderConfig
from models.backbones import FrozenVLBackboneConfig
from models.multiview_fusion import MultiViewFusionConfig
from models.observation_memory import ObservationMemoryConfig
from models.planner import PlannerConfig
from models.policy import PolicyConfig
from models.reveal_head import RevealHeadConfig
from models.world_model import RevealWMConfig
from train.trainer import TrainerConfig
@pytest.fixture
def tiny_policy_config():
def _factory(
hidden_dim: int = 16,
chunk_size: int = 2,
num_candidates: int = 4,
top_k: int = 2,
field_size: int = 4,
belief_map_size: int = 8,
) -> PolicyConfig:
return PolicyConfig(
backbone=FrozenVLBackboneConfig(
hidden_dim=hidden_dim,
freeze_backbone=True,
gradient_checkpointing=False,
use_dummy_backbone=True,
depth_patch_size=8,
),
fusion=MultiViewFusionConfig(
hidden_dim=hidden_dim,
num_layers=1,
num_heads=4,
ff_dim=hidden_dim * 4,
dropout=0.0,
),
memory=ObservationMemoryConfig(
hidden_dim=hidden_dim,
num_heads=4,
dropout=0.0,
history_steps=2,
scene_history_steps=2,
belief_history_steps=3,
max_history_steps=4,
scene_bank_size=2,
belief_bank_size=2,
),
decoder=ChunkDecoderConfig(
hidden_dim=hidden_dim,
num_heads=4,
num_layers=1,
ff_dim=hidden_dim * 4,
dropout=0.0,
chunk_size=chunk_size,
num_candidates=num_candidates,
num_proposal_modes=7,
planner_top_k=top_k,
),
reveal_head=RevealHeadConfig(
hidden_dim=hidden_dim,
num_heads=4,
field_size=field_size,
belief_map_size=belief_map_size,
predict_belief_map=True,
),
world_model=RevealWMConfig(
hidden_dim=hidden_dim,
num_heads=4,
field_size=field_size,
belief_map_size=belief_map_size,
scene_bank_size=2,
belief_bank_size=2,
),
planner=PlannerConfig(
hidden_dim=hidden_dim,
num_heads=4,
num_layers=1,
num_candidates=num_candidates,
top_k=top_k,
),
)
return _factory
@pytest.fixture
def tiny_trainer_config():
def _factory(policy_type: str = "elastic_reveal") -> TrainerConfig:
return TrainerConfig(
policy_type=policy_type,
use_bf16=False,
gradient_checkpointing=False,
freeze_backbone=True,
plan_during_train=True,
plan_during_eval=True,
)
return _factory
@pytest.fixture
def tiny_batch():
def _factory(
batch_size: int = 2,
history_steps: int = 2,
resolution: int = 16,
chunk_size: int = 2,
) -> dict[str, torch.Tensor | list[str]]:
images = torch.rand(batch_size, 3, 3, resolution, resolution)
depths = torch.rand(batch_size, 3, 1, resolution, resolution)
batch = {
"images": images,
"depths": depths,
"depth_valid": torch.ones_like(depths),
"camera_intrinsics": torch.eye(3).view(1, 1, 3, 3).expand(batch_size, 3, 3, 3).clone(),
"camera_extrinsics": torch.eye(4).view(1, 1, 4, 4).expand(batch_size, 3, 4, 4).clone(),
"proprio": torch.rand(batch_size, 32),
"texts": ["test task"] * batch_size,
"history_images": torch.rand(batch_size, history_steps, 3, 3, resolution, resolution),
"history_depths": torch.rand(batch_size, history_steps, 3, 1, resolution, resolution),
"history_depth_valid": torch.ones(batch_size, history_steps, 3, 1, resolution, resolution),
"history_proprio": torch.rand(batch_size, history_steps, 32),
"history_actions": torch.rand(batch_size, history_steps, 14),
"action_chunk": torch.rand(batch_size, chunk_size, 14),
}
return batch
return _factory
@pytest.fixture
def tiny_state():
def _factory(batch_size: int = 2, field_size: int = 4) -> dict[str, torch.Tensor]:
return {
"target_belief_field": torch.rand(batch_size, 1, field_size, field_size),
"visibility_field": torch.rand(batch_size, 1, field_size, field_size),
"clearance_field": torch.rand(batch_size, 2, field_size, field_size),
"occluder_contact_field": torch.rand(batch_size, 1, field_size, field_size),
"grasp_affordance_field": torch.rand(batch_size, 1, field_size, field_size),
"support_stability_field": torch.rand(batch_size, 1, field_size, field_size),
"persistence_field": torch.rand(batch_size, 1, field_size, field_size),
"reocclusion_field": torch.rand(batch_size, 1, field_size, field_size),
"disturbance_field": torch.rand(batch_size, 1, field_size, field_size),
"risk_field": torch.rand(batch_size, 1, field_size, field_size),
"uncertainty_field": torch.rand(batch_size, 1, field_size, field_size),
"access_field": torch.rand(batch_size, 3, field_size, field_size),
"support_mode_logits": torch.rand(batch_size, 3),
"phase_logits": torch.rand(batch_size, 5),
"arm_role_logits": torch.rand(batch_size, 2, 4),
"interaction_tokens": torch.rand(batch_size, 8, 16),
"field_tokens": torch.rand(batch_size, field_size * field_size, 16),
"latent_summary": torch.rand(batch_size, 16),
"corridor_logits": torch.rand(batch_size, 3, 32),
"persistence_horizon": torch.rand(batch_size, 3),
"disturbance_cost": torch.rand(batch_size),
"reocclusion_logit": torch.rand(batch_size, 3),
"belief_map": torch.rand(batch_size, 1, 8, 8),
"compact_state": torch.rand(batch_size, 30),
}
return _factory