| import gymnasium as gym |
| import ale_py |
| import json |
| from stable_baselines3 import PPO |
| from stable_baselines3.common.evaluation import evaluate_policy |
| from stable_baselines3.common.atari_wrappers import AtariWrapper |
| from stable_baselines3.common.env_util import make_vec_env |
| from stable_baselines3.common.vec_env import VecFrameStack |
| from safetensors.torch import save_model |
|
|
| if __name__ == "__main__": |
|
|
| gym.register_envs(ale_py) |
| env = make_vec_env(env_id="ALE/Enduro-v5", wrapper_class=AtariWrapper) |
| env = VecFrameStack(env, n_stack=4) |
|
|
| n_eval_episodes = 100 |
| deterministic = True |
| |
| agent = PPO.load("model.zip") |
|
|
| |
| save_model(agent.policy, "model.safetensors") |
|
|
| mean_reward, std_reward = evaluate_policy(agent, env, n_eval_episodes=n_eval_episodes, deterministic=deterministic) |
| |
| print(f"reward : {mean_reward} +/- {std_reward}") |
| |
| results = { |
| "mean_reward": mean_reward, |
| "std_reward": std_reward, |
| "episodes": n_eval_episodes, |
| "is_deterministic": deterministic |
| } |
| |
| |
| with open("results.json", "w") as f: |
| json.dump(results, f, indent=4) |
|
|