Reward Rush: CartPole DQN
This repository contains the cleaned weights for a Deep Q-Network agent trained for the CartPole-v1 environment.
Model Architecture
The model utilizes a multi-layer perceptron structure designed for low-latency inference:
- Input: 4 state observations
- Output: 2 discrete actions
- Network Structure:
- Linear(4, 128) -> ReLU
- Linear(128, 128) -> ReLU
- Linear(128, 2)
Common Implementation Mistakes to Avoid
- Variable Naming: The weights are mapped to specific names:
layer1,layer2, andlayer3. Using generic names likefc1ornn.Sequentialwill result in a loading error. - Missing Batch Dimension: The model expects a batch dimension. Input states must be wrapped using
unsqueeze(0)before inference. - Inference Logic: The model outputs raw Q-values for both actions. Use
argmax(dim=1)to select the correct action index for the environment.
Download and Test Code
This script downloads the weights from the Hugging Face repository, initializes the environment, and evaluates the agent over 100 test episodes.
import torch
import torch.nn as nn
import gymnasium as gym
import numpy as np
from huggingface_hub import hf_hub_download
class MatchedNet(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(4, 128)
self.layer2 = nn.Linear(128, 128)
self.layer3 = nn.Linear(128, 2)
def forward(self, x):
x = torch.relu(self.layer1(x))
x = torch.relu(self.layer2(x))
return self.layer3(x)
def run_cartpole_test():
path = hf_hub_download(repo_id="Nharen/Reward_Rush_DQN_Cart_Pole", filename="Cartpole.pth")
model = MatchedNet()
state_dict = torch.load(path, map_location='cpu', weights_only=True)
if isinstance(state_dict, dict) and "policy_net_state_dict" in state_dict:
state_dict = state_dict["policy_net_state_dict"]
model.load_state_dict(state_dict)
model.eval()
env = gym.make("CartPole-v1")
total_rewards = []
for _ in range(100):
state, _ = env.reset()
episode_reward = 0
done = False
while not done:
state_t = torch.as_tensor(state, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
action = model(state_t).argmax(dim=1).item()
state, reward, terminated, truncated, _ = env.step(action)
episode_reward += reward
done = terminated or truncated
total_rewards.append(episode_reward)
print(f"Average Reward: {np.mean(total_rewards)}")
env.close()
if __name__ == "__main__":
run_cartpole_test()
Evaluation results
- mean_reward on CartPole-v1self-reported500.000
- n_evaluation_episodes on CartPole-v1self-reported100.000