Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
Advanced AI Trading Demo - Hugging Face Spaces
|
| 3 |
Deep Q-Network (DQN) Reinforcement Learning for Financial Trading Simulation
|
| 4 |
|
| 5 |
-
Author:
|
| 6 |
License: MIT
|
| 7 |
"""
|
| 8 |
|
|
@@ -20,6 +20,7 @@ import logging
|
|
| 20 |
import os
|
| 21 |
from datetime import datetime
|
| 22 |
import json
|
|
|
|
| 23 |
|
| 24 |
# Configure logging
|
| 25 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -152,7 +153,8 @@ class AdvancedTradingEnvironment:
|
|
| 152 |
losses = -returns[returns < 0]
|
| 153 |
avg_gain = np.mean(gains[-14:]) if len(gains) > 0 else 0.001
|
| 154 |
avg_loss = np.mean(losses[-14:]) if len(losses) > 0 else 0.001
|
| 155 |
-
|
|
|
|
| 156 |
else:
|
| 157 |
rsi = 50.0
|
| 158 |
|
|
@@ -178,6 +180,7 @@ class AdvancedTradingEnvironment:
|
|
| 178 |
def reset(self) -> Tuple[np.ndarray, Dict]:
|
| 179 |
"""Reset environment to initial state."""
|
| 180 |
self._reset_state()
|
|
|
|
| 181 |
obs = self._get_observation()
|
| 182 |
info = self._get_info()
|
| 183 |
return obs, info
|
|
@@ -302,7 +305,7 @@ class AdvancedTradingEnvironment:
|
|
| 302 |
np.mean(recent_prices) / 100.0,
|
| 303 |
np.std(recent_prices) / 100.0,
|
| 304 |
(self.current_price - np.min(recent_prices)) /
|
| 305 |
-
(np.max(recent_prices) - np.min(recent_prices) + 1e-8)
|
| 306 |
]
|
| 307 |
|
| 308 |
# Portfolio features
|
|
@@ -324,9 +327,8 @@ class AdvancedTradingEnvironment:
|
|
| 324 |
# Technical indicators
|
| 325 |
technical_features = self._calculate_technical_indicators()
|
| 326 |
|
| 327 |
-
# Combine
|
| 328 |
-
all_features =
|
| 329 |
-
technical_features[:6]) # Ensure exactly 15 features
|
| 330 |
observation = np.array(all_features[:15], dtype=np.float32)
|
| 331 |
|
| 332 |
return observation
|
|
@@ -464,15 +466,21 @@ class DQNAgent:
|
|
| 464 |
'config': self.config.__dict__
|
| 465 |
}
|
| 466 |
torch.save(checkpoint, path)
|
|
|
|
| 467 |
|
| 468 |
def load_checkpoint(self, path: str):
|
| 469 |
"""Load agent checkpoint."""
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
|
| 477 |
# ---- 4. Main Trading Application ----
|
| 478 |
class TradingDemo:
|
|
@@ -498,12 +506,23 @@ class TradingDemo:
|
|
| 498 |
|
| 499 |
def load_model_if_exists(self):
|
| 500 |
"""Load existing model if available."""
|
|
|
|
| 501 |
if os.path.exists(self.model_path):
|
| 502 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
self.agent.load_checkpoint(self.model_path)
|
| 504 |
-
logger.info("Loaded existing model checkpoint")
|
| 505 |
except Exception as e:
|
| 506 |
logger.warning(f"Failed to load model: {e}")
|
|
|
|
| 507 |
|
| 508 |
def initialize(self, balance: float, risk: str, asset: str) -> str:
|
| 509 |
"""Initialize trading system with new parameters."""
|
|
@@ -556,7 +575,8 @@ class TradingDemo:
|
|
| 556 |
|
| 557 |
while not done:
|
| 558 |
action = self.agent.select_action(obs)
|
| 559 |
-
next_obs, reward,
|
|
|
|
| 560 |
|
| 561 |
self.agent.store_transition(obs, action, reward, next_obs, done)
|
| 562 |
loss = self.agent.update()
|
|
@@ -590,7 +610,8 @@ class TradingDemo:
|
|
| 590 |
yield progress, None
|
| 591 |
|
| 592 |
# Save trained model
|
| 593 |
-
self.agent
|
|
|
|
| 594 |
yield "โ
Training completed! Model saved.", self._create_training_plot()
|
| 595 |
|
| 596 |
except Exception as e:
|
|
@@ -611,7 +632,7 @@ class TradingDemo:
|
|
| 611 |
|
| 612 |
for step in range(steps):
|
| 613 |
action = self.agent.select_action(obs, training=False)
|
| 614 |
-
next_obs, _,
|
| 615 |
|
| 616 |
prices.append(self.env.current_price)
|
| 617 |
actions.append(action)
|
|
@@ -620,7 +641,7 @@ class TradingDemo:
|
|
| 620 |
cash_balances.append(info['cash_balance'])
|
| 621 |
|
| 622 |
obs = next_obs
|
| 623 |
-
if
|
| 624 |
break
|
| 625 |
|
| 626 |
plot = self._create_simulation_plot(
|
|
@@ -630,11 +651,12 @@ class TradingDemo:
|
|
| 630 |
final_return = ((net_worths[-1] - self.config.initial_balance) /
|
| 631 |
self.config.initial_balance * 100)
|
| 632 |
|
|
|
|
| 633 |
result = (f"โ
Simulation completed!\n"
|
| 634 |
f"๐ Steps: {len(prices)}\n"
|
| 635 |
f"๐ฐ Final Net Worth: ${net_worths[-1]:,.2f}\n"
|
| 636 |
f"๐ Total Return: {final_return:.2f}%\n"
|
| 637 |
-
f"๐ฏ Final Action: {
|
| 638 |
|
| 639 |
return result, plot
|
| 640 |
|
|
@@ -652,9 +674,7 @@ class TradingDemo:
|
|
| 652 |
fig = make_subplots(
|
| 653 |
rows=2, cols=2,
|
| 654 |
subplot_titles=('Episode Rewards', 'Training Loss', 'Epsilon Decay', 'Portfolio Performance'),
|
| 655 |
-
vertical_spacing=0.12
|
| 656 |
-
specs=[[{"secondary_y": False}, {"secondary_y": False}],
|
| 657 |
-
[{"secondary_y": True}, {"secondary_y": False}]]
|
| 658 |
)
|
| 659 |
|
| 660 |
# Rewards
|
|
@@ -692,16 +712,16 @@ class TradingDemo:
|
|
| 692 |
# Portfolio performance
|
| 693 |
returns = [(nw - self.config.initial_balance) / self.config.initial_balance * 100
|
| 694 |
for nw in self.training_history['net_worths']]
|
| 695 |
-
fig.add_trace(
|
| 696 |
-
go.Scatter(x=episodes, y=returns, mode='lines',
|
| 697 |
-
name='Return %', line=dict(color='purple')),
|
| 698 |
-
row=2, col=2, secondary_y=True
|
| 699 |
-
)
|
| 700 |
fig.add_trace(
|
| 701 |
go.Scatter(x=episodes, y=self.training_history['net_worths'],
|
| 702 |
mode='lines', name='Net Worth',
|
| 703 |
line=dict(color='blue'), yaxis='y'),
|
| 704 |
-
row=2, col=2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 705 |
)
|
| 706 |
|
| 707 |
fig.update_layout(
|
|
@@ -712,7 +732,7 @@ class TradingDemo:
|
|
| 712 |
)
|
| 713 |
|
| 714 |
fig.update_yaxes(title_text="Return (%)", secondary_y=True, row=2, col=2)
|
| 715 |
-
fig.update_yaxes(title_text="Net Worth ($)",
|
| 716 |
|
| 717 |
return fig
|
| 718 |
|
|
@@ -726,9 +746,7 @@ class TradingDemo:
|
|
| 726 |
rows=2, cols=2,
|
| 727 |
subplot_titles=('Price Action & Trading Signals', 'Portfolio Performance',
|
| 728 |
'Portfolio Allocation', 'Action Distribution'),
|
| 729 |
-
vertical_spacing=0.12
|
| 730 |
-
specs=[[{"secondary_y": False}, {"secondary_y": True}],
|
| 731 |
-
[{"secondary_y": False}, {"secondary_y": False}]]
|
| 732 |
)
|
| 733 |
|
| 734 |
# Price and actions
|
|
@@ -748,7 +766,7 @@ class TradingDemo:
|
|
| 748 |
action_prices = [prices[i] for i in action_steps]
|
| 749 |
fig.add_trace(
|
| 750 |
go.Scatter(x=action_steps, y=action_prices, mode='markers',
|
| 751 |
-
name=f'{name}
|
| 752 |
marker=dict(color=color, size=8, symbol='triangle-up')),
|
| 753 |
row=1, col=1
|
| 754 |
)
|
|
@@ -760,12 +778,12 @@ class TradingDemo:
|
|
| 760 |
fig.add_trace(
|
| 761 |
go.Scatter(x=steps, y=net_worths, mode='lines', name='Net Worth',
|
| 762 |
line=dict(color='purple', width=2)),
|
| 763 |
-
row=1, col=2
|
| 764 |
)
|
| 765 |
fig.add_trace(
|
| 766 |
go.Scatter(x=steps, y=returns, mode='lines', name='Returns %',
|
| 767 |
-
line=dict(color='orange', width=2)),
|
| 768 |
-
row=1, col=2
|
| 769 |
)
|
| 770 |
|
| 771 |
# Portfolio composition
|
|
@@ -796,7 +814,7 @@ class TradingDemo:
|
|
| 796 |
)
|
| 797 |
|
| 798 |
fig.update_yaxes(title_text="Returns (%)", secondary_y=True, row=1, col=2)
|
| 799 |
-
fig.update_yaxes(title_text="Value ($)",
|
| 800 |
|
| 801 |
return fig
|
| 802 |
|
|
@@ -810,23 +828,15 @@ def create_interface() -> gr.Blocks:
|
|
| 810 |
title="๐ค Advanced AI Trading Demo",
|
| 811 |
css="""
|
| 812 |
.gradio-container {max-width: 1400px !important;}
|
| 813 |
-
.status-box {background-color: #f0f9ff; padding: 1rem; border-radius: 8px;}
|
| 814 |
"""
|
| 815 |
) as interface:
|
| 816 |
|
| 817 |
gr.Markdown("""
|
| 818 |
# ๐ค Advanced AI Trading Demo
|
| 819 |
**Deep Reinforcement Learning for Financial Markets**
|
| 820 |
-
|
| 821 |
-
This demo showcases a **Deep Q-Network (DQN)** agent learning to trade in simulated financial markets.
|
| 822 |
-
The agent uses technical indicators, sentiment analysis, and risk management to optimize trading strategies.
|
| 823 |
-
|
| 824 |
-
**Key Features:**
|
| 825 |
-
- ๐ Multi-asset support (Crypto, Stocks, Forex)
|
| 826 |
-
- ๐ฏ Risk-adjusted position sizing
|
| 827 |
-
- ๐ง Deep Q-Network with experience replay
|
| 828 |
-
- ๐ Real-time training visualization
|
| 829 |
-
- ๐พ Model persistence across sessions
|
| 830 |
""")
|
| 831 |
|
| 832 |
# Configuration Row
|
|
@@ -836,20 +846,17 @@ def create_interface() -> gr.Blocks:
|
|
| 836 |
with gr.Group():
|
| 837 |
balance = gr.Slider(
|
| 838 |
1000, 50000, 10000, step=1000,
|
| 839 |
-
label="๐ฐ Initial Balance ($)"
|
| 840 |
-
info="Starting capital for trading"
|
| 841 |
)
|
| 842 |
risk = gr.Radio(
|
| 843 |
["Low", "Medium", "High"], value="Medium",
|
| 844 |
-
label="๐ฏ Risk Level"
|
| 845 |
-
info="Affects position sizing and volatility"
|
| 846 |
)
|
| 847 |
asset = gr.Radio(
|
| 848 |
["Crypto", "Stock", "Forex"], value="Crypto",
|
| 849 |
-
label="๐ Asset Type"
|
| 850 |
-
info="Different volatility characteristics"
|
| 851 |
)
|
| 852 |
-
init_btn = gr.Button("๐ Initialize Trading System", variant="primary"
|
| 853 |
|
| 854 |
with gr.Column(scale=2):
|
| 855 |
gr.Markdown("## ๐ System Status")
|
|
@@ -864,8 +871,7 @@ def create_interface() -> gr.Blocks:
|
|
| 864 |
gr.Markdown("## ๐๏ธโโ๏ธ Train AI Agent")
|
| 865 |
with gr.Group():
|
| 866 |
episodes = gr.Number(
|
| 867 |
-
value=
|
| 868 |
-
minimum=10, maximum=1000, precision=0
|
| 869 |
)
|
| 870 |
train_btn = gr.Button("๐ Start Training", variant="primary")
|
| 871 |
training_output = gr.Textbox(
|
|
@@ -877,8 +883,7 @@ def create_interface() -> gr.Blocks:
|
|
| 877 |
gr.Markdown("## โถ๏ธ Test Trained Agent")
|
| 878 |
with gr.Group():
|
| 879 |
sim_steps = gr.Number(
|
| 880 |
-
value=200, label="๐ Simulation Steps",
|
| 881 |
-
minimum=50, maximum=1000, precision=0
|
| 882 |
)
|
| 883 |
sim_btn = gr.Button("๐ฎ Run Simulation", variant="primary")
|
| 884 |
sim_output = gr.Textbox(
|
|
@@ -887,78 +892,66 @@ def create_interface() -> gr.Blocks:
|
|
| 887 |
sim_plot = gr.Plot(label="๐ Trading Results")
|
| 888 |
|
| 889 |
# Event Handlers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 890 |
init_btn.click(
|
| 891 |
-
fn=
|
| 892 |
inputs=[balance, risk, asset],
|
| 893 |
outputs=status
|
| 894 |
)
|
| 895 |
|
| 896 |
-
|
| 897 |
-
for status_text, plot in demo.train(episodes):
|
| 898 |
-
yield status_text, plot
|
| 899 |
-
|
| 900 |
-
train_btn.click(
|
| 901 |
fn=train_generator,
|
| 902 |
inputs=episodes,
|
| 903 |
outputs=[training_output, train_plot]
|
| 904 |
)
|
| 905 |
|
| 906 |
sim_btn.click(
|
| 907 |
-
fn=
|
| 908 |
inputs=sim_steps,
|
| 909 |
outputs=[sim_output, sim_plot]
|
| 910 |
)
|
| 911 |
|
| 912 |
-
# Instructions
|
| 913 |
gr.Markdown("""
|
| 914 |
-
## ๐
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
- Different assets have unique volatility characteristics
|
| 920 |
-
|
| 921 |
-
### 2. **Training**
|
| 922 |
-
- Click **Initialize** to setup the environment
|
| 923 |
-
- Start training with 100+ episodes for good results
|
| 924 |
-
- Monitor training progress through real-time charts
|
| 925 |
-
- Model is automatically saved for future use
|
| 926 |
-
|
| 927 |
-
### 3. **Simulation**
|
| 928 |
-
- Run simulations to test the trained agent's performance
|
| 929 |
-
- Observe trading decisions and portfolio evolution
|
| 930 |
-
- Analyze action distribution and performance metrics
|
| 931 |
|
| 932 |
## ๐ฎ Trading Actions
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
| **Sell** | 2 | Sell portion of position |
|
| 938 |
-
| **Close** | 3 | Liquidate entire position |
|
| 939 |
-
|
| 940 |
-
## ๐ฌ Technical Details
|
| 941 |
-
- **Algorithm**: Deep Q-Network (DQN) with experience replay
|
| 942 |
-
- **State Space**: 15 features (price, technical indicators, portfolio)
|
| 943 |
-
- **Reward**: Risk-adjusted returns with drawdown penalties
|
| 944 |
-
- **Exploration**: Epsilon-greedy with decay
|
| 945 |
""")
|
| 946 |
|
| 947 |
return interface
|
| 948 |
|
| 949 |
# ---- 6. Hugging Face Spaces Entry Point ----
|
| 950 |
if __name__ == "__main__":
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
|
| 962 |
-
|
| 963 |
-
|
| 964 |
-
|
|
|
|
| 2 |
Advanced AI Trading Demo - Hugging Face Spaces
|
| 3 |
Deep Q-Network (DQN) Reinforcement Learning for Financial Trading Simulation
|
| 4 |
|
| 5 |
+
Author: AI Trading Team
|
| 6 |
License: MIT
|
| 7 |
"""
|
| 8 |
|
|
|
|
| 20 |
import os
|
| 21 |
from datetime import datetime
|
| 22 |
import json
|
| 23 |
+
from dataclasses import dataclass # โ
Added missing import
|
| 24 |
|
| 25 |
# Configure logging
|
| 26 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 153 |
losses = -returns[returns < 0]
|
| 154 |
avg_gain = np.mean(gains[-14:]) if len(gains) > 0 else 0.001
|
| 155 |
avg_loss = np.mean(losses[-14:]) if len(losses) > 0 else 0.001
|
| 156 |
+
rs = avg_gain / avg_loss if avg_loss != 0 else 100
|
| 157 |
+
rsi = 100 - (100 / (1 + rs))
|
| 158 |
else:
|
| 159 |
rsi = 50.0
|
| 160 |
|
|
|
|
| 180 |
def reset(self) -> Tuple[np.ndarray, Dict]:
|
| 181 |
"""Reset environment to initial state."""
|
| 182 |
self._reset_state()
|
| 183 |
+
self._initialize_market_data() # โ
Reinitialize market data on reset
|
| 184 |
obs = self._get_observation()
|
| 185 |
info = self._get_info()
|
| 186 |
return obs, info
|
|
|
|
| 305 |
np.mean(recent_prices) / 100.0,
|
| 306 |
np.std(recent_prices) / 100.0,
|
| 307 |
(self.current_price - np.min(recent_prices)) /
|
| 308 |
+
(np.max(recent_prices) - np.min(recent_prices) + 1e-8)
|
| 309 |
]
|
| 310 |
|
| 311 |
# Portfolio features
|
|
|
|
| 327 |
# Technical indicators
|
| 328 |
technical_features = self._calculate_technical_indicators()
|
| 329 |
|
| 330 |
+
# Combine all features (should be 4 + 3 + 3 + 6 = 16, take first 15)
|
| 331 |
+
all_features = price_features + portfolio_features + sentiment_features + technical_features
|
|
|
|
| 332 |
observation = np.array(all_features[:15], dtype=np.float32)
|
| 333 |
|
| 334 |
return observation
|
|
|
|
| 466 |
'config': self.config.__dict__
|
| 467 |
}
|
| 468 |
torch.save(checkpoint, path)
|
| 469 |
+
logger.info(f"Model saved to {path}")
|
| 470 |
|
| 471 |
def load_checkpoint(self, path: str):
|
| 472 |
"""Load agent checkpoint."""
|
| 473 |
+
if os.path.exists(path):
|
| 474 |
+
try:
|
| 475 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 476 |
+
self.q_network.load_state_dict(checkpoint['q_network'])
|
| 477 |
+
self.target_network.load_state_dict(checkpoint['target_network'])
|
| 478 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
| 479 |
+
self.epsilon = checkpoint['epsilon']
|
| 480 |
+
self.steps = checkpoint['steps']
|
| 481 |
+
logger.info(f"Model loaded from {path}")
|
| 482 |
+
except Exception as e:
|
| 483 |
+
logger.warning(f"Failed to load model from {path}: {e}")
|
| 484 |
|
| 485 |
# ---- 4. Main Trading Application ----
|
| 486 |
class TradingDemo:
|
|
|
|
| 506 |
|
| 507 |
def load_model_if_exists(self):
|
| 508 |
"""Load existing model if available."""
|
| 509 |
+
self.agent = None # Reset agent first
|
| 510 |
if os.path.exists(self.model_path):
|
| 511 |
try:
|
| 512 |
+
# Create agent first, then load
|
| 513 |
+
temp_config = TradingConfig()
|
| 514 |
+
temp_env = AdvancedTradingEnvironment(temp_config)
|
| 515 |
+
self.agent = DQNAgent(
|
| 516 |
+
state_dim=temp_env.observation_space_dim,
|
| 517 |
+
action_dim=temp_env.action_space,
|
| 518 |
+
config=temp_config,
|
| 519 |
+
device=self.device
|
| 520 |
+
)
|
| 521 |
self.agent.load_checkpoint(self.model_path)
|
| 522 |
+
logger.info("โ
Loaded existing model checkpoint")
|
| 523 |
except Exception as e:
|
| 524 |
logger.warning(f"Failed to load model: {e}")
|
| 525 |
+
self.agent = None
|
| 526 |
|
| 527 |
def initialize(self, balance: float, risk: str, asset: str) -> str:
|
| 528 |
"""Initialize trading system with new parameters."""
|
|
|
|
| 575 |
|
| 576 |
while not done:
|
| 577 |
action = self.agent.select_action(obs)
|
| 578 |
+
next_obs, reward, terminated, truncated, info = self.env.step(action)
|
| 579 |
+
done = terminated or truncated
|
| 580 |
|
| 581 |
self.agent.store_transition(obs, action, reward, next_obs, done)
|
| 582 |
loss = self.agent.update()
|
|
|
|
| 610 |
yield progress, None
|
| 611 |
|
| 612 |
# Save trained model
|
| 613 |
+
if self.agent:
|
| 614 |
+
self.agent.save_checkpoint(self.model_path)
|
| 615 |
yield "โ
Training completed! Model saved.", self._create_training_plot()
|
| 616 |
|
| 617 |
except Exception as e:
|
|
|
|
| 632 |
|
| 633 |
for step in range(steps):
|
| 634 |
action = self.agent.select_action(obs, training=False)
|
| 635 |
+
next_obs, _, terminated, truncated, info = self.env.step(action)
|
| 636 |
|
| 637 |
prices.append(self.env.current_price)
|
| 638 |
actions.append(action)
|
|
|
|
| 641 |
cash_balances.append(info['cash_balance'])
|
| 642 |
|
| 643 |
obs = next_obs
|
| 644 |
+
if terminated or truncated:
|
| 645 |
break
|
| 646 |
|
| 647 |
plot = self._create_simulation_plot(
|
|
|
|
| 651 |
final_return = ((net_worths[-1] - self.config.initial_balance) /
|
| 652 |
self.config.initial_balance * 100)
|
| 653 |
|
| 654 |
+
last_action_name = DQNAgent.ACTION_NAMES.get(actions[-1], 'Unknown')
|
| 655 |
result = (f"โ
Simulation completed!\n"
|
| 656 |
f"๐ Steps: {len(prices)}\n"
|
| 657 |
f"๐ฐ Final Net Worth: ${net_worths[-1]:,.2f}\n"
|
| 658 |
f"๐ Total Return: {final_return:.2f}%\n"
|
| 659 |
+
f"๐ฏ Final Action: {last_action_name}")
|
| 660 |
|
| 661 |
return result, plot
|
| 662 |
|
|
|
|
| 674 |
fig = make_subplots(
|
| 675 |
rows=2, cols=2,
|
| 676 |
subplot_titles=('Episode Rewards', 'Training Loss', 'Epsilon Decay', 'Portfolio Performance'),
|
| 677 |
+
vertical_spacing=0.12
|
|
|
|
|
|
|
| 678 |
)
|
| 679 |
|
| 680 |
# Rewards
|
|
|
|
| 712 |
# Portfolio performance
|
| 713 |
returns = [(nw - self.config.initial_balance) / self.config.initial_balance * 100
|
| 714 |
for nw in self.training_history['net_worths']]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 715 |
fig.add_trace(
|
| 716 |
go.Scatter(x=episodes, y=self.training_history['net_worths'],
|
| 717 |
mode='lines', name='Net Worth',
|
| 718 |
line=dict(color='blue'), yaxis='y'),
|
| 719 |
+
row=2, col=2
|
| 720 |
+
)
|
| 721 |
+
fig.add_trace(
|
| 722 |
+
go.Scatter(x=episodes, y=returns, mode='lines',
|
| 723 |
+
name='Return %', line=dict(color='purple'), yaxis='y2'),
|
| 724 |
+
row=2, col=2
|
| 725 |
)
|
| 726 |
|
| 727 |
fig.update_layout(
|
|
|
|
| 732 |
)
|
| 733 |
|
| 734 |
fig.update_yaxes(title_text="Return (%)", secondary_y=True, row=2, col=2)
|
| 735 |
+
fig.update_yaxes(title_text="Net Worth ($)", row=2, col=2)
|
| 736 |
|
| 737 |
return fig
|
| 738 |
|
|
|
|
| 746 |
rows=2, cols=2,
|
| 747 |
subplot_titles=('Price Action & Trading Signals', 'Portfolio Performance',
|
| 748 |
'Portfolio Allocation', 'Action Distribution'),
|
| 749 |
+
vertical_spacing=0.12
|
|
|
|
|
|
|
| 750 |
)
|
| 751 |
|
| 752 |
# Price and actions
|
|
|
|
| 766 |
action_prices = [prices[i] for i in action_steps]
|
| 767 |
fig.add_trace(
|
| 768 |
go.Scatter(x=action_steps, y=action_prices, mode='markers',
|
| 769 |
+
name=f'{name}',
|
| 770 |
marker=dict(color=color, size=8, symbol='triangle-up')),
|
| 771 |
row=1, col=1
|
| 772 |
)
|
|
|
|
| 778 |
fig.add_trace(
|
| 779 |
go.Scatter(x=steps, y=net_worths, mode='lines', name='Net Worth',
|
| 780 |
line=dict(color='purple', width=2)),
|
| 781 |
+
row=1, col=2
|
| 782 |
)
|
| 783 |
fig.add_trace(
|
| 784 |
go.Scatter(x=steps, y=returns, mode='lines', name='Returns %',
|
| 785 |
+
line=dict(color='orange', width=2), yaxis='y2'),
|
| 786 |
+
row=1, col=2
|
| 787 |
)
|
| 788 |
|
| 789 |
# Portfolio composition
|
|
|
|
| 814 |
)
|
| 815 |
|
| 816 |
fig.update_yaxes(title_text="Returns (%)", secondary_y=True, row=1, col=2)
|
| 817 |
+
fig.update_yaxes(title_text="Value ($)", row=1, col=2)
|
| 818 |
|
| 819 |
return fig
|
| 820 |
|
|
|
|
| 828 |
title="๐ค Advanced AI Trading Demo",
|
| 829 |
css="""
|
| 830 |
.gradio-container {max-width: 1400px !important;}
|
| 831 |
+
.status-box {background-color: #f0f9ff; padding: 1rem; border-radius: 8px; border-left: 4px solid #3b82f6;}
|
| 832 |
"""
|
| 833 |
) as interface:
|
| 834 |
|
| 835 |
gr.Markdown("""
|
| 836 |
# ๐ค Advanced AI Trading Demo
|
| 837 |
**Deep Reinforcement Learning for Financial Markets**
|
| 838 |
+
|
| 839 |
+
This demo showcases a **Deep Q-Network (DQN)** agent learning to trade in simulated financial markets with realistic market dynamics, technical indicators, and risk management.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 840 |
""")
|
| 841 |
|
| 842 |
# Configuration Row
|
|
|
|
| 846 |
with gr.Group():
|
| 847 |
balance = gr.Slider(
|
| 848 |
1000, 50000, 10000, step=1000,
|
| 849 |
+
label="๐ฐ Initial Balance ($)"
|
|
|
|
| 850 |
)
|
| 851 |
risk = gr.Radio(
|
| 852 |
["Low", "Medium", "High"], value="Medium",
|
| 853 |
+
label="๐ฏ Risk Level"
|
|
|
|
| 854 |
)
|
| 855 |
asset = gr.Radio(
|
| 856 |
["Crypto", "Stock", "Forex"], value="Crypto",
|
| 857 |
+
label="๐ Asset Type"
|
|
|
|
| 858 |
)
|
| 859 |
+
init_btn = gr.Button("๐ Initialize Trading System", variant="primary")
|
| 860 |
|
| 861 |
with gr.Column(scale=2):
|
| 862 |
gr.Markdown("## ๐ System Status")
|
|
|
|
| 871 |
gr.Markdown("## ๐๏ธโโ๏ธ Train AI Agent")
|
| 872 |
with gr.Group():
|
| 873 |
episodes = gr.Number(
|
| 874 |
+
value=50, label="๐ฏ Training Episodes", precision=0
|
|
|
|
| 875 |
)
|
| 876 |
train_btn = gr.Button("๐ Start Training", variant="primary")
|
| 877 |
training_output = gr.Textbox(
|
|
|
|
| 883 |
gr.Markdown("## โถ๏ธ Test Trained Agent")
|
| 884 |
with gr.Group():
|
| 885 |
sim_steps = gr.Number(
|
| 886 |
+
value=200, label="๐ Simulation Steps", precision=0
|
|
|
|
| 887 |
)
|
| 888 |
sim_btn = gr.Button("๐ฎ Run Simulation", variant="primary")
|
| 889 |
sim_output = gr.Textbox(
|
|
|
|
| 892 |
sim_plot = gr.Plot(label="๐ Trading Results")
|
| 893 |
|
| 894 |
# Event Handlers
|
| 895 |
+
def initialize_wrapper(balance, risk, asset):
|
| 896 |
+
return demo.initialize(balance, risk, asset)
|
| 897 |
+
|
| 898 |
+
def simulate_wrapper(steps):
|
| 899 |
+
return demo.simulate(steps)
|
| 900 |
+
|
| 901 |
+
def train_generator(episodes):
|
| 902 |
+
try:
|
| 903 |
+
for status_text, plot in demo.train(int(episodes)):
|
| 904 |
+
yield status_text, plot
|
| 905 |
+
except Exception as e:
|
| 906 |
+
yield f"โ Training error: {str(e)}", None
|
| 907 |
+
|
| 908 |
init_btn.click(
|
| 909 |
+
fn=initialize_wrapper,
|
| 910 |
inputs=[balance, risk, asset],
|
| 911 |
outputs=status
|
| 912 |
)
|
| 913 |
|
| 914 |
+
train_btn.queue().click(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 915 |
fn=train_generator,
|
| 916 |
inputs=episodes,
|
| 917 |
outputs=[training_output, train_plot]
|
| 918 |
)
|
| 919 |
|
| 920 |
sim_btn.click(
|
| 921 |
+
fn=simulate_wrapper,
|
| 922 |
inputs=sim_steps,
|
| 923 |
outputs=[sim_output, sim_plot]
|
| 924 |
)
|
| 925 |
|
|
|
|
| 926 |
gr.Markdown("""
|
| 927 |
+
## ๐ Usage Instructions
|
| 928 |
+
1. **Configure** your trading parameters
|
| 929 |
+
2. **Initialize** the trading system
|
| 930 |
+
3. **Train** the AI agent (50+ episodes recommended)
|
| 931 |
+
4. **Simulate** trading with the trained agent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 932 |
|
| 933 |
## ๐ฎ Trading Actions
|
| 934 |
+
- **Hold (0)**: Maintain current position
|
| 935 |
+
- **Buy (1)**: Purchase assets (risk-adjusted)
|
| 936 |
+
- **Sell (2)**: Sell portion of position
|
| 937 |
+
- **Close (3)**: Liquidate entire position
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 938 |
""")
|
| 939 |
|
| 940 |
return interface
|
| 941 |
|
| 942 |
# ---- 6. Hugging Face Spaces Entry Point ----
|
| 943 |
if __name__ == "__main__":
|
| 944 |
+
try:
|
| 945 |
+
interface = create_interface()
|
| 946 |
+
interface.launch(
|
| 947 |
+
server_name="0.0.0.0",
|
| 948 |
+
server_port=7860,
|
| 949 |
+
share=False,
|
| 950 |
+
show_error=True,
|
| 951 |
+
enable_queue=True,
|
| 952 |
+
max_threads=40,
|
| 953 |
+
debug=False
|
| 954 |
+
)
|
| 955 |
+
except Exception as e:
|
| 956 |
+
logger.error(f"Failed to launch application: {e}")
|
| 957 |
+
raise
|