OmidSakaki commited on
Commit
ae2aacf
ยท
verified ยท
1 Parent(s): f77d216

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +654 -139
app.py CHANGED
@@ -1,279 +1,794 @@
1
  import gradio as gr
2
  import numpy as np
3
  import torch
 
 
 
 
4
  from pathlib import Path
5
- from typing import Dict, Tuple, Any
6
- from loguru import logger
 
7
  import yaml
8
- from gymnasium import spaces
9
 
 
10
  class TradingConfig:
 
 
 
 
11
  def __init__(self):
 
12
  self.initial_balance = 10000.0
13
  self.max_steps = 1000
14
  self.transaction_cost = 0.001
15
  self.risk_level = "Medium"
16
  self.asset_type = "Crypto"
 
 
17
  self.learning_rate = 0.0001
18
- self.gamma = 0.99
19
  self.epsilon_start = 1.0
20
  self.epsilon_min = 0.01
21
  self.epsilon_decay = 0.9995
22
  self.batch_size = 32
23
  self.memory_size = 10000
24
  self.target_update = 100
 
 
 
 
 
 
 
 
25
 
26
  class AdvancedTradingEnvironment:
27
- def __init__(self, config):
 
 
 
 
 
 
28
  self.initial_balance = config.initial_balance
29
  self.balance = self.initial_balance
30
  self.position = 0.0
31
  self.current_price = 100.0
32
  self.step_count = 0
33
  self.max_steps = config.max_steps
 
 
 
34
  self.price_history = []
 
35
  self.sentiment_history = []
36
- self._initialize_data()
37
- self.action_space = spaces.Discrete(4)
38
- self.observation_space = spaces.Box(low=-2.0, high=2.0, shape=(12,), dtype=np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- def _initialize_data(self):
41
- n_points = 100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  base_price = 100.0
 
43
  for i in range(n_points):
44
- price = base_price + np.sin(i * 0.1) * 10 + np.random.normal(0, 2)
45
- self.price_history.append(max(10.0, price))
46
- sentiment = 0.5 + np.random.normal(0, 0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  self.sentiment_history.append(np.clip(sentiment, 0.0, 1.0))
 
48
  self.current_price = self.price_history[-1]
49
 
50
- def reset(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  self.balance = self.initial_balance
52
  self.position = 0.0
53
  self.step_count = 0
 
 
 
 
54
  self.price_history = [100.0 + np.random.normal(0, 5)]
55
- self.sentiment_history = [0.5]
 
 
 
56
  obs = self._get_observation()
57
  info = self._get_info()
 
58
  return obs, info
59
 
60
- def step(self, action):
 
61
  self.step_count += 1
62
- price_change = np.random.normal(0, 0.02)
63
- self.current_price = max(10.0, self.current_price * (1 + price_change))
64
- self.price_history.append(self.current_price)
65
 
66
- sentiment_change = np.random.normal(0, 0.05)
67
- new_sentiment = np.clip(self.sentiment_history[-1] + sentiment_change, 0.0, 1.0)
68
- self.sentiment_history.append(new_sentiment)
69
 
 
70
  reward = self._execute_action(action)
71
 
 
72
  terminated = self.balance <= 0 or self.step_count >= self.max_steps
73
  truncated = False
74
 
 
75
  obs = self._get_observation()
76
  info = self._get_info()
77
 
 
 
 
 
78
  return obs, reward, terminated, truncated, info
79
 
80
- def _execute_action(self, action):
81
- reward = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  prev_net_worth = self.balance + self.position * self.current_price
 
83
 
84
  if action == 1: # Buy
85
- trade_amount = min(self.balance * 0.2, self.balance)
86
- cost = trade_amount
87
- if cost <= self.balance:
88
- self.position += trade_amount / self.current_price
89
- self.balance -= cost
 
 
90
 
91
  elif action == 2: # Sell
92
  if self.position > 0:
93
- sell_amount = min(self.position * 0.2, self.position)
94
- proceeds = sell_amount * self.current_price
95
- self.position -= sell_amount
 
96
  self.balance += proceeds
97
 
98
- elif action == 3: # Close
99
  if self.position > 0:
100
- proceeds = self.position * self.current_price
101
  self.balance += proceeds
102
  self.position = 0
103
 
104
- net_worth = self.balance + self.position * self.current_price
105
- reward = (net_worth - prev_net_worth) / self.initial_balance * 100
 
106
 
107
- return reward
108
-
109
- def _get_observation(self):
110
- recent_prices = self.price_history[-10:] if len(self.price_history) >= 10 else [self.current_price] * 10
111
- recent_sentiments = self.sentiment_history[-10:] if len(self.sentiment_history) >= 10 else [0.5] * 10
112
 
113
- features = [
114
- self.balance / self.initial_balance,
115
- self.position * self.current_price / self.initial_balance,
 
 
 
 
 
 
116
  self.current_price / 100.0,
117
  np.mean(recent_prices) / 100.0,
118
  np.std(recent_prices) / 100.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  np.mean(recent_sentiments),
120
  np.std(recent_sentiments),
121
- self.step_count / self.max_steps,
122
- 0.0, 0.0, 0.0, 0.0 # Padding
123
  ]
124
 
125
- return np.array(features[:12], dtype=np.float32)
 
 
 
 
 
 
 
 
 
126
 
127
- def _get_info(self):
 
128
  net_worth = self.balance + self.position * self.current_price
129
- return {'net_worth': net_worth}
 
 
 
 
 
 
 
 
 
130
 
131
  class DQNAgent:
132
- def __init__(self, state_dim, action_dim, config, device='cpu'):
 
 
 
 
 
133
  self.device = torch.device(device)
134
- self.q_network = torch.nn.Sequential(
135
- torch.nn.Linear(state_dim, 128),
136
- torch.nn.ReLU(),
137
- torch.nn.Linear(128, 128),
138
- torch.nn.ReLU(),
139
- torch.nn.Linear(128, action_dim)
140
- ).to(self.device)
141
-
142
- self.target_network = torch.nn.Sequential(
143
- torch.nn.Linear(state_dim, 128),
144
- torch.nn.ReLU(),
145
- torch.nn.Linear(128, 128),
146
- torch.nn.ReLU(),
147
- torch.nn.Linear(128, action_dim)
148
- ).to(self.device)
149
 
 
 
 
150
  self.target_network.load_state_dict(self.q_network.state_dict())
151
 
152
- self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=config.learning_rate)
 
 
 
 
153
  self.memory = deque(maxlen=config.memory_size)
154
- self.gamma = config.gamma
 
155
  self.epsilon = config.epsilon_start
156
  self.epsilon_min = config.epsilon_min
157
  self.epsilon_decay = config.epsilon_decay
 
 
158
  self.batch_size = config.batch_size
 
159
  self.target_update = config.target_update
160
  self.steps = 0
161
 
162
- def select_action(self, state, training=True):
163
- state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
164
  if training and random.random() < self.epsilon:
165
- return random.randint(0, 3)
 
 
 
166
  with torch.no_grad():
167
- return self.q_network(state).argmax(1).item()
 
168
 
169
- def store_transition(self, state, action, reward, next_state, done):
 
 
170
  self.memory.append((state, action, reward, next_state, done))
171
 
172
- def update(self):
 
173
  if len(self.memory) < self.batch_size:
174
  return 0.0
175
 
 
176
  batch = random.sample(self.memory, self.batch_size)
177
  states, actions, rewards, next_states, dones = zip(*batch)
178
 
 
179
  states = torch.FloatTensor(np.array(states)).to(self.device)
180
  actions = torch.LongTensor(actions).to(self.device)
181
  rewards = torch.FloatTensor(rewards).to(self.device)
182
  next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
183
- dones = torch.FloatTensor(dones).to(self.device)
 
 
 
184
 
185
- current_q = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
186
- next_q = self.target_network(next_states).max(1)[0]
187
- target_q = rewards + self.gamma * next_q * (1 - dones)
 
188
 
189
- loss = torch.nn.MSELoss()(current_q, target_q)
 
190
 
191
  self.optimizer.zero_grad()
192
  loss.backward()
 
 
 
193
  self.optimizer.step()
194
 
 
195
  self.steps += 1
196
  if self.steps % self.target_update == 0:
197
  self.target_network.load_state_dict(self.q_network.state_dict())
198
 
 
199
  self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
200
 
201
  return loss.item()
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  class TradingDemo:
 
 
 
 
 
204
  def __init__(self):
205
  self.config = TradingConfig()
206
  self.env = None
207
  self.agent = None
208
- self.device = 'cpu'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
- def initialize(self, balance, risk, asset):
211
- self.config.initial_balance = balance
212
- self.config.risk_level = risk
213
- self.config.asset_type = asset
214
- self.env = AdvancedTradingEnvironment(self.config)
215
- self.agent = DQNAgent(12, 4, self.config, self.device)
216
- return "โœ… Initialized!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
- def train(self, episodes):
219
- for ep in range(episodes):
 
 
 
 
 
220
  obs, _ = self.env.reset()
221
- total_reward = 0
222
- done = False
223
- while not done:
224
- action = self.agent.select_action(obs)
 
 
 
 
 
 
225
  next_obs, reward, done, _, info = self.env.step(action)
226
- self.agent.store_transition(obs, action, reward, next_obs, done)
 
 
 
 
 
 
 
227
  obs = next_obs
228
- total_reward += reward
229
- self.agent.update()
230
- yield f"Episode {ep+1}/{episodes} | Reward: {total_reward:.2f}", None
231
- yield "โœ… Training complete!", None
 
 
 
 
 
 
 
 
 
 
232
 
233
- def simulate(self, steps):
234
- obs, _ = self.env.reset()
235
- prices = []
236
- actions = []
237
- net_worths = []
238
- for _ in range(steps):
239
- action = self.agent.select_action(obs, training=False)
240
- next_obs, reward, done, _, info = self.env.step(action)
241
- prices.append(self.env.current_price)
242
- actions.append(action)
243
- net_worths.append(info['net_worth'])
244
- obs = next_obs
245
- if done:
246
- break
247
-
248
- import plotly.graph_objects as go
249
- fig = go.Figure()
250
- fig.add_trace(go.Scatter(y=prices, mode='lines', name='Price'))
251
- fig.add_trace(go.Scatter(y=net_worths, mode='lines', name='Net Worth'))
252
- return "โœ… Simulation complete!", fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
- demo = TradingDemo()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
- with gr.Blocks() as interface:
257
- gr.Markdown("# Trading AI Demo")
258
-
259
- with gr.Row():
260
- balance = gr.Slider(1000, 50000, 10000, label="Balance")
261
- risk = gr.Radio(["Low", "Medium", "High"], value="Medium", label="Risk")
262
- asset = gr.Radio(["Crypto", "Stock", "Forex"], value="Crypto", label="Asset")
263
- init_btn = gr.Button("Initialize")
264
-
265
- status = gr.Textbox(label="Status")
266
 
267
- episodes = gr.Number(value=50, label="Episodes")
268
- train_btn = gr.Button("Train")
269
- train_plot = gr.Plot()
270
-
271
- steps = gr.Number(value=100, label="Simulation Steps")
272
- sim_btn = gr.Button("Simulate")
273
- sim_plot = gr.Plot()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
- init_btn.click(demo.initialize, [balance, risk, asset], status)
276
- train_btn.click(demo.train, episodes, [status, train_plot])
277
- sim_btn.click(demo.simulate, steps, [status, sim_plot])
278
 
279
- interface.launch()
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from collections import deque
7
+ import random
8
  from pathlib import Path
9
+ from typing import Dict, Tuple, Any, List
10
+ import plotly.graph_objects as go
11
+ from plotly.subplots import make_subplots
12
  import yaml
 
13
 
14
+ # Configuration class for trading parameters
15
  class TradingConfig:
16
+ """
17
+ Configuration class for trading environment and agent parameters.
18
+ Centralizes all configurable parameters for easy modification.
19
+ """
20
  def __init__(self):
21
+ # Environment parameters
22
  self.initial_balance = 10000.0
23
  self.max_steps = 1000
24
  self.transaction_cost = 0.001
25
  self.risk_level = "Medium"
26
  self.asset_type = "Crypto"
27
+
28
+ # DQN agent parameters
29
  self.learning_rate = 0.0001
30
+ self.gamma = 0.99 # Discount factor
31
  self.epsilon_start = 1.0
32
  self.epsilon_min = 0.01
33
  self.epsilon_decay = 0.9995
34
  self.batch_size = 32
35
  self.memory_size = 10000
36
  self.target_update = 100
37
+ self.hidden_size = 128
38
+
39
+ # Risk adjustment factors based on risk level
40
+ self.risk_multipliers = {
41
+ "Low": 0.5,
42
+ "Medium": 1.0,
43
+ "High": 2.0
44
+ }
45
 
46
  class AdvancedTradingEnvironment:
47
+ """
48
+ Advanced trading environment simulating financial markets with multiple assets.
49
+ Supports crypto, stocks, and forex with realistic price movements and sentiment analysis.
50
+ """
51
+
52
+ def __init__(self, config: TradingConfig):
53
+ self.config = config
54
  self.initial_balance = config.initial_balance
55
  self.balance = self.initial_balance
56
  self.position = 0.0
57
  self.current_price = 100.0
58
  self.step_count = 0
59
  self.max_steps = config.max_steps
60
+ self.transaction_cost = config.transaction_cost
61
+
62
+ # Market data history
63
  self.price_history = []
64
+ self.volume_history = []
65
  self.sentiment_history = []
66
+
67
+ # Risk adjustment
68
+ self.risk_multiplier = config.risk_multipliers[config.risk_level]
69
+
70
+ # Initialize market data
71
+ self._initialize_market_data()
72
+
73
+ # Define action and observation spaces
74
+ self.action_space = self._create_action_space()
75
+ self.observation_space = self._create_observation_space()
76
+
77
+ # Portfolio tracking
78
+ self.portfolio_history = []
79
+ self.action_history = []
80
+
81
+ def _create_action_space(self) -> int:
82
+ """Define available trading actions"""
83
+ return 4 # 0: Hold, 1: Buy, 2: Sell, 3: Close Position
84
 
85
+ def _create_observation_space(self) -> Tuple:
86
+ """Define observation space dimensions"""
87
+ return (15,) # Increased features for better state representation
88
+
89
+ def _initialize_market_data(self):
90
+ """Initialize synthetic market data based on asset type"""
91
+ n_points = 200 # Longer history for better indicators
92
+
93
+ # Different volatility based on asset type
94
+ volatility_map = {
95
+ "Crypto": 0.03,
96
+ "Stock": 0.015,
97
+ "Forex": 0.008
98
+ }
99
+
100
+ volatility = volatility_map.get(self.config.asset_type, 0.02)
101
  base_price = 100.0
102
+
103
  for i in range(n_points):
104
+ # More realistic price generation with momentum
105
+ momentum = np.sin(i * 0.05) * 2
106
+ noise = np.random.normal(0, volatility)
107
+ price = base_price * (1 + momentum * 0.01 + noise)
108
+ price = max(10.0, price) # Prevent negative prices
109
+
110
+ self.price_history.append(price)
111
+
112
+ # Volume with some correlation to price movement
113
+ volume = 1000 + abs(price - base_price) * 50 + np.random.normal(0, 200)
114
+ self.volume_history.append(max(100, volume))
115
+
116
+ # Sentiment with persistence
117
+ if i > 0:
118
+ prev_sentiment = self.sentiment_history[-1]
119
+ sentiment_change = np.random.normal(0, 0.08)
120
+ sentiment = prev_sentiment + sentiment_change
121
+ else:
122
+ sentiment = 0.5 + np.random.normal(0, 0.1)
123
+
124
  self.sentiment_history.append(np.clip(sentiment, 0.0, 1.0))
125
+
126
  self.current_price = self.price_history[-1]
127
 
128
+ def _calculate_technical_indicators(self) -> List[float]:
129
+ """Calculate technical indicators from price history"""
130
+ prices = np.array(self.price_history[-50:]) # Use last 50 points
131
+
132
+ if len(prices) < 2:
133
+ return [0.0] * 6 # Default values
134
+
135
+ returns = np.diff(prices) / prices[:-1]
136
+
137
+ # Simple Moving Averages
138
+ sma_short = np.mean(prices[-10:]) if len(prices) >= 10 else prices[-1]
139
+ sma_long = np.mean(prices[-20:]) if len(prices) >= 20 else prices[-1]
140
+
141
+ # RSI (Relative Strength Index)
142
+ if len(returns) >= 14:
143
+ gains = returns[returns > 0]
144
+ losses = -returns[returns < 0]
145
+ avg_gain = np.mean(gains[-14:]) if len(gains) > 0 else 0.001
146
+ avg_loss = np.mean(losses[-14:]) if len(losses) > 0 else 0.001
147
+ rsi = 100 - (100 / (1 + avg_gain / avg_loss))
148
+ else:
149
+ rsi = 50.0
150
+
151
+ # Volatility (annualized)
152
+ volatility = np.std(returns) * np.sqrt(252) if len(returns) > 1 else 0.1
153
+
154
+ # Price momentum
155
+ momentum = (prices[-1] / prices[-5] - 1) if len(prices) >= 5 else 0.0
156
+
157
+ # Volume trend
158
+ volumes = np.array(self.volume_history[-10:])
159
+ volume_trend = np.mean(volumes[-5:]) / np.mean(volumes[-10:]) - 1 if len(volumes) >= 10 else 0.0
160
+
161
+ return [sma_short/100, sma_long/100, rsi/100, volatility, momentum, volume_trend]
162
+
163
+ def reset(self) -> Tuple[np.ndarray, Dict]:
164
+ """Reset environment to initial state"""
165
  self.balance = self.initial_balance
166
  self.position = 0.0
167
  self.step_count = 0
168
+ self.portfolio_history = []
169
+ self.action_history = []
170
+
171
+ # Reinitialize market data
172
  self.price_history = [100.0 + np.random.normal(0, 5)]
173
+ self.volume_history = [1000 + np.random.normal(0, 200)]
174
+ self.sentiment_history = [0.5 + np.random.normal(0, 0.1)]
175
+ self.current_price = self.price_history[-1]
176
+
177
  obs = self._get_observation()
178
  info = self._get_info()
179
+
180
  return obs, info
181
 
182
+ def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
183
+ """Execute one trading step"""
184
  self.step_count += 1
 
 
 
185
 
186
+ # Generate new market data with more realistic dynamics
187
+ self._update_market_data()
 
188
 
189
+ # Execute trading action
190
  reward = self._execute_action(action)
191
 
192
+ # Check termination conditions
193
  terminated = self.balance <= 0 or self.step_count >= self.max_steps
194
  truncated = False
195
 
196
+ # Get new observation and info
197
  obs = self._get_observation()
198
  info = self._get_info()
199
 
200
+ # Track portfolio value
201
+ self.portfolio_history.append(info['net_worth'])
202
+ self.action_history.append(action)
203
+
204
  return obs, reward, terminated, truncated, info
205
 
206
+ def _update_market_data(self):
207
+ """Update market data with realistic price movements"""
208
+ # Price change with momentum and volatility clustering
209
+ prev_returns = np.diff(self.price_history[-5:]) / self.price_history[-5:-1] if len(self.price_history) >= 6 else [0]
210
+ momentum = np.mean(prev_returns) if prev_returns else 0
211
+
212
+ volatility_map = {
213
+ "Crypto": 0.025,
214
+ "Stock": 0.012,
215
+ "Forex": 0.006
216
+ }
217
+ base_volatility = volatility_map.get(self.config.asset_type, 0.015)
218
+
219
+ # Volatility scaling based on risk level
220
+ volatility = base_volatility * self.risk_multiplier
221
+ price_change = momentum * 0.3 + np.random.normal(0, volatility)
222
+
223
+ self.current_price = max(10.0, self.current_price * (1 + price_change))
224
+ self.price_history.append(self.current_price)
225
+
226
+ # Update volume with some noise
227
+ base_volume = 1000
228
+ volume_noise = np.random.normal(0, 200)
229
+ new_volume = max(100, base_volume + abs(price_change) * 5000 + volume_noise)
230
+ self.volume_history.append(new_volume)
231
+
232
+ # Update sentiment with mean reversion
233
+ current_sentiment = self.sentiment_history[-1]
234
+ sentiment_reversion = (0.5 - current_sentiment) * 0.1 # Mean reversion
235
+ sentiment_noise = np.random.normal(0, 0.08)
236
+ new_sentiment = current_sentiment + sentiment_reversion + sentiment_noise
237
+ self.sentiment_history.append(np.clip(new_sentiment, 0.0, 1.0))
238
+
239
+ def _execute_action(self, action: int) -> float:
240
+ """Execute trading action and calculate reward"""
241
  prev_net_worth = self.balance + self.position * self.current_price
242
+ trade_size_multiplier = 0.2 * self.risk_multiplier # Risk-adjusted position sizing
243
 
244
  if action == 1: # Buy
245
+ if self.balance > 0:
246
+ trade_amount = min(self.balance * trade_size_multiplier, self.balance)
247
+ cost = trade_amount * (1 + self.transaction_cost)
248
+ if cost <= self.balance:
249
+ shares_bought = trade_amount / self.current_price
250
+ self.position += shares_bought
251
+ self.balance -= cost
252
 
253
  elif action == 2: # Sell
254
  if self.position > 0:
255
+ sell_fraction = trade_size_multiplier
256
+ shares_to_sell = min(self.position * sell_fraction, self.position)
257
+ proceeds = shares_to_sell * self.current_price * (1 - self.transaction_cost)
258
+ self.position -= shares_to_sell
259
  self.balance += proceeds
260
 
261
+ elif action == 3: # Close position
262
  if self.position > 0:
263
+ proceeds = self.position * self.current_price * (1 - self.transaction_cost)
264
  self.balance += proceeds
265
  self.position = 0
266
 
267
+ # Calculate new net worth and reward
268
+ new_net_worth = self.balance + self.position * self.current_price
269
+ raw_reward = (new_net_worth - prev_net_worth) / self.initial_balance * 100
270
 
271
+ # Risk-adjusted reward with penalty for large drawdowns
272
+ risk_penalty = 0.0
273
+ if new_net_worth < self.initial_balance * 0.8: # 20% drawdown
274
+ risk_penalty = (self.initial_balance - new_net_worth) / self.initial_balance * 10
 
275
 
276
+ final_reward = raw_reward - risk_penalty
277
+
278
+ return final_reward
279
+
280
+ def _get_observation(self) -> np.ndarray:
281
+ """Get current environment observation"""
282
+ # Price-based features
283
+ recent_prices = self.price_history[-20:] if len(self.price_history) >= 20 else [self.current_price] * 20
284
+ price_features = [
285
  self.current_price / 100.0,
286
  np.mean(recent_prices) / 100.0,
287
  np.std(recent_prices) / 100.0,
288
+ (self.current_price - np.min(recent_prices)) / (np.max(recent_prices) - np.min(recent_prices)) if len(recent_prices) > 1 else 0.5
289
+ ]
290
+
291
+ # Portfolio features
292
+ portfolio_features = [
293
+ self.balance / self.initial_balance,
294
+ self.position * self.current_price / self.initial_balance,
295
+ self.step_count / self.max_steps
296
+ ]
297
+
298
+ # Sentiment features
299
+ recent_sentiments = self.sentiment_history[-10:] if len(self.sentiment_history) >= 10 else [0.5] * 10
300
+ sentiment_features = [
301
  np.mean(recent_sentiments),
302
  np.std(recent_sentiments),
303
+ recent_sentiments[-1] # Latest sentiment
 
304
  ]
305
 
306
+ # Technical indicators
307
+ technical_features = self._calculate_technical_indicators()
308
+
309
+ # Combine all features
310
+ all_features = price_features + portfolio_features + sentiment_features + technical_features
311
+
312
+ # Ensure fixed size and convert to numpy array
313
+ observation = np.array(all_features[:15], dtype=np.float32)
314
+
315
+ return observation
316
 
317
+ def _get_info(self) -> Dict[str, Any]:
318
+ """Get environment information for logging"""
319
  net_worth = self.balance + self.position * self.current_price
320
+ return_total = (net_worth - self.initial_balance) / self.initial_balance * 100
321
+
322
+ return {
323
+ 'net_worth': net_worth,
324
+ 'return_percent': return_total,
325
+ 'position_value': self.position * self.current_price,
326
+ 'cash_balance': self.balance,
327
+ 'current_price': self.current_price,
328
+ 'steps': self.step_count
329
+ }
330
 
331
  class DQNAgent:
332
+ """
333
+ Deep Q-Network agent for trading decisions.
334
+ Implements experience replay and target network for stable learning.
335
+ """
336
+
337
+ def __init__(self, state_dim: int, action_dim: int, config: TradingConfig, device: str = 'cpu'):
338
  self.device = torch.device(device)
339
+ self.state_dim = state_dim
340
+ self.action_dim = action_dim
341
+ self.config = config
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
+ # Q-network and target network
344
+ self.q_network = self._build_network(state_dim, action_dim)
345
+ self.target_network = self._build_network(state_dim, action_dim)
346
  self.target_network.load_state_dict(self.q_network.state_dict())
347
 
348
+ # Optimization
349
+ self.optimizer = optim.Adam(self.q_network.parameters(), lr=config.learning_rate)
350
+ self.criterion = nn.MSELoss()
351
+
352
+ # Experience replay
353
  self.memory = deque(maxlen=config.memory_size)
354
+
355
+ # Exploration parameters
356
  self.epsilon = config.epsilon_start
357
  self.epsilon_min = config.epsilon_min
358
  self.epsilon_decay = config.epsilon_decay
359
+
360
+ # Training parameters
361
  self.batch_size = config.batch_size
362
+ self.gamma = config.gamma
363
  self.target_update = config.target_update
364
  self.steps = 0
365
 
366
+ def _build_network(self, state_dim: int, action_dim: int) -> nn.Sequential:
367
+ """Build the neural network for Q-value approximation"""
368
+ return nn.Sequential(
369
+ nn.Linear(state_dim, self.config.hidden_size),
370
+ nn.ReLU(),
371
+ nn.Linear(self.config.hidden_size, self.config.hidden_size),
372
+ nn.ReLU(),
373
+ nn.Linear(self.config.hidden_size, self.config.hidden_size // 2),
374
+ nn.ReLU(),
375
+ nn.Linear(self.config.hidden_size // 2, action_dim)
376
+ ).to(self.device)
377
+
378
+ def select_action(self, state: np.ndarray, training: bool = True) -> int:
379
+ """Select action using epsilon-greedy policy"""
380
  if training and random.random() < self.epsilon:
381
+ return random.randint(0, self.action_dim - 1)
382
+
383
+ state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
384
+
385
  with torch.no_grad():
386
+ q_values = self.q_network(state_tensor)
387
+ return q_values.argmax(1).item()
388
 
389
+ def store_transition(self, state: np.ndarray, action: int, reward: float,
390
+ next_state: np.ndarray, done: bool):
391
+ """Store experience in replay memory"""
392
  self.memory.append((state, action, reward, next_state, done))
393
 
394
+ def update(self) -> float:
395
+ """Update Q-network using experience replay"""
396
  if len(self.memory) < self.batch_size:
397
  return 0.0
398
 
399
+ # Sample batch from memory
400
  batch = random.sample(self.memory, self.batch_size)
401
  states, actions, rewards, next_states, dones = zip(*batch)
402
 
403
+ # Convert to tensors
404
  states = torch.FloatTensor(np.array(states)).to(self.device)
405
  actions = torch.LongTensor(actions).to(self.device)
406
  rewards = torch.FloatTensor(rewards).to(self.device)
407
  next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
408
+ dones = torch.BoolTensor(dones).to(self.device)
409
+
410
+ # Current Q values
411
+ current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
412
 
413
+ # Next Q values from target network
414
+ with torch.no_grad():
415
+ next_q_values = self.target_network(next_states).max(1)[0]
416
+ target_q_values = rewards + self.gamma * next_q_values * (~dones)
417
 
418
+ # Compute loss and update
419
+ loss = self.criterion(current_q_values, target_q_values)
420
 
421
  self.optimizer.zero_grad()
422
  loss.backward()
423
+
424
+ # Gradient clipping for stability
425
+ torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 1.0)
426
  self.optimizer.step()
427
 
428
+ # Update target network periodically
429
  self.steps += 1
430
  if self.steps % self.target_update == 0:
431
  self.target_network.load_state_dict(self.q_network.state_dict())
432
 
433
+ # Decay epsilon
434
  self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
435
 
436
  return loss.item()
437
 
438
+ def save(self, path: str):
439
+ """Save agent parameters"""
440
+ torch.save({
441
+ 'q_network_state_dict': self.q_network.state_dict(),
442
+ 'target_network_state_dict': self.target_network.state_dict(),
443
+ 'optimizer_state_dict': self.optimizer.state_dict(),
444
+ 'epsilon': self.epsilon,
445
+ 'steps': self.steps
446
+ }, path)
447
+
448
+ def load(self, path: str):
449
+ """Load agent parameters"""
450
+ checkpoint = torch.load(path, map_location=self.device)
451
+ self.q_network.load_state_dict(checkpoint['q_network_state_dict'])
452
+ self.target_network.load_state_dict(checkpoint['target_network_state_dict'])
453
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
454
+ self.epsilon = checkpoint['epsilon']
455
+ self.steps = checkpoint['steps']
456
+
457
  class TradingDemo:
458
+ """
459
+ Main demonstration class integrating trading environment and DQN agent.
460
+ Provides interface for training, simulation, and visualization.
461
+ """
462
+
463
  def __init__(self):
464
  self.config = TradingConfig()
465
  self.env = None
466
  self.agent = None
467
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
468
+ print(f"Using device: {self.device}")
469
+
470
+ # Training history
471
+ self.training_history = {
472
+ 'episode_rewards': [],
473
+ 'episode_losses': [],
474
+ 'epsilon_history': []
475
+ }
476
+
477
+ def initialize(self, balance: float, risk: str, asset: str) -> str:
478
+ """Initialize trading environment and agent"""
479
+ try:
480
+ self.config.initial_balance = float(balance)
481
+ self.config.risk_level = risk
482
+ self.config.asset_type = asset
483
+
484
+ # Create environment and agent
485
+ self.env = AdvancedTradingEnvironment(self.config)
486
+ self.agent = DQNAgent(15, 4, self.config, self.device)
487
+
488
+ # Reset training history
489
+ self.training_history = {
490
+ 'episode_rewards': [],
491
+ 'episode_losses': [],
492
+ 'epsilon_history': []
493
+ }
494
+
495
+ return f"โœ… System initialized! Balance: ${balance}, Risk: {risk}, Asset: {asset}"
496
+
497
+ except Exception as e:
498
+ return f"โŒ Initialization failed: {str(e)}"
499
 
500
+ def train(self, episodes: int):
501
+ """Train the DQN agent"""
502
+ if self.env is None or self.agent is None:
503
+ yield "โŒ Please initialize the system first!", None
504
+ return
505
+
506
+ try:
507
+ episodes = int(episodes)
508
+ for episode in range(episodes):
509
+ # Reset environment
510
+ obs, _ = self.env.reset()
511
+ total_reward = 0
512
+ episode_loss = 0
513
+ update_count = 0
514
+ done = False
515
+
516
+ while not done:
517
+ # Select and execute action
518
+ action = self.agent.select_action(obs)
519
+ next_obs, reward, done, _, info = self.env.step(action)
520
+
521
+ # Store experience and update
522
+ self.agent.store_transition(obs, action, reward, next_obs, done)
523
+ loss = self.agent.update()
524
+
525
+ if loss > 0:
526
+ episode_loss += loss
527
+ update_count += 1
528
+
529
+ total_reward += reward
530
+ obs = next_obs
531
+
532
+ # Calculate average loss
533
+ avg_loss = episode_loss / max(update_count, 1)
534
+
535
+ # Update history
536
+ self.training_history['episode_rewards'].append(total_reward)
537
+ self.training_history['episode_losses'].append(avg_loss)
538
+ self.training_history['epsilon_history'].append(self.agent.epsilon)
539
+
540
+ # Yield progress
541
+ progress = f"Episode {episode+1}/{episodes} | " \
542
+ f"Reward: {total_reward:.2f} | " \
543
+ f"Loss: {avg_loss:.4f} | " \
544
+ f"Epsilon: {self.agent.epsilon:.3f} | " \
545
+ f"Net Worth: ${info['net_worth']:.2f}"
546
+
547
+ # Create training plot every 10 episodes or at the end
548
+ if (episode + 1) % 10 == 0 or episode == episodes - 1:
549
+ plot = self._create_training_plot()
550
+ yield progress, plot
551
+ else:
552
+ yield progress, None
553
+
554
+ yield "โœ… Training completed successfully!", self._create_training_plot()
555
+
556
+ except Exception as e:
557
+ yield f"โŒ Training error: {str(e)}", None
558
 
559
+ def simulate(self, steps: int):
560
+ """Run trading simulation with current policy"""
561
+ if self.env is None or self.agent is None:
562
+ return "โŒ Please initialize and train the system first!", None
563
+
564
+ try:
565
+ steps = int(steps)
566
  obs, _ = self.env.reset()
567
+
568
+ # Tracking data
569
+ prices = []
570
+ actions = []
571
+ net_worths = []
572
+ portfolio_values = []
573
+ cash_balances = []
574
+
575
+ for step in range(steps):
576
+ action = self.agent.select_action(obs, training=False)
577
  next_obs, reward, done, _, info = self.env.step(action)
578
+
579
+ # Track metrics
580
+ prices.append(self.env.current_price)
581
+ actions.append(action)
582
+ net_worths.append(info['net_worth'])
583
+ portfolio_values.append(info['position_value'])
584
+ cash_balances.append(info['cash_balance'])
585
+
586
  obs = next_obs
587
+ if done:
588
+ break
589
+
590
+ # Create comprehensive visualization
591
+ fig = self._create_simulation_plot(prices, actions, net_worths, portfolio_values, cash_balances)
592
+
593
+ final_return = (net_worths[-1] - self.config.initial_balance) / self.config.initial_balance * 100
594
+ result_text = f"โœ… Simulation completed! Final Return: {final_return:.2f}% | " \
595
+ f"Final Net Worth: ${net_worths[-1]:.2f}"
596
+
597
+ return result_text, fig
598
+
599
+ except Exception as e:
600
+ return f"โŒ Simulation error: {str(e)}", None
601
 
602
+ def _create_training_plot(self):
603
+ """Create training progress visualization"""
604
+ if not self.training_history['episode_rewards']:
605
+ return None
606
+
607
+ episodes = list(range(1, len(self.training_history['episode_rewards']) + 1))
608
+
609
+ fig = make_subplots(rows=2, cols=2,
610
+ subplot_titles=('Episode Rewards', 'Training Loss',
611
+ 'Epsilon Decay', 'Moving Average Reward'),
612
+ vertical_spacing=0.12)
613
+
614
+ # Rewards
615
+ fig.add_trace(
616
+ go.Scatter(x=episodes, y=self.training_history['episode_rewards'],
617
+ mode='lines', name='Reward', line=dict(color='blue')),
618
+ row=1, col=1
619
+ )
620
+
621
+ # Loss
622
+ fig.add_trace(
623
+ go.Scatter(x=episodes, y=self.training_history['episode_losses'],
624
+ mode='lines', name='Loss', line=dict(color='red')),
625
+ row=1, col=2
626
+ )
627
+
628
+ # Epsilon
629
+ fig.add_trace(
630
+ go.Scatter(x=episodes, y=self.training_history['epsilon_history'],
631
+ mode='lines', name='Epsilon', line=dict(color='green')),
632
+ row=2, col=1
633
+ )
634
+
635
+ # Moving average reward
636
+ window = min(20, len(episodes))
637
+ moving_avg = [np.mean(self.training_history['episode_rewards'][max(0, i-window):i+1])
638
+ for i in range(len(episodes))]
639
+ fig.add_trace(
640
+ go.Scatter(x=episodes, y=moving_avg,
641
+ mode='lines', name='MA Reward', line=dict(color='orange', width=2)),
642
+ row=2, col=2
643
+ )
644
+
645
+ fig.update_layout(height=600, showlegend=True, title_text="Training Progress")
646
+ return fig
647
 
648
+ def _create_simulation_plot(self, prices, actions, net_worths, portfolio_values, cash_balances):
649
+ """Create comprehensive simulation results visualization"""
650
+ fig = make_subplots(rows=2, cols=2,
651
+ subplot_titles=('Price & Actions', 'Portfolio Performance',
652
+ 'Portfolio Composition', 'Action Distribution'),
653
+ vertical_spacing=0.12,
654
+ horizontal_spacing=0.1)
655
+
656
+ steps = list(range(len(prices)))
657
+
658
+ # Price and actions
659
+ fig.add_trace(
660
+ go.Scatter(x=steps, y=prices, mode='lines', name='Price', line=dict(color='blue')),
661
+ row=1, col=1
662
+ )
663
+
664
+ # Add action markers
665
+ action_colors = ['gray', 'green', 'red', 'orange'] # Hold, Buy, Sell, Close
666
+ action_names = ['Hold', 'Buy', 'Sell', 'Close']
667
+ for action in range(4):
668
+ action_indices = [i for i, a in enumerate(actions) if a == action]
669
+ if action_indices:
670
+ action_prices = [prices[i] for i in action_indices]
671
+ fig.add_trace(
672
+ go.Scatter(x=action_indices, y=action_prices,
673
+ mode='markers', name=action_names[action],
674
+ marker=dict(color=action_colors[action], size=8)),
675
+ row=1, col=1
676
+ )
677
+
678
+ # Portfolio performance
679
+ initial_balance = self.config.initial_balance
680
+ returns = [(nw - initial_balance) / initial_balance * 100 for nw in net_worths]
681
+
682
+ fig.add_trace(
683
+ go.Scatter(x=steps, y=net_worths, mode='lines', name='Net Worth', line=dict(color='purple')),
684
+ row=1, col=2
685
+ )
686
+ fig.add_trace(
687
+ go.Scatter(x=steps, y=returns, mode='lines', name='Return %', line=dict(color='orange'), yaxis='y2'),
688
+ row=1, col=2
689
+ )
690
+
691
+ # Portfolio composition
692
+ fig.add_trace(
693
+ go.Scatter(x=steps, y=portfolio_values, mode='lines', name='Portfolio Value', line=dict(color='green')),
694
+ row=2, col=1
695
+ )
696
+ fig.add_trace(
697
+ go.Scatter(x=steps, y=cash_balances, mode='lines', name='Cash Balance', line=dict(color='blue')),
698
+ row=2, col=1
699
+ )
700
+
701
+ # Action distribution
702
+ action_counts = [actions.count(i) for i in range(4)]
703
+ fig.add_trace(
704
+ go.Bar(x=action_names, y=action_counts,
705
+ marker_color=action_colors, name='Action Count'),
706
+ row=2, col=2
707
+ )
708
+
709
+ # Update layout
710
+ fig.update_layout(height=700, showlegend=True, title_text="Trading Simulation Results")
711
+ fig.update_yaxes(title_text="Return (%)", row=1, col=2, secondary_y=True)
712
+ fig.update_yaxes(title_text="Value ($)", row=1, col=2, secondary_y=False)
713
+
714
+ return fig
715
 
716
+ # Create and launch Gradio interface
717
+ def create_interface():
718
+ """Create Gradio interface for the trading demo"""
719
+ demo = TradingDemo()
 
 
 
 
 
 
720
 
721
+ with gr.Blocks(theme=gr.themes.Soft(), title="AI Trading Demo") as interface:
722
+ gr.Markdown("""
723
+ # ๐Ÿค– Advanced AI Trading Demo
724
+ **Deep Reinforcement Learning for Financial Markets**
725
+
726
+ This demo shows a DQN agent learning to trade in simulated financial markets.
727
+ The agent learns optimal trading strategies through reinforcement learning.
728
+ """)
729
+
730
+ with gr.Row():
731
+ with gr.Column(scale=1):
732
+ gr.Markdown("## ๐ŸŽฏ Configuration")
733
+
734
+ balance = gr.Slider(1000, 50000, 10000, step=1000, label="Initial Balance ($)")
735
+ risk = gr.Radio(["Low", "Medium", "High"], value="Medium", label="Risk Level")
736
+ asset = gr.Radio(["Crypto", "Stock", "Forex"], value="Crypto", label="Asset Type")
737
+ init_btn = gr.Button("๐Ÿš€ Initialize System", variant="primary")
738
+
739
+ with gr.Column(scale=2):
740
+ gr.Markdown("## ๐Ÿ“Š System Status")
741
+ status = gr.Textbox(label="Status", value="Click 'Initialize System' to start", interactive=False)
742
+
743
+ with gr.Row():
744
+ with gr.Column():
745
+ gr.Markdown("## ๐Ÿ‹๏ธ Training")
746
+ episodes = gr.Number(value=100, label="Training Episodes", precision=0)
747
+ train_btn = gr.Button("๐ŸŽฏ Start Training", variant="primary")
748
+ train_plot = gr.Plot(label="Training Progress")
749
+
750
+ with gr.Column():
751
+ gr.Markdown("## ๐Ÿ“ˆ Simulation")
752
+ steps = gr.Number(value=200, label="Simulation Steps", precision=0)
753
+ sim_btn = gr.Button("โ–ถ๏ธ Run Simulation", variant="primary")
754
+ sim_plot = gr.Plot(label="Simulation Results")
755
+
756
+ # Event handlers
757
+ init_btn.click(
758
+ demo.initialize,
759
+ inputs=[balance, risk, asset],
760
+ outputs=status
761
+ )
762
+
763
+ train_btn.click(
764
+ demo.train,
765
+ inputs=episodes,
766
+ outputs=[status, train_plot]
767
+ )
768
+
769
+ sim_btn.click(
770
+ demo.simulate,
771
+ inputs=steps,
772
+ outputs=[status, sim_plot]
773
+ )
774
+
775
+ gr.Markdown("""
776
+ ## ๐Ÿ“– How to Use:
777
+ 1. **Configure**: Set your initial balance, risk level, and asset type
778
+ 2. **Initialize**: Click 'Initialize System' to set up the trading environment
779
+ 3. **Train**: Start training the AI agent (recommended: 100+ episodes)
780
+ 4. **Simulate**: Run a trading simulation to see the trained agent in action
781
+
782
+ ## ๐ŸŽฎ Actions:
783
+ - **0: Hold** - Maintain current position
784
+ - **1: Buy** - Purchase asset (20% of balance)
785
+ - **2: Sell** - Sell portion of position (20%)
786
+ - **3: Close** - Liquidate entire position
787
+ """)
788
 
789
+ return interface
 
 
790
 
791
+ # Launch the application
792
+ if __name__ == "__main__":
793
+ interface = create_interface()
794
+ interface.launch(share=True, server_name="0.0.0.0", server_port=7860)