lucasschott's picture
update safetensors
48eeee2
import gymnasium as gym
import json
from stable_baselines3 import SAC
from stable_baselines3.common.evaluation import evaluate_policy
from safetensors.torch import save_model
if __name__ == "__main__":
env = gym.make("HalfCheetah-v5")
n_eval_episodes = 100
deterministic = True
agent = SAC.load("model.zip")
# Save the model as a safetensors file
save_model(agent.policy, "model.safetensors")
mean_reward, std_reward = evaluate_policy(agent, env, n_eval_episodes=n_eval_episodes, deterministic=deterministic)
print(f"reward : {mean_reward} +/- {std_reward}")
results = {
"mean_reward": mean_reward,
"std_reward": std_reward,
"episodes": n_eval_episodes,
"is_deterministic": deterministic
}
# Dump the results into a JSON file with pretty printing (indentation)
with open("results.json", "w") as f:
json.dump(results, f, indent=4)