Commit ·
43cdcd3
1
Parent(s): f2c934b
add safetensors
Browse files- eval.py +5 -4
- model.safetensors +3 -0
- requirements.txt +2 -1
eval.py
CHANGED
|
@@ -6,6 +6,7 @@ from stable_baselines3.common.evaluation import evaluate_policy
|
|
| 6 |
from stable_baselines3.common.atari_wrappers import AtariWrapper
|
| 7 |
from stable_baselines3.common.env_util import make_vec_env
|
| 8 |
from stable_baselines3.common.vec_env import VecFrameStack
|
|
|
|
| 9 |
|
| 10 |
if __name__ == "__main__":
|
| 11 |
|
|
@@ -18,10 +19,10 @@ if __name__ == "__main__":
|
|
| 18 |
|
| 19 |
agent = PPO.load("model.zip")
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
mean_reward, std_reward = evaluate_policy(agent, env, n_eval_episodes=n_eval_episodes, deterministic=deterministic)
|
| 26 |
|
| 27 |
print(f"reward : {mean_reward} +/- {std_reward}")
|
|
|
|
| 6 |
from stable_baselines3.common.atari_wrappers import AtariWrapper
|
| 7 |
from stable_baselines3.common.env_util import make_vec_env
|
| 8 |
from stable_baselines3.common.vec_env import VecFrameStack
|
| 9 |
+
from safetensors.torch import save_model
|
| 10 |
|
| 11 |
if __name__ == "__main__":
|
| 12 |
|
|
|
|
| 19 |
|
| 20 |
agent = PPO.load("model.zip")
|
| 21 |
|
| 22 |
+
# Save the state dict as a safetensors file
|
| 23 |
+
policy_state_dict = agent.policy.state_dict()
|
| 24 |
+
save_model(agent.policy, "model.safetensors")
|
| 25 |
+
|
| 26 |
mean_reward, std_reward = evaluate_policy(agent, env, n_eval_episodes=n_eval_episodes, deterministic=deterministic)
|
| 27 |
|
| 28 |
print(f"reward : {mean_reward} +/- {std_reward}")
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:32f14418431437b7d5f5a4e2ae80b4454939ccaab5cf811387a30409082e8baa
|
| 3 |
+
size 8978048
|
requirements.txt
CHANGED
|
@@ -3,4 +3,5 @@ ale_py==0.10.1
|
|
| 3 |
gymnasium[atari]
|
| 4 |
torch==2.4.0
|
| 5 |
stable_baselines3==2.4.1
|
| 6 |
-
opencv-python==25.0.1
|
|
|
|
|
|
| 3 |
gymnasium[atari]
|
| 4 |
torch==2.4.0
|
| 5 |
stable_baselines3==2.4.1
|
| 6 |
+
opencv-python==25.0.1
|
| 7 |
+
safetensors==0.5.3
|