rl-project-7Oct commited on
Commit
d937e11
·
verified ·
1 Parent(s): 9763567

Updated plot

Browse files
Files changed (1) hide show
  1. CNN_PPO/ppo_template_cnn.py +146 -123
CNN_PPO/ppo_template_cnn.py CHANGED
@@ -1,123 +1,146 @@
1
-
2
- import gymnasium as gym
3
- import sys
4
- import matplotlib.pyplot as plt
5
- import ale_py
6
- from ppo_helpers_cnn import *
7
- from gymnasium.spaces import Box
8
- import cv2
9
-
10
- def preprocess(obs):
11
- # Convert to grayscale
12
- obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
13
- # Resize
14
- obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
15
- # Add channel dimension and normalize
16
- return np.expand_dims(obs, axis=0).astype(np.float32) / 255.0
17
-
18
- def main() -> int:
19
- # env = gym.make("ALE/SpaceInvaders-v5", render_mode='human')
20
- #env = gym.make("ALE/Pacman-v5", render_mode="human")
21
- env = gym.make("ALE/Pacman-v5")
22
-
23
- episode = 0
24
- total_return = 0
25
- ep_return = 0
26
- steps = 100
27
- batches = 100
28
-
29
- print("Observation space:", env.observation_space)
30
- print("Action space:", env.action_space)
31
- """
32
- agent = Agent(obs_space=env.observation_space, action_space=env.action_space,
33
- hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
34
- entropy_coef=0.01, value_coef=0.5, seed=70,
35
- batch_size = 64, ppo_epochs = 4, lam = 0.95)
36
-
37
- """
38
- obs, _ = env.reset()
39
- dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
40
-
41
- agent = Agent(obs_space=dummy_obs_space, action_space=env.action_space,
42
- hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
43
- entropy_coef=0.01, value_coef=0.5, seed=70,
44
- batch_size = 64, ppo_epochs = 4, lam = 0.95)
45
-
46
- # === Return-Based Scaling stats ===
47
- r_mean, r_var = 0.0, 1e-8
48
- g2_mean = 1.0
49
-
50
- agent.r_var = r_var
51
- agent.g2_mean = g2_mean
52
-
53
- try:
54
- obs, info = env.reset(seed=42)
55
- state = preprocess(obs)
56
-
57
- loss_history = []
58
- reward_history = []
59
-
60
- for update in range(1, batches + 1):
61
- for t in range(steps):
62
- action, logp, value = agent.choose_action(state)
63
- next_obs, reward, terminated, truncated, info = env.step(action)
64
- done = terminated or truncated
65
- next_state = preprocess(next_obs)
66
-
67
- agent.remember(state, action, reward, done, logp, value, next_state)
68
-
69
- ep_return += reward
70
- state = next_state
71
-
72
- if done:
73
- episode += 1
74
- total_return += ep_return
75
- print(f"Episode {episode} return: {ep_return:.2f}")
76
- ep_return = 0
77
- obs, info = env.reset()
78
- state = preprocess(obs)
79
-
80
- avg_loss = agent.update_reward_gradient_clipping()
81
- loss_history.append(avg_loss)
82
-
83
- avg_ret = (total_return / episode) if episode else 0
84
- reward_history.append(avg_ret)
85
- print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={avg_loss:.4f}")
86
-
87
- fig = plt.figure()
88
-
89
- ax1 = plt.subplot(321)
90
- ax1.plot(agent.sigma_history, label="Return σ")
91
- ax1.set_xlabel("PPO Update")
92
- ax1.set_ylabel("σ (Return Std)")
93
-
94
- ax2 = plt.subplot(322)
95
- ax2.plot(loss_history, label="Avg Loss")
96
- ax2.set_ylabel("Average PPO Loss")
97
- ax2.set_xlabel("PPO Update")
98
-
99
- ax3 = plt.subplot(323)
100
- ax3.plot(reward_history, label="Reward")
101
- ax3.set_ylabel("Reward")
102
- ax3.set_xlabel("PPO Update")
103
-
104
- fig.suptitle("PPO Training Stability")
105
- fig.tight_layout()
106
- plt.show()
107
-
108
-
109
-
110
-
111
- except Exception as e:
112
- print(f"Error: {e}", file=sys.stderr)
113
- return 1
114
- finally:
115
- avg = total_return / episode if episode else 0
116
- print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}")
117
- env.close()
118
-
119
- return 0
120
-
121
-
122
- if __name__ == "__main__":
123
- raise SystemExit(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gymnasium as gym
3
+ import sys
4
+ import matplotlib.pyplot as plt
5
+ import ale_py
6
+ from ppo_helpers_cnn import *
7
+ from gymnasium.spaces import Box
8
+ import cv2
9
+
10
+ def preprocess(obs):
11
+ # Convert to grayscale
12
+ obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
13
+ # Resize
14
+ obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
15
+ # Add channel dimension and normalize
16
+ return np.expand_dims(obs, axis=0).astype(np.float32) / 255.0
17
+
18
+ def main() -> int:
19
+ # env = gym.make("ALE/SpaceInvaders-v5", render_mode='human')
20
+ #env = gym.make("ALE/Pacman-v5", render_mode="human")
21
+ env = gym.make("ALE/Pacman-v5")
22
+
23
+ episode = 0
24
+ total_return = 0
25
+ ep_return = 0
26
+ steps = 100
27
+ batches = 100
28
+
29
+ print("Observation space:", env.observation_space)
30
+ print("Action space:", env.action_space)
31
+ """
32
+ agent = Agent(obs_space=env.observation_space, action_space=env.action_space,
33
+ hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
34
+ entropy_coef=0.01, value_coef=0.5, seed=70,
35
+ batch_size = 64, ppo_epochs = 4, lam = 0.95)
36
+
37
+ """
38
+ # Initialize CNN with a dummy observation (to get correct input shape)
39
+ obs, _ = env.reset()
40
+ dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
41
+
42
+ agent = Agent(obs_space=dummy_obs_space, action_space=env.action_space,
43
+ hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
44
+ entropy_coef=0.01, value_coef=0.5, seed=70,
45
+ batch_size = 64, ppo_epochs = 4, lam = 0.95)
46
+ """
47
+ # Stats for Return-Based Scaling only
48
+ # === Return-Based Scaling stats ===
49
+ r_mean, r_var = 0.0, 1e-8
50
+ g2_mean = 1.0
51
+
52
+ agent.r_var = r_var
53
+ agent.g2_mean = g2_mean
54
+ """
55
+
56
+ try:
57
+ obs, info = env.reset(seed=42)
58
+ state = preprocess(obs)
59
+
60
+ loss_history = []
61
+ reward_history = []
62
+
63
+ for update in range(1, batches + 1):
64
+ for t in range(steps):
65
+ action, logp, value = agent.choose_action(state)
66
+ next_obs, reward, terminated, truncated, info = env.step(action)
67
+ done = terminated or truncated
68
+ next_state = preprocess(next_obs)
69
+
70
+ agent.remember(state, action, reward, done, logp, value, next_state)
71
+
72
+ ep_return += reward
73
+ state = next_state
74
+
75
+ if done:
76
+ episode += 1
77
+ total_return += ep_return
78
+ print(f"Episode {episode} return: {ep_return:.2f}")
79
+ ep_return = 0
80
+ obs, info = env.reset()
81
+ state = preprocess(obs)
82
+
83
+ # Using reward gradient clipping
84
+ avg_loss = agent.update_reward_gradient_clipping()
85
+
86
+ # Vanilla PPO (no normalization)
87
+ #avg_loss = agent.vanilla_ppo_update()
88
+ loss_history.append(avg_loss)
89
+
90
+ avg_ret = (total_return / episode) if episode else 0
91
+ reward_history.append(avg_ret)
92
+ print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={avg_loss:.4f}")
93
+
94
+ fig = plt.figure(figsize=(12, 8))
95
+
96
+ """
97
+ # Plot for Return-Based Scaling only
98
+ ax1 = plt.subplot(220)
99
+ ax1.plot(agent.sigma_history, label="Return σ")
100
+ ax1.set_xlabel("PPO Update")
101
+ ax1.set_ylabel("σ (Return Std)")
102
+ """
103
+
104
+ ax2 = plt.subplot(221)
105
+ ax2.plot(loss_history, label="Avg Loss")
106
+ ax2.set_ylabel("Average PPO Loss")
107
+ ax2.set_xlabel("PPO Update")
108
+
109
+ ax3 = plt.subplot(222)
110
+ ax3.plot(reward_history, label="Reward")
111
+ ax3.set_ylabel("Reward")
112
+ ax3.set_xlabel("PPO Update")
113
+
114
+ # Details about value loss and policy loss
115
+ ax4 = plt.subplot(223)
116
+ ax4.plot(agent.policy_loss_history, label="Policy Loss", alpha=0.7)
117
+ ax4.set_ylabel("Policy Loss")
118
+ ax4.set_xlabel("Training Step")
119
+ ax4.legend()
120
+
121
+ ax5 = plt.subplot(224)
122
+ ax5.plot(agent.value_loss_history, label="Value Loss", alpha=0.7)
123
+ ax5.set_ylabel("Value Loss")
124
+ ax5.set_xlabel("Training Step")
125
+ ax5.legend()
126
+
127
+ fig.suptitle("PPO Training Stability")
128
+ fig.tight_layout()
129
+ plt.show()
130
+
131
+
132
+
133
+
134
+ except Exception as e:
135
+ print(f"Error: {e}", file=sys.stderr)
136
+ return 1
137
+ finally:
138
+ avg = total_return / episode if episode else 0
139
+ print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}")
140
+ env.close()
141
+
142
+ return 0
143
+
144
+
145
+ if __name__ == "__main__":
146
+ raise SystemExit(main())