|
|
import torch |
|
|
import gymnasium as gym |
|
|
import torch.nn as nn |
|
|
|
|
|
class Actor(nn.Module): |
|
|
def __init__(self, state_dim, action_dim, hidden_size=64): |
|
|
super().__init__() |
|
|
self.network = nn.Sequential( |
|
|
nn.Linear(state_dim, hidden_size), |
|
|
nn.Tanh(), |
|
|
nn.Linear(hidden_size, hidden_size), |
|
|
nn.Tanh(), |
|
|
nn.Linear(hidden_size, action_dim) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.network(x) |
|
|
|
|
|
def test_model(): |
|
|
|
|
|
checkpoint = torch.load("model.pt", map_location='cpu') |
|
|
actor = Actor(state_dim=8, action_dim=4, hidden_size=checkpoint['config']['hidden_size']) |
|
|
actor.load_state_dict(checkpoint['actor_state_dict']) |
|
|
actor.eval() |
|
|
|
|
|
|
|
|
env = gym.make("LunarLander-v2", render_mode="human") |
|
|
state, _ = env.reset() |
|
|
total_reward = 0 |
|
|
|
|
|
for _ in range(1000): |
|
|
with torch.no_grad(): |
|
|
state_tensor = torch.FloatTensor(state).unsqueeze(0) |
|
|
logits = actor(state_tensor) |
|
|
action = torch.argmax(logits, dim=-1).item() |
|
|
|
|
|
state, reward, terminated, truncated, _ = env.step(action) |
|
|
total_reward += reward |
|
|
|
|
|
if terminated or truncated: |
|
|
break |
|
|
|
|
|
env.close() |
|
|
print(f"Total reward: {total_reward:.2f}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
test_model() |
|
|
|