OmidSakaki commited on
Commit
6097bc7
·
verified ·
1 Parent(s): 52bdffe

Update src/agents/advanced_agent.py

Browse files
Files changed (1) hide show
  1. src/agents/advanced_agent.py +157 -67
src/agents/advanced_agent.py CHANGED
@@ -4,33 +4,124 @@ import torch.optim as optim
4
  import numpy as np
5
  from collections import deque
6
  import random
7
- from .visual_agent import VisualTradingAgent, SimpleTradingNetwork
8
 
9
- class AdvancedTradingAgent(VisualTradingAgent):
10
- def __init__(self, state_dim, action_dim, learning_rate=0.001, use_sentiment=True):
11
- super().__init__(state_dim, action_dim, learning_rate)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  self.use_sentiment = use_sentiment
14
- self.sentiment_history = deque(maxlen=50)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Enhanced network architecture for sentiment analysis
17
- if use_sentiment:
18
- self.policy_net = EnhancedTradingNetwork(state_dim, action_dim)
19
- self.policy_net = self.policy_net.to(self.device)
20
- self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
21
-
22
  def select_action(self, state, current_sentiment=0.5, sentiment_confidence=0.0):
23
  """Select action with sentiment consideration"""
24
  if random.random() < self.epsilon:
25
  return random.randint(0, self.action_dim - 1)
26
 
27
  try:
 
28
  state_normalized = state.astype(np.float32) / 255.0
29
- state_tensor = torch.FloatTensor(state_normalized).unsqueeze(0).to(self.device)
30
 
31
  if self.use_sentiment:
32
  # Add sentiment to the decision process
33
- sentiment_tensor = torch.FloatTensor([current_sentiment, sentiment_confidence]).unsqueeze(0).to(self.device)
34
  with torch.no_grad():
35
  q_values = self.policy_net(state_tensor, sentiment_tensor)
36
  else:
@@ -45,8 +136,11 @@ class AdvancedTradingAgent(VisualTradingAgent):
45
 
46
  def store_transition(self, state, action, reward, next_state, done, sentiment_data=None):
47
  """Store experience with sentiment data"""
48
- experience = (state, action, reward, next_state, done, sentiment_data)
49
- self.memory.append(experience)
 
 
 
50
 
51
  def update(self):
52
  """Update network with sentiment-enhanced learning"""
@@ -54,41 +148,50 @@ class AdvancedTradingAgent(VisualTradingAgent):
54
  return 0.0
55
 
56
  try:
 
57
  batch = random.sample(self.memory, self.batch_size)
58
  states, actions, rewards, next_states, dones, sentiment_data = zip(*batch)
59
 
60
- # Convert to tensors
61
- states_tensor = torch.FloatTensor(np.array(states)).to(self.device) / 255.0
 
 
 
 
 
 
62
  actions_tensor = torch.LongTensor(actions).to(self.device)
63
  rewards_tensor = torch.FloatTensor(rewards).to(self.device)
64
- next_states_tensor = torch.FloatTensor(np.array(next_states)).to(self.device) / 255.0
65
  dones_tensor = torch.BoolTensor(dones).to(self.device)
66
 
67
  if self.use_sentiment and sentiment_data[0] is not None:
68
- # Extract sentiment features
69
  sentiment_features = []
70
  for data in sentiment_data:
71
- if data:
72
- sentiment_features.append([data.get('sentiment', 0.5), data.get('confidence', 0.0)])
73
  else:
74
  sentiment_features.append([0.5, 0.0])
75
 
76
  sentiment_tensor = torch.FloatTensor(sentiment_features).to(self.device)
77
- next_sentiment_tensor = sentiment_tensor # Simplified
78
 
79
  # Current Q values with sentiment
80
- current_q = self.policy_net(states_tensor, sentiment_tensor).gather(1, actions_tensor.unsqueeze(1))
 
81
 
82
  # Next Q values with sentiment
83
  with torch.no_grad():
84
- next_q = self.policy_net(next_states_tensor, next_sentiment_tensor).max(1)[0]
 
85
  target_q = rewards_tensor + (self.gamma * next_q * ~dones_tensor)
86
  else:
87
- # Fallback to standard DQN
88
- current_q = self.policy_net(states_tensor).gather(1, actions_tensor.unsqueeze(1))
 
89
 
90
  with torch.no_grad():
91
- next_q = self.policy_net(next_states_tensor).max(1)[0]
 
92
  target_q = rewards_tensor + (self.gamma * next_q * ~dones_tensor)
93
 
94
  # Compute loss
@@ -97,11 +200,14 @@ class AdvancedTradingAgent(VisualTradingAgent):
97
  # Optimize
98
  self.optimizer.zero_grad()
99
  loss.backward()
 
 
100
  torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
101
  self.optimizer.step()
102
 
103
  # Update exploration
104
  self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
 
105
 
106
  return float(loss.item())
107
 
@@ -109,12 +215,12 @@ class AdvancedTradingAgent(VisualTradingAgent):
109
  print(f"Error in advanced update: {e}")
110
  return 0.0
111
 
112
- class EnhancedTradingNetwork(nn.Module):
113
- def __init__(self, state_dim, action_dim, sentiment_dim=2):
114
- super(EnhancedTradingNetwork, self).__init__()
 
115
 
116
- # Visual processing branch (same as before)
117
- self.visual_conv = nn.Sequential(
118
  nn.Conv2d(4, 16, kernel_size=4, stride=2),
119
  nn.ReLU(),
120
  nn.Conv2d(16, 32, kernel_size=4, stride=2),
@@ -124,24 +230,8 @@ class EnhancedTradingNetwork(nn.Module):
124
  nn.AdaptiveAvgPool2d((8, 8))
125
  )
126
 
127
- self.visual_fc = nn.Sequential(
128
- nn.Linear(32 * 8 * 8, 256),
129
- nn.ReLU(),
130
- nn.Dropout(0.3)
131
- )
132
-
133
- # Sentiment processing branch
134
- self.sentiment_fc = nn.Sequential(
135
- nn.Linear(sentiment_dim, 64),
136
- nn.ReLU(),
137
- nn.Dropout(0.2),
138
- nn.Linear(64, 32),
139
- nn.ReLU()
140
- )
141
-
142
- # Combined decision making
143
- self.combined_fc = nn.Sequential(
144
- nn.Linear(256 + 32, 128),
145
  nn.ReLU(),
146
  nn.Dropout(0.2),
147
  nn.Linear(128, 64),
@@ -149,20 +239,20 @@ class EnhancedTradingNetwork(nn.Module):
149
  nn.Linear(64, action_dim)
150
  )
151
 
152
- def forward(self, x, sentiment=None):
153
- # Visual processing
154
- x = x.permute(0, 3, 1, 2) # (batch, 84, 84, 4) -> (batch, 4, 84, 84)
155
- visual_features = self.visual_conv(x)
156
- visual_features = visual_features.view(visual_features.size(0), -1)
157
- visual_features = self.visual_fc(visual_features)
158
-
159
- # Sentiment processing
160
- if sentiment is not None:
161
- sentiment_features = self.sentiment_fc(sentiment)
162
- combined_features = torch.cat([visual_features, sentiment_features], dim=1)
163
- else:
164
- combined_features = visual_features
165
-
166
- # Final decision
167
- q_values = self.combined_fc(combined_features)
168
- return q_values
 
4
  import numpy as np
5
  from collections import deque
6
  import random
 
7
 
8
+ class EnhancedTradingNetwork(nn.Module):
9
+ def __init__(self, state_dim, action_dim, sentiment_dim=2):
10
+ super(EnhancedTradingNetwork, self).__init__()
11
 
12
+ # Visual processing branch
13
+ self.visual_conv = nn.Sequential(
14
+ nn.Conv2d(4, 16, kernel_size=4, stride=2),
15
+ nn.ReLU(),
16
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
17
+ nn.ReLU(),
18
+ nn.Conv2d(32, 32, kernel_size=3, stride=1),
19
+ nn.ReLU(),
20
+ nn.AdaptiveAvgPool2d((8, 8))
21
+ )
22
+
23
+ # Calculate the output size after conv layers
24
+ self.conv_output_size = 32 * 8 * 8
25
+
26
+ self.visual_fc = nn.Sequential(
27
+ nn.Linear(self.conv_output_size, 256),
28
+ nn.ReLU(),
29
+ nn.Dropout(0.3)
30
+ )
31
+
32
+ # Sentiment processing branch
33
+ self.sentiment_fc = nn.Sequential(
34
+ nn.Linear(sentiment_dim, 64),
35
+ nn.ReLU(),
36
+ nn.Dropout(0.2),
37
+ nn.Linear(64, 32),
38
+ nn.ReLU()
39
+ )
40
+
41
+ # Combined decision making
42
+ self.combined_fc = nn.Sequential(
43
+ nn.Linear(256 + 32, 128),
44
+ nn.ReLU(),
45
+ nn.Dropout(0.2),
46
+ nn.Linear(128, 64),
47
+ nn.ReLU(),
48
+ nn.Linear(64, action_dim)
49
+ )
50
+
51
+ def forward(self, x, sentiment=None):
52
+ try:
53
+ # Visual processing with proper reshaping
54
+ # x shape: (batch_size, 84, 84, 4) -> (batch_size, 4, 84, 84)
55
+ if len(x.shape) == 4: # (batch, H, W, C)
56
+ x = x.permute(0, 3, 1, 2).contiguous()
57
+ else:
58
+ # Handle single sample case
59
+ x = x.unsqueeze(0) if len(x.shape) == 3 else x
60
+ x = x.permute(0, 3, 1, 2).contiguous()
61
+
62
+ visual_features = self.visual_conv(x)
63
+
64
+ # Use reshape instead of view for safety
65
+ batch_size = visual_features.size(0)
66
+ visual_features = visual_features.reshape(batch_size, -1)
67
+
68
+ visual_features = self.visual_fc(visual_features)
69
+
70
+ # Sentiment processing
71
+ if sentiment is not None:
72
+ if len(sentiment.shape) == 1:
73
+ sentiment = sentiment.unsqueeze(0)
74
+ sentiment_features = self.sentiment_fc(sentiment)
75
+ combined_features = torch.cat([visual_features, sentiment_features], dim=1)
76
+ else:
77
+ combined_features = visual_features
78
+
79
+ # Final decision
80
+ q_values = self.combined_fc(combined_features)
81
+ return q_values
82
+
83
+ except Exception as e:
84
+ print(f"Error in network forward: {e}")
85
+ # Return safe default
86
+ return torch.zeros((x.size(0) if hasattr(x, 'size') else 1, self.combined_fc[-1].out_features))
87
+
88
+ class AdvancedTradingAgent:
89
+ def __init__(self, state_dim, action_dim, learning_rate=0.001, use_sentiment=True):
90
+ self.state_dim = state_dim
91
+ self.action_dim = action_dim
92
+ self.learning_rate = learning_rate
93
  self.use_sentiment = use_sentiment
94
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95
+ print(f"Using device: {self.device}")
96
+
97
+ # Neural network
98
+ self.policy_net = EnhancedTradingNetwork(state_dim, action_dim).to(self.device)
99
+ self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
100
+
101
+ # Experience replay
102
+ self.memory = deque(maxlen=500)
103
+ self.batch_size = 16
104
+
105
+ # Training parameters
106
+ self.gamma = 0.99
107
+ self.epsilon = 1.0
108
+ self.epsilon_min = 0.1
109
+ self.epsilon_decay = 0.995
110
+ self.steps_done = 0
111
 
 
 
 
 
 
 
112
  def select_action(self, state, current_sentiment=0.5, sentiment_confidence=0.0):
113
  """Select action with sentiment consideration"""
114
  if random.random() < self.epsilon:
115
  return random.randint(0, self.action_dim - 1)
116
 
117
  try:
118
+ # Normalize state
119
  state_normalized = state.astype(np.float32) / 255.0
120
+ state_tensor = torch.FloatTensor(state_normalized).to(self.device)
121
 
122
  if self.use_sentiment:
123
  # Add sentiment to the decision process
124
+ sentiment_tensor = torch.FloatTensor([current_sentiment, sentiment_confidence]).to(self.device)
125
  with torch.no_grad():
126
  q_values = self.policy_net(state_tensor, sentiment_tensor)
127
  else:
 
136
 
137
  def store_transition(self, state, action, reward, next_state, done, sentiment_data=None):
138
  """Store experience with sentiment data"""
139
+ try:
140
+ experience = (state, action, reward, next_state, done, sentiment_data)
141
+ self.memory.append(experience)
142
+ except Exception as e:
143
+ print(f"Error storing transition: {e}")
144
 
145
  def update(self):
146
  """Update network with sentiment-enhanced learning"""
 
148
  return 0.0
149
 
150
  try:
151
+ # Sample batch from memory
152
  batch = random.sample(self.memory, self.batch_size)
153
  states, actions, rewards, next_states, dones, sentiment_data = zip(*batch)
154
 
155
+ # Convert to tensors with proper shape handling
156
+ states_array = np.array(states, dtype=np.float32) / 255.0
157
+ next_states_array = np.array(next_states, dtype=np.float32) / 255.0
158
+
159
+ # Ensure proper tensor shapes
160
+ states_tensor = torch.FloatTensor(states_array).to(self.device)
161
+ next_states_tensor = torch.FloatTensor(next_states_array).to(self.device)
162
+
163
  actions_tensor = torch.LongTensor(actions).to(self.device)
164
  rewards_tensor = torch.FloatTensor(rewards).to(self.device)
 
165
  dones_tensor = torch.BoolTensor(dones).to(self.device)
166
 
167
  if self.use_sentiment and sentiment_data[0] is not None:
168
+ # Extract sentiment features safely
169
  sentiment_features = []
170
  for data in sentiment_data:
171
+ if data and 'sentiment' in data and 'confidence' in data:
172
+ sentiment_features.append([data['sentiment'], data['confidence']])
173
  else:
174
  sentiment_features.append([0.5, 0.0])
175
 
176
  sentiment_tensor = torch.FloatTensor(sentiment_features).to(self.device)
 
177
 
178
  # Current Q values with sentiment
179
+ current_q = self.policy_net(states_tensor, sentiment_tensor)
180
+ current_q = current_q.gather(1, actions_tensor.unsqueeze(1))
181
 
182
  # Next Q values with sentiment
183
  with torch.no_grad():
184
+ next_q = self.policy_net(next_states_tensor, sentiment_tensor)
185
+ next_q = next_q.max(1)[0]
186
  target_q = rewards_tensor + (self.gamma * next_q * ~dones_tensor)
187
  else:
188
+ # Fallback to standard DQN without sentiment
189
+ current_q = self.policy_net(states_tensor)
190
+ current_q = current_q.gather(1, actions_tensor.unsqueeze(1))
191
 
192
  with torch.no_grad():
193
+ next_q = self.policy_net(next_states_tensor)
194
+ next_q = next_q.max(1)[0]
195
  target_q = rewards_tensor + (self.gamma * next_q * ~dones_tensor)
196
 
197
  # Compute loss
 
200
  # Optimize
201
  self.optimizer.zero_grad()
202
  loss.backward()
203
+
204
+ # Gradient clipping for stability
205
  torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
206
  self.optimizer.step()
207
 
208
  # Update exploration
209
  self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
210
+ self.steps_done += 1
211
 
212
  return float(loss.item())
213
 
 
215
  print(f"Error in advanced update: {e}")
216
  return 0.0
217
 
218
+ # Fallback to simple agent if advanced one fails
219
+ class SimpleTradingNetwork(nn.Module):
220
+ def __init__(self, state_dim, action_dim):
221
+ super(SimpleTradingNetwork, self).__init__()
222
 
223
+ self.conv_layers = nn.Sequential(
 
224
  nn.Conv2d(4, 16, kernel_size=4, stride=2),
225
  nn.ReLU(),
226
  nn.Conv2d(16, 32, kernel_size=4, stride=2),
 
230
  nn.AdaptiveAvgPool2d((8, 8))
231
  )
232
 
233
+ self.fc_layers = nn.Sequential(
234
+ nn.Linear(32 * 8 * 8, 128),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  nn.ReLU(),
236
  nn.Dropout(0.2),
237
  nn.Linear(128, 64),
 
239
  nn.Linear(64, action_dim)
240
  )
241
 
242
+ def forward(self, x):
243
+ try:
244
+ # Handle input shape
245
+ if len(x.shape) == 4: # (batch, H, W, C)
246
+ x = x.permute(0, 3, 1, 2).contiguous()
247
+ else:
248
+ x = x.unsqueeze(0) if len(x.shape) == 3 else x
249
+ x = x.permute(0, 3, 1, 2).contiguous()
250
+
251
+ x = self.conv_layers(x)
252
+ batch_size = x.size(0)
253
+ x = x.reshape(batch_size, -1)
254
+ x = self.fc_layers(x)
255
+ return x
256
+ except Exception as e:
257
+ print(f"Error in simple network: {e}")
258
+ return torch.zeros((x.size(0), self.fc_layers[-1].out_features))