ppo-lunarlanding-v2 / test_model.py
sam522's picture
Upload PPO LunarLander model
18cc826 verified
import torch
import gymnasium as gym
import torch.nn as nn
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, hidden_size=64):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, action_dim)
)
def forward(self, x):
return self.network(x)
def test_model():
# Load the model
checkpoint = torch.load("model.pt", map_location='cpu')
actor = Actor(state_dim=8, action_dim=4, hidden_size=checkpoint['config']['hidden_size'])
actor.load_state_dict(checkpoint['actor_state_dict'])
actor.eval()
# Test the agent
env = gym.make("LunarLander-v2", render_mode="human")
state, _ = env.reset()
total_reward = 0
for _ in range(1000):
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0)
logits = actor(state_tensor)
action = torch.argmax(logits, dim=-1).item()
state, reward, terminated, truncated, _ = env.step(action)
total_reward += reward
if terminated or truncated:
break
env.close()
print(f"Total reward: {total_reward:.2f}")
if __name__ == "__main__":
test_model()