OmidSakaki commited on
Commit
144fc70
·
verified ·
1 Parent(s): bdb3f57

Update src/agents/advanced_agent.py

Browse files
Files changed (1) hide show
  1. src/agents/advanced_agent.py +165 -113
src/agents/advanced_agent.py CHANGED
@@ -4,11 +4,13 @@ import torch.optim as optim
4
  import numpy as np
5
  from collections import deque
6
  import random
 
 
7
 
8
  class EnhancedTradingNetwork(nn.Module):
9
  def __init__(self, state_dim, action_dim, sentiment_dim=2):
10
  super(EnhancedTradingNetwork, self).__init__()
11
-
12
  # Visual processing branch
13
  self.visual_conv = nn.Sequential(
14
  nn.Conv2d(4, 16, kernel_size=4, stride=2),
@@ -19,16 +21,16 @@ class EnhancedTradingNetwork(nn.Module):
19
  nn.ReLU(),
20
  nn.AdaptiveAvgPool2d((8, 8))
21
  )
22
-
23
  # Calculate the output size after conv layers
24
  self.conv_output_size = 32 * 8 * 8
25
-
26
  self.visual_fc = nn.Sequential(
27
  nn.Linear(self.conv_output_size, 256),
28
  nn.ReLU(),
29
  nn.Dropout(0.3)
30
  )
31
-
32
  # Sentiment processing branch
33
  self.sentiment_fc = nn.Sequential(
34
  nn.Linear(sentiment_dim, 64),
@@ -37,7 +39,7 @@ class EnhancedTradingNetwork(nn.Module):
37
  nn.Linear(64, 32),
38
  nn.ReLU()
39
  )
40
-
41
  # Combined decision making
42
  self.combined_fc = nn.Sequential(
43
  nn.Linear(256 + 32, 128),
@@ -47,179 +49,232 @@ class EnhancedTradingNetwork(nn.Module):
47
  nn.ReLU(),
48
  nn.Linear(64, action_dim)
49
  )
50
-
 
 
 
51
  def forward(self, x, sentiment=None):
52
  try:
53
- # Visual processing with proper reshaping
54
- # x shape: (batch_size, 84, 84, 4) -> (batch_size, 4, 84, 84)
55
- if len(x.shape) == 4: # (batch, H, W, C)
56
- x = x.permute(0, 3, 1, 2).contiguous()
 
57
  else:
58
- # Handle single sample case
59
- x = x.unsqueeze(0) if len(x.shape) == 3 else x
60
- x = x.permute(0, 3, 1, 2).contiguous()
61
 
62
- visual_features = self.visual_conv(x)
 
 
 
 
 
63
 
64
- # Use reshape instead of view for safety
65
  batch_size = visual_features.size(0)
66
  visual_features = visual_features.reshape(batch_size, -1)
67
-
68
  visual_features = self.visual_fc(visual_features)
69
-
70
  # Sentiment processing
71
- if sentiment is not None:
72
  if len(sentiment.shape) == 1:
73
  sentiment = sentiment.unsqueeze(0)
 
74
  sentiment_features = self.sentiment_fc(sentiment)
75
  combined_features = torch.cat([visual_features, sentiment_features], dim=1)
76
  else:
77
- combined_features = visual_features
78
-
79
- # Final decision
 
80
  q_values = self.combined_fc(combined_features)
81
  return q_values
82
-
83
  except Exception as e:
84
  print(f"Error in network forward: {e}")
85
- # Return safe default
86
- return torch.zeros((x.size(0) if hasattr(x, 'size') else 1, self.combined_fc[-1].out_features))
 
 
87
 
88
  class AdvancedTradingAgent:
89
  def __init__(self, state_dim, action_dim, learning_rate=0.001, use_sentiment=True):
90
- self.state_dim = state_dim
91
  self.action_dim = action_dim
92
  self.learning_rate = learning_rate
93
  self.use_sentiment = use_sentiment
94
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95
  print(f"Using device: {self.device}")
96
-
97
  # Neural network
98
  self.policy_net = EnhancedTradingNetwork(state_dim, action_dim).to(self.device)
99
  self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
100
-
 
101
  # Experience replay
102
- self.memory = deque(maxlen=500)
103
- self.batch_size = 16
104
-
105
  # Training parameters
106
  self.gamma = 0.99
107
  self.epsilon = 1.0
108
- self.epsilon_min = 0.1
109
- self.epsilon_decay = 0.995
110
  self.steps_done = 0
111
-
112
- def select_action(self, state, current_sentiment=0.5, sentiment_confidence=0.0):
113
- """Select action with sentiment consideration"""
 
 
114
  if random.random() < self.epsilon:
115
  return random.randint(0, self.action_dim - 1)
116
-
117
  try:
118
- # Normalize state
119
- state_normalized = state.astype(np.float32) / 255.0
120
- state_tensor = torch.FloatTensor(state_normalized).to(self.device)
121
 
122
- if self.use_sentiment:
123
- # Add sentiment to the decision process
124
- sentiment_tensor = torch.FloatTensor([current_sentiment, sentiment_confidence]).to(self.device)
 
 
 
 
 
 
 
 
 
 
125
  with torch.no_grad():
126
  q_values = self.policy_net(state_tensor, sentiment_tensor)
127
  else:
128
  with torch.no_grad():
129
  q_values = self.policy_net(state_tensor)
130
-
131
- return int(q_values.argmax().item())
132
-
 
133
  except Exception as e:
134
- print(f"Error in advanced action selection: {e}")
135
  return random.randint(0, self.action_dim - 1)
136
-
137
  def store_transition(self, state, action, reward, next_state, done, sentiment_data=None):
138
- """Store experience with sentiment data"""
139
  try:
140
- experience = (state, action, reward, next_state, done, sentiment_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  self.memory.append(experience)
 
142
  except Exception as e:
143
  print(f"Error storing transition: {e}")
144
-
145
  def update(self):
146
- """Update network with sentiment-enhanced learning"""
147
  if len(self.memory) < self.batch_size:
148
  return 0.0
149
-
150
  try:
151
- # Sample batch from memory
152
  batch = random.sample(self.memory, self.batch_size)
153
- states, actions, rewards, next_states, dones, sentiment_data = zip(*batch)
154
-
155
- # Convert to tensors with proper shape handling
156
- states_array = np.array(states, dtype=np.float32) / 255.0
157
- next_states_array = np.array(next_states, dtype=np.float32) / 255.0
158
-
159
- # Ensure proper tensor shapes
160
- states_tensor = torch.FloatTensor(states_array).to(self.device)
161
- next_states_tensor = torch.FloatTensor(next_states_array).to(self.device)
162
 
 
 
163
  actions_tensor = torch.LongTensor(actions).to(self.device)
164
  rewards_tensor = torch.FloatTensor(rewards).to(self.device)
165
  dones_tensor = torch.BoolTensor(dones).to(self.device)
166
-
167
- if self.use_sentiment and sentiment_data[0] is not None:
168
- # Extract sentiment features safely
169
- sentiment_features = []
170
- for data in sentiment_data:
171
- if data and 'sentiment' in data and 'confidence' in data:
172
- sentiment_features.append([data['sentiment'], data['confidence']])
173
- else:
174
- sentiment_features.append([0.5, 0.0])
175
-
176
- sentiment_tensor = torch.FloatTensor(sentiment_features).to(self.device)
177
-
178
- # Current Q values with sentiment
179
  current_q = self.policy_net(states_tensor, sentiment_tensor)
180
- current_q = current_q.gather(1, actions_tensor.unsqueeze(1))
181
-
182
- # Next Q values with sentiment
183
- with torch.no_grad():
184
- next_q = self.policy_net(next_states_tensor, sentiment_tensor)
185
- next_q = next_q.max(1)[0]
186
- target_q = rewards_tensor + (self.gamma * next_q * ~dones_tensor)
187
  else:
188
- # Fallback to standard DQN without sentiment
189
  current_q = self.policy_net(states_tensor)
190
- current_q = current_q.gather(1, actions_tensor.unsqueeze(1))
191
-
192
- with torch.no_grad():
193
- next_q = self.policy_net(next_states_tensor)
194
- next_q = next_q.max(1)[0]
195
- target_q = rewards_tensor + (self.gamma * next_q * ~dones_tensor)
196
-
197
- # Compute loss
198
- loss = nn.MSELoss()(current_q.squeeze(), target_q)
199
 
200
- # Optimize
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  self.optimizer.zero_grad()
202
  loss.backward()
203
-
204
- # Gradient clipping for stability
205
  torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
206
  self.optimizer.step()
207
-
208
- # Update exploration
209
- self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
 
 
210
  self.steps_done += 1
 
211
 
 
 
 
 
212
  return float(loss.item())
213
-
214
  except Exception as e:
215
- print(f"Error in advanced update: {e}")
 
 
216
  return 0.0
 
 
 
 
217
 
218
- # Fallback to simple agent if advanced one fails
219
  class SimpleTradingNetwork(nn.Module):
220
  def __init__(self, state_dim, action_dim):
221
  super(SimpleTradingNetwork, self).__init__()
222
-
 
223
  self.conv_layers = nn.Sequential(
224
  nn.Conv2d(4, 16, kernel_size=4, stride=2),
225
  nn.ReLU(),
@@ -229,7 +284,7 @@ class SimpleTradingNetwork(nn.Module):
229
  nn.ReLU(),
230
  nn.AdaptiveAvgPool2d((8, 8))
231
  )
232
-
233
  self.fc_layers = nn.Sequential(
234
  nn.Linear(32 * 8 * 8, 128),
235
  nn.ReLU(),
@@ -238,21 +293,18 @@ class SimpleTradingNetwork(nn.Module):
238
  nn.ReLU(),
239
  nn.Linear(64, action_dim)
240
  )
241
-
242
  def forward(self, x):
243
  try:
244
- # Handle input shape
245
- if len(x.shape) == 4: # (batch, H, W, C)
246
- x = x.permute(0, 3, 1, 2).contiguous()
247
- else:
248
- x = x.unsqueeze(0) if len(x.shape) == 3 else x
249
- x = x.permute(0, 3, 1, 2).contiguous()
250
 
251
  x = self.conv_layers(x)
252
- batch_size = x.size(0)
253
- x = x.reshape(batch_size, -1)
254
  x = self.fc_layers(x)
255
  return x
256
  except Exception as e:
257
  print(f"Error in simple network: {e}")
258
- return torch.zeros((x.size(0), self.fc_layers[-1].out_features))
 
 
4
  import numpy as np
5
  from collections import deque
6
  import random
7
+ import warnings
8
+ warnings.filterwarnings('ignore')
9
 
10
  class EnhancedTradingNetwork(nn.Module):
11
  def __init__(self, state_dim, action_dim, sentiment_dim=2):
12
  super(EnhancedTradingNetwork, self).__init__()
13
+
14
  # Visual processing branch
15
  self.visual_conv = nn.Sequential(
16
  nn.Conv2d(4, 16, kernel_size=4, stride=2),
 
21
  nn.ReLU(),
22
  nn.AdaptiveAvgPool2d((8, 8))
23
  )
24
+
25
  # Calculate the output size after conv layers
26
  self.conv_output_size = 32 * 8 * 8
27
+
28
  self.visual_fc = nn.Sequential(
29
  nn.Linear(self.conv_output_size, 256),
30
  nn.ReLU(),
31
  nn.Dropout(0.3)
32
  )
33
+
34
  # Sentiment processing branch
35
  self.sentiment_fc = nn.Sequential(
36
  nn.Linear(sentiment_dim, 64),
 
39
  nn.Linear(64, 32),
40
  nn.ReLU()
41
  )
42
+
43
  # Combined decision making
44
  self.combined_fc = nn.Sequential(
45
  nn.Linear(256 + 32, 128),
 
49
  nn.ReLU(),
50
  nn.Linear(64, action_dim)
51
  )
52
+
53
+ # Store action_dim for error handling
54
+ self.action_dim = action_dim
55
+
56
  def forward(self, x, sentiment=None):
57
  try:
58
+ # Ensure input has batch dimension
59
+ if len(x.shape) == 3: # (H, W, C)
60
+ x = x.unsqueeze(0)
61
+ elif len(x.shape) == 4: # (batch, H, W, C)
62
+ pass
63
  else:
64
+ raise ValueError(f"Invalid input shape: {x.shape}")
 
 
65
 
66
+ # Permute to (batch, C, H, W)
67
+ x = x.permute(0, 3, 1, 2).contiguous().float()
68
+
69
+ # Check if channels match expected input
70
+ if x.size(1) != 4:
71
+ raise ValueError(f"Expected 4 channels, got {x.size(1)}")
72
 
73
+ visual_features = self.visual_conv(x)
74
  batch_size = visual_features.size(0)
75
  visual_features = visual_features.reshape(batch_size, -1)
 
76
  visual_features = self.visual_fc(visual_features)
77
+
78
  # Sentiment processing
79
+ if sentiment is not None and self.sentiment_fc is not None:
80
  if len(sentiment.shape) == 1:
81
  sentiment = sentiment.unsqueeze(0)
82
+ sentiment = sentiment.float()
83
  sentiment_features = self.sentiment_fc(sentiment)
84
  combined_features = torch.cat([visual_features, sentiment_features], dim=1)
85
  else:
86
+ # Pad with zeros if no sentiment
87
+ sentiment_features = torch.zeros(batch_size, 32, device=visual_features.device)
88
+ combined_features = torch.cat([visual_features, sentiment_features], dim=1)
89
+
90
  q_values = self.combined_fc(combined_features)
91
  return q_values
92
+
93
  except Exception as e:
94
  print(f"Error in network forward: {e}")
95
+ print(f"Input shape: {getattr(x, 'shape', 'Unknown')}")
96
+ # Return safe default with correct shape
97
+ batch_size = x.size(0) if hasattr(x, 'size') else 1
98
+ return torch.zeros(batch_size, self.action_dim, device=(x.device if hasattr(x, 'device') else 'cpu'))
99
 
100
  class AdvancedTradingAgent:
101
  def __init__(self, state_dim, action_dim, learning_rate=0.001, use_sentiment=True):
102
+ self.state_dim = state_dim # Should be (84, 84, 4) or similar
103
  self.action_dim = action_dim
104
  self.learning_rate = learning_rate
105
  self.use_sentiment = use_sentiment
106
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
107
  print(f"Using device: {self.device}")
108
+
109
  # Neural network
110
  self.policy_net = EnhancedTradingNetwork(state_dim, action_dim).to(self.device)
111
  self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
112
+ self.loss_fn = nn.MSELoss()
113
+
114
  # Experience replay
115
+ self.memory = deque(maxlen=10000) # Increased buffer size
116
+ self.batch_size = min(32, state_dim[0]//2) # Dynamic batch size
117
+
118
  # Training parameters
119
  self.gamma = 0.99
120
  self.epsilon = 1.0
121
+ self.epsilon_min = 0.01 # More aggressive exploration decay
122
+ self.epsilon_decay = 0.9995 # Slower decay
123
  self.steps_done = 0
124
+ self.target_update_freq = 100 # Target network update frequency
125
+ self.steps_since_target_update = 0
126
+
127
+ def select_action(self, state, current_sentiment=None, sentiment_confidence=None):
128
+ """Select action with epsilon-greedy policy"""
129
  if random.random() < self.epsilon:
130
  return random.randint(0, self.action_dim - 1)
131
+
132
  try:
133
+ # Validate and normalize state
134
+ if not isinstance(state, np.ndarray):
135
+ state = np.array(state)
136
 
137
+ if state.dtype != np.float32:
138
+ state = state.astype(np.float32)
139
+
140
+ # Normalize pixel values
141
+ if state.max() > 1.0:
142
+ state = state / 255.0
143
+
144
+ state_tensor = torch.FloatTensor(state).to(self.device)
145
+
146
+ # Prepare sentiment input
147
+ if self.use_sentiment and current_sentiment is not None:
148
+ sentiment = np.array([float(current_sentiment), float(sentiment_confidence or 0.0)])
149
+ sentiment_tensor = torch.FloatTensor(sentiment).to(self.device)
150
  with torch.no_grad():
151
  q_values = self.policy_net(state_tensor, sentiment_tensor)
152
  else:
153
  with torch.no_grad():
154
  q_values = self.policy_net(state_tensor)
155
+
156
+ action = int(q_values.argmax().item())
157
+ return action
158
+
159
  except Exception as e:
160
+ print(f"Error in action selection: {e}")
161
  return random.randint(0, self.action_dim - 1)
162
+
163
  def store_transition(self, state, action, reward, next_state, done, sentiment_data=None):
164
+ """Store experience tuple safely"""
165
  try:
166
+ # Ensure all inputs are numpy arrays
167
+ if not isinstance(state, np.ndarray):
168
+ state = np.array(state, dtype=np.float32)
169
+ if not isinstance(next_state, np.ndarray):
170
+ next_state = np.array(next_state, dtype=np.float32)
171
+
172
+ # Normalize before storing
173
+ if state.max() > 1.0:
174
+ state = state / 255.0
175
+ if next_state.max() > 1.0:
176
+ next_state = next_state / 255.0
177
+
178
+ # Handle sentiment data
179
+ if sentiment_data is None:
180
+ sentiment_data = {'sentiment': 0.5, 'confidence': 0.0}
181
+
182
+ experience = (state, action, float(reward), next_state, bool(done), sentiment_data)
183
  self.memory.append(experience)
184
+
185
  except Exception as e:
186
  print(f"Error storing transition: {e}")
187
+
188
  def update(self):
189
+ """DQN update with improved stability"""
190
  if len(self.memory) < self.batch_size:
191
  return 0.0
192
+
193
  try:
 
194
  batch = random.sample(self.memory, self.batch_size)
195
+ states, actions, rewards, next_states, dones, sentiments = zip(*batch)
196
+
197
+ # Convert to tensors
198
+ states = np.stack(states)
199
+ next_states = np.stack(next_states)
200
+ actions = np.array(actions)
201
+ rewards = np.array(rewards)
202
+ dones = np.array(dones)
 
203
 
204
+ states_tensor = torch.FloatTensor(states).to(self.device)
205
+ next_states_tensor = torch.FloatTensor(next_states).to(self.device)
206
  actions_tensor = torch.LongTensor(actions).to(self.device)
207
  rewards_tensor = torch.FloatTensor(rewards).to(self.device)
208
  dones_tensor = torch.BoolTensor(dones).to(self.device)
209
+
210
+ # Compute current Q values
211
+ if self.use_sentiment:
212
+ # Use sentiment from current state
213
+ sentiment_batch = []
214
+ for sentiment_data in sentiments:
215
+ sentiment = [sentiment_data.get('sentiment', 0.5),
216
+ sentiment_data.get('confidence', 0.0)]
217
+ sentiment_batch.append(sentiment)
218
+ sentiment_tensor = torch.FloatTensor(sentiment_batch).to(self.device)
 
 
 
219
  current_q = self.policy_net(states_tensor, sentiment_tensor)
 
 
 
 
 
 
 
220
  else:
 
221
  current_q = self.policy_net(states_tensor)
 
 
 
 
 
 
 
 
 
222
 
223
+ current_q = current_q.gather(1, actions_tensor.unsqueeze(1)).squeeze(1)
224
+
225
+ # Compute target Q values
226
+ with torch.no_grad():
227
+ if self.use_sentiment:
228
+ next_sentiment_batch = []
229
+ for sentiment_data in sentiments:
230
+ next_sentiment = [sentiment_data.get('sentiment', 0.5),
231
+ sentiment_data.get('confidence', 0.0)]
232
+ next_sentiment_batch.append(next_sentiment)
233
+ next_sentiment_tensor = torch.FloatTensor(next_sentiment_batch).to(self.device)
234
+ next_q = self.policy_net(next_states_tensor, next_sentiment_tensor)
235
+ else:
236
+ next_q = self.policy_net(next_states_tensor)
237
+
238
+ next_q_max = next_q.max(1)[0]
239
+ target_q = rewards_tensor + (self.gamma * next_q_max * ~dones_tensor)
240
+
241
+ # Compute loss and optimize
242
+ loss = self.loss_fn(current_q, target_q)
243
+
244
  self.optimizer.zero_grad()
245
  loss.backward()
 
 
246
  torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
247
  self.optimizer.step()
248
+
249
+ # Update epsilon
250
+ if self.epsilon > self.epsilon_min:
251
+ self.epsilon *= self.epsilon_decay
252
+
253
  self.steps_done += 1
254
+ self.steps_since_target_update += 1
255
 
256
+ # Update target network periodically (if implemented)
257
+ if self.steps_since_target_update % self.target_update_freq == 0:
258
+ self._update_target_network()
259
+
260
  return float(loss.item())
261
+
262
  except Exception as e:
263
+ print(f"Error in update: {e}")
264
+ import traceback
265
+ traceback.print_exc()
266
  return 0.0
267
+
268
+ def _update_target_network(self):
269
+ """Update target network (placeholder for double DQN)"""
270
+ pass # Implement target network update here
271
 
272
+ # Simple fallback network
273
  class SimpleTradingNetwork(nn.Module):
274
  def __init__(self, state_dim, action_dim):
275
  super(SimpleTradingNetwork, self).__init__()
276
+ self.action_dim = action_dim
277
+
278
  self.conv_layers = nn.Sequential(
279
  nn.Conv2d(4, 16, kernel_size=4, stride=2),
280
  nn.ReLU(),
 
284
  nn.ReLU(),
285
  nn.AdaptiveAvgPool2d((8, 8))
286
  )
287
+
288
  self.fc_layers = nn.Sequential(
289
  nn.Linear(32 * 8 * 8, 128),
290
  nn.ReLU(),
 
293
  nn.ReLU(),
294
  nn.Linear(64, action_dim)
295
  )
296
+
297
  def forward(self, x):
298
  try:
299
+ if len(x.shape) == 3:
300
+ x = x.unsqueeze(0)
301
+ x = x.permute(0, 3, 1, 2).contiguous().float()
 
 
 
302
 
303
  x = self.conv_layers(x)
304
+ x = x.reshape(x.size(0), -1)
 
305
  x = self.fc_layers(x)
306
  return x
307
  except Exception as e:
308
  print(f"Error in simple network: {e}")
309
+ batch_size = x.size(0) if hasattr(x, 'size') else 1
310
+ return torch.zeros(batch_size, self.action_dim, device=(x.device if hasattr(x, 'device') else 'cpu'))