OmidSakaki commited on
Commit
5a78d94
·
verified ·
1 Parent(s): 3457ff1

Update src/agents/visual_agent.py

Browse files
Files changed (1) hide show
  1. src/agents/visual_agent.py +50 -45
src/agents/visual_agent.py CHANGED
@@ -11,14 +11,14 @@ class VisualTradingAgent:
11
  self.action_dim = action_dim
12
  self.learning_rate = learning_rate
13
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
14
 
15
  # Neural network
16
  self.policy_net = TradingCNN(state_dim, action_dim).to(self.device)
17
- self.target_net = TradingCNN(state_dim, action_dim).to(self.device)
18
  self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
19
 
20
  # Experience replay
21
- self.memory = deque(maxlen=10000)
22
  self.batch_size = 32
23
 
24
  # Training parameters
@@ -26,18 +26,19 @@ class VisualTradingAgent:
26
  self.epsilon = 1.0
27
  self.epsilon_min = 0.01
28
  self.epsilon_decay = 0.995
29
- self.update_target_every = 1000
30
- self.steps_done = 0
31
 
32
  def select_action(self, state):
33
  """Select action using epsilon-greedy policy"""
34
  if random.random() < self.epsilon:
35
  return random.randint(0, self.action_dim - 1)
36
 
37
- state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
38
- with torch.no_grad():
39
- q_values = self.policy_net(state_tensor)
40
- return q_values.argmax().item()
 
 
 
41
 
42
  def store_transition(self, state, action, reward, next_state, done):
43
  """Store experience in replay memory"""
@@ -48,42 +49,42 @@ class VisualTradingAgent:
48
  if len(self.memory) < self.batch_size:
49
  return 0
50
 
51
- # Sample batch from memory
52
- batch = random.sample(self.memory, self.batch_size)
53
- states, actions, rewards, next_states, dones = zip(*batch)
54
-
55
- # Convert to tensors
56
- states = torch.FloatTensor(np.array(states)).to(self.device)
57
- actions = torch.LongTensor(actions).to(self.device)
58
- rewards = torch.FloatTensor(rewards).to(self.device)
59
- next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
60
- dones = torch.BoolTensor(dones).to(self.device)
61
-
62
- # Current Q values
63
- current_q = self.policy_net(states).gather(1, actions.unsqueeze(1))
64
-
65
- # Next Q values
66
- with torch.no_grad():
67
- next_q = self.target_net(next_states).max(1)[0]
68
- target_q = rewards + (self.gamma * next_q * ~dones)
69
-
70
- # Compute loss
71
- loss = nn.MSELoss()(current_q.squeeze(), target_q)
72
-
73
- # Optimize
74
- self.optimizer.zero_grad()
75
- loss.backward()
76
- self.optimizer.step()
77
-
78
- # Update target network
79
- self.steps_done += 1
80
- if self.steps_done % self.update_target_every == 0:
81
- self.target_net.load_state_dict(self.policy_net.state_dict())
82
-
83
- # Decay epsilon
84
- self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
85
-
86
- return loss.item()
87
 
88
  class TradingCNN(nn.Module):
89
  def __init__(self, state_dim, action_dim):
@@ -113,7 +114,11 @@ class TradingCNN(nn.Module):
113
 
114
  def forward(self, x):
115
  # x shape: (batch_size, 84, 84, 4) -> (batch_size, 4, 84, 84)
116
- x = x.permute(0, 3, 1, 2)
 
 
 
 
117
  x = self.conv_layers(x)
118
  x = x.view(x.size(0), -1)
119
  x = self.fc_layers(x)
 
11
  self.action_dim = action_dim
12
  self.learning_rate = learning_rate
13
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ print(f"Using device: {self.device}")
15
 
16
  # Neural network
17
  self.policy_net = TradingCNN(state_dim, action_dim).to(self.device)
 
18
  self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
19
 
20
  # Experience replay
21
+ self.memory = deque(maxlen=1000)
22
  self.batch_size = 32
23
 
24
  # Training parameters
 
26
  self.epsilon = 1.0
27
  self.epsilon_min = 0.01
28
  self.epsilon_decay = 0.995
 
 
29
 
30
  def select_action(self, state):
31
  """Select action using epsilon-greedy policy"""
32
  if random.random() < self.epsilon:
33
  return random.randint(0, self.action_dim - 1)
34
 
35
+ try:
36
+ state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
37
+ with torch.no_grad():
38
+ q_values = self.policy_net(state_tensor)
39
+ return q_values.argmax().item()
40
+ except:
41
+ return random.randint(0, self.action_dim - 1)
42
 
43
  def store_transition(self, state, action, reward, next_state, done):
44
  """Store experience in replay memory"""
 
49
  if len(self.memory) < self.batch_size:
50
  return 0
51
 
52
+ try:
53
+ # Sample batch from memory
54
+ batch = random.sample(self.memory, self.batch_size)
55
+ states, actions, rewards, next_states, dones = zip(*batch)
56
+
57
+ # Convert to tensors
58
+ states = torch.FloatTensor(np.array(states)).to(self.device)
59
+ actions = torch.LongTensor(actions).to(self.device)
60
+ rewards = torch.FloatTensor(rewards).to(self.device)
61
+ next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
62
+ dones = torch.BoolTensor(dones).to(self.device)
63
+
64
+ # Current Q values
65
+ current_q = self.policy_net(states).gather(1, actions.unsqueeze(1))
66
+
67
+ # Next Q values
68
+ with torch.no_grad():
69
+ next_q = self.policy_net(next_states).max(1)[0]
70
+ target_q = rewards + (self.gamma * next_q * ~dones)
71
+
72
+ # Compute loss
73
+ loss = nn.MSELoss()(current_q.squeeze(), target_q)
74
+
75
+ # Optimize
76
+ self.optimizer.zero_grad()
77
+ loss.backward()
78
+ self.optimizer.step()
79
+
80
+ # Decay epsilon
81
+ self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
82
+
83
+ return loss.item()
84
+
85
+ except Exception as e:
86
+ print(f"Error in update: {e}")
87
+ return 0
88
 
89
  class TradingCNN(nn.Module):
90
  def __init__(self, state_dim, action_dim):
 
114
 
115
  def forward(self, x):
116
  # x shape: (batch_size, 84, 84, 4) -> (batch_size, 4, 84, 84)
117
+ if len(x.shape) == 4: # Single observation
118
+ x = x.permute(0, 3, 1, 2)
119
+ else: # Batch of observations
120
+ x = x.permute(0, 3, 1, 2)
121
+
122
  x = self.conv_layers(x)
123
  x = x.view(x.size(0), -1)
124
  x = self.fc_layers(x)