Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import numpy as np | |
| from collections import deque | |
| import random | |
| class VisualTradingAgent: | |
| def __init__(self, state_dim, action_dim, learning_rate=0.001): | |
| self.state_dim = state_dim | |
| self.action_dim = action_dim | |
| self.learning_rate = learning_rate | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {self.device}") | |
| # Neural network - simplified for stability | |
| self.policy_net = SimpleTradingNetwork(state_dim, action_dim).to(self.device) | |
| self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate) | |
| # Experience replay | |
| self.memory = deque(maxlen=500) # Smaller memory for stability | |
| self.batch_size = 16 | |
| # Training parameters | |
| self.gamma = 0.99 | |
| self.epsilon = 1.0 | |
| self.epsilon_min = 0.1 | |
| self.epsilon_decay = 0.995 | |
| def select_action(self, state): | |
| """Select action using epsilon-greedy policy""" | |
| if random.random() < self.epsilon: | |
| return random.randint(0, self.action_dim - 1) | |
| try: | |
| # Normalize state and convert to tensor | |
| state_normalized = state.astype(np.float32) / 255.0 | |
| state_tensor = torch.FloatTensor(state_normalized).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| q_values = self.policy_net(state_tensor) | |
| return q_values.argmax().item() | |
| except Exception as e: | |
| print(f"Error in action selection: {e}") | |
| return random.randint(0, self.action_dim - 1) | |
| def store_transition(self, state, action, reward, next_state, done): | |
| """Store experience in replay memory""" | |
| self.memory.append((state, action, reward, next_state, done)) | |
| def update(self): | |
| """Update the neural network""" | |
| if len(self.memory) < self.batch_size: | |
| return 0 | |
| try: | |
| # Sample batch from memory | |
| batch = random.sample(self.memory, self.batch_size) | |
| states, actions, rewards, next_states, dones = zip(*batch) | |
| # Convert to tensors with normalization | |
| states = torch.FloatTensor(np.array(states)).to(self.device) / 255.0 | |
| actions = torch.LongTensor(actions).to(self.device) | |
| rewards = torch.FloatTensor(rewards).to(self.device) | |
| next_states = torch.FloatTensor(np.array(next_states)).to(self.device) / 255.0 | |
| dones = torch.BoolTensor(dones).to(self.device) | |
| # Current Q values | |
| current_q = self.policy_net(states).gather(1, actions.unsqueeze(1)) | |
| # Next Q values | |
| with torch.no_grad(): | |
| next_q = self.policy_net(next_states).max(1)[0] | |
| target_q = rewards + (self.gamma * next_q * ~dones) | |
| # Compute loss | |
| loss = nn.MSELoss()(current_q.squeeze(), target_q) | |
| # Optimize | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| # Gradient clipping for stability | |
| torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0) | |
| self.optimizer.step() | |
| # Decay epsilon | |
| self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay) | |
| return loss.item() | |
| except Exception as e: | |
| print(f"Error in update: {e}") | |
| return 0 | |
| class SimpleTradingNetwork(nn.Module): | |
| def __init__(self, state_dim, action_dim): | |
| super(SimpleTradingNetwork, self).__init__() | |
| # Simplified CNN for faster training | |
| self.conv_layers = nn.Sequential( | |
| nn.Conv2d(4, 16, kernel_size=4, stride=2), # Input: 84x84x4 | |
| nn.ReLU(), | |
| nn.Conv2d(16, 32, kernel_size=4, stride=2), # 41x41x16 -> 19x19x32 | |
| nn.ReLU(), | |
| nn.Conv2d(32, 32, kernel_size=3, stride=1), # 19x19x32 -> 17x17x32 | |
| nn.ReLU(), | |
| nn.AdaptiveAvgPool2d((8, 8)) # 17x17x32 -> 8x8x32 | |
| ) | |
| # Calculate flattened size | |
| self.flattened_size = 32 * 8 * 8 | |
| # Fully connected layers | |
| self.fc_layers = nn.Sequential( | |
| nn.Linear(self.flattened_size, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(128, 64), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(64, action_dim) | |
| ) | |
| def forward(self, x): | |
| # x shape: (batch_size, 84, 84, 4) -> (batch_size, 4, 84, 84) | |
| if len(x.shape) == 4: # Single observation | |
| x = x.permute(0, 3, 1, 2) | |
| else: # Batch of observations | |
| x = x.permute(0, 3, 1, 2) | |
| x = self.conv_layers(x) | |
| x = x.view(x.size(0), -1) | |
| x = self.fc_layers(x) | |
| return x |