UD7 / network.py
jangwon-kim-cocel's picture
Upload 10 files
0e2f05d verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import weight_init, AvgL1Norm
class EnsembleQNet(nn.Module):
def __init__(self, num_critics, state_dim, action_dim, device, zs_dim=256, hidden_dims=(256, 256), activation_fc=F.elu):
super(EnsembleQNet, self).__init__()
self.device = device
self.activation_fc = activation_fc
self.num_critics = num_critics
self.q_nets = nn.ModuleList()
for _ in range(self.num_critics):
q_net = self._build_q_net(state_dim, action_dim, zs_dim, hidden_dims)
self.q_nets.append(q_net)
self.apply(weight_init)
def _build_q_net(self, state_dim, action_dim, zs_dim, hidden_dims):
q_net = nn.ModuleDict({
's_input_layer': nn.Linear(state_dim + action_dim, hidden_dims[0]),
'emb_input_layer': nn.Linear(2 * zs_dim + hidden_dims[0], hidden_dims[0]),
'emb_hidden_layers': nn.ModuleList([
nn.Linear(hidden_dims[i], hidden_dims[i + 1]) for i in range(len(hidden_dims) - 1)
]),
'output_layer': nn.Linear(hidden_dims[-1], 1)
})
return q_net
def _format(self, state, action):
x, u = state, action
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device, dtype=torch.float32)
x = x.unsqueeze(0)
if not isinstance(u, torch.Tensor):
u = torch.tensor(u, device=self.device, dtype=torch.float32)
u = u.unsqueeze(0)
return x, u
def forward(self, state, action, zsa, zs):
s, a = self._format(state, action)
sa = torch.cat([s, a], dim=1)
embeddings = torch.cat([zsa, zs], dim=1)
q_values = []
for q_net in self.q_nets:
q = AvgL1Norm(q_net['s_input_layer'](sa))
q = torch.cat([q, embeddings], dim=1)
q = self.activation_fc(q_net['emb_input_layer'](q))
for hidden_layer in q_net['emb_hidden_layers']:
q = self.activation_fc(hidden_layer(q))
q = q_net['output_layer'](q)
q_values.append(q)
return torch.cat(q_values, dim=1)
class Policy(nn.Module):
def __init__(self, state_dim, action_dim, device, zs_dim=256, hidden_dims=(256, 256), activation_fc=F.relu):
super(Policy, self).__init__()
self.device = device
self.apply(weight_init)
self.activation_fc = activation_fc
self.s_input_layer = nn.Linear(state_dim, hidden_dims[0])
self.zss_input_layer = nn.Linear(zs_dim + hidden_dims[0], hidden_dims[0])
self.zss_hidden_layers = nn.ModuleList()
for i in range(len(hidden_dims)-1):
hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1])
self.zss_hidden_layers.append(hidden_layer)
self.zss_output_layer = nn.Linear(hidden_dims[-1], action_dim)
def _format(self, state):
x = state
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device, dtype=torch.float32)
x = x.unsqueeze(0)
return x
def forward(self, state, zs):
state = self._format(state)
state = AvgL1Norm(self.s_input_layer(state))
zss = torch.cat([state, zs], 1)
zss = self.activation_fc(self.zss_input_layer(zss))
for i, hidden_layer in enumerate(self.zss_hidden_layers):
zss = self.activation_fc(hidden_layer(zss))
zss = self.zss_output_layer(zss)
action = torch.tanh(zss)
return action
class Encoder(nn.Module):
def __init__(self, state_dim, action_dim, device, zs_dim=256, hidden_dims=(256, 256), activation_fc=F.elu):
super(Encoder, self).__init__()
self.device = device
self.activation_fc = activation_fc
self.s_encoder_input_layer = nn.Linear(state_dim, hidden_dims[0])
self.s_encoder_hidden_layers = nn.ModuleList()
for i in range(len(hidden_dims)-1):
hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1])
self.s_encoder_hidden_layers.append(hidden_layer)
self.s_encoder_output_layer = nn.Linear(hidden_dims[-1], zs_dim)
self.zsa_encoder_input_layer = nn.Linear(zs_dim + action_dim, hidden_dims[0])
self.zsa_encoder_hidden_layers = nn.ModuleList()
for i in range(len(hidden_dims)-1):
hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1])
self.zsa_encoder_hidden_layers.append(hidden_layer)
self.zsa_encoder_output_layer = nn.Linear(hidden_dims[-1], zs_dim)
def _format(self, state):
x = state
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device, dtype=torch.float32)
x = x.unsqueeze(0)
return x
def zs(self, state):
state = self._format(state)
zs = self.activation_fc(self.s_encoder_input_layer(state))
for i, hidden_layer in enumerate(self.s_encoder_hidden_layers):
zs = self.activation_fc(hidden_layer(zs))
zs = AvgL1Norm(self.s_encoder_output_layer(zs))
return zs
def zsa(self, zs, action):
action = self._format(action)
zsa = torch.cat([zs, action], 1)
zsa = self.activation_fc(self.zsa_encoder_input_layer(zsa))
for i, hidden_layer in enumerate(self.zsa_encoder_hidden_layers):
zsa = self.activation_fc(hidden_layer(zsa))
zsa = self.zsa_encoder_output_layer(zsa)
return zsa