File size: 946 Bytes
ca82368
f8a96b8
ca82368
 
48eeee2
ca82368
 
 
f8a96b8
 
 
 
 
 
48eeee2
 
e6867e3
f8a96b8
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import gymnasium as gym
import json
from stable_baselines3 import SAC
from stable_baselines3.common.evaluation import evaluate_policy
from safetensors.torch import save_model

if __name__ == "__main__":

    env = gym.make("HalfCheetah-v5")
    n_eval_episodes = 100
    deterministic = True
    
    agent = SAC.load("model.zip")
    
    # Save the model as a safetensors file
    save_model(agent.policy, "model.safetensors")

    mean_reward, std_reward = evaluate_policy(agent, env, n_eval_episodes=n_eval_episodes, deterministic=deterministic)
    
    print(f"reward : {mean_reward} +/- {std_reward}")
    
    results = {
        "mean_reward": mean_reward,
        "std_reward": std_reward,
        "episodes": n_eval_episodes,
        "is_deterministic": deterministic
    }
    
    # Dump the results into a JSON file with pretty printing (indentation)
    with open("results.json", "w") as f:
        json.dump(results, f, indent=4)