| | import gymnasium as gym |
| | import json |
| | from stable_baselines3 import SAC |
| | from stable_baselines3.common.evaluation import evaluate_policy |
| | from safetensors.torch import save_model |
| |
|
| | if __name__ == "__main__": |
| |
|
| | env = gym.make("HalfCheetah-v5") |
| | n_eval_episodes = 100 |
| | deterministic = True |
| | |
| | agent = SAC.load("model.zip") |
| | |
| | |
| | save_model(agent.policy, "model.safetensors") |
| |
|
| | mean_reward, std_reward = evaluate_policy(agent, env, n_eval_episodes=n_eval_episodes, deterministic=deterministic) |
| | |
| | print(f"reward : {mean_reward} +/- {std_reward}") |
| | |
| | results = { |
| | "mean_reward": mean_reward, |
| | "std_reward": std_reward, |
| | "episodes": n_eval_episodes, |
| | "is_deterministic": deterministic |
| | } |
| | |
| | |
| | with open("results.json", "w") as f: |
| | json.dump(results, f, indent=4) |
| |
|