OmidSakaki commited on
Commit
dd46062
·
verified ·
1 Parent(s): cb87d3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +246 -625
app.py CHANGED
@@ -1,658 +1,279 @@
1
  import gradio as gr
2
  import numpy as np
3
- import pandas as pd
4
  import torch
5
- import time
6
- import sys
7
- import os
8
- import threading
9
- import logging
10
- from datetime import datetime, timedelta
11
- from typing import Dict, Any, Optional, Tuple
12
- import warnings
13
- warnings.filterwarnings('ignore')
14
 
15
- # Configure logging
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger(__name__)
18
-
19
- # Create directories safely
20
- def setup_directories():
21
- """Setup project directories with error handling"""
22
- directories = ['src', 'src/environments', 'src/agents', 'src/sentiment', 'src/visualizers', 'src/utils']
23
- for dir_path in directories:
24
- try:
25
- os.makedirs(dir_path, exist_ok=True)
26
- init_file = os.path.join(dir_path, '__init__.py')
27
- if not os.path.exists(init_file):
28
- with open(init_file, 'w') as f:
29
- f.write('# Auto-generated init file\n')
30
- except Exception as e:
31
- logger.warning(f"Could not create directory {dir_path}: {e}")
32
 
33
- setup_directories()
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Add src to path safely
36
- if 'src' not in sys.path:
37
- sys.path.insert(0, 'src')
 
 
 
 
 
 
38
 
39
- # Safe imports with fallbacks
40
- try:
41
- from src.environments.advanced_trading_env import AdvancedTradingEnvironment
42
- from src.agents.advanced_agent import AdvancedTradingAgent
43
- from src.utils.config import TradingConfig
44
- from src.visualizers.chart_renderer import ChartRenderer
45
- CUSTOM_MODULES_AVAILABLE = True
46
- except ImportError as e:
47
- logger.warning(f"Custom modules not available: {e}. Using fallback mode.")
48
- CUSTOM_MODULES_AVAILABLE = False
49
- # Fallback imports will be defined below
50
 
51
- class SafeTradingDemo:
52
- """Safe trading demo with comprehensive error handling"""
53
-
54
- def __init__(self):
55
- self.env = None
56
- self.agent = None
57
- self.config = TradingConfig() if CUSTOM_MODULES_AVAILABLE else None
58
- self.renderer = ChartRenderer() if CUSTOM_MODULES_AVAILABLE else None
59
- self.current_state = None
60
- self.is_training = False
61
- self.training_complete = False
62
- self.live_trading = False
63
- self.trading_thread = None
64
- self.lock = threading.Lock()
65
- self.live_data: list = []
66
- self.performance_data: list = []
67
- self.action_history: list = []
68
- self.training_history: list = []
69
- self.initialized = False
70
- self.start_time = None
71
- self.last_update = None
72
 
73
- # Fallback environment and agent if custom modules unavailable
74
- if not CUSTOM_MODULES_AVAILABLE:
75
- self._setup_fallback_components()
76
-
77
- def _setup_fallback_components(self):
78
- """Setup basic fallback components"""
79
- class FallbackEnvironment:
80
- def __init__(self, initial_balance, risk_level, asset_type):
81
- self.initial_balance = initial_balance
82
- self.current_balance = initial_balance
83
- self.position = 0
84
- self.current_price = 100.0
85
-
86
- def reset(self):
87
- self.current_balance = self.initial_balance
88
- self.position = 0
89
- self.current_price = 100.0 + np.random.normal(0, 5)
90
- return np.random.rand(84, 84, 4).astype(np.float32)
91
-
92
- def step(self, action):
93
- self.current_price += np.random.normal(0, 1)
94
- reward = np.random.normal(0, 10)
95
- self.current_balance += reward * 0.1
96
- done = False
97
- info = {'net_worth': self.current_balance}
98
- next_state = np.random.rand(84, 84, 4).astype(np.float32)
99
- return next_state, reward, done, info
100
 
101
- class FallbackAgent:
102
- def __init__(self, state_dim, action_dim):
103
- self.epsilon = 1.0
104
- self.action_dim = action_dim
105
-
106
- def select_action(self, state):
107
- if np.random.random() < self.epsilon:
108
- return np.random.randint(0, self.action_dim)
109
- return 0
110
-
111
- def store_transition(self, *args):
112
- pass
113
-
114
- def update(self):
115
- self.epsilon = max(0.01, self.epsilon * 0.999)
116
- return np.random.random()
117
 
118
- self.FallbackEnvironment = FallbackEnvironment
119
- self.FallbackAgent = FallbackAgent
120
-
121
- def initialize_environment(self, initial_balance: float, risk_level: str,
122
- asset_type: str) -> str:
123
- """Initialize trading environment with comprehensive validation"""
124
- try:
125
- with self.lock:
126
- if self.live_trading:
127
- return "⚠️ لطفاً ابتدا معاملات را متوقف کنید"
128
-
129
- # Validate inputs
130
- if initial_balance < 1000:
131
- return "❌ سرمایه اولیه باید حداقل 1000 دلار باشد"
132
- if risk_level not in ["Low", "Medium", "High"]:
133
- return "❌ سطح ریسک نامعتبر"
134
- if asset_type not in ["Crypto", "Stock", "Forex"]:
135
- return "❌ نوع دارایی نامعتبر"
136
-
137
- logger.info(f"Initializing environment: balance={initial_balance}, "
138
- f"risk={risk_level}, asset={asset_type}")
139
-
140
- if CUSTOM_MODULES_AVAILABLE:
141
- self.env = AdvancedTradingEnvironment(
142
- initial_balance=float(initial_balance),
143
- risk_level=risk_level,
144
- asset_type=asset_type,
145
- use_sentiment=False # Disable for demo stability
146
- )
147
- self.agent = AdvancedTradingAgent(
148
- state_dim=(84, 84, 4),
149
- action_dim=4,
150
- learning_rate=self.config.learning_rate
151
- )
152
- else:
153
- self.env = self.FallbackEnvironment(initial_balance, risk_level, asset_type)
154
- self.agent = self.FallbackAgent((84, 84, 4), 4)
155
-
156
- self.current_state = self.env.reset()
157
- self._reset_data()
158
- self.initialized = True
159
- self.start_time = datetime.now()
160
-
161
- return (f"✅ محیط معاملاتی با موفقیت راه‌اندازی شد!\n\n"
162
- f"💰 سرمایه: ${initial_balance:,.2f}\n"
163
- f"🎯 نوع دارایی: {asset_type}\n"
164
- f"⚡ سطح ریسک: {risk_level}\n\n"
165
- f"🚀 آماده برای آموزش...")
166
-
167
- except Exception as e:
168
- logger.error(f"Environment initialization error: {e}", exc_info=True)
169
- return f"❌ خطا در راه‌اندازی: {str(e)}"
170
-
171
- def _reset_data(self):
172
- """Reset all data structures"""
173
- self.live_data.clear()
174
- self.performance_data.clear()
175
- self.action_history.clear()
176
- self.training_history.clear()
177
- self.training_complete = False
178
- self.live_trading = False
179
-
180
- def train_agent(self, num_episodes: int):
181
- """Train agent with progress updates and safety checks"""
182
- if not self.initialized:
183
- yield "❌ ابتدا محیط را راه‌اندازی کنید", None
184
- return
185
-
186
- if self.live_trading:
187
- yield "⚠️ ابت��ا معاملات را متوقف کنید", None
188
- return
189
 
190
- try:
191
- num_episodes = max(1, min(100, int(num_episodes))) # Limit episodes
192
- self.is_training = True
193
-
194
- for episode in range(num_episodes):
195
- if not self.is_training:
196
- break
197
-
198
- episode_start = time.time()
199
- state = self.env.reset()
200
- episode_reward = 0.0
201
- done = False
202
- step_count = 0
203
- max_steps = 200 # Safety limit
204
-
205
- while not done and step_count < max_steps:
206
- action = self.agent.select_action(state)
207
- next_state, reward, done, info = self.env.step(action)
208
-
209
- try:
210
- self.agent.store_transition(state, action, reward, next_state, done)
211
- except:
212
- pass # Ignore storage errors in demo
213
-
214
- state = next_state
215
- episode_reward += reward
216
- step_count += 1
217
-
218
- # Update agent
219
- try:
220
- loss = self.agent.update()
221
- except:
222
- loss = 0.0
223
-
224
- # Store episode data
225
- self.training_history.append({
226
- 'episode': episode,
227
- 'reward': episode_reward,
228
- 'net_worth': info.get('net_worth', 10000),
229
- 'loss': loss,
230
- 'steps': step_count,
231
- 'duration': time.time() - episode_start
232
- })
233
-
234
- # Create progress visualization
235
- try:
236
- progress_fig = self._create_training_chart()
237
- except:
238
- progress_fig = None
239
-
240
- # Progress status
241
- progress = (episode + 1) / num_episodes * 100
242
- status = (f"🔄 آموزش در حال انجام...\n"
243
- f"📊 اپیزود {episode+1}/{num_episodes} ({progress:.1f}%)\n"
244
- f"🎯 پاداش: {episode_reward:.2f}\n"
245
- f"💰 پرتفولیو: ${info.get('net_worth', 0):.2f}\n"
246
- f"📉 Loss: {loss:.4f}")
247
-
248
- yield status, progress_fig
249
- time.sleep(0.05) # Brief pause for UI responsiveness
250
-
251
- self.training_complete = True
252
- final_stats = self._calculate_training_stats()
253
- yield final_stats, self._create_training_chart()
254
-
255
- except Exception as e:
256
- logger.error(f"Training error: {e}", exc_info=True)
257
- self.is_training = False
258
- yield f"❌ خطا در آموزش: {str(e)}", None
259
- finally:
260
- self.is_training = False
261
-
262
- def _calculate_training_stats(self) -> str:
263
- """Calculate and format training statistics"""
264
- if not self.training_history:
265
- return "آمار آموزش در دسترس نیست"
266
 
267
- rewards = [h['reward'] for h in self.training_history]
268
- net_worths = [h['net_worth'] for h in self.training_history]
 
 
 
269
 
270
- return (f"✅ آموزش تکمیل شد!\n\n"
271
- f"📊 آمار نهایی:\n"
272
- f"• اپیزودها: {len(rewards)}\n"
273
- f"• میانگین پاداش: {np.mean(rewards):.2f}\n"
274
- f"• پاداش نهایی: {rewards[-1]:.2f}\n"
275
- f"• ارزش نهایی: ${net_worths[-1]:.2f}\n"
276
- f"🚀 آماده معامله Real-Time!")
277
-
278
- def _create_training_chart(self):
279
- """Create training progress chart"""
280
- try:
281
- if not self.training_history:
282
- return None
283
-
284
- import plotly.graph_objects as go
285
- from plotly.subplots import make_subplots
286
-
287
- episodes = [h['episode'] for h in self.training_history]
288
- rewards = [h['reward'] for h in self.training_history]
289
- net_worths = [h['net_worth'] for h in self.training_history]
290
-
291
- fig = make_subplots(rows=2, cols=1, subplot_titles=['پاداش اپیزود', 'ارزش پرتفولیو'])
292
-
293
- fig.add_trace(go.Scatter(x=episodes, y=rewards, mode='lines+markers',
294
- name='پاداش', line=dict(color='blue')), row=1, col=1)
295
- fig.add_trace(go.Scatter(x=episodes, y=net_worths, mode='lines+markers',
296
- name='پرتفولیو', line=dict(color='green')), row=2, col=1)
297
-
298
- fig.update_layout(height=400, title="📈 پیشرفت آموزش", template="plotly_white")
299
- return fig
300
-
301
- except:
302
- return None
303
-
304
- def start_live_trading(self) -> Tuple[str, Any, Any, Any]:
305
- """Start live trading with safety checks"""
306
- try:
307
- with self.lock:
308
- if not self.training_complete and CUSTOM_MODULES_AVAILABLE:
309
- return "⚠️ لطفاً ابتدا آموزش را کامل کنید", None, None, None
310
- if self.live_trading:
311
- return "⚠️ معاملات در حال اجراست", None, None, None
312
-
313
- self.live_trading = True
314
- self._reset_data()
315
- self._initialize_demo_data()
316
-
317
- # Start trading thread
318
- self.trading_thread = threading.Thread(target=self._trading_loop, daemon=True)
319
- self.trading_thread.start()
320
-
321
- time.sleep(0.5) # Allow thread to initialize
322
- return self._get_live_status()
323
-
324
- except Exception as e:
325
- logger.error(f"Live trading start error: {e}")
326
- return f"❌ خطا در شروع معاملات: {str(e)}", None, None, None
327
-
328
- def _trading_loop(self):
329
- """Safe trading loop with error handling"""
330
- max_steps = 500
331
- step = 0
332
 
333
- while self.live_trading and step < max_steps:
334
- try:
335
- with self.lock:
336
- if not self.initialized or self.env is None:
337
- break
338
-
339
- # Get action
340
- action = self.agent.select_action(self.current_state)
341
-
342
- # Execute step
343
- next_state, reward, done, info = self.env.step(action)
344
- self.current_state = next_state
345
-
346
- # Generate demo data
347
- self._generate_demo_step(action, reward, info)
348
-
349
- step += 1
350
- time.sleep(1) # 1 second intervals
351
-
352
- except Exception as e:
353
- logger.error(f"Trading loop error: {e}")
354
- time.sleep(2)
355
- continue
356
 
357
- self.live_trading = False
358
-
359
- def _generate_demo_step(self, action: int, reward: float, info: Dict):
360
- """Generate realistic demo data"""
361
- current_time = datetime.now()
362
- last_price = self.live_data[-1]['price'] if self.live_data else 100.0
363
 
364
- # Simulate price movement
365
- base_change = np.random.normal(0, 0.5)
366
- action_bias = {0: 0, 1: 0.3, 2: -0.3, 3: 0}[action]
367
- new_price = max(50, last_price + base_change + action_bias)
368
 
369
- # Update net worth
370
- net_worth = info.get('net_worth', self.env.initial_balance + reward * 10)
 
 
 
371
 
372
- self.live_data.append({
373
- 'timestamp': current_time,
374
- 'price': new_price,
375
- 'action': action,
376
- 'net_worth': net_worth,
377
- 'reward': reward,
378
- 'volume': np.random.randint(1000, 10000)
379
- })
 
 
 
380
 
381
- # Keep recent data only
382
- if len(self.live_data) > 100:
383
- self.live_data.pop(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
- self.action_history.append({
386
- 'step': len(self.action_history),
387
- 'action': action,
388
- 'reward': reward,
389
- 'price': new_price,
390
- 'timestamp': current_time
391
- })
392
-
393
- def _initialize_demo_data(self):
394
- """Initialize demo data"""
395
- base_price = 100.0
396
- for i in range(10):
397
- self.live_data.append({
398
- 'timestamp': datetime.now() - timedelta(seconds=10-i),
399
- 'price': base_price + np.random.normal(0, 2),
400
- 'action': 0,
401
- 'net_worth': self.env.initial_balance if self.env else 10000,
402
- 'reward': 0,
403
- 'volume': np.random.randint(1000, 5000)
404
- })
405
-
406
- def _get_live_status(self) -> Tuple[str, Any, Any, pd.DataFrame]:
407
- """Get current live trading status"""
408
- try:
409
- if not self.live_data:
410
- return "📊 در حال آماده‌سازی...", None, None, self._create_empty_stats()
411
-
412
- current = self.live_data[-1]
413
- initial = self.env.initial_balance if self.env else 10000
414
-
415
- profit = current['net_worth'] - initial
416
- profit_pct = (profit / initial) * 100
417
-
418
- action_names = ["نگهداری", "خرید", "فروش", "بستن"]
419
- status = (f"🎯 معاملات Real-Time فعال\n"
420
- f"💰 قیمت: ${current['price']:.2f}\n"
421
- f"🎪 اقدام: {action_names[current['action']]}\n"
422
- f"💼 پرتفولیو: ${current['net_worth']:.2f}\n"
423
- f"📈 P&L: ${profit:+.2f} ({profit_pct:+.2f}%)")
424
-
425
- live_fig = self._create_live_chart()
426
- perf_fig = self._create_performance_chart()
427
- stats_df = self._create_stats_table()
428
-
429
- return status, live_fig, perf_fig, stats_df
430
-
431
- except Exception as e:
432
- logger.error(f"Status update error: {e}")
433
- return "❌ خطا در به‌روزرسانی", None, None, self._create_empty_stats()
434
-
435
- def get_live_update(self) -> Tuple[str, Any, Any, pd.DataFrame]:
436
- """Manual live update trigger"""
437
- return self._get_live_status()
438
-
439
- def stop_live_trading(self) -> Tuple[str, Any, Any, pd.DataFrame]:
440
- """Stop live trading safely"""
441
- try:
442
- with self.lock:
443
- self.live_trading = False
444
- if self.trading_thread and self.trading_thread.is_alive():
445
- self.trading_thread.join(timeout=2.0)
446
-
447
- if self.live_data:
448
- final = self.live_data[-1]
449
- initial = self.env.initial_balance if self.env else 10000
450
- profit = final['net_worth'] - initial
451
- profit_pct = (profit / initial) * 100
452
-
453
- actions = [h['action'] for h in self.action_history]
454
- action_counts = {i: actions.count(i) for i in range(4)}
455
-
456
- status = (f"🛑 معاملات متوقف شد\n\n"
457
- f"📊 نتایج نهایی:\n"
458
- f"• سرمایه نهایی: ${final['net_worth']:.2f}\n"
459
- f"• سود/زیان: ${profit:+.2f} ({profit_pct:+.2f}%)\n"
460
- f"• کل اقدامات: {len(actions)}\n"
461
- f"• خرید: {action_counts[1]} | فروش: {action_counts[2]}")
462
- else:
463
- status = "معاملات متوقف شد - داده‌ای ثبت نشده"
464
-
465
- return status, self._create_live_chart(), self._create_performance_chart(), self._create_stats_table()
466
-
467
- except Exception as e:
468
- logger.error(f"Stop trading error: {e}")
469
- return f"❌ خطا در توقف: {str(e)}", None, None, self._create_empty_stats()
470
-
471
- def _create_live_chart(self):
472
- """Create live price chart"""
473
- try:
474
- if not self.live_data:
475
- import plotly.graph_objects as go
476
- fig = go.Figure()
477
- fig.update_layout(title="در حال آماده‌سازی...", height=400)
478
- return fig
479
-
480
- import plotly.graph_objects as go
481
- from plotly.subplots import make_subplots
482
-
483
- data = self.live_data[-50:] # Last 50 points
484
- times = [d['timestamp'] for d in data]
485
- prices = [d['price'] for d in data]
486
- volumes = [d['volume'] for d in data]
487
-
488
- fig = make_subplots(rows=2, cols=1, row_heights=[0.7, 0.3],
489
- subplot_titles=['قیمت', 'حجم'])
490
-
491
- fig.add_trace(go.Scatter(x=times, y=prices, mode='lines', name='قیمت',
492
- line=dict(color='cyan', width=2)), row=1, col=1)
493
-
494
- # Action markers
495
- for action, color, name in [(1, 'green', 'خرید'), (2, 'red', 'فروش')]:
496
- action_times = [d['timestamp'] for d in data if d['action'] == action]
497
- action_prices = [d['price'] for d in data if d['action'] == action]
498
- if action_times:
499
- fig.add_trace(go.Scatter(x=action_times, y=action_prices, mode='markers',
500
- marker=dict(color=color, size=10),
501
- name=name), row=1, col=1)
502
-
503
- fig.add_trace(go.Bar(x=times, y=volumes, name='حجم', marker_color='blue',
504
- opacity=0.6), row=2, col=1)
505
-
506
- fig.update_layout(height=450, template="plotly_dark", showlegend=True)
507
- return fig
508
-
509
- except:
510
- return None
511
-
512
- def _create_performance_chart(self):
513
- """Create performance chart"""
514
- try:
515
- if not self.live_data:
516
- import plotly.graph_objects as go
517
- fig = go.Figure()
518
- fig.update_layout(title="در حال آماده‌سازی...", height=300)
519
- return fig
520
-
521
- import plotly.graph_objects as go
522
- times = [d['timestamp'] for d in self.live_data]
523
- net_worths = [d['net_worth'] for d in self.live_data]
524
-
525
- fig = go.Figure()
526
- fig.add_trace(go.Scatter(x=times, y=net_worths, mode='lines', name='پرتفولیو',
527
- line=dict(color='green', width=3)))
528
-
529
- initial = self.env.initial_balance if self.env else 10000
530
- fig.add_hline(y=initial, line_dash="dash", line_color="red",
531
- annotation_text=f"سرمایه اولیه: ${initial:.2f}")
532
-
533
- fig.update_layout(height=350, title="عملکرد پرتفولیو", template="plotly_dark")
534
- return fig
535
-
536
- except:
537
- return None
538
-
539
- def _create_stats_table(self) -> pd.DataFrame:
540
- """Create statistics table"""
541
- try:
542
- if not self.live_data:
543
- return self._create_empty_stats()
544
-
545
- current = self.live_data[-1]
546
- initial = self.env.initial_balance if self.env else 10000
547
- profit = current['net_worth'] - initial
548
- profit_pct = (profit / initial) * 100
549
-
550
- stats = {
551
- 'متریک': ['💰 قیمت فعلی', '💼 پرتفولیو', '📈 P&L', '🎯 اقدام اخیر', '⏰ گام‌ها'],
552
- 'مقدار': [
553
- f"${current['price']:.2f}",
554
- f"${current['net_worth']:.2f}",
555
- f"${profit:+.2f} ({profit_pct:+.2f}%)",
556
- {0: 'نگهداری', 1: 'خرید', 2: 'فروش', 3: 'بستن'}[current['action']],
557
- str(len(self.action_history))
558
- ]
559
- }
560
- return pd.DataFrame(stats)
561
-
562
- except:
563
- return self._create_empty_stats()
564
-
565
- def _create_empty_stats(self) -> pd.DataFrame:
566
- """Create empty stats table"""
567
- return pd.DataFrame({
568
- 'متریک': ['وضعیت'],
569
- 'مقدار': ['در حال آماده‌سازی...']
570
- })
571
 
572
- def create_interface():
573
- """Create Gradio interface with proper error handling"""
574
- demo = SafeTradingDemo()
575
-
576
- with gr.Blocks(theme=gr.themes.Soft(), title="🤖 AI Trading Demo") as interface:
577
- gr.Markdown("# 🚀 هوش مصنوعی معامله‌گر هوشمند\n**آموزش و معاملات Real-Time**")
 
 
 
 
 
 
 
578
 
579
- with gr.Row():
580
- with gr.Column(scale=1):
581
- gr.Markdown("## ⚙️ تنظیمات")
582
- balance = gr.Slider(1000, 50000, value=10000, step=1000, label="سرمایه اولیه ($)")
583
- risk = gr.Radio(["Low", "Medium", "High"], value="Medium", label="سطح ریسک")
584
- asset = gr.Radio(["Crypto", "Stock", "Forex"], value="Crypto", label="نوع دارایی")
585
- init_btn = gr.Button("🚀 راه‌اندازی", variant="primary")
586
- init_status = gr.Textbox(label="وضعیت", interactive=False)
587
-
588
- with gr.Column(scale=2):
589
- status = gr.Textbox(label="وضعیت کلی", interactive=False, lines=4)
590
 
591
- with gr.Row():
592
- with gr.Column(scale=1):
593
- gr.Markdown("## 🎓 آموزش")
594
- episodes = gr.Slider(10, 100, value=20, step=5, label="اپیزودها")
595
- train_btn = gr.Button("🤖 شروع آموزش", variant="primary")
596
-
597
- with gr.Column(scale=2):
598
- train_plot = gr.Plot(label="پیشرفت آموزش")
599
 
600
- with gr.Row():
601
- with gr.Column(scale=1):
602
- gr.Markdown("## 🎯 معاملات زنده")
603
- start_btn = gr.Button("▶️ شروع معاملات", variant="secondary")
604
- update_btn = gr.Button("🔄 به‌روزرسانی", variant="secondary")
605
- stop_btn = gr.Button("⏹️ توقف", variant="stop")
606
-
607
- with gr.Column(scale=3):
608
- live_chart = gr.Plot(label="نمودار زنده")
609
 
610
- with gr.Row():
611
- perf_chart = gr.Plot(label="عملکرد")
612
- stats_table = gr.DataFrame(label="آمار", headers=["متریک", "مقدار"])
613
 
614
- # Event handlers
615
- init_btn.click(
616
- demo.initialize_environment,
617
- inputs=[balance, risk, asset],
618
- outputs=[init_status]
619
- )
620
 
621
- train_btn.click(
622
- demo.train_agent,
623
- inputs=[episodes],
624
- outputs=[status, train_plot]
625
- )
626
 
627
- start_btn.click(
628
- demo.start_live_trading,
629
- outputs=[status, live_chart, perf_chart, stats_table]
630
- )
631
 
632
- update_btn.click(
633
- demo.get_live_update,
634
- outputs=[status, live_chart, perf_chart, stats_table]
635
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
 
637
- stop_btn.click(
638
- demo.stop_live_trading,
639
- outputs=[status, live_chart, perf_chart, stats_table]
640
- )
641
-
642
- return interface, demo
643
 
644
- if __name__ == "__main__":
645
- logger.info("Starting AI Trading Demo...")
646
- interface, demo = create_interface()
 
 
 
 
 
 
 
 
 
647
 
648
- try:
649
- interface.launch(
650
- server_name="0.0.0.0",
651
- server_port=7860,
652
- share=False,
653
- show_error=True,
654
- quiet=False
655
- )
656
- except Exception as e:
657
- logger.error(f"Failed to launch interface: {e}")
658
- print(f"خطا در راه‌اندازی: {e}")
 
 
 
1
  import gradio as gr
2
  import numpy as np
 
3
  import torch
4
+ from pathlib import Path
5
+ from typing import Dict, Tuple, Any
6
+ from loguru import logger
7
+ import yaml
8
+ from gymnasium import spaces
 
 
 
 
9
 
10
+ class TradingConfig:
11
+ def __init__(self):
12
+ self.initial_balance = 10000.0
13
+ self.max_steps = 1000
14
+ self.transaction_cost = 0.001
15
+ self.risk_level = "Medium"
16
+ self.asset_type = "Crypto"
17
+ self.learning_rate = 0.0001
18
+ self.gamma = 0.99
19
+ self.epsilon_start = 1.0
20
+ self.epsilon_min = 0.01
21
+ self.epsilon_decay = 0.9995
22
+ self.batch_size = 32
23
+ self.memory_size = 10000
24
+ self.target_update = 100
 
 
25
 
26
+ class AdvancedTradingEnvironment:
27
+ def __init__(self, config):
28
+ self.initial_balance = config.initial_balance
29
+ self.balance = self.initial_balance
30
+ self.position = 0.0
31
+ self.current_price = 100.0
32
+ self.step_count = 0
33
+ self.max_steps = config.max_steps
34
+ self.price_history = []
35
+ self.sentiment_history = []
36
+ self._initialize_data()
37
+ self.action_space = spaces.Discrete(4)
38
+ self.observation_space = spaces.Box(low=-2.0, high=2.0, shape=(12,), dtype=np.float32)
39
 
40
+ def _initialize_data(self):
41
+ n_points = 100
42
+ base_price = 100.0
43
+ for i in range(n_points):
44
+ price = base_price + np.sin(i * 0.1) * 10 + np.random.normal(0, 2)
45
+ self.price_history.append(max(10.0, price))
46
+ sentiment = 0.5 + np.random.normal(0, 0.1)
47
+ self.sentiment_history.append(np.clip(sentiment, 0.0, 1.0))
48
+ self.current_price = self.price_history[-1]
49
 
50
+ def reset(self):
51
+ self.balance = self.initial_balance
52
+ self.position = 0.0
53
+ self.step_count = 0
54
+ self.price_history = [100.0 + np.random.normal(0, 5)]
55
+ self.sentiment_history = [0.5]
56
+ obs = self._get_observation()
57
+ info = self._get_info()
58
+ return obs, info
 
 
59
 
60
+ def step(self, action):
61
+ self.step_count += 1
62
+ price_change = np.random.normal(0, 0.02)
63
+ self.current_price = max(10.0, self.current_price * (1 + price_change))
64
+ self.price_history.append(self.current_price)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ sentiment_change = np.random.normal(0, 0.05)
67
+ new_sentiment = np.clip(self.sentiment_history[-1] + sentiment_change, 0.0, 1.0)
68
+ self.sentiment_history.append(new_sentiment)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ reward = self._execute_action(action)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ terminated = self.balance <= 0 or self.step_count >= self.max_steps
73
+ truncated = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ obs = self._get_observation()
76
+ info = self._get_info()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ return obs, reward, terminated, truncated, info
79
+
80
+ def _execute_action(self, action):
81
+ reward = 0.0
82
+ prev_net_worth = self.balance + self.position * self.current_price
83
 
84
+ if action == 1: # Buy
85
+ trade_amount = min(self.balance * 0.2, self.balance)
86
+ cost = trade_amount
87
+ if cost <= self.balance:
88
+ self.position += trade_amount / self.current_price
89
+ self.balance -= cost
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ elif action == 2: # Sell
92
+ if self.position > 0:
93
+ sell_amount = min(self.position * 0.2, self.position)
94
+ proceeds = sell_amount * self.current_price
95
+ self.position -= sell_amount
96
+ self.balance += proceeds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ elif action == 3: # Close
99
+ if self.position > 0:
100
+ proceeds = self.position * self.current_price
101
+ self.balance += proceeds
102
+ self.position = 0
 
103
 
104
+ net_worth = self.balance + self.position * self.current_price
105
+ reward = (net_worth - prev_net_worth) / self.initial_balance * 100
 
 
106
 
107
+ return reward
108
+
109
+ def _get_observation(self):
110
+ recent_prices = self.price_history[-10:] if len(self.price_history) >= 10 else [self.current_price] * 10
111
+ recent_sentiments = self.sentiment_history[-10:] if len(self.sentiment_history) >= 10 else [0.5] * 10
112
 
113
+ features = [
114
+ self.balance / self.initial_balance,
115
+ self.position * self.current_price / self.initial_balance,
116
+ self.current_price / 100.0,
117
+ np.mean(recent_prices) / 100.0,
118
+ np.std(recent_prices) / 100.0,
119
+ np.mean(recent_sentiments),
120
+ np.std(recent_sentiments),
121
+ self.step_count / self.max_steps,
122
+ 0.0, 0.0, 0.0, 0.0 # Padding
123
+ ]
124
 
125
+ return np.array(features[:12], dtype=np.float32)
126
+
127
+ def _get_info(self):
128
+ net_worth = self.balance + self.position * self.current_price
129
+ return {'net_worth': net_worth}
130
+
131
+ class DQNAgent:
132
+ def __init__(self, state_dim, action_dim, config, device='cpu'):
133
+ self.device = torch.device(device)
134
+ self.q_network = torch.nn.Sequential(
135
+ torch.nn.Linear(state_dim, 128),
136
+ torch.nn.ReLU(),
137
+ torch.nn.Linear(128, 128),
138
+ torch.nn.ReLU(),
139
+ torch.nn.Linear(128, action_dim)
140
+ ).to(self.device)
141
 
142
+ self.target_network = torch.nn.Sequential(
143
+ torch.nn.Linear(state_dim, 128),
144
+ torch.nn.ReLU(),
145
+ torch.nn.Linear(128, 128),
146
+ torch.nn.ReLU(),
147
+ torch.nn.Linear(128, action_dim)
148
+ ).to(self.device)
149
+
150
+ self.target_network.load_state_dict(self.q_network.state_dict())
151
+
152
+ self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=config.learning_rate)
153
+ self.memory = deque(maxlen=config.memory_size)
154
+ self.gamma = config.gamma
155
+ self.epsilon = config.epsilon_start
156
+ self.epsilon_min = config.epsilon_min
157
+ self.epsilon_decay = config.epsilon_decay
158
+ self.batch_size = config.batch_size
159
+ self.target_update = config.target_update
160
+ self.steps = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ def select_action(self, state, training=True):
163
+ state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
164
+ if training and random.random() < self.epsilon:
165
+ return random.randint(0, 3)
166
+ with torch.no_grad():
167
+ return self.q_network(state).argmax(1).item()
168
+
169
+ def store_transition(self, state, action, reward, next_state, done):
170
+ self.memory.append((state, action, reward, next_state, done))
171
+
172
+ def update(self):
173
+ if len(self.memory) < self.batch_size:
174
+ return 0.0
175
 
176
+ batch = random.sample(self.memory, self.batch_size)
177
+ states, actions, rewards, next_states, dones = zip(*batch)
 
 
 
 
 
 
 
 
 
178
 
179
+ states = torch.FloatTensor(np.array(states)).to(self.device)
180
+ actions = torch.LongTensor(actions).to(self.device)
181
+ rewards = torch.FloatTensor(rewards).to(self.device)
182
+ next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
183
+ dones = torch.FloatTensor(dones).to(self.device)
 
 
 
184
 
185
+ current_q = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
186
+ next_q = self.target_network(next_states).max(1)[0]
187
+ target_q = rewards + self.gamma * next_q * (1 - dones)
 
 
 
 
 
 
188
 
189
+ loss = torch.nn.MSELoss()(current_q, target_q)
 
 
190
 
191
+ self.optimizer.zero_grad()
192
+ loss.backward()
193
+ self.optimizer.step()
 
 
 
194
 
195
+ self.steps += 1
196
+ if self.steps % self.target_update == 0:
197
+ self.target_network.load_state_dict(self.q_network.state_dict())
 
 
198
 
199
+ self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
 
 
 
200
 
201
+ return loss.item()
202
+
203
+ class TradingDemo:
204
+ def __init__(self):
205
+ self.config = TradingConfig()
206
+ self.env = None
207
+ self.agent = None
208
+ self.device = 'cpu'
209
+
210
+ def initialize(self, balance, risk, asset):
211
+ self.config.initial_balance = balance
212
+ self.config.risk_level = risk
213
+ self.config.asset_type = asset
214
+ self.env = AdvancedTradingEnvironment(self.config)
215
+ self.agent = DQNAgent(12, 4, self.config, self.device)
216
+ return "✅ Initialized!"
217
+
218
+ def train(self, episodes):
219
+ for ep in range(episodes):
220
+ obs, _ = self.env.reset()
221
+ total_reward = 0
222
+ done = False
223
+ while not done:
224
+ action = self.agent.select_action(obs)
225
+ next_obs, reward, done, _, info = self.env.step(action)
226
+ self.agent.store_transition(obs, action, reward, next_obs, done)
227
+ obs = next_obs
228
+ total_reward += reward
229
+ self.agent.update()
230
+ yield f"Episode {ep+1}/{episodes} | Reward: {total_reward:.2f}", None
231
+ yield "✅ Training complete!", None
232
+
233
+ def simulate(self, steps):
234
+ obs, _ = self.env.reset()
235
+ prices = []
236
+ actions = []
237
+ net_worths = []
238
+ for _ in range(steps):
239
+ action = self.agent.select_action(obs, training=False)
240
+ next_obs, reward, done, _, info = self.env.step(action)
241
+ prices.append(self.env.current_price)
242
+ actions.append(action)
243
+ net_worths.append(info['net_worth'])
244
+ obs = next_obs
245
+ if done:
246
+ break
247
 
248
+ import plotly.graph_objects as go
249
+ fig = go.Figure()
250
+ fig.add_trace(go.Scatter(y=prices, mode='lines', name='Price'))
251
+ fig.add_trace(go.Scatter(y=net_worths, mode='lines', name='Net Worth'))
252
+ return "✅ Simulation complete!", fig
 
253
 
254
+ demo = TradingDemo()
255
+
256
+ with gr.Blocks() as interface:
257
+ gr.Markdown("# Trading AI Demo")
258
+
259
+ with gr.Row():
260
+ balance = gr.Slider(1000, 50000, 10000, label="Balance")
261
+ risk = gr.Radio(["Low", "Medium", "High"], value="Medium", label="Risk")
262
+ asset = gr.Radio(["Crypto", "Stock", "Forex"], value="Crypto", label="Asset")
263
+ init_btn = gr.Button("Initialize")
264
+
265
+ status = gr.Textbox(label="Status")
266
 
267
+ episodes = gr.Number(value=50, label="Episodes")
268
+ train_btn = gr.Button("Train")
269
+ train_plot = gr.Plot()
270
+
271
+ steps = gr.Number(value=100, label="Simulation Steps")
272
+ sim_btn = gr.Button("Simulate")
273
+ sim_plot = gr.Plot()
274
+
275
+ init_btn.click(demo.initialize, [balance, risk, asset], status)
276
+ train_btn.click(demo.train, episodes, [status, train_plot])
277
+ sim_btn.click(demo.simulate, steps, [status, sim_plot])
278
+
279
+ interface.launch()