jangwon-kim-cocel's picture
Upload 11 files
96170c3 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from layer import LinearSVDO
# Define a simple 2 layer Network
class Net(nn.Module):
def __init__(self, state_dim, action_dim, action_bound, hidden_dims, alpha_threshold, theta_threshold, device):
super(Net, self).__init__()
self.fc1 = LinearSVDO(state_dim, hidden_dims[0], alpha_threshold, theta_threshold, device)
self.fc2 = LinearSVDO(hidden_dims[0], hidden_dims[1], alpha_threshold, theta_threshold, device)
self.fc3 = LinearSVDO(hidden_dims[1], action_dim, alpha_threshold, theta_threshold, device)
self.action_rescale = torch.as_tensor((action_bound[1] - action_bound[0]) / 2., dtype=torch.float32)
self.action_rescale_bias = torch.as_tensor((action_bound[1] + action_bound[0]) / 2., dtype=torch.float32)
self.device = device
self.alpha_threshold = alpha_threshold
def _format(self, state):
x = state
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device, dtype=torch.float32)
x = x.unsqueeze(0)
return x
def forward(self, x):
x = self._format(x)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.tanh(self.fc3(x))
x = x * self.action_rescale + self.action_rescale_bias
return x
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, student_hidden_dims, max_action):
super(Actor, self).__init__()
self.l1 = nn.Linear(state_dim, student_hidden_dims[0])
self.l2 = nn.Linear(student_hidden_dims[0], student_hidden_dims[1])
self.l3 = nn.Linear(student_hidden_dims[1], action_dim)
self.device = 'cuda'
self.max_action = max_action
def _format(self, state):
x = state
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device, dtype=torch.float32)
x = x.unsqueeze(0)
return x
def forward(self, state):
x = self._format(state)
a = F.relu(self.l1(x))
a = F.relu(self.l2(a))
return self.max_action * torch.tanh(self.l3(a))
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic, self).__init__()
self.device = 'cuda'
# Q1 architecture
self.l1 = nn.Linear(state_dim + action_dim, 256)
self.l2 = nn.Linear(256, 256)
self.l3 = nn.Linear(256, 1)
# Q2 architecture
self.l4 = nn.Linear(state_dim + action_dim, 256)
self.l5 = nn.Linear(256, 256)
self.l6 = nn.Linear(256, 1)
def _format(self, state, action):
x, u = state, action
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device, dtype=torch.float32)
x = x.unsqueeze(0)
if not isinstance(u, torch.Tensor):
u = torch.tensor(u, device=self.device, dtype=torch.float32)
u = u.unsqueeze(0)
return x, u
def forward(self, state, action):
x, u = self._format(state, action)
sa = torch.cat([x, u], 1)
q1 = F.relu(self.l1(sa))
q1 = F.relu(self.l2(q1))
q1 = self.l3(q1)
q2 = F.relu(self.l4(sa))
q2 = F.relu(self.l5(q2))
q2 = self.l6(q2)
return q1, q2
def Q1(self, state, action):
sa = torch.cat([state, action], 1)
q1 = F.relu(self.l1(sa))
q1 = F.relu(self.l2(q1))
q1 = self.l3(q1)
return q1