VLAarchtests / tests /test_reocclusion_memory_regression.py
lsnu's picture
2026-03-25 runpod handoff update
e7d8e79 verified
import torch
from models.observation_memory import DualObservationMemory
def test_reocclusion_memory_regression(tiny_policy_config):
config = tiny_policy_config(hidden_dim=16)
config.memory.scene_bank_size = 4
config.memory.belief_bank_size = 4
memory = DualObservationMemory(config.memory)
open_scene = torch.zeros(1, 12, config.backbone.hidden_dim)
open_scene[:, :3] = 1.0
closed_scene = torch.zeros_like(open_scene)
history = torch.stack([open_scene[0], open_scene[0]], dim=0).unsqueeze(0)
history_actions = torch.zeros(1, 2, 14)
closed_output = memory(closed_scene, history_scene_tokens=history, history_actions=history_actions)
closed_no_history = memory(
closed_scene,
history_scene_tokens=torch.zeros_like(history),
history_actions=history_actions,
)
belief_norm = closed_output["belief_memory_tokens"].norm()
belief_delta = (closed_output["belief_memory_tokens"] - closed_no_history["belief_memory_tokens"]).norm()
assert belief_norm > 0.0
assert belief_delta > 1e-3