rl-project-7Oct commited on
Commit
bcb0c1c
·
verified ·
1 Parent(s): 741396a

Upload ppo_template_cnn.py

Browse files
Files changed (1) hide show
  1. CNN_PPO/ppo_template_cnn.py +123 -0
CNN_PPO/ppo_template_cnn.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gymnasium as gym
3
+ import sys
4
+ import matplotlib.pyplot as plt
5
+ import ale_py
6
+ from ppo_helpers_cnn import *
7
+ from gymnasium.spaces import Box
8
+ import cv2
9
+
10
+ def preprocess(obs):
11
+ # Convert to grayscale
12
+ obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
13
+ # Resize
14
+ obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
15
+ # Add channel dimension and normalize
16
+ return np.expand_dims(obs, axis=0).astype(np.float32) / 255.0
17
+
18
+ def main() -> int:
19
+ # env = gym.make("ALE/SpaceInvaders-v5", render_mode='human')
20
+ #env = gym.make("ALE/Pacman-v5", render_mode="human")
21
+ env = gym.make("ALE/Pacman-v5")
22
+
23
+ episode = 0
24
+ total_return = 0
25
+ ep_return = 0
26
+ steps = 100
27
+ batches = 100
28
+
29
+ print("Observation space:", env.observation_space)
30
+ print("Action space:", env.action_space)
31
+ """
32
+ agent = Agent(obs_space=env.observation_space, action_space=env.action_space,
33
+ hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
34
+ entropy_coef=0.01, value_coef=0.5, seed=70,
35
+ batch_size = 64, ppo_epochs = 4, lam = 0.95)
36
+
37
+ """
38
+ obs, _ = env.reset()
39
+ dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
40
+
41
+ agent = Agent(obs_space=dummy_obs_space, action_space=env.action_space,
42
+ hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
43
+ entropy_coef=0.01, value_coef=0.5, seed=70,
44
+ batch_size = 64, ppo_epochs = 4, lam = 0.95)
45
+
46
+ # === Return-Based Scaling stats ===
47
+ r_mean, r_var = 0.0, 1e-8
48
+ g2_mean = 1.0
49
+
50
+ agent.r_var = r_var
51
+ agent.g2_mean = g2_mean
52
+
53
+ try:
54
+ obs, info = env.reset(seed=42)
55
+ state = preprocess(obs)
56
+
57
+ loss_history = []
58
+ reward_history = []
59
+
60
+ for update in range(1, batches + 1):
61
+ for t in range(steps):
62
+ action, logp, value = agent.choose_action(state)
63
+ next_obs, reward, terminated, truncated, info = env.step(action)
64
+ done = terminated or truncated
65
+ next_state = preprocess(next_obs)
66
+
67
+ agent.remember(state, action, reward, done, logp, value, next_state)
68
+
69
+ ep_return += reward
70
+ state = next_state
71
+
72
+ if done:
73
+ episode += 1
74
+ total_return += ep_return
75
+ print(f"Episode {episode} return: {ep_return:.2f}")
76
+ ep_return = 0
77
+ obs, info = env.reset()
78
+ state = preprocess(obs)
79
+
80
+ avg_loss = agent.update_reward_gradient_clipping()
81
+ loss_history.append(avg_loss)
82
+
83
+ avg_ret = (total_return / episode) if episode else 0
84
+ reward_history.append(avg_ret)
85
+ print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={avg_loss:.4f}")
86
+
87
+ fig = plt.figure()
88
+
89
+ ax1 = plt.subplot(321)
90
+ ax1.plot(agent.sigma_history, label="Return σ")
91
+ ax1.set_xlabel("PPO Update")
92
+ ax1.set_ylabel("σ (Return Std)")
93
+
94
+ ax2 = plt.subplot(322)
95
+ ax2.plot(loss_history, label="Avg Loss")
96
+ ax2.set_ylabel("Average PPO Loss")
97
+ ax2.set_xlabel("PPO Update")
98
+
99
+ ax3 = plt.subplot(323)
100
+ ax3.plot(reward_history, label="Reward")
101
+ ax3.set_ylabel("Reward")
102
+ ax3.set_xlabel("PPO Update")
103
+
104
+ fig.suptitle("PPO Training Stability")
105
+ fig.tight_layout()
106
+ plt.show()
107
+
108
+
109
+
110
+
111
+ except Exception as e:
112
+ print(f"Error: {e}", file=sys.stderr)
113
+ return 1
114
+ finally:
115
+ avg = total_return / episode if episode else 0
116
+ print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}")
117
+ env.close()
118
+
119
+ return 0
120
+
121
+
122
+ if __name__ == "__main__":
123
+ raise SystemExit(main())