OmidSakaki commited on
Commit
ba725ea
·
verified ·
1 Parent(s): 85402cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +288 -908
app.py CHANGED
@@ -1,957 +1,337 @@
1
- """
2
- Advanced AI Trading Demo - Hugging Face Spaces
3
- Deep Q-Network (DQN) Reinforcement Learning for Financial Trading Simulation
4
-
5
- Author: AI Trading Team
6
- License: MIT
7
- """
8
-
9
  import gradio as gr
 
10
  import numpy as np
11
- import torch
12
- import torch.nn as nn
13
- import torch.optim as optim
14
- from collections import deque
15
- import random
16
- from typing import Dict, Tuple, Any, List, Optional, Generator
17
  import plotly.graph_objects as go
18
  from plotly.subplots import make_subplots
19
- import logging
20
- import os
21
- from datetime import datetime
22
- import json
23
- from dataclasses import dataclass # ✅ Added missing import
24
-
25
- # Configure logging
26
- logging.basicConfig(level=logging.INFO)
27
- logger = logging.getLogger(__name__)
28
-
29
- # ---- 1. Enhanced Configuration ----
30
- @dataclass
31
- class TradingConfig:
32
- """Central configuration for trading environment and agent."""
33
-
34
- # Financial Parameters
35
- initial_balance: float = 10000.0
36
- max_steps: int = 1000
37
- transaction_cost: float = 0.001 # 0.1%
38
-
39
- # Asset Settings
40
- asset_type: str = "Crypto"
41
- risk_level: str = "Medium"
42
-
43
- # DQN Parameters
44
- learning_rate: float = 0.0001
45
- gamma: float = 0.99
46
- epsilon_start: float = 1.0
47
- epsilon_min: float = 0.01
48
- epsilon_decay: float = 0.9995
49
- batch_size: int = 32
50
- memory_size: int = 10000
51
- target_update: int = 100
52
- hidden_size: int = 128
53
-
54
- # Risk Management
55
- risk_multipliers: Dict[str, float] = None
56
-
57
- def __post_init__(self):
58
- if self.risk_multipliers is None:
59
- self.risk_multipliers = {
60
- "Low": 0.5,
61
- "Medium": 1.0,
62
- "High": 2.0
63
- }
64
-
65
- # Asset-specific volatility
66
- self.volatility_map = {
67
- "Crypto": 0.03,
68
- "Stock": 0.015,
69
- "Forex": 0.008
70
- }
71
 
72
- # ---- 2. Professional Trading Environment ----
73
- class AdvancedTradingEnvironment:
74
- """
75
- Realistic financial market simulator with multi-asset support,
76
- technical indicators, sentiment analysis, and risk management.
77
- """
78
-
79
- def __init__(self, config: TradingConfig):
80
- self.config = config
81
- self._reset_state()
82
- self._initialize_market_data()
83
-
84
- # Action space: 0=Hold, 1=Buy, 2=Sell, 3=Close
85
- self.action_space = 4
86
- self.observation_space_dim = 15
87
 
88
- # Tracking
89
- self.portfolio_history = []
90
- self.action_history = []
91
- self.trade_log = []
92
-
93
- def _reset_state(self):
94
- """Reset internal state to initial conditions."""
95
- self.balance = self.config.initial_balance
96
- self.position = 0.0
97
- self.current_price = 100.0
98
- self.step_count = 0
99
- self.portfolio_history = []
100
- self.action_history = []
101
- self.trade_log = []
102
-
103
- def _initialize_market_data(self):
104
- """Generate initial market data with realistic dynamics."""
105
- n_points = 200
106
- volatility = self.config.volatility_map.get(
107
- self.config.asset_type, 0.02
108
- )
109
-
110
- self.price_history = []
111
- self.volume_history = []
112
- self.sentiment_history = []
113
-
114
- base_price = 100.0
115
- for i in range(n_points):
116
- # Generate price with momentum and noise
117
- momentum = np.sin(i * 0.05) * 2
118
- noise = np.random.normal(0, volatility)
119
- price = base_price * (1 + momentum * 0.01 + noise)
120
- self.price_history.append(max(10.0, price))
121
 
122
- # Volume correlated with price movement
123
- volume = 1000 + abs(price - base_price) * 50 + np.random.normal(0, 200)
124
- self.volume_history.append(max(100, volume))
125
 
126
- # Sentiment with mean reversion
127
- if i > 0:
128
- prev_sentiment = self.sentiment_history[-1]
129
- sentiment_change = np.random.normal(0, 0.08)
130
- sentiment = prev_sentiment + sentiment_change
131
- else:
132
- sentiment = 0.5 + np.random.normal(0, 0.1)
133
- self.sentiment_history.append(np.clip(sentiment, 0.0, 1.0))
134
 
135
- self.current_price = self.price_history[-1]
136
-
137
- def _calculate_technical_indicators(self) -> List[float]:
138
- """Compute comprehensive technical indicators."""
139
- prices = np.array(self.price_history[-50:])
140
- if len(prices) < 2:
141
- return [0.0] * 6
142
 
143
- # Price returns
144
- returns = np.diff(prices) / prices[:-1]
145
-
146
- # Moving averages
147
- sma_short = np.mean(prices[-10:]) if len(prices) >= 10 else prices[-1]
148
- sma_long = np.mean(prices[-20:]) if len(prices) >= 20 else prices[-1]
149
-
150
- # RSI
151
- if len(returns) >= 14:
152
- gains = returns[returns > 0]
153
- losses = -returns[returns < 0]
154
- avg_gain = np.mean(gains[-14:]) if len(gains) > 0 else 0.001
155
- avg_loss = np.mean(losses[-14:]) if len(losses) > 0 else 0.001
156
- rs = avg_gain / avg_loss if avg_loss != 0 else 100
157
- rsi = 100 - (100 / (1 + rs))
158
- else:
159
- rsi = 50.0
160
 
161
- # Volatility and momentum
162
- volatility = np.std(returns) * np.sqrt(252) if len(returns) > 1 else 0.1
163
- momentum = (prices[-1] / prices[-5] - 1) if len(prices) >= 5 else 0.0
164
-
165
- # Volume trend
166
- volumes = np.array(self.volume_history[-10:])
167
- volume_trend = (np.mean(volumes[-5:]) / np.mean(volumes[-10:]) - 1
168
- if len(volumes) >= 10 else 0.0)
169
-
170
- # Normalize features
171
- return [
172
- sma_short / 100.0,
173
- sma_long / 100.0,
174
- rsi / 100.0,
175
- volatility,
176
- momentum,
177
- volume_trend
178
- ]
179
-
180
- def reset(self) -> Tuple[np.ndarray, Dict]:
181
- """Reset environment to initial state."""
182
- self._reset_state()
183
- self._initialize_market_data() # ✅ Reinitialize market data on reset
184
- obs = self._get_observation()
185
- info = self._get_info()
186
- return obs, info
187
 
188
- def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
189
- """Execute one step in the environment."""
190
- self.step_count += 1
191
- self._update_market_data()
192
- reward = self._execute_action(action)
193
 
194
- terminated = (self.balance <= 0 or
195
- self.step_count >= self.config.max_steps)
196
- truncated = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
- obs = self._get_observation()
199
- info = self._get_info()
 
 
200
 
201
- self.portfolio_history.append(info['net_worth'])
202
- self.action_history.append(action)
 
203
 
204
- return obs, reward, terminated, truncated, info
205
-
206
- def _update_market_data(self):
207
- """Update market data with realistic dynamics."""
208
- # Momentum from recent returns
209
- recent_prices = self.price_history[-5:]
210
- if len(recent_prices) >= 2:
211
- prev_returns = np.diff(recent_prices) / recent_prices[:-1]
212
- momentum = np.mean(prev_returns)
213
  else:
214
- momentum = 0
215
-
216
- # Asset-specific volatility with risk multiplier
217
- base_vol = self.config.volatility_map.get(self.config.asset_type, 0.015)
218
- volatility = base_vol * self.config.risk_multipliers[self.config.risk_level]
219
-
220
- # Price evolution
221
- price_change = momentum * 0.3 + np.random.normal(0, volatility)
222
- self.current_price = max(10.0, self.current_price * (1 + price_change))
223
- self.price_history.append(self.current_price)
224
-
225
- # Volume generation
226
- base_volume = 1000
227
- volume_noise = np.random.normal(0, 200)
228
- new_volume = max(100, base_volume + abs(price_change) * 5000 + volume_noise)
229
- self.volume_history.append(new_volume)
230
 
231
- # Sentiment with mean reversion
232
- current_sentiment = self.sentiment_history[-1]
233
- sentiment_reversion = (0.5 - current_sentiment) * 0.1
234
- sentiment_noise = np.random.normal(0, 0.08)
235
- new_sentiment = np.clip(current_sentiment + sentiment_reversion + sentiment_noise, 0.0, 1.0)
236
- self.sentiment_history.append(new_sentiment)
237
 
238
- def _execute_action(self, action: int) -> float:
239
- """Execute trading action and calculate reward."""
240
- prev_net_worth = self.balance + self.position * self.current_price
241
-
242
- # Risk-adjusted position sizing
243
- trade_size_multiplier = 0.2 * self.config.risk_multipliers[self.config.risk_level]
244
-
245
- if action == 1: # Buy
246
- if self.balance > 0:
247
- trade_amount = min(self.balance * trade_size_multiplier, self.balance)
248
- cost = trade_amount * (1 + self.config.transaction_cost)
249
- if cost <= self.balance:
250
- shares_bought = trade_amount / self.current_price
251
- self.position += shares_bought
252
- self.balance -= cost
253
- self.trade_log.append({
254
- 'type': 'BUY',
255
- 'shares': shares_bought,
256
- 'price': self.current_price,
257
- 'timestamp': self.step_count
258
- })
259
-
260
- elif action == 2: # Sell
261
- if self.position > 0:
262
- sell_fraction = trade_size_multiplier
263
- shares_to_sell = min(self.position * sell_fraction, self.position)
264
- proceeds = shares_to_sell * self.current_price * (1 - self.config.transaction_cost)
265
- self.position -= shares_to_sell
266
- self.balance += proceeds
267
- self.trade_log.append({
268
- 'type': 'SELL',
269
- 'shares': shares_to_sell,
270
- 'price': self.current_price,
271
- 'timestamp': self.step_count
272
- })
273
-
274
- elif action == 3: # Close Position
275
- if self.position > 0:
276
- proceeds = self.position * self.current_price * (1 - self.config.transaction_cost)
277
- self.balance += proceeds
278
- self.trade_log.append({
279
- 'type': 'CLOSE',
280
- 'shares': self.position,
281
- 'price': self.current_price,
282
- 'timestamp': self.step_count
283
- })
284
- self.position = 0
285
-
286
- # Reward calculation with risk management
287
- new_net_worth = self.balance + self.position * self.current_price
288
- raw_reward = (new_net_worth - prev_net_worth) / self.config.initial_balance * 100
289
-
290
- # Risk penalty for significant drawdowns
291
- risk_penalty = 0.0
292
- if new_net_worth < self.config.initial_balance * 0.8:
293
- risk_penalty = (self.config.initial_balance - new_net_worth) / self.config.initial_balance * 10
294
 
295
- final_reward = raw_reward - risk_penalty
296
- return final_reward
297
-
298
- def _get_observation(self) -> np.ndarray:
299
- """Generate observation vector for the agent."""
300
- # Price features
301
- recent_prices = (self.price_history[-20:] if len(self.price_history) >= 20
302
- else [self.current_price] * 20)
303
- price_features = [
304
- self.current_price / 100.0,
305
- np.mean(recent_prices) / 100.0,
306
- np.std(recent_prices) / 100.0,
307
- (self.current_price - np.min(recent_prices)) /
308
- (np.max(recent_prices) - np.min(recent_prices) + 1e-8)
309
- ]
310
-
311
- # Portfolio features
312
- portfolio_features = [
313
- self.balance / self.config.initial_balance,
314
- (self.position * self.current_price) / self.config.initial_balance,
315
- self.step_count / self.config.max_steps
316
- ]
317
-
318
- # Sentiment features
319
- recent_sentiments = (self.sentiment_history[-10:] if len(self.sentiment_history) >= 10
320
- else [0.5] * 10)
321
- sentiment_features = [
322
- np.mean(recent_sentiments),
323
- np.std(recent_sentiments),
324
- recent_sentiments[-1]
325
- ]
326
 
327
- # Technical indicators
328
- technical_features = self._calculate_technical_indicators()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
- # Combine all features (should be 4 + 3 + 3 + 6 = 16, take first 15)
331
- all_features = price_features + portfolio_features + sentiment_features + technical_features
332
- observation = np.array(all_features[:15], dtype=np.float32)
333
 
334
- return observation
 
 
 
 
 
335
 
336
- def _get_info(self) -> Dict[str, Any]:
337
- """Get current environment information."""
338
- net_worth = self.balance + self.position * self.current_price
339
- return_total = ((net_worth - self.config.initial_balance) /
340
- self.config.initial_balance * 100)
 
 
 
 
 
341
 
342
  return {
343
- 'net_worth': net_worth,
344
- 'return_percent': return_total,
345
- 'position_value': self.position * self.current_price,
346
- 'cash_balance': self.balance,
347
- 'current_price': self.current_price,
348
- 'steps': self.step_count,
349
- 'position': self.position,
350
- 'balance': self.balance
 
351
  }
352
 
353
- # ---- 3. Deep Q-Network Agent ----
354
- class DQNAgent:
355
- """Deep Q-Network agent for reinforcement learning trading."""
356
-
357
- ACTION_NAMES = {0: 'Hold', 1: 'Buy', 2: 'Sell', 3: 'Close'}
 
 
358
 
359
- def __init__(self, state_dim: int, action_dim: int, config: TradingConfig,
360
- device: str = None):
361
- self.state_dim = state_dim
362
- self.action_dim = action_dim
363
- self.config = config
364
- self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
365
-
366
- # Networks
367
- self.q_network = self._build_network()
368
- self.target_network = self._build_network()
369
- self.target_network.load_state_dict(self.q_network.state_dict())
370
- self.target_network.eval()
371
-
372
- # Training components
373
- self.optimizer = optim.Adam(self.q_network.parameters(), lr=config.learning_rate)
374
- self.criterion = nn.MSELoss()
375
- self.memory = deque(maxlen=config.memory_size)
376
-
377
- # Epsilon-greedy parameters
378
- self.epsilon = config.epsilon_start
379
- self.epsilon_min = config.epsilon_min
380
- self.epsilon_decay = config.epsilon_decay
381
-
382
- # Training state
383
- self.batch_size = config.batch_size
384
- self.gamma = config.gamma
385
- self.target_update = config.target_update
386
- self.steps = 0
387
 
388
- def _build_network(self) -> nn.Module:
389
- """Build deep neural network for Q-value approximation."""
390
- return nn.Sequential(
391
- nn.Linear(self.state_dim, self.config.hidden_size),
392
- nn.ReLU(),
393
- nn.Dropout(0.2),
394
- nn.Linear(self.config.hidden_size, self.config.hidden_size),
395
- nn.ReLU(),
396
- nn.Dropout(0.2),
397
- nn.Linear(self.config.hidden_size, self.config.hidden_size // 2),
398
- nn.ReLU(),
399
- nn.Linear(self.config.hidden_size // 2, self.action_dim)
400
- ).to(self.device)
401
 
402
- @torch.no_grad()
403
- def select_action(self, state: np.ndarray, training: bool = True) -> int:
404
- """Select action using epsilon-greedy policy."""
405
- if training and random.random() < self.epsilon:
406
- return random.randint(0, self.action_dim - 1)
407
-
408
- state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
409
- q_values = self.q_network(state_tensor)
410
- return q_values.argmax(1).item()
411
 
412
- def store_transition(self, state: np.ndarray, action: int,
413
- reward: float, next_state: np.ndarray, done: bool):
414
- """Store experience tuple in replay memory."""
415
- self.memory.append((state, action, reward, next_state, done))
 
416
 
417
- def update(self) -> Optional[float]:
418
- """Update Q-network using experience replay."""
419
- if len(self.memory) < self.batch_size:
420
- return None
421
-
422
- batch = random.sample(self.memory, self.batch_size)
423
- states, actions, rewards, next_states, dones = zip(*batch)
424
-
425
- # Convert to tensors
426
- states = torch.FloatTensor(np.array(states)).to(self.device)
427
- actions = torch.LongTensor(actions).to(self.device)
428
- rewards = torch.FloatTensor(rewards).to(self.device)
429
- next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
430
- dones = torch.BoolTensor(dones).to(self.device)
431
-
432
- # Current Q-values
433
- current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
434
-
435
- # Target Q-values
436
- with torch.no_grad():
437
- next_q_values = self.target_network(next_states).max(1)[0]
438
- target_q_values = rewards + self.gamma * next_q_values * (~dones).float()
439
-
440
- # Compute loss and update
441
- loss = self.criterion(current_q_values, target_q_values)
442
- self.optimizer.zero_grad()
443
- loss.backward()
444
- torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 1.0)
445
- self.optimizer.step()
446
-
447
- self.steps += 1
448
-
449
- # Update target network
450
- if self.steps % self.target_update == 0:
451
- self.target_network.load_state_dict(self.q_network.state_dict())
452
-
453
- # Decay epsilon
454
- self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
455
-
456
- return loss.item()
457
 
458
- def save_checkpoint(self, path: str):
459
- """Save agent checkpoint."""
460
- checkpoint = {
461
- 'q_network': self.q_network.state_dict(),
462
- 'target_network': self.target_network.state_dict(),
463
- 'optimizer': self.optimizer.state_dict(),
464
- 'epsilon': self.epsilon,
465
- 'steps': self.steps,
466
- 'config': self.config.__dict__
467
- }
468
- torch.save(checkpoint, path)
469
- logger.info(f"Model saved to {path}")
470
 
471
- def load_checkpoint(self, path: str):
472
- """Load agent checkpoint."""
473
- if os.path.exists(path):
474
- try:
475
- checkpoint = torch.load(path, map_location=self.device)
476
- self.q_network.load_state_dict(checkpoint['q_network'])
477
- self.target_network.load_state_dict(checkpoint['target_network'])
478
- self.optimizer.load_state_dict(checkpoint['optimizer'])
479
- self.epsilon = checkpoint['epsilon']
480
- self.steps = checkpoint['steps']
481
- logger.info(f"Model loaded from {path}")
482
- except Exception as e:
483
- logger.warning(f"Failed to load model from {path}: {e}")
484
 
485
- # ---- 4. Main Trading Application ----
486
- class TradingDemo:
487
- """Main application integrating environment, agent, and UI."""
 
488
 
489
- def __init__(self):
490
- self.config = TradingConfig()
491
- self.env: Optional[AdvancedTradingEnvironment] = None
492
- self.agent: Optional[DQNAgent] = None
493
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
494
-
495
- # Training history
496
- self.training_history = {
497
- 'rewards': [],
498
- 'losses': [],
499
- 'epsilons': [],
500
- 'net_worths': []
501
- }
502
-
503
- # Model persistence
504
- self.model_path = "dqn_trading_model.pth"
505
- self.load_model_if_exists()
506
 
507
- def load_model_if_exists(self):
508
- """Load existing model if available."""
509
- self.agent = None # Reset agent first
510
- if os.path.exists(self.model_path):
511
- try:
512
- # Create agent first, then load
513
- temp_config = TradingConfig()
514
- temp_env = AdvancedTradingEnvironment(temp_config)
515
- self.agent = DQNAgent(
516
- state_dim=temp_env.observation_space_dim,
517
- action_dim=temp_env.action_space,
518
- config=temp_config,
519
- device=self.device
520
- )
521
- self.agent.load_checkpoint(self.model_path)
522
- logger.info("✅ Loaded existing model checkpoint")
523
- except Exception as e:
524
- logger.warning(f"Failed to load model: {e}")
525
- self.agent = None
526
 
527
- def initialize(self, balance: float, risk: str, asset: str) -> str:
528
- """Initialize trading system with new parameters."""
529
- try:
530
- self.config.initial_balance = float(balance)
531
- self.config.risk_level = risk
532
- self.config.asset_type = asset
533
-
534
- self.env = AdvancedTradingEnvironment(self.config)
535
- self.agent = DQNAgent(
536
- state_dim=self.env.observation_space_dim,
537
- action_dim=self.env.action_space,
538
- config=self.config,
539
- device=self.device
540
- )
541
-
542
- # Reset training history
543
- self.training_history = {
544
- 'rewards': [],
545
- 'losses': [],
546
- 'epsilons': [],
547
- 'net_worths': []
548
- }
549
-
550
- return (f"✅ System initialized successfully!\n"
551
- f"💰 Balance: ${balance:,.2f}\n"
552
- f"🎯 Risk: {risk}\n"
553
- f"📈 Asset: {asset}\n"
554
- f"💻 Device: {self.device}")
555
-
556
- except Exception as e:
557
- logger.error(f"Initialization failed: {e}")
558
- return f"❌ Initialization failed: {str(e)}"
559
 
560
- def train(self, episodes: int) -> Generator[Tuple[str, Optional[go.Figure]], None, None]:
561
- """Train the DQN agent."""
562
- if not self.env or not self.agent:
563
- yield " Please initialize the system first!", None
564
- return
565
-
566
- try:
567
- episodes = int(episodes)
568
- logger.info(f"Starting training for {episodes} episodes")
569
-
570
- for episode in range(episodes):
571
- obs, _ = self.env.reset()
572
- total_reward = 0
573
- episode_losses = []
574
- done = False
575
-
576
- while not done:
577
- action = self.agent.select_action(obs)
578
- next_obs, reward, terminated, truncated, info = self.env.step(action)
579
- done = terminated or truncated
580
-
581
- self.agent.store_transition(obs, action, reward, next_obs, done)
582
- loss = self.agent.update()
583
- if loss is not None:
584
- episode_losses.append(loss)
585
-
586
- total_reward += reward
587
- obs = next_obs
588
-
589
- # Record episode metrics
590
- avg_loss = np.mean(episode_losses) if episode_losses else 0.0
591
- self.training_history['rewards'].append(total_reward)
592
- self.training_history['losses'].append(avg_loss)
593
- self.training_history['epsilons'].append(self.agent.epsilon)
594
- self.training_history['net_worths'].append(info['net_worth'])
595
-
596
- # Progress update
597
- final_return = info['return_percent']
598
- progress = (f"Episode {episode+1}/{episodes}\n"
599
- f"📈 Reward: {total_reward:.2f}\n"
600
- f"📉 Loss: {avg_loss:.4f}\n"
601
- f"🎯 Epsilon: {self.agent.epsilon:.3f}\n"
602
- f"💰 Net Worth: ${info['net_worth']:,.2f}\n"
603
- f"📊 Return: {final_return:.2f}%")
604
-
605
- # Generate plot every 10 episodes or at end
606
- if (episode + 1) % 10 == 0 or episode == episodes - 1:
607
- plot = self._create_training_plot()
608
- yield progress, plot
609
- else:
610
- yield progress, None
611
-
612
- # Save trained model
613
- if self.agent:
614
- self.agent.save_checkpoint(self.model_path)
615
- yield "✅ Training completed! Model saved.", self._create_training_plot()
616
-
617
- except Exception as e:
618
- logger.error(f"Training error: {e}")
619
- yield f"❌ Training failed: {str(e)}", None
620
 
621
- def simulate(self, steps: int) -> Tuple[str, Optional[go.Figure]]:
622
- """Run trading simulation with trained agent."""
623
- if not self.env or not self.agent:
624
- return "❌ Please initialize and train the system first!", None
625
-
626
- try:
627
- steps = int(steps)
628
- obs, _ = self.env.reset()
629
-
630
- prices, actions, net_worths = [], [], []
631
- portfolio_values, cash_balances = [], []
632
-
633
- for step in range(steps):
634
- action = self.agent.select_action(obs, training=False)
635
- next_obs, _, terminated, truncated, info = self.env.step(action)
636
-
637
- prices.append(self.env.current_price)
638
- actions.append(action)
639
- net_worths.append(info['net_worth'])
640
- portfolio_values.append(info['position_value'])
641
- cash_balances.append(info['cash_balance'])
642
-
643
- obs = next_obs
644
- if terminated or truncated:
645
- break
646
-
647
- plot = self._create_simulation_plot(
648
- prices, actions, net_worths, portfolio_values, cash_balances
649
- )
650
-
651
- final_return = ((net_worths[-1] - self.config.initial_balance) /
652
- self.config.initial_balance * 100)
653
-
654
- last_action_name = DQNAgent.ACTION_NAMES.get(actions[-1], 'Unknown')
655
- result = (f"✅ Simulation completed!\n"
656
- f"📈 Steps: {len(prices)}\n"
657
- f"💰 Final Net Worth: ${net_worths[-1]:,.2f}\n"
658
- f"📊 Total Return: {final_return:.2f}%\n"
659
- f"🎯 Final Action: {last_action_name}")
660
-
661
- return result, plot
662
-
663
- except Exception as e:
664
- logger.error(f"Simulation error: {e}")
665
- return f"❌ Simulation failed: {str(e)}", None
666
 
667
- def _create_training_plot(self) -> Optional[go.Figure]:
668
- """Create comprehensive training progress visualization."""
669
- if not self.training_history['rewards']:
670
- return None
671
-
672
- episodes = list(range(1, len(self.training_history['rewards']) + 1))
673
-
674
- fig = make_subplots(
675
- rows=2, cols=2,
676
- subplot_titles=('Episode Rewards', 'Training Loss', 'Epsilon Decay', 'Portfolio Performance'),
677
- vertical_spacing=0.12
678
- )
679
-
680
- # Rewards
681
- fig.add_trace(
682
- go.Scatter(x=episodes, y=self.training_history['rewards'],
683
- mode='lines+markers', name='Total Reward',
684
- line=dict(color='blue', width=2)),
685
- row=1, col=1
686
- )
687
-
688
- # Moving average
689
- window = min(20, len(episodes))
690
- ma_rewards = [np.mean(self.training_history['rewards'][max(0, i-window):i+1])
691
- for i in range(len(episodes))]
692
- fig.add_trace(
693
- go.Scatter(x=episodes, y=ma_rewards, mode='lines',
694
- name='MA Reward', line=dict(color='orange', width=3)),
695
- row=1, col=1
696
- )
697
-
698
- # Losses
699
- fig.add_trace(
700
- go.Scatter(x=episodes, y=self.training_history['losses'],
701
- mode='lines', name='Loss', line=dict(color='red')),
702
- row=1, col=2
703
- )
704
-
705
- # Epsilon
706
- fig.add_trace(
707
- go.Scatter(x=episodes, y=self.training_history['epsilons'],
708
- mode='lines', name='Epsilon', line=dict(color='green')),
709
- row=2, col=1
710
- )
711
-
712
- # Portfolio performance
713
- returns = [(nw - self.config.initial_balance) / self.config.initial_balance * 100
714
- for nw in self.training_history['net_worths']]
715
- fig.add_trace(
716
- go.Scatter(x=episodes, y=self.training_history['net_worths'],
717
- mode='lines', name='Net Worth',
718
- line=dict(color='blue'), yaxis='y'),
719
- row=2, col=2
720
- )
721
- fig.add_trace(
722
- go.Scatter(x=episodes, y=returns, mode='lines',
723
- name='Return %', line=dict(color='purple'), yaxis='y2'),
724
- row=2, col=2
725
- )
726
-
727
- fig.update_layout(
728
- height=700,
729
- showlegend=True,
730
- title_text="🧠 DQN Training Progress",
731
- hovermode='x unified'
732
- )
733
-
734
- fig.update_yaxes(title_text="Return (%)", secondary_y=True, row=2, col=2)
735
- fig.update_yaxes(title_text="Net Worth ($)", row=2, col=2)
736
-
737
- return fig
738
 
739
- def _create_simulation_plot(self, prices: List[float], actions: List[int],
740
- net_worths: List[float], portfolio_values: List[float],
741
- cash_balances: List[float]) -> go.Figure:
742
- """Create detailed simulation results visualization."""
743
- steps = list(range(len(prices)))
744
-
745
- fig = make_subplots(
746
- rows=2, cols=2,
747
- subplot_titles=('Price Action & Trading Signals', 'Portfolio Performance',
748
- 'Portfolio Allocation', 'Action Distribution'),
749
- vertical_spacing=0.12
750
- )
751
-
752
- # Price and actions
753
- fig.add_trace(
754
- go.Scatter(x=steps, y=prices, mode='lines', name='Asset Price',
755
- line=dict(color='blue', width=2)),
756
- row=1, col=1
757
- )
758
-
759
- # Action markers
760
- action_colors = ['gray', 'green', 'red', 'orange']
761
- action_names = list(DQNAgent.ACTION_NAMES.values())
762
-
763
- for action, color, name in zip(range(4), action_colors, action_names):
764
- action_steps = [i for i, a in enumerate(actions) if a == action]
765
- if action_steps:
766
- action_prices = [prices[i] for i in action_steps]
767
- fig.add_trace(
768
- go.Scatter(x=action_steps, y=action_prices, mode='markers',
769
- name=f'{name}',
770
- marker=dict(color=color, size=8, symbol='triangle-up')),
771
- row=1, col=1
772
- )
773
-
774
- # Portfolio performance
775
- initial_balance = self.config.initial_balance
776
- returns = [(nw - initial_balance) / initial_balance * 100 for nw in net_worths]
777
-
778
- fig.add_trace(
779
- go.Scatter(x=steps, y=net_worths, mode='lines', name='Net Worth',
780
- line=dict(color='purple', width=2)),
781
- row=1, col=2
782
- )
783
- fig.add_trace(
784
- go.Scatter(x=steps, y=returns, mode='lines', name='Returns %',
785
- line=dict(color='orange', width=2), yaxis='y2'),
786
- row=1, col=2
787
- )
788
-
789
- # Portfolio composition
790
- fig.add_trace(
791
- go.Scatter(x=steps, y=portfolio_values, mode='lines',
792
- name='Portfolio Value', line=dict(color='green')),
793
- row=2, col=1
794
- )
795
- fig.add_trace(
796
- go.Scatter(x=steps, y=cash_balances, mode='lines',
797
- name='Cash Balance', line=dict(color='blue')),
798
- row=2, col=1
799
- )
800
-
801
- # Action distribution
802
- action_counts = [actions.count(i) for i in range(4)]
803
- fig.add_trace(
804
- go.Bar(x=action_names, y=action_counts, name='Action Frequency',
805
- marker_color=action_colors),
806
- row=2, col=2
807
- )
808
-
809
- fig.update_layout(
810
- height=700,
811
- showlegend=True,
812
- title_text="📈 Trading Simulation Results",
813
- hovermode='x unified'
814
- )
815
-
816
- fig.update_yaxes(title_text="Returns (%)", secondary_y=True, row=1, col=2)
817
- fig.update_yaxes(title_text="Value ($)", row=1, col=2)
818
-
819
- return fig
820
-
821
- # ---- 5. Gradio Interface for Hugging Face ----
822
- def create_interface() -> gr.Blocks:
823
- """Create professional Gradio interface."""
824
- demo = TradingDemo()
825
 
826
- with gr.Blocks(
827
- theme=gr.themes.Soft(),
828
- title="🤖 Advanced AI Trading Demo",
829
- css="""
830
- .gradio-container {max-width: 1400px !important;}
831
- .status-box {background-color: #f0f9ff; padding: 1rem; border-radius: 8px; border-left: 4px solid #3b82f6;}
832
- """
833
- ) as interface:
834
-
835
- gr.Markdown("""
836
- # 🤖 Advanced AI Trading Demo
837
- **Deep Reinforcement Learning for Financial Markets**
838
 
839
- This demo showcases a **Deep Q-Network (DQN)** agent learning to trade in simulated financial markets with realistic market dynamics, technical indicators, and risk management.
840
- """)
841
-
842
- # Configuration Row
843
- with gr.Row():
844
- with gr.Column(scale=1):
845
- gr.Markdown("## 🎯 System Configuration")
846
- with gr.Group():
847
- balance = gr.Slider(
848
- 1000, 50000, 10000, step=1000,
849
- label="💰 Initial Balance ($)"
850
- )
851
- risk = gr.Radio(
852
- ["Low", "Medium", "High"], value="Medium",
853
- label="🎯 Risk Level"
854
- )
855
- asset = gr.Radio(
856
- ["Crypto", "Stock", "Forex"], value="Crypto",
857
- label="📈 Asset Type"
858
- )
859
- init_btn = gr.Button("🚀 Initialize Trading System", variant="primary")
860
-
861
- with gr.Column(scale=2):
862
- gr.Markdown("## 📊 System Status")
863
- status = gr.Markdown(
864
- value="👋 Welcome! Configure parameters and click **Initialize** to begin.",
865
- elem_classes=["status-box"]
866
- )
867
-
868
- # Training and Simulation Row
869
- with gr.Row():
870
- with gr.Column():
871
- gr.Markdown("## 🏋️‍♂️ Train AI Agent")
872
- with gr.Group():
873
- episodes = gr.Number(
874
- value=50, label="🎯 Training Episodes", precision=0
875
- )
876
- train_btn = gr.Button("🎓 Start Training", variant="primary")
877
- training_output = gr.Textbox(
878
- label="Training Progress", lines=6, interactive=False
879
- )
880
- train_plot = gr.Plot(label="📈 Training Progress")
881
-
882
- with gr.Column():
883
- gr.Markdown("## ▶️ Test Trained Agent")
884
- with gr.Group():
885
- sim_steps = gr.Number(
886
- value=200, label="📊 Simulation Steps", precision=0
887
- )
888
- sim_btn = gr.Button("🎮 Run Simulation", variant="primary")
889
- sim_output = gr.Textbox(
890
- label="Simulation Results", lines=4, interactive=False
891
- )
892
- sim_plot = gr.Plot(label="📊 Trading Results")
893
-
894
- # Event Handlers
895
- def initialize_wrapper(balance, risk, asset):
896
- return demo.initialize(balance, risk, asset)
897
-
898
- def simulate_wrapper(steps):
899
- return demo.simulate(steps)
900
-
901
- def train_generator(episodes):
902
- try:
903
- for status_text, plot in demo.train(int(episodes)):
904
- yield status_text, plot
905
- except Exception as e:
906
- yield f"❌ Training error: {str(e)}", None
907
-
908
- init_btn.click(
909
- fn=initialize_wrapper,
910
- inputs=[balance, risk, asset],
911
- outputs=status
912
- )
913
-
914
- train_btn.queue().click(
915
- fn=train_generator,
916
- inputs=episodes,
917
- outputs=[training_output, train_plot]
918
- )
919
-
920
- sim_btn.click(
921
- fn=simulate_wrapper,
922
- inputs=sim_steps,
923
- outputs=[sim_output, sim_plot]
924
- )
925
-
926
- gr.Markdown("""
927
- ## 📖 Usage Instructions
928
- 1. **Configure** your trading parameters
929
- 2. **Initialize** the trading system
930
- 3. **Train** the AI agent (50+ episodes recommended)
931
- 4. **Simulate** trading with the trained agent
932
 
933
- ## 🎮 Trading Actions
934
- - **Hold (0)**: Maintain current position
935
- - **Buy (1)**: Purchase assets (risk-adjusted)
936
- - **Sell (2)**: Sell portion of position
937
- - **Close (3)**: Liquidate entire position
938
- """)
 
 
 
 
 
939
 
940
- return interface
 
 
 
 
 
 
 
 
 
 
 
941
 
942
- # ---- 6. Hugging Face Spaces Entry Point ----
943
  if __name__ == "__main__":
944
- try:
945
- interface = create_interface()
946
- interface.launch(
947
- server_name="0.0.0.0",
948
- server_port=7860,
949
- share=False,
950
- show_error=True,
951
- enable_queue=True,
952
- max_threads=40,
953
- debug=False
954
- )
955
- except Exception as e:
956
- logger.error(f"Failed to launch application: {e}")
957
- raise
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import pandas as pd
3
  import numpy as np
4
+ import yfinance as yf
 
 
 
 
 
5
  import plotly.graph_objects as go
6
  from plotly.subplots import make_subplots
7
+ from datetime import datetime, timedelta
8
+ import warnings
9
+ warnings.filterwarnings('ignore')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Market Data Provider
12
+ class MarketDataProvider:
13
+ def __init__(self, symbols=['AAPL', 'GOOGL', 'MSFT', 'TSLA']):
14
+ self.symbols = symbols
15
+ self.data_cache = {}
 
 
 
 
 
 
 
 
 
 
16
 
17
+ def get_stock_data(self, symbol, period='1mo'):
18
+ try:
19
+ cache_key = f"{symbol}_{period}"
20
+ if cache_key in self.data_cache:
21
+ return self.data_cache[cache_key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ ticker = yf.Ticker(symbol)
24
+ hist_data = ticker.history(period=period)
 
25
 
26
+ if hist_data.empty:
27
+ return self.generate_simulated_data(symbol)
 
 
 
 
 
 
28
 
29
+ data = {
30
+ 'prices': hist_data['Close'].tolist(),
31
+ 'dates': hist_data.index.strftime('%Y-%m-%d').tolist(),
32
+ 'volume': hist_data['Volume'].tolist(),
33
+ 'current_price': hist_data['Close'].iloc[-1],
34
+ 'change': ((hist_data['Close'].iloc[-1] - hist_data['Close'].iloc[0]) / hist_data['Close'].iloc[0]) * 100
35
+ }
36
 
37
+ self.data_cache[cache_key] = data
38
+ return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ except Exception as e:
41
+ print(f"Error fetching {symbol}: {e}")
42
+ return self.generate_simulated_data(symbol)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ def generate_simulated_data(self, symbol):
45
+ base_price = np.random.uniform(100, 200)
46
+ days = 30
47
+ prices = [base_price * (1 + np.random.normal(0, 0.02)) for _ in range(days)]
48
+ dates = [(datetime.now() - timedelta(days=i)).strftime('%Y-%m-%d') for i in range(days, 0, -1)]
49
 
50
+ return {
51
+ 'prices': prices,
52
+ 'dates': dates,
53
+ 'volume': [np.random.randint(1000000, 5000000) for _ in range(days)],
54
+ 'current_price': prices[-1],
55
+ 'change': ((prices[-1] - prices[0]) / prices[0]) * 100
56
+ }
57
+
58
+ # AI Trading Agents
59
+ class TradingAgents:
60
+ def __init__(self):
61
+ self.agents_config = {
62
+ 'research': {
63
+ 'name': 'Financial Research Agent',
64
+ 'emoji': '📊',
65
+ 'prompt_template': """Analyze {symbol} stock:
66
+ Price: ${current_price:.2f}
67
+ Change: {change:+.2f}%
68
+ Trend: {trend}
69
+
70
+ Provide fundamental analysis and recommendation:"""
71
+ },
72
+ 'technical': {
73
+ 'name': 'Technical Analysis Agent',
74
+ 'emoji': '📈',
75
+ 'prompt_template': """Technical analysis for {symbol}:
76
+ Price: ${current_price:.2f}
77
+ Trend: {trend}
78
+ Volatility: {volatility:.1f}%
79
+
80
+ Provide technical levels and trading signals:"""
81
+ },
82
+ 'risk': {
83
+ 'name': 'Risk Management Agent',
84
+ 'emoji': '🛡️',
85
+ 'prompt_template': """Risk assessment for {symbol}:
86
+ Price: ${current_price:.2f}
87
+ Volatility: {volatility:.1f}%
88
+
89
+ Provide risk management strategy:"""
90
+ }
91
+ }
92
 
93
+ def calculate_metrics(self, price_data):
94
+ prices = price_data['prices']
95
+ if len(prices) < 2:
96
+ return {"trend": "Neutral", "volatility": 0}
97
 
98
+ price_change = ((prices[-1] - prices[0]) / prices[0]) * 100
99
+ returns = np.diff(prices) / prices[:-1]
100
+ volatility = np.std(returns) * np.sqrt(252) * 100
101
 
102
+ if price_change > 5:
103
+ trend = "Strong Bullish"
104
+ elif price_change > 0:
105
+ trend = "Mild Bullish"
 
 
 
 
 
106
  else:
107
+ trend = "Bearish"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ return {
110
+ "trend": f"{trend} ({price_change:+.1f}%)",
111
+ "volatility": volatility,
112
+ "price_change": price_change
113
+ }
 
114
 
115
+ def analyze_stock(self, symbol, price_data):
116
+ metrics = self.calculate_metrics(price_data)
117
+ current_price = price_data['current_price']
118
+
119
+ analyses = {}
120
+
121
+ for agent_type, config in self.agents_config.items():
122
+ prompt = config['prompt_template'].format(
123
+ symbol=symbol,
124
+ current_price=current_price,
125
+ change=metrics['price_change'],
126
+ trend=metrics['trend'],
127
+ volatility=metrics['volatility']
128
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ analysis = self.get_agent_response(agent_type, prompt)
131
+ analyses[agent_type] = {
132
+ 'name': config['name'],
133
+ 'emoji': config['emoji'],
134
+ 'analysis': analysis
135
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ analyses['decision'] = self.generate_final_decision(symbol, current_price, analyses, metrics)
138
+ return analyses
139
+
140
+ def get_agent_response(self, agent_type, prompt):
141
+ responses = {
142
+ 'research': [
143
+ "📊 STRONG FUNDAMENTALS: Positive earnings growth, market leadership position. Institutional accumulation visible. RECOMMENDATION: BUY with 80% confidence. Target 15-20% upside.",
144
+ "📊 MIXED FUNDAMENTALS: Valuation concerns offset by solid cash flow. Competitive pressures increasing. RECOMMENDATION: HOLD and monitor for better entry.",
145
+ "📊 EXCELLENT GROWTH: Innovative product pipeline, expanding market share. Strong balance sheet. RECOMMENDATION: STRONG BUY for long-term growth."
146
+ ],
147
+ 'technical': [
148
+ "📈 BULLISH PATTERN: Breakout above resistance. Support at ${support:.2f}. RSI: 65. ENTRY: Current levels. TARGET: ${target:.2f}.",
149
+ "📈 CONSOLIDATION PHASE: Trading range ${support:.2f}-${resistance:.2f}. Wait for breakout confirmation. Volume declining.",
150
+ "📈 STRONG UPTREND: Higher highs and higher lows. Volume confirmation. Fibonacci target: ${target:.2f}. Stop: ${stop:.2f}."
151
+ ],
152
+ 'risk': [
153
+ "🛡️ MODERATE RISK: Position size 3-4%. Stop-loss 8%. Risk-reward 1:2.5. Maximum drawdown 12%.",
154
+ "🛡️ CONSERVATIVE: Position size 2-3%. Stop-loss 10% trailing. Monitor earnings date closely.",
155
+ "🛡️ FAVORABLE: Position size 4-5%. Stop-loss 6%. Risk-reward 1:3.0. Low portfolio correlation."
156
+ ]
157
+ }
158
 
159
+ response = np.random.choice(responses[agent_type])
160
+ current_price = np.random.uniform(150, 250)
 
161
 
162
+ return response.format(
163
+ support=current_price * 0.95,
164
+ resistance=current_price * 1.08,
165
+ target=current_price * 1.15,
166
+ stop=current_price * 0.92
167
+ )
168
 
169
+ def generate_final_decision(self, symbol, current_price, analyses, metrics):
170
+ if metrics['price_change'] > 3 and metrics['volatility'] < 25:
171
+ decision = "BUY"
172
+ confidence = np.random.randint(75, 90)
173
+ elif metrics['price_change'] < -3:
174
+ decision = "SELL"
175
+ confidence = np.random.randint(70, 85)
176
+ else:
177
+ decision = "HOLD"
178
+ confidence = np.random.randint(60, 80)
179
 
180
  return {
181
+ 'name': 'Final Decision',
182
+ 'emoji': '🎯',
183
+ 'analysis': f"""🎯 FINAL DECISION: {decision}
184
+
185
+ Confidence: {confidence}%
186
+ Price: ${current_price:.2f}
187
+ Position Size: {max(2, min(5, 8 - metrics['volatility']/10))}%
188
+
189
+ Action: {'Enter long position' if decision == 'BUY' else 'Wait for better setup'}"""
190
  }
191
 
192
+ # Initialize components
193
+ market_data = MarketDataProvider()
194
+ trading_agents = TradingAgents()
195
+
196
+ # Gradio Interface Functions
197
+ def create_stock_chart(symbol, price_data):
198
+ fig = go.Figure()
199
 
200
+ fig.add_trace(go.Scatter(
201
+ x=price_data['dates'],
202
+ y=price_data['prices'],
203
+ mode='lines+markers',
204
+ name=f'{symbol} Price',
205
+ line=dict(color='#00D4AA', width=3),
206
+ marker=dict(size=6)
207
+ ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
+ fig.update_layout(
210
+ title=f'{symbol} Price Chart - 30 Days',
211
+ xaxis_title='Date',
212
+ yaxis_title='Price ($)',
213
+ template='plotly_dark',
214
+ height=300,
215
+ margin=dict(l=50, r=50, t=50, b=50)
216
+ )
 
 
 
 
 
217
 
218
+ return fig
219
+
220
+ def create_performance_dashboard(stocks_data):
221
+ symbols = list(stocks_data.keys())
222
+ changes = [stocks_data[symbol]['change'] for symbol in symbols]
223
+ prices = [stocks_data[symbol]['current_price'] for symbol in symbols]
 
 
 
224
 
225
+ fig = make_subplots(
226
+ rows=1, cols=2,
227
+ subplot_titles=['30-Day Performance (%)', 'Current Prices ($)'],
228
+ specs=[[{"type": "bar"}, {"type": "bar"}]]
229
+ )
230
 
231
+ colors = ['#00D4AA' if x >= 0 else '#FF6B6B' for x in changes]
232
+ fig.add_trace(go.Bar(x=symbols, y=changes, marker_color=colors), row=1, col=1)
233
+ fig.add_trace(go.Bar(x=symbols, y=prices, marker_color='#636EFA'), row=1, col=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ fig.update_layout(
236
+ title='Stock Performance Overview',
237
+ template='plotly_dark',
238
+ height=400,
239
+ showlegend=False
240
+ )
 
 
 
 
 
 
241
 
242
+ return fig
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
+ def analyze_single_stock(symbol):
245
+ """Analyze a single stock"""
246
+ price_data = market_data.get_stock_data(symbol)
247
+ analyses = trading_agents.analyze_stock(symbol, price_data)
248
 
249
+ chart = create_stock_chart(symbol, price_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
+ analysis_output = f"# {symbol} Analysis Report\n\n"
252
+ analysis_output += f"**Current Price:** ${price_data['current_price']:.2f}\n"
253
+ analysis_output += f"**30-Day Change:** {price_data['change']:+.2f}%\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
+ for agent_type, analysis in analyses.items():
256
+ analysis_output += f"## {analysis['emoji']} {analysis['name']}\n"
257
+ analysis_output += f"{analysis['analysis']}\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
+ return chart, analysis_output
260
+
261
+ def analyze_all_stocks():
262
+ """Analyze all tracked stocks"""
263
+ stocks_data = {}
264
+ for symbol in market_data.symbols:
265
+ stocks_data[symbol] = market_data.get_stock_data(symbol)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
+ dashboard = create_performance_dashboard(stocks_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
+ analysis_output = "# Multi-Agent Trading Analysis\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
+ for symbol in market_data.symbols:
272
+ analyses = trading_agents.analyze_stock(symbol, stocks_data[symbol])
273
+ decision_line = analyses['decision']['analysis'].split('\n')[0]
274
+
275
+ analysis_output += f"## {symbol}\n"
276
+ analysis_output += f"**Price:** ${stocks_data[symbol]['current_price']:.2f} | "
277
+ analysis_output += f"**Change:** {stocks_data[symbol]['change']:+.2f}%\n"
278
+ analysis_output += f"**Decision:** {decision_line}\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
+ return dashboard, analysis_output
 
 
 
 
 
 
 
 
 
 
 
281
 
282
+ def update_analysis(symbol_input):
283
+ """Update analysis based on user input"""
284
+ if symbol_input:
285
+ symbol = symbol_input.upper().strip()
286
+ return analyze_single_stock(symbol)
287
+ else:
288
+ return analyze_all_stocks()
289
+
290
+ # Gradio Interface
291
+ with gr.Blocks(theme=gr.themes.Soft(), title="AI Trading Agents") as demo:
292
+ gr.Markdown("""
293
+ # 🤖 Multi-Agent AI Trading System
294
+ **Professional stock analysis powered by AI agents**
295
+ """)
296
+
297
+ with gr.Row():
298
+ with gr.Column(scale=1):
299
+ symbol_input = gr.Textbox(
300
+ label="Enter Stock Symbol (e.g., AAPL, TSLA)",
301
+ placeholder="Leave empty for all tracked stocks...",
302
+ max_lines=1
303
+ )
304
+ analyze_btn = gr.Button("Analyze Stock", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
+ with gr.Column(scale=2):
307
+ gr.Markdown("### Live Market Analysis")
308
+
309
+ with gr.Tabs():
310
+ with gr.TabItem("📈 Charts"):
311
+ with gr.Row():
312
+ chart_output = gr.Plot(label="Price Chart")
313
+ dashboard_output = gr.Plot(label="Performance Dashboard")
314
+
315
+ with gr.TabItem("📊 Analysis"):
316
+ analysis_output = gr.Markdown(label="AI Analysis Report")
317
 
318
+ # Event handlers
319
+ analyze_btn.click(
320
+ fn=update_analysis,
321
+ inputs=[symbol_input],
322
+ outputs=[chart_output, analysis_output]
323
+ )
324
+
325
+ # Load initial analysis
326
+ demo.load(
327
+ fn=analyze_all_stocks,
328
+ outputs=[dashboard_output, analysis_output]
329
+ )
330
 
331
+ # For Hugging Face Spaces
332
  if __name__ == "__main__":
333
+ demo.launch(
334
+ server_name="0.0.0.0",
335
+ server_port=7860,
336
+ share=True
337
+ )