VLAarchtests / tests /test_candidate_macro_coverage.py
lsnu's picture
2026-03-25 runpod handoff update
e7d8e79 verified
from models.action_decoder import proposal_mode_vocab
from train.trainer import build_policy
def test_candidate_macro_coverage(tiny_policy_config, tiny_trainer_config, tiny_batch):
config = tiny_policy_config(num_candidates=7)
batch = tiny_batch(chunk_size=config.decoder.chunk_size)
policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
policy.eval()
task_to_text = {
"foliage": ["create a gap in the foliage and retrieve the target"],
"bag": ["open the bag mouth and retrieve the target object"],
"cloth": ["lift the top layer enough to retrieve the hidden object"],
}
for task_name, texts in task_to_text.items():
output = policy(
images=batch["images"][:1],
depths=batch["depths"][:1],
depth_valid=batch["depth_valid"][:1],
camera_intrinsics=batch["camera_intrinsics"][:1],
camera_extrinsics=batch["camera_extrinsics"][:1],
proprio=batch["proprio"][:1],
texts=texts,
history_images=batch["history_images"][:1],
history_depths=batch["history_depths"][:1],
history_depth_valid=batch["history_depth_valid"][:1],
history_proprio=batch["history_proprio"][:1],
history_actions=batch["history_actions"][:1],
plan=False,
)
expected = set(proposal_mode_vocab(task_name, config.decoder.num_proposal_modes))
observed = set(output["proposal_mode_names"][0])
assert expected.issubset(observed)