VisualTradingAI / src /agents /advanced_agent.py
OmidSakaki's picture
Update src/agents/advanced_agent.py
6097bc7 verified
raw
history blame
10.2 kB
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
class EnhancedTradingNetwork(nn.Module):
def __init__(self, state_dim, action_dim, sentiment_dim=2):
super(EnhancedTradingNetwork, self).__init__()
# Visual processing branch
self.visual_conv = nn.Sequential(
nn.Conv2d(4, 16, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, stride=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((8, 8))
)
# Calculate the output size after conv layers
self.conv_output_size = 32 * 8 * 8
self.visual_fc = nn.Sequential(
nn.Linear(self.conv_output_size, 256),
nn.ReLU(),
nn.Dropout(0.3)
)
# Sentiment processing branch
self.sentiment_fc = nn.Sequential(
nn.Linear(sentiment_dim, 64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, 32),
nn.ReLU()
)
# Combined decision making
self.combined_fc = nn.Sequential(
nn.Linear(256 + 32, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, action_dim)
)
def forward(self, x, sentiment=None):
try:
# Visual processing with proper reshaping
# x shape: (batch_size, 84, 84, 4) -> (batch_size, 4, 84, 84)
if len(x.shape) == 4: # (batch, H, W, C)
x = x.permute(0, 3, 1, 2).contiguous()
else:
# Handle single sample case
x = x.unsqueeze(0) if len(x.shape) == 3 else x
x = x.permute(0, 3, 1, 2).contiguous()
visual_features = self.visual_conv(x)
# Use reshape instead of view for safety
batch_size = visual_features.size(0)
visual_features = visual_features.reshape(batch_size, -1)
visual_features = self.visual_fc(visual_features)
# Sentiment processing
if sentiment is not None:
if len(sentiment.shape) == 1:
sentiment = sentiment.unsqueeze(0)
sentiment_features = self.sentiment_fc(sentiment)
combined_features = torch.cat([visual_features, sentiment_features], dim=1)
else:
combined_features = visual_features
# Final decision
q_values = self.combined_fc(combined_features)
return q_values
except Exception as e:
print(f"Error in network forward: {e}")
# Return safe default
return torch.zeros((x.size(0) if hasattr(x, 'size') else 1, self.combined_fc[-1].out_features))
class AdvancedTradingAgent:
def __init__(self, state_dim, action_dim, learning_rate=0.001, use_sentiment=True):
self.state_dim = state_dim
self.action_dim = action_dim
self.learning_rate = learning_rate
self.use_sentiment = use_sentiment
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
# Neural network
self.policy_net = EnhancedTradingNetwork(state_dim, action_dim).to(self.device)
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
# Experience replay
self.memory = deque(maxlen=500)
self.batch_size = 16
# Training parameters
self.gamma = 0.99
self.epsilon = 1.0
self.epsilon_min = 0.1
self.epsilon_decay = 0.995
self.steps_done = 0
def select_action(self, state, current_sentiment=0.5, sentiment_confidence=0.0):
"""Select action with sentiment consideration"""
if random.random() < self.epsilon:
return random.randint(0, self.action_dim - 1)
try:
# Normalize state
state_normalized = state.astype(np.float32) / 255.0
state_tensor = torch.FloatTensor(state_normalized).to(self.device)
if self.use_sentiment:
# Add sentiment to the decision process
sentiment_tensor = torch.FloatTensor([current_sentiment, sentiment_confidence]).to(self.device)
with torch.no_grad():
q_values = self.policy_net(state_tensor, sentiment_tensor)
else:
with torch.no_grad():
q_values = self.policy_net(state_tensor)
return int(q_values.argmax().item())
except Exception as e:
print(f"Error in advanced action selection: {e}")
return random.randint(0, self.action_dim - 1)
def store_transition(self, state, action, reward, next_state, done, sentiment_data=None):
"""Store experience with sentiment data"""
try:
experience = (state, action, reward, next_state, done, sentiment_data)
self.memory.append(experience)
except Exception as e:
print(f"Error storing transition: {e}")
def update(self):
"""Update network with sentiment-enhanced learning"""
if len(self.memory) < self.batch_size:
return 0.0
try:
# Sample batch from memory
batch = random.sample(self.memory, self.batch_size)
states, actions, rewards, next_states, dones, sentiment_data = zip(*batch)
# Convert to tensors with proper shape handling
states_array = np.array(states, dtype=np.float32) / 255.0
next_states_array = np.array(next_states, dtype=np.float32) / 255.0
# Ensure proper tensor shapes
states_tensor = torch.FloatTensor(states_array).to(self.device)
next_states_tensor = torch.FloatTensor(next_states_array).to(self.device)
actions_tensor = torch.LongTensor(actions).to(self.device)
rewards_tensor = torch.FloatTensor(rewards).to(self.device)
dones_tensor = torch.BoolTensor(dones).to(self.device)
if self.use_sentiment and sentiment_data[0] is not None:
# Extract sentiment features safely
sentiment_features = []
for data in sentiment_data:
if data and 'sentiment' in data and 'confidence' in data:
sentiment_features.append([data['sentiment'], data['confidence']])
else:
sentiment_features.append([0.5, 0.0])
sentiment_tensor = torch.FloatTensor(sentiment_features).to(self.device)
# Current Q values with sentiment
current_q = self.policy_net(states_tensor, sentiment_tensor)
current_q = current_q.gather(1, actions_tensor.unsqueeze(1))
# Next Q values with sentiment
with torch.no_grad():
next_q = self.policy_net(next_states_tensor, sentiment_tensor)
next_q = next_q.max(1)[0]
target_q = rewards_tensor + (self.gamma * next_q * ~dones_tensor)
else:
# Fallback to standard DQN without sentiment
current_q = self.policy_net(states_tensor)
current_q = current_q.gather(1, actions_tensor.unsqueeze(1))
with torch.no_grad():
next_q = self.policy_net(next_states_tensor)
next_q = next_q.max(1)[0]
target_q = rewards_tensor + (self.gamma * next_q * ~dones_tensor)
# Compute loss
loss = nn.MSELoss()(current_q.squeeze(), target_q)
# Optimize
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping for stability
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
self.optimizer.step()
# Update exploration
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
self.steps_done += 1
return float(loss.item())
except Exception as e:
print(f"Error in advanced update: {e}")
return 0.0
# Fallback to simple agent if advanced one fails
class SimpleTradingNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(SimpleTradingNetwork, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(4, 16, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, stride=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((8, 8))
)
self.fc_layers = nn.Sequential(
nn.Linear(32 * 8 * 8, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, action_dim)
)
def forward(self, x):
try:
# Handle input shape
if len(x.shape) == 4: # (batch, H, W, C)
x = x.permute(0, 3, 1, 2).contiguous()
else:
x = x.unsqueeze(0) if len(x.shape) == 3 else x
x = x.permute(0, 3, 1, 2).contiguous()
x = self.conv_layers(x)
batch_size = x.size(0)
x = x.reshape(batch_size, -1)
x = self.fc_layers(x)
return x
except Exception as e:
print(f"Error in simple network: {e}")
return torch.zeros((x.size(0), self.fc_layers[-1].out_features))