| import torch | |
| from models.world_model import ElasticOcclusionWorldModel | |
| def test_world_model_identity_rollout(tiny_policy_config, tiny_state): | |
| config = tiny_policy_config() | |
| config.world_model.rollout_mode = "identity_rollout" | |
| model = ElasticOcclusionWorldModel(config.world_model) | |
| state = tiny_state(field_size=config.reveal_head.field_size) | |
| scene_tokens = torch.rand(2, 12, config.backbone.hidden_dim) | |
| action_chunk = torch.rand(2, config.decoder.chunk_size, 14) | |
| output = model(scene_tokens=scene_tokens, interaction_state=state, action_chunk=action_chunk) | |
| expected_roles = state["arm_role_logits"].unsqueeze(1).expand(-1, config.decoder.chunk_size, -1, -1) | |
| assert torch.allclose(output["arm_role_logits"], expected_roles) | |
| expected_visibility = state["visibility_field"].unsqueeze(1).expand(-1, config.decoder.chunk_size, -1, -1, -1) | |
| assert torch.allclose(output["visibility_field"], expected_visibility) | |