File size: 983 Bytes
31ade1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch

from models.world_model import LightweightRevealStateTransitionModel


def test_lightweight_transition_contract(tiny_policy_config, tiny_state):
    config = tiny_policy_config(num_candidates=4, chunk_size=2)
    model = LightweightRevealStateTransitionModel(config.world_model)
    state = tiny_state(batch_size=2, field_size=config.reveal_head.field_size)
    action_chunk = torch.rand(2, 4, config.decoder.chunk_size, config.decoder.action_dim)
    proposal_mode_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long).expand(2, -1)

    rollout = model(
        interaction_state=state,
        action_chunk=action_chunk,
        proposal_mode_ids=proposal_mode_ids,
    )

    assert rollout["visibility_summary"].shape == (2, 4, config.decoder.chunk_size)
    assert rollout["access_field"].shape[:4] == (2, 4, config.decoder.chunk_size, config.world_model.num_support_modes)
    assert rollout["clearance_field"].shape == (2, 4, config.decoder.chunk_size, 2, 1, 1)