File size: 1,563 Bytes
e7d8e79 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | from models.action_decoder import proposal_mode_vocab
from train.trainer import build_policy
def test_candidate_macro_coverage(tiny_policy_config, tiny_trainer_config, tiny_batch):
config = tiny_policy_config(num_candidates=7)
batch = tiny_batch(chunk_size=config.decoder.chunk_size)
policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
policy.eval()
task_to_text = {
"foliage": ["create a gap in the foliage and retrieve the target"],
"bag": ["open the bag mouth and retrieve the target object"],
"cloth": ["lift the top layer enough to retrieve the hidden object"],
}
for task_name, texts in task_to_text.items():
output = policy(
images=batch["images"][:1],
depths=batch["depths"][:1],
depth_valid=batch["depth_valid"][:1],
camera_intrinsics=batch["camera_intrinsics"][:1],
camera_extrinsics=batch["camera_extrinsics"][:1],
proprio=batch["proprio"][:1],
texts=texts,
history_images=batch["history_images"][:1],
history_depths=batch["history_depths"][:1],
history_depth_valid=batch["history_depth_valid"][:1],
history_proprio=batch["history_proprio"][:1],
history_actions=batch["history_actions"][:1],
plan=False,
)
expected = set(proposal_mode_vocab(task_name, config.decoder.num_proposal_modes))
observed = set(output["proposal_mode_names"][0])
assert expected.issubset(observed)
|