|
|
import itertools |
|
|
import numpy as np |
|
|
import os |
|
|
|
|
|
from copy import deepcopy |
|
|
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper |
|
|
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize |
|
|
from torch.utils.tensorboard.writer import SummaryWriter |
|
|
from typing import List, Optional, Union |
|
|
|
|
|
from shared.callbacks.callback import Callback |
|
|
from shared.policy.policy import Policy |
|
|
from shared.stats import Episode, EpisodeAccumulator, EpisodesStats |
|
|
from wrappers.vec_episode_recorder import VecEpisodeRecorder |
|
|
|
|
|
|
|
|
class EvaluateAccumulator(EpisodeAccumulator): |
|
|
def __init__( |
|
|
self, |
|
|
num_envs: int, |
|
|
goal_episodes: int, |
|
|
print_returns: bool = True, |
|
|
ignore_first_episode: bool = False, |
|
|
): |
|
|
super().__init__(num_envs) |
|
|
self.completed_episodes_by_env_idx = [[] for _ in range(num_envs)] |
|
|
self.goal_episodes_per_env = int(np.ceil(goal_episodes / num_envs)) |
|
|
self.print_returns = print_returns |
|
|
if ignore_first_episode: |
|
|
first_done = set() |
|
|
|
|
|
def should_record_done(idx: int) -> bool: |
|
|
has_done_first_episode = idx in first_done |
|
|
first_done.add(idx) |
|
|
return has_done_first_episode |
|
|
|
|
|
self.should_record_done = should_record_done |
|
|
else: |
|
|
self.should_record_done = lambda idx: True |
|
|
|
|
|
def on_done(self, ep_idx: int, episode: Episode) -> None: |
|
|
if ( |
|
|
self.should_record_done(ep_idx) |
|
|
and len(self.completed_episodes_by_env_idx[ep_idx]) |
|
|
>= self.goal_episodes_per_env |
|
|
): |
|
|
return |
|
|
self.completed_episodes_by_env_idx[ep_idx].append(episode) |
|
|
if self.print_returns: |
|
|
print( |
|
|
f"Episode {len(self)} | " |
|
|
f"Score {episode.score} | " |
|
|
f"Length {episode.length}" |
|
|
) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return sum(len(ce) for ce in self.completed_episodes_by_env_idx) |
|
|
|
|
|
@property |
|
|
def episodes(self) -> List[Episode]: |
|
|
return list(itertools.chain(*self.completed_episodes_by_env_idx)) |
|
|
|
|
|
def is_done(self) -> bool: |
|
|
return all( |
|
|
len(ce) == self.goal_episodes_per_env |
|
|
for ce in self.completed_episodes_by_env_idx |
|
|
) |
|
|
|
|
|
|
|
|
def evaluate( |
|
|
env: VecEnv, |
|
|
policy: Policy, |
|
|
n_episodes: int, |
|
|
render: bool = False, |
|
|
deterministic: bool = True, |
|
|
print_returns: bool = True, |
|
|
ignore_first_episode: bool = False, |
|
|
) -> EpisodesStats: |
|
|
policy.eval() |
|
|
episodes = EvaluateAccumulator( |
|
|
env.num_envs, n_episodes, print_returns, ignore_first_episode |
|
|
) |
|
|
|
|
|
obs = env.reset() |
|
|
while not episodes.is_done(): |
|
|
act = policy.act(obs, deterministic=deterministic) |
|
|
obs, rew, done, _ = env.step(act) |
|
|
episodes.step(rew, done) |
|
|
if render: |
|
|
env.render() |
|
|
stats = EpisodesStats(episodes.episodes) |
|
|
if print_returns: |
|
|
print(stats) |
|
|
return stats |
|
|
|
|
|
|
|
|
class EvalCallback(Callback): |
|
|
def __init__( |
|
|
self, |
|
|
policy: Policy, |
|
|
env: VecEnv, |
|
|
tb_writer: SummaryWriter, |
|
|
best_model_path: Optional[str] = None, |
|
|
step_freq: Union[int, float] = 50_000, |
|
|
n_episodes: int = 10, |
|
|
save_best: bool = True, |
|
|
deterministic: bool = True, |
|
|
record_best_videos: bool = True, |
|
|
video_env: Optional[VecEnv] = None, |
|
|
best_video_dir: Optional[str] = None, |
|
|
max_video_length: int = 3600, |
|
|
ignore_first_episode: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.policy = policy |
|
|
self.env = env |
|
|
self.tb_writer = tb_writer |
|
|
self.best_model_path = best_model_path |
|
|
self.step_freq = int(step_freq) |
|
|
self.n_episodes = n_episodes |
|
|
self.save_best = save_best |
|
|
self.deterministic = deterministic |
|
|
self.stats: List[EpisodesStats] = [] |
|
|
self.best = None |
|
|
|
|
|
self.record_best_videos = record_best_videos |
|
|
assert video_env or not record_best_videos |
|
|
self.video_env = video_env |
|
|
assert best_video_dir or not record_best_videos |
|
|
self.best_video_dir = best_video_dir |
|
|
if best_video_dir: |
|
|
os.makedirs(best_video_dir, exist_ok=True) |
|
|
self.max_video_length = max_video_length |
|
|
self.best_video_base_path = None |
|
|
|
|
|
self.ignore_first_episode = ignore_first_episode |
|
|
|
|
|
def on_step(self, timesteps_elapsed: int = 1) -> bool: |
|
|
super().on_step(timesteps_elapsed) |
|
|
if self.timesteps_elapsed // self.step_freq >= len(self.stats): |
|
|
sync_vec_normalize(self.policy.vec_normalize, self.env) |
|
|
self.evaluate() |
|
|
return True |
|
|
|
|
|
def evaluate( |
|
|
self, n_episodes: Optional[int] = None, print_returns: Optional[bool] = None |
|
|
) -> EpisodesStats: |
|
|
eval_stat = evaluate( |
|
|
self.env, |
|
|
self.policy, |
|
|
n_episodes or self.n_episodes, |
|
|
deterministic=self.deterministic, |
|
|
print_returns=print_returns or False, |
|
|
ignore_first_episode=self.ignore_first_episode, |
|
|
) |
|
|
self.policy.train(True) |
|
|
print(f"Eval Timesteps: {self.timesteps_elapsed} | {eval_stat}") |
|
|
|
|
|
self.stats.append(eval_stat) |
|
|
|
|
|
if not self.best or eval_stat >= self.best: |
|
|
strictly_better = not self.best or eval_stat > self.best |
|
|
self.best = eval_stat |
|
|
if self.save_best: |
|
|
assert self.best_model_path |
|
|
self.policy.save(self.best_model_path) |
|
|
print("Saved best model") |
|
|
self.best.write_to_tensorboard( |
|
|
self.tb_writer, "best_eval", self.timesteps_elapsed |
|
|
) |
|
|
if strictly_better and self.record_best_videos: |
|
|
assert self.video_env and self.best_video_dir |
|
|
sync_vec_normalize(self.policy.vec_normalize, self.video_env) |
|
|
self.best_video_base_path = os.path.join( |
|
|
self.best_video_dir, str(self.timesteps_elapsed) |
|
|
) |
|
|
video_wrapped = VecEpisodeRecorder( |
|
|
self.video_env, |
|
|
self.best_video_base_path, |
|
|
max_video_length=self.max_video_length, |
|
|
) |
|
|
video_stats = evaluate( |
|
|
video_wrapped, |
|
|
self.policy, |
|
|
1, |
|
|
deterministic=self.deterministic, |
|
|
print_returns=False, |
|
|
) |
|
|
print(f"Saved best video: {video_stats}") |
|
|
|
|
|
eval_stat.write_to_tensorboard(self.tb_writer, "eval", self.timesteps_elapsed) |
|
|
|
|
|
return eval_stat |
|
|
|
|
|
|
|
|
def sync_vec_normalize( |
|
|
origin_vec_normalize: Optional[VecNormalize], destination_env: VecEnv |
|
|
) -> None: |
|
|
if origin_vec_normalize is not None: |
|
|
eval_env_wrapper = destination_env |
|
|
while isinstance(eval_env_wrapper, VecEnvWrapper): |
|
|
if isinstance(eval_env_wrapper, VecNormalize): |
|
|
if hasattr(origin_vec_normalize, "obs_rms"): |
|
|
eval_env_wrapper.obs_rms = deepcopy(origin_vec_normalize.obs_rms) |
|
|
eval_env_wrapper.ret_rms = deepcopy(origin_vec_normalize.ret_rms) |
|
|
eval_env_wrapper = eval_env_wrapper.venv |
|
|
|