VLAarchtests / tests /test_world_model_state_contract.py
lsnu's picture
Add files using upload-large-folder tool
16405f2 verified
import torch
from models.world_model import ElasticOcclusionWorldModel
def test_world_model_state_contract(tiny_policy_config, tiny_state):
config = tiny_policy_config()
model = ElasticOcclusionWorldModel(config.world_model)
state = tiny_state(field_size=config.reveal_head.field_size)
action_chunk = torch.rand(2, config.decoder.chunk_size, 14)
scene_tokens = torch.rand(2, 12, config.backbone.hidden_dim)
state["scene_memory_tokens"] = torch.rand(2, config.memory.scene_bank_size, config.backbone.hidden_dim)
state["belief_memory_tokens"] = torch.rand(2, config.memory.belief_bank_size, config.backbone.hidden_dim)
output = model(
scene_tokens=scene_tokens,
interaction_state=state,
action_chunk=action_chunk,
scene_memory_tokens=state["scene_memory_tokens"],
belief_memory_tokens=state["belief_memory_tokens"],
)
assert output["target_belief_field"].shape[:3] == (2, config.decoder.chunk_size, 1)
assert output["scene_memory_tokens"].shape[2] == config.memory.scene_bank_size
assert output["belief_memory_tokens"].shape[2] == config.memory.belief_bank_size