| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import utils | |
| from torch import nn | |
| class DoubleQCritic(nn.Module): | |
| """Critic network, employes double Q-learning.""" | |
| def __init__(self, obs_dim, action_dim, hidden_dim, hidden_depth): | |
| super().__init__() | |
| self.Q1 = utils.mlp(obs_dim + action_dim, hidden_dim, 1, hidden_depth) | |
| self.Q2 = utils.mlp(obs_dim + action_dim, hidden_dim, 1, hidden_depth) | |
| self.outputs = dict() | |
| self.apply(utils.weight_init) | |
| def forward(self, obs, action): | |
| assert obs.size(0) == action.size(0) | |
| obs_action = torch.cat([obs, action], dim=-1) | |
| q1 = self.Q1(obs_action) | |
| q2 = self.Q2(obs_action) | |
| self.outputs['q1'] = q1 | |
| self.outputs['q2'] = q2 | |
| return q1, q2 | |
| def log(self, logger, step): | |
| for k, v in self.outputs.items(): | |
| logger.log_histogram(f'train_critic/{k}_hist', v, step) | |
| assert len(self.Q1) == len(self.Q2) | |
| for i, (m1, m2) in enumerate(zip(self.Q1, self.Q2)): | |
| assert type(m1) == type(m2) | |
| if type(m1) is nn.Linear: | |
| logger.log_param(f'train_critic/q1_fc{i}', m1, step) | |
| logger.log_param(f'train_critic/q2_fc{i}', m2, step) |