Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from torch.distributions import Categorical | |
| class CategoricalPolicy(nn.Module): | |
| def __init__(self, state_dim, act_dim, weight1=None, weight2=None): | |
| super().__init__() | |
| self.model = nn.Linear(state_dim, act_dim, bias=False) | |
| if weight1 is not None: | |
| nn.init.constant_(self.model.weight[0][0], weight1) | |
| if weight2 is not None: | |
| nn.init.constant_(self.model.weight[0][1], weight2) | |
| def forward(self, state): | |
| x = torch.from_numpy(state).float().unsqueeze(0) | |
| x = self.model(x) | |
| # we just consider 1 dimensional probability of action | |
| p = torch.sigmoid(x) | |
| return torch.cat([p, 1 - p], dim=1) | |
| def act(self, state): | |
| probs = self.forward(state) | |
| dist = Categorical(probs) | |
| action = dist.sample() | |
| return action.item(), dist.log_prob(action) | |
| def sample(self, probs): | |
| dist = Categorical(probs) | |
| action = dist.sample() | |
| return action.item(), dist.log_prob(action) | |
| def log_prob(self, probs, target_action): | |
| dist = Categorical(probs) | |
| action = dist.sample() | |
| return action.item(), dist.log_prob(target_action) | |
| def get_action(self, state): | |
| probs = self.forward(state) | |
| dist = Categorical(probs) | |
| action = dist.sample() | |
| return action.item() | |