OmidSakaki commited on
Commit
2965337
·
verified ·
1 Parent(s): ae2aacf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -13
app.py CHANGED
@@ -6,7 +6,7 @@ import torch.optim as optim
6
  from collections import deque
7
  import random
8
  from pathlib import Path
9
- from typing import Dict, Tuple, Any, List
10
  import plotly.graph_objects as go
11
  from plotly.subplots import make_subplots
12
  import yaml
@@ -71,21 +71,13 @@ class AdvancedTradingEnvironment:
71
  self._initialize_market_data()
72
 
73
  # Define action and observation spaces
74
- self.action_space = self._create_action_space()
75
- self.observation_space = self._create_observation_space()
76
 
77
  # Portfolio tracking
78
  self.portfolio_history = []
79
  self.action_history = []
80
 
81
- def _create_action_space(self) -> int:
82
- """Define available trading actions"""
83
- return 4 # 0: Hold, 1: Buy, 2: Sell, 3: Close Position
84
-
85
- def _create_observation_space(self) -> Tuple:
86
- """Define observation space dimensions"""
87
- return (15,) # Increased features for better state representation
88
-
89
  def _initialize_market_data(self):
90
  """Initialize synthetic market data based on asset type"""
91
  n_points = 200 # Longer history for better indicators
@@ -405,7 +397,7 @@ class DQNAgent:
405
  actions = torch.LongTensor(actions).to(self.device)
406
  rewards = torch.FloatTensor(rewards).to(self.device)
407
  next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
408
- dones = torch.BoolTensor(dones).to(self.device)
409
 
410
  # Current Q values
411
  current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
@@ -413,7 +405,8 @@ class DQNAgent:
413
  # Next Q values from target network
414
  with torch.no_grad():
415
  next_q_values = self.target_network(next_states).max(1)[0]
416
- target_q_values = rewards + self.gamma * next_q_values * (~dones)
 
417
 
418
  # Compute loss and update
419
  loss = self.criterion(current_q_values, target_q_values)
 
6
  from collections import deque
7
  import random
8
  from pathlib import Path
9
+ from typing import Dict, Tuple, Any, List, Optional
10
  import plotly.graph_objects as go
11
  from plotly.subplots import make_subplots
12
  import yaml
 
71
  self._initialize_market_data()
72
 
73
  # Define action and observation spaces
74
+ self.action_space = 4 # 0: Hold, 1: Buy, 2: Sell, 3: Close Position
75
+ self.observation_space = (15,) # Increased features for better state representation
76
 
77
  # Portfolio tracking
78
  self.portfolio_history = []
79
  self.action_history = []
80
 
 
 
 
 
 
 
 
 
81
  def _initialize_market_data(self):
82
  """Initialize synthetic market data based on asset type"""
83
  n_points = 200 # Longer history for better indicators
 
397
  actions = torch.LongTensor(actions).to(self.device)
398
  rewards = torch.FloatTensor(rewards).to(self.device)
399
  next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
400
+ dones = torch.BoolTensor(dones).to(self.device) # Fixed: Use BoolTensor instead of FloatTensor
401
 
402
  # Current Q values
403
  current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
 
405
  # Next Q values from target network
406
  with torch.no_grad():
407
  next_q_values = self.target_network(next_states).max(1)[0]
408
+ # Fixed: Use proper boolean masking
409
+ target_q_values = rewards + self.gamma * next_q_values * (~dones).float()
410
 
411
  # Compute loss and update
412
  loss = self.criterion(current_q_values, target_q_values)