File size: 822 Bytes
16405f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch

from models.action_decoder import SymmetricCoordinatedChunkDecoder


def test_action_decoder_equivariance(tiny_policy_config, tiny_state):
    config = tiny_policy_config()
    decoder = SymmetricCoordinatedChunkDecoder(config.decoder)
    state = tiny_state(field_size=config.reveal_head.field_size)
    scene_tokens = torch.rand(2, 10, config.backbone.hidden_dim)
    memory_tokens = torch.rand(2, config.memory.scene_bank_size + config.memory.belief_bank_size, config.backbone.hidden_dim)
    output = decoder(
        scene_tokens=scene_tokens,
        interaction_state=state,
        memory_tokens=memory_tokens,
        compute_equivariance_probe=True,
    )
    error = (output["equivariance_probe_action_mean"] - output["equivariance_target_action_mean"]).abs().mean()
    assert float(error) < 0.1