File size: 3,199 Bytes
662707e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

import gymnasium as gym
import sys
import matplotlib.pyplot as plt

# import numpy as np
from ppo_helpers_v2 import *


def preprocess(obs):
    return obs.astype(np.float32).ravel() / 255.0


def main() -> int:
    # env = gym.make("ALE/SpaceInvaders-v5", render_mode='human')
    env = gym.make("ALE/Pacman-v5")

    episode = 0
    total_return = 0
    ep_return = 0
    steps = 2000
    batches = 100

    print("Observation space:", env.observation_space)
    print("Action space:", env.action_space)

    agent = Agent(obs_space=env.observation_space, action_space=env.action_space,
                  hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
                  entropy_coef=0.01, value_coef=0.5, seed=70,
                  batch_size = 64, ppo_epochs = 4, lam = 0.95)

    # === Return-Based Scaling stats ===
    r_mean, r_var = 0.0, 1e-8
    g2_mean = 1.0

    agent.r_var = r_var
    agent.g2_mean = g2_mean

    try:
        obs, info = env.reset(seed=42)
        state = preprocess(obs)

        loss_history = []
        reward_history = []

        for update in range(1, batches + 1):
            for t in range(steps):
                action, logp, value = agent.choose_action(state)
                next_obs, reward, terminated, truncated, info = env.step(action)
                done = terminated or truncated
                next_state = preprocess(next_obs)

                agent.remember(state, action, reward, done, logp, value, next_state)

                ep_return += reward
                state = next_state

                if done:
                    episode += 1
                    total_return += ep_return
                    print(f"Episode {episode} return: {ep_return:.2f}")
                    ep_return = 0
                    obs, info = env.reset()
                    state = preprocess(obs)

            avg_loss = agent.update_rbs()  # collect loss from update
            loss_history.append(avg_loss)

            avg_ret = (total_return / episode) if episode else 0
            reward_history.append(avg_ret)
            print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={avg_loss:.4f}")

        fig = plt.figure()

        ax1 = plt.subplot(321)
        ax1.plot(agent.sigma_history, label="Return σ")
        ax1.set_xlabel("PPO Update")
        ax1.set_ylabel("σ (Return Std)")

        ax2 = plt.subplot(322)
        ax2.plot(loss_history, label="Avg Loss")
        ax2.set_ylabel("Average PPO Loss")
        ax2.set_xlabel("PPO Update")

        ax3 = plt.subplot(323)
        ax3.plot(reward_history, label="Reward")
        ax3.set_ylabel("Reward")
        ax3.set_xlabel("PPO Update")

        fig.suptitle("PPO Training Stability")
        fig.tight_layout()
        plt.show()




    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        return 1
    finally:
        avg = total_return / episode if episode else 0
        print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}")
        env.close()

    return 0


if __name__ == "__main__":
    raise SystemExit(main())