| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from agent.helpers import SinusoidalPosEmb, init_weights | |
| class Critic(nn.Module): | |
| def __init__(self, state_dim, action_dim, hidden_dim=256): | |
| super(Critic, self).__init__() | |
| self.q1_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim), | |
| nn.Mish(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.Mish(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.Mish(), | |
| nn.Linear(hidden_dim, 1)) | |
| self.q2_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim), | |
| nn.Mish(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.Mish(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.Mish(), | |
| nn.Linear(hidden_dim, 1)) | |
| self.apply(init_weights) | |
| def forward(self, state, action): | |
| x = torch.cat([state, action], dim=-1) | |
| return self.q1_model(x), self.q2_model(x) | |
| def q1(self, state, action): | |
| x = torch.cat([state, action], dim=-1) | |
| return self.q1_model(x) | |
| def q_min(self, state, action): | |
| q1, q2 = self.forward(state, action) | |
| return torch.min(q1, q2) | |
| class Model(nn.Module): | |
| def __init__(self, state_dim, action_dim, hidden_size=256, time_dim=32): | |
| super(Model, self).__init__() | |
| self.time_mlp = nn.Sequential( | |
| SinusoidalPosEmb(time_dim), | |
| nn.Linear(time_dim, hidden_size), | |
| nn.Mish(), | |
| nn.Linear(hidden_size, time_dim), | |
| ) | |
| input_dim = state_dim + action_dim + time_dim | |
| self.layer = nn.Sequential(nn.Linear(input_dim, hidden_size), | |
| nn.Mish(), | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.Mish(), | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.Mish(), | |
| nn.Linear(hidden_size, action_dim)) | |
| self.apply(init_weights) | |
| def forward(self, x, time, state): | |
| t = self.time_mlp(time) | |
| out = torch.cat([x, t, state], dim=-1) | |
| out = self.layer(out) | |
| return out | |
| class MLP(nn.Module): | |
| def __init__(self, state_dim, action_dim, hidden_size=256): | |
| super(MLP, self).__init__() | |
| input_dim = state_dim | |
| self.mid_layer = nn.Sequential(nn.Linear(input_dim, hidden_size), | |
| nn.Mish(), | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.Mish(), | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.Mish()) | |
| self.final_layer = nn.Linear(hidden_size, action_dim) | |
| self.apply(init_weights) | |
| def forward(self, state, eval=False): | |
| out = self.mid_layer(state) | |
| out = self.final_layer(out) | |
| if not eval: | |
| out += torch.randn_like(out) * 0.1 | |
| return out | |
| def loss(self, action, state): | |
| return F.mse_loss(self.forward(state), action, reduction='mean') | |