File size: 3,677 Bytes
a063d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import config

class ActorCritic(tf.keras.Model):
    def __init__(self, action_bounds):
        super(ActorCritic, self).__init__()
        self.action_bounds = action_bounds
        
        self.actor_dense1 = tf.keras.layers.Dense(128, activation='tanh')
        self.actor_dense2 = tf.keras.layers.Dense(64, activation='tanh')
        self.mu = tf.keras.layers.Dense(1, activation='tanh')
        self.log_std = tf.Variable(initial_value=config.INITIAL_LOG_STD, trainable=True)
        
        self.critic_dense1 = tf.keras.layers.Dense(128, activation='tanh')
        self.critic_dense2 = tf.keras.layers.Dense(64, activation='tanh')
        self.value = tf.keras.layers.Dense(1, activation=None)

    def call(self, state):
        x_act = self.actor_dense1(state)
        x_act = self.actor_dense2(x_act)
        mean = self.mu(x_act) * self.action_bounds
        std = tf.exp(self.log_std)
        
        x_crit = self.critic_dense1(state)
        x_crit = self.critic_dense2(x_crit)
        value = self.value(x_crit)
        
        return mean, std, value

class PPOAgent:
    def __init__(self, action_bounds):
        self.ac = ActorCritic(action_bounds)
        self.actor_opt = tf.keras.optimizers.Adam(learning_rate=config.POLICY_LR)
        self.critic_opt = tf.keras.optimizers.Adam(learning_rate=config.VALUE_LR)

    def get_vector_actions(self, states):
        """Evaluates high-speed batched state tensors natively via GPU tensors."""
        states_tensor = tf.convert_to_tensor(states, dtype=tf.float32)
        means, stds, values = self.ac(states_tensor)
        dist = tfp.distributions.Normal(means, stds)
        
        actions = dist.sample()
        actions = tf.clip_by_value(actions, -self.ac.action_bounds, self.ac.action_bounds)
        log_probs = dist.log_prob(actions)
        
        return actions.numpy(), log_probs.numpy(), tf.squeeze(values).numpy()

    @tf.function
    def train_step(self, states, actions, old_log_probs, returns, advantages):
        # Gather specific variables belonging to each network explicitly
        actor_vars = (
            self.ac.actor_dense1.trainable_variables +
            self.ac.actor_dense2.trainable_variables +
            self.ac.mu.trainable_variables +
            [self.ac.log_std]
        )
        critic_vars = (
            self.ac.critic_dense1.trainable_variables +
            self.ac.critic_dense2.trainable_variables +
            self.ac.value.trainable_variables
        )

        with tf.GradientTape(persistent=True) as tape:
            mean, std, values = self.ac(states)
            dist = tfp.distributions.Normal(mean, std)
            new_log_probs = dist.log_prob(actions)
            
            # PPO Objective clipping
            ratio = tf.exp(new_log_probs - old_log_probs)
            surr1 = ratio * advantages
            surr2 = tf.clip_by_value(ratio, 1.0 - config.CLIP_RATIO, 1.0 + config.CLIP_RATIO) * advantages
            actor_loss = -tf.reduce_mean(tf.minimum(surr1, surr2))
            
            # Value MSE tracking
            critic_loss = tf.reduce_mean(tf.square(returns - tf.squeeze(values)))
            
        # Calculate gradients using the isolated variable pools
        actor_grads = tape.gradient(actor_loss, actor_vars)
        critic_grads = tape.gradient(critic_loss, critic_vars)
        
        # Apply gradients explicitly to their respective parameters
        self.actor_opt.apply_gradients(zip(actor_grads, actor_vars))
        self.critic_opt.apply_gradients(zip(critic_grads, critic_vars))
        return actor_loss, critic_loss