Anoozh-Akileswaran commited on
Commit
fc2ab64
·
1 Parent(s): 20989d1

Observation, Advantage and Return normalization for SAC and PPO

Browse files
Files changed (39) hide show
  1. {Observation_Advantage_Norm_diff_combo → Observation_norm_PPO/Observation_Advantage_Norm_diff_combo}/ppo__rew_norm_obs_diff_combo.py +0 -0
  2. {Observation_Advantage_Norm_diff_combo → Observation_norm_PPO/Observation_Advantage_Norm_diff_combo}/ppo_rew_norm_obs_env_diff_combo.py +0 -0
  3. {Observation_Advantage_Norm_diff_env → Observation_norm_PPO/Observation_Advantage_Norm_diff_env}/ppo__rew_norm_obs_diff_env.py +0 -0
  4. {Observation_Advantage_Norm_diff_env → Observation_norm_PPO/Observation_Advantage_Norm_diff_env}/ppo_rew_norm_obs_env_diff_env.py +0 -0
  5. {Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for Learning Rate of update_advantage_norm.png +0 -0
  6. {Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for Learning Rate of update_observation_norm.png +0 -0
  7. {Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for Learning Rate of update_return_norm.png +0 -0
  8. {Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for Learning Rate of vanilla_ppo_update.png +0 -0
  9. {Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for entropy coefficient of update_advantage_norm.png +0 -0
  10. {Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for entropy coefficient of update_observation_norm.png +0 -0
  11. {Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for entropy coefficient of update_return_norm.png +0 -0
  12. {Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for entropy coefficient of vanilla_ppo_update.png +0 -0
  13. {Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for gamma value of update_advantage_norm.png +0 -0
  14. {Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for gamma value of update_observation_norm.png +0 -0
  15. {Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for gamma value of update_return_norm.png +0 -0
  16. {Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for gamma value of vanilla_ppo_update.png +0 -0
  17. {Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/ppo__rew_norm_obs_diff_hyp.py +0 -0
  18. {Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/ppo_rew_norm_obs_env_diff_hypo.py +0 -0
  19. {Observation_Advantage_Norm_in_batch → Observation_norm_PPO/Observation_Advantage_Norm_in_batch}/ppo__rew_norm_obs_in_batch.py +0 -0
  20. {Observation_Advantage_Norm_in_batch → Observation_norm_PPO/Observation_Advantage_Norm_in_batch}/ppo_rew_norm_obs_env_in_batch.py +0 -0
  21. {Observation_Advantage_Norm_in_batch → Observation_norm_PPO/Observation_Advantage_Norm_in_batch}/update_advantage_norm_in_batch.png +0 -0
  22. {Observation_Advantage_Norm_in_batch → Observation_norm_PPO/Observation_Advantage_Norm_in_batch}/update_observation_norm_in_batch.png +0 -0
  23. {Observation_Advantage_Norm_in_batch → Observation_norm_PPO/Observation_Advantage_Norm_in_batch}/update_return_norm_in_batch.png +0 -0
  24. {Observation_Advantage_Norm_in_batch → Observation_norm_PPO/Observation_Advantage_Norm_in_batch}/vanilla_ppo_update_in_batch.png +0 -0
  25. {Observation_Advantage_Norm_running_averages → Observation_norm_PPO/Observation_Advantage_Norm_running_averages}/ppo__rew_norm_obs_running_average.py +0 -0
  26. {Observation_Advantage_Norm_running_averages → Observation_norm_PPO/Observation_Advantage_Norm_running_averages}/ppo_rew_norm_obs_env_running_average.py +0 -0
  27. {Observation_Advantage_Norm_running_averages → Observation_norm_PPO/Observation_Advantage_Norm_running_averages}/update_advantage_norm_running_average_.png +0 -0
  28. {Observation_Advantage_Norm_running_averages → Observation_norm_PPO/Observation_Advantage_Norm_running_averages}/update_observation_norm_running_average_.png +0 -0
  29. {Observation_Advantage_Norm_running_averages → Observation_norm_PPO/Observation_Advantage_Norm_running_averages}/update_return_norm_running_average_.png +0 -0
  30. {Observation_Advantage_Norm_running_averages → Observation_norm_PPO/Observation_Advantage_Norm_running_averages}/vanilla_ppo_update_running_average_.png +0 -0
  31. Observation_norm_SAC/Obser_norm_in_batch/Observation norm._in_batch_.png +0 -0
  32. Observation_norm_SAC/Obser_norm_in_batch/sac_helpers_cnn_in_batch.py +303 -0
  33. Observation_norm_SAC/Obser_norm_in_batch/sac_model_cnn_in_batch.py +206 -0
  34. Observation_norm_SAC/Obser_norm_running_average/Observation norm._running_average_.png +0 -0
  35. Observation_norm_SAC/Obser_norm_running_average/sac_helpers_cnn_running_average.py +350 -0
  36. Observation_norm_SAC/Obser_norm_running_average/sac_model_cnn_running_average.py +207 -0
  37. Observation_norm_SAC/sac_helpers_cnn.py +274 -0
  38. Observation_norm_SAC/sac_model_cnn.py +206 -0
  39. SAC-2/sac-project/sac_helpers_cnn.py +1 -0
{Observation_Advantage_Norm_diff_combo → Observation_norm_PPO/Observation_Advantage_Norm_diff_combo}/ppo__rew_norm_obs_diff_combo.py RENAMED
File without changes
{Observation_Advantage_Norm_diff_combo → Observation_norm_PPO/Observation_Advantage_Norm_diff_combo}/ppo_rew_norm_obs_env_diff_combo.py RENAMED
File without changes
{Observation_Advantage_Norm_diff_env → Observation_norm_PPO/Observation_Advantage_Norm_diff_env}/ppo__rew_norm_obs_diff_env.py RENAMED
File without changes
{Observation_Advantage_Norm_diff_env → Observation_norm_PPO/Observation_Advantage_Norm_diff_env}/ppo_rew_norm_obs_env_diff_env.py RENAMED
File without changes
{Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for Learning Rate of update_advantage_norm.png RENAMED
File without changes
{Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for Learning Rate of update_observation_norm.png RENAMED
File without changes
{Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for Learning Rate of update_return_norm.png RENAMED
File without changes
{Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for Learning Rate of vanilla_ppo_update.png RENAMED
File without changes
{Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for entropy coefficient of update_advantage_norm.png RENAMED
File without changes
{Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for entropy coefficient of update_observation_norm.png RENAMED
File without changes
{Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for entropy coefficient of update_return_norm.png RENAMED
File without changes
{Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for entropy coefficient of vanilla_ppo_update.png RENAMED
File without changes
{Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for gamma value of update_advantage_norm.png RENAMED
File without changes
{Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for gamma value of update_observation_norm.png RENAMED
File without changes
{Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for gamma value of update_return_norm.png RENAMED
File without changes
{Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/Performance config for gamma value of vanilla_ppo_update.png RENAMED
File without changes
{Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/ppo__rew_norm_obs_diff_hyp.py RENAMED
File without changes
{Observation_Advantage_Norm_diff_hypo → Observation_norm_PPO/Observation_Advantage_Norm_diff_hypo}/ppo_rew_norm_obs_env_diff_hypo.py RENAMED
File without changes
{Observation_Advantage_Norm_in_batch → Observation_norm_PPO/Observation_Advantage_Norm_in_batch}/ppo__rew_norm_obs_in_batch.py RENAMED
File without changes
{Observation_Advantage_Norm_in_batch → Observation_norm_PPO/Observation_Advantage_Norm_in_batch}/ppo_rew_norm_obs_env_in_batch.py RENAMED
File without changes
{Observation_Advantage_Norm_in_batch → Observation_norm_PPO/Observation_Advantage_Norm_in_batch}/update_advantage_norm_in_batch.png RENAMED
File without changes
{Observation_Advantage_Norm_in_batch → Observation_norm_PPO/Observation_Advantage_Norm_in_batch}/update_observation_norm_in_batch.png RENAMED
File without changes
{Observation_Advantage_Norm_in_batch → Observation_norm_PPO/Observation_Advantage_Norm_in_batch}/update_return_norm_in_batch.png RENAMED
File without changes
{Observation_Advantage_Norm_in_batch → Observation_norm_PPO/Observation_Advantage_Norm_in_batch}/vanilla_ppo_update_in_batch.png RENAMED
File without changes
{Observation_Advantage_Norm_running_averages → Observation_norm_PPO/Observation_Advantage_Norm_running_averages}/ppo__rew_norm_obs_running_average.py RENAMED
File without changes
{Observation_Advantage_Norm_running_averages → Observation_norm_PPO/Observation_Advantage_Norm_running_averages}/ppo_rew_norm_obs_env_running_average.py RENAMED
File without changes
{Observation_Advantage_Norm_running_averages → Observation_norm_PPO/Observation_Advantage_Norm_running_averages}/update_advantage_norm_running_average_.png RENAMED
File without changes
{Observation_Advantage_Norm_running_averages → Observation_norm_PPO/Observation_Advantage_Norm_running_averages}/update_observation_norm_running_average_.png RENAMED
File without changes
{Observation_Advantage_Norm_running_averages → Observation_norm_PPO/Observation_Advantage_Norm_running_averages}/update_return_norm_running_average_.png RENAMED
File without changes
{Observation_Advantage_Norm_running_averages → Observation_norm_PPO/Observation_Advantage_Norm_running_averages}/vanilla_ppo_update_running_average_.png RENAMED
File without changes
Observation_norm_SAC/Obser_norm_in_batch/Observation norm._in_batch_.png ADDED
Observation_norm_SAC/Obser_norm_in_batch/sac_helpers_cnn_in_batch.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as T
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.distributions import Categorical
6
+
7
+ class Agent:
8
+ def __init__(self, obs_space, action_space, hidden, gamma, lr, alpha, seed, batch_size, tau=0.005):
9
+ if seed is not None:
10
+ np.random.seed(seed)
11
+ T.manual_seed(seed)
12
+
13
+ # Use GPU if available
14
+
15
+ self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
16
+ self.action_dim = int(getattr(action_space, "n", action_space.n)) # Use .n for Discrete
17
+ self.obs_shape = obs_space.shape
18
+ self.observeNorm = ObservationNorm()
19
+ self.gamma, self.tau, self.batch_size = gamma, tau, batch_size
20
+ # Make alpha learnable (adjust entropy based on reward magnitude)
21
+ self.target_entropy = -float(self.action_dim)
22
+ self.log_alpha = T.tensor(np.log(alpha), requires_grad=True, device=self.device)
23
+ self.alpha = np.exp(self.log_alpha.item())
24
+ self.alpha_opt = optim.Adam([self.log_alpha], lr=lr)
25
+
26
+ self.policy = CategoricalActor(self.obs_shape, self.action_dim, hidden).to(self.device)
27
+ self.q1 = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
28
+ self.q2 = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
29
+ self.q1_target = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
30
+ self.q2_target = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
31
+ self.q1_target.load_state_dict(self.q1.state_dict())
32
+ self.q2_target.load_state_dict(self.q2.state_dict())
33
+
34
+ self.policy_opt = optim.Adam(self.policy.parameters(), lr=lr)
35
+ self.q1_opt = optim.Adam(self.q1.parameters(), lr=lr)
36
+ self.q2_opt = optim.Adam(self.q2.parameters(), lr=lr)
37
+ self.memory = Memory()
38
+
39
+ def choose_action(self, observation, eval_mode=False):
40
+ state = T.as_tensor(observation, dtype=T.float32, device=self.device)
41
+ state = self.observeNorm.normalize(state.unsqueeze(0))
42
+ state = state.squeeze(0)
43
+ with T.no_grad():
44
+ logits = self.policy(state.unsqueeze(0))
45
+ dist = Categorical(logits=logits)
46
+ if eval_mode:
47
+ action = logits.argmax(dim=-1)
48
+ else:
49
+ action = dist.sample()
50
+ return int(action.item())
51
+
52
+ def remember(self, state, action, reward, done, next_state):
53
+
54
+ state = T.as_tensor(state, dtype=T.float32, device=self.device)
55
+ if state.dim() == 3: # [C,H,W]
56
+ state = state.unsqueeze(0) # [1,C,H,W]
57
+
58
+ state = self.observeNorm.normalize(state)
59
+ state = state.squeeze(0)
60
+ # next_state also needs normalization
61
+
62
+ next_state = T.as_tensor(next_state, dtype=T.float32, device=self.device)
63
+ if next_state.dim() == 3: # [C,H,W]
64
+ next_state = next_state.unsqueeze(0) # [1,C,H,W]
65
+
66
+ next_state = self.observeNorm.normalize(next_state)
67
+ next_state = next_state.squeeze(0)
68
+
69
+ self.memory.store(state, action, reward, done, next_state)
70
+
71
+ def vanilla_sac_update(self):
72
+ if len(self.memory.states) < self.batch_size:
73
+ return 0.0
74
+
75
+ # Mini-batch sampling
76
+ idxs = np.random.choice(len(self.memory.states), self.batch_size, replace=False)
77
+ states = T.as_tensor(np.array([self.memory.states[i] for i in idxs]), dtype=T.float32, device=self.device)
78
+ actions = T.as_tensor(np.array([self.memory.actions[i] for i in idxs]), dtype=T.int64, device=self.device).unsqueeze(-1)
79
+ rewards = T.as_tensor(np.array([self.memory.rewards[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
80
+ dones = T.as_tensor(np.array([self.memory.dones[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
81
+ next_states = T.as_tensor(np.array([self.memory.next_states[i] for i in idxs]), dtype=T.float32, device=self.device)
82
+
83
+ # Critic update, Soft Q-Learning Objective: to ensure high-entropy actions for exploration
84
+ with T.no_grad():
85
+ next_logits = self.policy(next_states)
86
+ next_dist = Categorical(logits=next_logits)
87
+ next_probs = next_dist.probs
88
+ next_log_probs = next_dist.logits - T.logsumexp(next_dist.logits, dim=-1, keepdim=True)
89
+ q1_next = self.q1_target(next_states)
90
+ q2_next = self.q2_target(next_states)
91
+ # Soft Policy Evaluation
92
+ min_q_next = T.min(q1_next, q2_next)
93
+ next_value = (next_probs * (min_q_next - self.alpha * next_log_probs)).sum(dim=-1, keepdim=True)
94
+ target = rewards + self.gamma * (1 - dones) * next_value
95
+
96
+ q1 = self.q1(states).gather(1, actions)
97
+ q2 = self.q2(states).gather(1, actions)
98
+
99
+ # Losses of both Q-functions
100
+ q1_loss = nn.MSELoss()(q1, target)
101
+ q2_loss = nn.MSELoss()(q2, target)
102
+
103
+ self.q1_opt.zero_grad()
104
+ q1_loss.backward()
105
+ self.q1_opt.step()
106
+ self.q2_opt.zero_grad()
107
+ q2_loss.backward()
108
+ self.q2_opt.step()
109
+
110
+ # Policy/Actor Objective
111
+ logits = self.policy(states)
112
+ dist = Categorical(logits=logits)
113
+ probs = dist.probs
114
+ log_probs = dist.logits - T.logsumexp(dist.logits, dim=-1, keepdim=True)
115
+ q1_policy = self.q1(states)
116
+ q2_policy = self.q2(states)
117
+ min_q_policy = T.min(q1_policy, q2_policy)
118
+ # Slightly different policy loss for discrete actions
119
+ policy_loss = (probs * (self.alpha * log_probs - min_q_policy)).sum(dim=-1).mean()
120
+
121
+ # Temperature to update Alpha
122
+ alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
123
+ self.alpha_opt.zero_grad()
124
+ alpha_loss.backward()
125
+ self.alpha_opt.step()
126
+ self.alpha = self.log_alpha.exp().item()
127
+
128
+ self.policy_opt.zero_grad()
129
+ policy_loss.backward()
130
+ self.policy_opt.step()
131
+
132
+ # Target network update
133
+ for target_param, param in zip(self.q1_target.parameters(), self.q1.parameters()):
134
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
135
+ for target_param, param in zip(self.q2_target.parameters(), self.q2.parameters()):
136
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
137
+
138
+ return policy_loss.item()
139
+
140
+ def update_reward_gradient_clipping(self):
141
+ if len(self.memory.states) < self.batch_size:
142
+ return 0.0
143
+
144
+ # Mini-batch sampling
145
+ idxs = np.random.choice(len(self.memory.states), self.batch_size, replace=False)
146
+ states = T.as_tensor(np.array([self.memory.states[i] for i in idxs]), dtype=T.float32, device=self.device)
147
+ actions = T.as_tensor(np.array([self.memory.actions[i] for i in idxs]), dtype=T.int64, device=self.device).unsqueeze(-1)
148
+ rewards = T.as_tensor(np.array([self.memory.rewards[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
149
+ dones = T.as_tensor(np.array([self.memory.dones[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
150
+ next_states = T.as_tensor(np.array([self.memory.next_states[i] for i in idxs]), dtype=T.float32, device=self.device)
151
+
152
+ """
153
+ # Min-max normalization and tanh scaling to [-1, 1]
154
+ rewards_np = np.array([self.memory.rewards[i] for i in idxs])
155
+ r_min = rewards_np.min()
156
+ r_max = rewards_np.max()
157
+ # Avoid division by zero
158
+ r_scaled = 2 * (rewards_np - r_min) / (r_max - r_min + 1e-8) - 1
159
+ normalized_rewards = np.tanh(r_scaled)
160
+ rewards = T.as_tensor(normalized_rewards, dtype=T.float32, device=self.device).unsqueeze(-1)
161
+ """
162
+
163
+ # Critic update, Soft Q-Learning Objective: to ensure high-entropy actions for exploration
164
+ with T.no_grad():
165
+ next_logits = self.policy(next_states)
166
+ next_dist = Categorical(logits=next_logits)
167
+ next_probs = next_dist.probs
168
+ next_log_probs = next_dist.logits - T.logsumexp(next_dist.logits, dim=-1, keepdim=True)
169
+ q1_next = self.q1_target(next_states)
170
+ q2_next = self.q2_target(next_states)
171
+ # Soft Policy Evaluation
172
+ min_q_next = T.min(q1_next, q2_next)
173
+ next_value = (next_probs * (min_q_next - self.alpha * next_log_probs)).sum(dim=-1, keepdim=True)
174
+ target = rewards + self.gamma * (1 - dones) * next_value
175
+
176
+ q1 = self.q1(states).gather(1, actions)
177
+ q2 = self.q2(states).gather(1, actions)
178
+
179
+ # Losses of both Q-functions
180
+ q1_loss = nn.MSELoss()(q1, target)
181
+ q2_loss = nn.MSELoss()(q2, target)
182
+
183
+ self.q1_opt.zero_grad()
184
+ q1_loss.backward()
185
+ self.q1_opt.step()
186
+ self.q2_opt.zero_grad()
187
+ q2_loss.backward()
188
+ self.q2_opt.step()
189
+
190
+ # Policy/Actor Objective
191
+ logits = self.policy(states)
192
+ dist = Categorical(logits=logits)
193
+ probs = dist.probs
194
+ log_probs = dist.logits - T.logsumexp(dist.logits, dim=-1, keepdim=True)
195
+ q1_policy = self.q1(states)
196
+ q2_policy = self.q2(states)
197
+ min_q_policy = T.min(q1_policy, q2_policy)
198
+ # Slightly different policy loss for discrete actions
199
+ policy_loss = (probs * (self.alpha * log_probs - min_q_policy)).sum(dim=-1).mean()
200
+
201
+ # Temperature to update Alpha
202
+ alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
203
+ self.alpha_opt.zero_grad()
204
+ alpha_loss.backward()
205
+ T.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=1.0) # Gradient clipping
206
+ self.alpha_opt.step()
207
+ self.alpha = self.log_alpha.exp().item()
208
+
209
+ self.policy_opt.zero_grad()
210
+ policy_loss.backward()
211
+ T.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=1.0) # Gradient clipping
212
+ self.policy_opt.step()
213
+
214
+ # Target network update
215
+ for target_param, param in zip(self.q1_target.parameters(), self.q1.parameters()):
216
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
217
+ for target_param, param in zip(self.q2_target.parameters(), self.q2.parameters()):
218
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
219
+
220
+ return policy_loss.item()
221
+
222
+ # Actor/Policy network
223
+ # Typical SAC Actor network is used to output a Gaussian distribution of a state
224
+ # Here, we adapt it for discrete actions using a Categorical distribution, as the ATARI environment is discrete
225
+ # The policy outputs logits for each discrete action.
226
+
227
+ # From: https://ch.mathworks.com/help/reinforcement-learning/ug/soft-actor-critic-agents.html
228
+ # The actor takes the current observation and generates a categorical distribution, in which each possible action is associated with a probability.
229
+
230
+ class CategoricalActor(nn.Module):
231
+ def __init__(self, obs_shape: tuple, action_dim: int, hidden: int):
232
+ super().__init__()
233
+ c, h, w = obs_shape
234
+ self.cnn = nn.Sequential(
235
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
236
+ nn.ReLU(),
237
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
238
+ nn.ReLU(),
239
+ nn.Flatten()
240
+ )
241
+ with T.no_grad():
242
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
243
+ self.fc = nn.Sequential(
244
+ nn.Linear(cnn_output_dim, hidden),
245
+ nn.ReLU(),
246
+ nn.Linear(hidden, action_dim)
247
+ )
248
+
249
+ def forward(self, state: T.Tensor):
250
+ if state.dim() == 3:
251
+ state = state.unsqueeze(0)
252
+ cnn_out = self.cnn(state)
253
+ logits = self.fc(cnn_out)
254
+ return logits
255
+
256
+ # Q-network for discrete actions
257
+ class QNetwork(nn.Module):
258
+ def __init__(self, obs_shape: tuple, action_dim: int, hidden: int):
259
+ super().__init__()
260
+ c, h, w = obs_shape
261
+ self.cnn = nn.Sequential(
262
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
263
+ nn.ReLU(),
264
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
265
+ nn.ReLU(),
266
+ nn.Flatten()
267
+ )
268
+ with T.no_grad():
269
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
270
+ self.net = nn.Sequential(
271
+ nn.Linear(cnn_output_dim, hidden),
272
+ nn.ReLU(),
273
+ nn.Linear(hidden, action_dim)
274
+ )
275
+
276
+ def forward(self, state: T.Tensor):
277
+ if state.dim() == 3:
278
+ state = state.unsqueeze(0)
279
+ cnn_out = self.cnn(state)
280
+ return self.net(cnn_out)
281
+
282
+ class Memory:
283
+ def __init__(self):
284
+ self.states, self.actions, self.rewards, self.dones, self.next_states = [], [], [], [], []
285
+ def store(self, s, a, r, d, ns):
286
+ self.states.append(np.asarray(s, dtype=np.float32))
287
+ self.actions.append(np.asarray(a, dtype=np.float32))
288
+ self.rewards.append(float(r))
289
+ self.dones.append(float(d))
290
+ self.next_states.append(np.asarray(ns, dtype=np.float32))
291
+ def clear(self):
292
+ self.states, self.actions, self.rewards, self.dones, self.next_states = [], [], [], [], []
293
+
294
+
295
+
296
+
297
+
298
+ class ObservationNorm:
299
+
300
+
301
+ def normalize(self, x):
302
+ return (x - x.mean()) / (x.std(unbiased=False) + 1e-8) # We add epsilon to make sure that we don't
303
+ # divide through zero.
Observation_norm_SAC/Obser_norm_in_batch/sac_model_cnn_in_batch.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ale_py
2
+ import gymnasium as gym
3
+ import sys
4
+ import numpy as np
5
+ from sac_helpers_cnn_in_batch import *
6
+ from gymnasium.spaces import Box
7
+ import cv2
8
+ import matplotlib.pyplot as plt
9
+
10
+ def preprocess(obs):
11
+ obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
12
+ obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
13
+ obs = np.expand_dims(obs, axis=0)
14
+ return obs.astype(np.float32) / 255.0
15
+ """
16
+ def main() -> int:
17
+ episode = 0
18
+ total_return = 0
19
+ ep_return = 0
20
+ steps = 100
21
+ batches = 100
22
+ avg_returns = []
23
+ avg_losses = []
24
+
25
+ env = gym.make("ALE/Pacman-v5")
26
+ # Initialize CNN with a dummy observation (to get correct input shape)
27
+ obs, _ = env.reset()
28
+ dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
29
+
30
+ agent = Agent(
31
+ obs_space=dummy_obs_space,
32
+ action_space=env.action_space,
33
+ hidden=64,
34
+ gamma=0.99,
35
+ lr=3e-4,
36
+ alpha=0.2,
37
+ seed=70,
38
+ batch_size=32,
39
+ tau=0.005
40
+ )
41
+
42
+ try:
43
+ obs, info = env.reset(seed=42)
44
+ state = preprocess(obs)
45
+
46
+ for update in range(1, batches + 1):
47
+ batch_loss = []
48
+
49
+ for t in range(steps):
50
+ action = agent.choose_action(state)
51
+ next_obs, reward, terminated, truncated, info = env.step(action)
52
+ done = terminated or truncated
53
+ next_state = preprocess(next_obs)
54
+
55
+ agent.remember(state, action, reward, done, next_state)
56
+
57
+ ep_return += reward
58
+ state = next_state
59
+
60
+ if done:
61
+ episode += 1
62
+ total_return += ep_return
63
+ print(f"Episode {episode} return: {ep_return:.2f}")
64
+ ep_return = 0
65
+ obs, info = env.reset()
66
+ state = preprocess(obs)
67
+
68
+ loss = agent.vanilla_sac_update()
69
+ batch_loss.append(loss)
70
+
71
+ avg_ret = (total_return / episode) if episode else 0
72
+ avg_returns.append(avg_ret)
73
+ avg_losses.append(np.mean(batch_loss))
74
+
75
+ print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={np.mean(batch_loss):.4f}")
76
+
77
+ except Exception as e:
78
+ print(f"Error: {e}", file=sys.stderr)
79
+ return 1
80
+ finally:
81
+ avg = total_return / episode if episode else 0
82
+ print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}")
83
+
84
+ # Plot learning curve
85
+ plt.figure(figsize=(10, 4))
86
+ plt.subplot(1, 2, 1)
87
+ plt.plot(avg_returns)
88
+ plt.xlabel("Update")
89
+ plt.ylabel("Average Return")
90
+ plt.title("SAC Learning Curve")
91
+ plt.grid()
92
+
93
+ plt.subplot(1, 2, 2)
94
+ plt.plot(avg_losses)
95
+ plt.xlabel("Update")
96
+ plt.ylabel("Average Loss")
97
+ plt.title("Average Loss Curve")
98
+ plt.grid()
99
+
100
+ plt.tight_layout()
101
+ plt.show()
102
+ env.close()
103
+
104
+ return 0
105
+ """
106
+
107
+ def run_training(seed=42):
108
+ episode = 0
109
+ total_return = 0
110
+ ep_return = 0
111
+ steps = 100
112
+ batches = 100
113
+ avg_returns = []
114
+ avg_losses = []
115
+
116
+ env = gym.make("ALE/Pacman-v5")
117
+ obs, _ = env.reset()
118
+ dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
119
+
120
+ agent = Agent(
121
+ obs_space=dummy_obs_space,
122
+ action_space=env.action_space,
123
+ hidden=64,
124
+ gamma=0.99,
125
+ lr=3e-4,
126
+ alpha=0.2,
127
+ seed=seed,
128
+ batch_size=32,
129
+ tau=0.005
130
+ )
131
+
132
+ obs, info = env.reset(seed=seed)
133
+ state = preprocess(obs)
134
+
135
+ for update in range(1, batches + 1):
136
+ batch_loss = []
137
+ for t in range(steps):
138
+ action = agent.choose_action(state)
139
+ next_obs, reward, terminated, truncated, info = env.step(action)
140
+ done = terminated or truncated
141
+ next_state = preprocess(next_obs)
142
+
143
+ agent.remember(state, action, reward, done, next_state)
144
+
145
+ ep_return += reward
146
+ state = next_state
147
+
148
+ if done:
149
+ episode += 1
150
+ total_return += ep_return
151
+ ep_return = 0
152
+ obs, info = env.reset()
153
+ state = preprocess(obs)
154
+
155
+ loss = agent.vanilla_sac_update()
156
+ #loss = agent.update_reward_gradient_clipping()
157
+ batch_loss.append(loss)
158
+
159
+ avg_ret = (total_return / episode) if episode else 0
160
+ avg_returns.append(avg_ret)
161
+ avg_losses.append(np.mean(batch_loss))
162
+
163
+ env.close()
164
+ return avg_returns, avg_losses
165
+
166
+ def main() -> int:
167
+ num_runs = 5
168
+ all_returns = []
169
+ all_losses = []
170
+
171
+ for run in range(num_runs):
172
+ print(f"Starting run {run+1}/{num_runs}")
173
+ avg_returns, avg_losses = run_training(seed=42 + run)
174
+ all_returns.append(avg_returns)
175
+ all_losses.append(avg_losses)
176
+
177
+ # Convert to numpy arrays for easy averaging
178
+ all_returns = np.array(all_returns)
179
+ all_losses = np.array(all_losses)
180
+
181
+ mean_returns = np.mean(all_returns, axis=0)
182
+ mean_losses = np.mean(all_losses, axis=0)
183
+
184
+ # Plot averaged learning curves
185
+ plt.figure(figsize=(10, 4))
186
+ plt.subplot(1, 2, 1)
187
+ plt.plot(mean_returns)
188
+ plt.xlabel("Update")
189
+ plt.ylabel("Average Return")
190
+ plt.title(f"SAC Learning Curve (avg over {num_runs} runs)")
191
+ plt.grid()
192
+
193
+ plt.subplot(1, 2, 2)
194
+ plt.plot(mean_losses)
195
+ plt.xlabel("Update")
196
+ plt.ylabel("Average Loss")
197
+ plt.title(f"Average Loss Curve (avg over {num_runs} runs)")
198
+ plt.grid()
199
+
200
+ plt.tight_layout()
201
+ plt.savefig("Observation norm." + "_in_batch_.png")
202
+
203
+ return 0
204
+
205
+ if __name__ == "__main__":
206
+ raise SystemExit(main())
Observation_norm_SAC/Obser_norm_running_average/Observation norm._running_average_.png ADDED
Observation_norm_SAC/Obser_norm_running_average/sac_helpers_cnn_running_average.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as T
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.distributions import Categorical
6
+
7
+ class Agent:
8
+ def __init__(self, obs_space, action_space, hidden, gamma, lr, alpha, seed, batch_size, tau=0.005):
9
+ if seed is not None:
10
+ np.random.seed(seed)
11
+ T.manual_seed(seed)
12
+
13
+ # Use GPU if available
14
+
15
+ self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
16
+ self.action_dim = int(getattr(action_space, "n", action_space.n)) # Use .n for Discrete
17
+ self.obs_shape = obs_space.shape
18
+
19
+ self.gamma, self.tau, self.batch_size = gamma, tau, batch_size
20
+ # Make alpha learnable (adjust entropy based on reward magnitude)
21
+ self.target_entropy = -float(self.action_dim)
22
+ self.log_alpha = T.tensor(np.log(alpha), requires_grad=True, device=self.device)
23
+ self.alpha = np.exp(self.log_alpha.item())
24
+ self.alpha_opt = optim.Adam([self.log_alpha], lr=lr)
25
+ self.observeNorm = ObservationNorm(self.obs_shape)
26
+
27
+
28
+ self.policy = CategoricalActor(self.obs_shape, self.action_dim, hidden).to(self.device)
29
+ self.q1 = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
30
+ self.q2 = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
31
+ self.q1_target = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
32
+ self.q2_target = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
33
+ self.q1_target.load_state_dict(self.q1.state_dict())
34
+ self.q2_target.load_state_dict(self.q2.state_dict())
35
+
36
+ self.policy_opt = optim.Adam(self.policy.parameters(), lr=lr)
37
+ self.q1_opt = optim.Adam(self.q1.parameters(), lr=lr)
38
+ self.q2_opt = optim.Adam(self.q2.parameters(), lr=lr)
39
+ self.memory = Memory()
40
+
41
+ def choose_action(self, observation, eval_mode=False):
42
+ state = T.as_tensor(observation, dtype=T.float32, device=self.device)
43
+
44
+ state = self.observeNorm.normalize(state.unsqueeze(0))
45
+ state = state.squeeze(0)
46
+ with T.no_grad():
47
+ logits = self.policy(state.unsqueeze(0))
48
+ dist = Categorical(logits=logits)
49
+ if eval_mode:
50
+ action = logits.argmax(dim=-1)
51
+ else:
52
+ action = dist.sample()
53
+ return int(action.item())
54
+
55
+ def remember(self, state, action, reward, done, next_state):
56
+
57
+ state = T.as_tensor(state, dtype=T.float32, device=self.device)
58
+ if state.dim() == 3: # [C,H,W]
59
+ state = state.unsqueeze(0) # [1,C,H,W]
60
+ self.observeNorm.update(state)
61
+ state = self.observeNorm.normalize(state)
62
+ state = state.squeeze(0)
63
+ #next_state also needs normalization
64
+
65
+ next_state = T.as_tensor(next_state, dtype=T.float32, device=self.device)
66
+ if next_state.dim() == 3: # [C,H,W]
67
+ next_state = next_state.unsqueeze(0) # [1,C,H,W]
68
+ self.observeNorm.update(next_state)
69
+ next_state = self.observeNorm.normalize(next_state)
70
+ next_state = next_state.squeeze(0)
71
+
72
+
73
+ self.memory.store(state, action, reward, done, next_state)
74
+
75
+ def vanilla_sac_update(self):
76
+ if len(self.memory.states) < self.batch_size:
77
+ return 0.0
78
+
79
+ # Mini-batch sampling
80
+ idxs = np.random.choice(len(self.memory.states), self.batch_size, replace=False)
81
+ states = T.as_tensor(np.array([self.memory.states[i] for i in idxs]), dtype=T.float32, device=self.device)
82
+ actions = T.as_tensor(np.array([self.memory.actions[i] for i in idxs]), dtype=T.int64, device=self.device).unsqueeze(-1)
83
+ rewards = T.as_tensor(np.array([self.memory.rewards[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
84
+ dones = T.as_tensor(np.array([self.memory.dones[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
85
+ next_states = T.as_tensor(np.array([self.memory.next_states[i] for i in idxs]), dtype=T.float32, device=self.device)
86
+
87
+ # Critic update, Soft Q-Learning Objective: to ensure high-entropy actions for exploration
88
+ with T.no_grad():
89
+ next_logits = self.policy(next_states)
90
+ next_dist = Categorical(logits=next_logits)
91
+ next_probs = next_dist.probs
92
+ next_log_probs = next_dist.logits - T.logsumexp(next_dist.logits, dim=-1, keepdim=True)
93
+ q1_next = self.q1_target(next_states)
94
+ q2_next = self.q2_target(next_states)
95
+ # Soft Policy Evaluation
96
+ min_q_next = T.min(q1_next, q2_next)
97
+ next_value = (next_probs * (min_q_next - self.alpha * next_log_probs)).sum(dim=-1, keepdim=True)
98
+ target = rewards + self.gamma * (1 - dones) * next_value
99
+
100
+ q1 = self.q1(states).gather(1, actions)
101
+ q2 = self.q2(states).gather(1, actions)
102
+
103
+ # Losses of both Q-functions
104
+ q1_loss = nn.MSELoss()(q1, target)
105
+ q2_loss = nn.MSELoss()(q2, target)
106
+
107
+ self.q1_opt.zero_grad()
108
+ q1_loss.backward()
109
+ self.q1_opt.step()
110
+ self.q2_opt.zero_grad()
111
+ q2_loss.backward()
112
+ self.q2_opt.step()
113
+
114
+ # Policy/Actor Objective
115
+ logits = self.policy(states)
116
+ dist = Categorical(logits=logits)
117
+ probs = dist.probs
118
+ log_probs = dist.logits - T.logsumexp(dist.logits, dim=-1, keepdim=True)
119
+ q1_policy = self.q1(states)
120
+ q2_policy = self.q2(states)
121
+ min_q_policy = T.min(q1_policy, q2_policy)
122
+ # Slightly different policy loss for discrete actions
123
+ policy_loss = (probs * (self.alpha * log_probs - min_q_policy)).sum(dim=-1).mean()
124
+
125
+ # Temperature to update Alpha
126
+ alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
127
+ self.alpha_opt.zero_grad()
128
+ alpha_loss.backward()
129
+ self.alpha_opt.step()
130
+ self.alpha = self.log_alpha.exp().item()
131
+
132
+ self.policy_opt.zero_grad()
133
+ policy_loss.backward()
134
+ self.policy_opt.step()
135
+
136
+ # Target network update
137
+ for target_param, param in zip(self.q1_target.parameters(), self.q1.parameters()):
138
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
139
+ for target_param, param in zip(self.q2_target.parameters(), self.q2.parameters()):
140
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
141
+
142
+ return policy_loss.item()
143
+
144
+ def update_reward_gradient_clipping(self):
145
+ if len(self.memory.states) < self.batch_size:
146
+ return 0.0
147
+
148
+ # Mini-batch sampling
149
+ idxs = np.random.choice(len(self.memory.states), self.batch_size, replace=False)
150
+ states = T.as_tensor(np.array([self.memory.states[i] for i in idxs]), dtype=T.float32, device=self.device)
151
+ actions = T.as_tensor(np.array([self.memory.actions[i] for i in idxs]), dtype=T.int64, device=self.device).unsqueeze(-1)
152
+ rewards = T.as_tensor(np.array([self.memory.rewards[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
153
+ dones = T.as_tensor(np.array([self.memory.dones[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
154
+ next_states = T.as_tensor(np.array([self.memory.next_states[i] for i in idxs]), dtype=T.float32, device=self.device)
155
+
156
+ """
157
+ # Min-max normalization and tanh scaling to [-1, 1]
158
+ rewards_np = np.array([self.memory.rewards[i] for i in idxs])
159
+ r_min = rewards_np.min()
160
+ r_max = rewards_np.max()
161
+ # Avoid division by zero
162
+ r_scaled = 2 * (rewards_np - r_min) / (r_max - r_min + 1e-8) - 1
163
+ normalized_rewards = np.tanh(r_scaled)
164
+ rewards = T.as_tensor(normalized_rewards, dtype=T.float32, device=self.device).unsqueeze(-1)
165
+ """
166
+
167
+ # Critic update, Soft Q-Learning Objective: to ensure high-entropy actions for exploration
168
+ with T.no_grad():
169
+ next_logits = self.policy(next_states)
170
+ next_dist = Categorical(logits=next_logits)
171
+ next_probs = next_dist.probs
172
+ next_log_probs = next_dist.logits - T.logsumexp(next_dist.logits, dim=-1, keepdim=True)
173
+ q1_next = self.q1_target(next_states)
174
+ q2_next = self.q2_target(next_states)
175
+ # Soft Policy Evaluation
176
+ min_q_next = T.min(q1_next, q2_next)
177
+ next_value = (next_probs * (min_q_next - self.alpha * next_log_probs)).sum(dim=-1, keepdim=True)
178
+ target = rewards + self.gamma * (1 - dones) * next_value
179
+
180
+ q1 = self.q1(states).gather(1, actions)
181
+ q2 = self.q2(states).gather(1, actions)
182
+
183
+ # Losses of both Q-functions
184
+ q1_loss = nn.MSELoss()(q1, target)
185
+ q2_loss = nn.MSELoss()(q2, target)
186
+
187
+ self.q1_opt.zero_grad()
188
+ q1_loss.backward()
189
+ self.q1_opt.step()
190
+ self.q2_opt.zero_grad()
191
+ q2_loss.backward()
192
+ self.q2_opt.step()
193
+
194
+ # Policy/Actor Objective
195
+ logits = self.policy(states)
196
+ dist = Categorical(logits=logits)
197
+ probs = dist.probs
198
+ log_probs = dist.logits - T.logsumexp(dist.logits, dim=-1, keepdim=True)
199
+ q1_policy = self.q1(states)
200
+ q2_policy = self.q2(states)
201
+ min_q_policy = T.min(q1_policy, q2_policy)
202
+ # Slightly different policy loss for discrete actions
203
+ policy_loss = (probs * (self.alpha * log_probs - min_q_policy)).sum(dim=-1).mean()
204
+
205
+ # Temperature to update Alpha
206
+ alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
207
+ self.alpha_opt.zero_grad()
208
+ alpha_loss.backward()
209
+ T.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=1.0) # Gradient clipping
210
+ self.alpha_opt.step()
211
+ self.alpha = self.log_alpha.exp().item()
212
+
213
+ self.policy_opt.zero_grad()
214
+ policy_loss.backward()
215
+ T.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=1.0) # Gradient clipping
216
+ self.policy_opt.step()
217
+
218
+ # Target network update
219
+ for target_param, param in zip(self.q1_target.parameters(), self.q1.parameters()):
220
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
221
+ for target_param, param in zip(self.q2_target.parameters(), self.q2.parameters()):
222
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
223
+
224
+ return policy_loss.item()
225
+
226
+
227
+
228
+ # Actor/Policy network
229
+ # Typical SAC Actor network is used to output a Gaussian distribution of a state
230
+ # Here, we adapt it for discrete actions using a Categorical distribution, as the ATARI environment is discrete
231
+ # The policy outputs logits for each discrete action.
232
+
233
+ # From: https://ch.mathworks.com/help/reinforcement-learning/ug/soft-actor-critic-agents.html
234
+ # The actor takes the current observation and generates a categorical distribution, in which each possible action is associated with a probability.
235
+
236
+ class CategoricalActor(nn.Module):
237
+ def __init__(self, obs_shape: tuple, action_dim: int, hidden: int):
238
+ super().__init__()
239
+ c, h, w = obs_shape
240
+ self.cnn = nn.Sequential(
241
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
242
+ nn.ReLU(),
243
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
244
+ nn.ReLU(),
245
+ nn.Flatten()
246
+ )
247
+ with T.no_grad():
248
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
249
+ self.fc = nn.Sequential(
250
+ nn.Linear(cnn_output_dim, hidden),
251
+ nn.ReLU(),
252
+ nn.Linear(hidden, action_dim)
253
+ )
254
+
255
+ def forward(self, state: T.Tensor):
256
+ if state.dim() == 3:
257
+ state = state.unsqueeze(0)
258
+ cnn_out = self.cnn(state)
259
+ logits = self.fc(cnn_out)
260
+ return logits
261
+
262
+ # Q-network for discrete actions
263
+ class QNetwork(nn.Module):
264
+ def __init__(self, obs_shape: tuple, action_dim: int, hidden: int):
265
+ super().__init__()
266
+ c, h, w = obs_shape
267
+ self.cnn = nn.Sequential(
268
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
269
+ nn.ReLU(),
270
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
271
+ nn.ReLU(),
272
+ nn.Flatten()
273
+ )
274
+ with T.no_grad():
275
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
276
+ self.net = nn.Sequential(
277
+ nn.Linear(cnn_output_dim, hidden),
278
+ nn.ReLU(),
279
+ nn.Linear(hidden, action_dim)
280
+ )
281
+
282
+ def forward(self, state: T.Tensor):
283
+ if state.dim() == 3:
284
+ state = state.unsqueeze(0)
285
+ cnn_out = self.cnn(state)
286
+ return self.net(cnn_out)
287
+
288
+ class Memory:
289
+ def __init__(self):
290
+ self.states, self.actions, self.rewards, self.dones, self.next_states = [], [], [], [], []
291
+ def store(self, s, a, r, d, ns):
292
+ self.states.append(np.asarray(s, dtype=np.float32))
293
+ self.actions.append(np.asarray(a, dtype=np.float32))
294
+ self.rewards.append(float(r))
295
+ self.dones.append(float(d))
296
+ self.next_states.append(np.asarray(ns, dtype=np.float32))
297
+ def clear(self):
298
+ self.states, self.actions, self.rewards, self.dones, self.next_states = [], [], [], [], []
299
+
300
+
301
+
302
+
303
+
304
+
305
+ class ObservationNorm:
306
+ def __init__(self, shape):
307
+ c, h, w = shape
308
+ self.mean = T.zeros((c, 1, 1))
309
+ self.var = T.ones((c, 1, 1))
310
+ self.count = 1e-4
311
+
312
+ def update(self, x: T.Tensor):
313
+ batch_mean = x.mean(dim=[0, 2, 3], keepdim=True)
314
+ batch_var = x.var(dim=[0, 2, 3], keepdim=True, unbiased=False)
315
+ batch_count = x.shape[0]
316
+ self._update_from_moments(batch_mean, batch_var, batch_count)
317
+
318
+ def _update_from_moments(self, batch_mean, batch_var, batch_count):
319
+ delta = batch_mean - self.mean
320
+ total_count = self.count + batch_count
321
+
322
+ # Update mean
323
+ new_mean = self.mean + delta * batch_count / total_count
324
+
325
+ # Update variance
326
+ m_a = self.var * self.count # scaled old variance
327
+ m_b = batch_var * batch_count # scaled batch variance
328
+ M2 = m_a + m_b + (delta ** 2) * (self.count * batch_count / total_count)
329
+ new_var = M2 / total_count
330
+
331
+ # Assign updates
332
+ self.mean = new_mean
333
+ self.var = new_var
334
+ self.count = total_count
335
+
336
+ def normalize(self, x:T.Tensor):
337
+ return (x - self.mean) / (T.sqrt(self.var) + 1e-8) # We add epsilon to make sure that we don't
338
+ # divide through zero.
339
+
340
+
341
+
342
+
343
+
344
+
345
+
346
+
347
+
348
+
349
+
350
+
Observation_norm_SAC/Obser_norm_running_average/sac_model_cnn_running_average.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ale_py
2
+ import gymnasium as gym
3
+ import sys
4
+ import numpy as np
5
+ from sac_helpers_cnn_running_average import *
6
+ from gymnasium.spaces import Box
7
+ import cv2
8
+ import matplotlib.pyplot as plt
9
+
10
+ def preprocess(obs):
11
+ obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
12
+ obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
13
+ obs = np.expand_dims(obs, axis=0)
14
+ return obs.astype(np.float32) / 255.0
15
+ """
16
+ def main() -> int:
17
+ episode = 0
18
+ total_return = 0
19
+ ep_return = 0
20
+ steps = 100
21
+ batches = 100
22
+ avg_returns = []
23
+ avg_losses = []
24
+
25
+ env = gym.make("ALE/Pacman-v5")
26
+ # Initialize CNN with a dummy observation (to get correct input shape)
27
+ obs, _ = env.reset()
28
+ dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
29
+
30
+ agent = Agent(
31
+ obs_space=dummy_obs_space,
32
+ action_space=env.action_space,
33
+ hidden=64,
34
+ gamma=0.99,
35
+ lr=3e-4,
36
+ alpha=0.2,
37
+ seed=70,
38
+ batch_size=32,
39
+ tau=0.005
40
+ )
41
+
42
+ try:
43
+ obs, info = env.reset(seed=42)
44
+ state = preprocess(obs)
45
+
46
+ for update in range(1, batches + 1):
47
+ batch_loss = []
48
+
49
+ for t in range(steps):
50
+ action = agent.choose_action(state)
51
+ next_obs, reward, terminated, truncated, info = env.step(action)
52
+ done = terminated or truncated
53
+ next_state = preprocess(next_obs)
54
+
55
+ agent.remember(state, action, reward, done, next_state)
56
+
57
+ ep_return += reward
58
+ state = next_state
59
+
60
+ if done:
61
+ episode += 1
62
+ total_return += ep_return
63
+ print(f"Episode {episode} return: {ep_return:.2f}")
64
+ ep_return = 0
65
+ obs, info = env.reset()
66
+ state = preprocess(obs)
67
+
68
+ loss = agent.vanilla_sac_update()
69
+ batch_loss.append(loss)
70
+
71
+ avg_ret = (total_return / episode) if episode else 0
72
+ avg_returns.append(avg_ret)
73
+ avg_losses.append(np.mean(batch_loss))
74
+
75
+ print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={np.mean(batch_loss):.4f}")
76
+
77
+ except Exception as e:
78
+ print(f"Error: {e}", file=sys.stderr)
79
+ return 1
80
+ finally:
81
+ avg = total_return / episode if episode else 0
82
+ print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}")
83
+
84
+ # Plot learning curve
85
+ plt.figure(figsize=(10, 4))
86
+ plt.subplot(1, 2, 1)
87
+ plt.plot(avg_returns)
88
+ plt.xlabel("Update")
89
+ plt.ylabel("Average Return")
90
+ plt.title("SAC Learning Curve")
91
+ plt.grid()
92
+
93
+ plt.subplot(1, 2, 2)
94
+ plt.plot(avg_losses)
95
+ plt.xlabel("Update")
96
+ plt.ylabel("Average Loss")
97
+ plt.title("Average Loss Curve")
98
+ plt.grid()
99
+
100
+ plt.tight_layout()
101
+ plt.show()
102
+ env.close()
103
+
104
+ return 0
105
+ """
106
+
107
+ def run_training(seed=42):
108
+ episode = 0
109
+ total_return = 0
110
+ ep_return = 0
111
+ steps = 100
112
+ batches = 100
113
+ avg_returns = []
114
+ avg_losses = []
115
+
116
+ env = gym.make("ALE/Pacman-v5")
117
+ obs, _ = env.reset()
118
+ dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
119
+
120
+ agent = Agent(
121
+ obs_space=dummy_obs_space,
122
+ action_space=env.action_space,
123
+ hidden=64,
124
+ gamma=0.99,
125
+ lr=3e-4,
126
+ alpha=0.2,
127
+ seed=seed,
128
+ batch_size=32,
129
+ tau=0.005
130
+ )
131
+
132
+ obs, info = env.reset(seed=seed)
133
+ state = preprocess(obs)
134
+
135
+ for update in range(1, batches + 1):
136
+ batch_loss = []
137
+ for t in range(steps):
138
+ action = agent.choose_action(state)
139
+ next_obs, reward, terminated, truncated, info = env.step(action)
140
+ done = terminated or truncated
141
+ next_state = preprocess(next_obs)
142
+
143
+ agent.remember(state, action, reward, done, next_state)
144
+
145
+ ep_return += reward
146
+ state = next_state
147
+
148
+ if done:
149
+ episode += 1
150
+ total_return += ep_return
151
+ ep_return = 0
152
+ obs, info = env.reset()
153
+ state = preprocess(obs)
154
+
155
+ loss = agent.vanilla_sac_update()
156
+ #loss = agent.update_reward_gradient_clipping()
157
+ batch_loss.append(loss)
158
+
159
+ avg_ret = (total_return / episode) if episode else 0
160
+ avg_returns.append(avg_ret)
161
+ avg_losses.append(np.mean(batch_loss))
162
+
163
+ env.close()
164
+ return avg_returns, avg_losses
165
+
166
+ def main() -> int:
167
+ num_runs = 5
168
+ all_returns = []
169
+ all_losses = []
170
+
171
+ for run in range(num_runs):
172
+ print(f"Starting run {run+1}/{num_runs}")
173
+ avg_returns, avg_losses = run_training(seed=42 + run)
174
+ all_returns.append(avg_returns)
175
+ all_losses.append(avg_losses)
176
+
177
+ # Convert to numpy arrays for easy averaging
178
+ all_returns = np.array(all_returns)
179
+ all_losses = np.array(all_losses)
180
+
181
+ mean_returns = np.mean(all_returns, axis=0)
182
+ mean_losses = np.mean(all_losses, axis=0)
183
+
184
+ # Plot averaged learning curves
185
+ plt.figure(figsize=(10, 4))
186
+ plt.subplot(1, 2, 1)
187
+ plt.plot(mean_returns)
188
+ plt.xlabel("Update")
189
+ plt.ylabel("Average Return")
190
+ plt.title(f"SAC Learning Curve (avg over {num_runs} runs)")
191
+ plt.grid()
192
+
193
+ plt.subplot(1, 2, 2)
194
+ plt.plot(mean_losses)
195
+ plt.xlabel("Update")
196
+ plt.ylabel("Average Loss")
197
+ plt.title(f"Average Loss Curve (avg over {num_runs} runs)")
198
+ plt.grid()
199
+
200
+ plt.tight_layout()
201
+ plt.savefig("Observation norm." + "_running_average_.png")
202
+
203
+
204
+ return 0
205
+
206
+ if __name__ == "__main__":
207
+ raise SystemExit(main())
Observation_norm_SAC/sac_helpers_cnn.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as T
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.distributions import Categorical
6
+
7
+ class Agent:
8
+ def __init__(self, obs_space, action_space, hidden, gamma, lr, alpha, seed, batch_size, tau=0.005):
9
+ if seed is not None:
10
+ np.random.seed(seed)
11
+ T.manual_seed(seed)
12
+
13
+ # Use GPU if available
14
+
15
+ self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
16
+ self.action_dim = int(getattr(action_space, "n", action_space.n)) # Use .n for Discrete
17
+ self.obs_shape = obs_space.shape
18
+
19
+ self.gamma, self.tau, self.batch_size = gamma, tau, batch_size
20
+ # Make alpha learnable (adjust entropy based on reward magnitude)
21
+ self.target_entropy = -float(self.action_dim)
22
+ self.log_alpha = T.tensor(np.log(alpha), requires_grad=True, device=self.device)
23
+ self.alpha = np.exp(self.log_alpha.item())
24
+ self.alpha_opt = optim.Adam([self.log_alpha], lr=lr)
25
+
26
+ self.policy = CategoricalActor(self.obs_shape, self.action_dim, hidden).to(self.device)
27
+ self.q1 = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
28
+ self.q2 = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
29
+ self.q1_target = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
30
+ self.q2_target = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
31
+ self.q1_target.load_state_dict(self.q1.state_dict())
32
+ self.q2_target.load_state_dict(self.q2.state_dict())
33
+
34
+ self.policy_opt = optim.Adam(self.policy.parameters(), lr=lr)
35
+ self.q1_opt = optim.Adam(self.q1.parameters(), lr=lr)
36
+ self.q2_opt = optim.Adam(self.q2.parameters(), lr=lr)
37
+ self.memory = Memory()
38
+
39
+ def choose_action(self, observation, eval_mode=False):
40
+ state = T.as_tensor(observation, dtype=T.float32, device=self.device)
41
+ with T.no_grad():
42
+ logits = self.policy(state.unsqueeze(0))
43
+ dist = Categorical(logits=logits)
44
+ if eval_mode:
45
+ action = logits.argmax(dim=-1)
46
+ else:
47
+ action = dist.sample()
48
+ return int(action.item())
49
+
50
+ def remember(self, state, action, reward, done, next_state):
51
+ self.memory.store(state, action, reward, done, next_state)
52
+
53
+ def vanilla_sac_update(self):
54
+ if len(self.memory.states) < self.batch_size:
55
+ return 0.0
56
+
57
+ # Mini-batch sampling
58
+ idxs = np.random.choice(len(self.memory.states), self.batch_size, replace=False)
59
+ states = T.as_tensor(np.array([self.memory.states[i] for i in idxs]), dtype=T.float32, device=self.device)
60
+ actions = T.as_tensor(np.array([self.memory.actions[i] for i in idxs]), dtype=T.int64, device=self.device).unsqueeze(-1)
61
+ rewards = T.as_tensor(np.array([self.memory.rewards[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
62
+ dones = T.as_tensor(np.array([self.memory.dones[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
63
+ next_states = T.as_tensor(np.array([self.memory.next_states[i] for i in idxs]), dtype=T.float32, device=self.device)
64
+
65
+ # Critic update, Soft Q-Learning Objective: to ensure high-entropy actions for exploration
66
+ with T.no_grad():
67
+ next_logits = self.policy(next_states)
68
+ next_dist = Categorical(logits=next_logits)
69
+ next_probs = next_dist.probs
70
+ next_log_probs = next_dist.logits - T.logsumexp(next_dist.logits, dim=-1, keepdim=True)
71
+ q1_next = self.q1_target(next_states)
72
+ q2_next = self.q2_target(next_states)
73
+ # Soft Policy Evaluation
74
+ min_q_next = T.min(q1_next, q2_next)
75
+ next_value = (next_probs * (min_q_next - self.alpha * next_log_probs)).sum(dim=-1, keepdim=True)
76
+ target = rewards + self.gamma * (1 - dones) * next_value
77
+
78
+ q1 = self.q1(states).gather(1, actions)
79
+ q2 = self.q2(states).gather(1, actions)
80
+
81
+ # Losses of both Q-functions
82
+ q1_loss = nn.MSELoss()(q1, target)
83
+ q2_loss = nn.MSELoss()(q2, target)
84
+
85
+ self.q1_opt.zero_grad()
86
+ q1_loss.backward()
87
+ self.q1_opt.step()
88
+ self.q2_opt.zero_grad()
89
+ q2_loss.backward()
90
+ self.q2_opt.step()
91
+
92
+ # Policy/Actor Objective
93
+ logits = self.policy(states)
94
+ dist = Categorical(logits=logits)
95
+ probs = dist.probs
96
+ log_probs = dist.logits - T.logsumexp(dist.logits, dim=-1, keepdim=True)
97
+ q1_policy = self.q1(states)
98
+ q2_policy = self.q2(states)
99
+ min_q_policy = T.min(q1_policy, q2_policy)
100
+ # Slightly different policy loss for discrete actions
101
+ policy_loss = (probs * (self.alpha * log_probs - min_q_policy)).sum(dim=-1).mean()
102
+
103
+ # Temperature to update Alpha
104
+ alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
105
+ self.alpha_opt.zero_grad()
106
+ alpha_loss.backward()
107
+ self.alpha_opt.step()
108
+ self.alpha = self.log_alpha.exp().item()
109
+
110
+ self.policy_opt.zero_grad()
111
+ policy_loss.backward()
112
+ self.policy_opt.step()
113
+
114
+ # Target network update
115
+ for target_param, param in zip(self.q1_target.parameters(), self.q1.parameters()):
116
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
117
+ for target_param, param in zip(self.q2_target.parameters(), self.q2.parameters()):
118
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
119
+
120
+ return policy_loss.item()
121
+
122
+ def update_reward_gradient_clipping(self):
123
+ if len(self.memory.states) < self.batch_size:
124
+ return 0.0
125
+
126
+ # Mini-batch sampling
127
+ idxs = np.random.choice(len(self.memory.states), self.batch_size, replace=False)
128
+ states = T.as_tensor(np.array([self.memory.states[i] for i in idxs]), dtype=T.float32, device=self.device)
129
+ actions = T.as_tensor(np.array([self.memory.actions[i] for i in idxs]), dtype=T.int64, device=self.device).unsqueeze(-1)
130
+ rewards = T.as_tensor(np.array([self.memory.rewards[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
131
+ dones = T.as_tensor(np.array([self.memory.dones[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
132
+ next_states = T.as_tensor(np.array([self.memory.next_states[i] for i in idxs]), dtype=T.float32, device=self.device)
133
+
134
+ """
135
+ # Min-max normalization and tanh scaling to [-1, 1]
136
+ rewards_np = np.array([self.memory.rewards[i] for i in idxs])
137
+ r_min = rewards_np.min()
138
+ r_max = rewards_np.max()
139
+ # Avoid division by zero
140
+ r_scaled = 2 * (rewards_np - r_min) / (r_max - r_min + 1e-8) - 1
141
+ normalized_rewards = np.tanh(r_scaled)
142
+ rewards = T.as_tensor(normalized_rewards, dtype=T.float32, device=self.device).unsqueeze(-1)
143
+ """
144
+
145
+ # Critic update, Soft Q-Learning Objective: to ensure high-entropy actions for exploration
146
+ with T.no_grad():
147
+ next_logits = self.policy(next_states)
148
+ next_dist = Categorical(logits=next_logits)
149
+ next_probs = next_dist.probs
150
+ next_log_probs = next_dist.logits - T.logsumexp(next_dist.logits, dim=-1, keepdim=True)
151
+ q1_next = self.q1_target(next_states)
152
+ q2_next = self.q2_target(next_states)
153
+ # Soft Policy Evaluation
154
+ min_q_next = T.min(q1_next, q2_next)
155
+ next_value = (next_probs * (min_q_next - self.alpha * next_log_probs)).sum(dim=-1, keepdim=True)
156
+ target = rewards + self.gamma * (1 - dones) * next_value
157
+
158
+ q1 = self.q1(states).gather(1, actions)
159
+ q2 = self.q2(states).gather(1, actions)
160
+
161
+ # Losses of both Q-functions
162
+ q1_loss = nn.MSELoss()(q1, target)
163
+ q2_loss = nn.MSELoss()(q2, target)
164
+
165
+ self.q1_opt.zero_grad()
166
+ q1_loss.backward()
167
+ self.q1_opt.step()
168
+ self.q2_opt.zero_grad()
169
+ q2_loss.backward()
170
+ self.q2_opt.step()
171
+
172
+ # Policy/Actor Objective
173
+ logits = self.policy(states)
174
+ dist = Categorical(logits=logits)
175
+ probs = dist.probs
176
+ log_probs = dist.logits - T.logsumexp(dist.logits, dim=-1, keepdim=True)
177
+ q1_policy = self.q1(states)
178
+ q2_policy = self.q2(states)
179
+ min_q_policy = T.min(q1_policy, q2_policy)
180
+ # Slightly different policy loss for discrete actions
181
+ policy_loss = (probs * (self.alpha * log_probs - min_q_policy)).sum(dim=-1).mean()
182
+
183
+ # Temperature to update Alpha
184
+ alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
185
+ self.alpha_opt.zero_grad()
186
+ alpha_loss.backward()
187
+ T.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=1.0) # Gradient clipping
188
+ self.alpha_opt.step()
189
+ self.alpha = self.log_alpha.exp().item()
190
+
191
+ self.policy_opt.zero_grad()
192
+ policy_loss.backward()
193
+ T.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=1.0) # Gradient clipping
194
+ self.policy_opt.step()
195
+
196
+ # Target network update
197
+ for target_param, param in zip(self.q1_target.parameters(), self.q1.parameters()):
198
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
199
+ for target_param, param in zip(self.q2_target.parameters(), self.q2.parameters()):
200
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
201
+
202
+ return policy_loss.item()
203
+
204
+ # Actor/Policy network
205
+ # Typical SAC Actor network is used to output a Gaussian distribution of a state
206
+ # Here, we adapt it for discrete actions using a Categorical distribution, as the ATARI environment is discrete
207
+ # The policy outputs logits for each discrete action.
208
+
209
+ # From: https://ch.mathworks.com/help/reinforcement-learning/ug/soft-actor-critic-agents.html
210
+ # The actor takes the current observation and generates a categorical distribution, in which each possible action is associated with a probability.
211
+
212
+ class CategoricalActor(nn.Module):
213
+ def __init__(self, obs_shape: tuple, action_dim: int, hidden: int):
214
+ super().__init__()
215
+ c, h, w = obs_shape
216
+ self.cnn = nn.Sequential(
217
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
218
+ nn.ReLU(),
219
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
220
+ nn.ReLU(),
221
+ nn.Flatten()
222
+ )
223
+ with T.no_grad():
224
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
225
+ self.fc = nn.Sequential(
226
+ nn.Linear(cnn_output_dim, hidden),
227
+ nn.ReLU(),
228
+ nn.Linear(hidden, action_dim)
229
+ )
230
+
231
+ def forward(self, state: T.Tensor):
232
+ if state.dim() == 3:
233
+ state = state.unsqueeze(0)
234
+ cnn_out = self.cnn(state)
235
+ logits = self.fc(cnn_out)
236
+ return logits
237
+
238
+ # Q-network for discrete actions
239
+ class QNetwork(nn.Module):
240
+ def __init__(self, obs_shape: tuple, action_dim: int, hidden: int):
241
+ super().__init__()
242
+ c, h, w = obs_shape
243
+ self.cnn = nn.Sequential(
244
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
245
+ nn.ReLU(),
246
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
247
+ nn.ReLU(),
248
+ nn.Flatten()
249
+ )
250
+ with T.no_grad():
251
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
252
+ self.net = nn.Sequential(
253
+ nn.Linear(cnn_output_dim, hidden),
254
+ nn.ReLU(),
255
+ nn.Linear(hidden, action_dim)
256
+ )
257
+
258
+ def forward(self, state: T.Tensor):
259
+ if state.dim() == 3:
260
+ state = state.unsqueeze(0)
261
+ cnn_out = self.cnn(state)
262
+ return self.net(cnn_out)
263
+
264
+ class Memory:
265
+ def __init__(self):
266
+ self.states, self.actions, self.rewards, self.dones, self.next_states = [], [], [], [], []
267
+ def store(self, s, a, r, d, ns):
268
+ self.states.append(np.asarray(s, dtype=np.float32))
269
+ self.actions.append(np.asarray(a, dtype=np.float32))
270
+ self.rewards.append(float(r))
271
+ self.dones.append(float(d))
272
+ self.next_states.append(np.asarray(ns, dtype=np.float32))
273
+ def clear(self):
274
+ self.states, self.actions, self.rewards, self.dones, self.next_states = [], [], [], [], []
Observation_norm_SAC/sac_model_cnn.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ale_py
2
+ import gymnasium as gym
3
+ import sys
4
+ import numpy as np
5
+ from sac_helpers_cnn import *
6
+ from gymnasium.spaces import Box
7
+ import cv2
8
+ import matplotlib.pyplot as plt
9
+
10
+ def preprocess(obs):
11
+ obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
12
+ obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
13
+ obs = np.expand_dims(obs, axis=0)
14
+ return obs.astype(np.float32) / 255.0
15
+ """
16
+ def main() -> int:
17
+ episode = 0
18
+ total_return = 0
19
+ ep_return = 0
20
+ steps = 100
21
+ batches = 100
22
+ avg_returns = []
23
+ avg_losses = []
24
+
25
+ env = gym.make("ALE/Pacman-v5")
26
+ # Initialize CNN with a dummy observation (to get correct input shape)
27
+ obs, _ = env.reset()
28
+ dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
29
+
30
+ agent = Agent(
31
+ obs_space=dummy_obs_space,
32
+ action_space=env.action_space,
33
+ hidden=64,
34
+ gamma=0.99,
35
+ lr=3e-4,
36
+ alpha=0.2,
37
+ seed=70,
38
+ batch_size=32,
39
+ tau=0.005
40
+ )
41
+
42
+ try:
43
+ obs, info = env.reset(seed=42)
44
+ state = preprocess(obs)
45
+
46
+ for update in range(1, batches + 1):
47
+ batch_loss = []
48
+
49
+ for t in range(steps):
50
+ action = agent.choose_action(state)
51
+ next_obs, reward, terminated, truncated, info = env.step(action)
52
+ done = terminated or truncated
53
+ next_state = preprocess(next_obs)
54
+
55
+ agent.remember(state, action, reward, done, next_state)
56
+
57
+ ep_return += reward
58
+ state = next_state
59
+
60
+ if done:
61
+ episode += 1
62
+ total_return += ep_return
63
+ print(f"Episode {episode} return: {ep_return:.2f}")
64
+ ep_return = 0
65
+ obs, info = env.reset()
66
+ state = preprocess(obs)
67
+
68
+ loss = agent.vanilla_sac_update()
69
+ batch_loss.append(loss)
70
+
71
+ avg_ret = (total_return / episode) if episode else 0
72
+ avg_returns.append(avg_ret)
73
+ avg_losses.append(np.mean(batch_loss))
74
+
75
+ print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={np.mean(batch_loss):.4f}")
76
+
77
+ except Exception as e:
78
+ print(f"Error: {e}", file=sys.stderr)
79
+ return 1
80
+ finally:
81
+ avg = total_return / episode if episode else 0
82
+ print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}")
83
+
84
+ # Plot learning curve
85
+ plt.figure(figsize=(10, 4))
86
+ plt.subplot(1, 2, 1)
87
+ plt.plot(avg_returns)
88
+ plt.xlabel("Update")
89
+ plt.ylabel("Average Return")
90
+ plt.title("SAC Learning Curve")
91
+ plt.grid()
92
+
93
+ plt.subplot(1, 2, 2)
94
+ plt.plot(avg_losses)
95
+ plt.xlabel("Update")
96
+ plt.ylabel("Average Loss")
97
+ plt.title("Average Loss Curve")
98
+ plt.grid()
99
+
100
+ plt.tight_layout()
101
+ plt.show()
102
+ env.close()
103
+
104
+ return 0
105
+ """
106
+
107
+ def run_training(seed=42):
108
+ episode = 0
109
+ total_return = 0
110
+ ep_return = 0
111
+ steps = 100
112
+ batches = 100
113
+ avg_returns = []
114
+ avg_losses = []
115
+
116
+ env = gym.make("ALE/Pacman-v5")
117
+ obs, _ = env.reset()
118
+ dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
119
+
120
+ agent = Agent(
121
+ obs_space=dummy_obs_space,
122
+ action_space=env.action_space,
123
+ hidden=64,
124
+ gamma=0.99,
125
+ lr=3e-4,
126
+ alpha=0.2,
127
+ seed=seed,
128
+ batch_size=32,
129
+ tau=0.005
130
+ )
131
+
132
+ obs, info = env.reset(seed=seed)
133
+ state = preprocess(obs)
134
+
135
+ for update in range(1, batches + 1):
136
+ batch_loss = []
137
+ for t in range(steps):
138
+ action = agent.choose_action(state)
139
+ next_obs, reward, terminated, truncated, info = env.step(action)
140
+ done = terminated or truncated
141
+ next_state = preprocess(next_obs)
142
+
143
+ agent.remember(state, action, reward, done, next_state)
144
+
145
+ ep_return += reward
146
+ state = next_state
147
+
148
+ if done:
149
+ episode += 1
150
+ total_return += ep_return
151
+ ep_return = 0
152
+ obs, info = env.reset()
153
+ state = preprocess(obs)
154
+
155
+ #loss = agent.vanilla_sac_update()
156
+ loss = agent.update_reward_gradient_clipping()
157
+ batch_loss.append(loss)
158
+
159
+ avg_ret = (total_return / episode) if episode else 0
160
+ avg_returns.append(avg_ret)
161
+ avg_losses.append(np.mean(batch_loss))
162
+
163
+ env.close()
164
+ return avg_returns, avg_losses
165
+
166
+ def main() -> int:
167
+ num_runs = 5
168
+ all_returns = []
169
+ all_losses = []
170
+
171
+ for run in range(num_runs):
172
+ print(f"Starting run {run+1}/{num_runs}")
173
+ avg_returns, avg_losses = run_training(seed=42 + run)
174
+ all_returns.append(avg_returns)
175
+ all_losses.append(avg_losses)
176
+
177
+ # Convert to numpy arrays for easy averaging
178
+ all_returns = np.array(all_returns)
179
+ all_losses = np.array(all_losses)
180
+
181
+ mean_returns = np.mean(all_returns, axis=0)
182
+ mean_losses = np.mean(all_losses, axis=0)
183
+
184
+ # Plot averaged learning curves
185
+ plt.figure(figsize=(10, 4))
186
+ plt.subplot(1, 2, 1)
187
+ plt.plot(mean_returns)
188
+ plt.xlabel("Update")
189
+ plt.ylabel("Average Return")
190
+ plt.title(f"SAC Learning Curve (avg over {num_runs} runs)")
191
+ plt.grid()
192
+
193
+ plt.subplot(1, 2, 2)
194
+ plt.plot(mean_losses)
195
+ plt.xlabel("Update")
196
+ plt.ylabel("Average Loss")
197
+ plt.title(f"Average Loss Curve (avg over {num_runs} runs)")
198
+ plt.grid()
199
+
200
+ plt.tight_layout()
201
+ plt.show()
202
+
203
+ return 0
204
+
205
+ if __name__ == "__main__":
206
+ raise SystemExit(main())
SAC-2/sac-project/sac_helpers_cnn.py CHANGED
@@ -11,6 +11,7 @@ class Agent:
11
  T.manual_seed(seed)
12
 
13
  # Use GPU if available
 
14
  self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
15
  self.action_dim = int(getattr(action_space, "n", action_space.n)) # Use .n for Discrete
16
  self.obs_shape = obs_space.shape
 
11
  T.manual_seed(seed)
12
 
13
  # Use GPU if available
14
+
15
  self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
16
  self.action_dim = int(getattr(action_space, "n", action_space.n)) # Use .n for Discrete
17
  self.obs_shape = obs_space.shape