OmidSakaki commited on
Commit
1f5a715
·
verified ·
1 Parent(s): 208b262

Update src/agents/visual_agent.py

Browse files
Files changed (1) hide show
  1. src/agents/visual_agent.py +30 -20
src/agents/visual_agent.py CHANGED
@@ -13,18 +13,18 @@ class VisualTradingAgent:
13
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  print(f"Using device: {self.device}")
15
 
16
- # Neural network
17
- self.policy_net = TradingCNN(state_dim, action_dim).to(self.device)
18
  self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
19
 
20
  # Experience replay
21
- self.memory = deque(maxlen=1000)
22
- self.batch_size = 32
23
 
24
  # Training parameters
25
  self.gamma = 0.99
26
  self.epsilon = 1.0
27
- self.epsilon_min = 0.01
28
  self.epsilon_decay = 0.995
29
 
30
  def select_action(self, state):
@@ -33,11 +33,15 @@ class VisualTradingAgent:
33
  return random.randint(0, self.action_dim - 1)
34
 
35
  try:
36
- state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
 
 
 
37
  with torch.no_grad():
38
  q_values = self.policy_net(state_tensor)
39
  return q_values.argmax().item()
40
- except:
 
41
  return random.randint(0, self.action_dim - 1)
42
 
43
  def store_transition(self, state, action, reward, next_state, done):
@@ -54,11 +58,11 @@ class VisualTradingAgent:
54
  batch = random.sample(self.memory, self.batch_size)
55
  states, actions, rewards, next_states, dones = zip(*batch)
56
 
57
- # Convert to tensors
58
- states = torch.FloatTensor(np.array(states)).to(self.device)
59
  actions = torch.LongTensor(actions).to(self.device)
60
  rewards = torch.FloatTensor(rewards).to(self.device)
61
- next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
62
  dones = torch.BoolTensor(dones).to(self.device)
63
 
64
  # Current Q values
@@ -75,6 +79,9 @@ class VisualTradingAgent:
75
  # Optimize
76
  self.optimizer.zero_grad()
77
  loss.backward()
 
 
 
78
  self.optimizer.step()
79
 
80
  # Decay epsilon
@@ -86,30 +93,33 @@ class VisualTradingAgent:
86
  print(f"Error in update: {e}")
87
  return 0
88
 
89
- class TradingCNN(nn.Module):
90
  def __init__(self, state_dim, action_dim):
91
- super(TradingCNN, self).__init__()
92
 
93
- # CNN for visual processing
94
  self.conv_layers = nn.Sequential(
95
- nn.Conv2d(4, 32, kernel_size=8, stride=4),
96
  nn.ReLU(),
97
- nn.Conv2d(32, 64, kernel_size=4, stride=2),
98
  nn.ReLU(),
99
- nn.Conv2d(64, 64, kernel_size=3, stride=1),
100
  nn.ReLU(),
101
- nn.AdaptiveAvgPool2d((6, 6))
102
  )
103
 
 
 
 
104
  # Fully connected layers
105
  self.fc_layers = nn.Sequential(
106
- nn.Linear(64 * 6 * 6, 512),
107
  nn.ReLU(),
108
  nn.Dropout(0.2),
109
- nn.Linear(512, 256),
110
  nn.ReLU(),
111
  nn.Dropout(0.2),
112
- nn.Linear(256, action_dim)
113
  )
114
 
115
  def forward(self, x):
 
13
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  print(f"Using device: {self.device}")
15
 
16
+ # Neural network - simplified for stability
17
+ self.policy_net = SimpleTradingNetwork(state_dim, action_dim).to(self.device)
18
  self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
19
 
20
  # Experience replay
21
+ self.memory = deque(maxlen=500) # Smaller memory for stability
22
+ self.batch_size = 16
23
 
24
  # Training parameters
25
  self.gamma = 0.99
26
  self.epsilon = 1.0
27
+ self.epsilon_min = 0.1
28
  self.epsilon_decay = 0.995
29
 
30
  def select_action(self, state):
 
33
  return random.randint(0, self.action_dim - 1)
34
 
35
  try:
36
+ # Normalize state and convert to tensor
37
+ state_normalized = state.astype(np.float32) / 255.0
38
+ state_tensor = torch.FloatTensor(state_normalized).unsqueeze(0).to(self.device)
39
+
40
  with torch.no_grad():
41
  q_values = self.policy_net(state_tensor)
42
  return q_values.argmax().item()
43
+ except Exception as e:
44
+ print(f"Error in action selection: {e}")
45
  return random.randint(0, self.action_dim - 1)
46
 
47
  def store_transition(self, state, action, reward, next_state, done):
 
58
  batch = random.sample(self.memory, self.batch_size)
59
  states, actions, rewards, next_states, dones = zip(*batch)
60
 
61
+ # Convert to tensors with normalization
62
+ states = torch.FloatTensor(np.array(states)).to(self.device) / 255.0
63
  actions = torch.LongTensor(actions).to(self.device)
64
  rewards = torch.FloatTensor(rewards).to(self.device)
65
+ next_states = torch.FloatTensor(np.array(next_states)).to(self.device) / 255.0
66
  dones = torch.BoolTensor(dones).to(self.device)
67
 
68
  # Current Q values
 
79
  # Optimize
80
  self.optimizer.zero_grad()
81
  loss.backward()
82
+
83
+ # Gradient clipping for stability
84
+ torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
85
  self.optimizer.step()
86
 
87
  # Decay epsilon
 
93
  print(f"Error in update: {e}")
94
  return 0
95
 
96
+ class SimpleTradingNetwork(nn.Module):
97
  def __init__(self, state_dim, action_dim):
98
+ super(SimpleTradingNetwork, self).__init__()
99
 
100
+ # Simplified CNN for faster training
101
  self.conv_layers = nn.Sequential(
102
+ nn.Conv2d(4, 16, kernel_size=4, stride=2), # Input: 84x84x4
103
  nn.ReLU(),
104
+ nn.Conv2d(16, 32, kernel_size=4, stride=2), # 41x41x16 -> 19x19x32
105
  nn.ReLU(),
106
+ nn.Conv2d(32, 32, kernel_size=3, stride=1), # 19x19x32 -> 17x17x32
107
  nn.ReLU(),
108
+ nn.AdaptiveAvgPool2d((8, 8)) # 17x17x32 -> 8x8x32
109
  )
110
 
111
+ # Calculate flattened size
112
+ self.flattened_size = 32 * 8 * 8
113
+
114
  # Fully connected layers
115
  self.fc_layers = nn.Sequential(
116
+ nn.Linear(self.flattened_size, 128),
117
  nn.ReLU(),
118
  nn.Dropout(0.2),
119
+ nn.Linear(128, 64),
120
  nn.ReLU(),
121
  nn.Dropout(0.2),
122
+ nn.Linear(64, action_dim)
123
  )
124
 
125
  def forward(self, x):