| import torch | |
| from models.world_model import ElasticOcclusionWorldModel | |
| def test_world_model_task_adapter(tiny_policy_config, tiny_state): | |
| config = tiny_policy_config() | |
| config.world_model.rollout_mode = "spatial_rollout" | |
| model = ElasticOcclusionWorldModel(config.world_model) | |
| state = tiny_state(field_size=config.reveal_head.field_size) | |
| state["scene_memory_tokens"] = torch.rand(1, config.memory.scene_bank_size, config.backbone.hidden_dim) | |
| state["belief_memory_tokens"] = torch.rand(1, config.memory.belief_bank_size, config.backbone.hidden_dim) | |
| state = {key: value[:1] for key, value in state.items()} | |
| scene_tokens = torch.rand(1, 12, config.backbone.hidden_dim) | |
| action_chunk = torch.zeros(1, config.decoder.chunk_size, 14) | |
| foliage = model( | |
| scene_tokens=scene_tokens, | |
| interaction_state=state, | |
| action_chunk=action_chunk, | |
| scene_memory_tokens=state["scene_memory_tokens"], | |
| belief_memory_tokens=state["belief_memory_tokens"], | |
| task_names=["foliage"], | |
| ) | |
| bag = model( | |
| scene_tokens=scene_tokens, | |
| interaction_state=state, | |
| action_chunk=action_chunk, | |
| scene_memory_tokens=state["scene_memory_tokens"], | |
| belief_memory_tokens=state["belief_memory_tokens"], | |
| task_names=["bag"], | |
| ) | |
| assert not torch.allclose(foliage["support_mode_logits"], bag["support_mode_logits"]) | |