OmidSakaki commited on
Commit
de24124
·
verified ·
1 Parent(s): 8f54cbf

Update src/agents/visual_agent.py

Browse files
Files changed (1) hide show
  1. src/agents/visual_agent.py +74 -56
src/agents/visual_agent.py CHANGED
@@ -5,6 +5,52 @@ import numpy as np
5
  from collections import deque
6
  import random
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  class VisualTradingAgent:
9
  def __init__(self, state_dim, action_dim, learning_rate=0.001):
10
  self.state_dim = state_dim
@@ -13,7 +59,7 @@ class VisualTradingAgent:
13
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  print(f"Using device: {self.device}")
15
 
16
- # Neural network - simplified for stability
17
  self.policy_net = SimpleTradingNetwork(state_dim, action_dim).to(self.device)
18
  self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
19
 
@@ -26,6 +72,8 @@ class VisualTradingAgent:
26
  self.epsilon = 1.0
27
  self.epsilon_min = 0.1
28
  self.epsilon_decay = 0.995
 
 
29
 
30
  def select_action(self, state):
31
  """Select action using epsilon-greedy policy"""
@@ -39,19 +87,22 @@ class VisualTradingAgent:
39
 
40
  with torch.no_grad():
41
  q_values = self.policy_net(state_tensor)
42
- return q_values.argmax().item()
43
  except Exception as e:
44
  print(f"Error in action selection: {e}")
45
  return random.randint(0, self.action_dim - 1)
46
 
47
  def store_transition(self, state, action, reward, next_state, done):
48
  """Store experience in replay memory"""
49
- self.memory.append((state, action, reward, next_state, done))
 
 
 
50
 
51
  def update(self):
52
  """Update the neural network"""
53
  if len(self.memory) < self.batch_size:
54
- return 0
55
 
56
  try:
57
  # Sample batch from memory
@@ -59,19 +110,22 @@ class VisualTradingAgent:
59
  states, actions, rewards, next_states, dones = zip(*batch)
60
 
61
  # Convert to tensors with normalization
62
- states = torch.FloatTensor(np.array(states)).to(self.device) / 255.0
63
- actions = torch.LongTensor(actions).to(self.device)
64
- rewards = torch.FloatTensor(rewards).to(self.device)
65
- next_states = torch.FloatTensor(np.array(next_states)).to(self.device) / 255.0
66
- dones = torch.BoolTensor(dones).to(self.device)
 
 
 
67
 
68
  # Current Q values
69
- current_q = self.policy_net(states).gather(1, actions.unsqueeze(1))
70
 
71
  # Next Q values
72
  with torch.no_grad():
73
- next_q = self.policy_net(next_states).max(1)[0]
74
- target_q = rewards + (self.gamma * next_q * ~dones)
75
 
76
  # Compute loss
77
  loss = nn.MSELoss()(current_q.squeeze(), target_q)
@@ -84,52 +138,16 @@ class VisualTradingAgent:
84
  torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
85
  self.optimizer.step()
86
 
87
- # Decay epsilon
 
 
 
 
 
88
  self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
89
 
90
- return loss.item()
91
 
92
  except Exception as e:
93
  print(f"Error in update: {e}")
94
- return 0
95
-
96
- class SimpleTradingNetwork(nn.Module):
97
- def __init__(self, state_dim, action_dim):
98
- super(SimpleTradingNetwork, self).__init__()
99
-
100
- # Simplified CNN for faster training
101
- self.conv_layers = nn.Sequential(
102
- nn.Conv2d(4, 16, kernel_size=4, stride=2), # Input: 84x84x4
103
- nn.ReLU(),
104
- nn.Conv2d(16, 32, kernel_size=4, stride=2), # 41x41x16 -> 19x19x32
105
- nn.ReLU(),
106
- nn.Conv2d(32, 32, kernel_size=3, stride=1), # 19x19x32 -> 17x17x32
107
- nn.ReLU(),
108
- nn.AdaptiveAvgPool2d((8, 8)) # 17x17x32 -> 8x8x32
109
- )
110
-
111
- # Calculate flattened size
112
- self.flattened_size = 32 * 8 * 8
113
-
114
- # Fully connected layers
115
- self.fc_layers = nn.Sequential(
116
- nn.Linear(self.flattened_size, 128),
117
- nn.ReLU(),
118
- nn.Dropout(0.2),
119
- nn.Linear(128, 64),
120
- nn.ReLU(),
121
- nn.Dropout(0.2),
122
- nn.Linear(64, action_dim)
123
- )
124
-
125
- def forward(self, x):
126
- # x shape: (batch_size, 84, 84, 4) -> (batch_size, 4, 84, 84)
127
- if len(x.shape) == 4: # Single observation
128
- x = x.permute(0, 3, 1, 2)
129
- else: # Batch of observations
130
- x = x.permute(0, 3, 1, 2)
131
-
132
- x = self.conv_layers(x)
133
- x = x.view(x.size(0), -1)
134
- x = self.fc_layers(x)
135
- return x
 
5
  from collections import deque
6
  import random
7
 
8
+ class SimpleTradingNetwork(nn.Module):
9
+ def __init__(self, state_dim, action_dim):
10
+ super(SimpleTradingNetwork, self).__init__()
11
+
12
+ # Simplified CNN for faster training
13
+ self.conv_layers = nn.Sequential(
14
+ nn.Conv2d(4, 16, kernel_size=4, stride=2), # Input: 84x84x4 -> 41x41x16
15
+ nn.ReLU(),
16
+ nn.Conv2d(16, 32, kernel_size=4, stride=2), # 41x41x16 -> 19x19x32
17
+ nn.ReLU(),
18
+ nn.Conv2d(32, 32, kernel_size=3, stride=1), # 19x19x32 -> 17x17x32
19
+ nn.ReLU(),
20
+ nn.AdaptiveAvgPool2d((8, 8)) # 17x17x32 -> 8x8x32
21
+ )
22
+
23
+ # Calculate flattened size
24
+ self.flattened_size = 32 * 8 * 8
25
+
26
+ # Fully connected layers
27
+ self.fc_layers = nn.Sequential(
28
+ nn.Linear(self.flattened_size, 128),
29
+ nn.ReLU(),
30
+ nn.Dropout(0.2),
31
+ nn.Linear(128, 64),
32
+ nn.ReLU(),
33
+ nn.Dropout(0.2),
34
+ nn.Linear(64, action_dim)
35
+ )
36
+
37
+ def forward(self, x):
38
+ try:
39
+ # x shape: (batch_size, 84, 84, 4) -> (batch_size, 4, 84, 84)
40
+ if len(x.shape) == 4: # Single observation
41
+ x = x.permute(0, 3, 1, 2)
42
+ else: # Batch of observations
43
+ x = x.permute(0, 3, 1, 2)
44
+
45
+ x = self.conv_layers(x)
46
+ x = x.view(x.size(0), -1)
47
+ x = self.fc_layers(x)
48
+ return x
49
+ except Exception as e:
50
+ print(f"Error in network forward: {e}")
51
+ # Return zeros in case of error
52
+ return torch.zeros((x.size(0), self.fc_layers[-1].out_features))
53
+
54
  class VisualTradingAgent:
55
  def __init__(self, state_dim, action_dim, learning_rate=0.001):
56
  self.state_dim = state_dim
 
59
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
  print(f"Using device: {self.device}")
61
 
62
+ # Neural network
63
  self.policy_net = SimpleTradingNetwork(state_dim, action_dim).to(self.device)
64
  self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
65
 
 
72
  self.epsilon = 1.0
73
  self.epsilon_min = 0.1
74
  self.epsilon_decay = 0.995
75
+ self.update_target_every = 100
76
+ self.steps_done = 0
77
 
78
  def select_action(self, state):
79
  """Select action using epsilon-greedy policy"""
 
87
 
88
  with torch.no_grad():
89
  q_values = self.policy_net(state_tensor)
90
+ return int(q_values.argmax().item())
91
  except Exception as e:
92
  print(f"Error in action selection: {e}")
93
  return random.randint(0, self.action_dim - 1)
94
 
95
  def store_transition(self, state, action, reward, next_state, done):
96
  """Store experience in replay memory"""
97
+ try:
98
+ self.memory.append((state, action, reward, next_state, done))
99
+ except Exception as e:
100
+ print(f"Error storing transition: {e}")
101
 
102
  def update(self):
103
  """Update the neural network"""
104
  if len(self.memory) < self.batch_size:
105
+ return 0.0
106
 
107
  try:
108
  # Sample batch from memory
 
110
  states, actions, rewards, next_states, dones = zip(*batch)
111
 
112
  # Convert to tensors with normalization
113
+ states_array = np.array(states, dtype=np.float32) / 255.0
114
+ next_states_array = np.array(next_states, dtype=np.float32) / 255.0
115
+
116
+ states_tensor = torch.FloatTensor(states_array).to(self.device)
117
+ actions_tensor = torch.LongTensor(actions).to(self.device)
118
+ rewards_tensor = torch.FloatTensor(rewards).to(self.device)
119
+ next_states_tensor = torch.FloatTensor(next_states_array).to(self.device)
120
+ dones_tensor = torch.BoolTensor(dones).to(self.device)
121
 
122
  # Current Q values
123
+ current_q = self.policy_net(states_tensor).gather(1, actions_tensor.unsqueeze(1))
124
 
125
  # Next Q values
126
  with torch.no_grad():
127
+ next_q = self.policy_net(next_states_tensor).max(1)[0]
128
+ target_q = rewards_tensor + (self.gamma * next_q * ~dones_tensor)
129
 
130
  # Compute loss
131
  loss = nn.MSELoss()(current_q.squeeze(), target_q)
 
138
  torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
139
  self.optimizer.step()
140
 
141
+ # Update steps and decay epsilon
142
+ self.steps_done += 1
143
+ if self.steps_done % self.update_target_every == 0:
144
+ # For simplicity, we're using the same network
145
+ pass
146
+
147
  self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
148
 
149
+ return float(loss.item())
150
 
151
  except Exception as e:
152
  print(f"Error in update: {e}")
153
+ return 0.0