| import torch | |
| from models.observation_memory import DualObservationMemory | |
| def test_dual_memory_contract(tiny_policy_config): | |
| config = tiny_policy_config() | |
| memory = DualObservationMemory(config.memory) | |
| scene_tokens = torch.rand(2, 12, config.backbone.hidden_dim) | |
| history_scene_tokens = torch.rand(2, 3, 12, config.backbone.hidden_dim) | |
| history_actions = torch.rand(2, 3, 14) | |
| output = memory(scene_tokens, history_scene_tokens=history_scene_tokens, history_actions=history_actions) | |
| assert output["scene_memory_tokens"].shape[1] == config.memory.scene_bank_size | |
| assert output["belief_memory_tokens"].shape[1] == config.memory.belief_bank_size | |
| assert output["memory_tokens"].shape[1] == config.memory.scene_bank_size + config.memory.belief_bank_size | |
| assert torch.all(output["memory_write_rate"] >= 0.0) | |
| assert torch.all(output["memory_write_rate"] <= 1.0) | |