import torch from models.world_model import ElasticOcclusionWorldModel def test_world_model_null_rollout(tiny_policy_config, tiny_state): config = tiny_policy_config() config.world_model.rollout_mode = "null_rollout" model = ElasticOcclusionWorldModel(config.world_model) state = tiny_state(field_size=config.reveal_head.field_size) scene_tokens = torch.rand(2, 12, config.backbone.hidden_dim) action_chunk = torch.rand(2, config.decoder.chunk_size, 14) output = model(scene_tokens=scene_tokens, interaction_state=state, action_chunk=action_chunk) expected = state["target_belief_field"].unsqueeze(1).expand(-1, config.decoder.chunk_size, -1, -1, -1) assert torch.allclose(output["target_belief_field"], expected) expected_support = state["support_mode_logits"].unsqueeze(1).expand(-1, config.decoder.chunk_size, -1) assert torch.allclose(output["support_mode_logits"], expected_support)