lucasschott commited on
Commit
43cdcd3
·
1 Parent(s): f2c934b

add safetensors

Browse files
Files changed (3) hide show
  1. eval.py +5 -4
  2. model.safetensors +3 -0
  3. requirements.txt +2 -1
eval.py CHANGED
@@ -6,6 +6,7 @@ from stable_baselines3.common.evaluation import evaluate_policy
6
  from stable_baselines3.common.atari_wrappers import AtariWrapper
7
  from stable_baselines3.common.env_util import make_vec_env
8
  from stable_baselines3.common.vec_env import VecFrameStack
 
9
 
10
  if __name__ == "__main__":
11
 
@@ -18,10 +19,10 @@ if __name__ == "__main__":
18
 
19
  agent = PPO.load("model.zip")
20
 
21
- print(agent.policy)
22
- total_params = sum(p.numel() for p in agent.policy.parameters())
23
- print("\nTotal number of parameters: {:,}".format(total_params))
24
-
25
  mean_reward, std_reward = evaluate_policy(agent, env, n_eval_episodes=n_eval_episodes, deterministic=deterministic)
26
 
27
  print(f"reward : {mean_reward} +/- {std_reward}")
 
6
  from stable_baselines3.common.atari_wrappers import AtariWrapper
7
  from stable_baselines3.common.env_util import make_vec_env
8
  from stable_baselines3.common.vec_env import VecFrameStack
9
+ from safetensors.torch import save_model
10
 
11
  if __name__ == "__main__":
12
 
 
19
 
20
  agent = PPO.load("model.zip")
21
 
22
+ # Save the state dict as a safetensors file
23
+ policy_state_dict = agent.policy.state_dict()
24
+ save_model(agent.policy, "model.safetensors")
25
+
26
  mean_reward, std_reward = evaluate_policy(agent, env, n_eval_episodes=n_eval_episodes, deterministic=deterministic)
27
 
28
  print(f"reward : {mean_reward} +/- {std_reward}")
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32f14418431437b7d5f5a4e2ae80b4454939ccaab5cf811387a30409082e8baa
3
+ size 8978048
requirements.txt CHANGED
@@ -3,4 +3,5 @@ ale_py==0.10.1
3
  gymnasium[atari]
4
  torch==2.4.0
5
  stable_baselines3==2.4.1
6
- opencv-python==25.0.1
 
 
3
  gymnasium[atari]
4
  torch==2.4.0
5
  stable_baselines3==2.4.1
6
+ opencv-python==25.0.1
7
+ safetensors==0.5.3