lsnu's picture
Add files using upload-large-folder tool
912c7e2 verified
import wandb
from tqdm import tqdm
from pfp.envs.rlbench_env import RLBenchEnv
from pfp.policy.base_policy import BasePolicy
class RLBenchRunner:
def __init__(
self,
num_episodes: int,
max_episode_length: int,
env_config: dict,
verbose=False,
) -> None:
self.env: RLBenchEnv = RLBenchEnv(**env_config)
self.num_episodes = num_episodes
self.max_episode_length = max_episode_length
self.verbose = verbose
return
def run(self, policy: BasePolicy):
wandb.define_metric("success", summary="mean")
wandb.define_metric("steps", summary="mean")
success_list: list[bool] = []
steps_list: list[int] = []
self.env.reset_rng()
for episode in tqdm(range(self.num_episodes)):
policy.reset_obs()
self.env.reset()
for step in range(self.max_episode_length):
robot_state, obs = self.env.get_obs()
prediction = policy.predict_action(obs, robot_state)
self.env.vis_step(robot_state, obs, prediction)
next_robot_state = prediction[-1, 0] # Last K step, first T step
reward, terminate = self.env.step(next_robot_state)
success = bool(reward)
if success or terminate:
break
success_list.append(success)
if success:
steps_list.append(step)
if self.verbose:
print(f"Steps: {step}")
print(f"Success: {success}")
wandb.log({"episode": episode, "success": int(success), "steps": step})
return success_list, steps_list