kuds commited on
Commit
bfc9a77
·
verified ·
1 Parent(s): cc7d468

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +98 -3
README.md CHANGED
@@ -1,3 +1,98 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ library_name: stable-baselines3
6
+ tags:
7
+ - reinforcement-learning
8
+ - PongNoFrameskip-v4
9
+ model-index:
10
+ - name: PPO
11
+ results:
12
+ - task:
13
+ type: reinforcement-learning
14
+ name: reinforcement-learning
15
+ dataset:
16
+ name: PongNoFrameskip-v4
17
+ type: PongNoFrameskip-v4
18
+ metrics:
19
+ - type: mean_reward
20
+ value: 21.00 +/- 00.00
21
+ name: mean_reward
22
+ verified: false
23
+ ---
24
+
25
+
26
+ # **DQN** Agent playing **PongNoFrameskip-v4**
27
+ - [Github Repository](https://github.com/kuds/rl-atari-pong)
28
+ - [Google Colab Notebook](https://colab.research.google.com/github/kuds/rl-atari-pong/blob/main/%5BAtari%20Pong%5D%20Single-Agent%20Reinforcement%20Learning%20PPO.ipynb)
29
+ - [Finding Theta - Blog Post](https://www.findingtheta.com/blog/mastering-ataris-pong-with-reinforcement-learning-overcoming-sparse-rewards-and-optimizing-performance)
30
+
31
+
32
+ Then, you can load the model using the following Python code:
33
+
34
+ ```python
35
+ import gymnasium as gym
36
+ from stable_baselines3 import PPO
37
+ from stable_baselines3.common.env_util import make_atari_env
38
+ from stable_baselines3.common.vec_env import VecTransposeImage
39
+ from stable_baselines3.common.atari_wrappers import WarpFrame
40
+
41
+ # Load the trained model
42
+ model = PPO.load("best-model.zip")
43
+
44
+ # Create the environment
45
+ env = make_atari_env("PongNoFrameskip-v4", n_envs=1)
46
+ env = VecFrameStack(env, n_stack=4)
47
+ env = VecTransposeImage(env)
48
+
49
+ # Reset the environment
50
+ obs, info = env.reset()
51
+
52
+ # Enjoy the trained agent
53
+ for _ in range(1000):
54
+ action, _states = model.predict(obs, deterministic=True)
55
+ obs, rewards, terminated, truncated, info = env.step(action)
56
+ if terminated or truncated:
57
+ obs, info = env.reset()
58
+ env.render()
59
+ env.close()
60
+ ```
61
+
62
+ ### Hugging Face Hub
63
+
64
+ You can also use the Hugging Face Hub to load the model. First, you need to install the Hugging Face Hub library:
65
+
66
+ ```bash
67
+ pip install huggingface_hub
68
+ ```
69
+
70
+ Then, you can load the model from the hub using the following code:
71
+
72
+ ```python
73
+ from huggingface_hub import hf_hub_download
74
+ import torch as th
75
+ import gymnasium as gym
76
+ from stable_baselines3 import PPO
77
+ from stable_baselines3.common.env_util import make_atari_env
78
+ from stable_baselines3.common.vec_env import VecTransposeImage
79
+ from stable_baselines3.common.atari_wrappers import WarpFrame
80
+
81
+ # Download the model from the Hub
82
+ model_path = hf_hub_download(repo_id="kuds/atari-pong-v4-ppo", filename="best-model.zip")
83
+
84
+ # Load the model
85
+ model = PPO.load(model_path)
86
+
87
+ # Create the environment
88
+ env = make_atari_env("PongNoFrameskip-v4", n_envs=1)
89
+ env = VecFrameStack(env, n_stack=4)
90
+ env = VecTransposeImage(env)
91
+
92
+ # Enjoy the trained agent
93
+ obs = env.reset()
94
+ for i in range(1000):
95
+ action, _states = model.predict(obs, deterministic=True)
96
+ obs, rewards, dones, info = env.step(action)
97
+ env.render("human")
98
+ ```