OmidSakaki commited on
Commit
ed6cc8e
·
verified ·
1 Parent(s): 0f09ca5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -413
app.py CHANGED
@@ -1,38 +1,30 @@
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
- import random
5
- from collections import deque
6
- from typing import Tuple, Iterator, Any
7
- import plotly.graph_objects as go
8
- from plotly.subplots import make_subplots
9
- import logging
10
- from datetime import datetime
11
-
12
- # Setup logging
13
- logging.basicConfig(level=logging.INFO)
14
- logger = logging.getLogger(__name__)
15
 
16
  class TradingConfig:
17
- """Configuration for trading environment"""
18
- def __init__(self, initial_balance: float = 10000.0):
19
- self.initial_balance = initial_balance
20
  self.max_steps = 1000
21
  self.transaction_cost = 0.001
 
 
22
  self.learning_rate = 0.0001
23
  self.gamma = 0.99
24
  self.epsilon_start = 1.0
25
  self.epsilon_min = 0.01
26
- self.epsilon_decay = 0.999
27
  self.batch_size = 32
28
- self.memory_size = 5000 # Reduced for Spaces memory limits
29
- self.target_update = 50
30
 
31
  class AdvancedTradingEnvironment:
32
- """Simplified trading environment compatible with HF Spaces"""
33
-
34
- def __init__(self, config: TradingConfig):
35
- self.config = config
36
  self.initial_balance = config.initial_balance
37
  self.balance = self.initial_balance
38
  self.position = 0.0
@@ -41,218 +33,147 @@ class AdvancedTradingEnvironment:
41
  self.max_steps = config.max_steps
42
  self.price_history = []
43
  self.sentiment_history = []
44
- self.action_history = []
45
-
46
- # Initialize data
47
  self._initialize_data()
48
-
49
- # Spaces: Discrete actions (0=Hold, 1=Buy, 2=Sell, 3=Close)
50
- self.action_space = gr.utils.Discrete(4)
51
- self.observation_space = None # For compatibility
52
-
53
  def _initialize_data(self):
54
- """Initialize price and sentiment history"""
55
- n_points = 50 # Reduced for faster init
56
  base_price = 100.0
57
-
58
  for i in range(n_points):
59
- # Realistic price simulation
60
- trend = np.sin(i * 0.2) * 5
61
- noise = np.random.normal(0, 3)
62
- price = max(50.0, base_price + trend + noise)
63
- self.price_history.append(price)
64
-
65
- # Correlated sentiment
66
- sentiment = 0.5 + 0.3 * np.tanh((price - base_price) / base_price) + np.random.normal(0, 0.1)
67
  self.sentiment_history.append(np.clip(sentiment, 0.0, 1.0))
68
-
69
  self.current_price = self.price_history[-1]
70
-
71
- def reset(self) -> Tuple[np.ndarray, dict]:
72
- """Reset environment to initial state"""
73
  self.balance = self.initial_balance
74
  self.position = 0.0
75
  self.step_count = 0
76
- self.action_history = []
77
-
78
- # Reset price series
79
- self._initialize_data()
80
-
81
  obs = self._get_observation()
82
  info = self._get_info()
83
-
84
- logger.info(f"Environment reset: balance=${self.balance}, price=${self.current_price}")
85
  return obs, info
86
-
87
- def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, dict]:
88
- """Execute one step"""
89
  self.step_count += 1
90
- self.action_history.append(action)
 
 
91
 
92
- # Market evolution
93
- self._market_step()
 
94
 
95
- # Execute action
96
  reward = self._execute_action(action)
97
 
98
- # Check termination
99
- terminated = (self.balance <= 0 or self.step_count >= self.max_steps)
100
  truncated = False
101
 
102
  obs = self._get_observation()
103
  info = self._get_info()
104
 
105
  return obs, reward, terminated, truncated, info
106
-
107
- def _market_step(self):
108
- """Simulate market movement"""
109
- # Price evolution (geometric brownian motion approximation)
110
- drift = 0.0001
111
- volatility = 0.02
112
- price_change = drift + volatility * np.random.normal()
113
- self.current_price = max(10.0, self.current_price * np.exp(price_change))
114
-
115
- self.price_history.append(self.current_price)
116
- if len(self.price_history) > 100:
117
- self.price_history.pop(0)
118
-
119
- # Update sentiment
120
- sentiment_change = 0.1 * price_change + np.random.normal(0, 0.05)
121
- new_sentiment = np.clip(self.sentiment_history[-1] + sentiment_change, 0.0, 1.0)
122
- self.sentiment_history.append(new_sentiment)
123
- if len(self.sentiment_history) > 100:
124
- self.sentiment_history.pop(0)
125
-
126
- def _execute_action(self, action: int) -> float:
127
- """Execute trading action and return reward"""
128
  reward = 0.0
129
  prev_net_worth = self.balance + self.position * self.current_price
130
 
131
- # Transaction cost
132
- cost_factor = self.config.transaction_cost
133
-
134
  if action == 1: # Buy
135
- if self.balance > 100: # Minimum trade size
136
- trade_value = min(self.balance * 0.1, self.balance - 100)
137
- cost = trade_value * (1 + cost_factor)
138
- if cost <= self.balance:
139
- shares = trade_value / self.current_price
140
- self.position += shares
141
- self.balance -= cost
142
- reward -= cost_factor * trade_value
143
 
144
  elif action == 2: # Sell
145
  if self.position > 0:
146
- sell_shares = min(self.position * 0.1, self.position)
147
- proceeds = sell_shares * self.current_price * (1 - cost_factor)
148
- self.position -= sell_shares
149
  self.balance += proceeds
150
- reward -= cost_factor * sell_shares * self.current_price
151
 
152
- elif action == 3: # Close position
153
  if self.position > 0:
154
- proceeds = self.position * self.current_price * (1 - cost_factor)
155
  self.balance += proceeds
156
  self.position = 0
157
- reward -= cost_factor * abs(self.position) * self.current_price * self.current_price
158
 
159
- # Portfolio reward
160
- current_net_worth = self.balance + self.position * self.current_price
161
- pnl = (current_net_worth - prev_net_worth) / self.initial_balance
162
- reward += pnl * 100 # Scale reward
163
 
164
  return reward
165
-
166
- def _get_observation(self) -> np.ndarray:
167
- """Get current state observation"""
168
- # Use recent market data
169
- recent_prices = self.price_history[-10:] if len(self.price_history) >= 10 else self.price_history
170
- recent_sentiments = self.sentiment_history[-10:] if len(self.sentiment_history) >= 10 else self.sentiment_history
171
 
172
- # Features
173
  features = [
174
- self.balance / self.initial_balance, # Normalized balance
175
- (self.position * self.current_price) / self.initial_balance, # Position value
176
- self.current_price / 200.0, # Normalized price
177
- np.mean(recent_prices) / 200.0, # Short-term average
178
- np.std(recent_prices) / 200.0 if len(recent_prices) > 1 else 0, # Volatility
179
- np.mean(recent_sentiments), # Average sentiment
180
- np.std(recent_sentiments) if len(recent_sentiments) > 1 else 0, # Sentiment volatility
181
- self.step_count / self.max_steps, # Progress
182
- len([a for a in self.action_history[-5:] if a == 1]) / 5.0, # Recent buy ratio
183
- len([a for a in self.action_history[-5:] if a == 2]) / 5.0, # Recent sell ratio
184
- np.mean(np.diff(recent_prices)) if len(recent_prices) > 1 else 0, # Price momentum
185
- 0.0 # Padding
186
  ]
187
 
188
  return np.array(features[:12], dtype=np.float32)
189
-
190
- def _get_info(self) -> dict:
191
- """Get environment info"""
192
  net_worth = self.balance + self.position * self.current_price
193
- return {
194
- 'net_worth': net_worth,
195
- 'balance': self.balance,
196
- 'position': self.position,
197
- 'current_price': self.current_price,
198
- 'step': self.step_count,
199
- 'pnl': (net_worth - self.initial_balance) / self.initial_balance * 100
200
- }
201
 
202
  class DQNAgent:
203
- """Deep Q-Network agent"""
204
-
205
- def __init__(self, state_dim: int, action_dim: int, config: TradingConfig, device: str = 'cpu'):
206
- self.state_dim = state_dim
207
- self.action_dim = action_dim
208
  self.device = torch.device(device)
209
- self.config = config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
- # Neural networks
212
- self.q_network = self._build_network().to(self.device)
213
- self.target_network = self._build_network().to(self.device)
214
  self.target_network.load_state_dict(self.q_network.state_dict())
215
 
216
  self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=config.learning_rate)
217
  self.memory = deque(maxlen=config.memory_size)
 
218
  self.epsilon = config.epsilon_start
 
 
 
 
219
  self.steps = 0
220
-
221
- def _build_network(self) -> torch.nn.Module:
222
- return torch.nn.Sequential(
223
- torch.nn.Linear(self.state_dim, 128),
224
- torch.nn.ReLU(),
225
- torch.nn.Linear(128, 128),
226
- torch.nn.ReLU(),
227
- torch.nn.Linear(128, self.action_dim)
228
- )
229
-
230
- def select_action(self, state: np.ndarray, training: bool = True) -> int:
231
- """Epsilon-greedy action selection"""
232
  if training and random.random() < self.epsilon:
233
- return random.randint(0, self.action_dim - 1)
234
-
235
- state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
236
  with torch.no_grad():
237
- q_values = self.q_network(state_tensor)
238
- action = q_values.argmax().item()
239
-
240
- if training:
241
- self.epsilon = max(self.config.epsilon_min, self.epsilon * self.config.epsilon_decay)
242
-
243
- return action
244
-
245
  def store_transition(self, state, action, reward, next_state, done):
246
- """Store experience"""
247
- self.memory.append((state, action, float(reward), next_state, done))
248
-
249
- def update(self) -> float:
250
- """Train the network"""
251
- if len(self.memory) < self.config.batch_size:
252
  return 0.0
253
 
254
- # Sample batch
255
- batch = random.sample(self.memory, self.config.batch_size)
256
  states, actions, rewards, next_states, dones = zip(*batch)
257
 
258
  states = torch.FloatTensor(np.array(states)).to(self.device)
@@ -261,275 +182,98 @@ class DQNAgent:
261
  next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
262
  dones = torch.FloatTensor(dones).to(self.device)
263
 
264
- # Q-learning
265
  current_q = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
266
  next_q = self.target_network(next_states).max(1)[0]
267
- target_q = rewards + self.config.gamma * next_q * (1 - dones)
268
 
269
  loss = torch.nn.MSELoss()(current_q, target_q)
270
 
271
  self.optimizer.zero_grad()
272
  loss.backward()
273
- torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 1.0)
274
  self.optimizer.step()
275
 
276
  self.steps += 1
277
-
278
- # Update target network
279
- if self.steps % self.config.target_update == 0:
280
  self.target_network.load_state_dict(self.q_network.state_dict())
281
 
 
 
282
  return loss.item()
283
 
284
  class TradingDemo:
285
- """Main demo class with proper state management"""
286
-
287
  def __init__(self):
288
- self.config = None
289
  self.env = None
290
  self.agent = None
291
  self.device = 'cpu'
292
- self.is_initialized = False
293
- self.training_history = []
294
-
295
- def initialize(self, balance: float) -> str:
296
- """Initialize environment and agent"""
297
- try:
298
- self.config = TradingConfig(initial_balance=float(balance))
299
- self.env = AdvancedTradingEnvironment(self.config)
300
- state_dim = 12 # Fixed observation size
301
- self.agent = DQNAgent(state_dim, 4, self.config, self.device)
302
- self.is_initialized = True
303
- self.training_history = []
304
-
305
- logger.info(f"Initialized: balance=${balance}")
306
- return f"✅ Environment initialized! Balance: ${balance:,.2f}"
307
-
308
- except Exception as e:
309
- logger.error(f"Initialization failed: {e}")
310
- self.is_initialized = False
311
- return f"❌ Initialization failed: {str(e)}"
312
-
313
- def train(self, episodes: int) -> Iterator[Tuple[str, go.Figure]]:
314
- """Train the agent with progress updates"""
315
- if not self.is_initialized or self.env is None or self.agent is None:
316
- yield "❌ Please initialize first!", None
317
- return
318
-
319
- try:
320
- episodes = int(episodes)
321
- episode_rewards = []
322
-
323
- for ep in range(episodes):
324
- obs, _ = self.env.reset()
325
- total_reward = 0
326
- done = False
327
-
328
- while not done:
329
- action = self.agent.select_action(obs, training=True)
330
- next_obs, reward, done, _, info = self.env.step(action)
331
-
332
- self.agent.store_transition(obs, action, reward, next_obs, done)
333
- obs = next_obs
334
- total_reward += reward
335
-
336
- # Update agent
337
- loss = self.agent.update()
338
- episode_rewards.append(total_reward)
339
- self.training_history.append({
340
- 'episode': ep + 1,
341
- 'reward': total_reward,
342
- 'loss': loss
343
- })
344
-
345
- # Progress update every 10 episodes
346
- if (ep + 1) % 10 == 0 or ep == episodes - 1:
347
- avg_reward = np.mean(episode_rewards[-10:])
348
- progress = (ep + 1) / episodes * 100
349
-
350
- status = (f"📈 Training Progress: {ep+1}/{episodes} ({progress:.1f}%)\n"
351
- f"🎯 Episode Reward: {total_reward:.2f}\n"
352
- f"📊 Avg Reward (last 10): {avg_reward:.2f}\n"
353
- f"📉 Loss: {loss:.4f}")
354
-
355
- # Create progress chart
356
- chart = self._create_training_chart()
357
- yield status, chart
358
-
359
- final_status = f"✅ Training completed! Final Avg Reward: {np.mean(episode_rewards):.2f}"
360
- yield final_status, self._create_training_chart()
361
-
362
- except Exception as e:
363
- logger.error(f"Training error: {e}")
364
- yield f"❌ Training failed: {str(e)}", None
365
-
366
- def simulate(self, steps: int) -> Tuple[str, go.Figure, go.Figure]:
367
- """Run trading simulation"""
368
- if not self.is_initialized or self.env is None or self.agent is None:
369
- return "❌ Please initialize and train first!", None, None
370
-
371
- try:
372
- steps = int(steps)
373
  obs, _ = self.env.reset()
374
-
375
- prices = []
376
- actions = []
377
- net_worths = []
378
- rewards = []
379
-
380
- for step in range(steps):
381
- action = self.agent.select_action(obs, training=False)
382
  next_obs, reward, done, _, info = self.env.step(action)
383
-
384
- prices.append(self.env.current_price)
385
- actions.append(action)
386
- net_worths.append(info['net_worth'])
387
- rewards.append(reward)
388
-
389
  obs = next_obs
390
- if done:
391
- break
392
-
393
- # Create charts
394
- price_chart = self._create_price_chart(prices, actions)
395
- performance_chart = self._create_performance_chart(net_worths, rewards)
396
-
397
- final_pnl = (net_worths[-1] - self.config.initial_balance) / self.config.initial_balance * 100
398
- status = f"✅ Simulation completed! Steps: {len(prices)}, Final P&L: {final_pnl:+.2f}%"
399
-
400
- return status, price_chart, performance_chart
401
-
402
- except Exception as e:
403
- logger.error(f"Simulation error: {e}")
404
- return f"❌ Simulation failed: {str(e)}", None, None
405
-
406
- def _create_training_chart(self) -> go.Figure:
407
- """Create training progress chart"""
408
- if not self.training_history:
409
- fig = go.Figure()
410
- fig.update_layout(title="Training in progress...", height=400)
411
- return fig
412
-
413
- episodes = [h['episode'] for h in self.training_history]
414
- rewards = [h['reward'] for h in self.training_history]
415
- losses = [h['loss'] for h in self.training_history]
416
-
417
- fig = make_subplots(rows=2, cols=1, subplot_titles=["Rewards", "Loss"])
418
-
419
- fig.add_trace(go.Scatter(x=episodes, y=rewards, mode='lines+markers', name='Reward'), row=1, col=1)
420
- fig.add_trace(go.Scatter(x=episodes, y=losses, mode='lines', name='Loss'), row=2, col=1)
421
-
422
- fig.update_layout(height=500, title="Training Progress", showlegend=True)
423
- return fig
424
-
425
- def _create_price_chart(self, prices: list, actions: list) -> go.Figure:
426
- """Create price action chart"""
427
  fig = go.Figure()
428
-
429
- # Price line
430
- fig.add_trace(go.Scatter(x=list(range(len(prices))), y=prices, mode='lines',
431
- name='Price', line=dict(color='blue', width=2)))
432
-
433
- # Action markers
434
- buy_indices = [i for i, a in enumerate(actions) if a == 1]
435
- sell_indices = [i for i, a in enumerate(actions) if a == 2]
436
-
437
- if buy_indices:
438
- buy_prices = [prices[i] for i in buy_indices]
439
- fig.add_trace(go.Scatter(x=buy_indices, y=buy_prices, mode='markers',
440
- name='Buy', marker=dict(color='green', size=10, symbol='triangle-up')))
441
-
442
- if sell_indices:
443
- sell_prices = [prices[i] for i in sell_indices]
444
- fig.add_trace(go.Scatter(x=sell_indices, y=sell_prices, mode='markers',
445
- name='Sell', marker=dict(color='red', size=10, symbol='triangle-down')))
446
-
447
- fig.update_layout(title="Price Action Simulation", height=400, showlegend=True)
448
- return fig
449
-
450
- def _create_performance_chart(self, net_worths: list, rewards: list) -> go.Figure:
451
- """Create performance dashboard"""
452
- fig = make_subplots(rows=2, cols=1, subplot_titles=["Net Worth", "Rewards"])
453
-
454
- fig.add_trace(go.Scatter(x=list(range(len(net_worths))), y=net_worths,
455
- mode='lines', name='Net Worth'), row=1, col=1)
456
- fig.add_trace(go.Bar(x=list(range(len(rewards))), y=rewards,
457
- name='Rewards'), row=2, col=1)
458
-
459
- fig.update_layout(title="Performance", height=500, showlegend=True)
460
- return fig
461
 
462
- # Initialize demo
463
  demo = TradingDemo()
464
 
465
- # Create Gradio interface
466
- with gr.Blocks(theme=gr.themes.Soft(), title="🤖 AI Trading Demo") as interface:
467
- gr.Markdown("# 🚀 AI Trading Assistant\nReinforcement Learning Trading Demo")
468
 
469
- # Configuration
470
  with gr.Row():
471
- with gr.Column(scale=1):
472
- gr.Markdown("## ⚙️ Configuration")
473
- balance = gr.Slider(1000, 50000, value=10000, step=1000, label="Initial Balance ($)")
474
- init_btn = gr.Button("🚀 Initialize Environment", variant="primary")
475
- status = gr.Textbox(label="Status", interactive=False, lines=2)
476
-
477
- # Training
478
- with gr.Row():
479
- with gr.Column(scale=1):
480
- gr.Markdown("## 🎓 Training")
481
- episodes = gr.Slider(10, 100, value=30, step=5, label="Training Episodes")
482
- train_btn = gr.Button("🤖 Start Training", variant="primary")
483
-
484
- with gr.Column(scale=2):
485
- train_status = gr.Textbox(label="Training Progress", interactive=False, lines=3)
486
- train_chart = gr.Plot(label="Training Metrics")
487
-
488
- # Simulation
489
- with gr.Row():
490
- with gr.Column(scale=1):
491
- gr.Markdown("## 🎯 Simulation")
492
- sim_steps = gr.Slider(20, 200, value=50, step=10, label="Simulation Steps")
493
- sim_btn = gr.Button("▶️ Run Simulation", variant="secondary")
494
-
495
- with gr.Column(scale=2):
496
- sim_status = gr.Textbox(label="Simulation Status", interactive=False)
497
- price_chart = gr.Plot(label="Price Chart")
498
-
499
- performance_chart = gr.Plot(label="Performance Dashboard")
500
-
501
- # Event handlers with proper error handling
502
- init_btn.click(
503
- demo.initialize,
504
- inputs=[balance],
505
- outputs=[status]
506
- )
507
 
508
- train_btn.click(
509
- demo.train,
510
- inputs=[episodes],
511
- outputs=[train_status, train_chart]
512
- )
513
 
514
- sim_btn.click(
515
- demo.simulate,
516
- inputs=[sim_steps],
517
- outputs=[sim_status, price_chart, performance_chart]
518
- )
519
 
520
- gr.Markdown("""
521
- ## 📝 Usage Instructions:
522
- 1. **Initialize**: Set balance and click "Initialize Environment"
523
- 2. **Train**: Adjust episodes and click "Start Training"
524
- 3. **Simulate**: Set steps and click "Run Simulation"
525
 
526
- **Note**: Training must complete before simulation!
527
- """)
 
528
 
529
- if __name__ == "__main__":
530
- interface.launch(
531
- server_name="0.0.0.0",
532
- server_port=7860,
533
- share=False,
534
- debug=True
535
- )
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
+ from pathlib import Path
5
+ from typing import Dict, Tuple, Any
6
+ from loguru import logger
7
+ import yaml
8
+ from gymnasium import spaces
 
 
 
 
 
 
9
 
10
  class TradingConfig:
11
+ def __init__(self):
12
+ self.initial_balance = 10000.0
 
13
  self.max_steps = 1000
14
  self.transaction_cost = 0.001
15
+ self.risk_level = "Medium"
16
+ self.asset_type = "Crypto"
17
  self.learning_rate = 0.0001
18
  self.gamma = 0.99
19
  self.epsilon_start = 1.0
20
  self.epsilon_min = 0.01
21
+ self.epsilon_decay = 0.9995
22
  self.batch_size = 32
23
+ self.memory_size = 10000
24
+ self.target_update = 100
25
 
26
  class AdvancedTradingEnvironment:
27
+ def __init__(self, config):
 
 
 
28
  self.initial_balance = config.initial_balance
29
  self.balance = self.initial_balance
30
  self.position = 0.0
 
33
  self.max_steps = config.max_steps
34
  self.price_history = []
35
  self.sentiment_history = []
 
 
 
36
  self._initialize_data()
37
+ self.action_space = spaces.Discrete(4)
38
+ self.observation_space = spaces.Box(low=-2.0, high=2.0, shape=(12,), dtype=np.float32)
39
+
 
 
40
  def _initialize_data(self):
41
+ n_points = 100
 
42
  base_price = 100.0
 
43
  for i in range(n_points):
44
+ price = base_price + np.sin(i * 0.1) * 10 + np.random.normal(0, 2)
45
+ self.price_history.append(max(10.0, price))
46
+ sentiment = 0.5 + np.random.normal(0, 0.1)
 
 
 
 
 
47
  self.sentiment_history.append(np.clip(sentiment, 0.0, 1.0))
 
48
  self.current_price = self.price_history[-1]
49
+
50
+ def reset(self):
 
51
  self.balance = self.initial_balance
52
  self.position = 0.0
53
  self.step_count = 0
54
+ self.price_history = [100.0 + np.random.normal(0, 5)]
55
+ self.sentiment_history = [0.5]
 
 
 
56
  obs = self._get_observation()
57
  info = self._get_info()
 
 
58
  return obs, info
59
+
60
+ def step(self, action):
 
61
  self.step_count += 1
62
+ price_change = np.random.normal(0, 0.02)
63
+ self.current_price = max(10.0, self.current_price * (1 + price_change))
64
+ self.price_history.append(self.current_price)
65
 
66
+ sentiment_change = np.random.normal(0, 0.05)
67
+ new_sentiment = np.clip(self.sentiment_history[-1] + sentiment_change, 0.0, 1.0)
68
+ self.sentiment_history.append(new_sentiment)
69
 
 
70
  reward = self._execute_action(action)
71
 
72
+ terminated = self.balance <= 0 or self.step_count >= self.max_steps
 
73
  truncated = False
74
 
75
  obs = self._get_observation()
76
  info = self._get_info()
77
 
78
  return obs, reward, terminated, truncated, info
79
+
80
+ def _execute_action(self, action):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  reward = 0.0
82
  prev_net_worth = self.balance + self.position * self.current_price
83
 
 
 
 
84
  if action == 1: # Buy
85
+ trade_amount = min(self.balance * 0.2, self.balance)
86
+ cost = trade_amount
87
+ if cost <= self.balance:
88
+ self.position += trade_amount / self.current_price
89
+ self.balance -= cost
 
 
 
90
 
91
  elif action == 2: # Sell
92
  if self.position > 0:
93
+ sell_amount = min(self.position * 0.2, self.position)
94
+ proceeds = sell_amount * self.current_price
95
+ self.position -= sell_amount
96
  self.balance += proceeds
 
97
 
98
+ elif action == 3: # Close
99
  if self.position > 0:
100
+ proceeds = self.position * self.current_price
101
  self.balance += proceeds
102
  self.position = 0
 
103
 
104
+ net_worth = self.balance + self.position * self.current_price
105
+ reward = (net_worth - prev_net_worth) / self.initial_balance * 100
 
 
106
 
107
  return reward
108
+
109
+ def _get_observation(self):
110
+ recent_prices = self.price_history[-10:] if len(self.price_history) >= 10 else [self.current_price] * 10
111
+ recent_sentiments = self.sentiment_history[-10:] if len(self.sentiment_history) >= 10 else [0.5] * 10
 
 
112
 
 
113
  features = [
114
+ self.balance / self.initial_balance,
115
+ self.position * self.current_price / self.initial_balance,
116
+ self.current_price / 100.0,
117
+ np.mean(recent_prices) / 100.0,
118
+ np.std(recent_prices) / 100.0,
119
+ np.mean(recent_sentiments),
120
+ np.std(recent_sentiments),
121
+ self.step_count / self.max_steps,
122
+ 0.0, 0.0, 0.0, 0.0 # Padding
 
 
 
123
  ]
124
 
125
  return np.array(features[:12], dtype=np.float32)
126
+
127
+ def _get_info(self):
 
128
  net_worth = self.balance + self.position * self.current_price
129
+ return {'net_worth': net_worth}
 
 
 
 
 
 
 
130
 
131
  class DQNAgent:
132
+ def __init__(self, state_dim, action_dim, config, device='cpu'):
 
 
 
 
133
  self.device = torch.device(device)
134
+ self.q_network = torch.nn.Sequential(
135
+ torch.nn.Linear(state_dim, 128),
136
+ torch.nn.ReLU(),
137
+ torch.nn.Linear(128, 128),
138
+ torch.nn.ReLU(),
139
+ torch.nn.Linear(128, action_dim)
140
+ ).to(self.device)
141
+
142
+ self.target_network = torch.nn.Sequential(
143
+ torch.nn.Linear(state_dim, 128),
144
+ torch.nn.ReLU(),
145
+ torch.nn.Linear(128, 128),
146
+ torch.nn.ReLU(),
147
+ torch.nn.Linear(128, action_dim)
148
+ ).to(self.device)
149
 
 
 
 
150
  self.target_network.load_state_dict(self.q_network.state_dict())
151
 
152
  self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=config.learning_rate)
153
  self.memory = deque(maxlen=config.memory_size)
154
+ self.gamma = config.gamma
155
  self.epsilon = config.epsilon_start
156
+ self.epsilon_min = config.epsilon_min
157
+ self.epsilon_decay = config.epsilon_decay
158
+ self.batch_size = config.batch_size
159
+ self.target_update = config.target_update
160
  self.steps = 0
161
+
162
+ def select_action(self, state, training=True):
163
+ state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
 
 
 
 
 
 
 
 
 
164
  if training and random.random() < self.epsilon:
165
+ return random.randint(0, 3)
 
 
166
  with torch.no_grad():
167
+ return self.q_network(state).argmax(1).item()
168
+
 
 
 
 
 
 
169
  def store_transition(self, state, action, reward, next_state, done):
170
+ self.memory.append((state, action, reward, next_state, done))
171
+
172
+ def update(self):
173
+ if len(self.memory) < self.batch_size:
 
 
174
  return 0.0
175
 
176
+ batch = random.sample(self.memory, self.batch_size)
 
177
  states, actions, rewards, next_states, dones = zip(*batch)
178
 
179
  states = torch.FloatTensor(np.array(states)).to(self.device)
 
182
  next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
183
  dones = torch.FloatTensor(dones).to(self.device)
184
 
 
185
  current_q = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
186
  next_q = self.target_network(next_states).max(1)[0]
187
+ target_q = rewards + self.gamma * next_q * (1 - dones)
188
 
189
  loss = torch.nn.MSELoss()(current_q, target_q)
190
 
191
  self.optimizer.zero_grad()
192
  loss.backward()
 
193
  self.optimizer.step()
194
 
195
  self.steps += 1
196
+ if self.steps % self.target_update == 0:
 
 
197
  self.target_network.load_state_dict(self.q_network.state_dict())
198
 
199
+ self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
200
+
201
  return loss.item()
202
 
203
  class TradingDemo:
 
 
204
  def __init__(self):
205
+ self.config = TradingConfig()
206
  self.env = None
207
  self.agent = None
208
  self.device = 'cpu'
209
+
210
+ def initialize(self, balance, risk, asset):
211
+ self.config.initial_balance = balance
212
+ self.config.risk_level = risk
213
+ self.config.asset_type = asset
214
+ self.env = AdvancedTradingEnvironment(self.config)
215
+ self.agent = DQNAgent(12, 4, self.config, self.device)
216
+ return "✅ Initialized!"
217
+
218
+ def train(self, episodes):
219
+ for ep in range(episodes):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  obs, _ = self.env.reset()
221
+ total_reward = 0
222
+ done = False
223
+ while not done:
224
+ action = self.agent.select_action(obs)
 
 
 
 
225
  next_obs, reward, done, _, info = self.env.step(action)
226
+ self.agent.store_transition(obs, action, reward, next_obs, done)
 
 
 
 
 
227
  obs = next_obs
228
+ total_reward += reward
229
+ self.agent.update()
230
+ yield f"Episode {ep+1}/{episodes} | Reward: {total_reward:.2f}", None
231
+ yield "✅ Training complete!", None
232
+
233
+ def simulate(self, steps):
234
+ obs, _ = self.env.reset()
235
+ prices = []
236
+ actions = []
237
+ net_worths = []
238
+ for _ in range(steps):
239
+ action = self.agent.select_action(obs, training=False)
240
+ next_obs, reward, done, _, info = self.env.step(action)
241
+ prices.append(self.env.current_price)
242
+ actions.append(action)
243
+ net_worths.append(info['net_worth'])
244
+ obs = next_obs
245
+ if done:
246
+ break
247
+
248
+ import plotly.graph_objects as go
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  fig = go.Figure()
250
+ fig.add_trace(go.Scatter(y=prices, mode='lines', name='Price'))
251
+ fig.add_trace(go.Scatter(y=net_worths, mode='lines', name='Net Worth'))
252
+ return "✅ Simulation complete!", fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
 
254
  demo = TradingDemo()
255
 
256
+ with gr.Blocks() as interface:
257
+ gr.Markdown("# Trading AI Demo")
 
258
 
 
259
  with gr.Row():
260
+ balance = gr.Slider(1000, 50000, 10000, label="Balance")
261
+ risk = gr.Radio(["Low", "Medium", "High"], value="Medium", label="Risk")
262
+ asset = gr.Radio(["Crypto", "Stock", "Forex"], value="Crypto", label="Asset")
263
+ init_btn = gr.Button("Initialize")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
+ status = gr.Textbox(label="Status")
 
 
 
 
266
 
267
+ episodes = gr.Number(value=50, label="Episodes")
268
+ train_btn = gr.Button("Train")
269
+ train_plot = gr.Plot()
 
 
270
 
271
+ steps = gr.Number(value=100, label="Simulation Steps")
272
+ sim_btn = gr.Button("Simulate")
273
+ sim_plot = gr.Plot()
 
 
274
 
275
+ init_btn.click(demo.initialize, [balance, risk, asset], status)
276
+ train_btn.click(demo.train, episodes, [status, train_plot])
277
+ sim_btn.click(demo.simulate, steps, [status, sim_plot])
278
 
279
+ interface.launch()