|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from layer import LinearSVDO
|
|
|
|
|
|
|
|
|
|
|
|
class Net(nn.Module):
|
|
|
def __init__(self, state_dim, action_dim, action_bound, hidden_dims, alpha_threshold, theta_threshold, device):
|
|
|
super(Net, self).__init__()
|
|
|
self.fc1 = LinearSVDO(state_dim, hidden_dims[0], alpha_threshold, theta_threshold, device)
|
|
|
self.fc2 = LinearSVDO(hidden_dims[0], hidden_dims[1], alpha_threshold, theta_threshold, device)
|
|
|
self.fc3 = LinearSVDO(hidden_dims[1], action_dim, alpha_threshold, theta_threshold, device)
|
|
|
self.action_rescale = torch.as_tensor((action_bound[1] - action_bound[0]) / 2., dtype=torch.float32)
|
|
|
self.action_rescale_bias = torch.as_tensor((action_bound[1] + action_bound[0]) / 2., dtype=torch.float32)
|
|
|
self.device = device
|
|
|
self.alpha_threshold = alpha_threshold
|
|
|
|
|
|
def _format(self, state):
|
|
|
x = state
|
|
|
if not isinstance(x, torch.Tensor):
|
|
|
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
|
|
x = x.unsqueeze(0)
|
|
|
return x
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = self._format(x)
|
|
|
x = F.relu(self.fc1(x))
|
|
|
x = F.relu(self.fc2(x))
|
|
|
x = F.tanh(self.fc3(x))
|
|
|
x = x * self.action_rescale + self.action_rescale_bias
|
|
|
return x
|
|
|
|
|
|
|
|
|
class Actor(nn.Module):
|
|
|
def __init__(self, state_dim, action_dim, student_hidden_dims, max_action):
|
|
|
super(Actor, self).__init__()
|
|
|
self.l1 = nn.Linear(state_dim, student_hidden_dims[0])
|
|
|
self.l2 = nn.Linear(student_hidden_dims[0], student_hidden_dims[1])
|
|
|
self.l3 = nn.Linear(student_hidden_dims[1], action_dim)
|
|
|
self.device = 'cuda'
|
|
|
|
|
|
self.max_action = max_action
|
|
|
|
|
|
def _format(self, state):
|
|
|
x = state
|
|
|
if not isinstance(x, torch.Tensor):
|
|
|
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
|
|
x = x.unsqueeze(0)
|
|
|
return x
|
|
|
|
|
|
def forward(self, state):
|
|
|
x = self._format(state)
|
|
|
a = F.relu(self.l1(x))
|
|
|
a = F.relu(self.l2(a))
|
|
|
return self.max_action * torch.tanh(self.l3(a))
|
|
|
|
|
|
|
|
|
class Critic(nn.Module):
|
|
|
def __init__(self, state_dim, action_dim):
|
|
|
super(Critic, self).__init__()
|
|
|
|
|
|
self.device = 'cuda'
|
|
|
|
|
|
|
|
|
self.l1 = nn.Linear(state_dim + action_dim, 256)
|
|
|
self.l2 = nn.Linear(256, 256)
|
|
|
self.l3 = nn.Linear(256, 1)
|
|
|
|
|
|
|
|
|
self.l4 = nn.Linear(state_dim + action_dim, 256)
|
|
|
self.l5 = nn.Linear(256, 256)
|
|
|
self.l6 = nn.Linear(256, 1)
|
|
|
|
|
|
def _format(self, state, action):
|
|
|
x, u = state, action
|
|
|
if not isinstance(x, torch.Tensor):
|
|
|
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
|
|
x = x.unsqueeze(0)
|
|
|
|
|
|
if not isinstance(u, torch.Tensor):
|
|
|
u = torch.tensor(u, device=self.device, dtype=torch.float32)
|
|
|
u = u.unsqueeze(0)
|
|
|
|
|
|
return x, u
|
|
|
|
|
|
def forward(self, state, action):
|
|
|
x, u = self._format(state, action)
|
|
|
sa = torch.cat([x, u], 1)
|
|
|
|
|
|
q1 = F.relu(self.l1(sa))
|
|
|
q1 = F.relu(self.l2(q1))
|
|
|
q1 = self.l3(q1)
|
|
|
|
|
|
q2 = F.relu(self.l4(sa))
|
|
|
q2 = F.relu(self.l5(q2))
|
|
|
q2 = self.l6(q2)
|
|
|
return q1, q2
|
|
|
|
|
|
def Q1(self, state, action):
|
|
|
sa = torch.cat([state, action], 1)
|
|
|
|
|
|
q1 = F.relu(self.l1(sa))
|
|
|
q1 = F.relu(self.l2(q1))
|
|
|
q1 = self.l3(q1)
|
|
|
return q1 |