File size: 946 Bytes
ca82368 f8a96b8 ca82368 48eeee2 ca82368 f8a96b8 48eeee2 e6867e3 f8a96b8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import gymnasium as gym
import json
from stable_baselines3 import SAC
from stable_baselines3.common.evaluation import evaluate_policy
from safetensors.torch import save_model
if __name__ == "__main__":
env = gym.make("HalfCheetah-v5")
n_eval_episodes = 100
deterministic = True
agent = SAC.load("model.zip")
# Save the model as a safetensors file
save_model(agent.policy, "model.safetensors")
mean_reward, std_reward = evaluate_policy(agent, env, n_eval_episodes=n_eval_episodes, deterministic=deterministic)
print(f"reward : {mean_reward} +/- {std_reward}")
results = {
"mean_reward": mean_reward,
"std_reward": std_reward,
"episodes": n_eval_episodes,
"is_deterministic": deterministic
}
# Dump the results into a JSON file with pretty printing (indentation)
with open("results.json", "w") as f:
json.dump(results, f, indent=4)
|