File size: 1,406 Bytes
e7d8e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch

from models.world_model import ElasticOcclusionWorldModel


def test_world_model_task_adapter(tiny_policy_config, tiny_state):
    config = tiny_policy_config()
    config.world_model.rollout_mode = "spatial_rollout"
    model = ElasticOcclusionWorldModel(config.world_model)
    state = tiny_state(field_size=config.reveal_head.field_size)
    state["scene_memory_tokens"] = torch.rand(1, config.memory.scene_bank_size, config.backbone.hidden_dim)
    state["belief_memory_tokens"] = torch.rand(1, config.memory.belief_bank_size, config.backbone.hidden_dim)
    state = {key: value[:1] for key, value in state.items()}
    scene_tokens = torch.rand(1, 12, config.backbone.hidden_dim)
    action_chunk = torch.zeros(1, config.decoder.chunk_size, 14)
    foliage = model(
        scene_tokens=scene_tokens,
        interaction_state=state,
        action_chunk=action_chunk,
        scene_memory_tokens=state["scene_memory_tokens"],
        belief_memory_tokens=state["belief_memory_tokens"],
        task_names=["foliage"],
    )
    bag = model(
        scene_tokens=scene_tokens,
        interaction_state=state,
        action_chunk=action_chunk,
        scene_memory_tokens=state["scene_memory_tokens"],
        belief_memory_tokens=state["belief_memory_tokens"],
        task_names=["bag"],
    )
    assert not torch.allclose(foliage["support_mode_logits"], bag["support_mode_logits"])