OmidSakaki commited on
Commit
ef2e897
ยท
verified ยท
1 Parent(s): ffc1f4e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +452 -0
app.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+ import io
7
+ import base64
8
+ from PIL import Image
9
+ import plotly.graph_objects as go
10
+ from plotly.subplots import make_subplots
11
+ import time
12
+ import sys
13
+ import os
14
+
15
+ # Add src to path
16
+ sys.path.append('src')
17
+
18
+ from src.environments.visual_trading_env import VisualTradingEnvironment
19
+ from src.agents.visual_agent import VisualTradingAgent
20
+ from src.visualizers.chart_renderer import ChartRenderer
21
+ from src.utils.data_loader import DataLoader
22
+ from src.utils.config import TradingConfig
23
+
24
+ class TradingAIDemo:
25
+ def __init__(self):
26
+ self.config = TradingConfig()
27
+ self.env = None
28
+ self.agent = None
29
+ self.current_state = None
30
+ self.is_training = False
31
+ self.episode_history = []
32
+ self.chart_renderer = ChartRenderer()
33
+
34
+ def initialize_environment(self, initial_balance, risk_level, asset_type):
35
+ """Initialize trading environment"""
36
+ try:
37
+ self.env = VisualTradingEnvironment(
38
+ initial_balance=initial_balance,
39
+ risk_level=risk_level,
40
+ asset_type=asset_type
41
+ )
42
+ self.agent = VisualTradingAgent(
43
+ state_dim=self.env.observation_space.shape[0],
44
+ action_dim=self.env.action_space.n
45
+ )
46
+ self.current_state = self.env.reset()
47
+ return "โœ… Environment initialized successfully!"
48
+ except Exception as e:
49
+ return f"โŒ Error initializing environment: {str(e)}"
50
+
51
+ def run_single_step(self, action_choice):
52
+ """Run a single step in the environment"""
53
+ if self.env is None or self.agent is None:
54
+ return None, "Please initialize environment first!"
55
+
56
+ try:
57
+ # Use selected action or let agent decide
58
+ if action_choice == "AI Decision":
59
+ action = self.agent.select_action(self.current_state)
60
+ else:
61
+ action_mapping = {"Buy": 1, "Sell": 2, "Hold": 0, "Close": 3}
62
+ action = action_mapping[action_choice]
63
+
64
+ # Execute action
65
+ next_state, reward, done, info = self.env.step(action)
66
+ self.current_state = next_state
67
+
68
+ # Create visualization
69
+ fig = self.create_visualization(info, action, reward)
70
+
71
+ # Update history
72
+ self.episode_history.append({
73
+ 'step': len(self.episode_history),
74
+ 'action': action,
75
+ 'reward': reward,
76
+ 'net_worth': info['net_worth'],
77
+ 'balance': info['balance'],
78
+ 'position': info['position_size']
79
+ })
80
+
81
+ status = f"Action: {['Hold', 'Buy', 'Sell', 'Close'][action]} | Reward: {reward:.3f} | Net Worth: ${info['net_worth']:.2f}"
82
+ if done:
83
+ status += " | Episode Completed!"
84
+
85
+ return fig, status
86
+
87
+ except Exception as e:
88
+ return None, f"โŒ Error during step: {str(e)}"
89
+
90
+ def run_episode(self, num_steps):
91
+ """Run a complete episode"""
92
+ if self.env is None or self.agent is None:
93
+ return None, "Please initialize environment first!"
94
+
95
+ try:
96
+ self.env.reset()
97
+ total_reward = 0
98
+ step_data = []
99
+
100
+ for step in range(num_steps):
101
+ action = self.agent.select_action(self.current_state)
102
+ next_state, reward, done, info = self.env.step(action)
103
+ self.current_state = next_state
104
+ total_reward += reward
105
+
106
+ step_data.append({
107
+ 'step': step,
108
+ 'action': action,
109
+ 'reward': reward,
110
+ 'net_worth': info['net_worth'],
111
+ 'price': info['current_price']
112
+ })
113
+
114
+ if done:
115
+ break
116
+
117
+ # Create episode summary visualization
118
+ fig = self.create_episode_summary(step_data)
119
+ summary = f"Episode completed! Total Reward: {total_reward:.2f} | Final Net Worth: ${info['net_worth']:.2f}"
120
+
121
+ return fig, summary
122
+
123
+ except Exception as e:
124
+ return None, f"โŒ Error during episode: {str(e)}"
125
+
126
+ def train_agent(self, num_episodes, learning_rate):
127
+ """Train the AI agent"""
128
+ if self.env is None:
129
+ return "Please initialize environment first!"
130
+
131
+ self.is_training = True
132
+ progress = []
133
+
134
+ try:
135
+ for episode in range(num_episodes):
136
+ state = self.env.reset()
137
+ episode_reward = 0
138
+ done = False
139
+
140
+ while not done:
141
+ action = self.agent.select_action(state)
142
+ next_state, reward, done, info = self.env.step(action)
143
+ self.agent.store_transition(state, action, reward, next_state, done)
144
+ state = next_state
145
+ episode_reward += reward
146
+
147
+ # Update agent
148
+ loss = self.agent.update()
149
+
150
+ progress.append({
151
+ 'episode': episode,
152
+ 'reward': episode_reward,
153
+ 'net_worth': info['net_worth'],
154
+ 'loss': loss
155
+ })
156
+
157
+ yield self.create_training_progress(progress), f"Training... Episode {episode+1}/{num_episodes}"
158
+
159
+ self.is_training = False
160
+ yield self.create_training_progress(progress), "โœ… Training completed!"
161
+
162
+ except Exception as e:
163
+ self.is_training = False
164
+ yield None, f"โŒ Training error: {str(e)}"
165
+
166
+ def create_visualization(self, info, action, reward):
167
+ """Create real-time trading visualization"""
168
+ fig = make_subplots(
169
+ rows=2, cols=2,
170
+ subplot_titles=['Price Chart & Actions', 'Portfolio Performance',
171
+ 'Action Distribution', 'Reward History'],
172
+ specs=[[{"secondary_y": True}, {}],
173
+ [{}, {}]],
174
+ vertical_spacing=0.1,
175
+ horizontal_spacing=0.1
176
+ )
177
+
178
+ # Add price chart with actions
179
+ price_data = self.env.get_price_history()
180
+ fig.add_trace(
181
+ go.Scatter(x=list(range(len(price_data))), y=price_data,
182
+ mode='lines', name='Price', line=dict(color='blue')),
183
+ row=1, col=1
184
+ )
185
+
186
+ # Add portfolio value
187
+ portfolio_history = [h['net_worth'] for h in self.episode_history[-50:]]
188
+ if portfolio_history:
189
+ fig.add_trace(
190
+ go.Scatter(x=list(range(len(portfolio_history))), y=portfolio_history,
191
+ mode='lines', name='Portfolio', line=dict(color='green')),
192
+ row=1, col=2
193
+ )
194
+
195
+ # Add action distribution
196
+ if self.episode_history:
197
+ actions = [h['action'] for h in self.episode_history]
198
+ action_counts = pd.Series(actions).value_counts().sort_index()
199
+ fig.add_trace(
200
+ go.Bar(x=['Hold', 'Buy', 'Sell', 'Close'][:len(action_counts)],
201
+ y=action_counts.values, name='Actions'),
202
+ row=2, col=1
203
+ )
204
+
205
+ # Add reward history
206
+ rewards = [h['reward'] for h in self.episode_history[-20:]]
207
+ if rewards:
208
+ fig.add_trace(
209
+ go.Scatter(x=list(range(len(rewards))), y=rewards,
210
+ mode='lines+markers', name='Rewards', line=dict(color='orange')),
211
+ row=2, col=2
212
+ )
213
+
214
+ fig.update_layout(
215
+ height=600,
216
+ showlegend=True,
217
+ title_text=f"Trading Dashboard | Action: {['Hold', 'Buy', 'Sell', 'Close'][action]} | Reward: {reward:.3f}"
218
+ )
219
+
220
+ return fig
221
+
222
+ def create_episode_summary(self, step_data):
223
+ """Create episode summary visualization"""
224
+ if not step_data:
225
+ return go.Figure()
226
+
227
+ df = pd.DataFrame(step_data)
228
+
229
+ fig = make_subplots(
230
+ rows=2, cols=2,
231
+ subplot_titles=['Portfolio Value Over Time', 'Cumulative Rewards',
232
+ 'Action Frequency', 'Price vs Actions'],
233
+ specs=[[{}, {}], [{}, {}]]
234
+ )
235
+
236
+ # Portfolio value
237
+ fig.add_trace(
238
+ go.Scatter(x=df['step'], y=df['net_worth'], mode='lines',
239
+ name='Portfolio Value', line=dict(color='green')),
240
+ row=1, col=1
241
+ )
242
+
243
+ # Cumulative rewards
244
+ df['cumulative_reward'] = df['reward'].cumsum()
245
+ fig.add_trace(
246
+ go.Scatter(x=df['step'], y=df['cumulative_reward'], mode='lines',
247
+ name='Cumulative Reward', line=dict(color='orange')),
248
+ row=1, col=2
249
+ )
250
+
251
+ # Action frequency
252
+ action_counts = df['action'].value_counts().sort_index()
253
+ fig.add_trace(
254
+ go.Bar(x=[['Hold', 'Buy', 'Sell', 'Close'][i] for i in action_counts.index],
255
+ y=action_counts.values, name='Actions'),
256
+ row=2, col=1
257
+ )
258
+
259
+ # Price with action markers
260
+ fig.add_trace(
261
+ go.Scatter(x=df['step'], y=df['price'], mode='lines',
262
+ name='Price', line=dict(color='blue')),
263
+ row=2, col=2
264
+ )
265
+
266
+ # Add action markers
267
+ buy_actions = df[df['action'] == 1]
268
+ sell_actions = df[df['action'] == 2]
269
+
270
+ if not buy_actions.empty:
271
+ fig.add_trace(
272
+ go.Scatter(x=buy_actions['step'], y=buy_actions['price'],
273
+ mode='markers', name='Buy', marker=dict(color='green', size=10, symbol='triangle-up')),
274
+ row=2, col=2
275
+ )
276
+
277
+ if not sell_actions.empty:
278
+ fig.add_trace(
279
+ go.Scatter(x=sell_actions['step'], y=sell_actions['price'],
280
+ mode='markers', name='Sell', marker=dict(color='red', size=10, symbol='triangle-down')),
281
+ row=2, col=2
282
+ )
283
+
284
+ fig.update_layout(height=600, showlegend=True, title_text="Episode Summary")
285
+ return fig
286
+
287
+ def create_training_progress(self, progress):
288
+ """Create training progress visualization"""
289
+ if not progress:
290
+ return go.Figure()
291
+
292
+ df = pd.DataFrame(progress)
293
+
294
+ fig = make_subplots(
295
+ rows=2, cols=2,
296
+ subplot_titles=['Episode Rewards', 'Portfolio Value',
297
+ 'Training Loss', 'Performance Metrics'],
298
+ specs=[[{}, {}], [{}, {}]]
299
+ )
300
+
301
+ # Rewards
302
+ fig.add_trace(
303
+ go.Scatter(x=df['episode'], y=df['reward'], mode='lines+markers',
304
+ name='Reward', line=dict(color='blue')),
305
+ row=1, col=1
306
+ )
307
+
308
+ # Portfolio value
309
+ fig.add_trace(
310
+ go.Scatter(x=df['episode'], y=df['net_worth'], mode='lines+markers',
311
+ name='Net Worth', line=dict(color='green')),
312
+ row=1, col=2
313
+ )
314
+
315
+ # Loss
316
+ if 'loss' in df.columns:
317
+ fig.add_trace(
318
+ go.Scatter(x=df['episode'], y=df['loss'], mode='lines+markers',
319
+ name='Loss', line=dict(color='red')),
320
+ row=2, col=1
321
+ )
322
+
323
+ # Moving average reward
324
+ if len(df) > 10:
325
+ df['ma_reward'] = df['reward'].rolling(window=10).mean()
326
+ fig.add_trace(
327
+ go.Scatter(x=df['episode'], y=df['ma_reward'], mode='lines',
328
+ name='MA Reward (10)', line=dict(color='orange', dash='dash')),
329
+ row=2, col=2
330
+ )
331
+
332
+ fig.update_layout(height=600, showlegend=True, title_text="Training Progress")
333
+ return fig
334
+
335
+ # Initialize the demo
336
+ demo = TradingAIDemo()
337
+
338
+ # Create Gradio interface
339
+ def create_interface():
340
+ with gr.Blocks(theme=gr.themes.Soft(), title="Visual Trading AI") as interface:
341
+ gr.Markdown("""
342
+ # ๐Ÿš€ Visual Trading AI
343
+ *Intelligent Trading Agent with Visual Market Analysis*
344
+
345
+ This AI agent learns to trade by analyzing price charts visually using Deep Reinforcement Learning.
346
+ """)
347
+
348
+ with gr.Row():
349
+ with gr.Column(scale=1):
350
+ # Configuration section
351
+ gr.Markdown("## โš™๏ธ Configuration")
352
+
353
+ initial_balance = gr.Slider(1000, 50000, value=10000, step=1000,
354
+ label="Initial Balance ($)")
355
+ risk_level = gr.Radio(["Low", "Medium", "High"], value="Medium",
356
+ label="Risk Level")
357
+ asset_type = gr.Radio(["Stock", "Crypto", "Forex"], value="Stock",
358
+ label="Asset Type")
359
+
360
+ init_btn = gr.Button("๐Ÿš€ Initialize Environment", variant="primary")
361
+ init_status = gr.Textbox(label="Status", interactive=False)
362
+
363
+ with gr.Column(scale=2):
364
+ # Visualization output
365
+ plot_output = gr.Plot(label="Trading Dashboard")
366
+ status_output = gr.Textbox(label="Step Status", interactive=False)
367
+
368
+ with gr.Row():
369
+ # Action controls
370
+ action_choice = gr.Radio(["AI Decision", "Buy", "Sell", "Hold", "Close"],
371
+ value="AI Decision", label="Action Selection")
372
+ step_btn = gr.Button("โ–ถ๏ธ Execute Step", variant="secondary")
373
+ episode_btn = gr.Button("๐ŸŽฏ Run Episode (50 steps)", variant="secondary")
374
+
375
+ with gr.Row():
376
+ # Training section
377
+ gr.Markdown("## ๐ŸŽ“ AI Training")
378
+
379
+ with gr.Column():
380
+ num_episodes = gr.Slider(10, 1000, value=100, step=10,
381
+ label="Training Episodes")
382
+ learning_rate = gr.Slider(0.0001, 0.01, value=0.001, step=0.0001,
383
+ label="Learning Rate")
384
+ train_btn = gr.Button("๐Ÿค– Start Training", variant="primary")
385
+
386
+ with gr.Column():
387
+ training_plot = gr.Plot(label="Training Progress")
388
+ training_status = gr.Textbox(label="Training Status")
389
+
390
+ with gr.Row():
391
+ # Information section
392
+ gr.Markdown("## ๐Ÿ“Š Performance Metrics")
393
+ metrics = gr.DataFrame(
394
+ headers=["Metric", "Value"],
395
+ value=[["Total Steps", "0"], ["Total Reward", "0"],
396
+ ["Current Net Worth", "$10,000"], ["Best Action", "Hold"]],
397
+ row_count=4, col_count=2, interactive=False
398
+ )
399
+
400
+ # Event handlers
401
+ init_btn.click(
402
+ demo.initialize_environment,
403
+ inputs=[initial_balance, risk_level, asset_type],
404
+ outputs=[init_status]
405
+ )
406
+
407
+ step_btn.click(
408
+ demo.run_single_step,
409
+ inputs=[action_choice],
410
+ outputs=[plot_output, status_output]
411
+ )
412
+
413
+ episode_btn.click(
414
+ lambda: demo.run_episode(50),
415
+ outputs=[plot_output, status_output]
416
+ )
417
+
418
+ train_btn.click(
419
+ demo.train_agent,
420
+ inputs=[num_episodes, learning_rate],
421
+ outputs=[training_plot, training_status]
422
+ )
423
+
424
+ gr.Markdown("""
425
+ ## ๐Ÿ”ง How It Works
426
+
427
+ **Architecture:**
428
+ - **Visual Processing**: CNN analyzes price charts
429
+ - **Reinforcement Learning**: PPO algorithm learns trading strategies
430
+ - **Real-time Visualization**: Interactive dashboard shows agent decisions
431
+
432
+ **Features:**
433
+ - ๐ŸŽฏ Visual market analysis
434
+ - ๐Ÿค– Deep RL-based decision making
435
+ - ๐Ÿ“Š Real-time performance tracking
436
+ - ๐ŸŽฎ Interactive control
437
+ - ๐Ÿ“ˆ Professional visualization
438
+
439
+ *Built with PyTorch, Gym, and Gradio*
440
+ """)
441
+
442
+ return interface
443
+
444
+ # Create and launch interface
445
+ if __name__ == "__main__":
446
+ interface = create_interface()
447
+ interface.launch(
448
+ share=True,
449
+ server_name="0.0.0.0",
450
+ server_port=7860,
451
+ show_error=True
452
+ )