VLAarchtests / tests /test_action_decoder_equivariance.py
lsnu's picture
Add files using upload-large-folder tool
16405f2 verified
import torch
from models.action_decoder import SymmetricCoordinatedChunkDecoder
def test_action_decoder_equivariance(tiny_policy_config, tiny_state):
config = tiny_policy_config()
decoder = SymmetricCoordinatedChunkDecoder(config.decoder)
state = tiny_state(field_size=config.reveal_head.field_size)
scene_tokens = torch.rand(2, 10, config.backbone.hidden_dim)
memory_tokens = torch.rand(2, config.memory.scene_bank_size + config.memory.belief_bank_size, config.backbone.hidden_dim)
output = decoder(
scene_tokens=scene_tokens,
interaction_state=state,
memory_tokens=memory_tokens,
compute_equivariance_probe=True,
)
error = (output["equivariance_probe_action_mean"] - output["equivariance_target_action_mean"]).abs().mean()
assert float(error) < 0.1