File size: 951 Bytes
e7d8e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch

from models.world_model import ElasticOcclusionWorldModel


def test_world_model_identity_rollout(tiny_policy_config, tiny_state):
    config = tiny_policy_config()
    config.world_model.rollout_mode = "identity_rollout"
    model = ElasticOcclusionWorldModel(config.world_model)
    state = tiny_state(field_size=config.reveal_head.field_size)
    scene_tokens = torch.rand(2, 12, config.backbone.hidden_dim)
    action_chunk = torch.rand(2, config.decoder.chunk_size, 14)
    output = model(scene_tokens=scene_tokens, interaction_state=state, action_chunk=action_chunk)
    expected_roles = state["arm_role_logits"].unsqueeze(1).expand(-1, config.decoder.chunk_size, -1, -1)
    assert torch.allclose(output["arm_role_logits"], expected_roles)
    expected_visibility = state["visibility_field"].unsqueeze(1).expand(-1, config.decoder.chunk_size, -1, -1, -1)
    assert torch.allclose(output["visibility_field"], expected_visibility)