File size: 983 Bytes
31ade1f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | import torch
from models.world_model import LightweightRevealStateTransitionModel
def test_lightweight_transition_contract(tiny_policy_config, tiny_state):
config = tiny_policy_config(num_candidates=4, chunk_size=2)
model = LightweightRevealStateTransitionModel(config.world_model)
state = tiny_state(batch_size=2, field_size=config.reveal_head.field_size)
action_chunk = torch.rand(2, 4, config.decoder.chunk_size, config.decoder.action_dim)
proposal_mode_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long).expand(2, -1)
rollout = model(
interaction_state=state,
action_chunk=action_chunk,
proposal_mode_ids=proposal_mode_ids,
)
assert rollout["visibility_summary"].shape == (2, 4, config.decoder.chunk_size)
assert rollout["access_field"].shape[:4] == (2, 4, config.decoder.chunk_size, config.world_model.num_support_modes)
assert rollout["clearance_field"].shape == (2, 4, config.decoder.chunk_size, 2, 1, 1)
|