| import torch | |
| class ParameterisedPolicy(torch.nn.Module): | |
| """ | |
| REINFORCE RL agent class. | |
| Use ParameterisedPolicy.act(observation) method to predict action based on input state. | |
| obs_len - length of the state vector. | |
| act_space_len - length of the action vector. | |
| """ | |
| def __init__(self, obs_len=8, act_space_len=2): | |
| super().__init__() | |
| self.obs_len = obs_len | |
| self.act_space_len = act_space_len | |
| self.lin_1 = torch.nn.Linear(self.obs_len, 256) | |
| self.rel_1 = torch.nn.ReLU() | |
| self.lin_2 = torch.nn.Linear(256, 128) | |
| self.rel_2 = torch.nn.ReLU() | |
| self.lin_3 = torch.nn.Linear(128, self.act_space_len) | |
| self.lin_4 = torch.nn.Linear(128, self.act_space_len) | |
| self.elu = torch.nn.ELU() | |
| def forward(self, x): | |
| x = self.lin_1(x) | |
| x = self.rel_1(x) | |
| x = self.lin_2(x) | |
| x = self.rel_2(x) | |
| mu = self.lin_3(x) | |
| x = self.lin_4(x) | |
| sigma = self.elu(x) + 1.000001 | |
| return mu, sigma | |
| def act(self, observation): | |
| """ | |
| Method returns action when gym state vector is passed. | |
| """ | |
| (mus, sigmas) = self.forward(torch.tensor(observation, dtype=torch.float32)) | |
| m = torch.distributions.normal.Normal(mus, sigmas) | |
| action = m.sample().detach().numpy() | |
| return action |