File size: 897 Bytes
16405f2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch
from models.observation_memory import DualObservationMemory
def test_dual_memory_contract(tiny_policy_config):
config = tiny_policy_config()
memory = DualObservationMemory(config.memory)
scene_tokens = torch.rand(2, 12, config.backbone.hidden_dim)
history_scene_tokens = torch.rand(2, 3, 12, config.backbone.hidden_dim)
history_actions = torch.rand(2, 3, 14)
output = memory(scene_tokens, history_scene_tokens=history_scene_tokens, history_actions=history_actions)
assert output["scene_memory_tokens"].shape[1] == config.memory.scene_bank_size
assert output["belief_memory_tokens"].shape[1] == config.memory.belief_bank_size
assert output["memory_tokens"].shape[1] == config.memory.scene_bank_size + config.memory.belief_bank_size
assert torch.all(output["memory_write_rate"] >= 0.0)
assert torch.all(output["memory_write_rate"] <= 1.0)
|