| import wandb |
| from tqdm import tqdm |
| from pfp.envs.rlbench_env import RLBenchEnv |
| from pfp.policy.base_policy import BasePolicy |
|
|
|
|
| class RLBenchRunner: |
| def __init__( |
| self, |
| num_episodes: int, |
| max_episode_length: int, |
| env_config: dict, |
| verbose=False, |
| ) -> None: |
| self.env: RLBenchEnv = RLBenchEnv(**env_config) |
| self.num_episodes = num_episodes |
| self.max_episode_length = max_episode_length |
| self.verbose = verbose |
| return |
|
|
| def run(self, policy: BasePolicy): |
| wandb.define_metric("success", summary="mean") |
| wandb.define_metric("steps", summary="mean") |
| success_list: list[bool] = [] |
| steps_list: list[int] = [] |
| self.env.reset_rng() |
| for episode in tqdm(range(self.num_episodes)): |
| policy.reset_obs() |
| self.env.reset() |
| for step in range(self.max_episode_length): |
| robot_state, obs = self.env.get_obs() |
| prediction = policy.predict_action(obs, robot_state) |
| self.env.vis_step(robot_state, obs, prediction) |
| next_robot_state = prediction[-1, 0] |
| reward, terminate = self.env.step(next_robot_state) |
| success = bool(reward) |
| if success or terminate: |
| break |
| success_list.append(success) |
| if success: |
| steps_list.append(step) |
| if self.verbose: |
| print(f"Steps: {step}") |
| print(f"Success: {success}") |
| wandb.log({"episode": episode, "success": int(success), "steps": step}) |
| return success_list, steps_list |
|
|