Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import numpy as np | |
| from collections import deque | |
| import random | |
| class EnhancedTradingNetwork(nn.Module): | |
| def __init__(self, state_dim, action_dim, sentiment_dim=2): | |
| super(EnhancedTradingNetwork, self).__init__() | |
| # Visual processing branch | |
| self.visual_conv = nn.Sequential( | |
| nn.Conv2d(4, 16, kernel_size=4, stride=2), | |
| nn.ReLU(), | |
| nn.Conv2d(16, 32, kernel_size=4, stride=2), | |
| nn.ReLU(), | |
| nn.Conv2d(32, 32, kernel_size=3, stride=1), | |
| nn.ReLU(), | |
| nn.AdaptiveAvgPool2d((8, 8)) | |
| ) | |
| # Calculate the output size after conv layers | |
| self.conv_output_size = 32 * 8 * 8 | |
| self.visual_fc = nn.Sequential( | |
| nn.Linear(self.conv_output_size, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.3) | |
| ) | |
| # Sentiment processing branch | |
| self.sentiment_fc = nn.Sequential( | |
| nn.Linear(sentiment_dim, 64), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(64, 32), | |
| nn.ReLU() | |
| ) | |
| # Combined decision making | |
| self.combined_fc = nn.Sequential( | |
| nn.Linear(256 + 32, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(128, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, action_dim) | |
| ) | |
| def forward(self, x, sentiment=None): | |
| try: | |
| # Visual processing with proper reshaping | |
| # x shape: (batch_size, 84, 84, 4) -> (batch_size, 4, 84, 84) | |
| if len(x.shape) == 4: # (batch, H, W, C) | |
| x = x.permute(0, 3, 1, 2).contiguous() | |
| else: | |
| # Handle single sample case | |
| x = x.unsqueeze(0) if len(x.shape) == 3 else x | |
| x = x.permute(0, 3, 1, 2).contiguous() | |
| visual_features = self.visual_conv(x) | |
| # Use reshape instead of view for safety | |
| batch_size = visual_features.size(0) | |
| visual_features = visual_features.reshape(batch_size, -1) | |
| visual_features = self.visual_fc(visual_features) | |
| # Sentiment processing | |
| if sentiment is not None: | |
| if len(sentiment.shape) == 1: | |
| sentiment = sentiment.unsqueeze(0) | |
| sentiment_features = self.sentiment_fc(sentiment) | |
| combined_features = torch.cat([visual_features, sentiment_features], dim=1) | |
| else: | |
| combined_features = visual_features | |
| # Final decision | |
| q_values = self.combined_fc(combined_features) | |
| return q_values | |
| except Exception as e: | |
| print(f"Error in network forward: {e}") | |
| # Return safe default | |
| return torch.zeros((x.size(0) if hasattr(x, 'size') else 1, self.combined_fc[-1].out_features)) | |
| class AdvancedTradingAgent: | |
| def __init__(self, state_dim, action_dim, learning_rate=0.001, use_sentiment=True): | |
| self.state_dim = state_dim | |
| self.action_dim = action_dim | |
| self.learning_rate = learning_rate | |
| self.use_sentiment = use_sentiment | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {self.device}") | |
| # Neural network | |
| self.policy_net = EnhancedTradingNetwork(state_dim, action_dim).to(self.device) | |
| self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate) | |
| # Experience replay | |
| self.memory = deque(maxlen=500) | |
| self.batch_size = 16 | |
| # Training parameters | |
| self.gamma = 0.99 | |
| self.epsilon = 1.0 | |
| self.epsilon_min = 0.1 | |
| self.epsilon_decay = 0.995 | |
| self.steps_done = 0 | |
| def select_action(self, state, current_sentiment=0.5, sentiment_confidence=0.0): | |
| """Select action with sentiment consideration""" | |
| if random.random() < self.epsilon: | |
| return random.randint(0, self.action_dim - 1) | |
| try: | |
| # Normalize state | |
| state_normalized = state.astype(np.float32) / 255.0 | |
| state_tensor = torch.FloatTensor(state_normalized).to(self.device) | |
| if self.use_sentiment: | |
| # Add sentiment to the decision process | |
| sentiment_tensor = torch.FloatTensor([current_sentiment, sentiment_confidence]).to(self.device) | |
| with torch.no_grad(): | |
| q_values = self.policy_net(state_tensor, sentiment_tensor) | |
| else: | |
| with torch.no_grad(): | |
| q_values = self.policy_net(state_tensor) | |
| return int(q_values.argmax().item()) | |
| except Exception as e: | |
| print(f"Error in advanced action selection: {e}") | |
| return random.randint(0, self.action_dim - 1) | |
| def store_transition(self, state, action, reward, next_state, done, sentiment_data=None): | |
| """Store experience with sentiment data""" | |
| try: | |
| experience = (state, action, reward, next_state, done, sentiment_data) | |
| self.memory.append(experience) | |
| except Exception as e: | |
| print(f"Error storing transition: {e}") | |
| def update(self): | |
| """Update network with sentiment-enhanced learning""" | |
| if len(self.memory) < self.batch_size: | |
| return 0.0 | |
| try: | |
| # Sample batch from memory | |
| batch = random.sample(self.memory, self.batch_size) | |
| states, actions, rewards, next_states, dones, sentiment_data = zip(*batch) | |
| # Convert to tensors with proper shape handling | |
| states_array = np.array(states, dtype=np.float32) / 255.0 | |
| next_states_array = np.array(next_states, dtype=np.float32) / 255.0 | |
| # Ensure proper tensor shapes | |
| states_tensor = torch.FloatTensor(states_array).to(self.device) | |
| next_states_tensor = torch.FloatTensor(next_states_array).to(self.device) | |
| actions_tensor = torch.LongTensor(actions).to(self.device) | |
| rewards_tensor = torch.FloatTensor(rewards).to(self.device) | |
| dones_tensor = torch.BoolTensor(dones).to(self.device) | |
| if self.use_sentiment and sentiment_data[0] is not None: | |
| # Extract sentiment features safely | |
| sentiment_features = [] | |
| for data in sentiment_data: | |
| if data and 'sentiment' in data and 'confidence' in data: | |
| sentiment_features.append([data['sentiment'], data['confidence']]) | |
| else: | |
| sentiment_features.append([0.5, 0.0]) | |
| sentiment_tensor = torch.FloatTensor(sentiment_features).to(self.device) | |
| # Current Q values with sentiment | |
| current_q = self.policy_net(states_tensor, sentiment_tensor) | |
| current_q = current_q.gather(1, actions_tensor.unsqueeze(1)) | |
| # Next Q values with sentiment | |
| with torch.no_grad(): | |
| next_q = self.policy_net(next_states_tensor, sentiment_tensor) | |
| next_q = next_q.max(1)[0] | |
| target_q = rewards_tensor + (self.gamma * next_q * ~dones_tensor) | |
| else: | |
| # Fallback to standard DQN without sentiment | |
| current_q = self.policy_net(states_tensor) | |
| current_q = current_q.gather(1, actions_tensor.unsqueeze(1)) | |
| with torch.no_grad(): | |
| next_q = self.policy_net(next_states_tensor) | |
| next_q = next_q.max(1)[0] | |
| target_q = rewards_tensor + (self.gamma * next_q * ~dones_tensor) | |
| # Compute loss | |
| loss = nn.MSELoss()(current_q.squeeze(), target_q) | |
| # Optimize | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| # Gradient clipping for stability | |
| torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0) | |
| self.optimizer.step() | |
| # Update exploration | |
| self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay) | |
| self.steps_done += 1 | |
| return float(loss.item()) | |
| except Exception as e: | |
| print(f"Error in advanced update: {e}") | |
| return 0.0 | |
| # Fallback to simple agent if advanced one fails | |
| class SimpleTradingNetwork(nn.Module): | |
| def __init__(self, state_dim, action_dim): | |
| super(SimpleTradingNetwork, self).__init__() | |
| self.conv_layers = nn.Sequential( | |
| nn.Conv2d(4, 16, kernel_size=4, stride=2), | |
| nn.ReLU(), | |
| nn.Conv2d(16, 32, kernel_size=4, stride=2), | |
| nn.ReLU(), | |
| nn.Conv2d(32, 32, kernel_size=3, stride=1), | |
| nn.ReLU(), | |
| nn.AdaptiveAvgPool2d((8, 8)) | |
| ) | |
| self.fc_layers = nn.Sequential( | |
| nn.Linear(32 * 8 * 8, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(128, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, action_dim) | |
| ) | |
| def forward(self, x): | |
| try: | |
| # Handle input shape | |
| if len(x.shape) == 4: # (batch, H, W, C) | |
| x = x.permute(0, 3, 1, 2).contiguous() | |
| else: | |
| x = x.unsqueeze(0) if len(x.shape) == 3 else x | |
| x = x.permute(0, 3, 1, 2).contiguous() | |
| x = self.conv_layers(x) | |
| batch_size = x.size(0) | |
| x = x.reshape(batch_size, -1) | |
| x = self.fc_layers(x) | |
| return x | |
| except Exception as e: | |
| print(f"Error in simple network: {e}") | |
| return torch.zeros((x.size(0), self.fc_layers[-1].out_features)) |