import tensorflow as tf import tensorflow.contrib.layers as layers def build_q_func(network, hiddens=[256], dueling=True, layer_norm=False, **network_kwargs): if isinstance(network, str): from baselines.common.models import get_network_builder network = get_network_builder(network)(**network_kwargs) def q_func_builder(input_placeholder, num_actions, scope, reuse=False): with tf.compat.v1.variable_scope(scope, reuse=reuse): latent = network(input_placeholder) if isinstance(latent, tuple): if latent[1] is not None: raise NotImplementedError("DQN is not compatible with recurrent policies yet") latent = latent[0] latent = layers.flatten(latent) with tf.compat.v1.variable_scope("action_value"): action_out = latent for hidden in hiddens: action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=None) if layer_norm: action_out = layers.layer_norm(action_out, center=True, scale=True) action_out = tf.nn.relu(action_out) action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None) if dueling: with tf.compat.v1.variable_scope("state_value"): state_out = latent for hidden in hiddens: state_out = layers.fully_connected(state_out, num_outputs=hidden, activation_fn=None) if layer_norm: state_out = layers.layer_norm(state_out, center=True, scale=True) state_out = tf.nn.relu(state_out) state_score = layers.fully_connected(state_out, num_outputs=1, activation_fn=None) action_scores_mean = tf.reduce_mean(input_tensor=action_scores, axis=1) action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, 1) q_out = state_score + action_scores_centered else: q_out = action_scores return q_out return q_func_builder