privateboss's picture
Upload 8 files
a063d15 verified
Raw
History Blame Contribute Delete
3.68 kB
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import config
class ActorCritic(tf.keras.Model):
def __init__(self, action_bounds):
super(ActorCritic, self).__init__()
self.action_bounds = action_bounds
self.actor_dense1 = tf.keras.layers.Dense(128, activation='tanh')
self.actor_dense2 = tf.keras.layers.Dense(64, activation='tanh')
self.mu = tf.keras.layers.Dense(1, activation='tanh')
self.log_std = tf.Variable(initial_value=config.INITIAL_LOG_STD, trainable=True)
self.critic_dense1 = tf.keras.layers.Dense(128, activation='tanh')
self.critic_dense2 = tf.keras.layers.Dense(64, activation='tanh')
self.value = tf.keras.layers.Dense(1, activation=None)
def call(self, state):
x_act = self.actor_dense1(state)
x_act = self.actor_dense2(x_act)
mean = self.mu(x_act) * self.action_bounds
std = tf.exp(self.log_std)
x_crit = self.critic_dense1(state)
x_crit = self.critic_dense2(x_crit)
value = self.value(x_crit)
return mean, std, value
class PPOAgent:
def __init__(self, action_bounds):
self.ac = ActorCritic(action_bounds)
self.actor_opt = tf.keras.optimizers.Adam(learning_rate=config.POLICY_LR)
self.critic_opt = tf.keras.optimizers.Adam(learning_rate=config.VALUE_LR)
def get_vector_actions(self, states):
"""Evaluates high-speed batched state tensors natively via GPU tensors."""
states_tensor = tf.convert_to_tensor(states, dtype=tf.float32)
means, stds, values = self.ac(states_tensor)
dist = tfp.distributions.Normal(means, stds)
actions = dist.sample()
actions = tf.clip_by_value(actions, -self.ac.action_bounds, self.ac.action_bounds)
log_probs = dist.log_prob(actions)
return actions.numpy(), log_probs.numpy(), tf.squeeze(values).numpy()
@tf.function
def train_step(self, states, actions, old_log_probs, returns, advantages):
# Gather specific variables belonging to each network explicitly
actor_vars = (
self.ac.actor_dense1.trainable_variables +
self.ac.actor_dense2.trainable_variables +
self.ac.mu.trainable_variables +
[self.ac.log_std]
)
critic_vars = (
self.ac.critic_dense1.trainable_variables +
self.ac.critic_dense2.trainable_variables +
self.ac.value.trainable_variables
)
with tf.GradientTape(persistent=True) as tape:
mean, std, values = self.ac(states)
dist = tfp.distributions.Normal(mean, std)
new_log_probs = dist.log_prob(actions)
# PPO Objective clipping
ratio = tf.exp(new_log_probs - old_log_probs)
surr1 = ratio * advantages
surr2 = tf.clip_by_value(ratio, 1.0 - config.CLIP_RATIO, 1.0 + config.CLIP_RATIO) * advantages
actor_loss = -tf.reduce_mean(tf.minimum(surr1, surr2))
# Value MSE tracking
critic_loss = tf.reduce_mean(tf.square(returns - tf.squeeze(values)))
# Calculate gradients using the isolated variable pools
actor_grads = tape.gradient(actor_loss, actor_vars)
critic_grads = tape.gradient(critic_loss, critic_vars)
# Apply gradients explicitly to their respective parameters
self.actor_opt.apply_gradients(zip(actor_grads, actor_vars))
self.critic_opt.apply_gradients(zip(critic_grads, critic_vars))
return actor_loss, critic_loss