Spaces:
Sleeping
Sleeping
Delete src
Browse files- src/agents/advanced_agent.py +0 -310
- src/agents/visual_agent.py +0 -153
- src/environments/advanced_trading_env.py +0 -293
- src/environments/visual_trading_env.py +0 -228
- src/sentiment/twitter_analyzer.py +0 -495
- src/utils/config.py +0 -290
- src/visualizers/chart_renderer.py +0 -410
src/agents/advanced_agent.py
DELETED
|
@@ -1,310 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.optim as optim
|
| 4 |
-
import numpy as np
|
| 5 |
-
from collections import deque
|
| 6 |
-
import random
|
| 7 |
-
import warnings
|
| 8 |
-
warnings.filterwarnings('ignore')
|
| 9 |
-
|
| 10 |
-
class EnhancedTradingNetwork(nn.Module):
|
| 11 |
-
def __init__(self, state_dim, action_dim, sentiment_dim=2):
|
| 12 |
-
super(EnhancedTradingNetwork, self).__init__()
|
| 13 |
-
|
| 14 |
-
# Visual processing branch
|
| 15 |
-
self.visual_conv = nn.Sequential(
|
| 16 |
-
nn.Conv2d(4, 16, kernel_size=4, stride=2),
|
| 17 |
-
nn.ReLU(),
|
| 18 |
-
nn.Conv2d(16, 32, kernel_size=4, stride=2),
|
| 19 |
-
nn.ReLU(),
|
| 20 |
-
nn.Conv2d(32, 32, kernel_size=3, stride=1),
|
| 21 |
-
nn.ReLU(),
|
| 22 |
-
nn.AdaptiveAvgPool2d((8, 8))
|
| 23 |
-
)
|
| 24 |
-
|
| 25 |
-
# Calculate the output size after conv layers
|
| 26 |
-
self.conv_output_size = 32 * 8 * 8
|
| 27 |
-
|
| 28 |
-
self.visual_fc = nn.Sequential(
|
| 29 |
-
nn.Linear(self.conv_output_size, 256),
|
| 30 |
-
nn.ReLU(),
|
| 31 |
-
nn.Dropout(0.3)
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
# Sentiment processing branch
|
| 35 |
-
self.sentiment_fc = nn.Sequential(
|
| 36 |
-
nn.Linear(sentiment_dim, 64),
|
| 37 |
-
nn.ReLU(),
|
| 38 |
-
nn.Dropout(0.2),
|
| 39 |
-
nn.Linear(64, 32),
|
| 40 |
-
nn.ReLU()
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
# Combined decision making
|
| 44 |
-
self.combined_fc = nn.Sequential(
|
| 45 |
-
nn.Linear(256 + 32, 128),
|
| 46 |
-
nn.ReLU(),
|
| 47 |
-
nn.Dropout(0.2),
|
| 48 |
-
nn.Linear(128, 64),
|
| 49 |
-
nn.ReLU(),
|
| 50 |
-
nn.Linear(64, action_dim)
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
# Store action_dim for error handling
|
| 54 |
-
self.action_dim = action_dim
|
| 55 |
-
|
| 56 |
-
def forward(self, x, sentiment=None):
|
| 57 |
-
try:
|
| 58 |
-
# Ensure input has batch dimension
|
| 59 |
-
if len(x.shape) == 3: # (H, W, C)
|
| 60 |
-
x = x.unsqueeze(0)
|
| 61 |
-
elif len(x.shape) == 4: # (batch, H, W, C)
|
| 62 |
-
pass
|
| 63 |
-
else:
|
| 64 |
-
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 65 |
-
|
| 66 |
-
# Permute to (batch, C, H, W)
|
| 67 |
-
x = x.permute(0, 3, 1, 2).contiguous().float()
|
| 68 |
-
|
| 69 |
-
# Check if channels match expected input
|
| 70 |
-
if x.size(1) != 4:
|
| 71 |
-
raise ValueError(f"Expected 4 channels, got {x.size(1)}")
|
| 72 |
-
|
| 73 |
-
visual_features = self.visual_conv(x)
|
| 74 |
-
batch_size = visual_features.size(0)
|
| 75 |
-
visual_features = visual_features.reshape(batch_size, -1)
|
| 76 |
-
visual_features = self.visual_fc(visual_features)
|
| 77 |
-
|
| 78 |
-
# Sentiment processing
|
| 79 |
-
if sentiment is not None and self.sentiment_fc is not None:
|
| 80 |
-
if len(sentiment.shape) == 1:
|
| 81 |
-
sentiment = sentiment.unsqueeze(0)
|
| 82 |
-
sentiment = sentiment.float()
|
| 83 |
-
sentiment_features = self.sentiment_fc(sentiment)
|
| 84 |
-
combined_features = torch.cat([visual_features, sentiment_features], dim=1)
|
| 85 |
-
else:
|
| 86 |
-
# Pad with zeros if no sentiment
|
| 87 |
-
sentiment_features = torch.zeros(batch_size, 32, device=visual_features.device)
|
| 88 |
-
combined_features = torch.cat([visual_features, sentiment_features], dim=1)
|
| 89 |
-
|
| 90 |
-
q_values = self.combined_fc(combined_features)
|
| 91 |
-
return q_values
|
| 92 |
-
|
| 93 |
-
except Exception as e:
|
| 94 |
-
print(f"Error in network forward: {e}")
|
| 95 |
-
print(f"Input shape: {getattr(x, 'shape', 'Unknown')}")
|
| 96 |
-
# Return safe default with correct shape
|
| 97 |
-
batch_size = x.size(0) if hasattr(x, 'size') else 1
|
| 98 |
-
return torch.zeros(batch_size, self.action_dim, device=(x.device if hasattr(x, 'device') else 'cpu'))
|
| 99 |
-
|
| 100 |
-
class AdvancedTradingAgent:
|
| 101 |
-
def __init__(self, state_dim, action_dim, learning_rate=0.001, use_sentiment=True):
|
| 102 |
-
self.state_dim = state_dim # Should be (84, 84, 4) or similar
|
| 103 |
-
self.action_dim = action_dim
|
| 104 |
-
self.learning_rate = learning_rate
|
| 105 |
-
self.use_sentiment = use_sentiment
|
| 106 |
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 107 |
-
print(f"Using device: {self.device}")
|
| 108 |
-
|
| 109 |
-
# Neural network
|
| 110 |
-
self.policy_net = EnhancedTradingNetwork(state_dim, action_dim).to(self.device)
|
| 111 |
-
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
|
| 112 |
-
self.loss_fn = nn.MSELoss()
|
| 113 |
-
|
| 114 |
-
# Experience replay
|
| 115 |
-
self.memory = deque(maxlen=10000) # Increased buffer size
|
| 116 |
-
self.batch_size = min(32, state_dim[0]//2) # Dynamic batch size
|
| 117 |
-
|
| 118 |
-
# Training parameters
|
| 119 |
-
self.gamma = 0.99
|
| 120 |
-
self.epsilon = 1.0
|
| 121 |
-
self.epsilon_min = 0.01 # More aggressive exploration decay
|
| 122 |
-
self.epsilon_decay = 0.9995 # Slower decay
|
| 123 |
-
self.steps_done = 0
|
| 124 |
-
self.target_update_freq = 100 # Target network update frequency
|
| 125 |
-
self.steps_since_target_update = 0
|
| 126 |
-
|
| 127 |
-
def select_action(self, state, current_sentiment=None, sentiment_confidence=None):
|
| 128 |
-
"""Select action with epsilon-greedy policy"""
|
| 129 |
-
if random.random() < self.epsilon:
|
| 130 |
-
return random.randint(0, self.action_dim - 1)
|
| 131 |
-
|
| 132 |
-
try:
|
| 133 |
-
# Validate and normalize state
|
| 134 |
-
if not isinstance(state, np.ndarray):
|
| 135 |
-
state = np.array(state)
|
| 136 |
-
|
| 137 |
-
if state.dtype != np.float32:
|
| 138 |
-
state = state.astype(np.float32)
|
| 139 |
-
|
| 140 |
-
# Normalize pixel values
|
| 141 |
-
if state.max() > 1.0:
|
| 142 |
-
state = state / 255.0
|
| 143 |
-
|
| 144 |
-
state_tensor = torch.FloatTensor(state).to(self.device)
|
| 145 |
-
|
| 146 |
-
# Prepare sentiment input
|
| 147 |
-
if self.use_sentiment and current_sentiment is not None:
|
| 148 |
-
sentiment = np.array([float(current_sentiment), float(sentiment_confidence or 0.0)])
|
| 149 |
-
sentiment_tensor = torch.FloatTensor(sentiment).to(self.device)
|
| 150 |
-
with torch.no_grad():
|
| 151 |
-
q_values = self.policy_net(state_tensor, sentiment_tensor)
|
| 152 |
-
else:
|
| 153 |
-
with torch.no_grad():
|
| 154 |
-
q_values = self.policy_net(state_tensor)
|
| 155 |
-
|
| 156 |
-
action = int(q_values.argmax().item())
|
| 157 |
-
return action
|
| 158 |
-
|
| 159 |
-
except Exception as e:
|
| 160 |
-
print(f"Error in action selection: {e}")
|
| 161 |
-
return random.randint(0, self.action_dim - 1)
|
| 162 |
-
|
| 163 |
-
def store_transition(self, state, action, reward, next_state, done, sentiment_data=None):
|
| 164 |
-
"""Store experience tuple safely"""
|
| 165 |
-
try:
|
| 166 |
-
# Ensure all inputs are numpy arrays
|
| 167 |
-
if not isinstance(state, np.ndarray):
|
| 168 |
-
state = np.array(state, dtype=np.float32)
|
| 169 |
-
if not isinstance(next_state, np.ndarray):
|
| 170 |
-
next_state = np.array(next_state, dtype=np.float32)
|
| 171 |
-
|
| 172 |
-
# Normalize before storing
|
| 173 |
-
if state.max() > 1.0:
|
| 174 |
-
state = state / 255.0
|
| 175 |
-
if next_state.max() > 1.0:
|
| 176 |
-
next_state = next_state / 255.0
|
| 177 |
-
|
| 178 |
-
# Handle sentiment data
|
| 179 |
-
if sentiment_data is None:
|
| 180 |
-
sentiment_data = {'sentiment': 0.5, 'confidence': 0.0}
|
| 181 |
-
|
| 182 |
-
experience = (state, action, float(reward), next_state, bool(done), sentiment_data)
|
| 183 |
-
self.memory.append(experience)
|
| 184 |
-
|
| 185 |
-
except Exception as e:
|
| 186 |
-
print(f"Error storing transition: {e}")
|
| 187 |
-
|
| 188 |
-
def update(self):
|
| 189 |
-
"""DQN update with improved stability"""
|
| 190 |
-
if len(self.memory) < self.batch_size:
|
| 191 |
-
return 0.0
|
| 192 |
-
|
| 193 |
-
try:
|
| 194 |
-
batch = random.sample(self.memory, self.batch_size)
|
| 195 |
-
states, actions, rewards, next_states, dones, sentiments = zip(*batch)
|
| 196 |
-
|
| 197 |
-
# Convert to tensors
|
| 198 |
-
states = np.stack(states)
|
| 199 |
-
next_states = np.stack(next_states)
|
| 200 |
-
actions = np.array(actions)
|
| 201 |
-
rewards = np.array(rewards)
|
| 202 |
-
dones = np.array(dones)
|
| 203 |
-
|
| 204 |
-
states_tensor = torch.FloatTensor(states).to(self.device)
|
| 205 |
-
next_states_tensor = torch.FloatTensor(next_states).to(self.device)
|
| 206 |
-
actions_tensor = torch.LongTensor(actions).to(self.device)
|
| 207 |
-
rewards_tensor = torch.FloatTensor(rewards).to(self.device)
|
| 208 |
-
dones_tensor = torch.BoolTensor(dones).to(self.device)
|
| 209 |
-
|
| 210 |
-
# Compute current Q values
|
| 211 |
-
if self.use_sentiment:
|
| 212 |
-
# Use sentiment from current state
|
| 213 |
-
sentiment_batch = []
|
| 214 |
-
for sentiment_data in sentiments:
|
| 215 |
-
sentiment = [sentiment_data.get('sentiment', 0.5),
|
| 216 |
-
sentiment_data.get('confidence', 0.0)]
|
| 217 |
-
sentiment_batch.append(sentiment)
|
| 218 |
-
sentiment_tensor = torch.FloatTensor(sentiment_batch).to(self.device)
|
| 219 |
-
current_q = self.policy_net(states_tensor, sentiment_tensor)
|
| 220 |
-
else:
|
| 221 |
-
current_q = self.policy_net(states_tensor)
|
| 222 |
-
|
| 223 |
-
current_q = current_q.gather(1, actions_tensor.unsqueeze(1)).squeeze(1)
|
| 224 |
-
|
| 225 |
-
# Compute target Q values
|
| 226 |
-
with torch.no_grad():
|
| 227 |
-
if self.use_sentiment:
|
| 228 |
-
next_sentiment_batch = []
|
| 229 |
-
for sentiment_data in sentiments:
|
| 230 |
-
next_sentiment = [sentiment_data.get('sentiment', 0.5),
|
| 231 |
-
sentiment_data.get('confidence', 0.0)]
|
| 232 |
-
next_sentiment_batch.append(next_sentiment)
|
| 233 |
-
next_sentiment_tensor = torch.FloatTensor(next_sentiment_batch).to(self.device)
|
| 234 |
-
next_q = self.policy_net(next_states_tensor, next_sentiment_tensor)
|
| 235 |
-
else:
|
| 236 |
-
next_q = self.policy_net(next_states_tensor)
|
| 237 |
-
|
| 238 |
-
next_q_max = next_q.max(1)[0]
|
| 239 |
-
target_q = rewards_tensor + (self.gamma * next_q_max * ~dones_tensor)
|
| 240 |
-
|
| 241 |
-
# Compute loss and optimize
|
| 242 |
-
loss = self.loss_fn(current_q, target_q)
|
| 243 |
-
|
| 244 |
-
self.optimizer.zero_grad()
|
| 245 |
-
loss.backward()
|
| 246 |
-
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
| 247 |
-
self.optimizer.step()
|
| 248 |
-
|
| 249 |
-
# Update epsilon
|
| 250 |
-
if self.epsilon > self.epsilon_min:
|
| 251 |
-
self.epsilon *= self.epsilon_decay
|
| 252 |
-
|
| 253 |
-
self.steps_done += 1
|
| 254 |
-
self.steps_since_target_update += 1
|
| 255 |
-
|
| 256 |
-
# Update target network periodically (if implemented)
|
| 257 |
-
if self.steps_since_target_update % self.target_update_freq == 0:
|
| 258 |
-
self._update_target_network()
|
| 259 |
-
|
| 260 |
-
return float(loss.item())
|
| 261 |
-
|
| 262 |
-
except Exception as e:
|
| 263 |
-
print(f"Error in update: {e}")
|
| 264 |
-
import traceback
|
| 265 |
-
traceback.print_exc()
|
| 266 |
-
return 0.0
|
| 267 |
-
|
| 268 |
-
def _update_target_network(self):
|
| 269 |
-
"""Update target network (placeholder for double DQN)"""
|
| 270 |
-
pass # Implement target network update here
|
| 271 |
-
|
| 272 |
-
# Simple fallback network
|
| 273 |
-
class SimpleTradingNetwork(nn.Module):
|
| 274 |
-
def __init__(self, state_dim, action_dim):
|
| 275 |
-
super(SimpleTradingNetwork, self).__init__()
|
| 276 |
-
self.action_dim = action_dim
|
| 277 |
-
|
| 278 |
-
self.conv_layers = nn.Sequential(
|
| 279 |
-
nn.Conv2d(4, 16, kernel_size=4, stride=2),
|
| 280 |
-
nn.ReLU(),
|
| 281 |
-
nn.Conv2d(16, 32, kernel_size=4, stride=2),
|
| 282 |
-
nn.ReLU(),
|
| 283 |
-
nn.Conv2d(32, 32, kernel_size=3, stride=1),
|
| 284 |
-
nn.ReLU(),
|
| 285 |
-
nn.AdaptiveAvgPool2d((8, 8))
|
| 286 |
-
)
|
| 287 |
-
|
| 288 |
-
self.fc_layers = nn.Sequential(
|
| 289 |
-
nn.Linear(32 * 8 * 8, 128),
|
| 290 |
-
nn.ReLU(),
|
| 291 |
-
nn.Dropout(0.2),
|
| 292 |
-
nn.Linear(128, 64),
|
| 293 |
-
nn.ReLU(),
|
| 294 |
-
nn.Linear(64, action_dim)
|
| 295 |
-
)
|
| 296 |
-
|
| 297 |
-
def forward(self, x):
|
| 298 |
-
try:
|
| 299 |
-
if len(x.shape) == 3:
|
| 300 |
-
x = x.unsqueeze(0)
|
| 301 |
-
x = x.permute(0, 3, 1, 2).contiguous().float()
|
| 302 |
-
|
| 303 |
-
x = self.conv_layers(x)
|
| 304 |
-
x = x.reshape(x.size(0), -1)
|
| 305 |
-
x = self.fc_layers(x)
|
| 306 |
-
return x
|
| 307 |
-
except Exception as e:
|
| 308 |
-
print(f"Error in simple network: {e}")
|
| 309 |
-
batch_size = x.size(0) if hasattr(x, 'size') else 1
|
| 310 |
-
return torch.zeros(batch_size, self.action_dim, device=(x.device if hasattr(x, 'device') else 'cpu'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/agents/visual_agent.py
DELETED
|
@@ -1,153 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.optim as optim
|
| 4 |
-
import numpy as np
|
| 5 |
-
from collections import deque
|
| 6 |
-
import random
|
| 7 |
-
|
| 8 |
-
class SimpleTradingNetwork(nn.Module):
|
| 9 |
-
def __init__(self, state_dim, action_dim):
|
| 10 |
-
super(SimpleTradingNetwork, self).__init__()
|
| 11 |
-
|
| 12 |
-
# Simplified CNN for faster training
|
| 13 |
-
self.conv_layers = nn.Sequential(
|
| 14 |
-
nn.Conv2d(4, 16, kernel_size=4, stride=2), # Input: 84x84x4 -> 41x41x16
|
| 15 |
-
nn.ReLU(),
|
| 16 |
-
nn.Conv2d(16, 32, kernel_size=4, stride=2), # 41x41x16 -> 19x19x32
|
| 17 |
-
nn.ReLU(),
|
| 18 |
-
nn.Conv2d(32, 32, kernel_size=3, stride=1), # 19x19x32 -> 17x17x32
|
| 19 |
-
nn.ReLU(),
|
| 20 |
-
nn.AdaptiveAvgPool2d((8, 8)) # 17x17x32 -> 8x8x32
|
| 21 |
-
)
|
| 22 |
-
|
| 23 |
-
# Calculate flattened size
|
| 24 |
-
self.flattened_size = 32 * 8 * 8
|
| 25 |
-
|
| 26 |
-
# Fully connected layers
|
| 27 |
-
self.fc_layers = nn.Sequential(
|
| 28 |
-
nn.Linear(self.flattened_size, 128),
|
| 29 |
-
nn.ReLU(),
|
| 30 |
-
nn.Dropout(0.2),
|
| 31 |
-
nn.Linear(128, 64),
|
| 32 |
-
nn.ReLU(),
|
| 33 |
-
nn.Dropout(0.2),
|
| 34 |
-
nn.Linear(64, action_dim)
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
def forward(self, x):
|
| 38 |
-
try:
|
| 39 |
-
# x shape: (batch_size, 84, 84, 4) -> (batch_size, 4, 84, 84)
|
| 40 |
-
if len(x.shape) == 4: # Single observation
|
| 41 |
-
x = x.permute(0, 3, 1, 2)
|
| 42 |
-
else: # Batch of observations
|
| 43 |
-
x = x.permute(0, 3, 1, 2)
|
| 44 |
-
|
| 45 |
-
x = self.conv_layers(x)
|
| 46 |
-
x = x.view(x.size(0), -1)
|
| 47 |
-
x = self.fc_layers(x)
|
| 48 |
-
return x
|
| 49 |
-
except Exception as e:
|
| 50 |
-
print(f"Error in network forward: {e}")
|
| 51 |
-
# Return zeros in case of error
|
| 52 |
-
return torch.zeros((x.size(0), self.fc_layers[-1].out_features))
|
| 53 |
-
|
| 54 |
-
class VisualTradingAgent:
|
| 55 |
-
def __init__(self, state_dim, action_dim, learning_rate=0.001):
|
| 56 |
-
self.state_dim = state_dim
|
| 57 |
-
self.action_dim = action_dim
|
| 58 |
-
self.learning_rate = learning_rate
|
| 59 |
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 60 |
-
print(f"Using device: {self.device}")
|
| 61 |
-
|
| 62 |
-
# Neural network
|
| 63 |
-
self.policy_net = SimpleTradingNetwork(state_dim, action_dim).to(self.device)
|
| 64 |
-
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
|
| 65 |
-
|
| 66 |
-
# Experience replay
|
| 67 |
-
self.memory = deque(maxlen=500) # Smaller memory for stability
|
| 68 |
-
self.batch_size = 16
|
| 69 |
-
|
| 70 |
-
# Training parameters
|
| 71 |
-
self.gamma = 0.99
|
| 72 |
-
self.epsilon = 1.0
|
| 73 |
-
self.epsilon_min = 0.1
|
| 74 |
-
self.epsilon_decay = 0.995
|
| 75 |
-
self.update_target_every = 100
|
| 76 |
-
self.steps_done = 0
|
| 77 |
-
|
| 78 |
-
def select_action(self, state):
|
| 79 |
-
"""Select action using epsilon-greedy policy"""
|
| 80 |
-
if random.random() < self.epsilon:
|
| 81 |
-
return random.randint(0, self.action_dim - 1)
|
| 82 |
-
|
| 83 |
-
try:
|
| 84 |
-
# Normalize state and convert to tensor
|
| 85 |
-
state_normalized = state.astype(np.float32) / 255.0
|
| 86 |
-
state_tensor = torch.FloatTensor(state_normalized).unsqueeze(0).to(self.device)
|
| 87 |
-
|
| 88 |
-
with torch.no_grad():
|
| 89 |
-
q_values = self.policy_net(state_tensor)
|
| 90 |
-
return int(q_values.argmax().item())
|
| 91 |
-
except Exception as e:
|
| 92 |
-
print(f"Error in action selection: {e}")
|
| 93 |
-
return random.randint(0, self.action_dim - 1)
|
| 94 |
-
|
| 95 |
-
def store_transition(self, state, action, reward, next_state, done):
|
| 96 |
-
"""Store experience in replay memory"""
|
| 97 |
-
try:
|
| 98 |
-
self.memory.append((state, action, reward, next_state, done))
|
| 99 |
-
except Exception as e:
|
| 100 |
-
print(f"Error storing transition: {e}")
|
| 101 |
-
|
| 102 |
-
def update(self):
|
| 103 |
-
"""Update the neural network"""
|
| 104 |
-
if len(self.memory) < self.batch_size:
|
| 105 |
-
return 0.0
|
| 106 |
-
|
| 107 |
-
try:
|
| 108 |
-
# Sample batch from memory
|
| 109 |
-
batch = random.sample(self.memory, self.batch_size)
|
| 110 |
-
states, actions, rewards, next_states, dones = zip(*batch)
|
| 111 |
-
|
| 112 |
-
# Convert to tensors with normalization
|
| 113 |
-
states_array = np.array(states, dtype=np.float32) / 255.0
|
| 114 |
-
next_states_array = np.array(next_states, dtype=np.float32) / 255.0
|
| 115 |
-
|
| 116 |
-
states_tensor = torch.FloatTensor(states_array).to(self.device)
|
| 117 |
-
actions_tensor = torch.LongTensor(actions).to(self.device)
|
| 118 |
-
rewards_tensor = torch.FloatTensor(rewards).to(self.device)
|
| 119 |
-
next_states_tensor = torch.FloatTensor(next_states_array).to(self.device)
|
| 120 |
-
dones_tensor = torch.BoolTensor(dones).to(self.device)
|
| 121 |
-
|
| 122 |
-
# Current Q values
|
| 123 |
-
current_q = self.policy_net(states_tensor).gather(1, actions_tensor.unsqueeze(1))
|
| 124 |
-
|
| 125 |
-
# Next Q values
|
| 126 |
-
with torch.no_grad():
|
| 127 |
-
next_q = self.policy_net(next_states_tensor).max(1)[0]
|
| 128 |
-
target_q = rewards_tensor + (self.gamma * next_q * ~dones_tensor)
|
| 129 |
-
|
| 130 |
-
# Compute loss
|
| 131 |
-
loss = nn.MSELoss()(current_q.squeeze(), target_q)
|
| 132 |
-
|
| 133 |
-
# Optimize
|
| 134 |
-
self.optimizer.zero_grad()
|
| 135 |
-
loss.backward()
|
| 136 |
-
|
| 137 |
-
# Gradient clipping for stability
|
| 138 |
-
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
| 139 |
-
self.optimizer.step()
|
| 140 |
-
|
| 141 |
-
# Update steps and decay epsilon
|
| 142 |
-
self.steps_done += 1
|
| 143 |
-
if self.steps_done % self.update_target_every == 0:
|
| 144 |
-
# For simplicity, we're using the same network
|
| 145 |
-
pass
|
| 146 |
-
|
| 147 |
-
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
|
| 148 |
-
|
| 149 |
-
return float(loss.item())
|
| 150 |
-
|
| 151 |
-
except Exception as e:
|
| 152 |
-
print(f"Error in update: {e}")
|
| 153 |
-
return 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/environments/advanced_trading_env.py
DELETED
|
@@ -1,293 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import logging
|
| 3 |
-
from typing import Dict, Any, Optional, Tuple
|
| 4 |
-
from .visual_trading_env import VisualTradingEnvironment
|
| 5 |
-
from src.sentiment.twitter_analyzer import AdvancedSentimentAnalyzer
|
| 6 |
-
|
| 7 |
-
# Setup logging
|
| 8 |
-
logging.basicConfig(level=logging.INFO)
|
| 9 |
-
logger = logging.getLogger(__name__)
|
| 10 |
-
|
| 11 |
-
class AdvancedTradingEnvironment(VisualTradingEnvironment):
|
| 12 |
-
def __init__(self, initial_balance=10000, risk_level="Medium", asset_type="Crypto",
|
| 13 |
-
use_sentiment=True, sentiment_influence=0.3, sentiment_update_freq=5):
|
| 14 |
-
super().__init__(initial_balance, risk_level, asset_type)
|
| 15 |
-
|
| 16 |
-
# Validate inputs
|
| 17 |
-
if not 0.0 <= sentiment_influence <= 1.0:
|
| 18 |
-
raise ValueError("sentiment_influence must be between 0.0 and 1.0")
|
| 19 |
-
if sentiment_update_freq < 1:
|
| 20 |
-
raise ValueError("sentiment_update_freq must be at least 1")
|
| 21 |
-
|
| 22 |
-
self.use_sentiment = use_sentiment
|
| 23 |
-
self.sentiment_influence = sentiment_influence
|
| 24 |
-
self.sentiment_update_freq = sentiment_update_freq
|
| 25 |
-
self.sentiment_history = deque(maxlen=100) # Limited history
|
| 26 |
-
self.current_step = 0
|
| 27 |
-
|
| 28 |
-
# Sentiment analyzer with error handling
|
| 29 |
-
self.sentiment_analyzer = None
|
| 30 |
-
self.current_sentiment = 0.5
|
| 31 |
-
self.sentiment_confidence = 0.0
|
| 32 |
-
|
| 33 |
-
if use_sentiment:
|
| 34 |
-
try:
|
| 35 |
-
self.sentiment_analyzer = AdvancedSentimentAnalyzer()
|
| 36 |
-
self.sentiment_analyzer.initialize_models()
|
| 37 |
-
logger.info("Sentiment analyzer initialized successfully")
|
| 38 |
-
except Exception as e:
|
| 39 |
-
logger.warning(f"Failed to initialize sentiment analyzer: {e}. Disabling sentiment.")
|
| 40 |
-
self.use_sentiment = False
|
| 41 |
-
|
| 42 |
-
def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]:
|
| 43 |
-
"""Execute trading step with sentiment influence"""
|
| 44 |
-
if not isinstance(action, int) or action < 0:
|
| 45 |
-
logger.warning(f"Invalid action {action}, defaulting to hold")
|
| 46 |
-
action = 0 # Hold action as default
|
| 47 |
-
|
| 48 |
-
# Update sentiment periodically
|
| 49 |
-
if self.use_sentiment and self.current_step % self.sentiment_update_freq == 0:
|
| 50 |
-
self._update_sentiment()
|
| 51 |
-
|
| 52 |
-
self.current_step += 1
|
| 53 |
-
|
| 54 |
-
# Execute base environment step
|
| 55 |
-
try:
|
| 56 |
-
observation, reward, done, info = super().step(action)
|
| 57 |
-
except Exception as e:
|
| 58 |
-
logger.error(f"Error in base environment step: {e}")
|
| 59 |
-
# Return safe defaults
|
| 60 |
-
observation = self._get_safe_observation()
|
| 61 |
-
reward = 0.0
|
| 62 |
-
done = False
|
| 63 |
-
info = {}
|
| 64 |
-
|
| 65 |
-
# Apply sentiment modification to reward
|
| 66 |
-
if self.use_sentiment:
|
| 67 |
-
try:
|
| 68 |
-
reward = self._apply_sentiment_to_reward(reward, action, info)
|
| 69 |
-
except Exception as e:
|
| 70 |
-
logger.warning(f"Error applying sentiment to reward: {e}")
|
| 71 |
-
|
| 72 |
-
# Enhance observation with sentiment (optional)
|
| 73 |
-
enhanced_observation = self._enhance_observation(observation)
|
| 74 |
-
|
| 75 |
-
# Add sentiment info to info dict
|
| 76 |
-
info.update({
|
| 77 |
-
'sentiment': float(self.current_sentiment),
|
| 78 |
-
'sentiment_confidence': float(self.sentiment_confidence),
|
| 79 |
-
'sentiment_influence': float(self.sentiment_influence),
|
| 80 |
-
'step': self.current_step
|
| 81 |
-
})
|
| 82 |
-
|
| 83 |
-
return enhanced_observation, float(reward), bool(done), info
|
| 84 |
-
|
| 85 |
-
def _update_sentiment(self):
|
| 86 |
-
"""Update current market sentiment with robust error handling"""
|
| 87 |
-
if not self.sentiment_analyzer:
|
| 88 |
-
return
|
| 89 |
-
|
| 90 |
-
try:
|
| 91 |
-
sentiment_data = self.sentiment_analyzer.get_influencer_sentiment()
|
| 92 |
-
|
| 93 |
-
# Validate sentiment data
|
| 94 |
-
if not isinstance(sentiment_data, dict):
|
| 95 |
-
raise ValueError("Invalid sentiment data format")
|
| 96 |
-
|
| 97 |
-
market_sentiment = sentiment_data.get('market_sentiment')
|
| 98 |
-
confidence = sentiment_data.get('confidence')
|
| 99 |
-
|
| 100 |
-
if market_sentiment is None or not (-1.0 <= market_sentiment <= 1.0):
|
| 101 |
-
raise ValueError("Invalid market_sentiment value")
|
| 102 |
-
if confidence is None or not (0.0 <= confidence <= 1.0):
|
| 103 |
-
raise ValueError("Invalid confidence value")
|
| 104 |
-
|
| 105 |
-
self.current_sentiment = float(market_sentiment)
|
| 106 |
-
self.sentiment_confidence = float(confidence)
|
| 107 |
-
|
| 108 |
-
# Normalize sentiment to 0-1 range for consistency
|
| 109 |
-
self.current_sentiment = (self.current_sentiment + 1.0) / 2.0
|
| 110 |
-
|
| 111 |
-
# Update history
|
| 112 |
-
self.sentiment_history.append({
|
| 113 |
-
'sentiment': self.current_sentiment,
|
| 114 |
-
'confidence': self.sentiment_confidence,
|
| 115 |
-
'timestamp': self.current_step
|
| 116 |
-
})
|
| 117 |
-
|
| 118 |
-
logger.debug(f"Updated sentiment: {self.current_sentiment:.3f} (conf: {self.sentiment_confidence:.3f})")
|
| 119 |
-
|
| 120 |
-
except Exception as e:
|
| 121 |
-
logger.warning(f"Error updating sentiment: {e}")
|
| 122 |
-
# Fallback to neutral sentiment
|
| 123 |
-
self.current_sentiment = 0.5
|
| 124 |
-
self.sentiment_confidence = 0.0
|
| 125 |
-
self.sentiment_history.append({
|
| 126 |
-
'sentiment': 0.5,
|
| 127 |
-
'confidence': 0.0,
|
| 128 |
-
'timestamp': self.current_step
|
| 129 |
-
})
|
| 130 |
-
|
| 131 |
-
def _apply_sentiment_to_reward(self, original_reward: float, action: int,
|
| 132 |
-
info: Dict[str, Any]) -> float:
|
| 133 |
-
"""Modify reward based on sentiment analysis with bounds checking"""
|
| 134 |
-
if self.sentiment_confidence < 0.3:
|
| 135 |
-
return original_reward
|
| 136 |
-
|
| 137 |
-
try:
|
| 138 |
-
sentiment_multiplier = 1.0
|
| 139 |
-
sentiment_score = self.current_sentiment # 0-1 normalized
|
| 140 |
-
|
| 141 |
-
# Define action mappings (adjust based on your action space)
|
| 142 |
-
# Assuming: 0=hold, 1=buy, 2=sell, 3=close
|
| 143 |
-
bullish_threshold = 0.6
|
| 144 |
-
bearish_threshold = 0.4
|
| 145 |
-
|
| 146 |
-
if sentiment_score > bullish_threshold: # Bullish
|
| 147 |
-
if action == 1: # Buy
|
| 148 |
-
sentiment_multiplier += self.sentiment_influence * self.sentiment_confidence
|
| 149 |
-
elif action == 2: # Sell short
|
| 150 |
-
sentiment_multiplier -= self.sentiment_influence * 0.3 * self.sentiment_confidence
|
| 151 |
-
elif action == 3: # Close
|
| 152 |
-
sentiment_multiplier -= self.sentiment_influence * 0.2 * self.sentiment_confidence
|
| 153 |
-
|
| 154 |
-
elif sentiment_score < bearish_threshold: # Bearish
|
| 155 |
-
if action == 2: # Sell short
|
| 156 |
-
sentiment_multiplier += self.sentiment_influence * self.sentiment_confidence
|
| 157 |
-
elif action == 1: # Buy
|
| 158 |
-
sentiment_multiplier -= self.sentiment_influence * 0.5 * self.sentiment_confidence
|
| 159 |
-
elif action == 3: # Close
|
| 160 |
-
sentiment_multiplier += self.sentiment_influence * 0.3 * self.sentiment_confidence
|
| 161 |
-
|
| 162 |
-
# Apply trend momentum if enough history
|
| 163 |
-
trend_multiplier = self._calculate_sentiment_trend_multiplier()
|
| 164 |
-
sentiment_multiplier += trend_multiplier
|
| 165 |
-
|
| 166 |
-
# Clamp multiplier to reasonable bounds
|
| 167 |
-
sentiment_multiplier = np.clip(sentiment_multiplier, 0.5, 2.0)
|
| 168 |
-
|
| 169 |
-
enhanced_reward = original_reward * sentiment_multiplier
|
| 170 |
-
|
| 171 |
-
# Ensure reward doesn't become extreme
|
| 172 |
-
max_reward = abs(original_reward) * 2.5 if original_reward != 0 else 10.0
|
| 173 |
-
return np.clip(enhanced_reward, -max_reward, max_reward)
|
| 174 |
-
|
| 175 |
-
except Exception as e:
|
| 176 |
-
logger.error(f"Error in sentiment reward calculation: {e}")
|
| 177 |
-
return original_reward
|
| 178 |
-
|
| 179 |
-
def _calculate_sentiment_trend_multiplier(self) -> float:
|
| 180 |
-
"""Calculate trend-based multiplier from sentiment history"""
|
| 181 |
-
if len(self.sentiment_history) < 10:
|
| 182 |
-
return 0.0
|
| 183 |
-
|
| 184 |
-
try:
|
| 185 |
-
# Get recent and previous sentiment values
|
| 186 |
-
recent_sentiments = [h['sentiment'] for h in list(self.sentiment_history)[-5:]]
|
| 187 |
-
prev_sentiments = [h['sentiment'] for h in list(self.sentiment_history)[-10:-5]]
|
| 188 |
-
|
| 189 |
-
recent_avg = np.mean(recent_sentiments)
|
| 190 |
-
prev_avg = np.mean(prev_sentiments)
|
| 191 |
-
|
| 192 |
-
trend = recent_avg - prev_avg
|
| 193 |
-
# Scale trend influence
|
| 194 |
-
trend_multiplier = np.tanh(trend * 5) * self.sentiment_influence * 0.3
|
| 195 |
-
return float(trend_multiplier)
|
| 196 |
-
|
| 197 |
-
except Exception as e:
|
| 198 |
-
logger.warning(f"Error calculating trend multiplier: {e}")
|
| 199 |
-
return 0.0
|
| 200 |
-
|
| 201 |
-
def _enhance_observation(self, original_observation: np.ndarray) -> np.ndarray:
|
| 202 |
-
"""Enhance observation with sentiment information"""
|
| 203 |
-
if not self.use_sentiment or original_observation is None:
|
| 204 |
-
return original_observation
|
| 205 |
-
|
| 206 |
-
try:
|
| 207 |
-
# For now, return original observation
|
| 208 |
-
# Future: could concatenate sentiment as additional channels or metadata
|
| 209 |
-
return original_observation.copy()
|
| 210 |
-
except Exception as e:
|
| 211 |
-
logger.warning(f"Error enhancing observation: {e}")
|
| 212 |
-
return original_observation
|
| 213 |
-
|
| 214 |
-
def _get_safe_observation(self) -> np.ndarray:
|
| 215 |
-
"""Get a safe default observation"""
|
| 216 |
-
try:
|
| 217 |
-
# Try to get current observation from base env
|
| 218 |
-
if hasattr(self, 'current_observation'):
|
| 219 |
-
return self.current_observation.copy()
|
| 220 |
-
# Return zero observation of expected shape
|
| 221 |
-
return np.zeros((84, 84, 4), dtype=np.float32)
|
| 222 |
-
except:
|
| 223 |
-
return np.zeros((84, 84, 4), dtype=np.float32)
|
| 224 |
-
|
| 225 |
-
def get_sentiment_analysis(self) -> Dict[str, Any]:
|
| 226 |
-
"""Get detailed sentiment analysis with safety checks"""
|
| 227 |
-
if not self.use_sentiment:
|
| 228 |
-
return {"error": "Sentiment analysis disabled", "sentiment": 0.5, "confidence": 0.0}
|
| 229 |
-
|
| 230 |
-
try:
|
| 231 |
-
trend_direction = self._calculate_sentiment_trend_direction()
|
| 232 |
-
return {
|
| 233 |
-
"current_sentiment": float(self.current_sentiment),
|
| 234 |
-
"sentiment_confidence": float(self.sentiment_confidence),
|
| 235 |
-
"sentiment_trend": trend_direction,
|
| 236 |
-
"influence_level": float(self.sentiment_influence),
|
| 237 |
-
"history_length": len(self.sentiment_history),
|
| 238 |
-
"update_freq": self.sentiment_update_freq,
|
| 239 |
-
"last_update_step": self.current_step
|
| 240 |
-
}
|
| 241 |
-
except Exception as e:
|
| 242 |
-
logger.error(f"Error in get_sentiment_analysis: {e}")
|
| 243 |
-
return {
|
| 244 |
-
"error": str(e),
|
| 245 |
-
"sentiment": 0.5,
|
| 246 |
-
"confidence": 0.0,
|
| 247 |
-
"trend": "unknown"
|
| 248 |
-
}
|
| 249 |
-
|
| 250 |
-
def _calculate_sentiment_trend_direction(self) -> str:
|
| 251 |
-
"""Calculate sentiment trend direction"""
|
| 252 |
-
if len(self.sentiment_history) < 5:
|
| 253 |
-
return "insufficient_data"
|
| 254 |
-
|
| 255 |
-
try:
|
| 256 |
-
recent_avg = np.mean([h['sentiment'] for h in list(self.sentiment_history)[-5:]])
|
| 257 |
-
prev_avg = np.mean([h['sentiment'] for h in list(self.sentiment_history)[-10:-5]]) if len(self.sentiment_history) >= 10 else recent_avg
|
| 258 |
-
|
| 259 |
-
diff = recent_avg - prev_avg
|
| 260 |
-
if diff > 0.05:
|
| 261 |
-
return "bullish"
|
| 262 |
-
elif diff < -0.05:
|
| 263 |
-
return "bearish"
|
| 264 |
-
else:
|
| 265 |
-
return "neutral"
|
| 266 |
-
except:
|
| 267 |
-
return "error"
|
| 268 |
-
|
| 269 |
-
def reset(self) -> np.ndarray:
|
| 270 |
-
"""Reset environment with sentiment state"""
|
| 271 |
-
try:
|
| 272 |
-
observation = super().reset()
|
| 273 |
-
self.current_step = 0
|
| 274 |
-
self.sentiment_history.clear()
|
| 275 |
-
self.current_sentiment = 0.5
|
| 276 |
-
self.sentiment_confidence = 0.0
|
| 277 |
-
logger.info("Environment reset with sentiment state")
|
| 278 |
-
return observation
|
| 279 |
-
except Exception as e:
|
| 280 |
-
logger.error(f"Error in reset: {e}")
|
| 281 |
-
# Force safe reset
|
| 282 |
-
super().reset()
|
| 283 |
-
self.current_step = 0
|
| 284 |
-
self.sentiment_history.clear()
|
| 285 |
-
return np.zeros((84, 84, 4), dtype=np.float32)
|
| 286 |
-
|
| 287 |
-
@property
|
| 288 |
-
def action_space_size(self) -> int:
|
| 289 |
-
"""Get action space size from base environment"""
|
| 290 |
-
try:
|
| 291 |
-
return super().action_space.n if hasattr(super(), 'action_space') else 4
|
| 292 |
-
except:
|
| 293 |
-
return 4 # Default assumption
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/environments/visual_trading_env.py
DELETED
|
@@ -1,228 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import matplotlib.pyplot as plt
|
| 3 |
-
from PIL import Image
|
| 4 |
-
import io
|
| 5 |
-
|
| 6 |
-
class VisualTradingEnvironment:
|
| 7 |
-
def __init__(self, initial_balance=10000, risk_level="Medium", asset_type="Stock"):
|
| 8 |
-
self.initial_balance = float(initial_balance)
|
| 9 |
-
self.risk_level = risk_level
|
| 10 |
-
self.asset_type = asset_type
|
| 11 |
-
|
| 12 |
-
# Risk multipliers
|
| 13 |
-
risk_multipliers = {"Low": 0.5, "Medium": 1.0, "High": 2.0}
|
| 14 |
-
self.risk_multiplier = risk_multipliers.get(risk_level, 1.0)
|
| 15 |
-
|
| 16 |
-
# Generate market data
|
| 17 |
-
self._generate_market_data()
|
| 18 |
-
|
| 19 |
-
# Initialize state
|
| 20 |
-
self.reset()
|
| 21 |
-
|
| 22 |
-
def _generate_market_data(self, num_points=1000):
|
| 23 |
-
"""Generate realistic synthetic market data"""
|
| 24 |
-
np.random.seed(42)
|
| 25 |
-
|
| 26 |
-
# Base parameters based on asset type
|
| 27 |
-
base_params = {
|
| 28 |
-
"Stock": {"volatility": 0.01, "trend": 0.0005},
|
| 29 |
-
"Crypto": {"volatility": 0.02, "trend": 0.001},
|
| 30 |
-
"Forex": {"volatility": 0.005, "trend": 0.0002}
|
| 31 |
-
}
|
| 32 |
-
|
| 33 |
-
params = base_params.get(self.asset_type, base_params["Stock"])
|
| 34 |
-
volatility = params["volatility"] * self.risk_multiplier
|
| 35 |
-
trend = params["trend"]
|
| 36 |
-
|
| 37 |
-
prices = [100.0]
|
| 38 |
-
for i in range(1, num_points):
|
| 39 |
-
# Random walk with trend and some mean reversion
|
| 40 |
-
change = np.random.normal(trend, volatility)
|
| 41 |
-
# Add some mean reversion
|
| 42 |
-
mean_reversion = (100 - prices[-1]) * 0.001
|
| 43 |
-
price = max(1.0, prices[-1] * (1 + change) + mean_reversion)
|
| 44 |
-
prices.append(price)
|
| 45 |
-
|
| 46 |
-
self.price_data = np.array(prices)
|
| 47 |
-
|
| 48 |
-
def _get_visual_observation(self):
|
| 49 |
-
"""Generate visual representation of current market state"""
|
| 50 |
-
try:
|
| 51 |
-
# Get recent price window
|
| 52 |
-
window_size = 50
|
| 53 |
-
start_idx = max(0, self.current_step - window_size)
|
| 54 |
-
end_idx = self.current_step + 1
|
| 55 |
-
|
| 56 |
-
if end_idx > len(self.price_data):
|
| 57 |
-
end_idx = len(self.price_data)
|
| 58 |
-
|
| 59 |
-
prices = self.price_data[start_idx:end_idx]
|
| 60 |
-
|
| 61 |
-
# Create matplotlib figure with fixed size
|
| 62 |
-
fig, ax = plt.subplots(figsize=(4.2, 4.2), dpi=20, facecolor='black')
|
| 63 |
-
ax.set_facecolor('black')
|
| 64 |
-
|
| 65 |
-
# Plot price if we have data
|
| 66 |
-
if len(prices) > 0:
|
| 67 |
-
ax.plot(prices, color='cyan', linewidth=1.5)
|
| 68 |
-
|
| 69 |
-
# Remove axes for cleaner visual
|
| 70 |
-
ax.set_xticks([])
|
| 71 |
-
ax.set_yticks([])
|
| 72 |
-
ax.spines['top'].set_visible(False)
|
| 73 |
-
ax.spines['right'].set_visible(False)
|
| 74 |
-
ax.spines['bottom'].set_visible(False)
|
| 75 |
-
ax.spines['left'].set_visible(False)
|
| 76 |
-
|
| 77 |
-
# Set fixed limits to ensure consistent size
|
| 78 |
-
ax.set_xlim(0, 50)
|
| 79 |
-
if len(prices) > 0:
|
| 80 |
-
price_min, price_max = min(prices), max(prices)
|
| 81 |
-
price_range = price_max - price_min
|
| 82 |
-
if price_range == 0:
|
| 83 |
-
price_range = 1
|
| 84 |
-
ax.set_ylim(price_min - price_range * 0.1, price_max + price_range * 0.1)
|
| 85 |
-
else:
|
| 86 |
-
ax.set_ylim(0, 100)
|
| 87 |
-
|
| 88 |
-
# Convert to numpy array with consistent size
|
| 89 |
-
buf = io.BytesIO()
|
| 90 |
-
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, facecolor='black', dpi=20)
|
| 91 |
-
buf.seek(0)
|
| 92 |
-
img = Image.open(buf).convert('RGB')
|
| 93 |
-
|
| 94 |
-
# Resize to consistent dimensions
|
| 95 |
-
img = img.resize((84, 84), Image.Resampling.LANCZOS)
|
| 96 |
-
img_array = np.array(img)
|
| 97 |
-
|
| 98 |
-
plt.close(fig)
|
| 99 |
-
|
| 100 |
-
# Create attention map with same dimensions
|
| 101 |
-
attention_map = np.zeros((84, 84), dtype=np.uint8)
|
| 102 |
-
if len(prices) > 1:
|
| 103 |
-
recent_change = (prices[-1] - prices[-2]) / prices[-2]
|
| 104 |
-
intensity = min(255, abs(recent_change) * 5000)
|
| 105 |
-
|
| 106 |
-
# Simple attention based on price movement
|
| 107 |
-
center_x, center_y = 42, 42
|
| 108 |
-
size = max(5, int(intensity / 50))
|
| 109 |
-
|
| 110 |
-
for i in range(max(0, center_x-size), min(84, center_x+size)):
|
| 111 |
-
for j in range(max(0, center_y-size), min(84, center_y+size)):
|
| 112 |
-
distance = np.sqrt((i-center_x)**2 + (j-center_y)**2)
|
| 113 |
-
if distance <= size:
|
| 114 |
-
attention_value = intensity * (1 - distance/size)
|
| 115 |
-
attention_map[i, j] = max(attention_map[i, j], int(attention_value))
|
| 116 |
-
|
| 117 |
-
# Combine RGB with attention map
|
| 118 |
-
visual_obs = np.concatenate([
|
| 119 |
-
img_array,
|
| 120 |
-
attention_map[:, :, np.newaxis] # Add channel dimension
|
| 121 |
-
], axis=2)
|
| 122 |
-
|
| 123 |
-
return visual_obs
|
| 124 |
-
|
| 125 |
-
except Exception as e:
|
| 126 |
-
print(f"Error in visual observation: {e}")
|
| 127 |
-
# Return default observation in case of error
|
| 128 |
-
return np.zeros((84, 84, 4), dtype=np.uint8)
|
| 129 |
-
|
| 130 |
-
def reset(self):
|
| 131 |
-
"""Reset environment to initial state"""
|
| 132 |
-
self.current_step = 50 # Start with some history
|
| 133 |
-
self.balance = self.initial_balance
|
| 134 |
-
self.position_size = 0.0
|
| 135 |
-
self.entry_price = 0.0
|
| 136 |
-
self.net_worth = self.initial_balance
|
| 137 |
-
self.total_trades = 0
|
| 138 |
-
self.done = False
|
| 139 |
-
|
| 140 |
-
return self._get_visual_observation()
|
| 141 |
-
|
| 142 |
-
def step(self, action):
|
| 143 |
-
"""Execute one trading step"""
|
| 144 |
-
try:
|
| 145 |
-
current_price = self.price_data[self.current_step]
|
| 146 |
-
prev_net_worth = self.net_worth
|
| 147 |
-
|
| 148 |
-
reward = 0.0
|
| 149 |
-
|
| 150 |
-
# Execute action
|
| 151 |
-
if action == 1 and self.position_size == 0: # Buy
|
| 152 |
-
# Risk-adjusted position sizing
|
| 153 |
-
position_value = self.balance * 0.1 * self.risk_multiplier
|
| 154 |
-
self.position_size = position_value / current_price
|
| 155 |
-
self.entry_price = current_price
|
| 156 |
-
self.balance -= position_value
|
| 157 |
-
self.total_trades += 1
|
| 158 |
-
reward = -0.01 # Small penalty for transaction
|
| 159 |
-
|
| 160 |
-
elif action == 2 and self.position_size > 0: # Sell (increase position)
|
| 161 |
-
additional_value = self.balance * 0.05 * self.risk_multiplier
|
| 162 |
-
additional_size = additional_value / current_price
|
| 163 |
-
self.position_size += additional_size
|
| 164 |
-
self.balance -= additional_value
|
| 165 |
-
self.total_trades += 1
|
| 166 |
-
reward = -0.005
|
| 167 |
-
|
| 168 |
-
elif action == 3 and self.position_size > 0: # Close position
|
| 169 |
-
close_value = self.position_size * current_price
|
| 170 |
-
self.balance += close_value
|
| 171 |
-
if self.entry_price > 0:
|
| 172 |
-
profit_loss = (current_price - self.entry_price) / self.entry_price
|
| 173 |
-
reward = profit_loss * 10 # Scale profit/loss
|
| 174 |
-
self.position_size = 0.0
|
| 175 |
-
self.entry_price = 0.0
|
| 176 |
-
self.total_trades += 1
|
| 177 |
-
|
| 178 |
-
# Update net worth
|
| 179 |
-
position_value = self.position_size * current_price if self.position_size > 0 else 0.0
|
| 180 |
-
self.net_worth = self.balance + position_value
|
| 181 |
-
|
| 182 |
-
# Add small reward for portfolio growth
|
| 183 |
-
if prev_net_worth > 0:
|
| 184 |
-
portfolio_change = (self.net_worth - prev_net_worth) / prev_net_worth
|
| 185 |
-
reward += portfolio_change * 5
|
| 186 |
-
|
| 187 |
-
# Move to next step
|
| 188 |
-
self.current_step += 1
|
| 189 |
-
|
| 190 |
-
# Check if episode is done
|
| 191 |
-
if self.current_step >= len(self.price_data) - 1:
|
| 192 |
-
self.done = True
|
| 193 |
-
# Final reward based on overall performance
|
| 194 |
-
if self.initial_balance > 0:
|
| 195 |
-
final_return = (self.net_worth - self.initial_balance) / self.initial_balance
|
| 196 |
-
reward += final_return * 20
|
| 197 |
-
|
| 198 |
-
info = {
|
| 199 |
-
'net_worth': float(self.net_worth),
|
| 200 |
-
'balance': float(self.balance),
|
| 201 |
-
'position_size': float(self.position_size),
|
| 202 |
-
'current_price': float(current_price),
|
| 203 |
-
'total_trades': int(self.total_trades),
|
| 204 |
-
'step': int(self.current_step)
|
| 205 |
-
}
|
| 206 |
-
|
| 207 |
-
obs = self._get_visual_observation()
|
| 208 |
-
return obs, float(reward), bool(self.done), info
|
| 209 |
-
|
| 210 |
-
except Exception as e:
|
| 211 |
-
print(f"Error in step execution: {e}")
|
| 212 |
-
# Return safe default values in case of error
|
| 213 |
-
default_info = {
|
| 214 |
-
'net_worth': float(self.initial_balance),
|
| 215 |
-
'balance': float(self.initial_balance),
|
| 216 |
-
'position_size': 0.0,
|
| 217 |
-
'current_price': 100.0,
|
| 218 |
-
'total_trades': 0,
|
| 219 |
-
'step': int(self.current_step)
|
| 220 |
-
}
|
| 221 |
-
return self._get_visual_observation(), 0.0, True, default_info
|
| 222 |
-
|
| 223 |
-
def get_price_history(self):
|
| 224 |
-
"""Get recent price history for visualization"""
|
| 225 |
-
window_size = min(50, self.current_step)
|
| 226 |
-
start_idx = max(0, self.current_step - window_size)
|
| 227 |
-
prices = self.price_data[start_idx:self.current_step]
|
| 228 |
-
return [float(price) for price in prices]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/sentiment/twitter_analyzer.py
DELETED
|
@@ -1,495 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import numpy as np
|
| 3 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
|
| 4 |
-
from textblob import TextBlob
|
| 5 |
-
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
|
| 6 |
-
from typing import Dict, List, Tuple, Optional
|
| 7 |
-
import time
|
| 8 |
-
from datetime import datetime, timedelta
|
| 9 |
-
import re
|
| 10 |
-
import logging
|
| 11 |
-
from functools import lru_cache
|
| 12 |
-
import warnings
|
| 13 |
-
warnings.filterwarnings('ignore')
|
| 14 |
-
|
| 15 |
-
# Setup logging
|
| 16 |
-
logging.basicConfig(level=logging.INFO)
|
| 17 |
-
logger = logging.getLogger(__name__)
|
| 18 |
-
|
| 19 |
-
class AdvancedSentimentAnalyzer:
|
| 20 |
-
def __init__(self, max_model_retries=3, cache_size=100):
|
| 21 |
-
self.sentiment_models = {}
|
| 22 |
-
self.vader_analyzer = None
|
| 23 |
-
self.max_model_retries = max_model_retries
|
| 24 |
-
self.cache = {} # Simple cache for expensive operations
|
| 25 |
-
|
| 26 |
-
# Influencers with validation
|
| 27 |
-
self.influencers = self._validate_influencers({
|
| 28 |
-
'elonmusk': {'name': 'Elon Musk', 'weight': 0.9, 'sector': 'all'},
|
| 29 |
-
'cz_binance': {'name': 'Changpeng Zhao', 'weight': 0.8, 'sector': 'crypto'},
|
| 30 |
-
'saylor': {'name': 'Michael Saylor', 'weight': 0.7, 'sector': 'bitcoin'},
|
| 31 |
-
'crypto_bitlord': {'name': 'Crypto Bitlord', 'weight': 0.6, 'sector': 'crypto'},
|
| 32 |
-
'aantonop': {'name': 'Andreas Antonopoulos', 'weight': 0.7, 'sector': 'bitcoin'},
|
| 33 |
-
'peterlbrandt': {'name': 'Peter Brandt', 'weight': 0.8, 'sector': 'trading'},
|
| 34 |
-
'nic__carter': {'name': 'Nic Carter', 'weight': 0.7, 'sector': 'crypto'},
|
| 35 |
-
'avalancheavax': {'name': 'Avalanche', 'weight': 0.6, 'sector': 'defi'}
|
| 36 |
-
})
|
| 37 |
-
|
| 38 |
-
self._initialize_vader()
|
| 39 |
-
|
| 40 |
-
def _validate_influencers(self, influencers: Dict) -> Dict:
|
| 41 |
-
"""Validate and normalize influencer weights"""
|
| 42 |
-
validated = {}
|
| 43 |
-
total_weight = 0
|
| 44 |
-
|
| 45 |
-
for username, data in influencers.items():
|
| 46 |
-
if 0.0 <= data.get('weight', 0) <= 1.0:
|
| 47 |
-
validated[username] = data
|
| 48 |
-
total_weight += data['weight']
|
| 49 |
-
|
| 50 |
-
# Normalize weights to sum to 1
|
| 51 |
-
if total_weight > 0:
|
| 52 |
-
for username in validated:
|
| 53 |
-
validated[username]['weight'] /= total_weight
|
| 54 |
-
|
| 55 |
-
logger.info(f"Validated {len(validated)} influencers with total weight {total_weight:.2f}")
|
| 56 |
-
return validated
|
| 57 |
-
|
| 58 |
-
def _initialize_vader(self):
|
| 59 |
-
"""Initialize VADER safely"""
|
| 60 |
-
try:
|
| 61 |
-
self.vader_analyzer = SentimentIntensityAnalyzer()
|
| 62 |
-
logger.info("VADER analyzer initialized")
|
| 63 |
-
except Exception as e:
|
| 64 |
-
logger.warning(f"Failed to initialize VADER: {e}")
|
| 65 |
-
self.vader_analyzer = None
|
| 66 |
-
|
| 67 |
-
@lru_cache(maxsize=128)
|
| 68 |
-
def _safe_pipeline_load(self, model_name: str):
|
| 69 |
-
"""Safely load pipeline with caching and retries"""
|
| 70 |
-
for attempt in range(self.max_model_retries):
|
| 71 |
-
try:
|
| 72 |
-
pipeline_obj = pipeline(
|
| 73 |
-
"sentiment-analysis",
|
| 74 |
-
model=model_name,
|
| 75 |
-
tokenizer=model_name,
|
| 76 |
-
device=-1, # CPU only for stability
|
| 77 |
-
return_all_scores=False
|
| 78 |
-
)
|
| 79 |
-
logger.info(f"Successfully loaded model: {model_name}")
|
| 80 |
-
return pipeline_obj
|
| 81 |
-
except Exception as e:
|
| 82 |
-
logger.warning(f"Attempt {attempt + 1} failed for {model_name}: {e}")
|
| 83 |
-
if attempt == self.max_model_retries - 1:
|
| 84 |
-
return None
|
| 85 |
-
time.sleep(1) # Brief delay before retry
|
| 86 |
-
|
| 87 |
-
def initialize_models(self) -> bool:
|
| 88 |
-
"""Initialize all sentiment analysis models with fallback"""
|
| 89 |
-
success_count = 0
|
| 90 |
-
|
| 91 |
-
try:
|
| 92 |
-
# Financial sentiment model
|
| 93 |
-
financial_model = self._safe_pipeline_load(
|
| 94 |
-
"mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis"
|
| 95 |
-
)
|
| 96 |
-
if financial_model:
|
| 97 |
-
self.sentiment_models['financial'] = financial_model
|
| 98 |
-
success_count += 1
|
| 99 |
-
|
| 100 |
-
# General sentiment model with fallback
|
| 101 |
-
general_model = self._safe_pipeline_load("distilbert-base-uncased-finetuned-sst-2-english")
|
| 102 |
-
if general_model:
|
| 103 |
-
self.sentiment_models['general'] = general_model
|
| 104 |
-
success_count += 1
|
| 105 |
-
else:
|
| 106 |
-
# Fallback to basic pipeline
|
| 107 |
-
try:
|
| 108 |
-
self.sentiment_models['general'] = pipeline("sentiment-analysis")
|
| 109 |
-
success_count += 1
|
| 110 |
-
except:
|
| 111 |
-
pass
|
| 112 |
-
|
| 113 |
-
# Crypto-specific model with fallback
|
| 114 |
-
crypto_model = self._safe_pipeline_load("ElKulako/cryptobert")
|
| 115 |
-
if crypto_model:
|
| 116 |
-
self.sentiment_models['crypto'] = crypto_model
|
| 117 |
-
success_count += 1
|
| 118 |
-
else:
|
| 119 |
-
self.sentiment_models['crypto'] = self.sentiment_models.get('financial',
|
| 120 |
-
self.sentiment_models.get('general'))
|
| 121 |
-
success_count += 1 if self.sentiment_models['crypto'] else 0
|
| 122 |
-
|
| 123 |
-
# At least one model should be available
|
| 124 |
-
if success_count > 0:
|
| 125 |
-
logger.info(f"✅ Loaded {success_count} sentiment models successfully!")
|
| 126 |
-
return True
|
| 127 |
-
else:
|
| 128 |
-
logger.error("❌ No sentiment models could be loaded")
|
| 129 |
-
return False
|
| 130 |
-
|
| 131 |
-
except Exception as e:
|
| 132 |
-
logger.error(f"❌ Critical error loading models: {e}")
|
| 133 |
-
return False
|
| 134 |
-
|
| 135 |
-
def analyze_text_sentiment(self, text: str) -> Dict:
|
| 136 |
-
"""Comprehensive sentiment analysis with robust error handling"""
|
| 137 |
-
if not text or len(text.strip()) < 5:
|
| 138 |
-
return self._default_sentiment()
|
| 139 |
-
|
| 140 |
-
cache_key = hash(text.strip()[:100]) # Simple cache key
|
| 141 |
-
if cache_key in self.cache:
|
| 142 |
-
return self.cache[cache_key]
|
| 143 |
-
|
| 144 |
-
try:
|
| 145 |
-
cleaned_text = self._clean_text(text)
|
| 146 |
-
|
| 147 |
-
# Analyze with available models
|
| 148 |
-
model_results = []
|
| 149 |
-
|
| 150 |
-
# Financial model
|
| 151 |
-
if 'financial' in self.sentiment_models:
|
| 152 |
-
model_results.append(self._analyze_model(cleaned_text, 'financial'))
|
| 153 |
-
|
| 154 |
-
# General model
|
| 155 |
-
if 'general' in self.sentiment_models:
|
| 156 |
-
model_results.append(self._analyze_model(cleaned_text, 'general'))
|
| 157 |
-
|
| 158 |
-
# Crypto model
|
| 159 |
-
if 'crypto' in self.sentiment_models:
|
| 160 |
-
model_results.append(self._analyze_model(cleaned_text, 'crypto'))
|
| 161 |
-
|
| 162 |
-
# Rule-based models
|
| 163 |
-
if self.vader_analyzer:
|
| 164 |
-
model_results.append(self._analyze_vader(cleaned_text))
|
| 165 |
-
|
| 166 |
-
model_results.append(self._analyze_textblob(cleaned_text))
|
| 167 |
-
|
| 168 |
-
# Filter valid results
|
| 169 |
-
valid_results = [r for r in model_results if r['score'] is not None]
|
| 170 |
-
|
| 171 |
-
if not valid_results:
|
| 172 |
-
return self._default_sentiment()
|
| 173 |
-
|
| 174 |
-
# Weighted combination (prioritize financial/crypto models)
|
| 175 |
-
weights = {
|
| 176 |
-
'financial': 0.35, 'crypto': 0.30, 'general': 0.20,
|
| 177 |
-
'vader': 0.10, 'textblob': 0.05
|
| 178 |
-
}
|
| 179 |
-
|
| 180 |
-
weighted_score = 0.0
|
| 181 |
-
total_weight = 0.0
|
| 182 |
-
confidences = []
|
| 183 |
-
|
| 184 |
-
for result in valid_results:
|
| 185 |
-
model_type = result.get('model_type', 'unknown')
|
| 186 |
-
weight = weights.get(model_type, 0.1)
|
| 187 |
-
weighted_score += result['score'] * weight
|
| 188 |
-
total_weight += weight
|
| 189 |
-
if 'confidence' in result:
|
| 190 |
-
confidences.append(result['confidence'])
|
| 191 |
-
|
| 192 |
-
if total_weight > 0:
|
| 193 |
-
final_score = weighted_score / total_weight
|
| 194 |
-
final_confidence = np.mean(confidences) if confidences else 0.0
|
| 195 |
-
else:
|
| 196 |
-
final_score = 0.5
|
| 197 |
-
final_confidence = 0.0
|
| 198 |
-
|
| 199 |
-
# Determine sentiment label
|
| 200 |
-
sentiment_label = self._score_to_label(final_score)
|
| 201 |
-
|
| 202 |
-
result = {
|
| 203 |
-
"sentiment": sentiment_label,
|
| 204 |
-
"score": float(final_score),
|
| 205 |
-
"confidence": float(final_confidence),
|
| 206 |
-
"urgency": self._detect_urgency(cleaned_text),
|
| 207 |
-
"keywords": self._extract_keywords(cleaned_text),
|
| 208 |
-
"models_used": len(valid_results),
|
| 209 |
-
"text_snippet": cleaned_text[:100] + "..." if len(cleaned_text) > 100 else cleaned_text
|
| 210 |
-
}
|
| 211 |
-
|
| 212 |
-
# Cache result
|
| 213 |
-
self.cache[cache_key] = result
|
| 214 |
-
if len(self.cache) > 50: # Limit cache size
|
| 215 |
-
self.cache.pop(next(iter(self.cache)))
|
| 216 |
-
|
| 217 |
-
return result
|
| 218 |
-
|
| 219 |
-
except Exception as e:
|
| 220 |
-
logger.error(f"Error in sentiment analysis: {e}")
|
| 221 |
-
return self._default_sentiment()
|
| 222 |
-
|
| 223 |
-
def _analyze_model(self, text: str, model_type: str) -> Dict:
|
| 224 |
-
"""Generic model analysis with error handling"""
|
| 225 |
-
try:
|
| 226 |
-
model = self.sentiment_models[model_type]
|
| 227 |
-
result = model(text[:512], truncation=True, max_length=512)[0] # Limit text length
|
| 228 |
-
|
| 229 |
-
score_map = {
|
| 230 |
-
'negative': 0.0, 'NEGATIVE': 0.0,
|
| 231 |
-
'neutral': 0.5, 'NEUTRAL': 0.5,
|
| 232 |
-
'positive': 1.0, 'POSITIVE': 1.0
|
| 233 |
-
}
|
| 234 |
-
|
| 235 |
-
score = score_map.get(result['label'].upper(), 0.5)
|
| 236 |
-
return {
|
| 237 |
-
'score': score,
|
| 238 |
-
'confidence': result['score'],
|
| 239 |
-
'model_type': model_type
|
| 240 |
-
}
|
| 241 |
-
except Exception as e:
|
| 242 |
-
logger.debug(f"Model {model_type} failed: {e}")
|
| 243 |
-
return {'score': None, 'confidence': 0.0, 'model_type': model_type}
|
| 244 |
-
|
| 245 |
-
def _score_to_label(self, score: float) -> str:
|
| 246 |
-
"""Convert score to sentiment label"""
|
| 247 |
-
if score > 0.6:
|
| 248 |
-
return "bullish"
|
| 249 |
-
elif score > 0.4:
|
| 250 |
-
return "neutral"
|
| 251 |
-
else:
|
| 252 |
-
return "bearish"
|
| 253 |
-
|
| 254 |
-
def _analyze_vader(self, text: str) -> Dict:
|
| 255 |
-
"""VADER analysis with error handling"""
|
| 256 |
-
if not self.vader_analyzer:
|
| 257 |
-
return {'score': None, 'confidence': 0.0, 'model_type': 'vader'}
|
| 258 |
-
|
| 259 |
-
try:
|
| 260 |
-
scores = self.vader_analyzer.polarity_scores(text)
|
| 261 |
-
compound = (scores['compound'] + 1) / 2 # Normalize to 0-1
|
| 262 |
-
return {
|
| 263 |
-
'score': compound,
|
| 264 |
-
'confidence': abs(scores['compound']),
|
| 265 |
-
'model_type': 'vader'
|
| 266 |
-
}
|
| 267 |
-
except Exception:
|
| 268 |
-
return {'score': None, 'confidence': 0.0, 'model_type': 'vader'}
|
| 269 |
-
|
| 270 |
-
def _analyze_textblob(self, text: str) -> Dict:
|
| 271 |
-
"""TextBlob analysis with error handling"""
|
| 272 |
-
try:
|
| 273 |
-
analysis = TextBlob(text)
|
| 274 |
-
polarity = (analysis.sentiment.polarity + 1) / 2 # Normalize to 0-1
|
| 275 |
-
return {
|
| 276 |
-
'score': polarity,
|
| 277 |
-
'confidence': abs(analysis.sentiment.polarity),
|
| 278 |
-
'model_type': 'textblob'
|
| 279 |
-
}
|
| 280 |
-
except Exception:
|
| 281 |
-
return {'score': None, 'confidence': 0.0, 'model_type': 'textblob'}
|
| 282 |
-
|
| 283 |
-
def _clean_text(self, text: str) -> str:
|
| 284 |
-
"""Enhanced text cleaning"""
|
| 285 |
-
try:
|
| 286 |
-
# Remove URLs
|
| 287 |
-
text = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', text)
|
| 288 |
-
# Remove mentions
|
| 289 |
-
text = re.sub(r'@\w+', '', text)
|
| 290 |
-
# Remove hashtags but keep text
|
| 291 |
-
text = re.sub(r'#\w+', '', text)
|
| 292 |
-
# Remove extra whitespace and normalize
|
| 293 |
-
text = ' '.join(text.split())
|
| 294 |
-
return text.strip()
|
| 295 |
-
except:
|
| 296 |
-
return text[:200] if len(text) > 200 else text
|
| 297 |
-
|
| 298 |
-
def _extract_keywords(self, text: str) -> List[str]:
|
| 299 |
-
"""Extract financial keywords with better matching"""
|
| 300 |
-
keyword_categories = {
|
| 301 |
-
'bullish': ['moon', 'rocket', 'bull', 'buy', 'long', 'growth', 'opportunity', 'bullrun'],
|
| 302 |
-
'bearish': ['crash', 'bear', 'sell', 'short', 'drop', 'dump', 'warning', 'risk', 'fud'],
|
| 303 |
-
'crypto': ['bitcoin', 'btc', 'ethereum', 'eth', 'crypto', 'blockchain', 'defi', 'nft'],
|
| 304 |
-
'urgency': ['now', 'urgent', 'immediately', 'alert', 'breaking', 'huge']
|
| 305 |
-
}
|
| 306 |
-
|
| 307 |
-
found = []
|
| 308 |
-
text_lower = text.lower()
|
| 309 |
-
|
| 310 |
-
for category, keywords in keyword_categories.items():
|
| 311 |
-
for keyword in keywords:
|
| 312 |
-
if re.search(rf'\b{keyword}\b', text_lower):
|
| 313 |
-
found.append(f"{category}:{keyword}")
|
| 314 |
-
|
| 315 |
-
return found[:5]
|
| 316 |
-
|
| 317 |
-
def _detect_urgency(self, text: str) -> float:
|
| 318 |
-
"""Improved urgency detection"""
|
| 319 |
-
urgency_indicators = ['!', 'urgent', 'breaking', 'alert', 'immediately', 'now', 'huge', 'massive']
|
| 320 |
-
text_lower = text.lower()
|
| 321 |
-
|
| 322 |
-
score = 0.0
|
| 323 |
-
for indicator in urgency_indicators:
|
| 324 |
-
if re.search(rf'\b{indicator}\b', text_lower):
|
| 325 |
-
score += 0.15
|
| 326 |
-
|
| 327 |
-
# Exclamation and question marks
|
| 328 |
-
punctuation_count = text.count('!') + text.count('?')
|
| 329 |
-
score += min(punctuation_count * 0.1, 0.3)
|
| 330 |
-
|
| 331 |
-
# Caps lock indicator
|
| 332 |
-
caps_ratio = sum(1 for c in text if c.isupper()) / len([c for c in text if c.isalpha()])
|
| 333 |
-
score += min(caps_ratio * 0.5, 0.2)
|
| 334 |
-
|
| 335 |
-
return min(score, 1.0)
|
| 336 |
-
|
| 337 |
-
def _default_sentiment(self) -> Dict:
|
| 338 |
-
"""Safe default sentiment"""
|
| 339 |
-
return {
|
| 340 |
-
"sentiment": "neutral",
|
| 341 |
-
"score": 0.5,
|
| 342 |
-
"confidence": 0.0,
|
| 343 |
-
"urgency": 0.0,
|
| 344 |
-
"keywords": [],
|
| 345 |
-
"models_used": 0,
|
| 346 |
-
"text_snippet": ""
|
| 347 |
-
}
|
| 348 |
-
|
| 349 |
-
def get_influencer_sentiment(self, hours_back: int = 24) -> Dict:
|
| 350 |
-
"""Get weighted influencer sentiment with caching"""
|
| 351 |
-
try:
|
| 352 |
-
# Generate synthetic tweets (in production, replace with real API)
|
| 353 |
-
tweets = self._generate_synthetic_tweets(hours_back)
|
| 354 |
-
influencer_sentiments = {}
|
| 355 |
-
|
| 356 |
-
for username, tweet_batch in tweets.items():
|
| 357 |
-
if username not in self.influencers:
|
| 358 |
-
continue
|
| 359 |
-
|
| 360 |
-
tweet_sentiments = []
|
| 361 |
-
for tweet in tweet_batch:
|
| 362 |
-
sentiment = self.analyze_text_sentiment(tweet['text'])
|
| 363 |
-
sentiment.update({
|
| 364 |
-
'timestamp': tweet['timestamp'],
|
| 365 |
-
'username': username
|
| 366 |
-
})
|
| 367 |
-
tweet_sentiments.append(sentiment)
|
| 368 |
-
|
| 369 |
-
if tweet_sentiments:
|
| 370 |
-
# Weighted average by confidence
|
| 371 |
-
total_weighted = sum(s['score'] * s['confidence'] for s in tweet_sentiments)
|
| 372 |
-
total_confidence = sum(s['confidence'] for s in tweet_sentiments)
|
| 373 |
-
|
| 374 |
-
avg_score = total_weighted / total_confidence if total_confidence > 0 else 0.5
|
| 375 |
-
avg_confidence = np.mean([s['confidence'] for s in tweet_sentiments])
|
| 376 |
-
|
| 377 |
-
influencer_sentiments[username] = {
|
| 378 |
-
'score': float(avg_score),
|
| 379 |
-
'confidence': float(avg_confidence),
|
| 380 |
-
'weight': self.influencers[username]['weight'],
|
| 381 |
-
'tweet_count': len(tweet_sentiments),
|
| 382 |
-
'tweets': tweet_sentiments[:3]
|
| 383 |
-
}
|
| 384 |
-
|
| 385 |
-
# Calculate market sentiment
|
| 386 |
-
if influencer_sentiments:
|
| 387 |
-
total_weighted_score = sum(
|
| 388 |
-
data['score'] * data['weight'] * data['confidence']
|
| 389 |
-
for data in influencer_sentiments.values()
|
| 390 |
-
)
|
| 391 |
-
total_weight = sum(
|
| 392 |
-
data['weight'] * data['confidence']
|
| 393 |
-
for data in influencer_sentiments.values()
|
| 394 |
-
)
|
| 395 |
-
|
| 396 |
-
market_sentiment = (total_weighted_score / total_weight
|
| 397 |
-
if total_weight > 0 else 0.5)
|
| 398 |
-
avg_confidence = np.mean([d['confidence'] for d in influencer_sentiments.values()])
|
| 399 |
-
else:
|
| 400 |
-
market_sentiment = 0.5
|
| 401 |
-
avg_confidence = 0.0
|
| 402 |
-
|
| 403 |
-
return {
|
| 404 |
-
"market_sentiment": float(market_sentiment),
|
| 405 |
-
"confidence": float(avg_confidence),
|
| 406 |
-
"influencer_count": len(influencer_sentiments),
|
| 407 |
-
"total_tweets": sum(d['tweet_count'] for d in influencer_sentiments.values()),
|
| 408 |
-
"timestamp": datetime.now().isoformat(),
|
| 409 |
-
"influencers": influencer_sentiments
|
| 410 |
-
}
|
| 411 |
-
|
| 412 |
-
except Exception as e:
|
| 413 |
-
logger.error(f"Error in get_influencer_sentiment: {e}")
|
| 414 |
-
return {
|
| 415 |
-
"market_sentiment": 0.5,
|
| 416 |
-
"confidence": 0.0,
|
| 417 |
-
"error": str(e),
|
| 418 |
-
"timestamp": datetime.now().isoformat()
|
| 419 |
-
}
|
| 420 |
-
|
| 421 |
-
def _generate_synthetic_tweets(self, hours_back: int) -> Dict:
|
| 422 |
-
"""Generate realistic synthetic tweets for testing"""
|
| 423 |
-
current_time = time.time()
|
| 424 |
-
tweets = {}
|
| 425 |
-
np.random.seed(int(current_time) % 10000) # Reproducible randomness
|
| 426 |
-
|
| 427 |
-
# Simulate market conditions
|
| 428 |
-
market_trend = np.sin(current_time / 3600) * 0.3 + 0.5
|
| 429 |
-
|
| 430 |
-
for username in self.influencers:
|
| 431 |
-
user_tweets = []
|
| 432 |
-
base_sentiment = np.clip(market_trend + np.random.normal(0, 0.15), 0.1, 0.9)
|
| 433 |
-
|
| 434 |
-
templates = self._get_user_templates(username, base_sentiment)
|
| 435 |
-
|
| 436 |
-
for i in range(np.random.randint(1, 4)): # 1-3 tweets
|
| 437 |
-
template = np.random.choice(templates)
|
| 438 |
-
tweet_text = template.format(**self._get_template_vars(base_sentiment))
|
| 439 |
-
|
| 440 |
-
# Add emojis occasionally
|
| 441 |
-
if np.random.random() < 0.4:
|
| 442 |
-
emojis = self._get_relevant_emojis(base_sentiment)
|
| 443 |
-
tweet_text += " " + np.random.choice(emojis)
|
| 444 |
-
|
| 445 |
-
user_tweets.append({
|
| 446 |
-
'text': tweet_text,
|
| 447 |
-
'timestamp': current_time - (i * 3600 * np.random.uniform(0.5, hours_back))
|
| 448 |
-
})
|
| 449 |
-
|
| 450 |
-
tweets[username] = user_tweets
|
| 451 |
-
|
| 452 |
-
return tweets
|
| 453 |
-
|
| 454 |
-
def _get_user_templates(self, username: str, sentiment: float) -> List[str]:
|
| 455 |
-
"""Get appropriate templates based on sentiment"""
|
| 456 |
-
templates = {
|
| 457 |
-
'bullish': [
|
| 458 |
-
"{action} looking strong! {emoji}",
|
| 459 |
-
"Great {topic} developments ahead 🚀",
|
| 460 |
-
"Bullish on {topic} {emoji}"
|
| 461 |
-
],
|
| 462 |
-
'bearish': [
|
| 463 |
-
"Caution on {topic} {emoji}",
|
| 464 |
-
"{action} facing challenges 📉",
|
| 465 |
-
"Bearish signals for {topic}"
|
| 466 |
-
],
|
| 467 |
-
'neutral': [
|
| 468 |
-
"Watching {topic} developments 👀",
|
| 469 |
-
"{action} market update 📊",
|
| 470 |
-
"Interesting {topic} news"
|
| 471 |
-
]
|
| 472 |
-
}
|
| 473 |
-
|
| 474 |
-
category = 'bullish' if sentiment > 0.6 else 'bearish' if sentiment < 0.4 else 'neutral'
|
| 475 |
-
return templates[category]
|
| 476 |
-
|
| 477 |
-
def _get_template_vars(self, sentiment: float) -> Dict:
|
| 478 |
-
"""Get variables for tweet templates"""
|
| 479 |
-
topics = ['BTC', 'crypto', 'market', 'DeFi']
|
| 480 |
-
actions = ['Bitcoin', 'ETH', 'market', 'altcoins']
|
| 481 |
-
|
| 482 |
-
return {
|
| 483 |
-
'topic': np.random.choice(topics),
|
| 484 |
-
'action': np.random.choice(actions),
|
| 485 |
-
'emoji': np.random.choice(['📈', '📉', '🚀', '💎'])
|
| 486 |
-
}
|
| 487 |
-
|
| 488 |
-
def _get_relevant_emojis(self, sentiment: float) -> List[str]:
|
| 489 |
-
"""Get sentiment-relevant emojis"""
|
| 490 |
-
if sentiment > 0.6:
|
| 491 |
-
return ['🚀', '📈', '💎', '🔥']
|
| 492 |
-
elif sentiment < 0.4:
|
| 493 |
-
return ['📉', '😬', '⚠️', '💥']
|
| 494 |
-
else:
|
| 495 |
-
return ['📊', '👀', '🤔', '💭']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils/config.py
DELETED
|
@@ -1,290 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import os
|
| 3 |
-
from typing import Dict, Any, Optional
|
| 4 |
-
from dataclasses import dataclass, asdict, field
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
import logging
|
| 7 |
-
|
| 8 |
-
logger = logging.getLogger(__name__)
|
| 9 |
-
|
| 10 |
-
@dataclass
|
| 11 |
-
class TradingConfig:
|
| 12 |
-
"""Comprehensive trading configuration with validation and persistence"""
|
| 13 |
-
|
| 14 |
-
# Environment settings
|
| 15 |
-
initial_balance: float = 10000.0
|
| 16 |
-
max_steps: int = 1000
|
| 17 |
-
transaction_cost: float = 0.001
|
| 18 |
-
risk_level: str = "Medium"
|
| 19 |
-
asset_type: str = "Crypto"
|
| 20 |
-
|
| 21 |
-
# AI Agent settings
|
| 22 |
-
learning_rate: float = 0.001
|
| 23 |
-
gamma: float = 0.99
|
| 24 |
-
epsilon_start: float = 1.0
|
| 25 |
-
epsilon_min: float = 0.01
|
| 26 |
-
epsilon_decay: float = 0.9995
|
| 27 |
-
memory_size: int = 10000
|
| 28 |
-
batch_size: int = 32
|
| 29 |
-
target_update_freq: int = 100
|
| 30 |
-
gradient_clip: float = 1.0
|
| 31 |
-
|
| 32 |
-
# Sentiment settings
|
| 33 |
-
use_sentiment: bool = True
|
| 34 |
-
sentiment_influence: float = 0.3
|
| 35 |
-
sentiment_update_freq: int = 5
|
| 36 |
-
|
| 37 |
-
# Visualization settings
|
| 38 |
-
chart_width: int = 800
|
| 39 |
-
chart_height: int = 600
|
| 40 |
-
update_interval: int = 100
|
| 41 |
-
enable_visualization: bool = True
|
| 42 |
-
|
| 43 |
-
# Training settings
|
| 44 |
-
max_episodes: int = 1000
|
| 45 |
-
eval_episodes: int = 10
|
| 46 |
-
eval_freq: int = 100
|
| 47 |
-
save_freq: int = 500
|
| 48 |
-
log_level: str = "INFO"
|
| 49 |
-
|
| 50 |
-
# Paths
|
| 51 |
-
model_dir: str = "models"
|
| 52 |
-
log_dir: str = "logs"
|
| 53 |
-
data_dir: str = "data"
|
| 54 |
-
|
| 55 |
-
# Device settings
|
| 56 |
-
use_cuda: bool = True
|
| 57 |
-
device: str = "auto"
|
| 58 |
-
|
| 59 |
-
def __post_init__(self):
|
| 60 |
-
"""Validate and initialize configuration"""
|
| 61 |
-
self._validate()
|
| 62 |
-
self._setup_paths()
|
| 63 |
-
self._setup_device()
|
| 64 |
-
self._setup_logging()
|
| 65 |
-
|
| 66 |
-
def _validate(self):
|
| 67 |
-
"""Validate configuration parameters"""
|
| 68 |
-
errors = []
|
| 69 |
-
|
| 70 |
-
# Balance validation
|
| 71 |
-
if self.initial_balance <= 0:
|
| 72 |
-
errors.append("initial_balance must be positive")
|
| 73 |
-
|
| 74 |
-
# Steps validation
|
| 75 |
-
if self.max_steps <= 0:
|
| 76 |
-
errors.append("max_steps must be positive")
|
| 77 |
-
|
| 78 |
-
# Costs validation
|
| 79 |
-
if not 0.0 <= self.transaction_cost <= 0.1:
|
| 80 |
-
errors.append("transaction_cost should be between 0 and 0.1")
|
| 81 |
-
|
| 82 |
-
# Learning rate validation
|
| 83 |
-
if not 0.0001 <= self.learning_rate <= 0.1:
|
| 84 |
-
errors.append("learning_rate should be between 0.0001 and 0.1")
|
| 85 |
-
|
| 86 |
-
# Discount factor validation
|
| 87 |
-
if not 0.0 <= self.gamma <= 1.0:
|
| 88 |
-
errors.append("gamma must be between 0 and 1")
|
| 89 |
-
|
| 90 |
-
# Epsilon validation
|
| 91 |
-
if not 0.0 <= self.epsilon_min <= self.epsilon_start <= 1.0:
|
| 92 |
-
errors.append("epsilon values must satisfy 0 <= epsilon_min <= epsilon_start <= 1")
|
| 93 |
-
|
| 94 |
-
# Batch size validation
|
| 95 |
-
if self.batch_size > self.memory_size:
|
| 96 |
-
errors.append("batch_size cannot exceed memory_size")
|
| 97 |
-
|
| 98 |
-
# Risk level validation
|
| 99 |
-
valid_risks = ["Low", "Medium", "High"]
|
| 100 |
-
if self.risk_level not in valid_risks:
|
| 101 |
-
errors.append(f"risk_level must be one of {valid_risks}")
|
| 102 |
-
|
| 103 |
-
# Asset type validation
|
| 104 |
-
valid_assets = ["Crypto", "Stocks", "Forex", "Commodities"]
|
| 105 |
-
if self.asset_type not in valid_assets:
|
| 106 |
-
errors.append(f"asset_type must be one of {valid_assets}")
|
| 107 |
-
|
| 108 |
-
# Sentiment influence validation
|
| 109 |
-
if not 0.0 <= self.sentiment_influence <= 1.0:
|
| 110 |
-
errors.append("sentiment_influence must be between 0 and 1")
|
| 111 |
-
|
| 112 |
-
if errors:
|
| 113 |
-
logger.error(f"Configuration validation errors: {errors}")
|
| 114 |
-
raise ValueError(f"Invalid configuration: {'; '.join(errors)}")
|
| 115 |
-
|
| 116 |
-
logger.info("Configuration validation passed")
|
| 117 |
-
|
| 118 |
-
def _setup_paths(self):
|
| 119 |
-
"""Create necessary directories"""
|
| 120 |
-
for path_attr in ['model_dir', 'log_dir', 'data_dir']:
|
| 121 |
-
path = Path(getattr(self, path_attr))
|
| 122 |
-
path.mkdir(parents=True, exist_ok=True)
|
| 123 |
-
setattr(self, f"{path_attr}_path", path)
|
| 124 |
-
|
| 125 |
-
def _setup_device(self):
|
| 126 |
-
"""Setup device configuration"""
|
| 127 |
-
import torch
|
| 128 |
-
if self.device == "auto":
|
| 129 |
-
self.device = "cuda" if self.use_cuda and torch.cuda.is_available() else "cpu"
|
| 130 |
-
else:
|
| 131 |
-
if self.device not in ["cpu", "cuda", "mps"]:
|
| 132 |
-
logger.warning(f"Unknown device {self.device}, defaulting to CPU")
|
| 133 |
-
self.device = "cpu"
|
| 134 |
-
|
| 135 |
-
logger.info(f"Using device: {self.device}")
|
| 136 |
-
|
| 137 |
-
def _setup_logging(self):
|
| 138 |
-
"""Setup logging configuration"""
|
| 139 |
-
import logging
|
| 140 |
-
log_level = getattr(logging, self.log_level.upper())
|
| 141 |
-
logging.getLogger().setLevel(log_level)
|
| 142 |
-
|
| 143 |
-
def to_dict(self) -> Dict[str, Any]:
|
| 144 |
-
"""Convert config to dictionary, excluding sensitive paths"""
|
| 145 |
-
config_dict = asdict(self)
|
| 146 |
-
# Remove absolute paths for serialization
|
| 147 |
-
for key in list(config_dict.keys()):
|
| 148 |
-
if key.endswith('_path') or 'dir' in key:
|
| 149 |
-
config_dict[key] = str(getattr(self, key)) if isinstance(getattr(self, key), Path) else getattr(self, key)
|
| 150 |
-
return config_dict
|
| 151 |
-
|
| 152 |
-
def to_json(self, filepath: Optional[str] = None) -> str:
|
| 153 |
-
"""Serialize config to JSON"""
|
| 154 |
-
config_dict = self.to_dict()
|
| 155 |
-
json_str = json.dumps(config_dict, indent=2, default=str)
|
| 156 |
-
|
| 157 |
-
if filepath:
|
| 158 |
-
with open(filepath, 'w') as f:
|
| 159 |
-
f.write(json_str)
|
| 160 |
-
logger.info(f"Config saved to {filepath}")
|
| 161 |
-
|
| 162 |
-
return json_str
|
| 163 |
-
|
| 164 |
-
@classmethod
|
| 165 |
-
def from_json(cls, filepath: str) -> 'TradingConfig':
|
| 166 |
-
"""Load config from JSON file"""
|
| 167 |
-
try:
|
| 168 |
-
with open(filepath, 'r') as f:
|
| 169 |
-
config_dict = json.load(f)
|
| 170 |
-
|
| 171 |
-
# Create dataclass instance
|
| 172 |
-
config = cls(**config_dict)
|
| 173 |
-
logger.info(f"Config loaded from {filepath}")
|
| 174 |
-
return config
|
| 175 |
-
except Exception as e:
|
| 176 |
-
logger.error(f"Error loading config from {filepath}: {e}")
|
| 177 |
-
raise
|
| 178 |
-
|
| 179 |
-
@classmethod
|
| 180 |
-
def from_dict(cls, config_dict: Dict[str, Any]) -> 'TradingConfig':
|
| 181 |
-
"""Create config from dictionary"""
|
| 182 |
-
return cls(**config_dict)
|
| 183 |
-
|
| 184 |
-
def save(self, filepath: str):
|
| 185 |
-
"""Save config to file"""
|
| 186 |
-
self.to_json(filepath)
|
| 187 |
-
|
| 188 |
-
@staticmethod
|
| 189 |
-
def load(filepath: str) -> 'TradingConfig':
|
| 190 |
-
"""Static method to load config"""
|
| 191 |
-
return TradingConfig.from_json(filepath)
|
| 192 |
-
|
| 193 |
-
def update(self, **kwargs):
|
| 194 |
-
"""Update config parameters and revalidate"""
|
| 195 |
-
for key, value in kwargs.items():
|
| 196 |
-
if hasattr(self, key):
|
| 197 |
-
setattr(self, key, value)
|
| 198 |
-
else:
|
| 199 |
-
logger.warning(f"Unknown config parameter: {key}")
|
| 200 |
-
|
| 201 |
-
self._validate()
|
| 202 |
-
logger.info("Config updated and validated")
|
| 203 |
-
|
| 204 |
-
def get_agent_params(self) -> Dict[str, Any]:
|
| 205 |
-
"""Get parameters specific to agent"""
|
| 206 |
-
return {
|
| 207 |
-
'learning_rate': self.learning_rate,
|
| 208 |
-
'gamma': self.gamma,
|
| 209 |
-
'epsilon_start': self.epsilon_start,
|
| 210 |
-
'epsilon_min': self.epsilon_min,
|
| 211 |
-
'epsilon_decay': self.epsilon_decay,
|
| 212 |
-
'memory_size': self.memory_size,
|
| 213 |
-
'batch_size': self.batch_size,
|
| 214 |
-
'target_update_freq': self.target_update_freq,
|
| 215 |
-
'gradient_clip': self.gradient_clip,
|
| 216 |
-
'device': self.device
|
| 217 |
-
}
|
| 218 |
-
|
| 219 |
-
def get_env_params(self) -> Dict[str, Any]:
|
| 220 |
-
"""Get parameters specific to environment"""
|
| 221 |
-
return {
|
| 222 |
-
'initial_balance': self.initial_balance,
|
| 223 |
-
'max_steps': self.max_steps,
|
| 224 |
-
'transaction_cost': self.transaction_cost,
|
| 225 |
-
'risk_level': self.risk_level,
|
| 226 |
-
'asset_type': self.asset_type,
|
| 227 |
-
'use_sentiment': self.use_sentiment,
|
| 228 |
-
'sentiment_influence': self.sentiment_influence,
|
| 229 |
-
'sentiment_update_freq': self.sentiment_update_freq
|
| 230 |
-
}
|
| 231 |
-
|
| 232 |
-
def __str__(self) -> str:
|
| 233 |
-
"""String representation of config"""
|
| 234 |
-
return json.dumps(self.to_dict(), indent=2)
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
# Legacy compatibility
|
| 238 |
-
class LegacyTradingConfig:
|
| 239 |
-
"""Wrapper for backward compatibility"""
|
| 240 |
-
|
| 241 |
-
def __init__(self, config_file: Optional[str] = None):
|
| 242 |
-
if config_file and os.path.exists(config_file):
|
| 243 |
-
self.config = TradingConfig.from_json(config_file)
|
| 244 |
-
else:
|
| 245 |
-
self.config = TradingConfig()
|
| 246 |
-
|
| 247 |
-
def __getattr__(self, name):
|
| 248 |
-
return getattr(self.config, name)
|
| 249 |
-
|
| 250 |
-
def to_dict(self):
|
| 251 |
-
return self.config.to_dict()
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
# Default config instance
|
| 255 |
-
DEFAULT_CONFIG = TradingConfig()
|
| 256 |
-
|
| 257 |
-
# Example usage and config loading
|
| 258 |
-
def create_config_from_env() -> TradingConfig:
|
| 259 |
-
"""Create config from environment variables"""
|
| 260 |
-
import os
|
| 261 |
-
config_dict = {}
|
| 262 |
-
|
| 263 |
-
env_mappings = {
|
| 264 |
-
'INITIAL_BALANCE': 'initial_balance',
|
| 265 |
-
'MAX_STEPS': 'max_steps',
|
| 266 |
-
'LEARNING_RATE': 'learning_rate',
|
| 267 |
-
'BATCH_SIZE': 'batch_size',
|
| 268 |
-
'USE_CUDA': 'use_cuda'
|
| 269 |
-
}
|
| 270 |
-
|
| 271 |
-
for env_var, config_key in env_mappings.items():
|
| 272 |
-
env_value = os.getenv(env_var)
|
| 273 |
-
if env_value is not None:
|
| 274 |
-
try:
|
| 275 |
-
# Try to convert to appropriate type
|
| 276 |
-
if config_key in ['initial_balance', 'learning_rate']:
|
| 277 |
-
config_dict[config_key] = float(env_value)
|
| 278 |
-
elif config_key in ['max_steps', 'batch_size']:
|
| 279 |
-
config_dict[config_key] = int(env_value)
|
| 280 |
-
elif config_key == 'use_cuda':
|
| 281 |
-
config_dict[config_key] = env_value.lower() in ('true', '1', 'yes')
|
| 282 |
-
except ValueError:
|
| 283 |
-
logger.warning(f"Invalid environment variable {env_var}: {env_value}")
|
| 284 |
-
|
| 285 |
-
if config_dict:
|
| 286 |
-
base_config = TradingConfig()
|
| 287 |
-
base_config.update(**config_dict)
|
| 288 |
-
return base_config
|
| 289 |
-
|
| 290 |
-
return DEFAULT_CONFIG
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/visualizers/chart_renderer.py
DELETED
|
@@ -1,410 +0,0 @@
|
|
| 1 |
-
import plotly.graph_objects as go
|
| 2 |
-
from plotly.subplots import make_subplots
|
| 3 |
-
import plotly.express as px
|
| 4 |
-
import numpy as np
|
| 5 |
-
import pandas as pd
|
| 6 |
-
from typing import List, Dict, Any, Optional, Union
|
| 7 |
-
import logging
|
| 8 |
-
from datetime import datetime
|
| 9 |
-
import warnings
|
| 10 |
-
warnings.filterwarnings('ignore')
|
| 11 |
-
|
| 12 |
-
logger = logging.getLogger(__name__)
|
| 13 |
-
|
| 14 |
-
class ChartRenderer:
|
| 15 |
-
"""Advanced chart renderer for trading visualizations with error handling"""
|
| 16 |
-
|
| 17 |
-
def __init__(self, theme: str = "plotly_white", default_height: int = 400):
|
| 18 |
-
self.theme = theme
|
| 19 |
-
self.default_height = default_height
|
| 20 |
-
self._validate_plotly()
|
| 21 |
-
|
| 22 |
-
def _validate_plotly(self):
|
| 23 |
-
"""Validate Plotly installation and capabilities"""
|
| 24 |
-
try:
|
| 25 |
-
import plotly
|
| 26 |
-
logger.info(f"Plotly version: {plotly.__version__}")
|
| 27 |
-
except ImportError:
|
| 28 |
-
raise ImportError("Plotly is required for ChartRenderer")
|
| 29 |
-
|
| 30 |
-
def _safe_data_validation(self, data, expected_len: Optional[int] = None,
|
| 31 |
-
data_type: str = "data") -> bool:
|
| 32 |
-
"""Validate input data safely"""
|
| 33 |
-
if data is None or len(data) == 0:
|
| 34 |
-
logger.warning(f"No {data_type} provided")
|
| 35 |
-
return False
|
| 36 |
-
|
| 37 |
-
if expected_len and len(data) != expected_len:
|
| 38 |
-
logger.warning(f"{data_type} length mismatch: expected {expected_len}, got {len(data)}")
|
| 39 |
-
|
| 40 |
-
if isinstance(data, (list, np.ndarray)):
|
| 41 |
-
if np.any(np.isnan(data)) or np.any(np.isinf(data)):
|
| 42 |
-
logger.warning(f"{data_type} contains NaN or Inf values")
|
| 43 |
-
return False
|
| 44 |
-
|
| 45 |
-
return True
|
| 46 |
-
|
| 47 |
-
def render_price_chart(self, prices: Union[List[float], np.ndarray],
|
| 48 |
-
actions: Optional[List[int]] = None,
|
| 49 |
-
current_step: int = 0,
|
| 50 |
-
title: Optional[str] = None,
|
| 51 |
-
height: Optional[int] = None) -> go.Figure:
|
| 52 |
-
"""Render interactive price chart with trading actions"""
|
| 53 |
-
fig = go.Figure()
|
| 54 |
-
height = height or self.default_height
|
| 55 |
-
|
| 56 |
-
# Validate data
|
| 57 |
-
if not self._safe_data_validation(prices, data_type="prices"):
|
| 58 |
-
return self._create_empty_figure("No Price Data", height)
|
| 59 |
-
|
| 60 |
-
try:
|
| 61 |
-
# Convert to numpy for consistency
|
| 62 |
-
prices = np.array(prices, dtype=np.float64)
|
| 63 |
-
time_steps = np.arange(len(prices))
|
| 64 |
-
|
| 65 |
-
# Add main price trace
|
| 66 |
-
fig.add_trace(go.Scatter(
|
| 67 |
-
x=time_steps,
|
| 68 |
-
y=prices,
|
| 69 |
-
mode='lines',
|
| 70 |
-
name='Price',
|
| 71 |
-
line=dict(color='#1f77b4', width=2),
|
| 72 |
-
hovertemplate='<b>Step %{x}</b><br>Price: $%{y:.2f}<extra></extra>'
|
| 73 |
-
))
|
| 74 |
-
|
| 75 |
-
# Add action markers with validation
|
| 76 |
-
if actions and self._safe_data_validation(actions, len(prices), "actions"):
|
| 77 |
-
self._add_action_markers(fig, prices, actions, time_steps)
|
| 78 |
-
|
| 79 |
-
# Add current step indicator
|
| 80 |
-
if 0 <= current_step < len(prices):
|
| 81 |
-
fig.add_vline(
|
| 82 |
-
x=current_step,
|
| 83 |
-
line_dash="dash",
|
| 84 |
-
line_color="orange",
|
| 85 |
-
annotation_text=f"Current Step ({current_step})",
|
| 86 |
-
annotation_position="top right"
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
# Calculate and add key metrics
|
| 90 |
-
self._add_price_metrics(fig, prices)
|
| 91 |
-
|
| 92 |
-
title = title or f"Asset Price Evolution (Step: {current_step})"
|
| 93 |
-
fig.update_layout(
|
| 94 |
-
title={
|
| 95 |
-
'text': title,
|
| 96 |
-
'x': 0.5,
|
| 97 |
-
'xanchor': 'center',
|
| 98 |
-
'font': {'size': 16}
|
| 99 |
-
},
|
| 100 |
-
xaxis_title="Time Step",
|
| 101 |
-
yaxis_title="Price ($)",
|
| 102 |
-
height=height + 100,
|
| 103 |
-
showlegend=True,
|
| 104 |
-
template=self.theme,
|
| 105 |
-
hovermode='x unified'
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
-
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
|
| 109 |
-
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
|
| 110 |
-
|
| 111 |
-
return fig
|
| 112 |
-
|
| 113 |
-
except Exception as e:
|
| 114 |
-
logger.error(f"Error rendering price chart: {e}")
|
| 115 |
-
return self._create_empty_figure("Error Rendering Price Chart", height)
|
| 116 |
-
|
| 117 |
-
def _add_action_markers(self, fig: go.Figure, prices: np.ndarray,
|
| 118 |
-
actions: List[int], time_steps: np.ndarray):
|
| 119 |
-
"""Add buy/sell/close action markers to figure"""
|
| 120 |
-
action_configs = {
|
| 121 |
-
1: {'name': 'Buy', 'color': '#2ca02c', 'symbol': 'triangle-up'},
|
| 122 |
-
2: {'name': 'Sell', 'color': '#d62728', 'symbol': 'triangle-down'},
|
| 123 |
-
3: {'name': 'Close', 'color': '#ff7f0e', 'symbol': 'x'}
|
| 124 |
-
}
|
| 125 |
-
|
| 126 |
-
for action_id, config in action_configs.items():
|
| 127 |
-
indices = [i for i, a in enumerate(actions) if a == action_id]
|
| 128 |
-
if indices:
|
| 129 |
-
action_prices = prices[indices]
|
| 130 |
-
fig.add_trace(go.Scatter(
|
| 131 |
-
x=[time_steps[i] for i in indices],
|
| 132 |
-
y=action_prices,
|
| 133 |
-
mode='markers',
|
| 134 |
-
name=config['name'],
|
| 135 |
-
marker=dict(
|
| 136 |
-
color=config['color'],
|
| 137 |
-
size=12,
|
| 138 |
-
symbol=config['symbol'],
|
| 139 |
-
line=dict(width=2, color='white')
|
| 140 |
-
),
|
| 141 |
-
hovertemplate=f'<b>{config["name"]}</b><br>Step: %{{x}}<br>Price: $%{{y:.2f}}<extra></extra>',
|
| 142 |
-
showlegend=True
|
| 143 |
-
))
|
| 144 |
-
|
| 145 |
-
def _add_price_metrics(self, fig: go.Figure, prices: np.ndarray):
|
| 146 |
-
"""Add price statistics as annotations"""
|
| 147 |
-
if len(prices) < 2:
|
| 148 |
-
return
|
| 149 |
-
|
| 150 |
-
max_price = np.max(prices)
|
| 151 |
-
min_price = np.min(prices)
|
| 152 |
-
avg_price = np.mean(prices)
|
| 153 |
-
|
| 154 |
-
# Add horizontal reference lines
|
| 155 |
-
fig.add_hline(y=max_price, line_dash="dot", line_color="green",
|
| 156 |
-
annotation_text=f"Max: ${max_price:.2f}")
|
| 157 |
-
fig.add_hline(y=min_price, line_dash="dot", line_color="red",
|
| 158 |
-
annotation_text=f"Min: ${min_price:.2f}")
|
| 159 |
-
fig.add_hline(y=avg_price, line_dash="dash", line_color="blue",
|
| 160 |
-
annotation_text=f"Avg: ${avg_price:.2f}")
|
| 161 |
-
|
| 162 |
-
def create_performance_chart(self, net_worth_history: List[float],
|
| 163 |
-
reward_history: Optional[List[float]] = None,
|
| 164 |
-
initial_balance: float = 10000,
|
| 165 |
-
height: Optional[int] = None) -> go.Figure:
|
| 166 |
-
"""Create comprehensive performance dashboard"""
|
| 167 |
-
height = height or 600
|
| 168 |
-
|
| 169 |
-
if not self._safe_data_validation(net_worth_history, data_type="net worth history"):
|
| 170 |
-
return self._create_empty_figure("No Performance Data", height)
|
| 171 |
-
|
| 172 |
-
try:
|
| 173 |
-
fig = make_subplots(
|
| 174 |
-
rows=2, cols=2,
|
| 175 |
-
subplot_titles=['Portfolio Value', 'Returns vs Initial Balance',
|
| 176 |
-
'Cumulative Reward', 'Reward Distribution'],
|
| 177 |
-
vertical_spacing=0.1,
|
| 178 |
-
horizontal_spacing=0.1,
|
| 179 |
-
specs=[[{"secondary_y": False}, {"secondary_y": False}],
|
| 180 |
-
[{"secondary_y": False}, {"secondary_y": False}]]
|
| 181 |
-
)
|
| 182 |
-
|
| 183 |
-
steps = np.arange(len(net_worth_history))
|
| 184 |
-
net_worth = np.array(net_worth_history)
|
| 185 |
-
|
| 186 |
-
# Portfolio value
|
| 187 |
-
fig.add_trace(
|
| 188 |
-
go.Scatter(x=steps, y=net_worth, mode='lines', name='Net Worth',
|
| 189 |
-
line=dict(color='#2ca02c', width=3)),
|
| 190 |
-
row=1, col=1
|
| 191 |
-
)
|
| 192 |
-
|
| 193 |
-
# Initial balance reference
|
| 194 |
-
fig.add_hline(y=initial_balance, line_dash="dash", line_color="red",
|
| 195 |
-
annotation_text=f"Initial: ${initial_balance:.2f}",
|
| 196 |
-
row=1, col=1)
|
| 197 |
-
|
| 198 |
-
# Returns comparison
|
| 199 |
-
returns = (net_worth - initial_balance) / initial_balance * 100
|
| 200 |
-
fig.add_trace(
|
| 201 |
-
go.Scatter(x=steps, y=returns, mode='lines', name='Returns %',
|
| 202 |
-
line=dict(color='#ff7f0e', width=2)),
|
| 203 |
-
row=1, col=2
|
| 204 |
-
)
|
| 205 |
-
fig.add_hline(y=0, line_dash="solid", line_color="gray", row=1, col=2)
|
| 206 |
-
|
| 207 |
-
# Cumulative reward
|
| 208 |
-
if reward_history and self._safe_data_validation(reward_history):
|
| 209 |
-
cum_reward = np.cumsum(reward_history)
|
| 210 |
-
fig.add_trace(
|
| 211 |
-
go.Scatter(x=steps[:len(cum_reward)], y=cum_reward, mode='lines',
|
| 212 |
-
name='Cumulative Reward', line=dict(color='#9467bd', width=2)),
|
| 213 |
-
row=2, col=1
|
| 214 |
-
)
|
| 215 |
-
|
| 216 |
-
# Reward distribution
|
| 217 |
-
if reward_history:
|
| 218 |
-
fig.add_trace(
|
| 219 |
-
go.Histogram(x=reward_history, name='Reward Distribution',
|
| 220 |
-
marker_color='#1f77b4', opacity=0.7),
|
| 221 |
-
row=2, col=2
|
| 222 |
-
)
|
| 223 |
-
|
| 224 |
-
fig.update_layout(
|
| 225 |
-
height=height,
|
| 226 |
-
showlegend=True,
|
| 227 |
-
title_text="Trading Performance Dashboard",
|
| 228 |
-
template=self.theme
|
| 229 |
-
)
|
| 230 |
-
|
| 231 |
-
# Update axis titles
|
| 232 |
-
fig.update_yaxes(title_text="Value ($)", row=1, col=1)
|
| 233 |
-
fig.update_yaxes(title_text="Returns (%)", row=1, col=2)
|
| 234 |
-
fig.update_yaxes(title_text="Cumulative Reward", row=2, col=1)
|
| 235 |
-
fig.update_xaxes(title_text="Steps", row=2, col=1)
|
| 236 |
-
fig.update_xaxes(title_text="Reward Value", row=2, col=2)
|
| 237 |
-
|
| 238 |
-
return fig
|
| 239 |
-
|
| 240 |
-
except Exception as e:
|
| 241 |
-
logger.error(f"Error creating performance chart: {e}")
|
| 242 |
-
return self._create_empty_figure("Error in Performance Chart", height)
|
| 243 |
-
|
| 244 |
-
def create_action_distribution(self, actions: List[int],
|
| 245 |
-
title: Optional[str] = None,
|
| 246 |
-
height: Optional[int] = None) -> go.Figure:
|
| 247 |
-
"""Create interactive action distribution visualization"""
|
| 248 |
-
height = height or 350
|
| 249 |
-
|
| 250 |
-
if not self._safe_data_validation(actions, data_type="actions"):
|
| 251 |
-
return self._create_empty_figure("No Actions Data", height)
|
| 252 |
-
|
| 253 |
-
try:
|
| 254 |
-
action_names = ['Hold', 'Buy', 'Sell', 'Close']
|
| 255 |
-
action_counts = [actions.count(i) for i in range(4)]
|
| 256 |
-
total_actions = sum(action_counts)
|
| 257 |
-
|
| 258 |
-
colors = ['#1f77b4', '#2ca02c', '#d62728', '#ff7f0e']
|
| 259 |
-
|
| 260 |
-
fig = go.Figure(data=[go.Pie(
|
| 261 |
-
labels=action_names,
|
| 262 |
-
values=action_counts,
|
| 263 |
-
hole=0.4,
|
| 264 |
-
marker_colors=colors,
|
| 265 |
-
textinfo='label+percent+value',
|
| 266 |
-
hovertemplate='<b>%{label}</b><br>Count: %{value}<br>Percentage: %{percent}<extra></extra>',
|
| 267 |
-
pull=[0, 0, 0, 0] # Equal spacing
|
| 268 |
-
)])
|
| 269 |
-
|
| 270 |
-
title = title or f"Action Distribution (Total: {total_actions} actions)"
|
| 271 |
-
fig.update_layout(
|
| 272 |
-
title={
|
| 273 |
-
'text': title,
|
| 274 |
-
'x': 0.5,
|
| 275 |
-
'xanchor': 'center'
|
| 276 |
-
},
|
| 277 |
-
height=height,
|
| 278 |
-
showlegend=True,
|
| 279 |
-
template=self.theme,
|
| 280 |
-
annotations=[dict(
|
| 281 |
-
text='Trading Actions',
|
| 282 |
-
x=0.5, y=0.5,
|
| 283 |
-
font_size=16,
|
| 284 |
-
showarrow=False
|
| 285 |
-
)]
|
| 286 |
-
)
|
| 287 |
-
|
| 288 |
-
return fig
|
| 289 |
-
|
| 290 |
-
except Exception as e:
|
| 291 |
-
logger.error(f"Error creating action distribution: {e}")
|
| 292 |
-
return self._create_empty_figure("Error in Action Distribution", height)
|
| 293 |
-
|
| 294 |
-
def create_training_progress(self, training_history: List[Dict],
|
| 295 |
-
window_size: int = 10,
|
| 296 |
-
height: Optional[int] = None) -> go.Figure:
|
| 297 |
-
"""Create comprehensive training progress dashboard"""
|
| 298 |
-
height = height or 700
|
| 299 |
-
|
| 300 |
-
if not training_history:
|
| 301 |
-
return self._create_empty_figure("No Training Data", height)
|
| 302 |
-
|
| 303 |
-
try:
|
| 304 |
-
# Extract data safely
|
| 305 |
-
episodes = [h.get('episode', i) for i, h in enumerate(training_history)]
|
| 306 |
-
rewards = [h.get('reward', 0) for h in training_history]
|
| 307 |
-
net_worths = [h.get('net_worth', 0) for h in training_history]
|
| 308 |
-
losses = [h.get('loss', 0) for h in training_history]
|
| 309 |
-
|
| 310 |
-
fig = make_subplots(
|
| 311 |
-
rows=2, cols=2,
|
| 312 |
-
subplot_titles=['Total Reward per Episode', 'Final Net Worth',
|
| 313 |
-
'Training Loss', 'Moving Average Reward'],
|
| 314 |
-
specs=[[{"secondary_y": False}, {"secondary_y": False}],
|
| 315 |
-
[{"secondary_y": False}, {"secondary_y": False}]]
|
| 316 |
-
)
|
| 317 |
-
|
| 318 |
-
# Rewards
|
| 319 |
-
fig.add_trace(go.Scatter(
|
| 320 |
-
x=episodes, y=rewards, mode='lines+markers',
|
| 321 |
-
name='Episode Reward', line=dict(color='#1f77b4', width=2),
|
| 322 |
-
marker=dict(size=4)
|
| 323 |
-
), row=1, col=1)
|
| 324 |
-
|
| 325 |
-
# Net worth
|
| 326 |
-
fig.add_trace(go.Scatter(
|
| 327 |
-
x=episodes, y=net_worths, mode='lines+markers',
|
| 328 |
-
name='Final Net Worth', line=dict(color='#2ca02c', width=2),
|
| 329 |
-
marker=dict(size=4)
|
| 330 |
-
), row=1, col=2)
|
| 331 |
-
|
| 332 |
-
# Loss (only if we have meaningful loss values)
|
| 333 |
-
valid_losses = [l for l in losses if l > 0]
|
| 334 |
-
if valid_losses:
|
| 335 |
-
fig.add_trace(go.Scatter(
|
| 336 |
-
x=episodes, y=losses, mode='lines',
|
| 337 |
-
name='Training Loss', line=dict(color='#d62728', width=2)
|
| 338 |
-
), row=2, col=1)
|
| 339 |
-
|
| 340 |
-
# Moving average
|
| 341 |
-
if len(rewards) >= window_size:
|
| 342 |
-
ma_rewards = pd.Series(rewards).rolling(window=window_size, min_periods=1).mean()
|
| 343 |
-
fig.add_trace(go.Scatter(
|
| 344 |
-
x=episodes, y=ma_rewards, mode='lines',
|
| 345 |
-
name=f'MA Reward ({window_size})',
|
| 346 |
-
line=dict(color='#ff7f0e', width=3, dash='dash')
|
| 347 |
-
), row=2, col=2)
|
| 348 |
-
|
| 349 |
-
fig.update_layout(
|
| 350 |
-
height=height,
|
| 351 |
-
showlegend=True,
|
| 352 |
-
title_text=f"Training Progress - {len(episodes)} Episodes",
|
| 353 |
-
template=self.theme
|
| 354 |
-
)
|
| 355 |
-
|
| 356 |
-
# Update axes
|
| 357 |
-
fig.update_yaxes(title_text="Reward", row=1, col=1)
|
| 358 |
-
fig.update_yaxes(title_text="Net Worth ($)", row=1, col=2)
|
| 359 |
-
fig.update_yaxes(title_text="Loss", row=2, col=1)
|
| 360 |
-
fig.update_xaxes(title_text="Episodes", row=2, col=1)
|
| 361 |
-
|
| 362 |
-
return fig
|
| 363 |
-
|
| 364 |
-
except Exception as e:
|
| 365 |
-
logger.error(f"Error creating training progress chart: {e}")
|
| 366 |
-
return self._create_empty_figure("Error in Training Progress", height)
|
| 367 |
-
|
| 368 |
-
def _create_empty_figure(self, title: str, height: int) -> go.Figure:
|
| 369 |
-
"""Create a safe empty figure"""
|
| 370 |
-
fig = go.Figure()
|
| 371 |
-
fig.update_layout(
|
| 372 |
-
title=title,
|
| 373 |
-
height=height,
|
| 374 |
-
template=self.theme
|
| 375 |
-
)
|
| 376 |
-
return fig
|
| 377 |
-
|
| 378 |
-
def save_chart(self, fig: go.Figure, filename: str, format: str = 'html'):
|
| 379 |
-
"""Save chart to file"""
|
| 380 |
-
try:
|
| 381 |
-
if format == 'html':
|
| 382 |
-
fig.write_html(filename)
|
| 383 |
-
elif format == 'png':
|
| 384 |
-
fig.write_image(filename)
|
| 385 |
-
elif format == 'pdf':
|
| 386 |
-
fig.write_image(filename, width=1200, height=800)
|
| 387 |
-
logger.info(f"Chart saved as {filename}")
|
| 388 |
-
except Exception as e:
|
| 389 |
-
logger.error(f"Error saving chart: {e}")
|
| 390 |
-
|
| 391 |
-
def show(self, fig: go.Figure):
|
| 392 |
-
"""Display chart (if in interactive environment)"""
|
| 393 |
-
try:
|
| 394 |
-
fig.show()
|
| 395 |
-
except Exception as e:
|
| 396 |
-
logger.warning(f"Could not display chart: {e}")
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
# Utility functions for batch rendering
|
| 400 |
-
def render_dashboard(prices, actions, net_worth, rewards, config):
|
| 401 |
-
"""Create a complete trading dashboard"""
|
| 402 |
-
renderer = ChartRenderer()
|
| 403 |
-
|
| 404 |
-
figs = {
|
| 405 |
-
'price': renderer.render_price_chart(prices, actions),
|
| 406 |
-
'performance': renderer.create_performance_chart(net_worth, rewards, config.initial_balance),
|
| 407 |
-
'actions': renderer.create_action_distribution(actions)
|
| 408 |
-
}
|
| 409 |
-
|
| 410 |
-
return figs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|