File size: 1,406 Bytes
e7d8e79 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import torch
from models.world_model import ElasticOcclusionWorldModel
def test_world_model_task_adapter(tiny_policy_config, tiny_state):
config = tiny_policy_config()
config.world_model.rollout_mode = "spatial_rollout"
model = ElasticOcclusionWorldModel(config.world_model)
state = tiny_state(field_size=config.reveal_head.field_size)
state["scene_memory_tokens"] = torch.rand(1, config.memory.scene_bank_size, config.backbone.hidden_dim)
state["belief_memory_tokens"] = torch.rand(1, config.memory.belief_bank_size, config.backbone.hidden_dim)
state = {key: value[:1] for key, value in state.items()}
scene_tokens = torch.rand(1, 12, config.backbone.hidden_dim)
action_chunk = torch.zeros(1, config.decoder.chunk_size, 14)
foliage = model(
scene_tokens=scene_tokens,
interaction_state=state,
action_chunk=action_chunk,
scene_memory_tokens=state["scene_memory_tokens"],
belief_memory_tokens=state["belief_memory_tokens"],
task_names=["foliage"],
)
bag = model(
scene_tokens=scene_tokens,
interaction_state=state,
action_chunk=action_chunk,
scene_memory_tokens=state["scene_memory_tokens"],
belief_memory_tokens=state["belief_memory_tokens"],
task_names=["bag"],
)
assert not torch.allclose(foliage["support_mode_logits"], bag["support_mode_logits"])
|