OmidSakaki commited on
Commit
d663209
·
verified ·
1 Parent(s): 3ca7d69

Delete src

Browse files
src/agents/advanced_agent.py DELETED
@@ -1,310 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
- import numpy as np
5
- from collections import deque
6
- import random
7
- import warnings
8
- warnings.filterwarnings('ignore')
9
-
10
- class EnhancedTradingNetwork(nn.Module):
11
- def __init__(self, state_dim, action_dim, sentiment_dim=2):
12
- super(EnhancedTradingNetwork, self).__init__()
13
-
14
- # Visual processing branch
15
- self.visual_conv = nn.Sequential(
16
- nn.Conv2d(4, 16, kernel_size=4, stride=2),
17
- nn.ReLU(),
18
- nn.Conv2d(16, 32, kernel_size=4, stride=2),
19
- nn.ReLU(),
20
- nn.Conv2d(32, 32, kernel_size=3, stride=1),
21
- nn.ReLU(),
22
- nn.AdaptiveAvgPool2d((8, 8))
23
- )
24
-
25
- # Calculate the output size after conv layers
26
- self.conv_output_size = 32 * 8 * 8
27
-
28
- self.visual_fc = nn.Sequential(
29
- nn.Linear(self.conv_output_size, 256),
30
- nn.ReLU(),
31
- nn.Dropout(0.3)
32
- )
33
-
34
- # Sentiment processing branch
35
- self.sentiment_fc = nn.Sequential(
36
- nn.Linear(sentiment_dim, 64),
37
- nn.ReLU(),
38
- nn.Dropout(0.2),
39
- nn.Linear(64, 32),
40
- nn.ReLU()
41
- )
42
-
43
- # Combined decision making
44
- self.combined_fc = nn.Sequential(
45
- nn.Linear(256 + 32, 128),
46
- nn.ReLU(),
47
- nn.Dropout(0.2),
48
- nn.Linear(128, 64),
49
- nn.ReLU(),
50
- nn.Linear(64, action_dim)
51
- )
52
-
53
- # Store action_dim for error handling
54
- self.action_dim = action_dim
55
-
56
- def forward(self, x, sentiment=None):
57
- try:
58
- # Ensure input has batch dimension
59
- if len(x.shape) == 3: # (H, W, C)
60
- x = x.unsqueeze(0)
61
- elif len(x.shape) == 4: # (batch, H, W, C)
62
- pass
63
- else:
64
- raise ValueError(f"Invalid input shape: {x.shape}")
65
-
66
- # Permute to (batch, C, H, W)
67
- x = x.permute(0, 3, 1, 2).contiguous().float()
68
-
69
- # Check if channels match expected input
70
- if x.size(1) != 4:
71
- raise ValueError(f"Expected 4 channels, got {x.size(1)}")
72
-
73
- visual_features = self.visual_conv(x)
74
- batch_size = visual_features.size(0)
75
- visual_features = visual_features.reshape(batch_size, -1)
76
- visual_features = self.visual_fc(visual_features)
77
-
78
- # Sentiment processing
79
- if sentiment is not None and self.sentiment_fc is not None:
80
- if len(sentiment.shape) == 1:
81
- sentiment = sentiment.unsqueeze(0)
82
- sentiment = sentiment.float()
83
- sentiment_features = self.sentiment_fc(sentiment)
84
- combined_features = torch.cat([visual_features, sentiment_features], dim=1)
85
- else:
86
- # Pad with zeros if no sentiment
87
- sentiment_features = torch.zeros(batch_size, 32, device=visual_features.device)
88
- combined_features = torch.cat([visual_features, sentiment_features], dim=1)
89
-
90
- q_values = self.combined_fc(combined_features)
91
- return q_values
92
-
93
- except Exception as e:
94
- print(f"Error in network forward: {e}")
95
- print(f"Input shape: {getattr(x, 'shape', 'Unknown')}")
96
- # Return safe default with correct shape
97
- batch_size = x.size(0) if hasattr(x, 'size') else 1
98
- return torch.zeros(batch_size, self.action_dim, device=(x.device if hasattr(x, 'device') else 'cpu'))
99
-
100
- class AdvancedTradingAgent:
101
- def __init__(self, state_dim, action_dim, learning_rate=0.001, use_sentiment=True):
102
- self.state_dim = state_dim # Should be (84, 84, 4) or similar
103
- self.action_dim = action_dim
104
- self.learning_rate = learning_rate
105
- self.use_sentiment = use_sentiment
106
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
107
- print(f"Using device: {self.device}")
108
-
109
- # Neural network
110
- self.policy_net = EnhancedTradingNetwork(state_dim, action_dim).to(self.device)
111
- self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
112
- self.loss_fn = nn.MSELoss()
113
-
114
- # Experience replay
115
- self.memory = deque(maxlen=10000) # Increased buffer size
116
- self.batch_size = min(32, state_dim[0]//2) # Dynamic batch size
117
-
118
- # Training parameters
119
- self.gamma = 0.99
120
- self.epsilon = 1.0
121
- self.epsilon_min = 0.01 # More aggressive exploration decay
122
- self.epsilon_decay = 0.9995 # Slower decay
123
- self.steps_done = 0
124
- self.target_update_freq = 100 # Target network update frequency
125
- self.steps_since_target_update = 0
126
-
127
- def select_action(self, state, current_sentiment=None, sentiment_confidence=None):
128
- """Select action with epsilon-greedy policy"""
129
- if random.random() < self.epsilon:
130
- return random.randint(0, self.action_dim - 1)
131
-
132
- try:
133
- # Validate and normalize state
134
- if not isinstance(state, np.ndarray):
135
- state = np.array(state)
136
-
137
- if state.dtype != np.float32:
138
- state = state.astype(np.float32)
139
-
140
- # Normalize pixel values
141
- if state.max() > 1.0:
142
- state = state / 255.0
143
-
144
- state_tensor = torch.FloatTensor(state).to(self.device)
145
-
146
- # Prepare sentiment input
147
- if self.use_sentiment and current_sentiment is not None:
148
- sentiment = np.array([float(current_sentiment), float(sentiment_confidence or 0.0)])
149
- sentiment_tensor = torch.FloatTensor(sentiment).to(self.device)
150
- with torch.no_grad():
151
- q_values = self.policy_net(state_tensor, sentiment_tensor)
152
- else:
153
- with torch.no_grad():
154
- q_values = self.policy_net(state_tensor)
155
-
156
- action = int(q_values.argmax().item())
157
- return action
158
-
159
- except Exception as e:
160
- print(f"Error in action selection: {e}")
161
- return random.randint(0, self.action_dim - 1)
162
-
163
- def store_transition(self, state, action, reward, next_state, done, sentiment_data=None):
164
- """Store experience tuple safely"""
165
- try:
166
- # Ensure all inputs are numpy arrays
167
- if not isinstance(state, np.ndarray):
168
- state = np.array(state, dtype=np.float32)
169
- if not isinstance(next_state, np.ndarray):
170
- next_state = np.array(next_state, dtype=np.float32)
171
-
172
- # Normalize before storing
173
- if state.max() > 1.0:
174
- state = state / 255.0
175
- if next_state.max() > 1.0:
176
- next_state = next_state / 255.0
177
-
178
- # Handle sentiment data
179
- if sentiment_data is None:
180
- sentiment_data = {'sentiment': 0.5, 'confidence': 0.0}
181
-
182
- experience = (state, action, float(reward), next_state, bool(done), sentiment_data)
183
- self.memory.append(experience)
184
-
185
- except Exception as e:
186
- print(f"Error storing transition: {e}")
187
-
188
- def update(self):
189
- """DQN update with improved stability"""
190
- if len(self.memory) < self.batch_size:
191
- return 0.0
192
-
193
- try:
194
- batch = random.sample(self.memory, self.batch_size)
195
- states, actions, rewards, next_states, dones, sentiments = zip(*batch)
196
-
197
- # Convert to tensors
198
- states = np.stack(states)
199
- next_states = np.stack(next_states)
200
- actions = np.array(actions)
201
- rewards = np.array(rewards)
202
- dones = np.array(dones)
203
-
204
- states_tensor = torch.FloatTensor(states).to(self.device)
205
- next_states_tensor = torch.FloatTensor(next_states).to(self.device)
206
- actions_tensor = torch.LongTensor(actions).to(self.device)
207
- rewards_tensor = torch.FloatTensor(rewards).to(self.device)
208
- dones_tensor = torch.BoolTensor(dones).to(self.device)
209
-
210
- # Compute current Q values
211
- if self.use_sentiment:
212
- # Use sentiment from current state
213
- sentiment_batch = []
214
- for sentiment_data in sentiments:
215
- sentiment = [sentiment_data.get('sentiment', 0.5),
216
- sentiment_data.get('confidence', 0.0)]
217
- sentiment_batch.append(sentiment)
218
- sentiment_tensor = torch.FloatTensor(sentiment_batch).to(self.device)
219
- current_q = self.policy_net(states_tensor, sentiment_tensor)
220
- else:
221
- current_q = self.policy_net(states_tensor)
222
-
223
- current_q = current_q.gather(1, actions_tensor.unsqueeze(1)).squeeze(1)
224
-
225
- # Compute target Q values
226
- with torch.no_grad():
227
- if self.use_sentiment:
228
- next_sentiment_batch = []
229
- for sentiment_data in sentiments:
230
- next_sentiment = [sentiment_data.get('sentiment', 0.5),
231
- sentiment_data.get('confidence', 0.0)]
232
- next_sentiment_batch.append(next_sentiment)
233
- next_sentiment_tensor = torch.FloatTensor(next_sentiment_batch).to(self.device)
234
- next_q = self.policy_net(next_states_tensor, next_sentiment_tensor)
235
- else:
236
- next_q = self.policy_net(next_states_tensor)
237
-
238
- next_q_max = next_q.max(1)[0]
239
- target_q = rewards_tensor + (self.gamma * next_q_max * ~dones_tensor)
240
-
241
- # Compute loss and optimize
242
- loss = self.loss_fn(current_q, target_q)
243
-
244
- self.optimizer.zero_grad()
245
- loss.backward()
246
- torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
247
- self.optimizer.step()
248
-
249
- # Update epsilon
250
- if self.epsilon > self.epsilon_min:
251
- self.epsilon *= self.epsilon_decay
252
-
253
- self.steps_done += 1
254
- self.steps_since_target_update += 1
255
-
256
- # Update target network periodically (if implemented)
257
- if self.steps_since_target_update % self.target_update_freq == 0:
258
- self._update_target_network()
259
-
260
- return float(loss.item())
261
-
262
- except Exception as e:
263
- print(f"Error in update: {e}")
264
- import traceback
265
- traceback.print_exc()
266
- return 0.0
267
-
268
- def _update_target_network(self):
269
- """Update target network (placeholder for double DQN)"""
270
- pass # Implement target network update here
271
-
272
- # Simple fallback network
273
- class SimpleTradingNetwork(nn.Module):
274
- def __init__(self, state_dim, action_dim):
275
- super(SimpleTradingNetwork, self).__init__()
276
- self.action_dim = action_dim
277
-
278
- self.conv_layers = nn.Sequential(
279
- nn.Conv2d(4, 16, kernel_size=4, stride=2),
280
- nn.ReLU(),
281
- nn.Conv2d(16, 32, kernel_size=4, stride=2),
282
- nn.ReLU(),
283
- nn.Conv2d(32, 32, kernel_size=3, stride=1),
284
- nn.ReLU(),
285
- nn.AdaptiveAvgPool2d((8, 8))
286
- )
287
-
288
- self.fc_layers = nn.Sequential(
289
- nn.Linear(32 * 8 * 8, 128),
290
- nn.ReLU(),
291
- nn.Dropout(0.2),
292
- nn.Linear(128, 64),
293
- nn.ReLU(),
294
- nn.Linear(64, action_dim)
295
- )
296
-
297
- def forward(self, x):
298
- try:
299
- if len(x.shape) == 3:
300
- x = x.unsqueeze(0)
301
- x = x.permute(0, 3, 1, 2).contiguous().float()
302
-
303
- x = self.conv_layers(x)
304
- x = x.reshape(x.size(0), -1)
305
- x = self.fc_layers(x)
306
- return x
307
- except Exception as e:
308
- print(f"Error in simple network: {e}")
309
- batch_size = x.size(0) if hasattr(x, 'size') else 1
310
- return torch.zeros(batch_size, self.action_dim, device=(x.device if hasattr(x, 'device') else 'cpu'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/agents/visual_agent.py DELETED
@@ -1,153 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
- import numpy as np
5
- from collections import deque
6
- import random
7
-
8
- class SimpleTradingNetwork(nn.Module):
9
- def __init__(self, state_dim, action_dim):
10
- super(SimpleTradingNetwork, self).__init__()
11
-
12
- # Simplified CNN for faster training
13
- self.conv_layers = nn.Sequential(
14
- nn.Conv2d(4, 16, kernel_size=4, stride=2), # Input: 84x84x4 -> 41x41x16
15
- nn.ReLU(),
16
- nn.Conv2d(16, 32, kernel_size=4, stride=2), # 41x41x16 -> 19x19x32
17
- nn.ReLU(),
18
- nn.Conv2d(32, 32, kernel_size=3, stride=1), # 19x19x32 -> 17x17x32
19
- nn.ReLU(),
20
- nn.AdaptiveAvgPool2d((8, 8)) # 17x17x32 -> 8x8x32
21
- )
22
-
23
- # Calculate flattened size
24
- self.flattened_size = 32 * 8 * 8
25
-
26
- # Fully connected layers
27
- self.fc_layers = nn.Sequential(
28
- nn.Linear(self.flattened_size, 128),
29
- nn.ReLU(),
30
- nn.Dropout(0.2),
31
- nn.Linear(128, 64),
32
- nn.ReLU(),
33
- nn.Dropout(0.2),
34
- nn.Linear(64, action_dim)
35
- )
36
-
37
- def forward(self, x):
38
- try:
39
- # x shape: (batch_size, 84, 84, 4) -> (batch_size, 4, 84, 84)
40
- if len(x.shape) == 4: # Single observation
41
- x = x.permute(0, 3, 1, 2)
42
- else: # Batch of observations
43
- x = x.permute(0, 3, 1, 2)
44
-
45
- x = self.conv_layers(x)
46
- x = x.view(x.size(0), -1)
47
- x = self.fc_layers(x)
48
- return x
49
- except Exception as e:
50
- print(f"Error in network forward: {e}")
51
- # Return zeros in case of error
52
- return torch.zeros((x.size(0), self.fc_layers[-1].out_features))
53
-
54
- class VisualTradingAgent:
55
- def __init__(self, state_dim, action_dim, learning_rate=0.001):
56
- self.state_dim = state_dim
57
- self.action_dim = action_dim
58
- self.learning_rate = learning_rate
59
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
- print(f"Using device: {self.device}")
61
-
62
- # Neural network
63
- self.policy_net = SimpleTradingNetwork(state_dim, action_dim).to(self.device)
64
- self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
65
-
66
- # Experience replay
67
- self.memory = deque(maxlen=500) # Smaller memory for stability
68
- self.batch_size = 16
69
-
70
- # Training parameters
71
- self.gamma = 0.99
72
- self.epsilon = 1.0
73
- self.epsilon_min = 0.1
74
- self.epsilon_decay = 0.995
75
- self.update_target_every = 100
76
- self.steps_done = 0
77
-
78
- def select_action(self, state):
79
- """Select action using epsilon-greedy policy"""
80
- if random.random() < self.epsilon:
81
- return random.randint(0, self.action_dim - 1)
82
-
83
- try:
84
- # Normalize state and convert to tensor
85
- state_normalized = state.astype(np.float32) / 255.0
86
- state_tensor = torch.FloatTensor(state_normalized).unsqueeze(0).to(self.device)
87
-
88
- with torch.no_grad():
89
- q_values = self.policy_net(state_tensor)
90
- return int(q_values.argmax().item())
91
- except Exception as e:
92
- print(f"Error in action selection: {e}")
93
- return random.randint(0, self.action_dim - 1)
94
-
95
- def store_transition(self, state, action, reward, next_state, done):
96
- """Store experience in replay memory"""
97
- try:
98
- self.memory.append((state, action, reward, next_state, done))
99
- except Exception as e:
100
- print(f"Error storing transition: {e}")
101
-
102
- def update(self):
103
- """Update the neural network"""
104
- if len(self.memory) < self.batch_size:
105
- return 0.0
106
-
107
- try:
108
- # Sample batch from memory
109
- batch = random.sample(self.memory, self.batch_size)
110
- states, actions, rewards, next_states, dones = zip(*batch)
111
-
112
- # Convert to tensors with normalization
113
- states_array = np.array(states, dtype=np.float32) / 255.0
114
- next_states_array = np.array(next_states, dtype=np.float32) / 255.0
115
-
116
- states_tensor = torch.FloatTensor(states_array).to(self.device)
117
- actions_tensor = torch.LongTensor(actions).to(self.device)
118
- rewards_tensor = torch.FloatTensor(rewards).to(self.device)
119
- next_states_tensor = torch.FloatTensor(next_states_array).to(self.device)
120
- dones_tensor = torch.BoolTensor(dones).to(self.device)
121
-
122
- # Current Q values
123
- current_q = self.policy_net(states_tensor).gather(1, actions_tensor.unsqueeze(1))
124
-
125
- # Next Q values
126
- with torch.no_grad():
127
- next_q = self.policy_net(next_states_tensor).max(1)[0]
128
- target_q = rewards_tensor + (self.gamma * next_q * ~dones_tensor)
129
-
130
- # Compute loss
131
- loss = nn.MSELoss()(current_q.squeeze(), target_q)
132
-
133
- # Optimize
134
- self.optimizer.zero_grad()
135
- loss.backward()
136
-
137
- # Gradient clipping for stability
138
- torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
139
- self.optimizer.step()
140
-
141
- # Update steps and decay epsilon
142
- self.steps_done += 1
143
- if self.steps_done % self.update_target_every == 0:
144
- # For simplicity, we're using the same network
145
- pass
146
-
147
- self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
148
-
149
- return float(loss.item())
150
-
151
- except Exception as e:
152
- print(f"Error in update: {e}")
153
- return 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/environments/advanced_trading_env.py DELETED
@@ -1,293 +0,0 @@
1
- import numpy as np
2
- import logging
3
- from typing import Dict, Any, Optional, Tuple
4
- from .visual_trading_env import VisualTradingEnvironment
5
- from src.sentiment.twitter_analyzer import AdvancedSentimentAnalyzer
6
-
7
- # Setup logging
8
- logging.basicConfig(level=logging.INFO)
9
- logger = logging.getLogger(__name__)
10
-
11
- class AdvancedTradingEnvironment(VisualTradingEnvironment):
12
- def __init__(self, initial_balance=10000, risk_level="Medium", asset_type="Crypto",
13
- use_sentiment=True, sentiment_influence=0.3, sentiment_update_freq=5):
14
- super().__init__(initial_balance, risk_level, asset_type)
15
-
16
- # Validate inputs
17
- if not 0.0 <= sentiment_influence <= 1.0:
18
- raise ValueError("sentiment_influence must be between 0.0 and 1.0")
19
- if sentiment_update_freq < 1:
20
- raise ValueError("sentiment_update_freq must be at least 1")
21
-
22
- self.use_sentiment = use_sentiment
23
- self.sentiment_influence = sentiment_influence
24
- self.sentiment_update_freq = sentiment_update_freq
25
- self.sentiment_history = deque(maxlen=100) # Limited history
26
- self.current_step = 0
27
-
28
- # Sentiment analyzer with error handling
29
- self.sentiment_analyzer = None
30
- self.current_sentiment = 0.5
31
- self.sentiment_confidence = 0.0
32
-
33
- if use_sentiment:
34
- try:
35
- self.sentiment_analyzer = AdvancedSentimentAnalyzer()
36
- self.sentiment_analyzer.initialize_models()
37
- logger.info("Sentiment analyzer initialized successfully")
38
- except Exception as e:
39
- logger.warning(f"Failed to initialize sentiment analyzer: {e}. Disabling sentiment.")
40
- self.use_sentiment = False
41
-
42
- def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]:
43
- """Execute trading step with sentiment influence"""
44
- if not isinstance(action, int) or action < 0:
45
- logger.warning(f"Invalid action {action}, defaulting to hold")
46
- action = 0 # Hold action as default
47
-
48
- # Update sentiment periodically
49
- if self.use_sentiment and self.current_step % self.sentiment_update_freq == 0:
50
- self._update_sentiment()
51
-
52
- self.current_step += 1
53
-
54
- # Execute base environment step
55
- try:
56
- observation, reward, done, info = super().step(action)
57
- except Exception as e:
58
- logger.error(f"Error in base environment step: {e}")
59
- # Return safe defaults
60
- observation = self._get_safe_observation()
61
- reward = 0.0
62
- done = False
63
- info = {}
64
-
65
- # Apply sentiment modification to reward
66
- if self.use_sentiment:
67
- try:
68
- reward = self._apply_sentiment_to_reward(reward, action, info)
69
- except Exception as e:
70
- logger.warning(f"Error applying sentiment to reward: {e}")
71
-
72
- # Enhance observation with sentiment (optional)
73
- enhanced_observation = self._enhance_observation(observation)
74
-
75
- # Add sentiment info to info dict
76
- info.update({
77
- 'sentiment': float(self.current_sentiment),
78
- 'sentiment_confidence': float(self.sentiment_confidence),
79
- 'sentiment_influence': float(self.sentiment_influence),
80
- 'step': self.current_step
81
- })
82
-
83
- return enhanced_observation, float(reward), bool(done), info
84
-
85
- def _update_sentiment(self):
86
- """Update current market sentiment with robust error handling"""
87
- if not self.sentiment_analyzer:
88
- return
89
-
90
- try:
91
- sentiment_data = self.sentiment_analyzer.get_influencer_sentiment()
92
-
93
- # Validate sentiment data
94
- if not isinstance(sentiment_data, dict):
95
- raise ValueError("Invalid sentiment data format")
96
-
97
- market_sentiment = sentiment_data.get('market_sentiment')
98
- confidence = sentiment_data.get('confidence')
99
-
100
- if market_sentiment is None or not (-1.0 <= market_sentiment <= 1.0):
101
- raise ValueError("Invalid market_sentiment value")
102
- if confidence is None or not (0.0 <= confidence <= 1.0):
103
- raise ValueError("Invalid confidence value")
104
-
105
- self.current_sentiment = float(market_sentiment)
106
- self.sentiment_confidence = float(confidence)
107
-
108
- # Normalize sentiment to 0-1 range for consistency
109
- self.current_sentiment = (self.current_sentiment + 1.0) / 2.0
110
-
111
- # Update history
112
- self.sentiment_history.append({
113
- 'sentiment': self.current_sentiment,
114
- 'confidence': self.sentiment_confidence,
115
- 'timestamp': self.current_step
116
- })
117
-
118
- logger.debug(f"Updated sentiment: {self.current_sentiment:.3f} (conf: {self.sentiment_confidence:.3f})")
119
-
120
- except Exception as e:
121
- logger.warning(f"Error updating sentiment: {e}")
122
- # Fallback to neutral sentiment
123
- self.current_sentiment = 0.5
124
- self.sentiment_confidence = 0.0
125
- self.sentiment_history.append({
126
- 'sentiment': 0.5,
127
- 'confidence': 0.0,
128
- 'timestamp': self.current_step
129
- })
130
-
131
- def _apply_sentiment_to_reward(self, original_reward: float, action: int,
132
- info: Dict[str, Any]) -> float:
133
- """Modify reward based on sentiment analysis with bounds checking"""
134
- if self.sentiment_confidence < 0.3:
135
- return original_reward
136
-
137
- try:
138
- sentiment_multiplier = 1.0
139
- sentiment_score = self.current_sentiment # 0-1 normalized
140
-
141
- # Define action mappings (adjust based on your action space)
142
- # Assuming: 0=hold, 1=buy, 2=sell, 3=close
143
- bullish_threshold = 0.6
144
- bearish_threshold = 0.4
145
-
146
- if sentiment_score > bullish_threshold: # Bullish
147
- if action == 1: # Buy
148
- sentiment_multiplier += self.sentiment_influence * self.sentiment_confidence
149
- elif action == 2: # Sell short
150
- sentiment_multiplier -= self.sentiment_influence * 0.3 * self.sentiment_confidence
151
- elif action == 3: # Close
152
- sentiment_multiplier -= self.sentiment_influence * 0.2 * self.sentiment_confidence
153
-
154
- elif sentiment_score < bearish_threshold: # Bearish
155
- if action == 2: # Sell short
156
- sentiment_multiplier += self.sentiment_influence * self.sentiment_confidence
157
- elif action == 1: # Buy
158
- sentiment_multiplier -= self.sentiment_influence * 0.5 * self.sentiment_confidence
159
- elif action == 3: # Close
160
- sentiment_multiplier += self.sentiment_influence * 0.3 * self.sentiment_confidence
161
-
162
- # Apply trend momentum if enough history
163
- trend_multiplier = self._calculate_sentiment_trend_multiplier()
164
- sentiment_multiplier += trend_multiplier
165
-
166
- # Clamp multiplier to reasonable bounds
167
- sentiment_multiplier = np.clip(sentiment_multiplier, 0.5, 2.0)
168
-
169
- enhanced_reward = original_reward * sentiment_multiplier
170
-
171
- # Ensure reward doesn't become extreme
172
- max_reward = abs(original_reward) * 2.5 if original_reward != 0 else 10.0
173
- return np.clip(enhanced_reward, -max_reward, max_reward)
174
-
175
- except Exception as e:
176
- logger.error(f"Error in sentiment reward calculation: {e}")
177
- return original_reward
178
-
179
- def _calculate_sentiment_trend_multiplier(self) -> float:
180
- """Calculate trend-based multiplier from sentiment history"""
181
- if len(self.sentiment_history) < 10:
182
- return 0.0
183
-
184
- try:
185
- # Get recent and previous sentiment values
186
- recent_sentiments = [h['sentiment'] for h in list(self.sentiment_history)[-5:]]
187
- prev_sentiments = [h['sentiment'] for h in list(self.sentiment_history)[-10:-5]]
188
-
189
- recent_avg = np.mean(recent_sentiments)
190
- prev_avg = np.mean(prev_sentiments)
191
-
192
- trend = recent_avg - prev_avg
193
- # Scale trend influence
194
- trend_multiplier = np.tanh(trend * 5) * self.sentiment_influence * 0.3
195
- return float(trend_multiplier)
196
-
197
- except Exception as e:
198
- logger.warning(f"Error calculating trend multiplier: {e}")
199
- return 0.0
200
-
201
- def _enhance_observation(self, original_observation: np.ndarray) -> np.ndarray:
202
- """Enhance observation with sentiment information"""
203
- if not self.use_sentiment or original_observation is None:
204
- return original_observation
205
-
206
- try:
207
- # For now, return original observation
208
- # Future: could concatenate sentiment as additional channels or metadata
209
- return original_observation.copy()
210
- except Exception as e:
211
- logger.warning(f"Error enhancing observation: {e}")
212
- return original_observation
213
-
214
- def _get_safe_observation(self) -> np.ndarray:
215
- """Get a safe default observation"""
216
- try:
217
- # Try to get current observation from base env
218
- if hasattr(self, 'current_observation'):
219
- return self.current_observation.copy()
220
- # Return zero observation of expected shape
221
- return np.zeros((84, 84, 4), dtype=np.float32)
222
- except:
223
- return np.zeros((84, 84, 4), dtype=np.float32)
224
-
225
- def get_sentiment_analysis(self) -> Dict[str, Any]:
226
- """Get detailed sentiment analysis with safety checks"""
227
- if not self.use_sentiment:
228
- return {"error": "Sentiment analysis disabled", "sentiment": 0.5, "confidence": 0.0}
229
-
230
- try:
231
- trend_direction = self._calculate_sentiment_trend_direction()
232
- return {
233
- "current_sentiment": float(self.current_sentiment),
234
- "sentiment_confidence": float(self.sentiment_confidence),
235
- "sentiment_trend": trend_direction,
236
- "influence_level": float(self.sentiment_influence),
237
- "history_length": len(self.sentiment_history),
238
- "update_freq": self.sentiment_update_freq,
239
- "last_update_step": self.current_step
240
- }
241
- except Exception as e:
242
- logger.error(f"Error in get_sentiment_analysis: {e}")
243
- return {
244
- "error": str(e),
245
- "sentiment": 0.5,
246
- "confidence": 0.0,
247
- "trend": "unknown"
248
- }
249
-
250
- def _calculate_sentiment_trend_direction(self) -> str:
251
- """Calculate sentiment trend direction"""
252
- if len(self.sentiment_history) < 5:
253
- return "insufficient_data"
254
-
255
- try:
256
- recent_avg = np.mean([h['sentiment'] for h in list(self.sentiment_history)[-5:]])
257
- prev_avg = np.mean([h['sentiment'] for h in list(self.sentiment_history)[-10:-5]]) if len(self.sentiment_history) >= 10 else recent_avg
258
-
259
- diff = recent_avg - prev_avg
260
- if diff > 0.05:
261
- return "bullish"
262
- elif diff < -0.05:
263
- return "bearish"
264
- else:
265
- return "neutral"
266
- except:
267
- return "error"
268
-
269
- def reset(self) -> np.ndarray:
270
- """Reset environment with sentiment state"""
271
- try:
272
- observation = super().reset()
273
- self.current_step = 0
274
- self.sentiment_history.clear()
275
- self.current_sentiment = 0.5
276
- self.sentiment_confidence = 0.0
277
- logger.info("Environment reset with sentiment state")
278
- return observation
279
- except Exception as e:
280
- logger.error(f"Error in reset: {e}")
281
- # Force safe reset
282
- super().reset()
283
- self.current_step = 0
284
- self.sentiment_history.clear()
285
- return np.zeros((84, 84, 4), dtype=np.float32)
286
-
287
- @property
288
- def action_space_size(self) -> int:
289
- """Get action space size from base environment"""
290
- try:
291
- return super().action_space.n if hasattr(super(), 'action_space') else 4
292
- except:
293
- return 4 # Default assumption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/environments/visual_trading_env.py DELETED
@@ -1,228 +0,0 @@
1
- import numpy as np
2
- import matplotlib.pyplot as plt
3
- from PIL import Image
4
- import io
5
-
6
- class VisualTradingEnvironment:
7
- def __init__(self, initial_balance=10000, risk_level="Medium", asset_type="Stock"):
8
- self.initial_balance = float(initial_balance)
9
- self.risk_level = risk_level
10
- self.asset_type = asset_type
11
-
12
- # Risk multipliers
13
- risk_multipliers = {"Low": 0.5, "Medium": 1.0, "High": 2.0}
14
- self.risk_multiplier = risk_multipliers.get(risk_level, 1.0)
15
-
16
- # Generate market data
17
- self._generate_market_data()
18
-
19
- # Initialize state
20
- self.reset()
21
-
22
- def _generate_market_data(self, num_points=1000):
23
- """Generate realistic synthetic market data"""
24
- np.random.seed(42)
25
-
26
- # Base parameters based on asset type
27
- base_params = {
28
- "Stock": {"volatility": 0.01, "trend": 0.0005},
29
- "Crypto": {"volatility": 0.02, "trend": 0.001},
30
- "Forex": {"volatility": 0.005, "trend": 0.0002}
31
- }
32
-
33
- params = base_params.get(self.asset_type, base_params["Stock"])
34
- volatility = params["volatility"] * self.risk_multiplier
35
- trend = params["trend"]
36
-
37
- prices = [100.0]
38
- for i in range(1, num_points):
39
- # Random walk with trend and some mean reversion
40
- change = np.random.normal(trend, volatility)
41
- # Add some mean reversion
42
- mean_reversion = (100 - prices[-1]) * 0.001
43
- price = max(1.0, prices[-1] * (1 + change) + mean_reversion)
44
- prices.append(price)
45
-
46
- self.price_data = np.array(prices)
47
-
48
- def _get_visual_observation(self):
49
- """Generate visual representation of current market state"""
50
- try:
51
- # Get recent price window
52
- window_size = 50
53
- start_idx = max(0, self.current_step - window_size)
54
- end_idx = self.current_step + 1
55
-
56
- if end_idx > len(self.price_data):
57
- end_idx = len(self.price_data)
58
-
59
- prices = self.price_data[start_idx:end_idx]
60
-
61
- # Create matplotlib figure with fixed size
62
- fig, ax = plt.subplots(figsize=(4.2, 4.2), dpi=20, facecolor='black')
63
- ax.set_facecolor('black')
64
-
65
- # Plot price if we have data
66
- if len(prices) > 0:
67
- ax.plot(prices, color='cyan', linewidth=1.5)
68
-
69
- # Remove axes for cleaner visual
70
- ax.set_xticks([])
71
- ax.set_yticks([])
72
- ax.spines['top'].set_visible(False)
73
- ax.spines['right'].set_visible(False)
74
- ax.spines['bottom'].set_visible(False)
75
- ax.spines['left'].set_visible(False)
76
-
77
- # Set fixed limits to ensure consistent size
78
- ax.set_xlim(0, 50)
79
- if len(prices) > 0:
80
- price_min, price_max = min(prices), max(prices)
81
- price_range = price_max - price_min
82
- if price_range == 0:
83
- price_range = 1
84
- ax.set_ylim(price_min - price_range * 0.1, price_max + price_range * 0.1)
85
- else:
86
- ax.set_ylim(0, 100)
87
-
88
- # Convert to numpy array with consistent size
89
- buf = io.BytesIO()
90
- plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, facecolor='black', dpi=20)
91
- buf.seek(0)
92
- img = Image.open(buf).convert('RGB')
93
-
94
- # Resize to consistent dimensions
95
- img = img.resize((84, 84), Image.Resampling.LANCZOS)
96
- img_array = np.array(img)
97
-
98
- plt.close(fig)
99
-
100
- # Create attention map with same dimensions
101
- attention_map = np.zeros((84, 84), dtype=np.uint8)
102
- if len(prices) > 1:
103
- recent_change = (prices[-1] - prices[-2]) / prices[-2]
104
- intensity = min(255, abs(recent_change) * 5000)
105
-
106
- # Simple attention based on price movement
107
- center_x, center_y = 42, 42
108
- size = max(5, int(intensity / 50))
109
-
110
- for i in range(max(0, center_x-size), min(84, center_x+size)):
111
- for j in range(max(0, center_y-size), min(84, center_y+size)):
112
- distance = np.sqrt((i-center_x)**2 + (j-center_y)**2)
113
- if distance <= size:
114
- attention_value = intensity * (1 - distance/size)
115
- attention_map[i, j] = max(attention_map[i, j], int(attention_value))
116
-
117
- # Combine RGB with attention map
118
- visual_obs = np.concatenate([
119
- img_array,
120
- attention_map[:, :, np.newaxis] # Add channel dimension
121
- ], axis=2)
122
-
123
- return visual_obs
124
-
125
- except Exception as e:
126
- print(f"Error in visual observation: {e}")
127
- # Return default observation in case of error
128
- return np.zeros((84, 84, 4), dtype=np.uint8)
129
-
130
- def reset(self):
131
- """Reset environment to initial state"""
132
- self.current_step = 50 # Start with some history
133
- self.balance = self.initial_balance
134
- self.position_size = 0.0
135
- self.entry_price = 0.0
136
- self.net_worth = self.initial_balance
137
- self.total_trades = 0
138
- self.done = False
139
-
140
- return self._get_visual_observation()
141
-
142
- def step(self, action):
143
- """Execute one trading step"""
144
- try:
145
- current_price = self.price_data[self.current_step]
146
- prev_net_worth = self.net_worth
147
-
148
- reward = 0.0
149
-
150
- # Execute action
151
- if action == 1 and self.position_size == 0: # Buy
152
- # Risk-adjusted position sizing
153
- position_value = self.balance * 0.1 * self.risk_multiplier
154
- self.position_size = position_value / current_price
155
- self.entry_price = current_price
156
- self.balance -= position_value
157
- self.total_trades += 1
158
- reward = -0.01 # Small penalty for transaction
159
-
160
- elif action == 2 and self.position_size > 0: # Sell (increase position)
161
- additional_value = self.balance * 0.05 * self.risk_multiplier
162
- additional_size = additional_value / current_price
163
- self.position_size += additional_size
164
- self.balance -= additional_value
165
- self.total_trades += 1
166
- reward = -0.005
167
-
168
- elif action == 3 and self.position_size > 0: # Close position
169
- close_value = self.position_size * current_price
170
- self.balance += close_value
171
- if self.entry_price > 0:
172
- profit_loss = (current_price - self.entry_price) / self.entry_price
173
- reward = profit_loss * 10 # Scale profit/loss
174
- self.position_size = 0.0
175
- self.entry_price = 0.0
176
- self.total_trades += 1
177
-
178
- # Update net worth
179
- position_value = self.position_size * current_price if self.position_size > 0 else 0.0
180
- self.net_worth = self.balance + position_value
181
-
182
- # Add small reward for portfolio growth
183
- if prev_net_worth > 0:
184
- portfolio_change = (self.net_worth - prev_net_worth) / prev_net_worth
185
- reward += portfolio_change * 5
186
-
187
- # Move to next step
188
- self.current_step += 1
189
-
190
- # Check if episode is done
191
- if self.current_step >= len(self.price_data) - 1:
192
- self.done = True
193
- # Final reward based on overall performance
194
- if self.initial_balance > 0:
195
- final_return = (self.net_worth - self.initial_balance) / self.initial_balance
196
- reward += final_return * 20
197
-
198
- info = {
199
- 'net_worth': float(self.net_worth),
200
- 'balance': float(self.balance),
201
- 'position_size': float(self.position_size),
202
- 'current_price': float(current_price),
203
- 'total_trades': int(self.total_trades),
204
- 'step': int(self.current_step)
205
- }
206
-
207
- obs = self._get_visual_observation()
208
- return obs, float(reward), bool(self.done), info
209
-
210
- except Exception as e:
211
- print(f"Error in step execution: {e}")
212
- # Return safe default values in case of error
213
- default_info = {
214
- 'net_worth': float(self.initial_balance),
215
- 'balance': float(self.initial_balance),
216
- 'position_size': 0.0,
217
- 'current_price': 100.0,
218
- 'total_trades': 0,
219
- 'step': int(self.current_step)
220
- }
221
- return self._get_visual_observation(), 0.0, True, default_info
222
-
223
- def get_price_history(self):
224
- """Get recent price history for visualization"""
225
- window_size = min(50, self.current_step)
226
- start_idx = max(0, self.current_step - window_size)
227
- prices = self.price_data[start_idx:self.current_step]
228
- return [float(price) for price in prices]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/sentiment/twitter_analyzer.py DELETED
@@ -1,495 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
4
- from textblob import TextBlob
5
- from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
6
- from typing import Dict, List, Tuple, Optional
7
- import time
8
- from datetime import datetime, timedelta
9
- import re
10
- import logging
11
- from functools import lru_cache
12
- import warnings
13
- warnings.filterwarnings('ignore')
14
-
15
- # Setup logging
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger(__name__)
18
-
19
- class AdvancedSentimentAnalyzer:
20
- def __init__(self, max_model_retries=3, cache_size=100):
21
- self.sentiment_models = {}
22
- self.vader_analyzer = None
23
- self.max_model_retries = max_model_retries
24
- self.cache = {} # Simple cache for expensive operations
25
-
26
- # Influencers with validation
27
- self.influencers = self._validate_influencers({
28
- 'elonmusk': {'name': 'Elon Musk', 'weight': 0.9, 'sector': 'all'},
29
- 'cz_binance': {'name': 'Changpeng Zhao', 'weight': 0.8, 'sector': 'crypto'},
30
- 'saylor': {'name': 'Michael Saylor', 'weight': 0.7, 'sector': 'bitcoin'},
31
- 'crypto_bitlord': {'name': 'Crypto Bitlord', 'weight': 0.6, 'sector': 'crypto'},
32
- 'aantonop': {'name': 'Andreas Antonopoulos', 'weight': 0.7, 'sector': 'bitcoin'},
33
- 'peterlbrandt': {'name': 'Peter Brandt', 'weight': 0.8, 'sector': 'trading'},
34
- 'nic__carter': {'name': 'Nic Carter', 'weight': 0.7, 'sector': 'crypto'},
35
- 'avalancheavax': {'name': 'Avalanche', 'weight': 0.6, 'sector': 'defi'}
36
- })
37
-
38
- self._initialize_vader()
39
-
40
- def _validate_influencers(self, influencers: Dict) -> Dict:
41
- """Validate and normalize influencer weights"""
42
- validated = {}
43
- total_weight = 0
44
-
45
- for username, data in influencers.items():
46
- if 0.0 <= data.get('weight', 0) <= 1.0:
47
- validated[username] = data
48
- total_weight += data['weight']
49
-
50
- # Normalize weights to sum to 1
51
- if total_weight > 0:
52
- for username in validated:
53
- validated[username]['weight'] /= total_weight
54
-
55
- logger.info(f"Validated {len(validated)} influencers with total weight {total_weight:.2f}")
56
- return validated
57
-
58
- def _initialize_vader(self):
59
- """Initialize VADER safely"""
60
- try:
61
- self.vader_analyzer = SentimentIntensityAnalyzer()
62
- logger.info("VADER analyzer initialized")
63
- except Exception as e:
64
- logger.warning(f"Failed to initialize VADER: {e}")
65
- self.vader_analyzer = None
66
-
67
- @lru_cache(maxsize=128)
68
- def _safe_pipeline_load(self, model_name: str):
69
- """Safely load pipeline with caching and retries"""
70
- for attempt in range(self.max_model_retries):
71
- try:
72
- pipeline_obj = pipeline(
73
- "sentiment-analysis",
74
- model=model_name,
75
- tokenizer=model_name,
76
- device=-1, # CPU only for stability
77
- return_all_scores=False
78
- )
79
- logger.info(f"Successfully loaded model: {model_name}")
80
- return pipeline_obj
81
- except Exception as e:
82
- logger.warning(f"Attempt {attempt + 1} failed for {model_name}: {e}")
83
- if attempt == self.max_model_retries - 1:
84
- return None
85
- time.sleep(1) # Brief delay before retry
86
-
87
- def initialize_models(self) -> bool:
88
- """Initialize all sentiment analysis models with fallback"""
89
- success_count = 0
90
-
91
- try:
92
- # Financial sentiment model
93
- financial_model = self._safe_pipeline_load(
94
- "mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis"
95
- )
96
- if financial_model:
97
- self.sentiment_models['financial'] = financial_model
98
- success_count += 1
99
-
100
- # General sentiment model with fallback
101
- general_model = self._safe_pipeline_load("distilbert-base-uncased-finetuned-sst-2-english")
102
- if general_model:
103
- self.sentiment_models['general'] = general_model
104
- success_count += 1
105
- else:
106
- # Fallback to basic pipeline
107
- try:
108
- self.sentiment_models['general'] = pipeline("sentiment-analysis")
109
- success_count += 1
110
- except:
111
- pass
112
-
113
- # Crypto-specific model with fallback
114
- crypto_model = self._safe_pipeline_load("ElKulako/cryptobert")
115
- if crypto_model:
116
- self.sentiment_models['crypto'] = crypto_model
117
- success_count += 1
118
- else:
119
- self.sentiment_models['crypto'] = self.sentiment_models.get('financial',
120
- self.sentiment_models.get('general'))
121
- success_count += 1 if self.sentiment_models['crypto'] else 0
122
-
123
- # At least one model should be available
124
- if success_count > 0:
125
- logger.info(f"✅ Loaded {success_count} sentiment models successfully!")
126
- return True
127
- else:
128
- logger.error("❌ No sentiment models could be loaded")
129
- return False
130
-
131
- except Exception as e:
132
- logger.error(f"❌ Critical error loading models: {e}")
133
- return False
134
-
135
- def analyze_text_sentiment(self, text: str) -> Dict:
136
- """Comprehensive sentiment analysis with robust error handling"""
137
- if not text or len(text.strip()) < 5:
138
- return self._default_sentiment()
139
-
140
- cache_key = hash(text.strip()[:100]) # Simple cache key
141
- if cache_key in self.cache:
142
- return self.cache[cache_key]
143
-
144
- try:
145
- cleaned_text = self._clean_text(text)
146
-
147
- # Analyze with available models
148
- model_results = []
149
-
150
- # Financial model
151
- if 'financial' in self.sentiment_models:
152
- model_results.append(self._analyze_model(cleaned_text, 'financial'))
153
-
154
- # General model
155
- if 'general' in self.sentiment_models:
156
- model_results.append(self._analyze_model(cleaned_text, 'general'))
157
-
158
- # Crypto model
159
- if 'crypto' in self.sentiment_models:
160
- model_results.append(self._analyze_model(cleaned_text, 'crypto'))
161
-
162
- # Rule-based models
163
- if self.vader_analyzer:
164
- model_results.append(self._analyze_vader(cleaned_text))
165
-
166
- model_results.append(self._analyze_textblob(cleaned_text))
167
-
168
- # Filter valid results
169
- valid_results = [r for r in model_results if r['score'] is not None]
170
-
171
- if not valid_results:
172
- return self._default_sentiment()
173
-
174
- # Weighted combination (prioritize financial/crypto models)
175
- weights = {
176
- 'financial': 0.35, 'crypto': 0.30, 'general': 0.20,
177
- 'vader': 0.10, 'textblob': 0.05
178
- }
179
-
180
- weighted_score = 0.0
181
- total_weight = 0.0
182
- confidences = []
183
-
184
- for result in valid_results:
185
- model_type = result.get('model_type', 'unknown')
186
- weight = weights.get(model_type, 0.1)
187
- weighted_score += result['score'] * weight
188
- total_weight += weight
189
- if 'confidence' in result:
190
- confidences.append(result['confidence'])
191
-
192
- if total_weight > 0:
193
- final_score = weighted_score / total_weight
194
- final_confidence = np.mean(confidences) if confidences else 0.0
195
- else:
196
- final_score = 0.5
197
- final_confidence = 0.0
198
-
199
- # Determine sentiment label
200
- sentiment_label = self._score_to_label(final_score)
201
-
202
- result = {
203
- "sentiment": sentiment_label,
204
- "score": float(final_score),
205
- "confidence": float(final_confidence),
206
- "urgency": self._detect_urgency(cleaned_text),
207
- "keywords": self._extract_keywords(cleaned_text),
208
- "models_used": len(valid_results),
209
- "text_snippet": cleaned_text[:100] + "..." if len(cleaned_text) > 100 else cleaned_text
210
- }
211
-
212
- # Cache result
213
- self.cache[cache_key] = result
214
- if len(self.cache) > 50: # Limit cache size
215
- self.cache.pop(next(iter(self.cache)))
216
-
217
- return result
218
-
219
- except Exception as e:
220
- logger.error(f"Error in sentiment analysis: {e}")
221
- return self._default_sentiment()
222
-
223
- def _analyze_model(self, text: str, model_type: str) -> Dict:
224
- """Generic model analysis with error handling"""
225
- try:
226
- model = self.sentiment_models[model_type]
227
- result = model(text[:512], truncation=True, max_length=512)[0] # Limit text length
228
-
229
- score_map = {
230
- 'negative': 0.0, 'NEGATIVE': 0.0,
231
- 'neutral': 0.5, 'NEUTRAL': 0.5,
232
- 'positive': 1.0, 'POSITIVE': 1.0
233
- }
234
-
235
- score = score_map.get(result['label'].upper(), 0.5)
236
- return {
237
- 'score': score,
238
- 'confidence': result['score'],
239
- 'model_type': model_type
240
- }
241
- except Exception as e:
242
- logger.debug(f"Model {model_type} failed: {e}")
243
- return {'score': None, 'confidence': 0.0, 'model_type': model_type}
244
-
245
- def _score_to_label(self, score: float) -> str:
246
- """Convert score to sentiment label"""
247
- if score > 0.6:
248
- return "bullish"
249
- elif score > 0.4:
250
- return "neutral"
251
- else:
252
- return "bearish"
253
-
254
- def _analyze_vader(self, text: str) -> Dict:
255
- """VADER analysis with error handling"""
256
- if not self.vader_analyzer:
257
- return {'score': None, 'confidence': 0.0, 'model_type': 'vader'}
258
-
259
- try:
260
- scores = self.vader_analyzer.polarity_scores(text)
261
- compound = (scores['compound'] + 1) / 2 # Normalize to 0-1
262
- return {
263
- 'score': compound,
264
- 'confidence': abs(scores['compound']),
265
- 'model_type': 'vader'
266
- }
267
- except Exception:
268
- return {'score': None, 'confidence': 0.0, 'model_type': 'vader'}
269
-
270
- def _analyze_textblob(self, text: str) -> Dict:
271
- """TextBlob analysis with error handling"""
272
- try:
273
- analysis = TextBlob(text)
274
- polarity = (analysis.sentiment.polarity + 1) / 2 # Normalize to 0-1
275
- return {
276
- 'score': polarity,
277
- 'confidence': abs(analysis.sentiment.polarity),
278
- 'model_type': 'textblob'
279
- }
280
- except Exception:
281
- return {'score': None, 'confidence': 0.0, 'model_type': 'textblob'}
282
-
283
- def _clean_text(self, text: str) -> str:
284
- """Enhanced text cleaning"""
285
- try:
286
- # Remove URLs
287
- text = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', text)
288
- # Remove mentions
289
- text = re.sub(r'@\w+', '', text)
290
- # Remove hashtags but keep text
291
- text = re.sub(r'#\w+', '', text)
292
- # Remove extra whitespace and normalize
293
- text = ' '.join(text.split())
294
- return text.strip()
295
- except:
296
- return text[:200] if len(text) > 200 else text
297
-
298
- def _extract_keywords(self, text: str) -> List[str]:
299
- """Extract financial keywords with better matching"""
300
- keyword_categories = {
301
- 'bullish': ['moon', 'rocket', 'bull', 'buy', 'long', 'growth', 'opportunity', 'bullrun'],
302
- 'bearish': ['crash', 'bear', 'sell', 'short', 'drop', 'dump', 'warning', 'risk', 'fud'],
303
- 'crypto': ['bitcoin', 'btc', 'ethereum', 'eth', 'crypto', 'blockchain', 'defi', 'nft'],
304
- 'urgency': ['now', 'urgent', 'immediately', 'alert', 'breaking', 'huge']
305
- }
306
-
307
- found = []
308
- text_lower = text.lower()
309
-
310
- for category, keywords in keyword_categories.items():
311
- for keyword in keywords:
312
- if re.search(rf'\b{keyword}\b', text_lower):
313
- found.append(f"{category}:{keyword}")
314
-
315
- return found[:5]
316
-
317
- def _detect_urgency(self, text: str) -> float:
318
- """Improved urgency detection"""
319
- urgency_indicators = ['!', 'urgent', 'breaking', 'alert', 'immediately', 'now', 'huge', 'massive']
320
- text_lower = text.lower()
321
-
322
- score = 0.0
323
- for indicator in urgency_indicators:
324
- if re.search(rf'\b{indicator}\b', text_lower):
325
- score += 0.15
326
-
327
- # Exclamation and question marks
328
- punctuation_count = text.count('!') + text.count('?')
329
- score += min(punctuation_count * 0.1, 0.3)
330
-
331
- # Caps lock indicator
332
- caps_ratio = sum(1 for c in text if c.isupper()) / len([c for c in text if c.isalpha()])
333
- score += min(caps_ratio * 0.5, 0.2)
334
-
335
- return min(score, 1.0)
336
-
337
- def _default_sentiment(self) -> Dict:
338
- """Safe default sentiment"""
339
- return {
340
- "sentiment": "neutral",
341
- "score": 0.5,
342
- "confidence": 0.0,
343
- "urgency": 0.0,
344
- "keywords": [],
345
- "models_used": 0,
346
- "text_snippet": ""
347
- }
348
-
349
- def get_influencer_sentiment(self, hours_back: int = 24) -> Dict:
350
- """Get weighted influencer sentiment with caching"""
351
- try:
352
- # Generate synthetic tweets (in production, replace with real API)
353
- tweets = self._generate_synthetic_tweets(hours_back)
354
- influencer_sentiments = {}
355
-
356
- for username, tweet_batch in tweets.items():
357
- if username not in self.influencers:
358
- continue
359
-
360
- tweet_sentiments = []
361
- for tweet in tweet_batch:
362
- sentiment = self.analyze_text_sentiment(tweet['text'])
363
- sentiment.update({
364
- 'timestamp': tweet['timestamp'],
365
- 'username': username
366
- })
367
- tweet_sentiments.append(sentiment)
368
-
369
- if tweet_sentiments:
370
- # Weighted average by confidence
371
- total_weighted = sum(s['score'] * s['confidence'] for s in tweet_sentiments)
372
- total_confidence = sum(s['confidence'] for s in tweet_sentiments)
373
-
374
- avg_score = total_weighted / total_confidence if total_confidence > 0 else 0.5
375
- avg_confidence = np.mean([s['confidence'] for s in tweet_sentiments])
376
-
377
- influencer_sentiments[username] = {
378
- 'score': float(avg_score),
379
- 'confidence': float(avg_confidence),
380
- 'weight': self.influencers[username]['weight'],
381
- 'tweet_count': len(tweet_sentiments),
382
- 'tweets': tweet_sentiments[:3]
383
- }
384
-
385
- # Calculate market sentiment
386
- if influencer_sentiments:
387
- total_weighted_score = sum(
388
- data['score'] * data['weight'] * data['confidence']
389
- for data in influencer_sentiments.values()
390
- )
391
- total_weight = sum(
392
- data['weight'] * data['confidence']
393
- for data in influencer_sentiments.values()
394
- )
395
-
396
- market_sentiment = (total_weighted_score / total_weight
397
- if total_weight > 0 else 0.5)
398
- avg_confidence = np.mean([d['confidence'] for d in influencer_sentiments.values()])
399
- else:
400
- market_sentiment = 0.5
401
- avg_confidence = 0.0
402
-
403
- return {
404
- "market_sentiment": float(market_sentiment),
405
- "confidence": float(avg_confidence),
406
- "influencer_count": len(influencer_sentiments),
407
- "total_tweets": sum(d['tweet_count'] for d in influencer_sentiments.values()),
408
- "timestamp": datetime.now().isoformat(),
409
- "influencers": influencer_sentiments
410
- }
411
-
412
- except Exception as e:
413
- logger.error(f"Error in get_influencer_sentiment: {e}")
414
- return {
415
- "market_sentiment": 0.5,
416
- "confidence": 0.0,
417
- "error": str(e),
418
- "timestamp": datetime.now().isoformat()
419
- }
420
-
421
- def _generate_synthetic_tweets(self, hours_back: int) -> Dict:
422
- """Generate realistic synthetic tweets for testing"""
423
- current_time = time.time()
424
- tweets = {}
425
- np.random.seed(int(current_time) % 10000) # Reproducible randomness
426
-
427
- # Simulate market conditions
428
- market_trend = np.sin(current_time / 3600) * 0.3 + 0.5
429
-
430
- for username in self.influencers:
431
- user_tweets = []
432
- base_sentiment = np.clip(market_trend + np.random.normal(0, 0.15), 0.1, 0.9)
433
-
434
- templates = self._get_user_templates(username, base_sentiment)
435
-
436
- for i in range(np.random.randint(1, 4)): # 1-3 tweets
437
- template = np.random.choice(templates)
438
- tweet_text = template.format(**self._get_template_vars(base_sentiment))
439
-
440
- # Add emojis occasionally
441
- if np.random.random() < 0.4:
442
- emojis = self._get_relevant_emojis(base_sentiment)
443
- tweet_text += " " + np.random.choice(emojis)
444
-
445
- user_tweets.append({
446
- 'text': tweet_text,
447
- 'timestamp': current_time - (i * 3600 * np.random.uniform(0.5, hours_back))
448
- })
449
-
450
- tweets[username] = user_tweets
451
-
452
- return tweets
453
-
454
- def _get_user_templates(self, username: str, sentiment: float) -> List[str]:
455
- """Get appropriate templates based on sentiment"""
456
- templates = {
457
- 'bullish': [
458
- "{action} looking strong! {emoji}",
459
- "Great {topic} developments ahead 🚀",
460
- "Bullish on {topic} {emoji}"
461
- ],
462
- 'bearish': [
463
- "Caution on {topic} {emoji}",
464
- "{action} facing challenges 📉",
465
- "Bearish signals for {topic}"
466
- ],
467
- 'neutral': [
468
- "Watching {topic} developments 👀",
469
- "{action} market update 📊",
470
- "Interesting {topic} news"
471
- ]
472
- }
473
-
474
- category = 'bullish' if sentiment > 0.6 else 'bearish' if sentiment < 0.4 else 'neutral'
475
- return templates[category]
476
-
477
- def _get_template_vars(self, sentiment: float) -> Dict:
478
- """Get variables for tweet templates"""
479
- topics = ['BTC', 'crypto', 'market', 'DeFi']
480
- actions = ['Bitcoin', 'ETH', 'market', 'altcoins']
481
-
482
- return {
483
- 'topic': np.random.choice(topics),
484
- 'action': np.random.choice(actions),
485
- 'emoji': np.random.choice(['📈', '📉', '🚀', '💎'])
486
- }
487
-
488
- def _get_relevant_emojis(self, sentiment: float) -> List[str]:
489
- """Get sentiment-relevant emojis"""
490
- if sentiment > 0.6:
491
- return ['🚀', '📈', '💎', '🔥']
492
- elif sentiment < 0.4:
493
- return ['📉', '😬', '⚠️', '💥']
494
- else:
495
- return ['📊', '👀', '🤔', '💭']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/config.py DELETED
@@ -1,290 +0,0 @@
1
- import json
2
- import os
3
- from typing import Dict, Any, Optional
4
- from dataclasses import dataclass, asdict, field
5
- from pathlib import Path
6
- import logging
7
-
8
- logger = logging.getLogger(__name__)
9
-
10
- @dataclass
11
- class TradingConfig:
12
- """Comprehensive trading configuration with validation and persistence"""
13
-
14
- # Environment settings
15
- initial_balance: float = 10000.0
16
- max_steps: int = 1000
17
- transaction_cost: float = 0.001
18
- risk_level: str = "Medium"
19
- asset_type: str = "Crypto"
20
-
21
- # AI Agent settings
22
- learning_rate: float = 0.001
23
- gamma: float = 0.99
24
- epsilon_start: float = 1.0
25
- epsilon_min: float = 0.01
26
- epsilon_decay: float = 0.9995
27
- memory_size: int = 10000
28
- batch_size: int = 32
29
- target_update_freq: int = 100
30
- gradient_clip: float = 1.0
31
-
32
- # Sentiment settings
33
- use_sentiment: bool = True
34
- sentiment_influence: float = 0.3
35
- sentiment_update_freq: int = 5
36
-
37
- # Visualization settings
38
- chart_width: int = 800
39
- chart_height: int = 600
40
- update_interval: int = 100
41
- enable_visualization: bool = True
42
-
43
- # Training settings
44
- max_episodes: int = 1000
45
- eval_episodes: int = 10
46
- eval_freq: int = 100
47
- save_freq: int = 500
48
- log_level: str = "INFO"
49
-
50
- # Paths
51
- model_dir: str = "models"
52
- log_dir: str = "logs"
53
- data_dir: str = "data"
54
-
55
- # Device settings
56
- use_cuda: bool = True
57
- device: str = "auto"
58
-
59
- def __post_init__(self):
60
- """Validate and initialize configuration"""
61
- self._validate()
62
- self._setup_paths()
63
- self._setup_device()
64
- self._setup_logging()
65
-
66
- def _validate(self):
67
- """Validate configuration parameters"""
68
- errors = []
69
-
70
- # Balance validation
71
- if self.initial_balance <= 0:
72
- errors.append("initial_balance must be positive")
73
-
74
- # Steps validation
75
- if self.max_steps <= 0:
76
- errors.append("max_steps must be positive")
77
-
78
- # Costs validation
79
- if not 0.0 <= self.transaction_cost <= 0.1:
80
- errors.append("transaction_cost should be between 0 and 0.1")
81
-
82
- # Learning rate validation
83
- if not 0.0001 <= self.learning_rate <= 0.1:
84
- errors.append("learning_rate should be between 0.0001 and 0.1")
85
-
86
- # Discount factor validation
87
- if not 0.0 <= self.gamma <= 1.0:
88
- errors.append("gamma must be between 0 and 1")
89
-
90
- # Epsilon validation
91
- if not 0.0 <= self.epsilon_min <= self.epsilon_start <= 1.0:
92
- errors.append("epsilon values must satisfy 0 <= epsilon_min <= epsilon_start <= 1")
93
-
94
- # Batch size validation
95
- if self.batch_size > self.memory_size:
96
- errors.append("batch_size cannot exceed memory_size")
97
-
98
- # Risk level validation
99
- valid_risks = ["Low", "Medium", "High"]
100
- if self.risk_level not in valid_risks:
101
- errors.append(f"risk_level must be one of {valid_risks}")
102
-
103
- # Asset type validation
104
- valid_assets = ["Crypto", "Stocks", "Forex", "Commodities"]
105
- if self.asset_type not in valid_assets:
106
- errors.append(f"asset_type must be one of {valid_assets}")
107
-
108
- # Sentiment influence validation
109
- if not 0.0 <= self.sentiment_influence <= 1.0:
110
- errors.append("sentiment_influence must be between 0 and 1")
111
-
112
- if errors:
113
- logger.error(f"Configuration validation errors: {errors}")
114
- raise ValueError(f"Invalid configuration: {'; '.join(errors)}")
115
-
116
- logger.info("Configuration validation passed")
117
-
118
- def _setup_paths(self):
119
- """Create necessary directories"""
120
- for path_attr in ['model_dir', 'log_dir', 'data_dir']:
121
- path = Path(getattr(self, path_attr))
122
- path.mkdir(parents=True, exist_ok=True)
123
- setattr(self, f"{path_attr}_path", path)
124
-
125
- def _setup_device(self):
126
- """Setup device configuration"""
127
- import torch
128
- if self.device == "auto":
129
- self.device = "cuda" if self.use_cuda and torch.cuda.is_available() else "cpu"
130
- else:
131
- if self.device not in ["cpu", "cuda", "mps"]:
132
- logger.warning(f"Unknown device {self.device}, defaulting to CPU")
133
- self.device = "cpu"
134
-
135
- logger.info(f"Using device: {self.device}")
136
-
137
- def _setup_logging(self):
138
- """Setup logging configuration"""
139
- import logging
140
- log_level = getattr(logging, self.log_level.upper())
141
- logging.getLogger().setLevel(log_level)
142
-
143
- def to_dict(self) -> Dict[str, Any]:
144
- """Convert config to dictionary, excluding sensitive paths"""
145
- config_dict = asdict(self)
146
- # Remove absolute paths for serialization
147
- for key in list(config_dict.keys()):
148
- if key.endswith('_path') or 'dir' in key:
149
- config_dict[key] = str(getattr(self, key)) if isinstance(getattr(self, key), Path) else getattr(self, key)
150
- return config_dict
151
-
152
- def to_json(self, filepath: Optional[str] = None) -> str:
153
- """Serialize config to JSON"""
154
- config_dict = self.to_dict()
155
- json_str = json.dumps(config_dict, indent=2, default=str)
156
-
157
- if filepath:
158
- with open(filepath, 'w') as f:
159
- f.write(json_str)
160
- logger.info(f"Config saved to {filepath}")
161
-
162
- return json_str
163
-
164
- @classmethod
165
- def from_json(cls, filepath: str) -> 'TradingConfig':
166
- """Load config from JSON file"""
167
- try:
168
- with open(filepath, 'r') as f:
169
- config_dict = json.load(f)
170
-
171
- # Create dataclass instance
172
- config = cls(**config_dict)
173
- logger.info(f"Config loaded from {filepath}")
174
- return config
175
- except Exception as e:
176
- logger.error(f"Error loading config from {filepath}: {e}")
177
- raise
178
-
179
- @classmethod
180
- def from_dict(cls, config_dict: Dict[str, Any]) -> 'TradingConfig':
181
- """Create config from dictionary"""
182
- return cls(**config_dict)
183
-
184
- def save(self, filepath: str):
185
- """Save config to file"""
186
- self.to_json(filepath)
187
-
188
- @staticmethod
189
- def load(filepath: str) -> 'TradingConfig':
190
- """Static method to load config"""
191
- return TradingConfig.from_json(filepath)
192
-
193
- def update(self, **kwargs):
194
- """Update config parameters and revalidate"""
195
- for key, value in kwargs.items():
196
- if hasattr(self, key):
197
- setattr(self, key, value)
198
- else:
199
- logger.warning(f"Unknown config parameter: {key}")
200
-
201
- self._validate()
202
- logger.info("Config updated and validated")
203
-
204
- def get_agent_params(self) -> Dict[str, Any]:
205
- """Get parameters specific to agent"""
206
- return {
207
- 'learning_rate': self.learning_rate,
208
- 'gamma': self.gamma,
209
- 'epsilon_start': self.epsilon_start,
210
- 'epsilon_min': self.epsilon_min,
211
- 'epsilon_decay': self.epsilon_decay,
212
- 'memory_size': self.memory_size,
213
- 'batch_size': self.batch_size,
214
- 'target_update_freq': self.target_update_freq,
215
- 'gradient_clip': self.gradient_clip,
216
- 'device': self.device
217
- }
218
-
219
- def get_env_params(self) -> Dict[str, Any]:
220
- """Get parameters specific to environment"""
221
- return {
222
- 'initial_balance': self.initial_balance,
223
- 'max_steps': self.max_steps,
224
- 'transaction_cost': self.transaction_cost,
225
- 'risk_level': self.risk_level,
226
- 'asset_type': self.asset_type,
227
- 'use_sentiment': self.use_sentiment,
228
- 'sentiment_influence': self.sentiment_influence,
229
- 'sentiment_update_freq': self.sentiment_update_freq
230
- }
231
-
232
- def __str__(self) -> str:
233
- """String representation of config"""
234
- return json.dumps(self.to_dict(), indent=2)
235
-
236
-
237
- # Legacy compatibility
238
- class LegacyTradingConfig:
239
- """Wrapper for backward compatibility"""
240
-
241
- def __init__(self, config_file: Optional[str] = None):
242
- if config_file and os.path.exists(config_file):
243
- self.config = TradingConfig.from_json(config_file)
244
- else:
245
- self.config = TradingConfig()
246
-
247
- def __getattr__(self, name):
248
- return getattr(self.config, name)
249
-
250
- def to_dict(self):
251
- return self.config.to_dict()
252
-
253
-
254
- # Default config instance
255
- DEFAULT_CONFIG = TradingConfig()
256
-
257
- # Example usage and config loading
258
- def create_config_from_env() -> TradingConfig:
259
- """Create config from environment variables"""
260
- import os
261
- config_dict = {}
262
-
263
- env_mappings = {
264
- 'INITIAL_BALANCE': 'initial_balance',
265
- 'MAX_STEPS': 'max_steps',
266
- 'LEARNING_RATE': 'learning_rate',
267
- 'BATCH_SIZE': 'batch_size',
268
- 'USE_CUDA': 'use_cuda'
269
- }
270
-
271
- for env_var, config_key in env_mappings.items():
272
- env_value = os.getenv(env_var)
273
- if env_value is not None:
274
- try:
275
- # Try to convert to appropriate type
276
- if config_key in ['initial_balance', 'learning_rate']:
277
- config_dict[config_key] = float(env_value)
278
- elif config_key in ['max_steps', 'batch_size']:
279
- config_dict[config_key] = int(env_value)
280
- elif config_key == 'use_cuda':
281
- config_dict[config_key] = env_value.lower() in ('true', '1', 'yes')
282
- except ValueError:
283
- logger.warning(f"Invalid environment variable {env_var}: {env_value}")
284
-
285
- if config_dict:
286
- base_config = TradingConfig()
287
- base_config.update(**config_dict)
288
- return base_config
289
-
290
- return DEFAULT_CONFIG
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/visualizers/chart_renderer.py DELETED
@@ -1,410 +0,0 @@
1
- import plotly.graph_objects as go
2
- from plotly.subplots import make_subplots
3
- import plotly.express as px
4
- import numpy as np
5
- import pandas as pd
6
- from typing import List, Dict, Any, Optional, Union
7
- import logging
8
- from datetime import datetime
9
- import warnings
10
- warnings.filterwarnings('ignore')
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
- class ChartRenderer:
15
- """Advanced chart renderer for trading visualizations with error handling"""
16
-
17
- def __init__(self, theme: str = "plotly_white", default_height: int = 400):
18
- self.theme = theme
19
- self.default_height = default_height
20
- self._validate_plotly()
21
-
22
- def _validate_plotly(self):
23
- """Validate Plotly installation and capabilities"""
24
- try:
25
- import plotly
26
- logger.info(f"Plotly version: {plotly.__version__}")
27
- except ImportError:
28
- raise ImportError("Plotly is required for ChartRenderer")
29
-
30
- def _safe_data_validation(self, data, expected_len: Optional[int] = None,
31
- data_type: str = "data") -> bool:
32
- """Validate input data safely"""
33
- if data is None or len(data) == 0:
34
- logger.warning(f"No {data_type} provided")
35
- return False
36
-
37
- if expected_len and len(data) != expected_len:
38
- logger.warning(f"{data_type} length mismatch: expected {expected_len}, got {len(data)}")
39
-
40
- if isinstance(data, (list, np.ndarray)):
41
- if np.any(np.isnan(data)) or np.any(np.isinf(data)):
42
- logger.warning(f"{data_type} contains NaN or Inf values")
43
- return False
44
-
45
- return True
46
-
47
- def render_price_chart(self, prices: Union[List[float], np.ndarray],
48
- actions: Optional[List[int]] = None,
49
- current_step: int = 0,
50
- title: Optional[str] = None,
51
- height: Optional[int] = None) -> go.Figure:
52
- """Render interactive price chart with trading actions"""
53
- fig = go.Figure()
54
- height = height or self.default_height
55
-
56
- # Validate data
57
- if not self._safe_data_validation(prices, data_type="prices"):
58
- return self._create_empty_figure("No Price Data", height)
59
-
60
- try:
61
- # Convert to numpy for consistency
62
- prices = np.array(prices, dtype=np.float64)
63
- time_steps = np.arange(len(prices))
64
-
65
- # Add main price trace
66
- fig.add_trace(go.Scatter(
67
- x=time_steps,
68
- y=prices,
69
- mode='lines',
70
- name='Price',
71
- line=dict(color='#1f77b4', width=2),
72
- hovertemplate='<b>Step %{x}</b><br>Price: $%{y:.2f}<extra></extra>'
73
- ))
74
-
75
- # Add action markers with validation
76
- if actions and self._safe_data_validation(actions, len(prices), "actions"):
77
- self._add_action_markers(fig, prices, actions, time_steps)
78
-
79
- # Add current step indicator
80
- if 0 <= current_step < len(prices):
81
- fig.add_vline(
82
- x=current_step,
83
- line_dash="dash",
84
- line_color="orange",
85
- annotation_text=f"Current Step ({current_step})",
86
- annotation_position="top right"
87
- )
88
-
89
- # Calculate and add key metrics
90
- self._add_price_metrics(fig, prices)
91
-
92
- title = title or f"Asset Price Evolution (Step: {current_step})"
93
- fig.update_layout(
94
- title={
95
- 'text': title,
96
- 'x': 0.5,
97
- 'xanchor': 'center',
98
- 'font': {'size': 16}
99
- },
100
- xaxis_title="Time Step",
101
- yaxis_title="Price ($)",
102
- height=height + 100,
103
- showlegend=True,
104
- template=self.theme,
105
- hovermode='x unified'
106
- )
107
-
108
- fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
109
- fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
110
-
111
- return fig
112
-
113
- except Exception as e:
114
- logger.error(f"Error rendering price chart: {e}")
115
- return self._create_empty_figure("Error Rendering Price Chart", height)
116
-
117
- def _add_action_markers(self, fig: go.Figure, prices: np.ndarray,
118
- actions: List[int], time_steps: np.ndarray):
119
- """Add buy/sell/close action markers to figure"""
120
- action_configs = {
121
- 1: {'name': 'Buy', 'color': '#2ca02c', 'symbol': 'triangle-up'},
122
- 2: {'name': 'Sell', 'color': '#d62728', 'symbol': 'triangle-down'},
123
- 3: {'name': 'Close', 'color': '#ff7f0e', 'symbol': 'x'}
124
- }
125
-
126
- for action_id, config in action_configs.items():
127
- indices = [i for i, a in enumerate(actions) if a == action_id]
128
- if indices:
129
- action_prices = prices[indices]
130
- fig.add_trace(go.Scatter(
131
- x=[time_steps[i] for i in indices],
132
- y=action_prices,
133
- mode='markers',
134
- name=config['name'],
135
- marker=dict(
136
- color=config['color'],
137
- size=12,
138
- symbol=config['symbol'],
139
- line=dict(width=2, color='white')
140
- ),
141
- hovertemplate=f'<b>{config["name"]}</b><br>Step: %{{x}}<br>Price: $%{{y:.2f}}<extra></extra>',
142
- showlegend=True
143
- ))
144
-
145
- def _add_price_metrics(self, fig: go.Figure, prices: np.ndarray):
146
- """Add price statistics as annotations"""
147
- if len(prices) < 2:
148
- return
149
-
150
- max_price = np.max(prices)
151
- min_price = np.min(prices)
152
- avg_price = np.mean(prices)
153
-
154
- # Add horizontal reference lines
155
- fig.add_hline(y=max_price, line_dash="dot", line_color="green",
156
- annotation_text=f"Max: ${max_price:.2f}")
157
- fig.add_hline(y=min_price, line_dash="dot", line_color="red",
158
- annotation_text=f"Min: ${min_price:.2f}")
159
- fig.add_hline(y=avg_price, line_dash="dash", line_color="blue",
160
- annotation_text=f"Avg: ${avg_price:.2f}")
161
-
162
- def create_performance_chart(self, net_worth_history: List[float],
163
- reward_history: Optional[List[float]] = None,
164
- initial_balance: float = 10000,
165
- height: Optional[int] = None) -> go.Figure:
166
- """Create comprehensive performance dashboard"""
167
- height = height or 600
168
-
169
- if not self._safe_data_validation(net_worth_history, data_type="net worth history"):
170
- return self._create_empty_figure("No Performance Data", height)
171
-
172
- try:
173
- fig = make_subplots(
174
- rows=2, cols=2,
175
- subplot_titles=['Portfolio Value', 'Returns vs Initial Balance',
176
- 'Cumulative Reward', 'Reward Distribution'],
177
- vertical_spacing=0.1,
178
- horizontal_spacing=0.1,
179
- specs=[[{"secondary_y": False}, {"secondary_y": False}],
180
- [{"secondary_y": False}, {"secondary_y": False}]]
181
- )
182
-
183
- steps = np.arange(len(net_worth_history))
184
- net_worth = np.array(net_worth_history)
185
-
186
- # Portfolio value
187
- fig.add_trace(
188
- go.Scatter(x=steps, y=net_worth, mode='lines', name='Net Worth',
189
- line=dict(color='#2ca02c', width=3)),
190
- row=1, col=1
191
- )
192
-
193
- # Initial balance reference
194
- fig.add_hline(y=initial_balance, line_dash="dash", line_color="red",
195
- annotation_text=f"Initial: ${initial_balance:.2f}",
196
- row=1, col=1)
197
-
198
- # Returns comparison
199
- returns = (net_worth - initial_balance) / initial_balance * 100
200
- fig.add_trace(
201
- go.Scatter(x=steps, y=returns, mode='lines', name='Returns %',
202
- line=dict(color='#ff7f0e', width=2)),
203
- row=1, col=2
204
- )
205
- fig.add_hline(y=0, line_dash="solid", line_color="gray", row=1, col=2)
206
-
207
- # Cumulative reward
208
- if reward_history and self._safe_data_validation(reward_history):
209
- cum_reward = np.cumsum(reward_history)
210
- fig.add_trace(
211
- go.Scatter(x=steps[:len(cum_reward)], y=cum_reward, mode='lines',
212
- name='Cumulative Reward', line=dict(color='#9467bd', width=2)),
213
- row=2, col=1
214
- )
215
-
216
- # Reward distribution
217
- if reward_history:
218
- fig.add_trace(
219
- go.Histogram(x=reward_history, name='Reward Distribution',
220
- marker_color='#1f77b4', opacity=0.7),
221
- row=2, col=2
222
- )
223
-
224
- fig.update_layout(
225
- height=height,
226
- showlegend=True,
227
- title_text="Trading Performance Dashboard",
228
- template=self.theme
229
- )
230
-
231
- # Update axis titles
232
- fig.update_yaxes(title_text="Value ($)", row=1, col=1)
233
- fig.update_yaxes(title_text="Returns (%)", row=1, col=2)
234
- fig.update_yaxes(title_text="Cumulative Reward", row=2, col=1)
235
- fig.update_xaxes(title_text="Steps", row=2, col=1)
236
- fig.update_xaxes(title_text="Reward Value", row=2, col=2)
237
-
238
- return fig
239
-
240
- except Exception as e:
241
- logger.error(f"Error creating performance chart: {e}")
242
- return self._create_empty_figure("Error in Performance Chart", height)
243
-
244
- def create_action_distribution(self, actions: List[int],
245
- title: Optional[str] = None,
246
- height: Optional[int] = None) -> go.Figure:
247
- """Create interactive action distribution visualization"""
248
- height = height or 350
249
-
250
- if not self._safe_data_validation(actions, data_type="actions"):
251
- return self._create_empty_figure("No Actions Data", height)
252
-
253
- try:
254
- action_names = ['Hold', 'Buy', 'Sell', 'Close']
255
- action_counts = [actions.count(i) for i in range(4)]
256
- total_actions = sum(action_counts)
257
-
258
- colors = ['#1f77b4', '#2ca02c', '#d62728', '#ff7f0e']
259
-
260
- fig = go.Figure(data=[go.Pie(
261
- labels=action_names,
262
- values=action_counts,
263
- hole=0.4,
264
- marker_colors=colors,
265
- textinfo='label+percent+value',
266
- hovertemplate='<b>%{label}</b><br>Count: %{value}<br>Percentage: %{percent}<extra></extra>',
267
- pull=[0, 0, 0, 0] # Equal spacing
268
- )])
269
-
270
- title = title or f"Action Distribution (Total: {total_actions} actions)"
271
- fig.update_layout(
272
- title={
273
- 'text': title,
274
- 'x': 0.5,
275
- 'xanchor': 'center'
276
- },
277
- height=height,
278
- showlegend=True,
279
- template=self.theme,
280
- annotations=[dict(
281
- text='Trading Actions',
282
- x=0.5, y=0.5,
283
- font_size=16,
284
- showarrow=False
285
- )]
286
- )
287
-
288
- return fig
289
-
290
- except Exception as e:
291
- logger.error(f"Error creating action distribution: {e}")
292
- return self._create_empty_figure("Error in Action Distribution", height)
293
-
294
- def create_training_progress(self, training_history: List[Dict],
295
- window_size: int = 10,
296
- height: Optional[int] = None) -> go.Figure:
297
- """Create comprehensive training progress dashboard"""
298
- height = height or 700
299
-
300
- if not training_history:
301
- return self._create_empty_figure("No Training Data", height)
302
-
303
- try:
304
- # Extract data safely
305
- episodes = [h.get('episode', i) for i, h in enumerate(training_history)]
306
- rewards = [h.get('reward', 0) for h in training_history]
307
- net_worths = [h.get('net_worth', 0) for h in training_history]
308
- losses = [h.get('loss', 0) for h in training_history]
309
-
310
- fig = make_subplots(
311
- rows=2, cols=2,
312
- subplot_titles=['Total Reward per Episode', 'Final Net Worth',
313
- 'Training Loss', 'Moving Average Reward'],
314
- specs=[[{"secondary_y": False}, {"secondary_y": False}],
315
- [{"secondary_y": False}, {"secondary_y": False}]]
316
- )
317
-
318
- # Rewards
319
- fig.add_trace(go.Scatter(
320
- x=episodes, y=rewards, mode='lines+markers',
321
- name='Episode Reward', line=dict(color='#1f77b4', width=2),
322
- marker=dict(size=4)
323
- ), row=1, col=1)
324
-
325
- # Net worth
326
- fig.add_trace(go.Scatter(
327
- x=episodes, y=net_worths, mode='lines+markers',
328
- name='Final Net Worth', line=dict(color='#2ca02c', width=2),
329
- marker=dict(size=4)
330
- ), row=1, col=2)
331
-
332
- # Loss (only if we have meaningful loss values)
333
- valid_losses = [l for l in losses if l > 0]
334
- if valid_losses:
335
- fig.add_trace(go.Scatter(
336
- x=episodes, y=losses, mode='lines',
337
- name='Training Loss', line=dict(color='#d62728', width=2)
338
- ), row=2, col=1)
339
-
340
- # Moving average
341
- if len(rewards) >= window_size:
342
- ma_rewards = pd.Series(rewards).rolling(window=window_size, min_periods=1).mean()
343
- fig.add_trace(go.Scatter(
344
- x=episodes, y=ma_rewards, mode='lines',
345
- name=f'MA Reward ({window_size})',
346
- line=dict(color='#ff7f0e', width=3, dash='dash')
347
- ), row=2, col=2)
348
-
349
- fig.update_layout(
350
- height=height,
351
- showlegend=True,
352
- title_text=f"Training Progress - {len(episodes)} Episodes",
353
- template=self.theme
354
- )
355
-
356
- # Update axes
357
- fig.update_yaxes(title_text="Reward", row=1, col=1)
358
- fig.update_yaxes(title_text="Net Worth ($)", row=1, col=2)
359
- fig.update_yaxes(title_text="Loss", row=2, col=1)
360
- fig.update_xaxes(title_text="Episodes", row=2, col=1)
361
-
362
- return fig
363
-
364
- except Exception as e:
365
- logger.error(f"Error creating training progress chart: {e}")
366
- return self._create_empty_figure("Error in Training Progress", height)
367
-
368
- def _create_empty_figure(self, title: str, height: int) -> go.Figure:
369
- """Create a safe empty figure"""
370
- fig = go.Figure()
371
- fig.update_layout(
372
- title=title,
373
- height=height,
374
- template=self.theme
375
- )
376
- return fig
377
-
378
- def save_chart(self, fig: go.Figure, filename: str, format: str = 'html'):
379
- """Save chart to file"""
380
- try:
381
- if format == 'html':
382
- fig.write_html(filename)
383
- elif format == 'png':
384
- fig.write_image(filename)
385
- elif format == 'pdf':
386
- fig.write_image(filename, width=1200, height=800)
387
- logger.info(f"Chart saved as {filename}")
388
- except Exception as e:
389
- logger.error(f"Error saving chart: {e}")
390
-
391
- def show(self, fig: go.Figure):
392
- """Display chart (if in interactive environment)"""
393
- try:
394
- fig.show()
395
- except Exception as e:
396
- logger.warning(f"Could not display chart: {e}")
397
-
398
-
399
- # Utility functions for batch rendering
400
- def render_dashboard(prices, actions, net_worth, rewards, config):
401
- """Create a complete trading dashboard"""
402
- renderer = ChartRenderer()
403
-
404
- figs = {
405
- 'price': renderer.render_price_chart(prices, actions),
406
- 'performance': renderer.create_performance_chart(net_worth, rewards, config.initial_balance),
407
- 'actions': renderer.create_action_distribution(actions)
408
- }
409
-
410
- return figs