Commit
·
349ad65
1
Parent(s):
1b637c6
Version 1.0 release
Browse files- README.md +174 -94
- requirements.txt +23 -14
- results/final_performance_comparison_all_agents.png +2 -2
- results/{td3_portfolio_alocation.png → ppo_allocation.png} +2 -2
- results/sac_allocation.png +3 -0
- results/{ppo_portfolio_alocation.png → td3_allocation.png} +2 -2
- results/{sac_portfolio_alocation.png → td3_transformer_allocation.png} +2 -2
- scripts/app.py +662 -0
- scripts/custom_policy.py +80 -0
- scripts/environment.py +52 -94
- scripts/evaluate.py +62 -51
- scripts/evaluate_baselines.py +50 -29
- scripts/fetch_data.py +0 -75
- scripts/fetch_market_data.py +90 -64
- scripts/llm_analysis_rag.py +243 -0
- scripts/predict_tomorrow.py +123 -0
- scripts/tune_sac.py +198 -0
- scripts/visualize_strategy.py +0 -123
README.md
CHANGED
|
@@ -1,11 +1,21 @@
|
|
| 1 |

|
| 2 |
[](https://www.python.org/)[](https://pytorch.org/)[](LICENSE)
|
| 3 |
|
| 4 |
-
# 🤖 Portfolio Optimization with Deep Reinforcement Learning
|
| 5 |
|
| 6 |
-
This project explores the use of Deep Reinforcement Learning to train autonomous agents for financial portfolio management. The goal
|
| 7 |
|
| 8 |
-
**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
---
|
| 11 |
|
|
@@ -13,16 +23,12 @@ This project explores the use of Deep Reinforcement Learning to train autonomous
|
|
| 13 |
|
| 14 |
1. [📊 The Data & Asset Selection](#-the-data--asset-selection)
|
| 15 |
2. [🎯 Benchmarking Against Baselines](#-benchmarking-against-baselines)
|
| 16 |
-
3. [🏆 Key Findings & The Champion
|
| 17 |
-
4. [
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
6. [🔬 The Research Journey: Why Simplicity Won](#-the-research-journey-why-simplicity-won)
|
| 23 |
-
7. [✅ Conclusion](#-conclusion)
|
| 24 |
-
8. [📂 Project Structure](#-project-structure)
|
| 25 |
-
9. [🚀 How to Run](#-how-to-run)
|
| 26 |
* [Setup](#setup)
|
| 27 |
* [Data Fetching](#data-fetching)
|
| 28 |
* [Training](#training)
|
|
@@ -32,16 +38,23 @@ This project explores the use of Deep Reinforcement Learning to train autonomous
|
|
| 32 |
|
| 33 |
## 📊 The Data & Asset Selection
|
| 34 |
|
| 35 |
-
The foundation of any financial machine learning project is the data.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
The selection of assets was crucial for creating a realistic decision-making environment for the agent. The portfolio consists of five assets, chosen to represent different classes and risk profiles:
|
| 38 |
|
| 39 |
* **Growth Equities (AAPL, MSFT):** Represent the high-growth, high-volatility technology sector.
|
| 40 |
* **Market Index (SPY):** An ETF tracking the S&P 500, representing the broader US stock market.
|
| 41 |
* **Safe Haven (TLT):** An ETF for 20+ Year US Treasury Bonds, which often acts as a "risk-off" asset during stock market downturns.
|
| 42 |
* **Alternative Asset (BTC-USD):** Represents a non-traditional, extremely volatile asset class with high potential returns.
|
| 43 |
|
| 44 |
-
This diverse mix forces the agent to learn not just about individual
|
| 45 |
|
| 46 |
---
|
| 47 |
|
|
@@ -57,73 +70,83 @@ The chart below shows the performance of a simple Buy and Hold strategy during t
|
|
| 57 |
|
| 58 |
---
|
| 59 |
|
| 60 |
-
## 🏆 Key Findings & The Champion
|
| 61 |
|
| 62 |
-
|
| 63 |
|
| 64 |
-
|
| 65 |
|
| 66 |
-
|
| 67 |
|
| 68 |
-
|
| 69 |
-
| :--- | :--- | :--- | :--- |
|
| 70 |
-
| **Total Return** | 47.24% | **50.89%** | 34.91% |
|
| 71 |
-
| **CAGR** | 13.76% | **14.70%** | 10.50% |
|
| 72 |
-
| **Sharpe Ratio** | **0.62** | 0.51 | 0.45 |
|
| 73 |
-
| **Max Drawdown** | **-28.41%** | -44.61% | -40.81% |
|
| 74 |
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |

|
|
|
|
| 78 |
|
| 79 |
-
|
| 80 |
|
| 81 |
-
|
| 82 |
|
| 83 |
-
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
###
|
| 86 |
|
| 87 |
-
The
|
| 88 |
|
| 89 |
-
* **
|
| 90 |
-
* **
|
| 91 |
-
* **Result:** This approach led to the best risk-adjusted returns, proving that a robust initial setup is more valuable than reactive trading.
|
| 92 |
|
| 93 |
-
|
| 94 |
|
| 95 |
-
|
| 96 |
|
| 97 |
-
|
| 98 |
|
| 99 |
-
|
| 100 |
-
* **Behavior:** Like TD3, it made one initial allocation and held firm. However, this allocation was far more aggressive.
|
| 101 |
-
* **Result:** It achieved the highest total return in some periods but suffered catastrophic drawdowns in stress tests, making its strategy unreliable and brittle.
|
| 102 |
|
| 103 |
-
|
| 104 |
|
| 105 |
-
###
|
| 106 |
|
| 107 |
-
|
| 108 |
|
| 109 |
-
|
| 110 |
-
* **Behavior:** The allocation chart clearly shows the agent rebalancing its portfolio over time, for example, by increasing its bond (TLT) holdings during the 2022 downturn.
|
| 111 |
-
* **Result:** While impressive that it learned this behavior, its performance was inconsistent. It succeeded in some periods (2018) but failed in others (2025), highlighting the immense difficulty of successful market timing.
|
| 112 |
|
| 113 |
-
|
| 114 |
|
| 115 |
-
|
|
|
|
|
|
|
| 116 |
|
| 117 |
-
|
| 118 |
|
| 119 |
-
|
| 120 |
|
| 121 |
-
|
| 122 |
|
| 123 |
-
|
| 124 |
-
* **SAC's Failure:** The green line shows the SAC agent's aggressive strategy failing catastrophically, resulting in a massive drawdown.
|
| 125 |
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
---
|
| 129 |
|
|
@@ -135,8 +158,39 @@ This project was also an exercise in scientific methodology. We initially hypoth
|
|
| 135 |
* **Hypothesis 2: Models with memory are better.** We tested an LSTM-based agent (`RecurrentPPO`). **Result:** Performance degraded. The added complexity led to overfitting on the training data.
|
| 136 |
* **Hypothesis 3: Using Regularization is better.** We tested both L1 and L2 regularization. **Results:** Performance degraded.
|
| 137 |
* **Hypothesis 4: Increasing the window from 30 days is better.** We tested increasing the window to 60 days. **Results:** Performance degraded. increasing the context window is not always good and it could be seen as more noise for the model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
---
|
| 142 |
|
|
@@ -148,68 +202,94 @@ This project successfully demonstrates that Deep Reinforcement Learning can be a
|
|
| 148 |
|
| 149 |
## 📂 Project Structure
|
| 150 |
|
| 151 |
-
The codebase is organized into modular, reusable scripts.
|
| 152 |
-
|
| 153 |
```bash
|
| 154 |
-
├── assets/
|
| 155 |
-
├── checkpoints/
|
| 156 |
-
├──
|
| 157 |
-
├──
|
| 158 |
-
|
| 159 |
-
│
|
| 160 |
-
│
|
| 161 |
-
│
|
| 162 |
-
│
|
| 163 |
-
│
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
```
|
| 166 |
|
| 167 |
-
---
|
| 168 |
-
|
| 169 |
## 🚀 How to Run
|
| 170 |
|
| 171 |
### Setup
|
| 172 |
|
| 173 |
-
1. Clone the repository
|
| 174 |
-
2. Create and activate a Python virtual environment.
|
| 175 |
-
3. Install the required packages:
|
| 176 |
|
| 177 |
-
|
| 178 |
-
pip install -r requirements.txt
|
| 179 |
-
```
|
| 180 |
|
| 181 |
-
|
|
|
|
| 182 |
|
| 183 |
-
|
| 184 |
|
| 185 |
```bash
|
| 186 |
-
|
| 187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
```
|
| 192 |
|
| 193 |
### Training
|
| 194 |
|
| 195 |
-
Use the `train.py` script to train
|
| 196 |
|
| 197 |
-
```
|
| 198 |
-
# Train
|
| 199 |
-
python
|
| 200 |
|
| 201 |
# Train a SAC agent for more timesteps
|
| 202 |
-
python
|
| 203 |
```
|
| 204 |
|
|
|
|
|
|
|
| 205 |
### Evaluation & Visualization
|
| 206 |
|
| 207 |
-
|
| 208 |
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
-
|
| 214 |
-
|
|
|
|
|
|
|
| 215 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |

|
| 2 |
[](https://www.python.org/)[](https://pytorch.org/)[](LICENSE)
|
| 3 |
|
| 4 |
+
# 🤖 Portfolio Optimization with Deep Reinforcement Learning (v1.0)
|
| 5 |
|
| 6 |
+
This project explores the use of Deep Reinforcement Learning (DRL) to train autonomous agents for financial portfolio management. The goal is to create agents that can dynamically allocate capital across a diverse set of assets to maximize returns while managing risk.
|
| 7 |
|
| 8 |
+
This is **Version 1.0** of the project, which moves beyond initial exploration to a more robust and comparative study. Building on the foundation of v0.1, this version introduces:
|
| 9 |
+
|
| 10 |
+
* **Comparative Analysis:** We train and evaluate three state-of-the-art DRL algorithms: **Proximal Policy Optimization (PPO)**, **Soft Actor-Critic (SAC)**, and **Twin Delayed DDPG (TD3)**. This allows us to understand the different emergent strategies and trade-offs of each approach.
|
| 11 |
+
* **Robust Benchmarking:** Agents' performance is rigorously compared against a standard **Buy and Hold** baseline, using a comprehensive set of financial metrics including Total Return, CAGR, Sharpe Ratio, Sortino Ratio, and Max Drawdown.
|
| 12 |
+
* **Modular Codebase:** The project has been refactored into a clean, modular structure with separate scripts for data fetching, training, evaluation, and visualization, making it easier to understand, extend, and reproduce results.
|
| 13 |
+
* **In-Depth Analysis:** We delve into *why* certain agents perform better, visualizing their asset allocation strategies over time to uncover their "investment philosophy."
|
| 14 |
+
|
| 15 |
+
* **Deep RL & LLM Portfolio Manager (Web App):** A key feature of v1.0 is the interactive web application built with **Gradio**. This dashboard bridges the gap between complex backend models and user-friendly analysis, allowing for live tracking, forward-looking strategy generation, and historical backtesting.
|
| 16 |
+
The dashboard integrates **Large Language Models (LLMs)**, specifically Qwen, to act as an AI Risk Analyst, providing textual justification and risk assessments for the RL agent's proposed strategies.
|
| 17 |
+
|
| 18 |
+
*You try the webapp here ->* [Gradio webapp](https://huggingface.co/spaces/DanielKiani/Portfolio-Optimization-with-Deep-Reinforcement-Learning)
|
| 19 |
|
| 20 |
---
|
| 21 |
|
|
|
|
| 23 |
|
| 24 |
1. [📊 The Data & Asset Selection](#-the-data--asset-selection)
|
| 25 |
2. [🎯 Benchmarking Against Baselines](#-benchmarking-against-baselines)
|
| 26 |
+
3. [🏆 Key Findings & The New Champion](#-key-findings--the-new-champion)
|
| 27 |
+
4. [🔬 The Research Journey: Why Simplicity Won](#-the-research-journey-why-simplicity-won)
|
| 28 |
+
5. [🖥️ Deep RL & LLM Portfolio Manager (Web App)](#️-deep-rl--llm-portfolio-manager-web-app)
|
| 29 |
+
6. [✅ Conclusion](#-conclusion)
|
| 30 |
+
7. [📂 Project Structure](#-project-structure)
|
| 31 |
+
8. [🚀 How to Run](#-how-to-run)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
* [Setup](#setup)
|
| 33 |
* [Data Fetching](#data-fetching)
|
| 34 |
* [Training](#training)
|
|
|
|
| 38 |
|
| 39 |
## 📊 The Data & Asset Selection
|
| 40 |
|
| 41 |
+
The foundation of any financial machine learning project is the data. The primary source for daily closing price data of the portfolio assets is **Yahoo Finance**, accessed via the `yfinance` library.
|
| 42 |
+
|
| 43 |
+
To provide the agents with broader economic context beyond just price history, the observation space is augmented with key macroeconomic indicators sourced from **FRED (Federal Reserve Economic Data)**. These indicators include data points such as the CBOE Volatility Index (VIX), various Treasury bill yields, and inflation expectations. This allows the agents to learn strategies that adapt to different market regimes, such as high volatility or rising interest rate environments.
|
| 44 |
+
|
| 45 |
+
**Environment & Realistic Constraints:**
|
| 46 |
+
To ensure realistic simulation results, the trading environment incorporates transaction costs.
|
| 47 |
+
* **Transaction Cost:** A fee of **0.001%** is applied to the notional value of every trade (both buys and sells). This forces the agents to learn strategies that generate returns net of fees, discouraging excessive, unprofitable trading.
|
| 48 |
+
|
| 49 |
+
The portfolio itself consists of five assets, chosen to represent different asset classes and risk profiles, creating a challenging decision-making environment:
|
| 50 |
|
|
|
|
| 51 |
|
| 52 |
* **Growth Equities (AAPL, MSFT):** Represent the high-growth, high-volatility technology sector.
|
| 53 |
* **Market Index (SPY):** An ETF tracking the S&P 500, representing the broader US stock market.
|
| 54 |
* **Safe Haven (TLT):** An ETF for 20+ Year US Treasury Bonds, which often acts as a "risk-off" asset during stock market downturns.
|
| 55 |
* **Alternative Asset (BTC-USD):** Represents a non-traditional, extremely volatile asset class with high potential returns.
|
| 56 |
|
| 57 |
+
This diverse mix forces the agent to learn not just about individual asset price movements, but also about their correlations and how to balance risk across different economic conditions.
|
| 58 |
|
| 59 |
---
|
| 60 |
|
|
|
|
| 70 |
|
| 71 |
---
|
| 72 |
|
| 73 |
+
## 🏆 Key Findings & The New Champion
|
| 74 |
|
| 75 |
+
Our latest evaluation on out-of-sample data from **2021-2023** has yielded surprising and significant results, challenging our initial assumptions and highlighting the impact of neural network architecture on agent performance.
|
| 76 |
|
| 77 |
+
The **TD3 agent powered by a Transformer architecture** has emerged as the undisputed champion in terms of risk-adjusted returns and capital preservation, while the **SAC agent** demonstrated the highest absolute growth potential.
|
| 78 |
|
| 79 |
+
#### Final Performance Comparison (2021-2023)
|
| 80 |
|
| 81 |
+
This table summarizes the performance of our key agents against the Buy & Hold baseline.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
| Metric | **TD3 (Transformer)** | SAC (MLP) | Buy & Hold | PPO (MLP) | TD3 (MLP) |
|
| 84 |
+
| :--- | :--- | :--- | :--- | :--- | :--- |
|
| 85 |
+
| **Total Return** | 25.34% | **39.23%** | 32.76% | 22.85% | 22.07% |
|
| 86 |
+
| **CAGR** | 8.20% | **12.25%** | 9.96% | 7.45% | 7.21% |
|
| 87 |
+
| **Sharpe Ratio** | **0.61** | 0.56 | 0.59 | 0.41 | 0.42 |
|
| 88 |
+
| **Volatility** | **14.77%** | 27.47% | 19.06% | 25.90% | 23.00% |
|
| 89 |
+
| **Max Drawdown** | **-20.01%** | -29.08% | -28.82% | -44.26% | -40.50% |
|
| 90 |
|
| 91 |

|
| 92 |
+
***note*: bitcoin was excluded from the performance comparison**
|
| 93 |
|
| 94 |
+
### 🥇 TD3 (Transformer): The Master of Risk Management
|
| 95 |
|
| 96 |
+
The most notable finding is the superior performance of the TD3 agent when equipped with a **Transformer-based policy network**. This agent achieved the best risk-adjusted metrics across the board.
|
| 97 |
|
| 98 |
+
* **Lowest Volatility (14.77%):** It provided a significantly smoother ride than even the passive Buy & Hold baseline.
|
| 99 |
+
* **Best Capital Preservation:** Its maximum drawdown of **-20.01%** was drastically lower than other agents and the baseline, proving its ability to protect capital during severe market downturns like the 2022 bear market.
|
| 100 |
+
* **Conclusion:** The Transformer's attention mechanism likely allowed the agent to better identify and react to long-term market shifts and regime changes, leading to a highly robust and defensive strategy.
|
| 101 |
|
| 102 |
+
### 🚀 SAC (MLP): The Aggressive Growth Engine
|
| 103 |
|
| 104 |
+
The **Soft Actor-Critic (SAC)** agent confirmed its role as the high-growth strategist.
|
| 105 |
|
| 106 |
+
* **Highest Returns:** It achieved the highest Total Return (**39.23%**) and CAGR (**12.25%**), outperforming the Buy & Hold baseline by a significant margin.
|
| 107 |
+
* **Higher Risk:** This performance came at the cost of the highest volatility (**27.47%**), making it a strategy suited for aggressive investors willing to tolerate larger price swings for maximum gain.
|
|
|
|
| 108 |
|
| 109 |
+
### 📉 The Failure of Standard Architectures
|
| 110 |
|
| 111 |
+
Interestingly, the standard Multi-Layer Perceptron (MLP) versions of PPO and TD3 failed to beat the simple Buy & Hold baseline. They suffered the lowest returns and the deepest drawdowns. This stark contrast with the success of the Transformer model highlights that for complex financial time-series, **network architecture is just as critical, if not more so, than the choice of RL algorithm itself.**
|
| 112 |
|
| 113 |
+
---
|
| 114 |
|
| 115 |
+
## 🧠 Comparative Analysis of Agent Strategies
|
|
|
|
|
|
|
| 116 |
|
| 117 |
+
A fascinating outcome of this project was observing how different combinations of RL algorithms and network architectures led to distinct investment philosophies. We can visualize this by looking at how each agent allocated its portfolio over time.
|
| 118 |
|
| 119 |
+
### TD3 (Transformer): The Dynamic Hedger
|
| 120 |
|
| 121 |
+
The Transformer-based TD3 agent did not learn a static allocation. Instead, it developed a sophisticated, **dynamic hedging strategy**. By leveraging the Transformer's attention mechanism to process the 30-day lookback window, the agent could identify market trends and adapt its portfolio accordingly.
|
| 122 |
|
| 123 |
+

|
|
|
|
|
|
|
| 124 |
|
| 125 |
+
As shown in the chart, the agent maintains a core position in equities (AAPL, MSFT, SPY) but actively manages its exposure. During the volatile bear market of 2022, the agent significantly increased its allocation to the safe-haven asset **TLT (US Treasury Bonds)**, effectively "smoothing out" its equity curve and avoiding the deep losses suffered by the baseline. This ability to dynamically shift into defensive assets is the key to its superior risk-adjusted performance.
|
| 126 |
|
| 127 |
+
### SAC (MLP): The High-Conviction Aggressor
|
| 128 |
+
|
| 129 |
+
The SAC agent learned a strategy that is nearly the polar opposite of the Transformer. It converged to a **high-risk, high-return static allocation strategy**. Its portfolio is heavily weighted towards high-growth assets, likely with a substantial allocation to Bitcoin (BTC-USD) and tech stocks, with very little exposure to defensive assets like bonds or cash.
|
| 130 |
|
| 131 |
+

|
| 132 |
|
| 133 |
+
The allocation chart reveals a strategy with minimal changes over time, indicating a "set-and-forget" approach. While this high-conviction bet paid off with the highest total return, it also exposed the portfolio to significant volatility.
|
| 134 |
|
| 135 |
+
### PPO (MLP): The Failed Active Trader
|
| 136 |
|
| 137 |
+
Unlike the other MLP-based agents which converged to static allocations, the PPO agent attempted a **dynamic, active trading strategy**.
|
|
|
|
| 138 |
|
| 139 |
+

|
| 140 |
+
|
| 141 |
+
As seen in the chart, the agent frequently rebalances its portfolio, shifting weights between equities, bonds, and cash. However, the performance metrics indicate that this activity was detrimental. With poor returns and the deepest maximum drawdown (-44.26%) among all agents, the PPO agent's attempts at market timing were unsuccessful, churning the portfolio without generating alpha or managing risk.
|
| 142 |
+
|
| 143 |
+
### TD3 (MLP): The Failed Static Allocator
|
| 144 |
+
|
| 145 |
+
The standard MLP version of the TD3 agent also converged to a static allocation, similar to the SAC agent, but chose a clearly suboptimal portfolio.
|
| 146 |
+
|
| 147 |
+

|
| 148 |
+
|
| 149 |
+
The chart shows a relatively fixed allocation that failed to perform well. Unlike the SAC agent, it did not capture high-growth opportunities, and unlike the Transformer agent, it lacked the dynamic capability to manage risk. This resulted in near-bottom performance across all metrics.
|
| 150 |
|
| 151 |
---
|
| 152 |
|
|
|
|
| 158 |
* **Hypothesis 2: Models with memory are better.** We tested an LSTM-based agent (`RecurrentPPO`). **Result:** Performance degraded. The added complexity led to overfitting on the training data.
|
| 159 |
* **Hypothesis 3: Using Regularization is better.** We tested both L1 and L2 regularization. **Results:** Performance degraded.
|
| 160 |
* **Hypothesis 4: Increasing the window from 30 days is better.** We tested increasing the window to 60 days. **Results:** Performance degraded. increasing the context window is not always good and it could be seen as more noise for the model.
|
| 161 |
+
* **Hypothesis 5: A Transformer-based architecture is superior.** We replaced the standard Multi-Layer Perceptron (MLP) policy network with a more powerful Transformer model, hypothesizing its attention mechanism would better capture complex temporal relationships. **Result**: Performance degraded. Similar to the LSTM experiment, the Transformer model was too complex for the amount of data available. It suffered from significant overfitting, performing well on training data but failing to generalize to unseen market scenarios.
|
| 162 |
+
|
| 163 |
+
The conclusion was clear: a simple MLP (Multi-Layer Perceptron) policy network, fed with just normalized price data and a concise 30-day window, was the most effective and robust architecture.
|
| 164 |
+
|
| 165 |
+
---
|
| 166 |
+
|
| 167 |
+
## 🖥️ Deep RL & LLM Portfolio Manager (Web App)
|
| 168 |
+
|
| 169 |
+
A key feature of v1.0 is the interactive web application built with **Gradio**. This dashboard bridges the gap between complex backend models and user-friendly analysis, allowing for live tracking, forward-looking strategy generation, and historical backtesting.
|
| 170 |
+
|
| 171 |
+
The dashboard integrates **Large Language Models (LLMs)**, specifically Qwen, to act as an AI Risk Analyst, providing textual justification and risk assessments for the RL agent's proposed strategies.
|
| 172 |
+
|
| 173 |
+
### Key Features:
|
| 174 |
+
|
| 175 |
+
#### 1. Live Dashboard & Net Worth Tracking
|
| 176 |
+
|
| 177 |
+
Track the current portfolio holdings, recent transactions, and the overall net worth evolution in real-time.
|
| 178 |
|
| 179 |
+

|
| 180 |
+
|
| 181 |
+
#### 2. AI-Powered Strategy Forecast & Risk Analysis
|
| 182 |
+
|
| 183 |
+
Generate tomorrow's optimal portfolio allocation using the trained RL agents. The integrated LLM analyzes the proposed allocation, current market volatility (VIX), and asset concentration to provide a comprehensive **Risk Analyst Report** with a confidence score and justifications.
|
| 184 |
+
|
| 185 |
+
It also includes **Explainable AI (XAI)** feature importance plots to show which market factors most influenced the agent's decision.
|
| 186 |
+
|
| 187 |
+

|
| 188 |
+
|
| 189 |
+
#### 3. Historical Simulation & Backtesting
|
| 190 |
+
|
| 191 |
+
Run dynamic backtests of the trained RL agents against baselines over any historical period. This tool is essential for validating performance across different market cycles.
|
| 192 |
+
|
| 193 |
+

|
| 194 |
|
| 195 |
---
|
| 196 |
|
|
|
|
| 202 |
|
| 203 |
## 📂 Project Structure
|
| 204 |
|
|
|
|
|
|
|
| 205 |
```bash
|
| 206 |
+
├── assets/ # Images for the README
|
| 207 |
+
├── checkpoints/ # Stores trained model weights (.zip files)
|
| 208 |
+
├── data/ # Stores fetched CSV data files
|
| 209 |
+
├── results/ # Stores generated plots and metrics logs
|
| 210 |
+
├── scripts/ # Contains all the executable scripts
|
| 211 |
+
│ ├── app.py # The Gradio web application
|
| 212 |
+
│ ├── check_env.py # Simple script to verify the custom environment
|
| 213 |
+
│ ├── custom_policy.py # Custom policy network definitions
|
| 214 |
+
│ ├── environment.py # The custom Gymnasium environment class
|
| 215 |
+
│ ├── evaluate_baselines.py # Calculates performance of baseline strategies
|
| 216 |
+
│ ├── evaluate.py # Main script to evaluate a trained agent
|
| 217 |
+
│ ├── fetch_market_data.py # Script to download historical data from YFinance
|
| 218 |
+
│ ├── llm_analysis_rag.py # Script for LLM-based analysis and RAG
|
| 219 |
+
│ ├── predict_tomorrow.py # Script to generate predictions for the next day
|
| 220 |
+
│ ├── stress_test.py # Compares all agents on a specific dataset
|
| 221 |
+
│ ├── train.py # Main script to train an RL agent
|
| 222 |
+
│ ├── tune_sac.py # Script for hyperparameter tuning of the SAC agent
|
| 223 |
+
│ └── visualize_strategy.py # Plots the asset allocation of a trained agent
|
| 224 |
+
├── requirements.txt # List of Python dependencies
|
| 225 |
+
└── README.md # This file
|
| 226 |
```
|
| 227 |
|
|
|
|
|
|
|
| 228 |
## 🚀 How to Run
|
| 229 |
|
| 230 |
### Setup
|
| 231 |
|
| 232 |
+
1. Clone the repository:
|
|
|
|
|
|
|
| 233 |
|
| 234 |
+
```Bash
|
|
|
|
|
|
|
| 235 |
|
| 236 |
+
git clone https://github.com/DanielKiani/Portfolio-Optimization-with-Deep-Reinforcement-Learning
|
| 237 |
+
```
|
| 238 |
|
| 239 |
+
2. Install the required packages:
|
| 240 |
|
| 241 |
```bash
|
| 242 |
+
pip install -r requirements.txt
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
### Data Fetching
|
| 246 |
+
|
| 247 |
+
Before training or evaluation, you need to download the historical market data. Use the `fetch_market_data.py` script.
|
| 248 |
|
| 249 |
+
```Bash
|
| 250 |
+
# Fetch training data (e.g., 2015-2020)
|
| 251 |
+
python scripts/fetch_market_data.py --start 2015-01-01 --end 2020-12-31 --filename data/train_data.csv
|
| 252 |
+
|
| 253 |
+
# Fetch evaluation data (e.g., 2021-2023)
|
| 254 |
+
python scripts/fetch_market_data.py --start 2021-01-01 --end 2023-12-31 --filename data/eval_data.csv
|
| 255 |
```
|
| 256 |
|
| 257 |
### Training
|
| 258 |
|
| 259 |
+
Use the `train.py` script to train an agent. You can specify the algorithm (ppo, sac, or td3) and the number of training timesteps.
|
| 260 |
|
| 261 |
+
```Bash
|
| 262 |
+
# Train a TD3 agent (default timesteps: 20000)
|
| 263 |
+
python scripts/train.py --agent td3 --datafile data/train_data.csv
|
| 264 |
|
| 265 |
# Train a SAC agent for more timesteps
|
| 266 |
+
python scripts/train.py --agent sac --datafile data/train_data.csv --timesteps 50000
|
| 267 |
```
|
| 268 |
|
| 269 |
+
The trained model will be saved in the `checkpoints/` directory (e.g., `sac_portfolio_model.zip`).
|
| 270 |
+
|
| 271 |
### Evaluation & Visualization
|
| 272 |
|
| 273 |
+
Once you have trained models and evaluation data, you can use the other scripts to analyze performance.
|
| 274 |
|
| 275 |
+
* **Compare all agents** (`stress_test.py`): This script loads all available models in `checkpoints/` and compares them against the baseline on a given dataset.
|
| 276 |
+
|
| 277 |
+
```Bash
|
| 278 |
+
python scripts/stress_test.py --datafile data/eval_data.csv
|
| 279 |
+
```
|
| 280 |
+
|
| 281 |
+
This will generate `results/agent_performance_comparison`.png and print a metrics table.
|
| 282 |
|
| 283 |
+
* **Evaluate a single agent** (`evaluate.py`): This script calculates detailed metrics for a specific agent and plots its portfolio value.
|
| 284 |
+
|
| 285 |
+
```Bash
|
| 286 |
+
python scripts/evaluate.py --agent td3 --checkpoint checkpoints/td3_portfolio_model.zip --datafile data/eval_data.csv
|
| 287 |
```
|
| 288 |
+
|
| 289 |
+
* **Visualize an agent's strategy** (`visualize_strategy.py`): This script creates a stacked area chart showing how the agent's asset allocation changed over time.
|
| 290 |
+
|
| 291 |
+
```Bash
|
| 292 |
+
python scripts/visualize_strategy.py --agent ppo --checkpoint checkpoints/ppo_portfolio_model.zip --datafile data/eval_data.csv
|
| 293 |
+
```
|
| 294 |
+
|
| 295 |
+
This will save the plot to `results/ppo_portfolio_allocation.png`.
|
requirements.txt
CHANGED
|
@@ -1,19 +1,28 @@
|
|
| 1 |
-
# Core
|
| 2 |
-
stable-baselines3==2.7.0
|
| 3 |
-
sb3_contrib==2.7.0
|
| 4 |
-
gymnasium==1.2.1
|
| 5 |
-
|
| 6 |
-
# Data Handling and Numerics
|
| 7 |
-
pandas==2.3.3
|
| 8 |
numpy==2.2.6
|
| 9 |
-
|
|
|
|
| 10 |
|
| 11 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
yfinance==0.2.66
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
#
|
| 15 |
-
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core Data Science & Mathematics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
numpy==2.2.6
|
| 3 |
+
pandas==2.3.3
|
| 4 |
+
scipy==1.16.3
|
| 5 |
|
| 6 |
+
# Visualization
|
| 7 |
+
matplotlib==3.10.0
|
| 8 |
+
plotly==5.24.1
|
| 9 |
+
|
| 10 |
+
# Financial Data
|
| 11 |
yfinance==0.2.66
|
| 12 |
+
pandas-datareader==0.10.0
|
| 13 |
+
|
| 14 |
+
# Reinforcement Learning
|
| 15 |
+
gymnasium==1.2.2
|
| 16 |
+
shimmy==2.0.0
|
| 17 |
+
stable-baselines3==2.7.0
|
| 18 |
+
sb3-contrib==2.7.0
|
| 19 |
|
| 20 |
+
# Deep Learning Framework
|
| 21 |
+
torch==2.9.0
|
| 22 |
|
| 23 |
+
# Utilities & Other
|
| 24 |
+
gradio==5.50.0
|
| 25 |
+
python-dotenv==1.2.1
|
| 26 |
+
tabulate==0.9.0
|
| 27 |
+
quantstats==0.0.62
|
| 28 |
+
pandas-ta==0.4.71b0
|
results/final_performance_comparison_all_agents.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
results/{td3_portfolio_alocation.png → ppo_allocation.png}
RENAMED
|
File without changes
|
results/sac_allocation.png
ADDED
|
Git LFS Details
|
results/{ppo_portfolio_alocation.png → td3_allocation.png}
RENAMED
|
File without changes
|
results/{sac_portfolio_alocation.png → td3_transformer_allocation.png}
RENAMED
|
File without changes
|
scripts/app.py
ADDED
|
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# scripts/app.py
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
import plotly.graph_objects as go
|
| 7 |
+
import plotly.express as px
|
| 8 |
+
from datetime import datetime, timedelta
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import json
|
| 12 |
+
import torch
|
| 13 |
+
from fetch_market_data import fetch_market_data, ASSETS, FRED_IDS
|
| 14 |
+
from llm_analysis_rag import analyze_agent_decision, analyze_historical_segment
|
| 15 |
+
from stable_baselines3 import SAC
|
| 16 |
+
from environment import PortfolioEnv
|
| 17 |
+
from scripts.evaluate_baselines import buy_and_hold, equally_weighted_rebalanced
|
| 18 |
+
|
| 19 |
+
# --- Configuration ---
|
| 20 |
+
MODEL_PATH = os.path.join(project_root, "checkpoints", "sac_portfolio_model.zip")
|
| 21 |
+
WINDOW_SIZE = 30
|
| 22 |
+
MACRO_COLS = list(FRED_IDS.values())
|
| 23 |
+
DASHBOARD_DATA_PATH = os.path.join(project_root, "data", "historical_dashboard_data.csv")
|
| 24 |
+
|
| 25 |
+
# *** UPDATE THESE DATES TO MATCH YOUR ACTUAL TRAINING PERIOD ***
|
| 26 |
+
TRAIN_START_DATE = "2015-01-01"
|
| 27 |
+
TRAIN_END_DATE = "2023-01-01"
|
| 28 |
+
|
| 29 |
+
# Global variable for dashboard data needed for Tabs 3 & 4
|
| 30 |
+
DASHBOARD_DATA_DF = None
|
| 31 |
+
|
| 32 |
+
# Define Time Period mappings for the dropdown
|
| 33 |
+
TIME_PERIODS = {
|
| 34 |
+
"6 Months": 180,
|
| 35 |
+
"1 Year": 365,
|
| 36 |
+
"2 Years": 730,
|
| 37 |
+
"5 Years": 1825,
|
| 38 |
+
"Max Available": 9999 # Sentinel value for max
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# =========================================
|
| 42 |
+
# Initialization Functions
|
| 43 |
+
# =========================================
|
| 44 |
+
|
| 45 |
+
def initialize_dashboard_data():
|
| 46 |
+
"""Fetches and loads historical data at startup for Tabs 3 & 4."""
|
| 47 |
+
global DASHBOARD_DATA_DF
|
| 48 |
+
print("--- Initializing Historical Data for Analyst/Simulation Tabs ---")
|
| 49 |
+
|
| 50 |
+
# Fetching last 6 years to support longer analysis periods and simulation
|
| 51 |
+
end_date = (datetime.now() - timedelta(days=1)).strftime('%Y-%m-%d')
|
| 52 |
+
start_date = (datetime.now() - timedelta(days=365*6)).strftime('%Y-%m-%d')
|
| 53 |
+
|
| 54 |
+
print(f"Fetching historical data from {start_date} to {end_date}...")
|
| 55 |
+
# This might take a minute on first run
|
| 56 |
+
fetch_market_data(start_date, end_date, DASHBOARD_DATA_PATH)
|
| 57 |
+
|
| 58 |
+
if os.path.exists(DASHBOARD_DATA_PATH):
|
| 59 |
+
DASHBOARD_DATA_DF = pd.read_csv(DASHBOARD_DATA_PATH, index_col=0, parse_dates=True)
|
| 60 |
+
# Basic cleaning
|
| 61 |
+
DASHBOARD_DATA_DF.dropna(how='all', inplace=True)
|
| 62 |
+
# Calculate equal weight return for dashboard metrics
|
| 63 |
+
asset_cols = [c for c in ASSETS if c in DASHBOARD_DATA_DF.columns]
|
| 64 |
+
if asset_cols:
|
| 65 |
+
DASHBOARD_DATA_DF['Daily_Ret_Eq'] = DASHBOARD_DATA_DF[asset_cols].pct_change().mean(axis=1)
|
| 66 |
+
print(f"Data loaded successfully. Shape: {DASHBOARD_DATA_DF.shape}")
|
| 67 |
+
print(f"Data range: {DASHBOARD_DATA_DF.index.min().date()} to {DASHBOARD_DATA_DF.index.max().date()}")
|
| 68 |
+
else:
|
| 69 |
+
print("❌ Failed to initialize historical data.")
|
| 70 |
+
|
| 71 |
+
# Initialize data at startup
|
| 72 |
+
try:
|
| 73 |
+
initialize_dashboard_data()
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f"Warning: Data initialization failed. Error: {e}")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# =========================================
|
| 79 |
+
# Professional Metrics & Evaluation Functions
|
| 80 |
+
# =========================================
|
| 81 |
+
|
| 82 |
+
def evaluate_agent_pro(env, model):
|
| 83 |
+
"""
|
| 84 |
+
Runs the trained agent on the environment and returns portfolio values.
|
| 85 |
+
"""
|
| 86 |
+
obs, info = env.reset()
|
| 87 |
+
terminated, truncated = False, False
|
| 88 |
+
portfolio_values = [env.initial_amount]
|
| 89 |
+
|
| 90 |
+
while not (terminated or truncated):
|
| 91 |
+
action, _states = model.predict(obs, deterministic=True)
|
| 92 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
| 93 |
+
portfolio_values.append(info['portfolio_value'])
|
| 94 |
+
|
| 95 |
+
# Align index with the actual steps taken
|
| 96 |
+
valid_dates = env.df.index[env.window_size-1:]
|
| 97 |
+
return pd.Series(portfolio_values, index=valid_dates[:len(portfolio_values)])
|
| 98 |
+
|
| 99 |
+
def calculate_metrics_pro(portfolio_values, freq=252, rf=0.0):
|
| 100 |
+
"""
|
| 101 |
+
Calculates key professional performance metrics from a series of portfolio values.
|
| 102 |
+
"""
|
| 103 |
+
if len(portfolio_values) < 2:
|
| 104 |
+
return {k: "N/A" for k in ["Total Return", "CAGR", "Sharpe Ratio", "Sortino Ratio", "Volatility", "Max Drawdown", "Calmar Ratio"]}
|
| 105 |
+
|
| 106 |
+
returns = portfolio_values.pct_change().dropna()
|
| 107 |
+
if returns.empty:
|
| 108 |
+
return {k: "0.00%" if "%" in k else "0.00" for k in ["Total Return", "CAGR", "Sharpe Ratio", "Sortino Ratio", "Volatility", "Max Drawdown", "Calmar Ratio"]}
|
| 109 |
+
|
| 110 |
+
total_return = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) - 1
|
| 111 |
+
num_years = (len(portfolio_values) - 1) / freq
|
| 112 |
+
cagr = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) ** (1/num_years) - 1 if num_years > 0 else 0.0
|
| 113 |
+
|
| 114 |
+
sharpe_ratio = np.sqrt(freq) * (returns.mean() - rf) / returns.std() if returns.std() > 0 else np.nan
|
| 115 |
+
|
| 116 |
+
downside_returns = returns[returns < 0]
|
| 117 |
+
downside_std = downside_returns.std()
|
| 118 |
+
sortino_ratio = np.sqrt(freq) * (returns.mean() - rf) / downside_std if downside_std > 0 else np.nan
|
| 119 |
+
|
| 120 |
+
volatility = returns.std() * np.sqrt(freq)
|
| 121 |
+
|
| 122 |
+
rolling_max = portfolio_values.cummax()
|
| 123 |
+
drawdown = portfolio_values / rolling_max - 1.0
|
| 124 |
+
max_drawdown = drawdown.min()
|
| 125 |
+
|
| 126 |
+
calmar_ratio = cagr / abs(max_drawdown) if max_drawdown != 0 and cagr != 0 else np.nan
|
| 127 |
+
|
| 128 |
+
return {
|
| 129 |
+
"Total Return": total_return,
|
| 130 |
+
"CAGR": cagr,
|
| 131 |
+
"Sharpe Ratio": sharpe_ratio,
|
| 132 |
+
"Sortino Ratio": sortino_ratio,
|
| 133 |
+
"Volatility": volatility,
|
| 134 |
+
"Max Drawdown": max_drawdown,
|
| 135 |
+
"Calmar Ratio": calmar_ratio
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
# =========================================
|
| 139 |
+
# XAI: Feature Importance Function
|
| 140 |
+
# =========================================
|
| 141 |
+
def calculate_feature_importance(model, obs):
|
| 142 |
+
"""
|
| 143 |
+
Calculates feature importance using Integrated Gradients on the RL agent's policy network.
|
| 144 |
+
"""
|
| 145 |
+
# Convert observation to torch tensor and enable gradient tracking
|
| 146 |
+
obs_tensor = torch.as_tensor(obs, dtype=torch.float32, device=model.device)
|
| 147 |
+
obs_tensor.requires_grad_()
|
| 148 |
+
|
| 149 |
+
# Get the policy network (actor)
|
| 150 |
+
actor = model.policy.actor
|
| 151 |
+
|
| 152 |
+
# Define a baseline (e.g., a zero observation)
|
| 153 |
+
baseline = torch.zeros_like(obs_tensor)
|
| 154 |
+
|
| 155 |
+
# Number of steps for integral approximation
|
| 156 |
+
steps = 50
|
| 157 |
+
|
| 158 |
+
# Generate scaled inputs along the path from baseline to input
|
| 159 |
+
scaled_inputs = [baseline + (float(i) / steps) * (obs_tensor - baseline) for i in range(steps + 1)]
|
| 160 |
+
|
| 161 |
+
grads = []
|
| 162 |
+
for scaled_input in scaled_inputs:
|
| 163 |
+
# Forward pass to get action distribution parameters (mean)
|
| 164 |
+
action_mean = actor(scaled_input)
|
| 165 |
+
|
| 166 |
+
# We need a scalar output to calculate gradients against.
|
| 167 |
+
# Here we sum, representing overall sensitivity of the action vector.
|
| 168 |
+
target_output = action_mean.sum()
|
| 169 |
+
|
| 170 |
+
# Calculate gradients of the target output with respect to the input features
|
| 171 |
+
grad = torch.autograd.grad(outputs=target_output, inputs=scaled_input)[0]
|
| 172 |
+
grads.append(grad)
|
| 173 |
+
|
| 174 |
+
# Average the gradients using the trapezoidal rule approximation
|
| 175 |
+
avg_grads = (grads[:-1] + grads[1:]) / 2.0
|
| 176 |
+
avg_grads = torch.stack(avg_grads).mean(dim=0)
|
| 177 |
+
|
| 178 |
+
# Calculate Integrated Gradients: (input - baseline) * average_gradients
|
| 179 |
+
integrated_grads = (obs_tensor - baseline) * avg_grads
|
| 180 |
+
|
| 181 |
+
# Detach, move to cpu, and convert to numpy array
|
| 182 |
+
importance_scores = integrated_grads.detach().cpu().numpy().flatten()
|
| 183 |
+
|
| 184 |
+
# Feature Names mapping
|
| 185 |
+
num_assets = len(ASSETS)
|
| 186 |
+
num_macro = len(MACRO_COLS)
|
| 187 |
+
|
| 188 |
+
# Create feature names based on the observation structure
|
| 189 |
+
feature_names = []
|
| 190 |
+
for i in range(WINDOW_SIZE):
|
| 191 |
+
for asset in ASSETS:
|
| 192 |
+
feature_names.append(f"{asset}_t-{WINDOW_SIZE-1-i}")
|
| 193 |
+
for i in range(WINDOW_SIZE):
|
| 194 |
+
for macro in MACRO_COLS:
|
| 195 |
+
feature_names.append(f"{macro}_t-{WINDOW_SIZE-1-i}")
|
| 196 |
+
|
| 197 |
+
# Combine into a dictionary and sort by absolute importance
|
| 198 |
+
feature_importance_dict = dict(zip(feature_names, importance_scores))
|
| 199 |
+
|
| 200 |
+
# Aggregate importance by feature type (sum of absolute values across time steps)
|
| 201 |
+
aggregated_importance = {}
|
| 202 |
+
for base_feature in ASSETS + MACRO_COLS:
|
| 203 |
+
total_imp = sum(abs(val) for key, val in feature_importance_dict.items() if key.startswith(base_feature))
|
| 204 |
+
aggregated_importance[base_feature] = total_imp
|
| 205 |
+
|
| 206 |
+
# Sort and take top N for display
|
| 207 |
+
top_features = dict(sorted(aggregated_importance.items(), key=lambda item: item[1], reverse=True)[:8])
|
| 208 |
+
|
| 209 |
+
# Create a Plotly bar chart
|
| 210 |
+
fig = px.bar(
|
| 211 |
+
x=list(top_features.values()),
|
| 212 |
+
y=list(top_features.keys()),
|
| 213 |
+
orientation='h',
|
| 214 |
+
title="Top Influential Features (XAI)",
|
| 215 |
+
labels={'x': 'Relative Importance Score', 'y': 'Feature'},
|
| 216 |
+
color=list(top_features.values()),
|
| 217 |
+
color_continuous_scale=px.colors.sequential.Viridis
|
| 218 |
+
)
|
| 219 |
+
fig.update_layout(
|
| 220 |
+
template="plotly_dark",
|
| 221 |
+
paper_bgcolor='rgba(0,0,0,0)',
|
| 222 |
+
plot_bgcolor='rgba(0,0,0,0)',
|
| 223 |
+
yaxis={'categoryorder':'total ascending'},
|
| 224 |
+
coloraxis_showscale=False,
|
| 225 |
+
margin=dict(l=10, r=10, t=40, b=10),
|
| 226 |
+
height=300 # Keep it compact
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
return fig
|
| 230 |
+
|
| 231 |
+
# =========================================
|
| 232 |
+
# Tab 4 Logic: Historical Simulation (UPDATED)
|
| 233 |
+
# =========================================
|
| 234 |
+
|
| 235 |
+
def run_historical_simulation(start_date_str, end_date_str):
|
| 236 |
+
"""
|
| 237 |
+
Runs the RL agent on historical data and compares to baselines using professional metrics.
|
| 238 |
+
"""
|
| 239 |
+
if DASHBOARD_DATA_DF is None:
|
| 240 |
+
return go.Figure(), "Data not initialized. Please restart app.", gr.update(visible=False)
|
| 241 |
+
|
| 242 |
+
status_msg = "Preparing simulation..."
|
| 243 |
+
yield go.Figure(), status_msg, gr.update(visible=False)
|
| 244 |
+
|
| 245 |
+
try:
|
| 246 |
+
# 1. Validate and Slice Data
|
| 247 |
+
try:
|
| 248 |
+
start_date = pd.to_datetime(start_date_str)
|
| 249 |
+
end_date = pd.to_datetime(end_date_str)
|
| 250 |
+
except ValueError:
|
| 251 |
+
yield go.Figure(), "Error: Invalid date format. Use YYYY-MM-DD.", gr.update(visible=False)
|
| 252 |
+
return
|
| 253 |
+
|
| 254 |
+
if start_date < DASHBOARD_DATA_DF.index.min() or end_date > DASHBOARD_DATA_DF.index.max():
|
| 255 |
+
avail_start = DASHBOARD_DATA_DF.index.min().date()
|
| 256 |
+
avail_end = DASHBOARD_DATA_DF.index.max().date()
|
| 257 |
+
yield go.Figure(), f"Error: Selected dates outside available range ({avail_start} to {avail_end}).", gr.update(visible=False)
|
| 258 |
+
return
|
| 259 |
+
|
| 260 |
+
df_slice = DASHBOARD_DATA_DF.loc[start_date:end_date].copy()
|
| 261 |
+
asset_cols_only = [c for c in ASSETS if c in df_slice.columns]
|
| 262 |
+
|
| 263 |
+
if len(df_slice) < WINDOW_SIZE + 10:
|
| 264 |
+
yield go.Figure(), "Error: Time period too short for simulation.", gr.update(visible=False)
|
| 265 |
+
return
|
| 266 |
+
|
| 267 |
+
# 2. Setup Environment and Agent
|
| 268 |
+
status_msg = "Running RL Agent simulation..."
|
| 269 |
+
yield go.Figure(), status_msg, gr.update(visible=False)
|
| 270 |
+
|
| 271 |
+
env = PortfolioEnv(df_slice, WINDOW_SIZE, initial_amount=10000)
|
| 272 |
+
|
| 273 |
+
if not os.path.exists(MODEL_PATH):
|
| 274 |
+
raise FileNotFoundError(f"Model not found: {MODEL_PATH}")
|
| 275 |
+
model = SAC.load(MODEL_PATH)
|
| 276 |
+
|
| 277 |
+
# 3. Run Simulation Loop & Get Values using Pro Function
|
| 278 |
+
rl_portfolio_series = evaluate_agent_pro(env, model)
|
| 279 |
+
|
| 280 |
+
# 4. Calculate Baselines using Pro Functions
|
| 281 |
+
status_msg = "Calculating baselines and metrics..."
|
| 282 |
+
yield go.Figure(), status_msg, gr.update(visible=False)
|
| 283 |
+
|
| 284 |
+
# Pass only asset columns to baseline functions
|
| 285 |
+
bnh_portfolio_series = buy_and_hold(df_slice[asset_cols_only], initial_amount=10000)
|
| 286 |
+
# Realign B&H index to match RL agent's start date
|
| 287 |
+
bnh_portfolio_series = bnh_portfolio_series.loc[rl_portfolio_series.index[0]:]
|
| 288 |
+
# Normalize B&H starting value to match RL agent's start
|
| 289 |
+
bnh_portfolio_series = bnh_portfolio_series / bnh_portfolio_series.iloc[0] * 10000
|
| 290 |
+
|
| 291 |
+
eq_portfolio_series = equally_weighted_rebalanced(df_slice[asset_cols_only], initial_amount=10000)
|
| 292 |
+
eq_portfolio_series = eq_portfolio_series.loc[rl_portfolio_series.index[0]:]
|
| 293 |
+
eq_portfolio_series = eq_portfolio_series / eq_portfolio_series.iloc[0] * 10000
|
| 294 |
+
|
| 295 |
+
# 5. Generate Plot
|
| 296 |
+
fig = go.Figure()
|
| 297 |
+
fig.add_trace(go.Scatter(x=rl_portfolio_series.index, y=rl_portfolio_series, mode='lines', name='RL Agent (SAC)', line=dict(color='#10b981', width=3)))
|
| 298 |
+
fig.add_trace(go.Scatter(x=bnh_portfolio_series.index, y=bnh_portfolio_series, mode='lines', name='Buy & Hold (SPY)', line=dict(color='#6b7280', dash='dash')))
|
| 299 |
+
fig.add_trace(go.Scatter(x=eq_portfolio_series.index, y=eq_portfolio_series, mode='lines', name='Equal Weighted', line=dict(color='#a855f7', dash='dot')))
|
| 300 |
+
|
| 301 |
+
fig.update_layout(
|
| 302 |
+
title="Simulation: Strategy Performance Comparison ($10k Start)",
|
| 303 |
+
xaxis_title="Date",
|
| 304 |
+
yaxis_title="Portfolio Value ($)",
|
| 305 |
+
template="plotly_dark",
|
| 306 |
+
paper_bgcolor='rgba(0,0,0,0)',
|
| 307 |
+
plot_bgcolor='rgba(0,0,0,0)',
|
| 308 |
+
hovermode="x unified",
|
| 309 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# 6. Calculate Professional Metrics Table
|
| 313 |
+
rl_m = calculate_metrics_pro(rl_portfolio_series)
|
| 314 |
+
bnh_m = calculate_metrics_pro(bnh_portfolio_series)
|
| 315 |
+
eq_m = calculate_metrics_pro(eq_portfolio_series)
|
| 316 |
+
|
| 317 |
+
# Helper to format based on metric type
|
| 318 |
+
def fmt(val, is_pct=True):
|
| 319 |
+
if pd.isna(val): return "N/A"
|
| 320 |
+
return f"{val:.2%}" if is_pct else f"{val:.2f}"
|
| 321 |
+
|
| 322 |
+
metrics_data = {
|
| 323 |
+
"Metric": ["Total Return", "CAGR", "Sharpe Ratio", "Sortino Ratio", "Volatility (Ann.)", "Max Drawdown", "Calmar Ratio"],
|
| 324 |
+
"RL Agent (SAC)": [fmt(rl_m["Total Return"]), fmt(rl_m["CAGR"]), fmt(rl_m["Sharpe Ratio"], False), fmt(rl_m["Sortino Ratio"], False), fmt(rl_m["Volatility"]), fmt(rl_m["Max Drawdown"]), fmt(rl_m["Calmar Ratio"], False)],
|
| 325 |
+
"Buy & Hold (SPY)": [fmt(bnh_m["Total Return"]), fmt(bnh_m["CAGR"]), fmt(bnh_m["Sharpe Ratio"], False), fmt(bnh_m["Sortino Ratio"], False), fmt(bnh_m["Volatility"]), fmt(bnh_m["Max Drawdown"]), fmt(bnh_m["Calmar Ratio"], False)],
|
| 326 |
+
"Equal Weighted": [fmt(eq_m["Total Return"]), fmt(eq_m["CAGR"]), fmt(eq_m["Sharpe Ratio"], False), fmt(eq_m["Sortino Ratio"], False), fmt(eq_m["Volatility"]), fmt(eq_m["Max Drawdown"]), fmt(eq_m["Calmar Ratio"], False)],
|
| 327 |
+
}
|
| 328 |
+
metrics_df = pd.DataFrame(metrics_data)
|
| 329 |
+
|
| 330 |
+
# Format the dataframe as a markdown table for cleaner display
|
| 331 |
+
metrics_md = metrics_df.to_markdown(index=False)
|
| 332 |
+
final_metrics_display = f"### 📊 Professional Performance Metrics\n\n{metrics_md}"
|
| 333 |
+
|
| 334 |
+
yield fig, "Simulation Complete.", final_metrics_display
|
| 335 |
+
|
| 336 |
+
except Exception as e:
|
| 337 |
+
import traceback
|
| 338 |
+
traceback.print_exc()
|
| 339 |
+
yield go.Figure(), f"Error during simulation: {str(e)}", gr.update(visible=False)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# =========================================
|
| 343 |
+
# Tab 3 Logic: Historical Data Analyst
|
| 344 |
+
# =========================================
|
| 345 |
+
|
| 346 |
+
def run_historical_analysis(selected_assets, period_name):
|
| 347 |
+
"""Backend for Tab 3."""
|
| 348 |
+
if DASHBOARD_DATA_DF is None or not selected_assets:
|
| 349 |
+
return go.Figure(), "Please wait for data initialization or select assets."
|
| 350 |
+
|
| 351 |
+
status_html = """<div style="color: #9ca3af;">🔄 Processing data and running AI analysis...</div>"""
|
| 352 |
+
yield go.Figure(), status_html
|
| 353 |
+
|
| 354 |
+
try:
|
| 355 |
+
# 1. Filter Data by Time Period
|
| 356 |
+
days = TIME_PERIODS.get(period_name, 365)
|
| 357 |
+
cutoff_date = datetime.now() - timedelta(days=days)
|
| 358 |
+
valid_assets = [a for a in selected_assets if a in DASHBOARD_DATA_DF.columns]
|
| 359 |
+
if not valid_assets:
|
| 360 |
+
yield go.Figure(), "Error: Selected assets not found in available data."
|
| 361 |
+
return
|
| 362 |
+
df_filtered = DASHBOARD_DATA_DF.loc[cutoff_date:, valid_assets].copy()
|
| 363 |
+
if df_filtered.empty:
|
| 364 |
+
yield go.Figure(), f"No data found for the selected period: {period_name}"
|
| 365 |
+
return
|
| 366 |
+
|
| 367 |
+
# 2. Generate Normalized Price Plot
|
| 368 |
+
df_normalized = df_filtered / df_filtered.iloc[0] * 100
|
| 369 |
+
fig = px.line(df_normalized, x=df_normalized.index, y=df_normalized.columns,
|
| 370 |
+
title=f"Performance Comparison: {period_name} (Base=100)",
|
| 371 |
+
color_discrete_sequence=px.colors.qualitative.Bold)
|
| 372 |
+
fig.update_layout(template="plotly_dark", paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)',
|
| 373 |
+
yaxis_title="Normalized Price", xaxis_title="Date", legend_title_text="", hovermode="x unified")
|
| 374 |
+
|
| 375 |
+
# 3. Run AI Analysis
|
| 376 |
+
analysis_text = analyze_historical_segment(df_filtered, valid_assets, period_name)
|
| 377 |
+
formatted_analysis = f"### 🤖 AI Analyst Report: {period_name}\n\n{analysis_text}"
|
| 378 |
+
yield fig, formatted_analysis
|
| 379 |
+
|
| 380 |
+
except Exception as e:
|
| 381 |
+
import traceback
|
| 382 |
+
traceback.print_exc()
|
| 383 |
+
yield go.Figure(), f"### Error during analysis\n\n{str(e)}"
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# =========================================
|
| 387 |
+
# Tab 2 Logic: Forecast & Analysis (XAI)
|
| 388 |
+
# =========================================
|
| 389 |
+
|
| 390 |
+
def get_latest_data_window(window_size=30):
|
| 391 |
+
"""Fetches latest data needed for prediction."""
|
| 392 |
+
print("Fetching prediction data...")
|
| 393 |
+
lookback_days = window_size + 150
|
| 394 |
+
end_date = datetime.now().strftime('%Y-%m-%d')
|
| 395 |
+
start_date = (datetime.now() - timedelta(days=lookback_days)).strftime('%Y-%m-%d')
|
| 396 |
+
temp_filename = os.path.join(project_root, "data", "temp_gradio_prediction_data.csv")
|
| 397 |
+
fetch_market_data(start_date, end_date, temp_filename)
|
| 398 |
+
if not os.path.exists(temp_filename): raise Exception("Failed to fetch market data file.")
|
| 399 |
+
df = pd.read_csv(temp_filename, index_col=0, parse_dates=True)
|
| 400 |
+
df.dropna(inplace=True)
|
| 401 |
+
if len(df) < window_size: raise Exception(f"Not enough clean data fetched for prediction.")
|
| 402 |
+
return df.iloc[-window_size:].copy()
|
| 403 |
+
|
| 404 |
+
def prepare_observation(data_window):
|
| 405 |
+
price_data = data_window[ASSETS].values
|
| 406 |
+
macro_data = data_window[MACRO_COLS].values
|
| 407 |
+
norm_prices = price_data / (price_data[0] + 1e-8)
|
| 408 |
+
norm_macro = macro_data / (macro_data[0] + 1e-8)
|
| 409 |
+
obs = np.concatenate([norm_prices, norm_macro], axis=1)
|
| 410 |
+
# Return both flattened obs for model and raw obs for XAI
|
| 411 |
+
return obs.flatten().astype(np.float32), obs.astype(np.float32), data_window
|
| 412 |
+
|
| 413 |
+
def predict_and_analyze():
|
| 414 |
+
"""Main function for Forecast Tab."""
|
| 415 |
+
status_msg = "Starting process..."
|
| 416 |
+
loading_html = """<div style="color: #9ca3af;">🔄 Fetching data & running prediction...</div>"""
|
| 417 |
+
# Update to yield an empty plot for the XAI chart initially
|
| 418 |
+
yield status_msg, None, go.Figure(), loading_html
|
| 419 |
+
|
| 420 |
+
try:
|
| 421 |
+
data_window = get_latest_data_window(WINDOW_SIZE)
|
| 422 |
+
# Get flattened obs for prediction and raw obs for XAI
|
| 423 |
+
flat_obs, raw_obs, df_window_for_analyst = prepare_observation(data_window)
|
| 424 |
+
|
| 425 |
+
if not os.path.exists(MODEL_PATH): raise FileNotFoundError(f"Model not found: {MODEL_PATH}")
|
| 426 |
+
model = SAC.load(MODEL_PATH)
|
| 427 |
+
|
| 428 |
+
# --- XAI: Calculate Feature Importance ---
|
| 429 |
+
status_msg = "Calculating feature importance..."
|
| 430 |
+
yield status_msg, None, go.Figure(), loading_html
|
| 431 |
+
xai_plot = calculate_feature_importance(model, raw_obs)
|
| 432 |
+
|
| 433 |
+
# --- Prediction ---
|
| 434 |
+
action, _ = model.predict(flat_obs, deterministic=True)
|
| 435 |
+
exp_action = np.exp(np.asarray(action).flatten())
|
| 436 |
+
weights = exp_action / np.sum(exp_action)
|
| 437 |
+
allocations_dict = {asset: weights[i] for i, asset in enumerate(ASSETS)}
|
| 438 |
+
allocations_dict['Cash'] = weights[-1]
|
| 439 |
+
alloc_df = pd.DataFrame(list(allocations_dict.items()), columns=['Asset', 'Proposed Allocation'])
|
| 440 |
+
alloc_df['Proposed Allocation'] = alloc_df['Proposed Allocation'].apply(lambda x: f"{x:.2%}")
|
| 441 |
+
|
| 442 |
+
status_msg = "Prediction done. Running AI Risk Analysis..."
|
| 443 |
+
analysing_html = """<div style="color: #9ca3af;">🤖 Running Qwen-2.5-3B Risk Analysis...</div>"""
|
| 444 |
+
# Yield XAI plot along with other outputs
|
| 445 |
+
yield status_msg, alloc_df, xai_plot, analysing_html
|
| 446 |
+
|
| 447 |
+
allocations_for_llm = {k: float(v) for k, v in allocations_dict.items()}
|
| 448 |
+
analysis_result = analyze_agent_decision(df_window_for_analyst, allocations_for_llm)
|
| 449 |
+
status_msg = "Analysis complete!"
|
| 450 |
+
|
| 451 |
+
if isinstance(analysis_result, dict):
|
| 452 |
+
strat = analysis_result.get('strategy_summary', 'N/A')
|
| 453 |
+
risk = analysis_result.get('risk_level', 'N/A').upper()
|
| 454 |
+
just = analysis_result.get('justification', 'N/A')
|
| 455 |
+
conf = analysis_result.get('confidence_score', 'N/A')
|
| 456 |
+
if 'HIGH' in risk:
|
| 457 |
+
risk_css = "color: #ef4444; font-weight: bold;"
|
| 458 |
+
status_bg = "#7f1d1d"
|
| 459 |
+
status_border = "#ef4444"
|
| 460 |
+
status_icon = "⛔"
|
| 461 |
+
status_text = "TRADE BLOCKED: High Risk Detected"
|
| 462 |
+
else:
|
| 463 |
+
risk_css = "color: #10b981; font-weight: bold;"
|
| 464 |
+
status_bg = "#064e3b"
|
| 465 |
+
status_border = "#10b981"
|
| 466 |
+
status_icon = "🚀"
|
| 467 |
+
status_text = "TRADE APPROVED"
|
| 468 |
+
|
| 469 |
+
report_html = f"""
|
| 470 |
+
<div style="background-color: #1f2937; padding: 20px; border-radius: 12px 12px 0 0; border: 1px solid #374151; border-bottom: none;">
|
| 471 |
+
<h3 style="margin-top: 0; color: #e5e7eb;">🤖 AI Risk Analyst Report</h3>
|
| 472 |
+
<div style="margin-bottom: 15px;"><strong style="color: #9ca3af;">Strategy:</strong><br><span style="color: #d1d5db;">{strat}</span></div>
|
| 473 |
+
<div style="margin-bottom: 15px;"><strong style="color: #9ca3af;">Risk Level:</strong><span style="margin-left: 8px; {risk_css}">{risk}</span></div>
|
| 474 |
+
<div style="margin-bottom: 15px;"><strong style="color: #9ca3af;">Justification:</strong><br><span style="color: #d1d5db;">{just}</span></div>
|
| 475 |
+
<div><strong style="color: #9ca3af;">Confidence:</strong> <span style="color: #d1d5db;">{conf}/10</span></div>
|
| 476 |
+
</div>
|
| 477 |
+
<div style="background-color: {status_bg}; color: white; padding: 15px; border-radius: 0 0 12px 12px; border: 2px solid {status_border}; text-align: center; font-size: 1.2em; font-weight: bold; display: flex; align-items: center; justify-content: center;">
|
| 478 |
+
<span style="margin-right: 10px; font-size: 1.4em;">{status_icon}</span>{status_text}
|
| 479 |
+
</div>"""
|
| 480 |
+
else:
|
| 481 |
+
report_html = f"""<div style="padding: 20px; background-color: #7f1d1d; color: #fca5a5; border-radius: 12px;"><h3>❌ Analysis Failed to Parse</h3><p>{str(analysis_result)}</p></div>"""
|
| 482 |
+
# Final yield with all outputs including XAI plot
|
| 483 |
+
yield status_msg, alloc_df, xai_plot, report_html
|
| 484 |
+
except Exception as e:
|
| 485 |
+
import traceback
|
| 486 |
+
traceback.print_exc()
|
| 487 |
+
status_msg = f"Error: {str(e)}"
|
| 488 |
+
error_html = f"""<div style="padding: 20px; background-color: #7f1d1d; color: #fca5a5; border-radius: 12px;"><h3>❌ Process Error</h3><p>{str(e)}</p></div>"""
|
| 489 |
+
# Final yield in case of error
|
| 490 |
+
yield status_msg, None, go.Figure(), error_html
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
# =========================================
|
| 494 |
+
# Tab 1 Logic: Live Dashboard (DUMMY DATA)
|
| 495 |
+
# =========================================
|
| 496 |
+
def get_dashboard_metrics():
|
| 497 |
+
return "$135,400", "+3.07%"
|
| 498 |
+
|
| 499 |
+
def get_portfolio_history_plot():
|
| 500 |
+
dates = pd.date_range(start="2023-01-01", periods=100)
|
| 501 |
+
np.random.seed(42)
|
| 502 |
+
rl_returns = np.random.normal(0.001, 0.01, 100)
|
| 503 |
+
bnh_returns = np.random.normal(0.0005, 0.012, 100)
|
| 504 |
+
rl_value = 10000 * np.cumprod(1 + rl_returns)
|
| 505 |
+
bnh_value = 10000 * np.cumprod(1 + bnh_returns)
|
| 506 |
+
fig = go.Figure()
|
| 507 |
+
fig.add_trace(go.Scatter(x=dates, y=rl_value, mode='lines', name='RL Agent (Live)', line=dict(color='#10b981', width=3)))
|
| 508 |
+
fig.add_trace(go.Scatter(x=dates, y=bnh_value, mode='lines', name='Benchmark', line=dict(color='#6b7280', dash='dash')))
|
| 509 |
+
fig.update_layout(title="Portfolio Net Worth (Live Tracking)", xaxis_title="Date", yaxis_title="Net Worth ($)", template="plotly_dark", paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)', legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1))
|
| 510 |
+
return fig
|
| 511 |
+
|
| 512 |
+
def get_current_allocation_plot():
|
| 513 |
+
labels = ASSETS + ['Cash']
|
| 514 |
+
values = [0.25, 0.10, 0.30, 0.15, 0.05, 0.15]
|
| 515 |
+
fig = px.pie(values=values, names=labels, title="Current Holdings Breakdown", color_discrete_sequence=px.colors.qualitative.Bold)
|
| 516 |
+
fig.update_traces(textposition='inside', textinfo='percent+label', hole=.4)
|
| 517 |
+
fig.update_layout(template="plotly_dark", paper_bgcolor='rgba(0,0,0,0)', legend=dict(orientation="h", yanchor="bottom", y=-0.1))
|
| 518 |
+
return fig
|
| 519 |
+
|
| 520 |
+
def get_recent_transactions():
|
| 521 |
+
data = [["2025-11-24", "Rebalance", "MULTIPLE", "N/A"], ["2025-11-24", "SELL", "SPY", "$4,500"], ["2025-11-24", "BUY", "TLT", "$4,200"], ["2025-11-21", "BUY", "BTC-USD", "$1,000"]]
|
| 522 |
+
return pd.DataFrame(data, columns=["Date", "Type", "Asset", "Approx. Value"])
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
# =========================================
|
| 526 |
+
# Gradio Interface
|
| 527 |
+
# =========================================
|
| 528 |
+
|
| 529 |
+
custom_css = """
|
| 530 |
+
.metric-box { background-color: #1f2937; padding: 20px; border-radius: 12px; border: 1px solid #374151; text-align: center; }
|
| 531 |
+
.metric-label { font-size: 1.1em; color: #9ca3af; margin-bottom: 5px; }
|
| 532 |
+
.metric-value { font-size: 2.2em; font-weight: 700; color: #e5e7eb; }
|
| 533 |
+
.disclaimer-box { background-color: #374151; padding: 15px; border-radius: 8px; border-left: 4px solid #f59e0b; color: #d1d5db; font-size: 0.9em; margin-bottom: 20px; }
|
| 534 |
+
"""
|
| 535 |
+
|
| 536 |
+
theme = gr.themes.Soft(primary_hue="emerald", secondary_hue="slate", neutral_hue="zinc").set(
|
| 537 |
+
body_background_fill="#111827", block_background_fill="#1f2937", block_border_width="1px", block_border_color="#374151"
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
with gr.Blocks(theme=theme, css=custom_css, title="Deep RL Portfolio Manager") as demo:
|
| 541 |
+
gr.HTML("""<script>function forceDark(){document.body.classList.add('dark');} forceDark(); setTimeout(forceDark, 500);</script>""")
|
| 542 |
+
|
| 543 |
+
gr.Markdown("# 🧠 Deep RL & LLM Portfolio Manager")
|
| 544 |
+
|
| 545 |
+
with gr.Tabs():
|
| 546 |
+
# ================= TAB 1: DASHBOARD (RESTORED) =================
|
| 547 |
+
with gr.TabItem("📊 Live Dashboard"):
|
| 548 |
+
# Metrics Row
|
| 549 |
+
with gr.Row():
|
| 550 |
+
# MOVED THIS LINE INSIDE THE TAB
|
| 551 |
+
nw_val, dc_val = get_dashboard_metrics()
|
| 552 |
+
with gr.Column(elem_classes=["metric-box"]):
|
| 553 |
+
gr.HTML(f"<div class='metric-label'>Current Net Worth</div><div class='metric-value'>{nw_val}</div>")
|
| 554 |
+
with gr.Column(elem_classes=["metric-box"]):
|
| 555 |
+
gr.HTML(f"<div class='metric-label'>24h Change</div><div class='metric-value' style='color: #10b981;'>{daily_change}</div>")
|
| 556 |
+
|
| 557 |
+
# Main Chart row
|
| 558 |
+
with gr.Row():
|
| 559 |
+
with gr.Column(scale=3):
|
| 560 |
+
history_chart = gr.Plot(value=get_portfolio_history_plot(), label="Net Worth History")
|
| 561 |
+
|
| 562 |
+
# Bottom Row: Allocations and Transactions
|
| 563 |
+
with gr.Row():
|
| 564 |
+
with gr.Column(scale=1):
|
| 565 |
+
allocation_chart = gr.Plot(value=get_current_allocation_plot(), label="Current Allocation")
|
| 566 |
+
with gr.Column(scale=2):
|
| 567 |
+
gr.Markdown("### Recent Transactions")
|
| 568 |
+
transactions_table = gr.Dataframe(value=get_recent_transactions(), interactive=False, wrap=True)
|
| 569 |
+
|
| 570 |
+
# ================= TAB 2: FORECAST (UPDATED with XAI) =================
|
| 571 |
+
with gr.TabItem("🔮 Forecast & AI Analysis"):
|
| 572 |
+
gr.Markdown("### Generate Tomorrow's Portfolio Strategy")
|
| 573 |
+
run_btn = gr.Button("🚀 Run Overnight Analysis", variant="primary", size="lg")
|
| 574 |
+
status_output = gr.Textbox(label="System Status", placeholder="Ready...", interactive=False, lines=1)
|
| 575 |
+
gr.Markdown("---")
|
| 576 |
+
|
| 577 |
+
with gr.Row():
|
| 578 |
+
# Left Column: Allocations & XAI Plot
|
| 579 |
+
with gr.Column(scale=2):
|
| 580 |
+
gr.Markdown("### 📈 Suggested Position")
|
| 581 |
+
allocation_output = gr.Dataframe(headers=["Asset", "Allocation"], datatype=["str", "str"], interactive=False)
|
| 582 |
+
|
| 583 |
+
# NEW: XAI Feature Importance Plot
|
| 584 |
+
gr.Markdown("### 🧠 Why did the agent choose this?")
|
| 585 |
+
xai_output_plot = gr.Plot(label="Top Influential Factors (XAI)", show_label=False)
|
| 586 |
+
|
| 587 |
+
# Right Column: AI Analysis Report
|
| 588 |
+
with gr.Column(scale=3):
|
| 589 |
+
analysis_report_html = gr.HTML(label="AI Risk Analysis Report")
|
| 590 |
+
|
| 591 |
+
# Updated click event with new XAI output
|
| 592 |
+
run_btn.click(
|
| 593 |
+
fn=predict_and_analyze,
|
| 594 |
+
inputs=None,
|
| 595 |
+
outputs=[status_output, allocation_output, xai_output_plot, analysis_report_html]
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
# ================= TAB 3: HISTORICAL DATA ANALYST =================
|
| 599 |
+
with gr.TabItem("📅 Historical Data Analyst"):
|
| 600 |
+
gr.Markdown("### Analyze Past Market Performance with AI")
|
| 601 |
+
|
| 602 |
+
with gr.Row():
|
| 603 |
+
with gr.Column(scale=1):
|
| 604 |
+
all_tickers_hist = ASSETS + list(FRED_IDS.values())
|
| 605 |
+
if DASHBOARD_DATA_DF is not None:
|
| 606 |
+
available_tickers_hist = [t for t in all_tickers_hist if t in DASHBOARD_DATA_DF.columns]
|
| 607 |
+
else:
|
| 608 |
+
available_tickers_hist = []
|
| 609 |
+
default_tickers_hist = available_tickers_hist[:3] if available_tickers_hist else []
|
| 610 |
+
|
| 611 |
+
asset_selector = gr.Dropdown(choices=available_tickers_hist, value=default_tickers_hist, multiselect=True, label="1. Select Assets")
|
| 612 |
+
period_selector = gr.Dropdown(choices=list(TIME_PERIODS.keys()), value="1 Year", label="2. Select Period")
|
| 613 |
+
analyze_btn = gr.Button("🔎 Run Analysis", variant="primary")
|
| 614 |
+
|
| 615 |
+
with gr.Column(scale=3):
|
| 616 |
+
historical_plot = gr.Plot(label="Performance Plot")
|
| 617 |
+
|
| 618 |
+
gr.Markdown("---")
|
| 619 |
+
historical_analysis_md = gr.Markdown("### 🤖 AI Analyst Report\n\n*Click 'Run Analysis' to generate.*")
|
| 620 |
+
|
| 621 |
+
analyze_btn.click(
|
| 622 |
+
fn=run_historical_analysis,
|
| 623 |
+
inputs=[asset_selector, period_selector],
|
| 624 |
+
outputs=[historical_plot, historical_analysis_md]
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
# ================= TAB 4: HISTORICAL SIMULATION (UPDATED with Pro Metrics) =================
|
| 628 |
+
with gr.TabItem("🔙 Historical Simulation"):
|
| 629 |
+
gr.Markdown("### Backtest the RL Agent against Baselines")
|
| 630 |
+
|
| 631 |
+
# Disclaimer Box
|
| 632 |
+
gr.HTML(f"""
|
| 633 |
+
<div class='disclaimer-box'>
|
| 634 |
+
<strong>⚠️ IMPORTANT DISCLAIMER:</strong> The RL model was trained on data from approximately
|
| 635 |
+
<strong>{TRAIN_START_DATE} to {TRAIN_END_DATE}</strong>. Running simulations outside or overlapping significantly
|
| 636 |
+
with this period may not accurately reflect real-world performance (lookahead bias or out-of-distribution data).
|
| 637 |
+
Use for educational purposes only.
|
| 638 |
+
</div>
|
| 639 |
+
""")
|
| 640 |
+
|
| 641 |
+
with gr.Row():
|
| 642 |
+
with gr.Column(scale=1):
|
| 643 |
+
start_date_input = gr.Textbox(label="Start Date (YYYY-MM-DD)", value=(datetime.now() - timedelta(days=365)).strftime('%Y-%m-%d'))
|
| 644 |
+
end_date_input = gr.Textbox(label="End Date (YYYY-MM-DD)", value=(datetime.now() - timedelta(days=1)).strftime('%Y-%m-%d'))
|
| 645 |
+
sim_btn = gr.Button("▶️ Run Simulation", variant="primary")
|
| 646 |
+
sim_status = gr.Textbox(label="Status", interactive=False, lines=1)
|
| 647 |
+
|
| 648 |
+
with gr.Column(scale=3):
|
| 649 |
+
sim_plot = gr.Plot(label="Simulation Performance")
|
| 650 |
+
|
| 651 |
+
gr.Markdown("---")
|
| 652 |
+
# Updated to Markdown component for better table formatting
|
| 653 |
+
sim_metrics_md = gr.Markdown("### 📊 Professional Performance Metrics\n\n*Run simulation to see metrics.*")
|
| 654 |
+
|
| 655 |
+
sim_btn.click(
|
| 656 |
+
fn=run_historical_simulation,
|
| 657 |
+
inputs=[start_date_input, end_date_input],
|
| 658 |
+
outputs=[sim_plot, sim_status, sim_metrics_md]
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
if __name__ == "__main__":
|
| 662 |
+
demo.queue().launch(server_name="0.0.0.0", server_port=7860, debug=True, share=True)
|
scripts/custom_policy.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# custom_policy.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from gymnasium import spaces
|
| 6 |
+
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
| 7 |
+
|
| 8 |
+
class TransformerFeatureExtractor(BaseFeaturesExtractor):
|
| 9 |
+
"""
|
| 10 |
+
A custom feature extractor that uses a Transformer Encoder.
|
| 11 |
+
|
| 12 |
+
It takes a flattened observation (window_size * n_features_per_step) and processes
|
| 13 |
+
it as a sequence.
|
| 14 |
+
"""
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
observation_space: spaces.Box,
|
| 18 |
+
features_dim: int = 256, # The final output dimension
|
| 19 |
+
n_features_per_step: int = 8, # <--- CRITICAL CHANGE: Matches 5 assets + 3 macro
|
| 20 |
+
window_size: int = 30,
|
| 21 |
+
d_model: int = 64, # Transformer's internal embedding dimension
|
| 22 |
+
n_head: int = 4, # Number of attention heads
|
| 23 |
+
n_layers: int = 2, # Number of transformer encoder layers
|
| 24 |
+
dropout: float = 0.1
|
| 25 |
+
):
|
| 26 |
+
|
| 27 |
+
super().__init__(observation_space, features_dim)
|
| 28 |
+
|
| 29 |
+
self.window_size = window_size
|
| 30 |
+
self.n_features_per_step = n_features_per_step
|
| 31 |
+
|
| 32 |
+
# Input shape check
|
| 33 |
+
expected_flat_dim = window_size * n_features_per_step
|
| 34 |
+
if observation_space.shape[0] != expected_flat_dim:
|
| 35 |
+
raise ValueError(
|
| 36 |
+
f"Observation space flat dimension {observation_space.shape[0]} "
|
| 37 |
+
f"does not match expected {expected_flat_dim} "
|
| 38 |
+
f"(window_size={window_size}, n_features_per_step={n_features_per_step})."
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# 1. Input Projection:
|
| 42 |
+
self.input_projection = nn.Linear(n_features_per_step, d_model)
|
| 43 |
+
|
| 44 |
+
# 2. Positional Encoding:
|
| 45 |
+
self.positional_encoding = nn.Parameter(torch.randn(1, window_size, d_model))
|
| 46 |
+
|
| 47 |
+
# 3. Transformer Encoder:
|
| 48 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 49 |
+
d_model=d_model,
|
| 50 |
+
nhead=n_head,
|
| 51 |
+
dropout=dropout,
|
| 52 |
+
batch_first=True
|
| 53 |
+
)
|
| 54 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
|
| 55 |
+
|
| 56 |
+
# 4. Output Layers:
|
| 57 |
+
self.flatten = nn.Flatten()
|
| 58 |
+
self.linear_out = nn.Linear(window_size * d_model, features_dim)
|
| 59 |
+
self.relu = nn.ReLU()
|
| 60 |
+
|
| 61 |
+
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
| 62 |
+
# Input shape: (batch_size, window_size * n_features_per_step)
|
| 63 |
+
|
| 64 |
+
# 1. Reshape to (batch_size, window_size, n_features_per_step)
|
| 65 |
+
x = observations.reshape(-1, self.window_size, self.n_features_per_step)
|
| 66 |
+
|
| 67 |
+
# 2. Project input features to d_model
|
| 68 |
+
x = self.input_projection(x)
|
| 69 |
+
|
| 70 |
+
# 3. Add positional encoding
|
| 71 |
+
x = x + self.positional_encoding
|
| 72 |
+
|
| 73 |
+
# 4. Pass through Transformer
|
| 74 |
+
x = self.transformer_encoder(x)
|
| 75 |
+
|
| 76 |
+
# 5. Flatten and project to final output
|
| 77 |
+
x = self.flatten(x)
|
| 78 |
+
x = self.relu(self.linear_out(x))
|
| 79 |
+
|
| 80 |
+
return x
|
scripts/environment.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import gymnasium as gym
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
|
@@ -5,126 +7,89 @@ from gymnasium import spaces
|
|
| 5 |
|
| 6 |
class PortfolioEnv(gym.Env):
|
| 7 |
"""
|
| 8 |
-
A custom
|
| 9 |
-
|
| 10 |
-
This environment simulates the daily trading of multiple financial assets. The agent's
|
| 11 |
-
goal is to learn a policy for allocating capital to maximize risk-adjusted returns.
|
| 12 |
"""
|
| 13 |
metadata = {'render_modes': ['human']}
|
| 14 |
|
| 15 |
def __init__(self, df, window_size=30, initial_balance=10000, transaction_cost_pct=0.001):
|
| 16 |
-
"""
|
| 17 |
-
Initializes the portfolio management environment.
|
| 18 |
-
|
| 19 |
-
Args:
|
| 20 |
-
df (pd.DataFrame): A DataFrame containing the daily closing prices of the assets.
|
| 21 |
-
The index should be dates and columns should be asset tickers.
|
| 22 |
-
window_size (int): The number of past days of price data to include in the observation.
|
| 23 |
-
initial_balance (float): The starting capital for the portfolio.
|
| 24 |
-
transaction_cost_pct (float): The percentage cost for each trade (e.g., 0.001 for 0.1%).
|
| 25 |
-
"""
|
| 26 |
super(PortfolioEnv, self).__init__()
|
| 27 |
|
| 28 |
-
# ---
|
| 29 |
self.df = df
|
| 30 |
self.window_size = window_size
|
| 31 |
self.initial_balance = initial_balance
|
| 32 |
self.transaction_cost_pct = transaction_cost_pct
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
# --- Action Space ---
|
| 36 |
-
# The agent outputs a vector of continuous values, one for each asset plus one for cash.
|
| 37 |
-
# These raw outputs are then converted to portfolio weights via a softmax function.
|
| 38 |
-
# The space is defined from -1 to 1 for better compatibility with standard RL algorithms.
|
| 39 |
-
# Shape: (number of assets + 1 for cash)
|
| 40 |
self.action_space = spaces.Box(
|
| 41 |
low=-1, high=1, shape=(self.n_assets + 1,), dtype=np.float32
|
| 42 |
)
|
| 43 |
|
| 44 |
# --- Observation Space ---
|
| 45 |
-
#
|
| 46 |
-
# Shape: (window_size * number of assets)
|
| 47 |
self.observation_space = spaces.Box(
|
| 48 |
low=-np.inf, high=np.inf,
|
| 49 |
-
shape=(self.window_size * self.
|
| 50 |
dtype=np.float32
|
| 51 |
)
|
| 52 |
|
| 53 |
-
# --- Internal State
|
| 54 |
-
# These variables track the state of the simulation over time.
|
| 55 |
self._current_step = 0
|
| 56 |
-
self._portfolio_value = 0
|
| 57 |
-
# Weights for each asset + cash, e.g., [w_aapl, w_msft, ..., w_cash]
|
| 58 |
self._weights = np.zeros(self.n_assets + 1)
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
def reset(self, seed=None):
|
| 61 |
-
"""
|
| 62 |
-
Resets the environment to its initial state for a new episode.
|
| 63 |
-
|
| 64 |
-
Returns:
|
| 65 |
-
tuple: A tuple containing the initial observation and auxiliary info.
|
| 66 |
-
"""
|
| 67 |
super().reset(seed=seed)
|
| 68 |
-
|
| 69 |
-
# Start the simulation at the first point where a full window of data is available.
|
| 70 |
self._current_step = self.window_size
|
| 71 |
self._portfolio_value = self.initial_balance
|
| 72 |
-
|
| 73 |
-
# Initialize weights to be 100% in cash.
|
| 74 |
self._weights = np.zeros(self.n_assets + 1)
|
| 75 |
-
self._weights[-1] = 1.0
|
| 76 |
|
| 77 |
observation = self._get_obs()
|
| 78 |
info = self._get_info()
|
| 79 |
-
|
| 80 |
return observation, info
|
| 81 |
|
| 82 |
def step(self, action):
|
| 83 |
-
"""
|
| 84 |
-
Executes one time step within the environment based on the agent's action.
|
| 85 |
-
|
| 86 |
-
Args:
|
| 87 |
-
action (np.ndarray): The raw output from the agent's policy network.
|
| 88 |
-
|
| 89 |
-
Returns:
|
| 90 |
-
tuple: A tuple containing the next observation, reward, terminated flag,
|
| 91 |
-
truncated flag, and auxiliary info.
|
| 92 |
-
"""
|
| 93 |
-
# 1. Store the portfolio value before taking the action.
|
| 94 |
current_portfolio_value = self._portfolio_value
|
| 95 |
|
| 96 |
-
|
| 97 |
-
# This ensures the weights are positive and sum to 1.
|
| 98 |
-
target_weights = np.exp(action) / np.sum(np.exp(action))
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
trades =
|
| 103 |
transaction_costs = np.sum(np.abs(trades)) * self.transaction_cost_pct
|
| 104 |
|
| 105 |
-
# 4. Update the internal state: apply costs, set new weights, and advance time.
|
| 106 |
self._balance = current_portfolio_value - transaction_costs
|
| 107 |
self._weights = target_weights
|
|
|
|
| 108 |
self._current_step += 1
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
price_ratio = next_prices / current_prices
|
| 114 |
-
|
| 115 |
-
# The new value of our asset holdings.
|
| 116 |
asset_values_after_price_change = (self._weights[:-1] * self._balance) * price_ratio
|
| 117 |
-
|
| 118 |
-
# The new total portfolio value is the sum of the updated asset values plus the cash holding.
|
| 119 |
new_portfolio_value = np.sum(asset_values_after_price_change) + (self._weights[-1] * self._balance)
|
| 120 |
self._portfolio_value = new_portfolio_value
|
| 121 |
|
| 122 |
-
|
| 123 |
-
# The reward is the log return of the portfolio value, which encourages geometric growth.
|
| 124 |
-
reward = np.log(new_portfolio_value / current_portfolio_value)
|
| 125 |
|
| 126 |
-
# 7. Check for termination conditions.
|
| 127 |
-
# The episode ends if the agent goes broke or runs out of data.
|
| 128 |
terminated = bool(self._portfolio_value <= self.initial_balance * 0.5)
|
| 129 |
truncated = self._current_step >= len(self.df) - 1
|
| 130 |
|
|
@@ -135,24 +100,25 @@ class PortfolioEnv(gym.Env):
|
|
| 135 |
|
| 136 |
def _get_obs(self):
|
| 137 |
"""
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
Returns:
|
| 141 |
-
np.ndarray: A flattened 1D array of the normalized price history.
|
| 142 |
"""
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
# Normalize the
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
def _get_info(self):
|
| 153 |
-
"""
|
| 154 |
-
Returns a dictionary of auxiliary information about the current state.
|
| 155 |
-
"""
|
| 156 |
return {
|
| 157 |
'step': self._current_step,
|
| 158 |
'portfolio_value': self._portfolio_value,
|
|
@@ -160,15 +126,7 @@ class PortfolioEnv(gym.Env):
|
|
| 160 |
}
|
| 161 |
|
| 162 |
def render(self, mode='human'):
|
| 163 |
-
|
| 164 |
-
Renders the environment's state (optional).
|
| 165 |
-
"""
|
| 166 |
-
if mode == 'human':
|
| 167 |
-
info = self._get_info()
|
| 168 |
-
print(f"Step: {info['step']}, Portfolio Value: {info['portfolio_value']:.2f}")
|
| 169 |
|
| 170 |
def close(self):
|
| 171 |
-
"""
|
| 172 |
-
Cleans up the environment (optional).
|
| 173 |
-
"""
|
| 174 |
pass
|
|
|
|
| 1 |
+
# src/environment.py (This is the CORRECT version for 8 features)
|
| 2 |
+
|
| 3 |
import gymnasium as gym
|
| 4 |
import numpy as np
|
| 5 |
import pandas as pd
|
|
|
|
| 7 |
|
| 8 |
class PortfolioEnv(gym.Env):
|
| 9 |
"""
|
| 10 |
+
A custom environment for portfolio management that includes macroeconomic data.
|
|
|
|
|
|
|
|
|
|
| 11 |
"""
|
| 12 |
metadata = {'render_modes': ['human']}
|
| 13 |
|
| 14 |
def __init__(self, df, window_size=30, initial_balance=10000, transaction_cost_pct=0.001):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
super(PortfolioEnv, self).__init__()
|
| 16 |
|
| 17 |
+
# --- Data Handling ---
|
| 18 |
self.df = df
|
| 19 |
self.window_size = window_size
|
| 20 |
self.initial_balance = initial_balance
|
| 21 |
self.transaction_cost_pct = transaction_cost_pct
|
| 22 |
+
|
| 23 |
+
# --- IMPORTANT: Define asset and macro columns ---
|
| 24 |
+
self.asset_columns = ['AAPL', 'BTC-USD', 'MSFT', 'SPY', 'TLT']
|
| 25 |
+
self.macro_columns = ['Federal Funds Rate', 'CPI', 'VIX']
|
| 26 |
+
|
| 27 |
+
self.n_assets = len(self.asset_columns)
|
| 28 |
+
self.n_macro_features = len(self.macro_columns)
|
| 29 |
+
|
| 30 |
+
# --- This is the attribute that was missing ---
|
| 31 |
+
self.n_features_per_step = self.n_assets + self.n_macro_features # Should be 8
|
| 32 |
|
| 33 |
# --- Action Space ---
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
self.action_space = spaces.Box(
|
| 35 |
low=-1, high=1, shape=(self.n_assets + 1,), dtype=np.float32
|
| 36 |
)
|
| 37 |
|
| 38 |
# --- Observation Space ---
|
| 39 |
+
# Shape: (window_size * total_features) = (30 * 8) = 240
|
|
|
|
| 40 |
self.observation_space = spaces.Box(
|
| 41 |
low=-np.inf, high=np.inf,
|
| 42 |
+
shape=(self.window_size * self.n_features_per_step,),
|
| 43 |
dtype=np.float32
|
| 44 |
)
|
| 45 |
|
| 46 |
+
# --- Internal State ---
|
|
|
|
| 47 |
self._current_step = 0
|
| 48 |
+
self._portfolio_value = 0
|
|
|
|
| 49 |
self._weights = np.zeros(self.n_assets + 1)
|
| 50 |
|
| 51 |
+
# Separate dataframes for prices and macro for easier handling
|
| 52 |
+
self.price_df = self.df[self.asset_columns]
|
| 53 |
+
self.macro_df = self.df[self.macro_columns]
|
| 54 |
+
|
| 55 |
def reset(self, seed=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
super().reset(seed=seed)
|
|
|
|
|
|
|
| 57 |
self._current_step = self.window_size
|
| 58 |
self._portfolio_value = self.initial_balance
|
| 59 |
+
|
|
|
|
| 60 |
self._weights = np.zeros(self.n_assets + 1)
|
| 61 |
+
self._weights[-1] = 1.0 # 100% in cash
|
| 62 |
|
| 63 |
observation = self._get_obs()
|
| 64 |
info = self._get_info()
|
|
|
|
| 65 |
return observation, info
|
| 66 |
|
| 67 |
def step(self, action):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
current_portfolio_value = self._portfolio_value
|
| 69 |
|
| 70 |
+
target_weights = np.exp(action) / np.sum(np.exp(action)) # Softmax
|
|
|
|
|
|
|
| 71 |
|
| 72 |
+
current_asset_values = self._weights[:-1] * current_portfolio_value
|
| 73 |
+
target_asset_values = target_weights[:-1] * current_portfolio_value
|
| 74 |
+
trades = target_asset_values - current_asset_values
|
| 75 |
transaction_costs = np.sum(np.abs(trades)) * self.transaction_cost_pct
|
| 76 |
|
|
|
|
| 77 |
self._balance = current_portfolio_value - transaction_costs
|
| 78 |
self._weights = target_weights
|
| 79 |
+
|
| 80 |
self._current_step += 1
|
| 81 |
|
| 82 |
+
current_prices = self.price_df.iloc[self._current_step - 1].values
|
| 83 |
+
next_prices = self.price_df.iloc[self._current_step].values
|
| 84 |
+
|
| 85 |
+
price_ratio = next_prices / (current_prices + 1e-8) # Add epsilon for safety
|
| 86 |
+
|
|
|
|
| 87 |
asset_values_after_price_change = (self._weights[:-1] * self._balance) * price_ratio
|
|
|
|
|
|
|
| 88 |
new_portfolio_value = np.sum(asset_values_after_price_change) + (self._weights[-1] * self._balance)
|
| 89 |
self._portfolio_value = new_portfolio_value
|
| 90 |
|
| 91 |
+
reward = np.log(new_portfolio_value / (current_portfolio_value + 1e-8)) # Add epsilon
|
|
|
|
|
|
|
| 92 |
|
|
|
|
|
|
|
| 93 |
terminated = bool(self._portfolio_value <= self.initial_balance * 0.5)
|
| 94 |
truncated = self._current_step >= len(self.df) - 1
|
| 95 |
|
|
|
|
| 100 |
|
| 101 |
def _get_obs(self):
|
| 102 |
"""
|
| 103 |
+
Gets the observation for the current time step.
|
| 104 |
+
This includes a window of prices AND a window of macro data.
|
|
|
|
|
|
|
| 105 |
"""
|
| 106 |
+
price_window = self.price_df.iloc[self._current_step - self.window_size : self._current_step].values
|
| 107 |
+
macro_window = self.macro_df.iloc[self._current_step - self.window_size : self._current_step].values
|
| 108 |
+
|
| 109 |
+
# Normalize the price window (relative changes)
|
| 110 |
+
normalized_price_window = price_window / (price_window[0] + 1e-8)
|
| 111 |
+
|
| 112 |
+
# Normalize the macro window
|
| 113 |
+
normalized_macro_window = macro_window / (macro_window[0] + 1e-8)
|
| 114 |
+
|
| 115 |
+
# Combine the normalized windows
|
| 116 |
+
observation_window = np.concatenate([normalized_price_window, normalized_macro_window], axis=1)
|
| 117 |
+
|
| 118 |
+
# Flatten into a 1D vector
|
| 119 |
+
return observation_window.flatten().astype(np.float32)
|
| 120 |
|
| 121 |
def _get_info(self):
|
|
|
|
|
|
|
|
|
|
| 122 |
return {
|
| 123 |
'step': self._current_step,
|
| 124 |
'portfolio_value': self._portfolio_value,
|
|
|
|
| 126 |
}
|
| 127 |
|
| 128 |
def render(self, mode='human'):
|
| 129 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
def close(self):
|
|
|
|
|
|
|
|
|
|
| 132 |
pass
|
scripts/evaluate.py
CHANGED
|
@@ -1,12 +1,16 @@
|
|
|
|
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
import numpy as np
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
-
|
| 5 |
-
from
|
| 6 |
-
from
|
| 7 |
from matplotlib.ticker import FuncFormatter
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
# --- Helper Function to Run the RL Agent ---
|
| 10 |
|
| 11 |
def evaluate_agent(env, model):
|
| 12 |
"""
|
|
@@ -14,7 +18,6 @@ def evaluate_agent(env, model):
|
|
| 14 |
"""
|
| 15 |
obs, info = env.reset()
|
| 16 |
terminated, truncated = False, False
|
| 17 |
-
|
| 18 |
portfolio_values = [env.initial_balance]
|
| 19 |
|
| 20 |
while not (terminated or truncated):
|
|
@@ -22,121 +25,129 @@ def evaluate_agent(env, model):
|
|
| 22 |
obs, reward, terminated, truncated, info = env.step(action)
|
| 23 |
portfolio_values.append(info['portfolio_value'])
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
def calculate_metrics(portfolio_values, freq=252, rf=0.0):
|
| 29 |
"""
|
| 30 |
Calculates key performance metrics from a series of portfolio values.
|
| 31 |
-
freq: number of trading periods in a year (252 for daily, 52 for weekly).
|
| 32 |
-
rf: risk-free rate (default = 0 for simplicity).
|
| 33 |
"""
|
|
|
|
|
|
|
|
|
|
| 34 |
returns = portfolio_values.pct_change().dropna()
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
# Total Return
|
| 37 |
total_return = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) - 1
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
-
num_years = (len(portfolio_values) / freq)
|
| 41 |
-
cagr = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) ** (1/num_years) - 1
|
| 42 |
-
|
| 43 |
-
# Sharpe Ratio
|
| 44 |
-
sharpe_ratio = np.sqrt(freq) * (returns.mean() - rf) / returns.std()
|
| 45 |
|
| 46 |
-
# Sortino Ratio (downside risk only)
|
| 47 |
downside_returns = returns[returns < 0]
|
| 48 |
downside_std = downside_returns.std()
|
| 49 |
sortino_ratio = np.sqrt(freq) * (returns.mean() - rf) / downside_std if downside_std > 0 else np.nan
|
| 50 |
|
| 51 |
-
# Volatility (annualized std)
|
| 52 |
volatility = returns.std() * np.sqrt(freq)
|
| 53 |
|
| 54 |
-
# Max Drawdown
|
| 55 |
rolling_max = portfolio_values.cummax()
|
| 56 |
drawdown = portfolio_values / rolling_max - 1.0
|
| 57 |
max_drawdown = drawdown.min()
|
| 58 |
|
| 59 |
-
|
| 60 |
-
calmar_ratio = cagr / abs(max_drawdown / 100) if max_drawdown != 0 else np.nan
|
| 61 |
|
| 62 |
return {
|
| 63 |
-
"Total Return": f"{total_return:.2%}",
|
| 64 |
-
"
|
| 65 |
-
"
|
| 66 |
-
"Sortino Ratio": f"{sortino_ratio:.2f}",
|
| 67 |
-
"Volatility": f"{volatility:.2%}",
|
| 68 |
-
"Max Drawdown": f"{max_drawdown:.2%}",
|
| 69 |
"Calmar Ratio": f"{calmar_ratio:.2f}"
|
| 70 |
}
|
| 71 |
|
| 72 |
|
| 73 |
-
def main(test_data_path='data/
|
| 74 |
"""
|
| 75 |
-
Loads, evaluates, and plots
|
| 76 |
-
against a Buy and Hold baseline.
|
| 77 |
"""
|
| 78 |
-
#
|
| 79 |
models_to_evaluate = {
|
| 80 |
-
"
|
| 81 |
-
"
|
| 82 |
-
"TD3 Agent": (TD3, 'checkpoints/td3_portfolio_model')
|
|
|
|
| 83 |
}
|
| 84 |
|
| 85 |
-
# Load test data
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
-
# Dictionary to store results
|
| 89 |
portfolio_values = {}
|
| 90 |
metrics = {}
|
| 91 |
|
| 92 |
# --- Run Evaluations for each RL Agent---
|
| 93 |
for name, (agent_type, model_path) in models_to_evaluate.items():
|
| 94 |
print(f"--- Evaluating {name} ---")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
model = agent_type.load(model_path)
|
| 96 |
-
env = PortfolioEnv(
|
| 97 |
portfolio_values[name] = evaluate_agent(env, model)
|
| 98 |
metrics[name] = calculate_metrics(portfolio_values[name])
|
| 99 |
|
| 100 |
# --- Evaluate Buy and Hold Baseline ---
|
| 101 |
print("\n--- Evaluating Buy and Hold Baseline ---")
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
| 103 |
portfolio_values["Buy and Hold"] = bnh_values
|
| 104 |
metrics["Buy and Hold"] = calculate_metrics(bnh_values)
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
| 106 |
# --- Combine and Print Metrics ---
|
| 107 |
print("\n--- Performance Metrics ---")
|
| 108 |
metrics_df = pd.DataFrame(metrics)
|
| 109 |
-
print(metrics_df)
|
| 110 |
|
| 111 |
# --- Plotting All Strategies ---
|
| 112 |
plt.style.use('seaborn-v0_8-darkgrid')
|
| 113 |
fig, ax = plt.subplots(figsize=(14, 8))
|
| 114 |
|
| 115 |
-
# Define colors for clarity
|
| 116 |
colors = {
|
| 117 |
-
"PPO Agent": "red",
|
| 118 |
-
"SAC Agent": "green",
|
| 119 |
-
"TD3 Agent": "orange",
|
| 120 |
-
"
|
|
|
|
|
|
|
| 121 |
}
|
| 122 |
|
| 123 |
for name, values in portfolio_values.items():
|
| 124 |
-
|
|
|
|
| 125 |
|
| 126 |
ax.set_title('Agent Performance Comparison', fontsize=16)
|
| 127 |
ax.set_xlabel('Date', fontsize=12)
|
| 128 |
ax.set_ylabel('Portfolio Value ($)', fontsize=12)
|
| 129 |
ax.legend(fontsize=12)
|
| 130 |
-
|
| 131 |
formatter = FuncFormatter(lambda x, p: f'${x:,.0f}')
|
| 132 |
ax.yaxis.set_major_formatter(formatter)
|
| 133 |
|
| 134 |
plt.tight_layout()
|
| 135 |
-
|
|
|
|
|
|
|
| 136 |
plt.show()
|
| 137 |
|
| 138 |
-
# Example of how to run this main function
|
| 139 |
if __name__ == '__main__':
|
| 140 |
-
# You can specify a different test file here if needed
|
| 141 |
-
# e.g., main(test_data_path='data/stress_test_2018.csv')
|
| 142 |
main()
|
|
|
|
| 1 |
+
# scripts/compare_performance.py
|
| 2 |
+
|
| 3 |
import pandas as pd
|
| 4 |
import numpy as np
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
+
import os
|
| 7 |
+
from stable_baselines3 import TD3, PPO, SAC
|
| 8 |
+
from gymnasium import spaces
|
| 9 |
from matplotlib.ticker import FuncFormatter
|
| 10 |
+
from environment import PortfolioEnv
|
| 11 |
+
from evaluate_baselines import buy_and_hold, equally_weighted_rebalanced
|
| 12 |
+
from custom_policy import TransformerFeatureExtractor
|
| 13 |
|
|
|
|
| 14 |
|
| 15 |
def evaluate_agent(env, model):
|
| 16 |
"""
|
|
|
|
| 18 |
"""
|
| 19 |
obs, info = env.reset()
|
| 20 |
terminated, truncated = False, False
|
|
|
|
| 21 |
portfolio_values = [env.initial_balance]
|
| 22 |
|
| 23 |
while not (terminated or truncated):
|
|
|
|
| 25 |
obs, reward, terminated, truncated, info = env.step(action)
|
| 26 |
portfolio_values.append(info['portfolio_value'])
|
| 27 |
|
| 28 |
+
# Align index with the actual steps taken
|
| 29 |
+
# The first obs is at window_size, so index should start one step before
|
| 30 |
+
valid_dates = env.df.index[env.window_size-1:]
|
| 31 |
+
return pd.Series(portfolio_values, index=valid_dates[:len(portfolio_values)])
|
| 32 |
|
| 33 |
|
| 34 |
def calculate_metrics(portfolio_values, freq=252, rf=0.0):
|
| 35 |
"""
|
| 36 |
Calculates key performance metrics from a series of portfolio values.
|
|
|
|
|
|
|
| 37 |
"""
|
| 38 |
+
if len(portfolio_values) < 2:
|
| 39 |
+
return { "Total Return": "N/A", "CAGR": "N/A", "Sharpe Ratio": "N/A", "Max Drawdown": "N/A" }
|
| 40 |
+
|
| 41 |
returns = portfolio_values.pct_change().dropna()
|
| 42 |
+
if returns.empty:
|
| 43 |
+
return { "Total Return": "0.00%", "CAGR": "0.00%", "Sharpe Ratio": "0.00", "Max Drawdown": "0.00%" }
|
| 44 |
|
|
|
|
| 45 |
total_return = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) - 1
|
| 46 |
+
num_years = (len(portfolio_values) - 1) / freq
|
| 47 |
+
cagr = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) ** (1/num_years) - 1 if num_years > 0 else 0.0
|
| 48 |
|
| 49 |
+
sharpe_ratio = np.sqrt(freq) * (returns.mean() - rf) / returns.std() if returns.std() > 0 else np.nan
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
|
|
|
| 51 |
downside_returns = returns[returns < 0]
|
| 52 |
downside_std = downside_returns.std()
|
| 53 |
sortino_ratio = np.sqrt(freq) * (returns.mean() - rf) / downside_std if downside_std > 0 else np.nan
|
| 54 |
|
|
|
|
| 55 |
volatility = returns.std() * np.sqrt(freq)
|
| 56 |
|
|
|
|
| 57 |
rolling_max = portfolio_values.cummax()
|
| 58 |
drawdown = portfolio_values / rolling_max - 1.0
|
| 59 |
max_drawdown = drawdown.min()
|
| 60 |
|
| 61 |
+
calmar_ratio = cagr / abs(max_drawdown) if max_drawdown != 0 and cagr != 0 else np.nan
|
|
|
|
| 62 |
|
| 63 |
return {
|
| 64 |
+
"Total Return": f"{total_return:.2%}", "CAGR": f"{cagr:.2%}",
|
| 65 |
+
"Sharpe Ratio": f"{sharpe_ratio:.2f}", "Sortino Ratio": f"{sortino_ratio:.2f}",
|
| 66 |
+
"Volatility": f"{volatility:.2%}", "Max Drawdown": f"{max_drawdown:.2%}",
|
|
|
|
|
|
|
|
|
|
| 67 |
"Calmar Ratio": f"{calmar_ratio:.2f}"
|
| 68 |
}
|
| 69 |
|
| 70 |
|
| 71 |
+
def main(test_data_path='data/eval.csv'):
|
| 72 |
"""
|
| 73 |
+
Loads, evaluates, and plots all agent performances against baselines.
|
|
|
|
| 74 |
"""
|
| 75 |
+
# Define Model Paths and Agent Types
|
| 76 |
models_to_evaluate = {
|
| 77 |
+
"SAC Agent Default (MLP)": (SAC, 'checkpoints/sac_portfolio_model.zip'),
|
| 78 |
+
"PPO Agent (MLP)": (PPO, 'checkpoints/ppo_portfolio_model.zip'),
|
| 79 |
+
"TD3 Agent (MLP)": (TD3, 'checkpoints/td3_portfolio_model.zip'),
|
| 80 |
+
"TD3 Agent (Transformer)": (TD3, 'checkpoints/td3_transformer_model.zip')
|
| 81 |
}
|
| 82 |
|
| 83 |
+
# Load test data (this contains ALL columns - assets + macro)
|
| 84 |
+
full_eval_df = pd.read_csv(test_data_path, index_col='Date', parse_dates=True)
|
| 85 |
+
|
| 86 |
+
# Define your actual tradable asset columns
|
| 87 |
+
asset_columns = ['AAPL', 'BTC-USD', 'MSFT', 'SPY', 'TLT']
|
| 88 |
|
|
|
|
| 89 |
portfolio_values = {}
|
| 90 |
metrics = {}
|
| 91 |
|
| 92 |
# --- Run Evaluations for each RL Agent---
|
| 93 |
for name, (agent_type, model_path) in models_to_evaluate.items():
|
| 94 |
print(f"--- Evaluating {name} ---")
|
| 95 |
+
if not os.path.exists(model_path):
|
| 96 |
+
print(f"⚠️ Warning: Model file not found at {model_path}. Skipping.")
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
model = agent_type.load(model_path)
|
| 100 |
+
env = PortfolioEnv(full_eval_df) # Pass the full DataFrame to the RL env
|
| 101 |
portfolio_values[name] = evaluate_agent(env, model)
|
| 102 |
metrics[name] = calculate_metrics(portfolio_values[name])
|
| 103 |
|
| 104 |
# --- Evaluate Buy and Hold Baseline ---
|
| 105 |
print("\n--- Evaluating Buy and Hold Baseline ---")
|
| 106 |
+
|
| 107 |
+
bnh_values = buy_and_hold(full_eval_df[asset_columns])
|
| 108 |
+
ewp_values = equally_weighted_rebalanced(full_eval_df[asset_columns])
|
| 109 |
+
|
| 110 |
portfolio_values["Buy and Hold"] = bnh_values
|
| 111 |
metrics["Buy and Hold"] = calculate_metrics(bnh_values)
|
| 112 |
+
|
| 113 |
+
portfolio_values["Equally Weighted"] = ewp_values
|
| 114 |
+
metrics["Equally Weighted"] = calculate_metrics(ewp_values)
|
| 115 |
+
|
| 116 |
# --- Combine and Print Metrics ---
|
| 117 |
print("\n--- Performance Metrics ---")
|
| 118 |
metrics_df = pd.DataFrame(metrics)
|
| 119 |
+
print(metrics_df.to_markdown(numalign="left", stralign="left"))
|
| 120 |
|
| 121 |
# --- Plotting All Strategies ---
|
| 122 |
plt.style.use('seaborn-v0_8-darkgrid')
|
| 123 |
fig, ax = plt.subplots(figsize=(14, 8))
|
| 124 |
|
|
|
|
| 125 |
colors = {
|
| 126 |
+
"PPO Agent (MLP)": "red",
|
| 127 |
+
"SAC Agent Default (MLP)": "green",
|
| 128 |
+
"TD3 Agent (MLP)": "orange",
|
| 129 |
+
"TD3 Agent (Transformer)": "cyan",
|
| 130 |
+
"Buy and Hold": "blue",
|
| 131 |
+
"Equally Weighted": "purple"
|
| 132 |
}
|
| 133 |
|
| 134 |
for name, values in portfolio_values.items():
|
| 135 |
+
if name in portfolio_values: # Check if it was successfully evaluated
|
| 136 |
+
ax.plot(values.index, values, label=name, color=colors.get(name, 'gray'), linewidth=2)
|
| 137 |
|
| 138 |
ax.set_title('Agent Performance Comparison', fontsize=16)
|
| 139 |
ax.set_xlabel('Date', fontsize=12)
|
| 140 |
ax.set_ylabel('Portfolio Value ($)', fontsize=12)
|
| 141 |
ax.legend(fontsize=12)
|
| 142 |
+
|
| 143 |
formatter = FuncFormatter(lambda x, p: f'${x:,.0f}')
|
| 144 |
ax.yaxis.set_major_formatter(formatter)
|
| 145 |
|
| 146 |
plt.tight_layout()
|
| 147 |
+
results_dir = 'results'
|
| 148 |
+
os.makedirs(results_dir, exist_ok=True)
|
| 149 |
+
plt.savefig(os.path.join(results_dir, 'final_performance_comparison_all_agents.png'))
|
| 150 |
plt.show()
|
| 151 |
|
|
|
|
| 152 |
if __name__ == '__main__':
|
|
|
|
|
|
|
| 153 |
main()
|
scripts/evaluate_baselines.py
CHANGED
|
@@ -1,46 +1,46 @@
|
|
| 1 |
-
# evaluate_baselines.py
|
| 2 |
-
|
| 3 |
import pandas as pd
|
| 4 |
import numpy as np
|
| 5 |
import matplotlib.pyplot as plt
|
|
|
|
| 6 |
|
| 7 |
-
def buy_and_hold(
|
| 8 |
"""
|
| 9 |
Simulates the Buy and Hold strategy.
|
| 10 |
|
| 11 |
Args:
|
| 12 |
-
|
| 13 |
initial_balance (int): The starting capital.
|
| 14 |
|
| 15 |
Returns:
|
| 16 |
pd.Series: A Series containing the portfolio value for each day.
|
| 17 |
"""
|
| 18 |
print("--- Simulating Buy and Hold ---")
|
| 19 |
-
n_assets = len(
|
| 20 |
|
| 21 |
# Invest an equal amount in each asset at the beginning
|
| 22 |
initial_investment_per_asset = initial_balance / n_assets
|
| 23 |
|
| 24 |
# Get the initial prices
|
| 25 |
-
initial_prices =
|
| 26 |
|
| 27 |
# Calculate the number of shares bought for each asset
|
| 28 |
-
|
|
|
|
| 29 |
|
| 30 |
# Calculate the portfolio value for each day
|
| 31 |
-
portfolio_values =
|
| 32 |
|
| 33 |
print(f"Initial Investment: ${initial_balance:.2f}")
|
| 34 |
-
print(f"Final Portfolio Value: ${portfolio_values.iloc[-1]:.2f}")
|
| 35 |
|
| 36 |
return portfolio_values
|
| 37 |
|
| 38 |
-
def equally_weighted_rebalanced(
|
| 39 |
"""
|
| 40 |
Simulates an Equally Weighted Portfolio with periodic rebalancing.
|
| 41 |
|
| 42 |
Args:
|
| 43 |
-
|
| 44 |
initial_balance (int): The starting capital.
|
| 45 |
rebalance_freq (str): The rebalancing frequency ('M' for monthly, 'Q' for quarterly).
|
| 46 |
transaction_cost_pct (float): The transaction cost as a percentage.
|
|
@@ -49,24 +49,30 @@ def equally_weighted_rebalanced(df, initial_balance=10000, rebalance_freq='M', t
|
|
| 49 |
pd.Series: A Series containing the portfolio value for each day.
|
| 50 |
"""
|
| 51 |
print(f"\n--- Simulating Equally Weighted Portfolio (Rebalanced {rebalance_freq}) ---")
|
| 52 |
-
n_assets = len(
|
| 53 |
|
| 54 |
# Set the initial weights to be equal
|
| 55 |
weights = np.full(n_assets, 1/n_assets)
|
| 56 |
|
| 57 |
portfolio_value = initial_balance
|
| 58 |
-
portfolio_values = pd.Series(index=
|
| 59 |
|
| 60 |
last_rebalance_date = None
|
| 61 |
|
| 62 |
-
for date, prices in
|
| 63 |
# Store the portfolio value for the day before any changes
|
| 64 |
portfolio_values[date] = portfolio_value
|
| 65 |
|
| 66 |
# Determine if it's a rebalancing day
|
| 67 |
-
# Rebalance on the first day of the new period (month, quarter)
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
# Calculate the value of trades to rebalance
|
| 71 |
target_asset_values = portfolio_value * (1/n_assets)
|
| 72 |
current_asset_values = weights * portfolio_value
|
|
@@ -82,32 +88,43 @@ def equally_weighted_rebalanced(df, initial_balance=10000, rebalance_freq='M', t
|
|
| 82 |
|
| 83 |
# Calculate portfolio value for the *next* day before the market opens
|
| 84 |
# Get price changes from today to the next trading day
|
| 85 |
-
today_prices =
|
| 86 |
-
next_day_index =
|
| 87 |
-
if next_day_index < len(
|
| 88 |
-
next_day_prices =
|
| 89 |
-
|
|
|
|
|
|
|
| 90 |
|
| 91 |
# Update portfolio value based on price changes
|
| 92 |
portfolio_value = np.sum( (weights * portfolio_value) * price_change_ratio )
|
| 93 |
|
| 94 |
# Update weights due to market drift
|
| 95 |
new_asset_values = (weights * portfolio_value) * price_change_ratio
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
print(f"Initial Investment: ${initial_balance:.2f}")
|
| 99 |
-
print(f"Final Portfolio Value: ${portfolio_values.iloc[-1]:.2f}")
|
| 100 |
|
| 101 |
return portfolio_values.dropna()
|
| 102 |
|
| 103 |
|
| 104 |
def main():
|
| 105 |
-
# Load the
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
# --- Run Baseline Strategies ---
|
| 109 |
-
bnh_values =
|
| 110 |
-
ewp_values = equally_weighted_rebalanced(
|
| 111 |
|
| 112 |
# --- Plot the results ---
|
| 113 |
plt.style.use('seaborn-v0_8-darkgrid')
|
|
@@ -127,7 +144,11 @@ def main():
|
|
| 127 |
ax.yaxis.set_major_formatter(formatter)
|
| 128 |
|
| 129 |
plt.tight_layout()
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
plt.show()
|
| 132 |
|
| 133 |
if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
import numpy as np
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
+
import os # Import os for directory creation
|
| 5 |
|
| 6 |
+
def buy_and_hold(df_assets, initial_balance=10000): # Renamed df to df_assets for clarity
|
| 7 |
"""
|
| 8 |
Simulates the Buy and Hold strategy.
|
| 9 |
|
| 10 |
Args:
|
| 11 |
+
df_assets (pd.DataFrame): DataFrame with daily tradable asset prices ONLY.
|
| 12 |
initial_balance (int): The starting capital.
|
| 13 |
|
| 14 |
Returns:
|
| 15 |
pd.Series: A Series containing the portfolio value for each day.
|
| 16 |
"""
|
| 17 |
print("--- Simulating Buy and Hold ---")
|
| 18 |
+
n_assets = len(df_assets.columns)
|
| 19 |
|
| 20 |
# Invest an equal amount in each asset at the beginning
|
| 21 |
initial_investment_per_asset = initial_balance / n_assets
|
| 22 |
|
| 23 |
# Get the initial prices
|
| 24 |
+
initial_prices = df_assets.iloc[0]
|
| 25 |
|
| 26 |
# Calculate the number of shares bought for each asset
|
| 27 |
+
# Handle potential division by zero if an asset price is 0 (though unlikely with real data)
|
| 28 |
+
shares = initial_investment_per_asset / (initial_prices + 1e-8)
|
| 29 |
|
| 30 |
# Calculate the portfolio value for each day
|
| 31 |
+
portfolio_values = df_assets.dot(shares)
|
| 32 |
|
| 33 |
print(f"Initial Investment: ${initial_balance:.2f}")
|
| 34 |
+
print(f"Final Portfolio Value (Buy and Hold): ${portfolio_values.iloc[-1]:.2f}")
|
| 35 |
|
| 36 |
return portfolio_values
|
| 37 |
|
| 38 |
+
def equally_weighted_rebalanced(df_assets, initial_balance=10000, rebalance_freq='M', transaction_cost_pct=0.001): # Renamed df to df_assets
|
| 39 |
"""
|
| 40 |
Simulates an Equally Weighted Portfolio with periodic rebalancing.
|
| 41 |
|
| 42 |
Args:
|
| 43 |
+
df_assets (pd.DataFrame): DataFrame with daily tradable asset prices ONLY.
|
| 44 |
initial_balance (int): The starting capital.
|
| 45 |
rebalance_freq (str): The rebalancing frequency ('M' for monthly, 'Q' for quarterly).
|
| 46 |
transaction_cost_pct (float): The transaction cost as a percentage.
|
|
|
|
| 49 |
pd.Series: A Series containing the portfolio value for each day.
|
| 50 |
"""
|
| 51 |
print(f"\n--- Simulating Equally Weighted Portfolio (Rebalanced {rebalance_freq}) ---")
|
| 52 |
+
n_assets = len(df_assets.columns)
|
| 53 |
|
| 54 |
# Set the initial weights to be equal
|
| 55 |
weights = np.full(n_assets, 1/n_assets)
|
| 56 |
|
| 57 |
portfolio_value = initial_balance
|
| 58 |
+
portfolio_values = pd.Series(index=df_assets.index, dtype=float) # Explicitly set dtype
|
| 59 |
|
| 60 |
last_rebalance_date = None
|
| 61 |
|
| 62 |
+
for i, (date, prices) in enumerate(df_assets.iterrows()):
|
| 63 |
# Store the portfolio value for the day before any changes
|
| 64 |
portfolio_values[date] = portfolio_value
|
| 65 |
|
| 66 |
# Determine if it's a rebalancing day
|
| 67 |
+
# Rebalance on the first day of the new period (month, quarter) or if it's the very first day
|
| 68 |
+
rebalance_this_day = False
|
| 69 |
+
if i == 0: # Rebalance on the very first day
|
| 70 |
+
rebalance_this_day = True
|
| 71 |
+
elif rebalance_freq == 'M' and date.month != df_assets.index[i-1].month:
|
| 72 |
+
rebalance_this_day = True
|
| 73 |
+
# Add 'Q' for quarterly if needed, similar logic
|
| 74 |
+
|
| 75 |
+
if rebalance_this_day:
|
| 76 |
# Calculate the value of trades to rebalance
|
| 77 |
target_asset_values = portfolio_value * (1/n_assets)
|
| 78 |
current_asset_values = weights * portfolio_value
|
|
|
|
| 88 |
|
| 89 |
# Calculate portfolio value for the *next* day before the market opens
|
| 90 |
# Get price changes from today to the next trading day
|
| 91 |
+
today_prices = prices # Already have prices for the current date
|
| 92 |
+
next_day_index = df_assets.index.get_loc(date) + 1
|
| 93 |
+
if next_day_index < len(df_assets):
|
| 94 |
+
next_day_prices = df_assets.iloc[next_day_index]
|
| 95 |
+
|
| 96 |
+
# Avoid division by zero
|
| 97 |
+
price_change_ratio = next_day_prices / (today_prices + 1e-8)
|
| 98 |
|
| 99 |
# Update portfolio value based on price changes
|
| 100 |
portfolio_value = np.sum( (weights * portfolio_value) * price_change_ratio )
|
| 101 |
|
| 102 |
# Update weights due to market drift
|
| 103 |
new_asset_values = (weights * portfolio_value) * price_change_ratio
|
| 104 |
+
# Avoid division by zero for total portfolio value
|
| 105 |
+
if np.sum(new_asset_values) > 1e-8: # Check if total value is effectively non-zero
|
| 106 |
+
weights = new_asset_values / np.sum(new_asset_values)
|
| 107 |
+
else:
|
| 108 |
+
weights = np.full(n_assets, 1/n_assets) # Default to equal or handle as error
|
| 109 |
+
|
| 110 |
|
| 111 |
print(f"Initial Investment: ${initial_balance:.2f}")
|
| 112 |
+
print(f"Final Portfolio Value (Equally Weighted): ${portfolio_values.iloc[-1]:.2f}")
|
| 113 |
|
| 114 |
return portfolio_values.dropna()
|
| 115 |
|
| 116 |
|
| 117 |
def main():
|
| 118 |
+
# Load the evaluation data (which contains both assets and macro data)
|
| 119 |
+
full_eval_df = pd.read_csv('data/eval.csv', index_col='Date', parse_dates=True)
|
| 120 |
+
|
| 121 |
+
# --- IMPORTANT: Filter ONLY asset columns for baselines ---
|
| 122 |
+
asset_columns = ['AAPL', 'BTC-USD', 'MSFT', 'SPY', 'TLT'] # Define your actual tradable assets
|
| 123 |
+
test_df_assets_only = full_eval_df[asset_columns]
|
| 124 |
|
| 125 |
# --- Run Baseline Strategies ---
|
| 126 |
+
bnh_values = buy_and_hold1(test_df_assets_only)
|
| 127 |
+
ewp_values = equally_weighted_rebalanced(test_df_assets_only)
|
| 128 |
|
| 129 |
# --- Plot the results ---
|
| 130 |
plt.style.use('seaborn-v0_8-darkgrid')
|
|
|
|
| 144 |
ax.yaxis.set_major_formatter(formatter)
|
| 145 |
|
| 146 |
plt.tight_layout()
|
| 147 |
+
|
| 148 |
+
# Ensure results directory exists for saving plot
|
| 149 |
+
results_dir = 'results'
|
| 150 |
+
os.makedirs(results_dir, exist_ok=True)
|
| 151 |
+
plt.savefig(os.path.join(results_dir, 'baseline_performance.png'))
|
| 152 |
plt.show()
|
| 153 |
|
| 154 |
if __name__ == '__main__':
|
scripts/fetch_data.py
DELETED
|
@@ -1,75 +0,0 @@
|
|
| 1 |
-
import yfinance as yf
|
| 2 |
-
import pandas as pd
|
| 3 |
-
import os
|
| 4 |
-
|
| 5 |
-
# --- Configuration ---
|
| 6 |
-
# Asset tickers
|
| 7 |
-
TICKERS = ["AAPL", "MSFT", "SPY", "TLT", "BTC-USD"]
|
| 8 |
-
# Time periods for training and testing
|
| 9 |
-
TRAIN_START_DATE = "2015-01-01"
|
| 10 |
-
TRAIN_END_DATE = "2020-12-31"
|
| 11 |
-
TEST_START_DATE = "2021-01-01"
|
| 12 |
-
TEST_END_DATE = "2023-12-31"
|
| 13 |
-
|
| 14 |
-
# Directory to save the data
|
| 15 |
-
DATA_DIR = "data"
|
| 16 |
-
TRAIN_DATA_PATH = os.path.join(DATA_DIR, "train.csv")
|
| 17 |
-
TEST_DATA_PATH = os.path.join(DATA_DIR, "test.csv")
|
| 18 |
-
|
| 19 |
-
# --- Data Fetching and Processing ---
|
| 20 |
-
|
| 21 |
-
def fetch_and_prepare_data(start_date, end_date, tickers):
|
| 22 |
-
"""
|
| 23 |
-
Fetches historical data for the given tickers and processes it.
|
| 24 |
-
Returns a DataFrame with 'Close' prices for each ticker.
|
| 25 |
-
"""
|
| 26 |
-
print(f"Fetching data from {start_date} to {end_date} for {tickers}...")
|
| 27 |
-
|
| 28 |
-
data = yf.download(tickers, start=start_date, end=end_date)
|
| 29 |
-
|
| 30 |
-
# CHANGE: Add .copy() to explicitly create a new DataFrame and avoid warnings.
|
| 31 |
-
close_data = data['Close'].copy()
|
| 32 |
-
|
| 33 |
-
print("\nData Head:")
|
| 34 |
-
print(close_data.head())
|
| 35 |
-
|
| 36 |
-
print("\nMissing values before cleaning:")
|
| 37 |
-
print(close_data.isnull().sum())
|
| 38 |
-
|
| 39 |
-
# Now, all inplace operations are safely performed on our own copy.
|
| 40 |
-
close_data.ffill(inplace=True)
|
| 41 |
-
close_data.bfill(inplace=True)
|
| 42 |
-
|
| 43 |
-
print("\nMissing values after cleaning:")
|
| 44 |
-
print(close_data.isnull().sum())
|
| 45 |
-
|
| 46 |
-
for col in close_data.columns:
|
| 47 |
-
close_data[col] = pd.to_numeric(close_data[col], errors='coerce')
|
| 48 |
-
|
| 49 |
-
close_data.dropna(inplace=True)
|
| 50 |
-
|
| 51 |
-
return close_data
|
| 52 |
-
|
| 53 |
-
def main():
|
| 54 |
-
"""Main function to run the data fetching process."""
|
| 55 |
-
# Create data directory if it doesn't exist
|
| 56 |
-
if not os.path.exists(DATA_DIR):
|
| 57 |
-
os.makedirs(DATA_DIR)
|
| 58 |
-
print(f"Created directory: {DATA_DIR}")
|
| 59 |
-
|
| 60 |
-
# Fetch, process, and save training data
|
| 61 |
-
print("--- Preparing Training Data ---")
|
| 62 |
-
train_data = fetch_and_prepare_data(TRAIN_START_DATE, TRAIN_END_DATE, TICKERS)
|
| 63 |
-
train_data.to_csv(TRAIN_DATA_PATH)
|
| 64 |
-
print(f"Training data saved to {TRAIN_DATA_PATH}")
|
| 65 |
-
|
| 66 |
-
print("\n" + "="*50 + "\n")
|
| 67 |
-
|
| 68 |
-
# Fetch, process, and save testing data
|
| 69 |
-
print("--- Preparing Testing Data ---")
|
| 70 |
-
test_data = fetch_and_prepare_data(TEST_START_DATE, TEST_END_DATE, TICKERS)
|
| 71 |
-
test_data.to_csv(TEST_DATA_PATH)
|
| 72 |
-
print(f"Testing data saved to {TEST_DATA_PATH}")
|
| 73 |
-
|
| 74 |
-
if __name__ == "__main__":
|
| 75 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/fetch_market_data.py
CHANGED
|
@@ -1,78 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import argparse
|
| 2 |
import os
|
| 3 |
-
import
|
| 4 |
-
import
|
| 5 |
-
from datetime import date
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
start_date (str): The start date for the data in 'YYYY-MM-DD' format.
|
| 13 |
-
end_date (str): The end date for the data in 'YYYY-MM-DD' format.
|
| 14 |
-
output_filename (str): The path and name of the file to save the data.
|
| 15 |
"""
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
# Data Cleaning
|
| 31 |
-
print(f"\nMissing values before cleaning:\n{close_data.isnull().sum()}")
|
| 32 |
-
close_data.ffill(inplace=True)
|
| 33 |
-
close_data.bfill(inplace=True)
|
| 34 |
-
|
| 35 |
-
# Drop any columns that are still all NaN (like BTC in the 2008 data)
|
| 36 |
-
close_data.dropna(axis=1, how='all', inplace=True)
|
| 37 |
-
|
| 38 |
-
print(f"\nMissing values after cleaning:\n{close_data.isnull().sum()}")
|
| 39 |
|
| 40 |
-
#
|
| 41 |
-
|
| 42 |
-
if output_dir and not os.path.exists(output_dir):
|
| 43 |
-
os.makedirs(output_dir)
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
"--start",
|
| 56 |
-
type=str,
|
| 57 |
-
default="2018-01-01",
|
| 58 |
-
help="Start date in YYYY-MM-DD format. Default is for the 2018 stress test."
|
| 59 |
-
)
|
| 60 |
-
parser.add_argument(
|
| 61 |
-
"--end",
|
| 62 |
-
type=str,
|
| 63 |
-
default="2019-12-31",
|
| 64 |
-
help="End date in YYYY-MM-DD format. Default is for the 2018 stress test."
|
| 65 |
-
)
|
| 66 |
-
parser.add_argument(
|
| 67 |
-
"--filename",
|
| 68 |
-
type=str,
|
| 69 |
-
default="data/stress_test_2018.csv",
|
| 70 |
-
help="Output file name (e.g., 'data/my_data.csv')."
|
| 71 |
-
)
|
| 72 |
|
| 73 |
-
args = parser.parse_args()
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
|
|
|
|
| 1 |
+
# scripts/fetch_market_data.py
|
| 2 |
+
|
| 3 |
+
import yfinance as yf_lib
|
| 4 |
+
import pandas as pd
|
| 5 |
import argparse
|
| 6 |
import os
|
| 7 |
+
from datetime import datetime, timedelta
|
| 8 |
+
from pandas_datareader import data as pdr
|
|
|
|
| 9 |
|
| 10 |
+
# --- MOVE THESE OUTSIDE THE FUNCTION ---
|
| 11 |
+
# Define your assets (Global variable, importable)
|
| 12 |
+
ASSETS = ['AAPL', 'MSFT', 'SPY', 'TLT', 'BTC-USD']
|
| 13 |
+
|
| 14 |
+
# Define FRED IDs for macroeconomic data (Global variable, importable)
|
| 15 |
+
FRED_IDS = {
|
| 16 |
+
'DFF': 'Federal Funds Rate', # Daily Federal Funds Rate
|
| 17 |
+
'CPIAUCSL': 'CPI', # Consumer Price Index (All Urban Consumers, Seasonally Adjusted, Monthly)
|
| 18 |
+
'VIXCLS': 'VIX' # CBOE Volatility Index (VIX) from FRED
|
| 19 |
+
}
|
| 20 |
+
# ---------------------------------------
|
| 21 |
|
| 22 |
+
def fetch_market_data(start_date, end_date, filename):
|
|
|
|
|
|
|
|
|
|
| 23 |
"""
|
| 24 |
+
Fetches market data, macroeconomic indicators (including VIX from FRED),
|
| 25 |
+
for specified assets and time period, then saves it to a CSV file.
|
| 26 |
+
"""
|
| 27 |
+
# No need to re-define assets and fred_ids here.
|
| 28 |
+
# The function will use the global ASSETS and FRED_IDS defined above.
|
| 29 |
|
| 30 |
+
print(f"--- Fetching market data for {ASSETS} from {start_date} to {end_date} ---")
|
| 31 |
+
|
| 32 |
+
# 1. Fetch Asset Prices (Daily) using yf_lib.download()
|
| 33 |
+
try:
|
| 34 |
+
# Use the global ASSETS variable
|
| 35 |
+
df_prices = yf_lib.download(ASSETS, start=start_date, end=end_date)['Close']
|
| 36 |
+
df_prices.dropna(inplace=True)
|
| 37 |
+
print(f"✅ Fetched {len(ASSETS)} asset prices.")
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"❌ Error fetching asset prices: {e}")
|
| 40 |
+
return None # Return None on failure
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
# 2. Fetch Macro Data (VIX, Federal Funds Rate, CPI) from FRED using pandas_datareader
|
| 43 |
+
print("--- Fetching macroeconomic data from FRED ---")
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
try:
|
| 46 |
+
# FRED data can be tricky with exact date ranges, fetching a bit more to ensure coverage
|
| 47 |
+
fred_start_date = (datetime.strptime(start_date, '%Y-%m-%d') - timedelta(days=365)).strftime('%Y-%m-%d')
|
| 48 |
|
| 49 |
+
# Use the global FRED_IDS variable
|
| 50 |
+
df_fred = pdr.DataReader(list(FRED_IDS.keys()), 'fred', start=fred_start_date, end=end_date)
|
| 51 |
+
df_fred.rename(columns=FRED_IDS, inplace=True)
|
| 52 |
+
print("✅ Fetched Federal Funds Rate, CPI, and VIX data from FRED.")
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"❌ Error fetching FRED data: {e}. Check FRED API access or ticker validity.")
|
| 55 |
+
df_fred = pd.DataFrame() # Create empty dataframe if fetch fails
|
| 56 |
|
| 57 |
+
# Combine all dataframes
|
| 58 |
+
df_combined = df_prices.copy()
|
| 59 |
+
|
| 60 |
+
# Merge FRED data (now includes VIX)
|
| 61 |
+
if not df_fred.empty:
|
| 62 |
+
df_combined = df_combined.merge(df_fred, left_index=True, right_index=True, how='left')
|
| 63 |
+
|
| 64 |
+
# Handle missing macro data: forward-fill and then back-fill for initial NaNs
|
| 65 |
+
# This loop now covers all FRED columns
|
| 66 |
+
# Use the global FRED_IDS variable
|
| 67 |
+
for col_name in FRED_IDS.values():
|
| 68 |
+
if col_name in df_combined.columns:
|
| 69 |
+
df_combined[col_name] = df_combined[col_name].ffill().bfill()
|
| 70 |
+
# Drop rows if they still have NaN for macro data after fill
|
| 71 |
+
df_combined.dropna(subset=[col_name], inplace=True) # Added dropna for robustness
|
| 72 |
+
|
| 73 |
+
# Ensure all data is within the requested date range after merging/filling
|
| 74 |
+
df_combined = df_combined.loc[start_date:end_date]
|
| 75 |
+
df_combined.dropna(inplace=True) # Final dropna for any remaining NaNs
|
| 76 |
+
|
| 77 |
+
if df_combined.empty:
|
| 78 |
+
print("❌ Final combined dataframe is empty after merging and cleaning. Check date ranges and data availability.")
|
| 79 |
+
return None # Return None on failure
|
| 80 |
+
|
| 81 |
+
# Save to CSV if a filename is provided
|
| 82 |
+
if filename:
|
| 83 |
+
output_dir = os.path.dirname(filename)
|
| 84 |
+
if output_dir and not os.path.dirname(filename) == "":
|
| 85 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 86 |
+
|
| 87 |
+
df_combined.to_csv(filename, index=True)
|
| 88 |
+
print(f"\n✅ Data saved successfully to {filename}")
|
| 89 |
+
|
| 90 |
+
print(f"Final data shape: {df_combined.shape}")
|
| 91 |
+
print("Columns:", df_combined.columns.tolist())
|
| 92 |
|
| 93 |
+
return df_combined # Return the DataFrame
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
|
|
|
| 95 |
|
| 96 |
+
if __name__ == '__main__':
|
| 97 |
+
parser = argparse.ArgumentParser(description="Fetch market and macroeconomic data.")
|
| 98 |
+
parser.add_argument("--start", type=str, default="2015-01-01", help="Start date (YYYY-MM-DD).")
|
| 99 |
+
parser.add_argument("--end", type=str, default="2020-12-31", help="End date (YYYY-MM-DD).")
|
| 100 |
+
parser.add_argument("--filename", type=str, default="data/train.csv", help="Output CSV filename.")
|
| 101 |
+
|
| 102 |
+
args = parser.parse_args()
|
| 103 |
|
| 104 |
+
fetch_market_data(args.start, args.end, args.filename)
|
scripts/llm_analysis_rag.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# scripts/llm_analysis_rag.py
|
| 2 |
+
import gc
|
| 3 |
+
import time
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
import torch
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
import json
|
| 10 |
+
import re
|
| 11 |
+
from datetime import datetime, timedelta
|
| 12 |
+
|
| 13 |
+
# LangChain components
|
| 14 |
+
from langchain_community.vectorstores import Chroma
|
| 15 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 16 |
+
from langchain_huggingface import HuggingFacePipeline
|
| 17 |
+
from langchain_classic.chains import RetrievalQA
|
| 18 |
+
from langchain_classic.prompts import PromptTemplate
|
| 19 |
+
from langchain_classic.docstore.document import Document
|
| 20 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 21 |
+
|
| 22 |
+
# --- Configuration ---
|
| 23 |
+
# HF_EMBEDDING_MODEL is no longer used in this reduced scope
|
| 24 |
+
HF_GENERATION_MODEL = "Qwen/Qwen2.5-3B-Instruct"
|
| 25 |
+
|
| 26 |
+
# Global variables
|
| 27 |
+
llm_pipeline_hf_instance = None
|
| 28 |
+
|
| 29 |
+
# --- Helper: Robust JSON Extractor ---
|
| 30 |
+
def extract_clean_json(response_text):
|
| 31 |
+
"""Robust JSON extractor handling Python booleans and Markdown."""
|
| 32 |
+
json_match = re.search(r'```json\s*(.*?)\s*```', response_text, re.DOTALL)
|
| 33 |
+
if json_match:
|
| 34 |
+
text_to_parse = json_match.group(1)
|
| 35 |
+
else:
|
| 36 |
+
start_idx = response_text.find('{')
|
| 37 |
+
end_idx = response_text.rfind('}')
|
| 38 |
+
if start_idx != -1 and end_idx != -1:
|
| 39 |
+
text_to_parse = response_text[start_idx:end_idx+1]
|
| 40 |
+
else:
|
| 41 |
+
# print(f"❌ PARSE ERROR: No JSON found: {response_text[:100]}...")
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
text_to_parse = text_to_parse.replace(": True", ": true").replace(": False", ": false")
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
return json.loads(text_to_parse)
|
| 48 |
+
except json.JSONDecodeError:
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
# --- Shared LLM Setup (Singleton Pattern) ---
|
| 52 |
+
def setup_llm_pipeline():
|
| 53 |
+
global llm_pipeline_hf_instance
|
| 54 |
+
if llm_pipeline_hf_instance is None:
|
| 55 |
+
print(f"--- Loading Model: {HF_GENERATION_MODEL} ---")
|
| 56 |
+
tokenizer = AutoTokenizer.from_pretrained(HF_GENERATION_MODEL, trust_remote_code=True)
|
| 57 |
+
# 4-bit quantization config for efficient loading
|
| 58 |
+
bnb_config = BitsAndBytesConfig(
|
| 59 |
+
load_in_4bit=True, bnb_4bit_use_double_quant=True,
|
| 60 |
+
bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
|
| 61 |
+
)
|
| 62 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 63 |
+
HF_GENERATION_MODEL, trust_remote_code=True,
|
| 64 |
+
quantization_config=bnb_config, device_map="auto"
|
| 65 |
+
)
|
| 66 |
+
# Create the HF pipeline
|
| 67 |
+
pipe = pipeline(
|
| 68 |
+
"text-generation", model=model, tokenizer=tokenizer,
|
| 69 |
+
max_new_tokens=1024, # Increased token limit for detailed historical analysis
|
| 70 |
+
do_sample=False, temperature=0.1, # Low temp for factual responses
|
| 71 |
+
return_full_text=False
|
| 72 |
+
)
|
| 73 |
+
llm_pipeline_hf_instance = HuggingFacePipeline(pipeline=pipe)
|
| 74 |
+
return llm_pipeline_hf_instance
|
| 75 |
+
|
| 76 |
+
# =========================================
|
| 77 |
+
# NEW FUNCTION: Structured Historical Analysis
|
| 78 |
+
# =========================================
|
| 79 |
+
def analyze_historical_segment(df_segment, selected_assets, period_name):
|
| 80 |
+
"""
|
| 81 |
+
Analyzes a specific segment of historical data directly without RAG.
|
| 82 |
+
Takes a DataFrame slice, calculates summary stats, and prompts the LLM.
|
| 83 |
+
"""
|
| 84 |
+
llm = setup_llm_pipeline()
|
| 85 |
+
print(f"--- Running Historical Analysis for {period_name} ---")
|
| 86 |
+
|
| 87 |
+
# 1. Create quantitative summary of the data segment for the prompt
|
| 88 |
+
if df_segment.empty:
|
| 89 |
+
return "No data available for this period to analyze."
|
| 90 |
+
|
| 91 |
+
start_date = df_segment.index.min().date()
|
| 92 |
+
end_date = df_segment.index.max().date()
|
| 93 |
+
|
| 94 |
+
start_vals = df_segment.iloc[0]
|
| 95 |
+
end_vals = df_segment.iloc[-1]
|
| 96 |
+
# Calculate percentage change over the period, handling potential zeros
|
| 97 |
+
pct_changes = ((end_vals - start_vals) / (start_vals.replace(0, np.nan)) * 100).fillna(0)
|
| 98 |
+
|
| 99 |
+
# Build the context string
|
| 100 |
+
data_summary = f"Analysis Period: {period_name} ({start_date} to {end_date})\n\n"
|
| 101 |
+
data_summary += "Performance Summary over Period:\n"
|
| 102 |
+
for asset in selected_assets:
|
| 103 |
+
if asset in df_segment.columns:
|
| 104 |
+
change = pct_changes[asset]
|
| 105 |
+
direction = "gained" if change > 0 else "lost"
|
| 106 |
+
data_summary += f"- {asset}: {direction} {abs(change):.2f}%\n"
|
| 107 |
+
|
| 108 |
+
# Add volatility context (standard deviation of daily returns)
|
| 109 |
+
data_summary += "\nVolatility Context (Daily Return Std Dev):\n"
|
| 110 |
+
daily_rets = df_segment.pct_change()
|
| 111 |
+
std_devs = daily_rets.std() * 100
|
| 112 |
+
for asset in selected_assets:
|
| 113 |
+
if asset in std_devs.index:
|
| 114 |
+
data_summary += f"- {asset}: {std_devs[asset]:.2f}%\n"
|
| 115 |
+
|
| 116 |
+
# 2. Create the Prompt
|
| 117 |
+
# We use Qwen's chat template format (<|im_start|>...)
|
| 118 |
+
prompt_template = """<|im_start|>system
|
| 119 |
+
You are a senior financial analyst. Your job is to analyze historical market data trends for selected assets over a specific time period.
|
| 120 |
+
Provide a concise, professional, and insightful summary of the performance, key trends, and comparative movements based *only* on the provided data summary.
|
| 121 |
+
Highlight significant gains, losses, or differences in volatility between the assets.
|
| 122 |
+
|
| 123 |
+
### DATA CONTEXT:
|
| 124 |
+
{data_summary}
|
| 125 |
+
<|im_end|>
|
| 126 |
+
<|im_start|>user
|
| 127 |
+
Generate the historical analysis report.
|
| 128 |
+
<|im_end|>
|
| 129 |
+
<|im_start|>assistant
|
| 130 |
+
"""
|
| 131 |
+
pt = PromptTemplate(template=prompt_template, input_variables=["data_summary"])
|
| 132 |
+
formatted_prompt = pt.format(data_summary=data_summary)
|
| 133 |
+
|
| 134 |
+
# 3. Invoke LLM
|
| 135 |
+
response = llm.invoke(formatted_prompt)
|
| 136 |
+
return response.strip()
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# --- Decision Analysis (Kept for Forecast Tab) ---
|
| 140 |
+
def analyze_agent_decision(current_market_data_window, proposed_allocations):
|
| 141 |
+
"""
|
| 142 |
+
HYBRID ANALYZER: Python does the math, LLM does the talking.
|
| 143 |
+
"""
|
| 144 |
+
llm = setup_llm_pipeline()
|
| 145 |
+
|
| 146 |
+
# --- 1. PREPARE DATA ---
|
| 147 |
+
# (Logic remains the same as before...)
|
| 148 |
+
vix_level = current_market_data_window['VIX'].iloc[-1] if 'VIX' in current_market_data_window else 0
|
| 149 |
+
|
| 150 |
+
# Identify largest position
|
| 151 |
+
risky_assets = {k:v for k,v in proposed_allocations.items() if k not in ['Cash', 'TLT']}
|
| 152 |
+
if risky_assets:
|
| 153 |
+
max_asset = max(risky_assets, key=risky_assets.get)
|
| 154 |
+
max_val = risky_assets[max_asset] * 100
|
| 155 |
+
else:
|
| 156 |
+
max_asset = "None"
|
| 157 |
+
max_val = 0.0
|
| 158 |
+
|
| 159 |
+
safe_haven_pct = (proposed_allocations.get('Cash', 0) + proposed_allocations.get('TLT', 0)) * 100
|
| 160 |
+
|
| 161 |
+
# --- 2. PYTHON LOGIC CORE ---
|
| 162 |
+
trigger_safe_haven = safe_haven_pct > 80.0
|
| 163 |
+
trigger_crash_rule = vix_level > 20.0 and safe_haven_pct < 30.0
|
| 164 |
+
trigger_concentration = vix_level > 15.0 and max_val > 40.0
|
| 165 |
+
|
| 166 |
+
# Determine Verdict Programmatically
|
| 167 |
+
calculated_risk = "MODERATE"
|
| 168 |
+
reason_code = "Standard market conditions."
|
| 169 |
+
|
| 170 |
+
if trigger_safe_haven:
|
| 171 |
+
calculated_risk = "LOW"
|
| 172 |
+
reason_code = f"Safe Haven Exception triggered (Safe Assets: {safe_haven_pct:.1f}% > 80%)."
|
| 173 |
+
elif trigger_crash_rule:
|
| 174 |
+
calculated_risk = "HIGH"
|
| 175 |
+
reason_code = f"Crash Protocol triggered (VIX {vix_level:.1f} > 20 and Safe Haven < 30%)."
|
| 176 |
+
elif trigger_concentration:
|
| 177 |
+
calculated_risk = "HIGH"
|
| 178 |
+
reason_code = f"Concentration Rule triggered (VIX {vix_level:.1f} > 15 and {max_asset} > 40%)."
|
| 179 |
+
|
| 180 |
+
# --- 3. THE "NARRATOR" PROMPT ---
|
| 181 |
+
prompt_template = """<|im_start|>system
|
| 182 |
+
You are a Senior Risk Analyst.
|
| 183 |
+
The Quantitative Engine has already processed the data and determined the Risk Level.
|
| 184 |
+
Your job is to summarize the strategy and explain the risk verdict to the user.
|
| 185 |
+
|
| 186 |
+
### QUANTITATIVE ENGINE OUTPUT:
|
| 187 |
+
- **Determined Risk Level:** {calculated_risk}
|
| 188 |
+
- **Primary Logic Trigger:** {reason_code}
|
| 189 |
+
|
| 190 |
+
### DATA CONTEXT:
|
| 191 |
+
- VIX: {vix:.2f}
|
| 192 |
+
- Largest Position: {max_asset} ({max_val:.1f}%)
|
| 193 |
+
- Safe Haven Allocation: {safe_pct:.1f}%
|
| 194 |
+
|
| 195 |
+
### INSTRUCTIONS:
|
| 196 |
+
1. **Strategy Summary:** Describe the allocation style (e.g., "Aggressive Tech", "Defensive Cash").
|
| 197 |
+
2. **Justification:** Explain the Risk Level using the "Primary Logic Trigger" provided above. Do not invent new math.
|
| 198 |
+
|
| 199 |
+
Return ONLY raw JSON:
|
| 200 |
+
{{
|
| 201 |
+
"strategy_summary": "string",
|
| 202 |
+
"risk_level": "{calculated_risk}",
|
| 203 |
+
"justification": "string",
|
| 204 |
+
"confidence_score": 10
|
| 205 |
+
}}
|
| 206 |
+
<|im_end|>
|
| 207 |
+
<|im_start|>user
|
| 208 |
+
Generate the report.
|
| 209 |
+
<|im_end|>
|
| 210 |
+
<|im_start|>assistant
|
| 211 |
+
"""
|
| 212 |
+
pt = PromptTemplate(template=prompt_template, input_variables=["calculated_risk", "reason_code", "vix", "max_asset", "max_val", "safe_pct"])
|
| 213 |
+
|
| 214 |
+
formatted = pt.format(
|
| 215 |
+
calculated_risk=calculated_risk,
|
| 216 |
+
reason_code=reason_code,
|
| 217 |
+
vix=vix_level,
|
| 218 |
+
max_asset=max_asset,
|
| 219 |
+
max_val=max_val,
|
| 220 |
+
safe_pct=safe_haven_pct
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
res = llm.invoke(formatted)
|
| 224 |
+
return extract_clean_json(res)
|
| 225 |
+
|
| 226 |
+
# --- MAIN (for testing) ---
|
| 227 |
+
if __name__ == '__main__':
|
| 228 |
+
print("Running test...")
|
| 229 |
+
# Generate Dummy Data
|
| 230 |
+
dates = pd.date_range(start="2023-01-01", periods=180, freq='D')
|
| 231 |
+
df_dummy = pd.DataFrame({
|
| 232 |
+
'SPY': np.linspace(400, 450, 180) + np.random.normal(0, 5, 180),
|
| 233 |
+
'BTC-USD': np.linspace(30000, 40000, 180) + np.random.normal(0, 1000, 180),
|
| 234 |
+
'VIX': np.linspace(20, 15, 180)
|
| 235 |
+
}, index=dates)
|
| 236 |
+
|
| 237 |
+
# Test the new historical analysis function
|
| 238 |
+
selected = ['SPY', 'BTC-USD']
|
| 239 |
+
period = "6 Months"
|
| 240 |
+
print(f"\nTesting analysis for {selected} over {period}...")
|
| 241 |
+
analysis = analyze_historical_segment(df_dummy, selected, period)
|
| 242 |
+
print("\n--- LLM Analysis Output ---")
|
| 243 |
+
print(analysis)
|
scripts/predict_tomorrow.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# scripts/predict_tomorrow.py
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
from datetime import datetime, timedelta
|
| 8 |
+
from stable_baselines3 import SAC
|
| 9 |
+
|
| 10 |
+
# --- Imports ---
|
| 11 |
+
try:
|
| 12 |
+
# Ensure we can find local scripts
|
| 13 |
+
sys.path.append(os.getcwd())
|
| 14 |
+
except:
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
from fetch_market_data import fetch_market_data, ASSETS, FRED_IDS
|
| 18 |
+
from llm_analysis_rag import analyze_agent_decision
|
| 19 |
+
|
| 20 |
+
# --- Configuration ---
|
| 21 |
+
MODEL_PATH = "checkpoints/sac_portfolio_model.zip"
|
| 22 |
+
WINDOW_SIZE = 30
|
| 23 |
+
MACRO_COLS = list(FRED_IDS.values()) # ['Federal Funds Rate', 'CPI', 'VIX']
|
| 24 |
+
|
| 25 |
+
def get_latest_data_window(window_size=30):
|
| 26 |
+
"""
|
| 27 |
+
Fetches live data and returns the last 'window_size' rows.
|
| 28 |
+
"""
|
| 29 |
+
print("--- 🔄 Fetching Real-Time Data for Prediction ---")
|
| 30 |
+
|
| 31 |
+
# Fetch a buffer to ensure we have enough data after cleaning
|
| 32 |
+
lookback_days = window_size + 100
|
| 33 |
+
end_date = datetime.now().strftime('%Y-%m-%d')
|
| 34 |
+
start_date = (datetime.now() - timedelta(days=lookback_days)).strftime('%Y-%m-%d')
|
| 35 |
+
|
| 36 |
+
# We don't strictly need to save to a file for prediction, so filename=None
|
| 37 |
+
df = fetch_market_data(start_date, end_date, filename=None)
|
| 38 |
+
|
| 39 |
+
if df is None or len(df) < window_size:
|
| 40 |
+
print(f"❌ Not enough data fetched. Got {len(df) if df is not None else 0} rows, needed {window_size}.")
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
# Return exactly the last N rows (Observation Window)
|
| 44 |
+
return df.iloc[-window_size:].copy()
|
| 45 |
+
|
| 46 |
+
def prepare_observation(data_window):
|
| 47 |
+
"""
|
| 48 |
+
Normalizes data: Window / First_Row_of_Window
|
| 49 |
+
"""
|
| 50 |
+
# Extract specific columns to guarantee order
|
| 51 |
+
price_data = data_window[ASSETS].values
|
| 52 |
+
macro_data = data_window[MACRO_COLS].values
|
| 53 |
+
|
| 54 |
+
# Normalize
|
| 55 |
+
norm_prices = price_data / (price_data[0] + 1e-8)
|
| 56 |
+
norm_macro = macro_data / (macro_data[0] + 1e-8)
|
| 57 |
+
|
| 58 |
+
# Concatenate and flatten for MLP input
|
| 59 |
+
obs = np.concatenate([norm_prices, norm_macro], axis=1)
|
| 60 |
+
return obs.flatten().astype(np.float32)
|
| 61 |
+
|
| 62 |
+
def get_allocations(action):
|
| 63 |
+
"""Applies Softmax to convert raw action to weights"""
|
| 64 |
+
action = np.asarray(action).flatten()
|
| 65 |
+
exp_action = np.exp(action)
|
| 66 |
+
return exp_action / np.sum(exp_action)
|
| 67 |
+
|
| 68 |
+
def main():
|
| 69 |
+
print(f"🚀 Prediction Job: {datetime.now().strftime('%Y-%m-%d')}")
|
| 70 |
+
|
| 71 |
+
# 1. Get Data
|
| 72 |
+
data_window = get_latest_data_window(WINDOW_SIZE)
|
| 73 |
+
if data_window is None: return
|
| 74 |
+
|
| 75 |
+
# 2. Prepare Obs
|
| 76 |
+
obs = prepare_observation(data_window)
|
| 77 |
+
|
| 78 |
+
# 3. Load MLP Model
|
| 79 |
+
if not os.path.exists(MODEL_PATH):
|
| 80 |
+
print(f"❌ Model not found at {MODEL_PATH}")
|
| 81 |
+
return
|
| 82 |
+
|
| 83 |
+
print(f"Loading MLP SAC model...")
|
| 84 |
+
model = SAC.load(MODEL_PATH)
|
| 85 |
+
|
| 86 |
+
# 4. Predict
|
| 87 |
+
action, _ = model.predict(obs, deterministic=True)
|
| 88 |
+
weights = get_allocations(action)
|
| 89 |
+
|
| 90 |
+
# 5. Format Allocations (THE FIX IS HERE)
|
| 91 |
+
allocations = {}
|
| 92 |
+
for i, asset in enumerate(ASSETS):
|
| 93 |
+
allocations[asset] = float(weights[i]) # Explicit float() cast
|
| 94 |
+
allocations['Cash'] = float(weights[-1]) # Explicit float() cast
|
| 95 |
+
|
| 96 |
+
# 6. Output Results
|
| 97 |
+
print("\n" + "="*40)
|
| 98 |
+
print(f"🤖 SAC MLP MODEL RECOMMENDATION")
|
| 99 |
+
print("="*40)
|
| 100 |
+
for asset, weight in allocations.items():
|
| 101 |
+
print(f"{asset:<10} : {weight:6.2%}")
|
| 102 |
+
print("="*40)
|
| 103 |
+
|
| 104 |
+
# 7. AI Risk Analyst
|
| 105 |
+
print("\n🧠 Running AI Risk Analysis...")
|
| 106 |
+
|
| 107 |
+
# Now this will work because all numbers are standard floats
|
| 108 |
+
analysis = analyze_agent_decision(data_window, allocations)
|
| 109 |
+
|
| 110 |
+
if isinstance(analysis, dict):
|
| 111 |
+
print(f"\nStrategy: {analysis.get('strategy_summary')}")
|
| 112 |
+
print(f"Risk Level: {analysis.get('risk_level')}")
|
| 113 |
+
print(f"Justification: {analysis.get('justification')}")
|
| 114 |
+
|
| 115 |
+
if analysis.get('risk_level') == 'High':
|
| 116 |
+
print("\n⛔ BLOCKING TRADE: High Risk detected by AI Guardrail.")
|
| 117 |
+
else:
|
| 118 |
+
print("\n✅ TRADE APPROVED.")
|
| 119 |
+
else:
|
| 120 |
+
print(analysis)
|
| 121 |
+
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
main()
|
scripts/tune_sac.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# scripts/tune_sac.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
import optuna
|
| 8 |
+
from stable_baselines3 import SAC
|
| 9 |
+
from stable_baselines3.common.vec_env import DummyVecEnv # Use DummyVecEnv
|
| 10 |
+
from stable_baselines3.common.callbacks import EvalCallback
|
| 11 |
+
from stable_baselines3.common.logger import configure
|
| 12 |
+
|
| 13 |
+
from environment import PortfolioEnv
|
| 14 |
+
|
| 15 |
+
# ==============================================================================
|
| 16 |
+
# 1. Configuration & Data Loading
|
| 17 |
+
# ==============================================================================
|
| 18 |
+
|
| 19 |
+
TRAIN_DATA_PATH = 'data/train.csv'
|
| 20 |
+
EVAL_DATA_PATH = 'data/eval.csv'
|
| 21 |
+
OPTUNA_LOG_DIR = 'optuna_logs'
|
| 22 |
+
CHECKPOINT_DIR = 'checkpoints/optuna_sac_trials'
|
| 23 |
+
|
| 24 |
+
# Create directories if they don't exist
|
| 25 |
+
os.makedirs(OPTUNA_LOG_DIR, exist_ok=True)
|
| 26 |
+
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
| 27 |
+
|
| 28 |
+
# Load data once
|
| 29 |
+
df_full_train = pd.read_csv(TRAIN_DATA_PATH, index_col='Date', parse_dates=True)
|
| 30 |
+
df_eval = pd.read_csv(EVAL_DATA_PATH, index_col='Date', parse_dates=True)
|
| 31 |
+
|
| 32 |
+
# Split df_full_train for tuning
|
| 33 |
+
train_split_point = int(len(df_full_train) * 0.8)
|
| 34 |
+
df_train_tune = df_full_train.iloc[:train_split_point]
|
| 35 |
+
df_validation_tune = df_full_train.iloc[train_split_point:]
|
| 36 |
+
|
| 37 |
+
print(f"Total training data points: {len(df_full_train)}")
|
| 38 |
+
print(f"Optuna training data points: {len(df_train_tune)}")
|
| 39 |
+
print(f"Optuna validation data points: {len(df_validation_tune)}")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ==============================================================================
|
| 43 |
+
# 2. Environment Creation Helper
|
| 44 |
+
# ==============================================================================
|
| 45 |
+
|
| 46 |
+
def make_env(df, window_size=30, initial_balance=10000, transaction_cost_pct=0.001):
|
| 47 |
+
"""
|
| 48 |
+
Helper function to create a PortfolioEnv instance.
|
| 49 |
+
"""
|
| 50 |
+
def _init():
|
| 51 |
+
env = PortfolioEnv(
|
| 52 |
+
df=df,
|
| 53 |
+
initial_balance=initial_balance,
|
| 54 |
+
window_size=window_size,
|
| 55 |
+
transaction_cost_pct=transaction_cost_pct
|
| 56 |
+
)
|
| 57 |
+
return env
|
| 58 |
+
return _init
|
| 59 |
+
|
| 60 |
+
# ==============================================================================
|
| 61 |
+
# 3. Optuna Objective Function
|
| 62 |
+
# ==============================================================================
|
| 63 |
+
|
| 64 |
+
def objective(trial: optuna.Trial) -> float:
|
| 65 |
+
"""
|
| 66 |
+
Objective function for Optuna to optimize hyperparameters for SAC.
|
| 67 |
+
"""
|
| 68 |
+
# Hyperparameter search space
|
| 69 |
+
learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-3, log=True)
|
| 70 |
+
gamma = trial.suggest_float('gamma', 0.9, 0.999)
|
| 71 |
+
tau = trial.suggest_float('tau', 0.005, 0.02)
|
| 72 |
+
buffer_size = trial.suggest_int('buffer_size', 50000, 1000000, log=True)
|
| 73 |
+
batch_size = trial.suggest_categorical('batch_size', [64, 128, 256, 512])
|
| 74 |
+
ent_coef = trial.suggest_float('ent_coef', 0.001, 0.1, log=True) # Use log scale for ent_coef
|
| 75 |
+
|
| 76 |
+
# Network architecture
|
| 77 |
+
n_layers = trial.suggest_int('n_layers', 1, 3)
|
| 78 |
+
net_arch = []
|
| 79 |
+
for i in range(n_layers):
|
| 80 |
+
layer_size = trial.suggest_categorical(f'layer_size_{i}', [64, 128, 256])
|
| 81 |
+
net_arch.append(layer_size)
|
| 82 |
+
|
| 83 |
+
policy_kwargs = dict(net_arch=net_arch) # SAC uses shared network or separate [pi, qf]
|
| 84 |
+
|
| 85 |
+
# Create environments for this trial
|
| 86 |
+
train_env = DummyVecEnv([make_env(df_train_tune)])
|
| 87 |
+
eval_env = DummyVecEnv([make_env(df_validation_tune)])
|
| 88 |
+
|
| 89 |
+
# Set up logger for the trial
|
| 90 |
+
trial_log_path = os.path.join(OPTUNA_LOG_DIR, f"trial_{trial.number}")
|
| 91 |
+
new_logger = configure(trial_log_path, ["stdout", "csv", "tensorboard"])
|
| 92 |
+
|
| 93 |
+
# Create SAC model
|
| 94 |
+
model = SAC(
|
| 95 |
+
"MlpPolicy",
|
| 96 |
+
train_env,
|
| 97 |
+
learning_rate=learning_rate,
|
| 98 |
+
gamma=gamma,
|
| 99 |
+
tau=tau,
|
| 100 |
+
buffer_size=buffer_size,
|
| 101 |
+
batch_size=batch_size,
|
| 102 |
+
ent_coef=ent_coef, # Pass the sampled value
|
| 103 |
+
policy_kwargs=policy_kwargs,
|
| 104 |
+
verbose=0,
|
| 105 |
+
seed=42, # Use a fixed seed for reproducibility within a trial
|
| 106 |
+
tensorboard_log=OPTUNA_LOG_DIR
|
| 107 |
+
)
|
| 108 |
+
model.set_logger(new_logger)
|
| 109 |
+
|
| 110 |
+
# Callback for evaluation
|
| 111 |
+
eval_callback = EvalCallback(
|
| 112 |
+
eval_env,
|
| 113 |
+
best_model_save_path=os.path.join(CHECKPOINT_DIR, f"best_sac_trial_{trial.number}"),
|
| 114 |
+
log_path=trial_log_path,
|
| 115 |
+
eval_freq=5000,
|
| 116 |
+
deterministic=True,
|
| 117 |
+
render=False,
|
| 118 |
+
n_eval_episodes=1
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
try:
|
| 122 |
+
# Train for a set number of steps per trial
|
| 123 |
+
total_timesteps_per_trial = 50000
|
| 124 |
+
model.learn(total_timesteps=total_timesteps_per_trial, callback=eval_callback, progress_bar=False)
|
| 125 |
+
|
| 126 |
+
# Load the best model found during this trial's training
|
| 127 |
+
best_model_path = os.path.join(CHECKPOINT_DIR, f"best_sac_trial_{trial.number}", "best_model.zip")
|
| 128 |
+
if os.path.exists(best_model_path):
|
| 129 |
+
model = SAC.load(best_model_path, env=eval_env)
|
| 130 |
+
else:
|
| 131 |
+
print(f"Warning: No best model saved for trial {trial.number}, using last model.")
|
| 132 |
+
|
| 133 |
+
# --- Final evaluation on the validation set ---
|
| 134 |
+
obs = eval_env.reset()
|
| 135 |
+
portfolio_values = [eval_env.envs[0].initial_balance]
|
| 136 |
+
done = False
|
| 137 |
+
while not done:
|
| 138 |
+
action, _ = model.predict(obs, deterministic=True)
|
| 139 |
+
obs, reward, done, info = eval_env.step(action)
|
| 140 |
+
portfolio_values.append(info[0]['portfolio_value'])
|
| 141 |
+
|
| 142 |
+
final_portfolio_value = portfolio_values[-1]
|
| 143 |
+
initial_portfolio_value = portfolio_values[0]
|
| 144 |
+
total_return = (final_portfolio_value / initial_portfolio_value) - 1
|
| 145 |
+
|
| 146 |
+
print(f"Trial {trial.number} finished. Total Return on validation: {total_return:.4f}")
|
| 147 |
+
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print(f"Trial {trial.number} failed due to: {e}")
|
| 150 |
+
return float('nan') # Optuna handles NaN as a failure
|
| 151 |
+
|
| 152 |
+
finally:
|
| 153 |
+
train_env.close()
|
| 154 |
+
eval_env.close()
|
| 155 |
+
|
| 156 |
+
return total_return # Optuna aims to maximize this metric
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ==============================================================================
|
| 160 |
+
# 4. Run Optuna Study
|
| 161 |
+
# ==============================================================================
|
| 162 |
+
|
| 163 |
+
if __name__ == '__main__':
|
| 164 |
+
study = optuna.create_study(
|
| 165 |
+
direction='maximize',
|
| 166 |
+
sampler=optuna.samplers.TPESampler(seed=42)
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
n_trials_to_run = 50
|
| 170 |
+
study.optimize(objective, n_trials=n_trials_to_run, n_jobs=1) # n_jobs=1 is safer for Colab
|
| 171 |
+
|
| 172 |
+
print("\n--- Optimization finished. ---")
|
| 173 |
+
print("Best trial:")
|
| 174 |
+
trial = study.best_trial
|
| 175 |
+
|
| 176 |
+
print(f" Value: {trial.value:.4f}")
|
| 177 |
+
print(" Params: ")
|
| 178 |
+
for key, value in trial.params.items():
|
| 179 |
+
print(f" {key}: {value}")
|
| 180 |
+
|
| 181 |
+
# Save the best parameters to a file
|
| 182 |
+
best_params = trial.params
|
| 183 |
+
with open('checkpoints/best_sac_params.txt', 'w') as f:
|
| 184 |
+
f.write(str(best_params))
|
| 185 |
+
print(f"\n✅ Best parameters saved to checkpoints/best_sac_params.txt")
|
| 186 |
+
|
| 187 |
+
# Plotting results
|
| 188 |
+
try:
|
| 189 |
+
import plotly
|
| 190 |
+
from optuna.visualization import plot_optimization_history, plot_param_importances
|
| 191 |
+
|
| 192 |
+
fig1 = plot_optimization_history(study)
|
| 193 |
+
fig1.show()
|
| 194 |
+
|
| 195 |
+
fig2 = plot_param_importances(study)
|
| 196 |
+
fig2.show()
|
| 197 |
+
except ImportError:
|
| 198 |
+
print("\nInstall plotly and kaleido to visualize Optuna results: !pip install plotly kaleido")
|
scripts/visualize_strategy.py
DELETED
|
@@ -1,123 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import os
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import numpy as np
|
| 5 |
-
import matplotlib.pyplot as plt
|
| 6 |
-
from matplotlib.ticker import FuncFormatter
|
| 7 |
-
from stable_baselines3 import PPO, SAC, TD3
|
| 8 |
-
from environment import PortfolioEnv
|
| 9 |
-
|
| 10 |
-
def visualize_strategy(agent_name, checkpoint_path, datafile_path, output_path):
|
| 11 |
-
"""
|
| 12 |
-
Loads a trained agent, runs a simulation, and plots its portfolio allocation strategy.
|
| 13 |
-
|
| 14 |
-
Args:
|
| 15 |
-
agent_name (str): The type of agent to load ('ppo', 'sac', 'td3').
|
| 16 |
-
checkpoint_path (str): The path to the saved model checkpoint file (.zip).
|
| 17 |
-
datafile_path (str): The path to the CSV market data for the simulation.
|
| 18 |
-
output_path (str): The path to save the output plot image.
|
| 19 |
-
"""
|
| 20 |
-
print(f"--- Visualizing strategy for {agent_name.upper()} agent ---")
|
| 21 |
-
|
| 22 |
-
# 1. Define a mapping from agent names to their classes
|
| 23 |
-
AGENT_CLASSES = {
|
| 24 |
-
"ppo": PPO,
|
| 25 |
-
"sac": SAC,
|
| 26 |
-
"td3": TD3
|
| 27 |
-
}
|
| 28 |
-
agent_class = AGENT_CLASSES[agent_name.lower()]
|
| 29 |
-
|
| 30 |
-
# 2. Load Data and Model
|
| 31 |
-
try:
|
| 32 |
-
test_df = pd.read_csv(datafile_path, index_col='Date', parse_dates=True)
|
| 33 |
-
model = agent_class.load(checkpoint_path)
|
| 34 |
-
except FileNotFoundError as e:
|
| 35 |
-
print(f"❌ Error: Could not find a required file. {e}")
|
| 36 |
-
return
|
| 37 |
-
except Exception as e:
|
| 38 |
-
print(f"❌ An error occurred: {e}")
|
| 39 |
-
return
|
| 40 |
-
|
| 41 |
-
# 3. Create Environment and Run Simulation
|
| 42 |
-
env = PortfolioEnv(test_df)
|
| 43 |
-
obs, info = env.reset()
|
| 44 |
-
terminated, truncated = False, False
|
| 45 |
-
|
| 46 |
-
weights_history = [info['weights']]
|
| 47 |
-
while not (terminated or truncated):
|
| 48 |
-
action, _states = model.predict(obs, deterministic=True)
|
| 49 |
-
obs, reward, terminated, truncated, info = env.step(action)
|
| 50 |
-
weights_history.append(info['weights'])
|
| 51 |
-
print("✅ Simulation complete.")
|
| 52 |
-
|
| 53 |
-
# 4. Prepare Data for Plotting
|
| 54 |
-
weights_df = pd.DataFrame(weights_history)
|
| 55 |
-
asset_names = test_df.columns.tolist() + ['Cash']
|
| 56 |
-
weights_df.columns = asset_names
|
| 57 |
-
weights_df.index = test_df.index[:len(weights_df)]
|
| 58 |
-
|
| 59 |
-
# 5. Plotting the Stacked Area Chart
|
| 60 |
-
print("📊 Generating plot...")
|
| 61 |
-
plt.style.use('seaborn-v0_8-darkgrid')
|
| 62 |
-
fig, ax = plt.subplots(figsize=(15, 8))
|
| 63 |
-
|
| 64 |
-
ax.stackplot(weights_df.index, weights_df.T, labels=weights_df.columns, alpha=0.8)
|
| 65 |
-
|
| 66 |
-
ax.set_title(f'Agent Portfolio Allocation Over Time ({agent_name.upper()})', fontsize=16)
|
| 67 |
-
ax.set_xlabel('Date', fontsize=12)
|
| 68 |
-
ax.set_ylabel('Portfolio Allocation (%)', fontsize=12)
|
| 69 |
-
ax.legend(loc='upper left', fontsize=10)
|
| 70 |
-
|
| 71 |
-
formatter = FuncFormatter(lambda y, p: f'{y:.0%}')
|
| 72 |
-
ax.yaxis.set_major_formatter(formatter)
|
| 73 |
-
|
| 74 |
-
plt.tight_layout()
|
| 75 |
-
|
| 76 |
-
# Ensure output directory exists
|
| 77 |
-
output_dir = os.path.dirname(output_path)
|
| 78 |
-
if output_dir and not os.path.exists(output_dir):
|
| 79 |
-
os.makedirs(output_dir)
|
| 80 |
-
|
| 81 |
-
plt.savefig(output_path)
|
| 82 |
-
print(f"✅ Plot saved to {output_path}")
|
| 83 |
-
plt.show()
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
if __name__ == "__main__":
|
| 87 |
-
# Set up command-line argument parsing
|
| 88 |
-
parser = argparse.ArgumentParser(description="Visualize a trained RL agent's portfolio allocation strategy.")
|
| 89 |
-
|
| 90 |
-
parser.add_argument(
|
| 91 |
-
"--agent",
|
| 92 |
-
type=str,
|
| 93 |
-
required=True,
|
| 94 |
-
choices=["ppo", "sac", "td3"],
|
| 95 |
-
help="The RL algorithm of the trained agent."
|
| 96 |
-
)
|
| 97 |
-
parser.add_argument(
|
| 98 |
-
"--checkpoint",
|
| 99 |
-
type=str,
|
| 100 |
-
required=True,
|
| 101 |
-
help="Path to the saved model checkpoint .zip file (e.g., 'td3_portfolio_model.zip')."
|
| 102 |
-
)
|
| 103 |
-
parser.add_argument(
|
| 104 |
-
"--datafile",
|
| 105 |
-
type=str,
|
| 106 |
-
default="data/test.csv",
|
| 107 |
-
help="Path to the market data CSV file to run the simulation on."
|
| 108 |
-
)
|
| 109 |
-
parser.add_argument(
|
| 110 |
-
"--output",
|
| 111 |
-
type=str,
|
| 112 |
-
default="results/agent_allocation.png",
|
| 113 |
-
help="Path to save the output plot image."
|
| 114 |
-
)
|
| 115 |
-
|
| 116 |
-
args = parser.parse_args()
|
| 117 |
-
|
| 118 |
-
visualize_strategy(
|
| 119 |
-
agent_name=args.agent,
|
| 120 |
-
checkpoint_path=args.checkpoint,
|
| 121 |
-
datafile_path=args.datafile,
|
| 122 |
-
output_path=args.output
|
| 123 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|