File size: 950 Bytes
e7d8e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch

from models.reveal_head import ElasticOcclusionStateHead


def test_task_conditioned_head_shapes(tiny_policy_config):
    config = tiny_policy_config()
    head = ElasticOcclusionStateHead(config.reveal_head)
    scene_tokens = torch.rand(3, 12, config.backbone.hidden_dim)
    memory_tokens = torch.rand(
        3,
        config.memory.scene_bank_size + config.memory.belief_bank_size,
        config.backbone.hidden_dim,
    )
    output = head(
        scene_tokens,
        memory_tokens=memory_tokens,
        task_names=["foliage", "bag", "cloth"],
    )
    for key in (
        "opening_quality",
        "actor_feasibility_score",
        "gap_width",
        "hold_quality",
        "fold_preservation",
        "lift_too_much_risk",
    ):
        assert key in output
        assert output[key].shape == (3,)
    assert output["opening_quality_field"].shape[1] == 1
    assert output["newly_revealed_field"].shape[1] == 1