OmidSakaki commited on
Commit
92a27f9
·
verified ·
1 Parent(s): 093afbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +498 -9
app.py CHANGED
@@ -5,19 +5,513 @@ import torch.nn as nn
5
  import torch.optim as optim
6
  from collections import deque
7
  import random
8
- from typing import Dict, Tuple, Any, List, Optional
9
  import plotly.graph_objects as go
10
  from plotly.subplots import make_subplots
11
 
12
- # (تمام کلاس‌های TradingConfig, AdvancedTradingEnvironment, DQNAgent, TradingDemo عیناً حفظ شده‌اند...)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # این تابع همانند قبل است
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def create_interface():
16
  demo = TradingDemo()
17
  with gr.Blocks(theme=gr.themes.Soft(), title="AI Trading Demo") as interface:
18
  gr.Markdown("""
19
  # 🤖 Advanced AI Trading Demo
20
  **Deep Reinforcement Learning for Financial Markets**
 
21
  This demo shows a DQN agent learning to trade in simulated financial markets.
22
  The agent learns optimal trading strategies through reinforcement learning.
23
  """)
@@ -29,7 +523,6 @@ def create_interface():
29
  risk = gr.Radio(["Low", "Medium", "High"], value="Medium", label="Risk Level")
30
  asset = gr.Radio(["Crypto", "Stock", "Forex"], value="Crypto", label="Asset Type")
31
  init_btn = gr.Button("🚀 Initialize System", variant="primary")
32
-
33
  with gr.Column(scale=2):
34
  gr.Markdown("## 📊 System Status")
35
  status = gr.Textbox(label="Status", value="Click 'Initialize System' to start", interactive=False)
@@ -40,26 +533,22 @@ def create_interface():
40
  episodes = gr.Number(value=100, label="Training Episodes", precision=0)
41
  train_btn = gr.Button("🎯 Start Training", variant="primary")
42
  train_plot = gr.Plot(label="Training Progress")
43
-
44
  with gr.Column():
45
  gr.Markdown("## 📈 Simulation")
46
  steps = gr.Number(value=200, label="Simulation Steps", precision=0)
47
  sim_btn = gr.Button("▶️ Run Simulation", variant="primary")
48
  sim_plot = gr.Plot(label="Simulation Results")
49
 
50
- # Event handlers
51
  init_btn.click(
52
  demo.initialize,
53
  inputs=[balance, risk, asset],
54
  outputs=status
55
  )
56
-
57
  train_btn.click(
58
  demo.train,
59
  inputs=episodes,
60
  outputs=[status, train_plot]
61
  )
62
-
63
  sim_btn.click(
64
  demo.simulate,
65
  inputs=steps,
@@ -72,6 +561,7 @@ def create_interface():
72
  2. **Initialize**: Click 'Initialize System' to set up the trading environment
73
  3. **Train**: Start training the AI agent (recommended: 100+ episodes)
74
  4. **Simulate**: Run a trading simulation to see the trained agent in action
 
75
  ## 🎮 Actions:
76
  - **0: Hold** - Maintain current position
77
  - **1: Buy** - Purchase asset (20% of balance)
@@ -80,5 +570,4 @@ def create_interface():
80
  """)
81
  return interface
82
 
83
- # نکته مهم: فقط این خط باید اجرا شود و نام متغیر باید demo باشد
84
  demo = create_interface()
 
5
  import torch.optim as optim
6
  from collections import deque
7
  import random
8
+ from typing import Dict, Tuple, Any, List
9
  import plotly.graph_objects as go
10
  from plotly.subplots import make_subplots
11
 
12
+ # ==== 1. Configuration Class ====
13
+ class TradingConfig:
14
+ """
15
+ Central configuration for trading environment and agent.
16
+ """
17
+ def __init__(self):
18
+ # Environment settings
19
+ self.initial_balance = 10000.0
20
+ self.max_steps = 1000
21
+ self.transaction_cost = 0.001
22
+ self.risk_level = "Medium"
23
+ self.asset_type = "Crypto"
24
+
25
+ # DQN agent settings
26
+ self.learning_rate = 0.0001
27
+ self.gamma = 0.99
28
+ self.epsilon_start = 1.0
29
+ self.epsilon_min = 0.01
30
+ self.epsilon_decay = 0.9995
31
+ self.batch_size = 32
32
+ self.memory_size = 10000
33
+ self.target_update = 100
34
+ self.hidden_size = 128
35
+
36
+ # Risk multipliers
37
+ self.risk_multipliers = {
38
+ "Low": 0.5,
39
+ "Medium": 1.0,
40
+ "High": 2.0
41
+ }
42
+
43
+ # ==== 2. Trading Environment ====
44
+ class AdvancedTradingEnvironment:
45
+ """
46
+ Simulates a financial market with synthetic data, multi-asset support,
47
+ and technical/sentiment indicators.
48
+ """
49
+ def __init__(self, config: TradingConfig):
50
+ self.config = config
51
+ self.initial_balance = config.initial_balance
52
+ self.balance = self.initial_balance
53
+ self.position = 0.0
54
+ self.current_price = 100.0
55
+ self.step_count = 0
56
+ self.max_steps = config.max_steps
57
+ self.transaction_cost = config.transaction_cost
58
+
59
+ # Market data
60
+ self.price_history = []
61
+ self.volume_history = []
62
+ self.sentiment_history = []
63
+
64
+ # Risk multiplier
65
+ self.risk_multiplier = config.risk_multipliers[config.risk_level]
66
+ self._initialize_market_data()
67
+
68
+ self.action_space = 4 # Hold, Buy, Sell, Close
69
+ self.observation_space = (15,)
70
+
71
+ # For plotting
72
+ self.portfolio_history = []
73
+ self.action_history = []
74
+
75
+ def _initialize_market_data(self):
76
+ n_points = 200
77
+ volatility_map = {"Crypto": 0.03, "Stock": 0.015, "Forex": 0.008}
78
+ volatility = volatility_map.get(self.config.asset_type, 0.02)
79
+ base_price = 100.0
80
+
81
+ for i in range(n_points):
82
+ momentum = np.sin(i * 0.05) * 2
83
+ noise = np.random.normal(0, volatility)
84
+ price = base_price * (1 + momentum * 0.01 + noise)
85
+ price = max(10.0, price)
86
+ self.price_history.append(price)
87
+ volume = 1000 + abs(price - base_price) * 50 + np.random.normal(0, 200)
88
+ self.volume_history.append(max(100, volume))
89
+ if i > 0:
90
+ prev_sentiment = self.sentiment_history[-1]
91
+ sentiment_change = np.random.normal(0, 0.08)
92
+ sentiment = prev_sentiment + sentiment_change
93
+ else:
94
+ sentiment = 0.5 + np.random.normal(0, 0.1)
95
+ self.sentiment_history.append(np.clip(sentiment, 0.0, 1.0))
96
+ self.current_price = self.price_history[-1]
97
+
98
+ def _calculate_technical_indicators(self) -> List[float]:
99
+ prices = np.array(self.price_history[-50:])
100
+ if len(prices) < 2:
101
+ return [0.0] * 6
102
+ returns = np.diff(prices) / prices[:-1]
103
+ sma_short = np.mean(prices[-10:]) if len(prices) >= 10 else prices[-1]
104
+ sma_long = np.mean(prices[-20:]) if len(prices) >= 20 else prices[-1]
105
+ if len(returns) >= 14:
106
+ gains = returns[returns > 0]
107
+ losses = -returns[returns < 0]
108
+ avg_gain = np.mean(gains[-14:]) if len(gains) > 0 else 0.001
109
+ avg_loss = np.mean(losses[-14:]) if len(losses) > 0 else 0.001
110
+ rsi = 100 - (100 / (1 + avg_gain / avg_loss))
111
+ else:
112
+ rsi = 50.0
113
+ volatility = np.std(returns) * np.sqrt(252) if len(returns) > 1 else 0.1
114
+ momentum = (prices[-1] / prices[-5] - 1) if len(prices) >= 5 else 0.0
115
+ volumes = np.array(self.volume_history[-10:])
116
+ volume_trend = np.mean(volumes[-5:]) / np.mean(volumes[-10:]) - 1 if len(volumes) >= 10 else 0.0
117
+ return [sma_short/100, sma_long/100, rsi/100, volatility, momentum, volume_trend]
118
+
119
+ def reset(self) -> Tuple[np.ndarray, Dict]:
120
+ self.balance = self.initial_balance
121
+ self.position = 0.0
122
+ self.step_count = 0
123
+ self.portfolio_history = []
124
+ self.action_history = []
125
+ self.price_history = [100.0 + np.random.normal(0, 5)]
126
+ self.volume_history = [1000 + np.random.normal(0, 200)]
127
+ self.sentiment_history = [0.5 + np.random.normal(0, 0.1)]
128
+ self.current_price = self.price_history[-1]
129
+ obs = self._get_observation()
130
+ info = self._get_info()
131
+ return obs, info
132
+
133
+ def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
134
+ self.step_count += 1
135
+ self._update_market_data()
136
+ reward = self._execute_action(action)
137
+ terminated = self.balance <= 0 or self.step_count >= self.max_steps
138
+ truncated = False
139
+ obs = self._get_observation()
140
+ info = self._get_info()
141
+ self.portfolio_history.append(info['net_worth'])
142
+ self.action_history.append(action)
143
+ return obs, reward, terminated, truncated, info
144
+
145
+ def _update_market_data(self):
146
+ prev_returns = np.diff(self.price_history[-5:]) / self.price_history[-5:-1] if len(self.price_history) >= 6 else [0]
147
+ momentum = np.mean(prev_returns) if prev_returns else 0
148
+ volatility_map = {"Crypto": 0.025, "Stock": 0.012, "Forex": 0.006}
149
+ base_volatility = volatility_map.get(self.config.asset_type, 0.015)
150
+ volatility = base_volatility * self.risk_multiplier
151
+ price_change = momentum * 0.3 + np.random.normal(0, volatility)
152
+ self.current_price = max(10.0, self.current_price * (1 + price_change))
153
+ self.price_history.append(self.current_price)
154
+ base_volume = 1000
155
+ volume_noise = np.random.normal(0, 200)
156
+ new_volume = max(100, base_volume + abs(price_change) * 5000 + volume_noise)
157
+ self.volume_history.append(new_volume)
158
+ current_sentiment = self.sentiment_history[-1]
159
+ sentiment_reversion = (0.5 - current_sentiment) * 0.1
160
+ sentiment_noise = np.random.normal(0, 0.08)
161
+ new_sentiment = current_sentiment + sentiment_reversion + sentiment_noise
162
+ self.sentiment_history.append(np.clip(new_sentiment, 0.0, 1.0))
163
+
164
+ def _execute_action(self, action: int) -> float:
165
+ prev_net_worth = self.balance + self.position * self.current_price
166
+ trade_size_multiplier = 0.2 * self.risk_multiplier
167
+ if action == 1: # Buy
168
+ if self.balance > 0:
169
+ trade_amount = min(self.balance * trade_size_multiplier, self.balance)
170
+ cost = trade_amount * (1 + self.transaction_cost)
171
+ if cost <= self.balance:
172
+ shares_bought = trade_amount / self.current_price
173
+ self.position += shares_bought
174
+ self.balance -= cost
175
+ elif action == 2: # Sell
176
+ if self.position > 0:
177
+ sell_fraction = trade_size_multiplier
178
+ shares_to_sell = min(self.position * sell_fraction, self.position)
179
+ proceeds = shares_to_sell * self.current_price * (1 - self.transaction_cost)
180
+ self.position -= shares_to_sell
181
+ self.balance += proceeds
182
+ elif action == 3: # Close
183
+ if self.position > 0:
184
+ proceeds = self.position * self.current_price * (1 - self.transaction_cost)
185
+ self.balance += proceeds
186
+ self.position = 0
187
+ new_net_worth = self.balance + self.position * self.current_price
188
+ raw_reward = (new_net_worth - prev_net_worth) / self.initial_balance * 100
189
+ risk_penalty = 0.0
190
+ if new_net_worth < self.initial_balance * 0.8:
191
+ risk_penalty = (self.initial_balance - new_net_worth) / self.initial_balance * 10
192
+ final_reward = raw_reward - risk_penalty
193
+ return final_reward
194
+
195
+ def _get_observation(self) -> np.ndarray:
196
+ recent_prices = self.price_history[-20:] if len(self.price_history) >= 20 else [self.current_price] * 20
197
+ price_features = [
198
+ self.current_price / 100.0,
199
+ np.mean(recent_prices) / 100.0,
200
+ np.std(recent_prices) / 100.0,
201
+ (self.current_price - np.min(recent_prices)) / (np.max(recent_prices) - np.min(recent_prices)) if len(recent_prices) > 1 else 0.5
202
+ ]
203
+ portfolio_features = [
204
+ self.balance / self.initial_balance,
205
+ self.position * self.current_price / self.initial_balance,
206
+ self.step_count / self.max_steps
207
+ ]
208
+ recent_sentiments = self.sentiment_history[-10:] if len(self.sentiment_history) >= 10 else [0.5] * 10
209
+ sentiment_features = [
210
+ np.mean(recent_sentiments),
211
+ np.std(recent_sentiments),
212
+ recent_sentiments[-1]
213
+ ]
214
+ technical_features = self._calculate_technical_indicators()
215
+ all_features = price_features + portfolio_features + sentiment_features + technical_features
216
+ observation = np.array(all_features[:15], dtype=np.float32)
217
+ return observation
218
+
219
+ def _get_info(self) -> Dict[str, Any]:
220
+ net_worth = self.balance + self.position * self.current_price
221
+ return_total = (net_worth - self.initial_balance) / self.initial_balance * 100
222
+ return {
223
+ 'net_worth': net_worth,
224
+ 'return_percent': return_total,
225
+ 'position_value': self.position * self.current_price,
226
+ 'cash_balance': self.balance,
227
+ 'current_price': self.current_price,
228
+ 'steps': self.step_count
229
+ }
230
+
231
+ # ==== 3. DQN Agent ====
232
+ class DQNAgent:
233
+ """
234
+ Deep Q-Network agent for trading.
235
+ """
236
+ def __init__(self, state_dim: int, action_dim: int, config: TradingConfig, device: str = 'cpu'):
237
+ self.device = torch.device(device)
238
+ self.state_dim = state_dim
239
+ self.action_dim = action_dim
240
+ self.config = config
241
+ self.q_network = self._build_network(state_dim, action_dim)
242
+ self.target_network = self._build_network(state_dim, action_dim)
243
+ self.target_network.load_state_dict(self.q_network.state_dict())
244
+ self.optimizer = optim.Adam(self.q_network.parameters(), lr=config.learning_rate)
245
+ self.criterion = nn.MSELoss()
246
+ self.memory = deque(maxlen=config.memory_size)
247
+ self.epsilon = config.epsilon_start
248
+ self.epsilon_min = config.epsilon_min
249
+ self.epsilon_decay = config.epsilon_decay
250
+ self.batch_size = config.batch_size
251
+ self.gamma = config.gamma
252
+ self.target_update = config.target_update
253
+ self.steps = 0
254
+
255
+ def _build_network(self, state_dim: int, action_dim: int) -> nn.Sequential:
256
+ return nn.Sequential(
257
+ nn.Linear(state_dim, self.config.hidden_size),
258
+ nn.ReLU(),
259
+ nn.Linear(self.config.hidden_size, self.config.hidden_size),
260
+ nn.ReLU(),
261
+ nn.Linear(self.config.hidden_size, self.config.hidden_size // 2),
262
+ nn.ReLU(),
263
+ nn.Linear(self.config.hidden_size // 2, action_dim)
264
+ ).to(self.device)
265
+
266
+ def select_action(self, state: np.ndarray, training: bool = True) -> int:
267
+ if training and random.random() < self.epsilon:
268
+ return random.randint(0, self.action_dim - 1)
269
+ state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
270
+ with torch.no_grad():
271
+ q_values = self.q_network(state_tensor)
272
+ return q_values.argmax(1).item()
273
+
274
+ def store_transition(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool):
275
+ self.memory.append((state, action, reward, next_state, done))
276
+
277
+ def update(self) -> float:
278
+ if len(self.memory) < self.batch_size:
279
+ return 0.0
280
+ batch = random.sample(self.memory, self.batch_size)
281
+ states, actions, rewards, next_states, dones = zip(*batch)
282
+ states = torch.FloatTensor(np.array(states)).to(self.device)
283
+ actions = torch.LongTensor(actions).to(self.device)
284
+ rewards = torch.FloatTensor(rewards).to(self.device)
285
+ next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
286
+ dones = torch.BoolTensor(dones).to(self.device)
287
+ current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
288
+ with torch.no_grad():
289
+ next_q_values = self.target_network(next_states).max(1)[0]
290
+ target_q_values = rewards + self.gamma * next_q_values * (~dones).float()
291
+ loss = self.criterion(current_q_values, target_q_values)
292
+ self.optimizer.zero_grad()
293
+ loss.backward()
294
+ torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 1.0)
295
+ self.optimizer.step()
296
+ self.steps += 1
297
+ if self.steps % self.target_update == 0:
298
+ self.target_network.load_state_dict(self.q_network.state_dict())
299
+ self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
300
+ return loss.item()
301
+
302
+ def save(self, path: str):
303
+ torch.save({
304
+ 'q_network_state_dict': self.q_network.state_dict(),
305
+ 'target_network_state_dict': self.target_network.state_dict(),
306
+ 'optimizer_state_dict': self.optimizer.state_dict(),
307
+ 'epsilon': self.epsilon,
308
+ 'steps': self.steps
309
+ }, path)
310
 
311
+ def load(self, path: str):
312
+ checkpoint = torch.load(path, map_location=self.device)
313
+ self.q_network.load_state_dict(checkpoint['q_network_state_dict'])
314
+ self.target_network.load_state_dict(checkpoint['target_network_state_dict'])
315
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
316
+ self.epsilon = checkpoint['epsilon']
317
+ self.steps = checkpoint['steps']
318
+
319
+ # ==== 4. Main Application ====
320
+ class TradingDemo:
321
+ """
322
+ Main class integrating environment and agent, with training/simulation and plots.
323
+ """
324
+ def __init__(self):
325
+ self.config = TradingConfig()
326
+ self.env = None
327
+ self.agent = None
328
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
329
+ self.training_history = {
330
+ 'episode_rewards': [],
331
+ 'episode_losses': [],
332
+ 'epsilon_history': []
333
+ }
334
+
335
+ def initialize(self, balance: float, risk: str, asset: str) -> str:
336
+ try:
337
+ self.config.initial_balance = float(balance)
338
+ self.config.risk_level = risk
339
+ self.config.asset_type = asset
340
+ self.env = AdvancedTradingEnvironment(self.config)
341
+ self.agent = DQNAgent(15, 4, self.config, self.device)
342
+ self.training_history = {
343
+ 'episode_rewards': [],
344
+ 'episode_losses': [],
345
+ 'epsilon_history': []
346
+ }
347
+ return f"✅ System initialized! Balance: ${balance}, Risk: {risk}, Asset: {asset}"
348
+ except Exception as e:
349
+ return f"❌ Initialization failed: {str(e)}"
350
+
351
+ def train(self, episodes: int):
352
+ if self.env is None or self.agent is None:
353
+ yield "❌ Please initialize the system first!", None
354
+ return
355
+ try:
356
+ episodes = int(episodes)
357
+ for episode in range(episodes):
358
+ obs, _ = self.env.reset()
359
+ total_reward = 0
360
+ episode_loss = 0
361
+ update_count = 0
362
+ done = False
363
+ while not done:
364
+ action = self.agent.select_action(obs)
365
+ next_obs, reward, done, _, info = self.env.step(action)
366
+ self.agent.store_transition(obs, action, reward, next_obs, done)
367
+ loss = self.agent.update()
368
+ if loss > 0:
369
+ episode_loss += loss
370
+ update_count += 1
371
+ total_reward += reward
372
+ obs = next_obs
373
+ avg_loss = episode_loss / max(update_count, 1)
374
+ self.training_history['episode_rewards'].append(total_reward)
375
+ self.training_history['episode_losses'].append(avg_loss)
376
+ self.training_history['epsilon_history'].append(self.agent.epsilon)
377
+ progress = f"Episode {episode+1}/{episodes} | " \
378
+ f"Reward: {total_reward:.2f} | " \
379
+ f"Loss: {avg_loss:.4f} | " \
380
+ f"Epsilon: {self.agent.epsilon:.3f} | " \
381
+ f"Net Worth: ${info['net_worth']:.2f}"
382
+ if (episode + 1) % 10 == 0 or episode == episodes - 1:
383
+ plot = self._create_training_plot()
384
+ yield progress, plot
385
+ else:
386
+ yield progress, None
387
+ yield "✅ Training completed successfully!", self._create_training_plot()
388
+ except Exception as e:
389
+ yield f"❌ Training error: {str(e)}", None
390
+
391
+ def simulate(self, steps: int):
392
+ if self.env is None or self.agent is None:
393
+ return "❌ Please initialize and train the system first!", None
394
+ try:
395
+ steps = int(steps)
396
+ obs, _ = self.env.reset()
397
+ prices = []
398
+ actions = []
399
+ net_worths = []
400
+ portfolio_values = []
401
+ cash_balances = []
402
+ for step in range(steps):
403
+ action = self.agent.select_action(obs, training=False)
404
+ next_obs, reward, done, _, info = self.env.step(action)
405
+ prices.append(self.env.current_price)
406
+ actions.append(action)
407
+ net_worths.append(info['net_worth'])
408
+ portfolio_values.append(info['position_value'])
409
+ cash_balances.append(info['cash_balance'])
410
+ obs = next_obs
411
+ if done:
412
+ break
413
+ fig = self._create_simulation_plot(prices, actions, net_worths, portfolio_values, cash_balances)
414
+ final_return = (net_worths[-1] - self.config.initial_balance) / self.config.initial_balance * 100
415
+ result_text = f"✅ Simulation completed! Final Return: {final_return:.2f}% | " \
416
+ f"Final Net Worth: ${net_worths[-1]:.2f}"
417
+ return result_text, fig
418
+ except Exception as e:
419
+ return f"❌ Simulation error: {str(e)}", None
420
+
421
+ def _create_training_plot(self):
422
+ if not self.training_history['episode_rewards']:
423
+ return None
424
+ episodes = list(range(1, len(self.training_history['episode_rewards']) + 1))
425
+ fig = make_subplots(rows=2, cols=2,
426
+ subplot_titles=('Episode Rewards', 'Training Loss',
427
+ 'Epsilon Decay', 'Moving Average Reward'),
428
+ vertical_spacing=0.12)
429
+ fig.add_trace(
430
+ go.Scatter(x=episodes, y=self.training_history['episode_rewards'],
431
+ mode='lines', name='Reward', line=dict(color='blue')),
432
+ row=1, col=1
433
+ )
434
+ fig.add_trace(
435
+ go.Scatter(x=episodes, y=self.training_history['episode_losses'],
436
+ mode='lines', name='Loss', line=dict(color='red')),
437
+ row=1, col=2
438
+ )
439
+ fig.add_trace(
440
+ go.Scatter(x=episodes, y=self.training_history['epsilon_history'],
441
+ mode='lines', name='Epsilon', line=dict(color='green')),
442
+ row=2, col=1
443
+ )
444
+ window = min(20, len(episodes))
445
+ moving_avg = [np.mean(self.training_history['episode_rewards'][max(0, i-window):i+1])
446
+ for i in range(len(episodes))]
447
+ fig.add_trace(
448
+ go.Scatter(x=episodes, y=moving_avg,
449
+ mode='lines', name='MA Reward', line=dict(color='orange', width=2)),
450
+ row=2, col=2
451
+ )
452
+ fig.update_layout(height=600, showlegend=True, title_text="Training Progress")
453
+ return fig
454
+
455
+ def _create_simulation_plot(self, prices, actions, net_worths, portfolio_values, cash_balances):
456
+ fig = make_subplots(rows=2, cols=2,
457
+ subplot_titles=('Price & Actions', 'Portfolio Performance',
458
+ 'Portfolio Composition', 'Action Distribution'),
459
+ vertical_spacing=0.12,
460
+ horizontal_spacing=0.1)
461
+ steps = list(range(len(prices)))
462
+ fig.add_trace(
463
+ go.Scatter(x=steps, y=prices, mode='lines', name='Price', line=dict(color='blue')),
464
+ row=1, col=1
465
+ )
466
+ action_colors = ['gray', 'green', 'red', 'orange']
467
+ action_names = ['Hold', 'Buy', 'Sell', 'Close']
468
+ for action in range(4):
469
+ action_indices = [i for i, a in enumerate(actions) if a == action]
470
+ if action_indices:
471
+ action_prices = [prices[i] for i in action_indices]
472
+ fig.add_trace(
473
+ go.Scatter(x=action_indices, y=action_prices,
474
+ mode='markers', name=action_names[action],
475
+ marker=dict(color=action_colors[action], size=8)),
476
+ row=1, col=1
477
+ )
478
+ initial_balance = self.config.initial_balance
479
+ returns = [(nw - initial_balance) / initial_balance * 100 for nw in net_worths]
480
+ fig.add_trace(
481
+ go.Scatter(x=steps, y=net_worths, mode='lines', name='Net Worth', line=dict(color='purple')),
482
+ row=1, col=2
483
+ )
484
+ fig.add_trace(
485
+ go.Scatter(x=steps, y=returns, mode='lines', name='Return %', line=dict(color='orange'), yaxis='y2'),
486
+ row=1, col=2
487
+ )
488
+ fig.add_trace(
489
+ go.Scatter(x=steps, y=portfolio_values, mode='lines', name='Portfolio Value', line=dict(color='green')),
490
+ row=2, col=1
491
+ )
492
+ fig.add_trace(
493
+ go.Scatter(x=steps, y=cash_balances, mode='lines', name='Cash Balance', line=dict(color='blue')),
494
+ row=2, col=1
495
+ )
496
+ action_counts = [actions.count(i) for i in range(4)]
497
+ fig.add_trace(
498
+ go.Bar(x=action_names, y=action_counts,
499
+ marker_color=action_colors, name='Action Count'),
500
+ row=2, col=2
501
+ )
502
+ fig.update_layout(height=700, showlegend=True, title_text="Trading Simulation Results")
503
+ fig.update_yaxes(title_text="Return (%)", row=1, col=2, secondary_y=True)
504
+ fig.update_yaxes(title_text="Value ($)", row=1, col=2, secondary_y=False)
505
+ return fig
506
+
507
+ # ==== 5. Gradio Interface ====
508
  def create_interface():
509
  demo = TradingDemo()
510
  with gr.Blocks(theme=gr.themes.Soft(), title="AI Trading Demo") as interface:
511
  gr.Markdown("""
512
  # 🤖 Advanced AI Trading Demo
513
  **Deep Reinforcement Learning for Financial Markets**
514
+
515
  This demo shows a DQN agent learning to trade in simulated financial markets.
516
  The agent learns optimal trading strategies through reinforcement learning.
517
  """)
 
523
  risk = gr.Radio(["Low", "Medium", "High"], value="Medium", label="Risk Level")
524
  asset = gr.Radio(["Crypto", "Stock", "Forex"], value="Crypto", label="Asset Type")
525
  init_btn = gr.Button("🚀 Initialize System", variant="primary")
 
526
  with gr.Column(scale=2):
527
  gr.Markdown("## 📊 System Status")
528
  status = gr.Textbox(label="Status", value="Click 'Initialize System' to start", interactive=False)
 
533
  episodes = gr.Number(value=100, label="Training Episodes", precision=0)
534
  train_btn = gr.Button("🎯 Start Training", variant="primary")
535
  train_plot = gr.Plot(label="Training Progress")
 
536
  with gr.Column():
537
  gr.Markdown("## 📈 Simulation")
538
  steps = gr.Number(value=200, label="Simulation Steps", precision=0)
539
  sim_btn = gr.Button("▶️ Run Simulation", variant="primary")
540
  sim_plot = gr.Plot(label="Simulation Results")
541
 
 
542
  init_btn.click(
543
  demo.initialize,
544
  inputs=[balance, risk, asset],
545
  outputs=status
546
  )
 
547
  train_btn.click(
548
  demo.train,
549
  inputs=episodes,
550
  outputs=[status, train_plot]
551
  )
 
552
  sim_btn.click(
553
  demo.simulate,
554
  inputs=steps,
 
561
  2. **Initialize**: Click 'Initialize System' to set up the trading environment
562
  3. **Train**: Start training the AI agent (recommended: 100+ episodes)
563
  4. **Simulate**: Run a trading simulation to see the trained agent in action
564
+
565
  ## 🎮 Actions:
566
  - **0: Hold** - Maintain current position
567
  - **1: Buy** - Purchase asset (20% of balance)
 
570
  """)
571
  return interface
572
 
 
573
  demo = create_interface()