Update Agent_class.py
Browse files- Agent_class.py +4 -4
Agent_class.py
CHANGED
|
@@ -2,7 +2,7 @@ import torch
|
|
| 2 |
|
| 3 |
class ParameterisedPolicy(torch.nn.Module):
|
| 4 |
"""
|
| 5 |
-
REINFORCE RL agent class. Returns action when the ParameterisedPolicy.act(observation) is used.
|
| 6 |
observation is a gym state vector.
|
| 7 |
obs_len - length of the state vector
|
| 8 |
act_space_len - length of the action vector
|
|
@@ -10,8 +10,6 @@ class ParameterisedPolicy(torch.nn.Module):
|
|
| 10 |
"""
|
| 11 |
def __init__(self, obs_len=8, act_space_len=2):
|
| 12 |
super().__init__()
|
| 13 |
-
self.deterministic = False
|
| 14 |
-
self.continuous = True
|
| 15 |
self.obs_len = obs_len
|
| 16 |
self.act_space_len = act_space_len
|
| 17 |
self.lin_1 = torch.nn.Linear(self.obs_len, 256)
|
|
@@ -41,7 +39,9 @@ class ParameterisedPolicy(torch.nn.Module):
|
|
| 41 |
return mu, sigma
|
| 42 |
|
| 43 |
def act(self, observation):
|
| 44 |
-
|
|
|
|
|
|
|
| 45 |
(mus, sigmas) = self.forward(torch.tensor(observation, dtype=torch.float32))
|
| 46 |
m = torch.distributions.normal.Normal(mus, sigmas)
|
| 47 |
action = m.sample().detach().numpy()
|
|
|
|
| 2 |
|
| 3 |
class ParameterisedPolicy(torch.nn.Module):
|
| 4 |
"""
|
| 5 |
+
REINFORCE RL agent class. Returns action when the ParameterisedPolicy.act(observation) method is used.
|
| 6 |
observation is a gym state vector.
|
| 7 |
obs_len - length of the state vector
|
| 8 |
act_space_len - length of the action vector
|
|
|
|
| 10 |
"""
|
| 11 |
def __init__(self, obs_len=8, act_space_len=2):
|
| 12 |
super().__init__()
|
|
|
|
|
|
|
| 13 |
self.obs_len = obs_len
|
| 14 |
self.act_space_len = act_space_len
|
| 15 |
self.lin_1 = torch.nn.Linear(self.obs_len, 256)
|
|
|
|
| 39 |
return mu, sigma
|
| 40 |
|
| 41 |
def act(self, observation):
|
| 42 |
+
"""
|
| 43 |
+
Method returns action when gym state vector is passed.
|
| 44 |
+
"""
|
| 45 |
(mus, sigmas) = self.forward(torch.tensor(observation, dtype=torch.float32))
|
| 46 |
m = torch.distributions.normal.Normal(mus, sigmas)
|
| 47 |
action = m.sample().detach().numpy()
|