|
|
--- |
|
|
license: mit |
|
|
language: |
|
|
- en |
|
|
library_name: stable-baselines3 |
|
|
tags: |
|
|
- reinforcement-learning |
|
|
- BreakoutNoFrameskip-v4 |
|
|
model-index: |
|
|
- name: PPO |
|
|
results: |
|
|
- task: |
|
|
type: reinforcement-learning |
|
|
name: reinforcement-learning |
|
|
dataset: |
|
|
name: BreakoutNoFrameskip-v4 |
|
|
type: BreakoutNoFrameskip-v4 |
|
|
metrics: |
|
|
- type: mean_reward |
|
|
value: 187.80 +/- 114.62 |
|
|
name: mean_reward |
|
|
verified: false |
|
|
--- |
|
|
|
|
|
|
|
|
# **PPO** Agent playing **BreakoutNoFrameskip-v4** |
|
|
- [Github Repository](https://github.com/kuds/rl-atari-breakout) |
|
|
- [Google Colab Notebook](https://colab.research.google.com/github/kuds/rl-atari-breakout/blob/main/%5BAtari%20Breakout%5D%20Single-Agent%20Reinforcement%20Learning%20PPO.ipynb) |
|
|
- [Finding Theta - Blog Post](https://www.findingtheta.com/blog/beginners-guide-to-model-based-reinforcement-learning-mbrl-with-ataris-breakout) |
|
|
|
|
|
|
|
|
Then, you can load the model using the following Python code: |
|
|
|
|
|
```python |
|
|
import gymnasium as gym |
|
|
from stable_baselines3 import PPO |
|
|
from stable_baselines3.common.env_util import make_atari_env |
|
|
from stable_baselines3.common.vec_env import VecTransposeImage |
|
|
from stable_baselines3.common.atari_wrappers import WarpFrame |
|
|
|
|
|
# Load the trained model |
|
|
model = PPO.load("best-model.zip") |
|
|
|
|
|
# Create the environment |
|
|
env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=1) |
|
|
env = VecFrameStack(env, n_stack=4) |
|
|
env = VecTransposeImage(env) |
|
|
|
|
|
# Reset the environment |
|
|
obs, info = env.reset() |
|
|
|
|
|
# Enjoy the trained agent |
|
|
for _ in range(1000): |
|
|
action, _states = model.predict(obs, deterministic=True) |
|
|
obs, rewards, terminated, truncated, info = env.step(action) |
|
|
if terminated or truncated: |
|
|
obs, info = env.reset() |
|
|
env.render() |
|
|
env.close() |
|
|
``` |
|
|
|
|
|
### Hugging Face Hub |
|
|
|
|
|
You can also use the Hugging Face Hub to load the model. First, you need to install the Hugging Face Hub library: |
|
|
|
|
|
```bash |
|
|
pip install huggingface_hub |
|
|
``` |
|
|
|
|
|
Then, you can load the model from the hub using the following code: |
|
|
|
|
|
```python |
|
|
from huggingface_hub import hf_hub_download |
|
|
import torch as th |
|
|
import gymnasium as gym |
|
|
from stable_baselines3 import PPO |
|
|
from stable_baselines3.common.env_util import make_atari_env |
|
|
from stable_baselines3.common.vec_env import VecTransposeImage |
|
|
from stable_baselines3.common.atari_wrappers import WarpFrame |
|
|
|
|
|
# Download the model from the Hub |
|
|
model_path = hf_hub_download(repo_id="kuds/atari-breakout-v4-ppo", filename="best-model.zip") |
|
|
|
|
|
# Load the model |
|
|
model = PPO.load(model_path) |
|
|
|
|
|
# Create the environment |
|
|
env = make_atari_env("BreakoutNoFrameskip-v4", n_envs=1) |
|
|
env = VecFrameStack(env, n_stack=4) |
|
|
env = VecTransposeImage(env) |
|
|
|
|
|
# Enjoy the trained agent |
|
|
obs = env.reset() |
|
|
for i in range(1000): |
|
|
action, _states = model.predict(obs, deterministic=True) |
|
|
obs, rewards, dones, info = env.step(action) |
|
|
env.render("human") |
|
|
env.close() |
|
|
``` |