rl-project-7Oct commited on
Commit
e6a9a76
·
verified ·
1 Parent(s): 8c8edd8

Upload 2 files

Browse files
SAC-2/sac-project/sac_helpers_cnn.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as T
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.distributions import Categorical
6
+
7
+ class Agent:
8
+ def __init__(self, obs_space, action_space, hidden, gamma, lr, alpha, seed, batch_size, tau=0.005):
9
+ if seed is not None:
10
+ np.random.seed(seed)
11
+ T.manual_seed(seed)
12
+
13
+ # Use GPU if available
14
+ self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
15
+ self.action_dim = int(getattr(action_space, "n", action_space.n)) # Use .n for Discrete
16
+ self.obs_shape = obs_space.shape
17
+
18
+ self.gamma, self.tau, self.batch_size = gamma, tau, batch_size
19
+ # Make alpha learnable (adjust entropy based on reward magnitude)
20
+ self.target_entropy = -float(self.action_dim)
21
+ self.log_alpha = T.tensor(np.log(alpha), requires_grad=True, device=self.device)
22
+ self.alpha = np.exp(self.log_alpha.item())
23
+ self.alpha_opt = optim.Adam([self.log_alpha], lr=lr)
24
+
25
+ self.policy = CategoricalActor(self.obs_shape, self.action_dim, hidden).to(self.device)
26
+ self.q1 = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
27
+ self.q2 = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
28
+ self.q1_target = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
29
+ self.q2_target = QNetwork(self.obs_shape, self.action_dim, hidden).to(self.device)
30
+ self.q1_target.load_state_dict(self.q1.state_dict())
31
+ self.q2_target.load_state_dict(self.q2.state_dict())
32
+
33
+ self.policy_opt = optim.Adam(self.policy.parameters(), lr=lr)
34
+ self.q1_opt = optim.Adam(self.q1.parameters(), lr=lr)
35
+ self.q2_opt = optim.Adam(self.q2.parameters(), lr=lr)
36
+ self.memory = Memory()
37
+
38
+ def choose_action(self, observation, eval_mode=False):
39
+ state = T.as_tensor(observation, dtype=T.float32, device=self.device)
40
+ with T.no_grad():
41
+ logits = self.policy(state.unsqueeze(0))
42
+ dist = Categorical(logits=logits)
43
+ if eval_mode:
44
+ action = logits.argmax(dim=-1)
45
+ else:
46
+ action = dist.sample()
47
+ return int(action.item())
48
+
49
+ def remember(self, state, action, reward, done, next_state):
50
+ self.memory.store(state, action, reward, done, next_state)
51
+
52
+ def vanilla_sac_update(self):
53
+ if len(self.memory.states) < self.batch_size:
54
+ return 0.0
55
+
56
+ # Mini-batch sampling
57
+ idxs = np.random.choice(len(self.memory.states), self.batch_size, replace=False)
58
+ states = T.as_tensor(np.array([self.memory.states[i] for i in idxs]), dtype=T.float32, device=self.device)
59
+ actions = T.as_tensor(np.array([self.memory.actions[i] for i in idxs]), dtype=T.int64, device=self.device).unsqueeze(-1)
60
+ rewards = T.as_tensor(np.array([self.memory.rewards[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
61
+ dones = T.as_tensor(np.array([self.memory.dones[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
62
+ next_states = T.as_tensor(np.array([self.memory.next_states[i] for i in idxs]), dtype=T.float32, device=self.device)
63
+
64
+ # Critic update, Soft Q-Learning Objective: to ensure high-entropy actions for exploration
65
+ with T.no_grad():
66
+ next_logits = self.policy(next_states)
67
+ next_dist = Categorical(logits=next_logits)
68
+ next_probs = next_dist.probs
69
+ next_log_probs = next_dist.logits - T.logsumexp(next_dist.logits, dim=-1, keepdim=True)
70
+ q1_next = self.q1_target(next_states)
71
+ q2_next = self.q2_target(next_states)
72
+ # Soft Policy Evaluation
73
+ min_q_next = T.min(q1_next, q2_next)
74
+ next_value = (next_probs * (min_q_next - self.alpha * next_log_probs)).sum(dim=-1, keepdim=True)
75
+ target = rewards + self.gamma * (1 - dones) * next_value
76
+
77
+ q1 = self.q1(states).gather(1, actions)
78
+ q2 = self.q2(states).gather(1, actions)
79
+
80
+ # Losses of both Q-functions
81
+ q1_loss = nn.MSELoss()(q1, target)
82
+ q2_loss = nn.MSELoss()(q2, target)
83
+
84
+ self.q1_opt.zero_grad()
85
+ q1_loss.backward()
86
+ self.q1_opt.step()
87
+ self.q2_opt.zero_grad()
88
+ q2_loss.backward()
89
+ self.q2_opt.step()
90
+
91
+ # Policy/Actor Objective
92
+ logits = self.policy(states)
93
+ dist = Categorical(logits=logits)
94
+ probs = dist.probs
95
+ log_probs = dist.logits - T.logsumexp(dist.logits, dim=-1, keepdim=True)
96
+ q1_policy = self.q1(states)
97
+ q2_policy = self.q2(states)
98
+ min_q_policy = T.min(q1_policy, q2_policy)
99
+ # Slightly different policy loss for discrete actions
100
+ policy_loss = (probs * (self.alpha * log_probs - min_q_policy)).sum(dim=-1).mean()
101
+
102
+ # Temperature to update Alpha
103
+ alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
104
+ self.alpha_opt.zero_grad()
105
+ alpha_loss.backward()
106
+ self.alpha_opt.step()
107
+ self.alpha = self.log_alpha.exp().item()
108
+
109
+ self.policy_opt.zero_grad()
110
+ policy_loss.backward()
111
+ self.policy_opt.step()
112
+
113
+ # Target network update
114
+ for target_param, param in zip(self.q1_target.parameters(), self.q1.parameters()):
115
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
116
+ for target_param, param in zip(self.q2_target.parameters(), self.q2.parameters()):
117
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
118
+
119
+ return policy_loss.item()
120
+
121
+ def update_reward_gradient_clipping(self):
122
+ if len(self.memory.states) < self.batch_size:
123
+ return 0.0
124
+
125
+ # Mini-batch sampling
126
+ idxs = np.random.choice(len(self.memory.states), self.batch_size, replace=False)
127
+ states = T.as_tensor(np.array([self.memory.states[i] for i in idxs]), dtype=T.float32, device=self.device)
128
+ actions = T.as_tensor(np.array([self.memory.actions[i] for i in idxs]), dtype=T.int64, device=self.device).unsqueeze(-1)
129
+ rewards = T.as_tensor(np.array([self.memory.rewards[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
130
+ dones = T.as_tensor(np.array([self.memory.dones[i] for i in idxs]), dtype=T.float32, device=self.device).unsqueeze(-1)
131
+ next_states = T.as_tensor(np.array([self.memory.next_states[i] for i in idxs]), dtype=T.float32, device=self.device)
132
+
133
+ """
134
+ # Min-max normalization and tanh scaling to [-1, 1]
135
+ rewards_np = np.array([self.memory.rewards[i] for i in idxs])
136
+ r_min = rewards_np.min()
137
+ r_max = rewards_np.max()
138
+ # Avoid division by zero
139
+ r_scaled = 2 * (rewards_np - r_min) / (r_max - r_min + 1e-8) - 1
140
+ normalized_rewards = np.tanh(r_scaled)
141
+ rewards = T.as_tensor(normalized_rewards, dtype=T.float32, device=self.device).unsqueeze(-1)
142
+ """
143
+
144
+ # Critic update, Soft Q-Learning Objective: to ensure high-entropy actions for exploration
145
+ with T.no_grad():
146
+ next_logits = self.policy(next_states)
147
+ next_dist = Categorical(logits=next_logits)
148
+ next_probs = next_dist.probs
149
+ next_log_probs = next_dist.logits - T.logsumexp(next_dist.logits, dim=-1, keepdim=True)
150
+ q1_next = self.q1_target(next_states)
151
+ q2_next = self.q2_target(next_states)
152
+ # Soft Policy Evaluation
153
+ min_q_next = T.min(q1_next, q2_next)
154
+ next_value = (next_probs * (min_q_next - self.alpha * next_log_probs)).sum(dim=-1, keepdim=True)
155
+ target = rewards + self.gamma * (1 - dones) * next_value
156
+
157
+ q1 = self.q1(states).gather(1, actions)
158
+ q2 = self.q2(states).gather(1, actions)
159
+
160
+ # Losses of both Q-functions
161
+ q1_loss = nn.MSELoss()(q1, target)
162
+ q2_loss = nn.MSELoss()(q2, target)
163
+
164
+ self.q1_opt.zero_grad()
165
+ q1_loss.backward()
166
+ self.q1_opt.step()
167
+ self.q2_opt.zero_grad()
168
+ q2_loss.backward()
169
+ self.q2_opt.step()
170
+
171
+ # Policy/Actor Objective
172
+ logits = self.policy(states)
173
+ dist = Categorical(logits=logits)
174
+ probs = dist.probs
175
+ log_probs = dist.logits - T.logsumexp(dist.logits, dim=-1, keepdim=True)
176
+ q1_policy = self.q1(states)
177
+ q2_policy = self.q2(states)
178
+ min_q_policy = T.min(q1_policy, q2_policy)
179
+ # Slightly different policy loss for discrete actions
180
+ policy_loss = (probs * (self.alpha * log_probs - min_q_policy)).sum(dim=-1).mean()
181
+
182
+ # Temperature to update Alpha
183
+ alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
184
+ self.alpha_opt.zero_grad()
185
+ alpha_loss.backward()
186
+ T.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=1.0) # Gradient clipping
187
+ self.alpha_opt.step()
188
+ self.alpha = self.log_alpha.exp().item()
189
+
190
+ self.policy_opt.zero_grad()
191
+ policy_loss.backward()
192
+ T.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=1.0) # Gradient clipping
193
+ self.policy_opt.step()
194
+
195
+ # Target network update
196
+ for target_param, param in zip(self.q1_target.parameters(), self.q1.parameters()):
197
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
198
+ for target_param, param in zip(self.q2_target.parameters(), self.q2.parameters()):
199
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
200
+
201
+ return policy_loss.item()
202
+
203
+ # Actor/Policy network
204
+ # Typical SAC Actor network is used to output a Gaussian distribution of a state
205
+ # Here, we adapt it for discrete actions using a Categorical distribution, as the ATARI environment is discrete
206
+ # The policy outputs logits for each discrete action.
207
+
208
+ # From: https://ch.mathworks.com/help/reinforcement-learning/ug/soft-actor-critic-agents.html
209
+ # The actor takes the current observation and generates a categorical distribution, in which each possible action is associated with a probability.
210
+
211
+ class CategoricalActor(nn.Module):
212
+ def __init__(self, obs_shape: tuple, action_dim: int, hidden: int):
213
+ super().__init__()
214
+ c, h, w = obs_shape
215
+ self.cnn = nn.Sequential(
216
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
217
+ nn.ReLU(),
218
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
219
+ nn.ReLU(),
220
+ nn.Flatten()
221
+ )
222
+ with T.no_grad():
223
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
224
+ self.fc = nn.Sequential(
225
+ nn.Linear(cnn_output_dim, hidden),
226
+ nn.ReLU(),
227
+ nn.Linear(hidden, action_dim)
228
+ )
229
+
230
+ def forward(self, state: T.Tensor):
231
+ if state.dim() == 3:
232
+ state = state.unsqueeze(0)
233
+ cnn_out = self.cnn(state)
234
+ logits = self.fc(cnn_out)
235
+ return logits
236
+
237
+ # Q-network for discrete actions
238
+ class QNetwork(nn.Module):
239
+ def __init__(self, obs_shape: tuple, action_dim: int, hidden: int):
240
+ super().__init__()
241
+ c, h, w = obs_shape
242
+ self.cnn = nn.Sequential(
243
+ nn.Conv2d(c, 16, kernel_size=8, stride=4),
244
+ nn.ReLU(),
245
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
246
+ nn.ReLU(),
247
+ nn.Flatten()
248
+ )
249
+ with T.no_grad():
250
+ cnn_output_dim = self.cnn(T.zeros(1, c, h, w)).shape[1]
251
+ self.net = nn.Sequential(
252
+ nn.Linear(cnn_output_dim, hidden),
253
+ nn.ReLU(),
254
+ nn.Linear(hidden, action_dim)
255
+ )
256
+
257
+ def forward(self, state: T.Tensor):
258
+ if state.dim() == 3:
259
+ state = state.unsqueeze(0)
260
+ cnn_out = self.cnn(state)
261
+ return self.net(cnn_out)
262
+
263
+ class Memory:
264
+ def __init__(self):
265
+ self.states, self.actions, self.rewards, self.dones, self.next_states = [], [], [], [], []
266
+ def store(self, s, a, r, d, ns):
267
+ self.states.append(np.asarray(s, dtype=np.float32))
268
+ self.actions.append(np.asarray(a, dtype=np.float32))
269
+ self.rewards.append(float(r))
270
+ self.dones.append(float(d))
271
+ self.next_states.append(np.asarray(ns, dtype=np.float32))
272
+ def clear(self):
273
+ self.states, self.actions, self.rewards, self.dones, self.next_states = [], [], [], [], []
SAC-2/sac-project/sac_model_cnn.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ale_py
2
+ import gymnasium as gym
3
+ import sys
4
+ import numpy as np
5
+ from sac_helpers_cnn import *
6
+ from gymnasium.spaces import Box
7
+ import cv2
8
+ import matplotlib.pyplot as plt
9
+
10
+ def preprocess(obs):
11
+ obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
12
+ obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
13
+ obs = np.expand_dims(obs, axis=0)
14
+ return obs.astype(np.float32) / 255.0
15
+ """
16
+ def main() -> int:
17
+ episode = 0
18
+ total_return = 0
19
+ ep_return = 0
20
+ steps = 100
21
+ batches = 100
22
+ avg_returns = []
23
+ avg_losses = []
24
+
25
+ env = gym.make("ALE/Pacman-v5")
26
+ # Initialize CNN with a dummy observation (to get correct input shape)
27
+ obs, _ = env.reset()
28
+ dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
29
+
30
+ agent = Agent(
31
+ obs_space=dummy_obs_space,
32
+ action_space=env.action_space,
33
+ hidden=64,
34
+ gamma=0.99,
35
+ lr=3e-4,
36
+ alpha=0.2,
37
+ seed=70,
38
+ batch_size=32,
39
+ tau=0.005
40
+ )
41
+
42
+ try:
43
+ obs, info = env.reset(seed=42)
44
+ state = preprocess(obs)
45
+
46
+ for update in range(1, batches + 1):
47
+ batch_loss = []
48
+
49
+ for t in range(steps):
50
+ action = agent.choose_action(state)
51
+ next_obs, reward, terminated, truncated, info = env.step(action)
52
+ done = terminated or truncated
53
+ next_state = preprocess(next_obs)
54
+
55
+ agent.remember(state, action, reward, done, next_state)
56
+
57
+ ep_return += reward
58
+ state = next_state
59
+
60
+ if done:
61
+ episode += 1
62
+ total_return += ep_return
63
+ print(f"Episode {episode} return: {ep_return:.2f}")
64
+ ep_return = 0
65
+ obs, info = env.reset()
66
+ state = preprocess(obs)
67
+
68
+ loss = agent.vanilla_sac_update()
69
+ batch_loss.append(loss)
70
+
71
+ avg_ret = (total_return / episode) if episode else 0
72
+ avg_returns.append(avg_ret)
73
+ avg_losses.append(np.mean(batch_loss))
74
+
75
+ print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={np.mean(batch_loss):.4f}")
76
+
77
+ except Exception as e:
78
+ print(f"Error: {e}", file=sys.stderr)
79
+ return 1
80
+ finally:
81
+ avg = total_return / episode if episode else 0
82
+ print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}")
83
+
84
+ # Plot learning curve
85
+ plt.figure(figsize=(10, 4))
86
+ plt.subplot(1, 2, 1)
87
+ plt.plot(avg_returns)
88
+ plt.xlabel("Update")
89
+ plt.ylabel("Average Return")
90
+ plt.title("SAC Learning Curve")
91
+ plt.grid()
92
+
93
+ plt.subplot(1, 2, 2)
94
+ plt.plot(avg_losses)
95
+ plt.xlabel("Update")
96
+ plt.ylabel("Average Loss")
97
+ plt.title("Average Loss Curve")
98
+ plt.grid()
99
+
100
+ plt.tight_layout()
101
+ plt.show()
102
+ env.close()
103
+
104
+ return 0
105
+ """
106
+
107
+ def run_training(seed=42):
108
+ episode = 0
109
+ total_return = 0
110
+ ep_return = 0
111
+ steps = 100
112
+ batches = 100
113
+ avg_returns = []
114
+ avg_losses = []
115
+
116
+ env = gym.make("ALE/Pacman-v5")
117
+ obs, _ = env.reset()
118
+ dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
119
+
120
+ agent = Agent(
121
+ obs_space=dummy_obs_space,
122
+ action_space=env.action_space,
123
+ hidden=64,
124
+ gamma=0.99,
125
+ lr=3e-4,
126
+ alpha=0.2,
127
+ seed=seed,
128
+ batch_size=32,
129
+ tau=0.005
130
+ )
131
+
132
+ obs, info = env.reset(seed=seed)
133
+ state = preprocess(obs)
134
+
135
+ for update in range(1, batches + 1):
136
+ batch_loss = []
137
+ for t in range(steps):
138
+ action = agent.choose_action(state)
139
+ next_obs, reward, terminated, truncated, info = env.step(action)
140
+ done = terminated or truncated
141
+ next_state = preprocess(next_obs)
142
+
143
+ agent.remember(state, action, reward, done, next_state)
144
+
145
+ ep_return += reward
146
+ state = next_state
147
+
148
+ if done:
149
+ episode += 1
150
+ total_return += ep_return
151
+ ep_return = 0
152
+ obs, info = env.reset()
153
+ state = preprocess(obs)
154
+
155
+ #loss = agent.vanilla_sac_update()
156
+ loss = agent.update_reward_gradient_clipping()
157
+ batch_loss.append(loss)
158
+
159
+ avg_ret = (total_return / episode) if episode else 0
160
+ avg_returns.append(avg_ret)
161
+ avg_losses.append(np.mean(batch_loss))
162
+
163
+ env.close()
164
+ return avg_returns, avg_losses
165
+
166
+ def main() -> int:
167
+ num_runs = 5
168
+ all_returns = []
169
+ all_losses = []
170
+
171
+ for run in range(num_runs):
172
+ print(f"Starting run {run+1}/{num_runs}")
173
+ avg_returns, avg_losses = run_training(seed=42 + run)
174
+ all_returns.append(avg_returns)
175
+ all_losses.append(avg_losses)
176
+
177
+ # Convert to numpy arrays for easy averaging
178
+ all_returns = np.array(all_returns)
179
+ all_losses = np.array(all_losses)
180
+
181
+ mean_returns = np.mean(all_returns, axis=0)
182
+ mean_losses = np.mean(all_losses, axis=0)
183
+
184
+ # Plot averaged learning curves
185
+ plt.figure(figsize=(10, 4))
186
+ plt.subplot(1, 2, 1)
187
+ plt.plot(mean_returns)
188
+ plt.xlabel("Update")
189
+ plt.ylabel("Average Return")
190
+ plt.title(f"SAC Learning Curve (avg over {num_runs} runs)")
191
+ plt.grid()
192
+
193
+ plt.subplot(1, 2, 2)
194
+ plt.plot(mean_losses)
195
+ plt.xlabel("Update")
196
+ plt.ylabel("Average Loss")
197
+ plt.title(f"Average Loss Curve (avg over {num_runs} runs)")
198
+ plt.grid()
199
+
200
+ plt.tight_layout()
201
+ plt.show()
202
+
203
+ return 0
204
+
205
+ if __name__ == "__main__":
206
+ raise SystemExit(main())