VisualTradingAI / src /agents /visual_agent.py
OmidSakaki's picture
Update src/agents/visual_agent.py
1f5a715 verified
raw
history blame
5.06 kB
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
class VisualTradingAgent:
def __init__(self, state_dim, action_dim, learning_rate=0.001):
self.state_dim = state_dim
self.action_dim = action_dim
self.learning_rate = learning_rate
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
# Neural network - simplified for stability
self.policy_net = SimpleTradingNetwork(state_dim, action_dim).to(self.device)
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
# Experience replay
self.memory = deque(maxlen=500) # Smaller memory for stability
self.batch_size = 16
# Training parameters
self.gamma = 0.99
self.epsilon = 1.0
self.epsilon_min = 0.1
self.epsilon_decay = 0.995
def select_action(self, state):
"""Select action using epsilon-greedy policy"""
if random.random() < self.epsilon:
return random.randint(0, self.action_dim - 1)
try:
# Normalize state and convert to tensor
state_normalized = state.astype(np.float32) / 255.0
state_tensor = torch.FloatTensor(state_normalized).unsqueeze(0).to(self.device)
with torch.no_grad():
q_values = self.policy_net(state_tensor)
return q_values.argmax().item()
except Exception as e:
print(f"Error in action selection: {e}")
return random.randint(0, self.action_dim - 1)
def store_transition(self, state, action, reward, next_state, done):
"""Store experience in replay memory"""
self.memory.append((state, action, reward, next_state, done))
def update(self):
"""Update the neural network"""
if len(self.memory) < self.batch_size:
return 0
try:
# Sample batch from memory
batch = random.sample(self.memory, self.batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
# Convert to tensors with normalization
states = torch.FloatTensor(np.array(states)).to(self.device) / 255.0
actions = torch.LongTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(np.array(next_states)).to(self.device) / 255.0
dones = torch.BoolTensor(dones).to(self.device)
# Current Q values
current_q = self.policy_net(states).gather(1, actions.unsqueeze(1))
# Next Q values
with torch.no_grad():
next_q = self.policy_net(next_states).max(1)[0]
target_q = rewards + (self.gamma * next_q * ~dones)
# Compute loss
loss = nn.MSELoss()(current_q.squeeze(), target_q)
# Optimize
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping for stability
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
self.optimizer.step()
# Decay epsilon
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
return loss.item()
except Exception as e:
print(f"Error in update: {e}")
return 0
class SimpleTradingNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(SimpleTradingNetwork, self).__init__()
# Simplified CNN for faster training
self.conv_layers = nn.Sequential(
nn.Conv2d(4, 16, kernel_size=4, stride=2), # Input: 84x84x4
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=4, stride=2), # 41x41x16 -> 19x19x32
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, stride=1), # 19x19x32 -> 17x17x32
nn.ReLU(),
nn.AdaptiveAvgPool2d((8, 8)) # 17x17x32 -> 8x8x32
)
# Calculate flattened size
self.flattened_size = 32 * 8 * 8
# Fully connected layers
self.fc_layers = nn.Sequential(
nn.Linear(self.flattened_size, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, action_dim)
)
def forward(self, x):
# x shape: (batch_size, 84, 84, 4) -> (batch_size, 4, 84, 84)
if len(x.shape) == 4: # Single observation
x = x.permute(0, 3, 1, 2)
else: # Batch of observations
x = x.permute(0, 3, 1, 2)
x = self.conv_layers(x)
x = x.view(x.size(0), -1)
x = self.fc_layers(x)
return x