kuds's picture
Update README.md
ad3dc63 verified
---
license: mit
language:
- en
library_name: stable-baselines3
tags:
- reinforcement-learning
- BreakoutNoFrameskip-v4
model-index:
- name: PPO
results:
- task:
type: reinforcement-learning
name: reinforcement-learning
dataset:
name: BreakoutNoFrameskip-v4
type: BreakoutNoFrameskip-v4
metrics:
- type: mean_reward
value: 187.80 +/- 114.62
name: mean_reward
verified: false
---
# **PPO** Agent playing **BreakoutNoFrameskip-v4**
- [Github Repository](https://github.com/kuds/rl-atari-breakout)
- [Google Colab Notebook](https://colab.research.google.com/github/kuds/rl-atari-breakout/blob/main/%5BAtari%20Breakout%5D%20Single-Agent%20Reinforcement%20Learning%20PPO.ipynb)
- [Finding Theta - Blog Post](https://www.findingtheta.com/blog/beginners-guide-to-model-based-reinforcement-learning-mbrl-with-ataris-breakout)
Then, you can load the model using the following Python code:
```python
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecTransposeImage
from stable_baselines3.common.atari_wrappers import WarpFrame
# Load the trained model
model = PPO.load("best-model.zip")
# Create the environment
env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=1)
env = VecFrameStack(env, n_stack=4)
env = VecTransposeImage(env)
# Reset the environment
obs, info = env.reset()
# Enjoy the trained agent
for _ in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, rewards, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, info = env.reset()
env.render()
env.close()
```
### Hugging Face Hub
You can also use the Hugging Face Hub to load the model. First, you need to install the Hugging Face Hub library:
```bash
pip install huggingface_hub
```
Then, you can load the model from the hub using the following code:
```python
from huggingface_hub import hf_hub_download
import torch as th
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecTransposeImage
from stable_baselines3.common.atari_wrappers import WarpFrame
# Download the model from the Hub
model_path = hf_hub_download(repo_id="kuds/atari-breakout-v4-ppo", filename="best-model.zip")
# Load the model
model = PPO.load(model_path)
# Create the environment
env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=1)
env = VecFrameStack(env, n_stack=4)
env = VecTransposeImage(env)
# Enjoy the trained agent
obs = env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, rewards, dones, info = env.step(action)
env.render("human")
env.close()
```