manansodha commited on
Commit
e8b2ea3
·
verified ·
1 Parent(s): d937e11

Add new method of reward clipping

Browse files
CNN_PPO/ppo_template_meanstd_clipping.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ import sys
3
+ import matplotlib.pyplot as plt
4
+ import ale_py
5
+ from ppo_helpers_cnn import *
6
+ from gymnasium.spaces import Box
7
+ import cv2
8
+ import numpy as np
9
+
10
+
11
+ def preprocess(obs):
12
+ obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
13
+ obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
14
+ return np.expand_dims(obs, axis=0).astype(np.float32) / 255.0
15
+
16
+
17
+ def main() -> int:
18
+ env = gym.make("ALE/Pacman-v5")
19
+
20
+ episode = 0
21
+ total_return = 0
22
+ ep_return = 0
23
+ steps = 2000
24
+ batches = 100
25
+
26
+ print("Observation space:", env.observation_space)
27
+ print("Action space:", env.action_space)
28
+
29
+ # Initialize CNN
30
+ obs, _ = env.reset()
31
+ dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
32
+
33
+ agent = Agent(
34
+ obs_space=dummy_obs_space, action_space=env.action_space,
35
+ hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
36
+ entropy_coef=0.01, value_coef=0.5, seed=70,
37
+ batch_size=64, ppo_epochs=4, lam=0.95
38
+ )
39
+
40
+ # === Return-Based Scaling stats ===
41
+ r_mean, r_var = 0.0, 1e-8
42
+ g2_mean = 1.0
43
+ agent.r_var = r_var
44
+ agent.g2_mean = g2_mean
45
+
46
+ # === YOUR NEW REWARD CLIPPING SYSTEM ===
47
+ alpha = np.random.uniform(0, 2)
48
+ print(f"\n[INFO] α sampled = {alpha:.3f}\n")
49
+
50
+ reward_batch = [] # stores total rewards for 5 episodes
51
+ clip_low, clip_high = None, None
52
+ EPISODES_PER_BATCH = 5
53
+
54
+ try:
55
+ obs, info = env.reset()
56
+ state = preprocess(obs)
57
+
58
+ loss_history = []
59
+ reward_history = []
60
+
61
+ # === PPO outer updates ===
62
+ for update in range(1, batches + 1):
63
+
64
+ reward_batch.clear()
65
+
66
+ # === Collect 5 full episodes ===
67
+ for ep in range(EPISODES_PER_BATCH):
68
+ ep_rewards_raw = []
69
+ done = False
70
+
71
+ while not done:
72
+ action, logp, value = agent.choose_action(state)
73
+ next_obs, reward, terminated, truncated, info = env.step(action)
74
+ done = terminated or truncated
75
+ next_state = preprocess(next_obs)
76
+
77
+ # === APPLY REWARD CLIPPING TO RAW REWARD IF READY ===
78
+ # if clip_low is not None:
79
+ # reward = np.clip(reward, clip_low, clip_high)
80
+
81
+ agent.remember(state, action, reward, done, logp, value, next_state)
82
+
83
+ ep_return += reward
84
+ ep_rewards_raw.append(reward)
85
+ # print("raw reward:", reward)
86
+ state = next_state
87
+
88
+ if done:
89
+ # episode completed
90
+ episode += 1
91
+ total_return += ep_return
92
+
93
+ reward_sum = sum(ep_rewards_raw)
94
+ if clip_low is not None:
95
+ reward_sum_clipped = np.clip(reward_sum, clip_low, clip_high)
96
+ reward_batch.append(reward_sum_clipped)
97
+ print(f"Episode {episode} | Reward (clipped): {reward_sum_clipped:.2f}")
98
+ else:
99
+ reward_batch.append(reward_sum)
100
+ print(f"Episode {episode} | Reward (clipped): {reward_sum:.2f}")
101
+
102
+
103
+ ep_return = 0
104
+ obs, info = env.reset()
105
+ state = preprocess(obs)
106
+
107
+ # === After every 5 episodes → compute clipping range ===
108
+ mu = np.mean(reward_batch)
109
+ sigma = np.std(reward_batch) + 1e-8
110
+
111
+ # clip_low = 0 # When raw reward is clipped
112
+ clip_low = mu - sigma * alpha # When sum of raw reward is clipped
113
+ clip_high = mu + alpha * sigma
114
+
115
+ print(f"[UPDATE {update}] New Reward Clip Range: [{clip_low:.2f}, {clip_high:.2f}]")
116
+
117
+ # === PPO UPDATE ===
118
+ avg_loss = agent.vanilla_ppo_update()
119
+ loss_history.append(avg_loss)
120
+
121
+ avg_ret = (total_return / episode) if episode else 0
122
+ reward_history.append(avg_ret)
123
+
124
+ print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={avg_loss:.4f}")
125
+
126
+ # === PLOTS ===
127
+ fig = plt.figure(figsize=(12, 8))
128
+
129
+ ax2 = plt.subplot(221)
130
+ ax2.plot(loss_history, label="Avg Loss")
131
+ ax2.set_ylabel("Average PPO Loss")
132
+ ax2.set_xlabel("PPO Update")
133
+
134
+ ax3 = plt.subplot(222)
135
+ ax3.plot(reward_history, label="Reward")
136
+ ax3.set_ylabel("Reward")
137
+ ax3.set_xlabel("PPO Update")
138
+
139
+ ax4 = plt.subplot(223)
140
+ ax4.plot(agent.policy_loss_history, label="Policy Loss", alpha=0.7)
141
+ ax4.set_ylabel("Policy Loss")
142
+ ax4.set_xlabel("Training Step")
143
+ ax4.legend()
144
+
145
+ ax5 = plt.subplot(224)
146
+ ax5.plot(agent.value_loss_history, label="Value Loss", alpha=0.7)
147
+ ax5.set_ylabel("Value Loss")
148
+ ax5.set_xlabel("Training Step")
149
+ ax5.legend()
150
+
151
+ fig.suptitle("PPO Training Stability")
152
+ fig.tight_layout()
153
+ plt.show()
154
+
155
+ except Exception as e:
156
+ print(f"Error: {e}", file=sys.stderr)
157
+ return 1
158
+
159
+ finally:
160
+ avg = total_return / episode if episode else 0
161
+ print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}")
162
+ env.close()
163
+
164
+ return 0
165
+
166
+
167
+ if __name__ == "__main__":
168
+ raise SystemExit(main())