Reinforce Agent playing CartPole-v1

This is a trained model of a custom Reinforce agent playing CartPole-v1.

This was created for Unit 4 of the Hugging Face Deep RL Course. I have added some features, such as entropy loss, and updating the code for Gymnasium versus the deprecated Gym

Hyperparameters used to train were optimized by Optuna

{
  "n_training_episodes": 1000,
  "n_evaluation_episodes": 100,
  "h_size": 9,
  "gamma": 0.9965597864582326,
  "learning_rate": 0.008939700003204954,
}

Usage

You will first need the custom class definition, listed here:

class Policy(nn.Module):
    def __init__(self, s_size, a_size, h_size):
        super(Policy, self).__init__()
        # our fully connected layers
        self.fc1 = nn.Linear(s_size, h_size) # start, hidden
        self.fc2 = nn.Linear(h_size, a_size) # hidden, output


    def forward(self, x):
        # forward pass of the model
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x, dim=1) # fc1 -> relu -> fc2 -> softmax

    def act(self, state):
        """
          Given a state, take action
        """
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        probs = self.forward(state).cpu()
        entropy = (-probs * probs.log()).sum() # compute entropy loss (see https://fosterelli.co/entropy-loss-for-reinforcement-learning for more details on why this helps)
        m = Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action), entropy
import gymnasium as gym
from huggingface_hub import hf_hub_download
import torch
import imageio
from IPython.display import Image

env_id = "CartPole-v1"
env = gym.make(env_id, render_mode="rgb_array")
device = "cuda" if torch.cuda.is_available() else "cpu"

# load policy from huggingface
model_fp = hf_hub_download(
  repo_id="reeeemo/Reinforce-CartPole-v1",
  filename="model.pt",
)

policy = torch.load(model_fp, map_location=device, weights_only=False)

# run an episode
images = []
state, _ = env.reset()
done = False
while not done:
  action, _, _ = policy.act(state)
  state, reward, terminated, truncated, info = env.step(action)
  images.append(env.render())
  done = terminated or truncated

# save images as video and present (must be in notebook)
imageio.mimsave('./cartpole_replay.gif', [np.array(img) for i, img in enumerate(images)], duration=(1000*1/30))
Image(open('./cartpole_replay.gif','rb').read())
Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading

Collection including reeeemo/Reinforce-CartPole-v1

Evaluation results