VLAarchtests / tests /test_task_conditioned_head_shapes.py
lsnu's picture
2026-03-25 runpod handoff update
e7d8e79 verified
import torch
from models.reveal_head import ElasticOcclusionStateHead
def test_task_conditioned_head_shapes(tiny_policy_config):
config = tiny_policy_config()
head = ElasticOcclusionStateHead(config.reveal_head)
scene_tokens = torch.rand(3, 12, config.backbone.hidden_dim)
memory_tokens = torch.rand(
3,
config.memory.scene_bank_size + config.memory.belief_bank_size,
config.backbone.hidden_dim,
)
output = head(
scene_tokens,
memory_tokens=memory_tokens,
task_names=["foliage", "bag", "cloth"],
)
for key in (
"opening_quality",
"actor_feasibility_score",
"gap_width",
"hold_quality",
"fold_preservation",
"lift_too_much_risk",
):
assert key in output
assert output[key].shape == (3,)
assert output["opening_quality_field"].shape[1] == 1
assert output["newly_revealed_field"].shape[1] == 1