OmidSakaki commited on
Commit
0f853b8
·
verified ·
1 Parent(s): feda71d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -597
app.py CHANGED
@@ -22,20 +22,10 @@ os.makedirs('src/visualizers', exist_ok=True)
22
  os.makedirs('src/utils', exist_ok=True)
23
 
24
  # Create __init__.py files
25
- with open('src/__init__.py', 'w') as f:
26
- f.write('')
27
-
28
- with open('src/environments/__init__.py', 'w') as f:
29
- f.write('')
30
-
31
- with open('src/agents/__init__.py', 'w') as f:
32
- f.write('')
33
-
34
- with open('src/visualizers/__init__.py', 'w') as f:
35
- f.write('')
36
-
37
- with open('src/utils/__init__.py', 'w') as f:
38
- f.write('')
39
 
40
  # Now import our custom modules
41
  sys.path.append('src')
@@ -43,117 +33,16 @@ sys.path.append('src')
43
  # Import our custom modules
44
  from src.environments.visual_trading_env import VisualTradingEnvironment
45
  from src.agents.visual_agent import VisualTradingAgent
46
-
47
- class ChartRenderer:
48
- """Simple chart renderer for visualization"""
49
- def __init__(self):
50
- pass
51
-
52
- def render_price_chart(self, prices, actions=None, current_step=0):
53
- """Render price chart with actions"""
54
- fig = go.Figure()
55
-
56
- if not prices:
57
- # Return empty figure if no data
58
- fig.update_layout(
59
- title="Price Chart - No Data",
60
- xaxis_title="Time Step",
61
- yaxis_title="Price",
62
- height=300
63
- )
64
- return fig
65
-
66
- # Add price line
67
- fig.add_trace(go.Scatter(
68
- x=list(range(len(prices))),
69
- y=prices,
70
- mode='lines',
71
- name='Price',
72
- line=dict(color='blue', width=2)
73
- ))
74
-
75
- # Add action markers if provided
76
- if actions and len(actions) == len(prices):
77
- buy_indices = [i for i, action in enumerate(actions) if action == 1]
78
- sell_indices = [i for i, action in enumerate(actions) if action == 2]
79
- close_indices = [i for i, action in enumerate(actions) if action == 3]
80
-
81
- if buy_indices:
82
- fig.add_trace(go.Scatter(
83
- x=buy_indices,
84
- y=[prices[i] for i in buy_indices],
85
- mode='markers',
86
- name='Buy',
87
- marker=dict(color='green', size=10, symbol='triangle-up', line=dict(width=2, color='darkgreen'))
88
- ))
89
-
90
- if sell_indices:
91
- fig.add_trace(go.Scatter(
92
- x=sell_indices,
93
- y=[prices[i] for i in sell_indices],
94
- mode='markers',
95
- name='Sell',
96
- marker=dict(color='red', size=10, symbol='triangle-down', line=dict(width=2, color='darkred'))
97
- ))
98
-
99
- if close_indices:
100
- fig.add_trace(go.Scatter(
101
- x=close_indices,
102
- y=[prices[i] for i in close_indices],
103
- mode='markers',
104
- name='Close',
105
- marker=dict(color='orange', size=8, symbol='x', line=dict(width=2, color='darkorange'))
106
- ))
107
-
108
- fig.update_layout(
109
- title=f"Price Chart (Step: {current_step})",
110
- xaxis_title="Time Step",
111
- yaxis_title="Price",
112
- height=300,
113
- showlegend=True
114
- )
115
-
116
- return fig
117
-
118
- class DataLoader:
119
- """Data loader for synthetic market data"""
120
- def __init__(self):
121
- pass
122
-
123
- def generate_synthetic_data(self, num_points=1000, trend=0.0005, volatility=0.02):
124
- """Generate realistic synthetic market data"""
125
- np.random.seed(42)
126
-
127
- prices = [100.0]
128
- for i in range(1, num_points):
129
- # Random walk with trend and some mean reversion
130
- change = np.random.normal(trend, volatility)
131
- # Add some mean reversion
132
- mean_reversion = (100 - prices[-1]) * 0.001
133
- price = max(1.0, prices[-1] * (1 + change) + mean_reversion)
134
- prices.append(price)
135
-
136
- return np.array(prices)
137
-
138
- class TradingConfig:
139
- """Configuration class for trading parameters"""
140
- def __init__(self):
141
- self.initial_balance = 10000
142
- self.max_steps = 1000
143
- self.transaction_cost = 0.001
144
- self.learning_rate = 0.001
145
- self.gamma = 0.99
146
 
147
  class TradingAIDemo:
148
  def __init__(self):
149
- self.config = TradingConfig()
150
  self.env = None
151
  self.agent = None
152
  self.current_state = None
153
  self.is_training = False
154
  self.episode_history = []
155
  self.chart_renderer = ChartRenderer()
156
- self.data_loader = DataLoader()
157
  self.initialized = False
158
 
159
  def initialize_environment(self, initial_balance, risk_level, asset_type):
@@ -169,7 +58,7 @@ class TradingAIDemo:
169
 
170
  # Initialize agent with correct dimensions
171
  self.agent = VisualTradingAgent(
172
- state_dim=(84, 84, 4), # Fixed dimensions
173
  action_dim=4
174
  )
175
 
@@ -290,7 +179,9 @@ class TradingAIDemo:
290
  # Calculate performance metrics
291
  initial_balance = self.env.initial_balance
292
  final_net_worth = info['net_worth']
293
- total_return = (final_net_worth - initial_balance) / initial_balance * 100
 
 
294
 
295
  summary = (
296
  f"🎯 Episode Completed!\n"
@@ -318,13 +209,14 @@ class TradingAIDemo:
318
  training_history = []
319
 
320
  try:
321
- for episode in range(int(num_episodes)):
 
322
  state = self.env.reset()
323
- episode_reward = 0
324
  done = False
325
  steps = 0
326
 
327
- while not done and steps < 100: # Limit steps per episode
328
  action = self.agent.select_action(state)
329
  next_state, reward, done, info = self.env.step(action)
330
  self.agent.store_transition(state, action, reward, next_state, done)
@@ -339,506 +231,40 @@ class TradingAIDemo:
339
  'episode': episode,
340
  'reward': episode_reward,
341
  'net_worth': info['net_worth'],
342
- 'loss': loss if loss else 0,
343
  'steps': steps
344
  })
345
 
346
  # Yield progress every 5 episodes or at the end
347
  if episode % 5 == 0 or episode == num_episodes - 1:
348
  progress_chart = self.create_training_progress(training_history)
 
349
  status = (
350
  f"🔄 Training Progress: {episode+1}/{num_episodes}\n"
351
  f"• Episode Reward: {episode_reward:.2f}\n"
352
  f"• Final Net Worth: ${info['net_worth']:.2f}\n"
353
- f"• Loss: {loss:.4f if loss else 0:.4f}\n"
354
  f"• Epsilon: {self.agent.epsilon:.3f}"
355
  )
356
  yield progress_chart, status
357
 
358
- # Small delay to make training visible
359
  time.sleep(0.01)
360
 
361
  self.is_training = False
 
 
 
 
 
362
  final_status = (
363
  f"✅ Training Completed!\n"
364
  f"• Total Episodes: {num_episodes}\n"
365
  f"• Final Epsilon: {self.agent.epsilon:.3f}\n"
366
- f"• Average Reward: {np.mean([h['reward'] for h in training_history]):.2f}"
367
  )
368
  yield self.create_training_progress(training_history), final_status
369
 
370
  except Exception as e:
371
  self.is_training = False
372
  error_msg = f"❌ Training error: {str(e)}"
373
- print(error_msg)
374
- yield None, error_msg
375
-
376
- def create_price_chart(self, info):
377
- """Create price chart with actions"""
378
- if not self.episode_history:
379
- # Return empty chart with message
380
- fig = go.Figure()
381
- fig.update_layout(
382
- title="Price Chart - No Data Available",
383
- xaxis_title="Time Step",
384
- yaxis_title="Price",
385
- height=300
386
- )
387
- return fig
388
-
389
- prices = [h['price'] for h in self.episode_history]
390
- actions = [h['action'] for h in self.episode_history]
391
-
392
- fig = go.Figure()
393
-
394
- # Price line
395
- fig.add_trace(go.Scatter(
396
- x=list(range(len(prices))),
397
- y=prices,
398
- mode='lines',
399
- name='Price',
400
- line=dict(color='blue', width=3)
401
- ))
402
-
403
- # Action markers
404
- buy_indices = [i for i, action in enumerate(actions) if action == 1]
405
- sell_indices = [i for i, action in enumerate(actions) if action == 2]
406
- close_indices = [i for i, action in enumerate(actions) if action == 3]
407
-
408
- if buy_indices:
409
- fig.add_trace(go.Scatter(
410
- x=buy_indices,
411
- y=[prices[i] for i in buy_indices],
412
- mode='markers',
413
- name='Buy',
414
- marker=dict(color='green', size=12, symbol='triangle-up',
415
- line=dict(width=2, color='darkgreen'))
416
- ))
417
-
418
- if sell_indices:
419
- fig.add_trace(go.Scatter(
420
- x=sell_indices,
421
- y=[prices[i] for i in sell_indices],
422
- mode='markers',
423
- name='Sell',
424
- marker=dict(color='red', size=12, symbol='triangle-down',
425
- line=dict(width=2, color='darkred'))
426
- ))
427
-
428
- if close_indices:
429
- fig.add_trace(go.Scatter(
430
- x=close_indices,
431
- y=[prices[i] for i in close_indices],
432
- mode='markers',
433
- name='Close',
434
- marker=dict(color='orange', size=10, symbol='x',
435
- line=dict(width=2, color='darkorange'))
436
- ))
437
-
438
- fig.update_layout(
439
- title="Price Chart with Trading Actions",
440
- xaxis_title="Step",
441
- yaxis_title="Price",
442
- height=350,
443
- showlegend=True,
444
- template="plotly_white"
445
- )
446
-
447
- return fig
448
-
449
- def create_performance_chart(self):
450
- """Create portfolio performance chart"""
451
- if not self.episode_history:
452
- fig = go.Figure()
453
- fig.update_layout(
454
- title="Portfolio Performance - No Data Available",
455
- height=400
456
- )
457
- return fig
458
-
459
- net_worth = [h['net_worth'] for h in self.episode_history]
460
- rewards = [h['reward'] for h in self.episode_history]
461
-
462
- fig = make_subplots(
463
- rows=2, cols=1,
464
- subplot_titles=['Portfolio Value Over Time', 'Step Rewards'],
465
- vertical_spacing=0.15
466
- )
467
-
468
- # Portfolio value
469
- fig.add_trace(go.Scatter(
470
- x=list(range(len(net_worth))),
471
- y=net_worth,
472
- mode='lines+markers',
473
- name='Net Worth',
474
- line=dict(color='green', width=3),
475
- marker=dict(size=4)
476
- ), row=1, col=1)
477
-
478
- # Add initial balance reference line
479
- if self.env:
480
- fig.add_hline(y=self.env.initial_balance, line_dash="dash",
481
- line_color="red", annotation_text="Initial Balance",
482
- row=1, col=1)
483
-
484
- # Rewards as bar chart
485
- fig.add_trace(go.Bar(
486
- x=list(range(len(rewards))),
487
- y=rewards,
488
- name='Reward',
489
- marker_color=['green' if r >= 0 else 'red' for r in rewards],
490
- opacity=0.7
491
- ), row=2, col=1)
492
-
493
- fig.update_layout(height=500, showlegend=False, template="plotly_white")
494
- fig.update_yaxes(title_text="Value ($)", row=1, col=1)
495
- fig.update_yaxes(title_text="Reward", row=2, col=1)
496
- fig.update_xaxes(title_text="Step", row=2, col=1)
497
-
498
- return fig
499
-
500
- def create_action_chart(self):
501
- """Create action distribution chart"""
502
- if not self.episode_history:
503
- fig = go.Figure()
504
- fig.update_layout(
505
- title="Action Distribution - No Data Available",
506
- height=300
507
- )
508
- return fig
509
-
510
- actions = [h['action'] for h in self.episode_history]
511
- action_names = ['Hold', 'Buy', 'Sell', 'Close']
512
- action_counts = [actions.count(i) for i in range(4)]
513
-
514
- colors = ['blue', 'green', 'red', 'orange']
515
-
516
- fig = go.Figure(data=[go.Pie(
517
- labels=action_names,
518
- values=action_counts,
519
- hole=.4,
520
- marker_colors=colors,
521
- textinfo='label+percent+value',
522
- hoverinfo='label+percent+value'
523
- )])
524
-
525
- fig.update_layout(
526
- title="Action Distribution",
527
- height=350,
528
- annotations=[dict(text='Actions', x=0.5, y=0.5, font_size=16, showarrow=False)]
529
- )
530
-
531
- return fig
532
-
533
- def create_training_progress(self, training_history):
534
- """Create training progress visualization"""
535
- if not training_history:
536
- fig = go.Figure()
537
- fig.update_layout(
538
- title="Training Progress - No Data Available",
539
- height=500
540
- )
541
- return fig
542
-
543
- df = pd.DataFrame(training_history)
544
-
545
- fig = make_subplots(
546
- rows=2, cols=2,
547
- subplot_titles=['Episode Rewards', 'Portfolio Value',
548
- 'Training Loss', 'Moving Average Reward (5)'],
549
- specs=[[{}, {}], [{}, {}]]
550
- )
551
-
552
- # Rewards
553
- fig.add_trace(go.Scatter(
554
- x=df['episode'], y=df['reward'], mode='lines+markers',
555
- name='Reward', line=dict(color='blue', width=2),
556
- marker=dict(size=4)
557
- ), row=1, col=1)
558
-
559
- # Portfolio value
560
- fig.add_trace(go.Scatter(
561
- x=df['episode'], y=df['net_worth'], mode='lines+markers',
562
- name='Net Worth', line=dict(color='green', width=2),
563
- marker=dict(size=4)
564
- ), row=1, col=2)
565
-
566
- # Add initial balance reference
567
- if self.env:
568
- fig.add_hline(y=self.env.initial_balance, line_dash="dash",
569
- line_color="red", annotation_text="Initial Balance",
570
- row=1, col=2)
571
-
572
- # Loss
573
- if 'loss' in df.columns and df['loss'].notna().any() and df['loss'].sum() > 0:
574
- fig.add_trace(go.Scatter(
575
- x=df['episode'], y=df['loss'], mode='lines+markers',
576
- name='Loss', line=dict(color='red', width=2),
577
- marker=dict(size=4)
578
- ), row=2, col=1)
579
-
580
- # Moving average reward
581
- if len(df) > 5:
582
- df['ma_reward'] = df['reward'].rolling(window=5).mean()
583
- fig.add_trace(go.Scatter(
584
- x=df['episode'], y=df['ma_reward'], mode='lines',
585
- name='MA Reward (5)', line=dict(color='orange', width=3, dash='dash')
586
- ), row=2, col=2)
587
-
588
- fig.update_layout(
589
- height=600,
590
- showlegend=True,
591
- title_text="Training Progress Over Episodes",
592
- template="plotly_white"
593
- )
594
-
595
- return fig
596
-
597
- # Initialize the demo
598
- demo = TradingAIDemo()
599
-
600
- # Create Gradio interface
601
- def create_interface():
602
- with gr.Blocks(theme=gr.themes.Soft(), title="Visual Trading AI") as interface:
603
- gr.Markdown("""
604
- # 🚀 Visual Trading AI
605
- **هوش مصنوعی معامله‌گر بصری - تحلیل چارت‌های قیمت با یادگیری تقویتی عمیق**
606
-
607
- *این پروژه از شبکه‌های عصبی کانولوشن برای تحلیل بصری نمودارهای قیمت و یادگیری تقویتی برای تصمیم‌گیری معاملاتی استفاده می‌کند.*
608
- """)
609
-
610
- with gr.Row():
611
- with gr.Column(scale=1):
612
- # Configuration section
613
- gr.Markdown("## ⚙️ پیکربندی محیط")
614
-
615
- with gr.Row():
616
- initial_balance = gr.Slider(
617
- minimum=1000, maximum=50000, value=10000, step=1000,
618
- label="موجودی اولیه ($)", info="میزان سرمایه اولیه برای معامله"
619
- )
620
-
621
- with gr.Row():
622
- risk_level = gr.Radio(
623
- ["Low", "Medium", "High"],
624
- value="Medium",
625
- label="سطح ریسک",
626
- info="سطح ریسک پذیری در معاملات"
627
- )
628
-
629
- with gr.Row():
630
- asset_type = gr.Radio(
631
- ["Stock", "Crypto", "Forex"],
632
- value="Stock",
633
- label="نوع دارایی",
634
- info="نوع بازار مالی برای شبیه‌سازی"
635
- )
636
-
637
- with gr.Row():
638
- init_btn = gr.Button(
639
- "🚀 راه‌اندازی محیط معاملاتی",
640
- variant="primary",
641
- size="lg"
642
- )
643
-
644
- with gr.Row():
645
- init_status = gr.Textbox(
646
- label="وضعیت راه‌اندازی",
647
- interactive=False,
648
- placeholder="برای شروع، محیط را راه‌اندازی کنید...",
649
- lines=2
650
- )
651
-
652
- with gr.Column(scale=2):
653
- # Status output
654
- gr.Markdown("## 📊 وضعیت معاملات")
655
- status_output = gr.Textbox(
656
- label="وضعیت اجرا",
657
- interactive=False,
658
- placeholder="وضعیت معاملات اینجا نمایش داده می‌شود...",
659
- lines=4
660
- )
661
-
662
- with gr.Row():
663
- gr.Markdown("## 🎮 کنترل معاملات")
664
-
665
- with gr.Row():
666
- # Action controls
667
- action_choice = gr.Radio(
668
- ["AI Decision", "Buy", "Sell", "Hold", "Close"],
669
- value="AI Decision",
670
- label="انتخاب اقدام",
671
- info="AI Decision: تصمیم خودکار هوش مصنوعی"
672
- )
673
-
674
- with gr.Row():
675
- with gr.Column(scale=1):
676
- step_btn = gr.Button(
677
- "▶️ اجرای یک قدم",
678
- variant="secondary",
679
- size="lg"
680
- )
681
-
682
- with gr.Column(scale=1):
683
- episode_btn = gr.Button(
684
- "🎯 اجرای یک اپیزود (20 قدم)",
685
- variant="secondary",
686
- size="lg"
687
- )
688
-
689
- with gr.Row():
690
- # Visualization outputs
691
- with gr.Column(scale=1):
692
- price_chart = gr.Plot(
693
- label="📈 نمودار قیمت و اقدامات"
694
- )
695
-
696
- with gr.Column(scale=1):
697
- performance_chart = gr.Plot(
698
- label="💰 عملکرد پرتفولیو"
699
- )
700
-
701
- with gr.Row():
702
- with gr.Column(scale=1):
703
- action_chart = gr.Plot(
704
- label="🎯 توزیع اقدامات"
705
- )
706
-
707
- with gr.Row():
708
- gr.Markdown("## 🎓 آموزش هوش مصنوعی")
709
-
710
- with gr.Row():
711
- with gr.Column(scale=1):
712
- num_episodes = gr.Slider(
713
- minimum=10, maximum=200, value=50, step=10,
714
- label="تعداد اپیزودهای آموزش",
715
- info="تعداد دوره‌های آموزشی"
716
- )
717
-
718
- learning_rate = gr.Slider(
719
- minimum=0.0001, maximum=0.01, value=0.001, step=0.0001,
720
- label="نرخ یادگیری",
721
- info="سرعت یادگیری الگوریتم"
722
- )
723
-
724
- train_btn = gr.Button(
725
- "🤖 شروع آموزش",
726
- variant="primary",
727
- size="lg"
728
- )
729
-
730
- with gr.Column(scale=2):
731
- training_plot = gr.Plot(
732
- label="📊 پیشرفت آموزش"
733
- )
734
-
735
- training_status = gr.Textbox(
736
- label="وضعیت آموزش",
737
- interactive=False,
738
- placeholder="وضعیت آموزش اینجا نمایش داده می‌شود...",
739
- lines=3
740
- )
741
-
742
- with gr.Row():
743
- gr.Markdown("## ℹ️ راهنمای استفاده")
744
-
745
- with gr.Row():
746
- with gr.Column(scale=1):
747
- gr.Markdown("""
748
- **🎯 اقدامات ممکن:**
749
- - **Hold (0)**: حفظ وضعیت فعلی
750
- - **Buy (1)**: باز کردن پوزیشن خرید
751
- - **Sell (2)**: افزایش سایز پوزیشن
752
- - **Close (3)**: بستن پوزیشن فعلی
753
-
754
- **📈 معیارهای عملکرد:**
755
- - **Reward**: امتیاز دریافتی از محیط
756
- - **Net Worth**: ارزش کل پرتفولیو
757
- - **Balance**: موجودی نقدی
758
- - **Position**: سایز پوزیشن فعلی
759
- """)
760
-
761
- with gr.Column(scale=1):
762
- gr.Markdown("""
763
- **🔧 نحوه استفاده:**
764
- 1. محیط را راه‌اندازی کنید
765
- 2. اقدامات تکی یا اپیزودها را اجرا کنید
766
- 3. عملکرد را در نمودارها مشاهده کنید
767
- 4. هوش مصنوعی را آموزش دهید
768
- 5. نتایج را تحلیل کنید
769
-
770
- **⚠️ توجه:**
771
- این یک شبیه‌ساز آموزشی است و برای معاملات واقعی طراحی نشده است.
772
- """)
773
-
774
- # Event handlers
775
- init_btn.click(
776
- demo.initialize_environment,
777
- inputs=[initial_balance, risk_level, asset_type],
778
- outputs=[init_status]
779
- )
780
-
781
- step_btn.click(
782
- demo.run_single_step,
783
- inputs=[action_choice],
784
- outputs=[price_chart, performance_chart, action_chart, status_output]
785
- )
786
-
787
- episode_btn.click(
788
- demo.run_episode,
789
- inputs=[],
790
- outputs=[price_chart, performance_chart, action_chart, status_output]
791
- )
792
-
793
- train_btn.click(
794
- demo.train_agent,
795
- inputs=[num_episodes, learning_rate],
796
- outputs=[training_plot, training_status]
797
- )
798
-
799
- gr.Markdown("""
800
- ## 🏗 معماری فنی
801
-
802
- **🎯 هسته هوش مصنوعی:**
803
- - **پردازش بصری**: شبکه عصبی کانولوشن (CNN) برای تحلیل نمودارهای قیمت
804
- - **یادگیری تقویتی**: الگوریتم Deep Q-Network (DQN) برای تصمیم‌گیری
805
- - **تجربه replay**: ذخیره و بازیابی تجربیات برای یادگیری پایدار
806
-
807
- **🛠 فناوری‌ها:**
808
- - **یادگیری عمیق**: PyTorch
809
- - **محیط شبیه‌سازی**: محیط اختصاصی معاملاتی
810
- - **رابط کاربری**: Gradio
811
- - **ویژوالیزیشن**: Plotly, Matplotlib
812
- - **پردازش داده**: NumPy, Pandas
813
-
814
- **📊 ویژگی‌های کلیدی:**
815
- - تحلیل بصری نمودارهای قیمت
816
- - یادگیری خودکار استراتژی‌های معاملاتی
817
- - نمایش زنده عملکرد و تصمیم‌ها
818
- - کنترل دستی و خودکار
819
- - آنالیز جامع عملکرد
820
-
821
- *توسعه داده شده توسط Omid Sakaki - 2024*
822
- """)
823
-
824
- return interface
825
-
826
- # Create and launch interface
827
- if __name__ == "__main__":
828
- print("🚀 Starting Visual Trading AI Application...")
829
- print("📊 Initializing components...")
830
-
831
- interface = create_interface()
832
-
833
- print("✅ Application initialized successfully!")
834
- print("🌐 Starting server on http://0.0.0.0:7860")
835
- print("📱 You can now access the application in your browser")
836
-
837
- # Launch with better configuration
838
- interface.launch(
839
- server_name="0.0.0.0",
840
- server_port=7860,
841
- share=False,
842
- show_error=True,
843
- debug=True
844
- )
 
22
  os.makedirs('src/utils', exist_ok=True)
23
 
24
  # Create __init__.py files
25
+ for dir_path in ['src', 'src/environments', 'src/agents', 'src/visualizers', 'src/utils']:
26
+ init_file = os.path.join(dir_path, '__init__.py')
27
+ with open(init_file, 'w') as f:
28
+ f.write('')
 
 
 
 
 
 
 
 
 
 
29
 
30
  # Now import our custom modules
31
  sys.path.append('src')
 
33
  # Import our custom modules
34
  from src.environments.visual_trading_env import VisualTradingEnvironment
35
  from src.agents.visual_agent import VisualTradingAgent
36
+ from src.visualizers.chart_renderer import ChartRenderer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  class TradingAIDemo:
39
  def __init__(self):
 
40
  self.env = None
41
  self.agent = None
42
  self.current_state = None
43
  self.is_training = False
44
  self.episode_history = []
45
  self.chart_renderer = ChartRenderer()
 
46
  self.initialized = False
47
 
48
  def initialize_environment(self, initial_balance, risk_level, asset_type):
 
58
 
59
  # Initialize agent with correct dimensions
60
  self.agent = VisualTradingAgent(
61
+ state_dim=(84, 84, 4),
62
  action_dim=4
63
  )
64
 
 
179
  # Calculate performance metrics
180
  initial_balance = self.env.initial_balance
181
  final_net_worth = info['net_worth']
182
+ total_return = 0.0
183
+ if initial_balance > 0:
184
+ total_return = (final_net_worth - initial_balance) / initial_balance * 100
185
 
186
  summary = (
187
  f"🎯 Episode Completed!\n"
 
209
  training_history = []
210
 
211
  try:
212
+ num_episodes = int(num_episodes)
213
+ for episode in range(num_episodes):
214
  state = self.env.reset()
215
+ episode_reward = 0.0
216
  done = False
217
  steps = 0
218
 
219
+ while not done and steps < 100:
220
  action = self.agent.select_action(state)
221
  next_state, reward, done, info = self.env.step(action)
222
  self.agent.store_transition(state, action, reward, next_state, done)
 
231
  'episode': episode,
232
  'reward': episode_reward,
233
  'net_worth': info['net_worth'],
234
+ 'loss': loss,
235
  'steps': steps
236
  })
237
 
238
  # Yield progress every 5 episodes or at the end
239
  if episode % 5 == 0 or episode == num_episodes - 1:
240
  progress_chart = self.create_training_progress(training_history)
241
+
242
  status = (
243
  f"🔄 Training Progress: {episode+1}/{num_episodes}\n"
244
  f"• Episode Reward: {episode_reward:.2f}\n"
245
  f"• Final Net Worth: ${info['net_worth']:.2f}\n"
246
+ f"• Loss: {loss:.4f}\n"
247
  f"• Epsilon: {self.agent.epsilon:.3f}"
248
  )
249
  yield progress_chart, status
250
 
 
251
  time.sleep(0.01)
252
 
253
  self.is_training = False
254
+
255
+ # Calculate average reward
256
+ rewards = [h['reward'] for h in training_history]
257
+ avg_reward = np.mean(rewards) if rewards else 0.0
258
+
259
  final_status = (
260
  f"✅ Training Completed!\n"
261
  f"• Total Episodes: {num_episodes}\n"
262
  f"• Final Epsilon: {self.agent.epsilon:.3f}\n"
263
+ f"• Average Reward: {avg_reward:.2f}"
264
  )
265
  yield self.create_training_progress(training_history), final_status
266
 
267
  except Exception as e:
268
  self.is_training = False
269
  error_msg = f"❌ Training error: {str(e)}"
270
+ print(f"Training error