File size: 897 Bytes
16405f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch

from models.observation_memory import DualObservationMemory


def test_dual_memory_contract(tiny_policy_config):
    config = tiny_policy_config()
    memory = DualObservationMemory(config.memory)
    scene_tokens = torch.rand(2, 12, config.backbone.hidden_dim)
    history_scene_tokens = torch.rand(2, 3, 12, config.backbone.hidden_dim)
    history_actions = torch.rand(2, 3, 14)
    output = memory(scene_tokens, history_scene_tokens=history_scene_tokens, history_actions=history_actions)
    assert output["scene_memory_tokens"].shape[1] == config.memory.scene_bank_size
    assert output["belief_memory_tokens"].shape[1] == config.memory.belief_bank_size
    assert output["memory_tokens"].shape[1] == config.memory.scene_bank_size + config.memory.belief_bank_size
    assert torch.all(output["memory_write_rate"] >= 0.0)
    assert torch.all(output["memory_write_rate"] <= 1.0)