Apple / apple /models /categorical_policy.py
New Author Name
init
4b714e2
import torch
import torch.nn as nn
from torch.distributions import Categorical
class CategoricalPolicy(nn.Module):
def __init__(self, state_dim, act_dim, weight1=None, weight2=None):
super().__init__()
self.model = nn.Linear(state_dim, act_dim, bias=False)
if weight1 is not None:
nn.init.constant_(self.model.weight[0][0], weight1)
if weight2 is not None:
nn.init.constant_(self.model.weight[0][1], weight2)
def forward(self, state):
x = torch.from_numpy(state).float().unsqueeze(0)
x = self.model(x)
# we just consider 1 dimensional probability of action
p = torch.sigmoid(x)
return torch.cat([p, 1 - p], dim=1)
def act(self, state):
probs = self.forward(state)
dist = Categorical(probs)
action = dist.sample()
return action.item(), dist.log_prob(action)
def sample(self, probs):
dist = Categorical(probs)
action = dist.sample()
return action.item(), dist.log_prob(action)
def log_prob(self, probs, target_action):
dist = Categorical(probs)
action = dist.sample()
return action.item(), dist.log_prob(target_action)
@torch.no_grad()
def get_action(self, state):
probs = self.forward(state)
dist = Categorical(probs)
action = dist.sample()
return action.item()