File size: 1,357 Bytes
b14c4b7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | from train.trainer import build_policy
def test_rgb_backward_compat(tiny_policy_config, tiny_trainer_config, tiny_batch):
config = tiny_policy_config()
batch = tiny_batch(chunk_size=config.decoder.chunk_size)
interaction_policy = build_policy(config, tiny_trainer_config(policy_type="interaction_state"))
interaction_output = interaction_policy(
images=batch["images"],
proprio=batch["proprio"],
texts=batch["texts"],
history_images=batch["history_images"],
history_proprio=batch["history_proprio"],
history_actions=batch["history_actions"],
plan=True,
)
assert interaction_output["action_mean"].shape[-1] == 14
assert interaction_output["candidate_chunks"].shape[1] == config.decoder.num_candidates
elastic_policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
elastic_output = elastic_policy(
images=batch["images"],
proprio=batch["proprio"],
texts=batch["texts"],
history_images=batch["history_images"],
history_proprio=batch["history_proprio"],
history_actions=batch["history_actions"],
plan=True,
use_depth=False,
)
assert elastic_output["action_mean"].shape[-1] == 14
assert elastic_output["planned_chunk"].shape == elastic_output["action_mean"].shape
|