Spaces:
Runtime error
Runtime error
| from sim.robomimic.robomimic_runner import RolloutRunner | |
| from sim.policy import GeniePolicy | |
| import argparse | |
| from datetime import datetime | |
| current_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| if __name__ == "__main__": | |
| # initialize environment | |
| parser = argparse.ArgumentParser(description="policy to evaluate") | |
| # Data | |
| parser.add_argument( "--env_name", type=str, default="lift") | |
| parser.add_argument( "--num_runs", type=int, default=1) | |
| parser.add_argument( "--save_video", action="store_true") | |
| parser.add_argument( "--model", type=str, default="data/mar_policy_dynamics/step_30000") | |
| parser.add_argument( "--use_magvit", action="store_true") | |
| parser.add_argument( "--is_full_dynamics", action="store_true") | |
| parser.add_argument( "--use_raw_image", action="store_true") | |
| parser.add_argument( "--execution_horizon", type=int, default=4) | |
| parser.add_argument( "--diffusion_steps", type=int, default=100) | |
| parser.add_argument( "--inference_iterations", type=int, default=1) | |
| parser.add_argument( "--prompt_horizon", type=int, default=1) | |
| args = parser.parse_args() | |
| env_name = args.env_name | |
| rollout_runner = RolloutRunner( env_names=[env_name], episode_num=args.num_runs, save_video=args.save_video) | |
| execution_horizon = args.execution_horizon | |
| diffusion_steps = args.diffusion_steps | |
| inference_iterations = args.inference_iterations | |
| prompt_horizon = args.prompt_horizon | |
| is_full_dynamics = args.is_full_dynamics | |
| model = args.model | |
| if is_full_dynamics: | |
| # model = "data/mar_policy_dynamics2/final2_robomimic_scratch_mar_forward_dynamics_gpu_8_nodes_2_16g/step_50000" | |
| # model = "data/final2_robomimic_scratch_mar_full_dynamics_new_gpu_8_nodes_4_16g/step_10000" | |
| # model = "data/final2_robomimic_scratch_mar_dynamics_fullpastmask_new_gpu_8_nodes_4_16g/step_10000" | |
| # model = "data/final2_robomimic_scratch_mar_full_dynamics_fixed_new_gpu_8_nodes_4_16g/step_20000" | |
| model_suffix = f"dynamics_{model.split('/')[-2]}_{model.split('/')[-1]}_horizon{execution_horizon}" | |
| else: | |
| # model = "data/mar_policy2/final2_robomimic_scratch_mar_actiononly_gpu_8_nodes_4_16g/final_checkpt" | |
| # model = "data/final2_robomimic_scratch_mar_actiononly_new_gpu_8_nodes_4_16g/final2_robomimic_scratch_mar_actiononly_new_gpu_8_nodes_4_16g/step_10000" | |
| # model = "data/final2_robomimic_scratch_mar_fullpastmask_actiononly_fixed_gpu_8_nodes_4_16g/step_10000" | |
| # model = "data/final2_robomimic_scratch_mar_fullpastmask_actiononly_fixed_gpu_8_nodes_4_16g/step_10000" | |
| # model = "data/mar_policy_actiononly3/step_10000" | |
| model_suffix = f"actiononly_{model.split('/')[-2]}_{model.split('/')[-1]}_horizon{execution_horizon}" | |
| policy = GeniePolicy( | |
| image_encoder_type="temporalvae" if not args.use_magvit else "magvit", | |
| image_encoder_ckpt="stabilityai/stable-video-diffusion-img2vid" if not args.use_magvit else "data/magvit2.ckpt", | |
| quantize=False if not args.use_magvit else True, | |
| backbone_type="stmar" if not args.use_magvit else "stmaskgit", | |
| backbone_ckpt=model, | |
| prompt_horizon=prompt_horizon, # history step | |
| prediction_horizon=execution_horizon, # future step | |
| execution_horizon=execution_horizon, # open loop step | |
| inference_iterations=inference_iterations, # maskgit step | |
| diffusion_steps=diffusion_steps, # diffusion steps | |
| action_stride=1, | |
| domain="robomimic", | |
| is_full_dynamics=is_full_dynamics, | |
| use_raw_image=args.use_raw_image, | |
| ) | |
| # initialize policy | |
| success, reward = rollout_runner.run(policy=policy, env_name=[env_name], video_postfix=model_suffix) | |
| print(f"success: {success}, reward: {reward}") | |
| # dump the success with model name to csv | |
| with open("success.csv", "a+") as f: | |
| f.write(f"{model_suffix}, {success}, {reward}, {execution_horizon}, {diffusion_steps}, {inference_iterations}, {args.prompt_horizon}, {args.num_runs}, {current_date}\n") | |