Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,7 @@ import torch.optim as optim
|
|
| 6 |
from collections import deque
|
| 7 |
import random
|
| 8 |
from pathlib import Path
|
| 9 |
-
from typing import Dict, Tuple, Any, List
|
| 10 |
import plotly.graph_objects as go
|
| 11 |
from plotly.subplots import make_subplots
|
| 12 |
import yaml
|
|
@@ -71,21 +71,13 @@ class AdvancedTradingEnvironment:
|
|
| 71 |
self._initialize_market_data()
|
| 72 |
|
| 73 |
# Define action and observation spaces
|
| 74 |
-
self.action_space =
|
| 75 |
-
self.observation_space =
|
| 76 |
|
| 77 |
# Portfolio tracking
|
| 78 |
self.portfolio_history = []
|
| 79 |
self.action_history = []
|
| 80 |
|
| 81 |
-
def _create_action_space(self) -> int:
|
| 82 |
-
"""Define available trading actions"""
|
| 83 |
-
return 4 # 0: Hold, 1: Buy, 2: Sell, 3: Close Position
|
| 84 |
-
|
| 85 |
-
def _create_observation_space(self) -> Tuple:
|
| 86 |
-
"""Define observation space dimensions"""
|
| 87 |
-
return (15,) # Increased features for better state representation
|
| 88 |
-
|
| 89 |
def _initialize_market_data(self):
|
| 90 |
"""Initialize synthetic market data based on asset type"""
|
| 91 |
n_points = 200 # Longer history for better indicators
|
|
@@ -405,7 +397,7 @@ class DQNAgent:
|
|
| 405 |
actions = torch.LongTensor(actions).to(self.device)
|
| 406 |
rewards = torch.FloatTensor(rewards).to(self.device)
|
| 407 |
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
| 408 |
-
dones = torch.BoolTensor(dones).to(self.device)
|
| 409 |
|
| 410 |
# Current Q values
|
| 411 |
current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
|
|
@@ -413,7 +405,8 @@ class DQNAgent:
|
|
| 413 |
# Next Q values from target network
|
| 414 |
with torch.no_grad():
|
| 415 |
next_q_values = self.target_network(next_states).max(1)[0]
|
| 416 |
-
|
|
|
|
| 417 |
|
| 418 |
# Compute loss and update
|
| 419 |
loss = self.criterion(current_q_values, target_q_values)
|
|
|
|
| 6 |
from collections import deque
|
| 7 |
import random
|
| 8 |
from pathlib import Path
|
| 9 |
+
from typing import Dict, Tuple, Any, List, Optional
|
| 10 |
import plotly.graph_objects as go
|
| 11 |
from plotly.subplots import make_subplots
|
| 12 |
import yaml
|
|
|
|
| 71 |
self._initialize_market_data()
|
| 72 |
|
| 73 |
# Define action and observation spaces
|
| 74 |
+
self.action_space = 4 # 0: Hold, 1: Buy, 2: Sell, 3: Close Position
|
| 75 |
+
self.observation_space = (15,) # Increased features for better state representation
|
| 76 |
|
| 77 |
# Portfolio tracking
|
| 78 |
self.portfolio_history = []
|
| 79 |
self.action_history = []
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
def _initialize_market_data(self):
|
| 82 |
"""Initialize synthetic market data based on asset type"""
|
| 83 |
n_points = 200 # Longer history for better indicators
|
|
|
|
| 397 |
actions = torch.LongTensor(actions).to(self.device)
|
| 398 |
rewards = torch.FloatTensor(rewards).to(self.device)
|
| 399 |
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
| 400 |
+
dones = torch.BoolTensor(dones).to(self.device) # Fixed: Use BoolTensor instead of FloatTensor
|
| 401 |
|
| 402 |
# Current Q values
|
| 403 |
current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
|
|
|
|
| 405 |
# Next Q values from target network
|
| 406 |
with torch.no_grad():
|
| 407 |
next_q_values = self.target_network(next_states).max(1)[0]
|
| 408 |
+
# Fixed: Use proper boolean masking
|
| 409 |
+
target_q_values = rewards + self.gamma * next_q_values * (~dones).float()
|
| 410 |
|
| 411 |
# Compute loss and update
|
| 412 |
loss = self.criterion(current_q_values, target_q_values)
|