from train.trainer import build_policy def test_rgb_backward_compat(tiny_policy_config, tiny_trainer_config, tiny_batch): config = tiny_policy_config() batch = tiny_batch(chunk_size=config.decoder.chunk_size) interaction_policy = build_policy(config, tiny_trainer_config(policy_type="interaction_state")) interaction_output = interaction_policy( images=batch["images"], proprio=batch["proprio"], texts=batch["texts"], history_images=batch["history_images"], history_proprio=batch["history_proprio"], history_actions=batch["history_actions"], plan=True, ) assert interaction_output["action_mean"].shape[-1] == 14 assert interaction_output["candidate_chunks"].shape[1] == config.decoder.num_candidates elastic_policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal")) elastic_output = elastic_policy( images=batch["images"], proprio=batch["proprio"], texts=batch["texts"], history_images=batch["history_images"], history_proprio=batch["history_proprio"], history_actions=batch["history_actions"], plan=True, use_depth=False, ) assert elastic_output["action_mean"].shape[-1] == 14 assert elastic_output["planned_chunk"].shape == elastic_output["action_mean"].shape