OmidSakaki commited on
Commit
a12d9e9
·
verified ·
1 Parent(s): 3ee3593

Create src/agents/advanced_agent.py

Browse files
Files changed (1) hide show
  1. src/agents/advanced_agent.py +168 -0
src/agents/advanced_agent.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import numpy as np
5
+ from collections import deque
6
+ import random
7
+ from .visual_agent import VisualTradingAgent, SimpleTradingNetwork
8
+
9
+ class AdvancedTradingAgent(VisualTradingAgent):
10
+ def __init__(self, state_dim, action_dim, learning_rate=0.001, use_sentiment=True):
11
+ super().__init__(state_dim, action_dim, learning_rate)
12
+
13
+ self.use_sentiment = use_sentiment
14
+ self.sentiment_history = deque(maxlen=50)
15
+
16
+ # Enhanced network architecture for sentiment analysis
17
+ if use_sentiment:
18
+ self.policy_net = EnhancedTradingNetwork(state_dim, action_dim)
19
+ self.policy_net = self.policy_net.to(self.device)
20
+ self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
21
+
22
+ def select_action(self, state, current_sentiment=0.5, sentiment_confidence=0.0):
23
+ """Select action with sentiment consideration"""
24
+ if random.random() < self.epsilon:
25
+ return random.randint(0, self.action_dim - 1)
26
+
27
+ try:
28
+ state_normalized = state.astype(np.float32) / 255.0
29
+ state_tensor = torch.FloatTensor(state_normalized).unsqueeze(0).to(self.device)
30
+
31
+ if self.use_sentiment:
32
+ # Add sentiment to the decision process
33
+ sentiment_tensor = torch.FloatTensor([current_sentiment, sentiment_confidence]).unsqueeze(0).to(self.device)
34
+ with torch.no_grad():
35
+ q_values = self.policy_net(state_tensor, sentiment_tensor)
36
+ else:
37
+ with torch.no_grad():
38
+ q_values = self.policy_net(state_tensor)
39
+
40
+ return int(q_values.argmax().item())
41
+
42
+ except Exception as e:
43
+ print(f"Error in advanced action selection: {e}")
44
+ return random.randint(0, self.action_dim - 1)
45
+
46
+ def store_transition(self, state, action, reward, next_state, done, sentiment_data=None):
47
+ """Store experience with sentiment data"""
48
+ experience = (state, action, reward, next_state, done, sentiment_data)
49
+ self.memory.append(experience)
50
+
51
+ def update(self):
52
+ """Update network with sentiment-enhanced learning"""
53
+ if len(self.memory) < self.batch_size:
54
+ return 0.0
55
+
56
+ try:
57
+ batch = random.sample(self.memory, self.batch_size)
58
+ states, actions, rewards, next_states, dones, sentiment_data = zip(*batch)
59
+
60
+ # Convert to tensors
61
+ states_tensor = torch.FloatTensor(np.array(states)).to(self.device) / 255.0
62
+ actions_tensor = torch.LongTensor(actions).to(self.device)
63
+ rewards_tensor = torch.FloatTensor(rewards).to(self.device)
64
+ next_states_tensor = torch.FloatTensor(np.array(next_states)).to(self.device) / 255.0
65
+ dones_tensor = torch.BoolTensor(dones).to(self.device)
66
+
67
+ if self.use_sentiment and sentiment_data[0] is not None:
68
+ # Extract sentiment features
69
+ sentiment_features = []
70
+ for data in sentiment_data:
71
+ if data:
72
+ sentiment_features.append([data.get('sentiment', 0.5), data.get('confidence', 0.0)])
73
+ else:
74
+ sentiment_features.append([0.5, 0.0])
75
+
76
+ sentiment_tensor = torch.FloatTensor(sentiment_features).to(self.device)
77
+ next_sentiment_tensor = sentiment_tensor # Simplified
78
+
79
+ # Current Q values with sentiment
80
+ current_q = self.policy_net(states_tensor, sentiment_tensor).gather(1, actions_tensor.unsqueeze(1))
81
+
82
+ # Next Q values with sentiment
83
+ with torch.no_grad():
84
+ next_q = self.policy_net(next_states_tensor, next_sentiment_tensor).max(1)[0]
85
+ target_q = rewards_tensor + (self.gamma * next_q * ~dones_tensor)
86
+ else:
87
+ # Fallback to standard DQN
88
+ current_q = self.policy_net(states_tensor).gather(1, actions_tensor.unsqueeze(1))
89
+
90
+ with torch.no_grad():
91
+ next_q = self.policy_net(next_states_tensor).max(1)[0]
92
+ target_q = rewards_tensor + (self.gamma * next_q * ~dones_tensor)
93
+
94
+ # Compute loss
95
+ loss = nn.MSELoss()(current_q.squeeze(), target_q)
96
+
97
+ # Optimize
98
+ self.optimizer.zero_grad()
99
+ loss.backward()
100
+ torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
101
+ self.optimizer.step()
102
+
103
+ # Update exploration
104
+ self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
105
+
106
+ return float(loss.item())
107
+
108
+ except Exception as e:
109
+ print(f"Error in advanced update: {e}")
110
+ return 0.0
111
+
112
+ class EnhancedTradingNetwork(nn.Module):
113
+ def __init__(self, state_dim, action_dim, sentiment_dim=2):
114
+ super(EnhancedTradingNetwork, self).__init__()
115
+
116
+ # Visual processing branch (same as before)
117
+ self.visual_conv = nn.Sequential(
118
+ nn.Conv2d(4, 16, kernel_size=4, stride=2),
119
+ nn.ReLU(),
120
+ nn.Conv2d(16, 32, kernel_size=4, stride=2),
121
+ nn.ReLU(),
122
+ nn.Conv2d(32, 32, kernel_size=3, stride=1),
123
+ nn.ReLU(),
124
+ nn.AdaptiveAvgPool2d((8, 8))
125
+ )
126
+
127
+ self.visual_fc = nn.Sequential(
128
+ nn.Linear(32 * 8 * 8, 256),
129
+ nn.ReLU(),
130
+ nn.Dropout(0.3)
131
+ )
132
+
133
+ # Sentiment processing branch
134
+ self.sentiment_fc = nn.Sequential(
135
+ nn.Linear(sentiment_dim, 64),
136
+ nn.ReLU(),
137
+ nn.Dropout(0.2),
138
+ nn.Linear(64, 32),
139
+ nn.ReLU()
140
+ )
141
+
142
+ # Combined decision making
143
+ self.combined_fc = nn.Sequential(
144
+ nn.Linear(256 + 32, 128),
145
+ nn.ReLU(),
146
+ nn.Dropout(0.2),
147
+ nn.Linear(128, 64),
148
+ nn.ReLU(),
149
+ nn.Linear(64, action_dim)
150
+ )
151
+
152
+ def forward(self, x, sentiment=None):
153
+ # Visual processing
154
+ x = x.permute(0, 3, 1, 2) # (batch, 84, 84, 4) -> (batch, 4, 84, 84)
155
+ visual_features = self.visual_conv(x)
156
+ visual_features = visual_features.view(visual_features.size(0), -1)
157
+ visual_features = self.visual_fc(visual_features)
158
+
159
+ # Sentiment processing
160
+ if sentiment is not None:
161
+ sentiment_features = self.sentiment_fc(sentiment)
162
+ combined_features = torch.cat([visual_features, sentiment_features], dim=1)
163
+ else:
164
+ combined_features = visual_features
165
+
166
+ # Final decision
167
+ q_values = self.combined_fc(combined_features)
168
+ return q_values