Enduro-v5-PPO / eval.py
lucasschott's picture
update safetensors
d7ef094
import gymnasium as gym
import ale_py
import json
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecFrameStack
from safetensors.torch import save_model
if __name__ == "__main__":
gym.register_envs(ale_py)
env = make_vec_env(env_id="ALE/Enduro-v5", wrapper_class=AtariWrapper)
env = VecFrameStack(env, n_stack=4)
n_eval_episodes = 100
deterministic = True
agent = PPO.load("model.zip")
# Save the model as a safetensors file
save_model(agent.policy, "model.safetensors")
mean_reward, std_reward = evaluate_policy(agent, env, n_eval_episodes=n_eval_episodes, deterministic=deterministic)
print(f"reward : {mean_reward} +/- {std_reward}")
results = {
"mean_reward": mean_reward,
"std_reward": std_reward,
"episodes": n_eval_episodes,
"is_deterministic": deterministic
}
# Dump the results into a JSON file with pretty printing (indentation)
with open("results.json", "w") as f:
json.dump(results, f, indent=4)