Commit
Β·
7d2e753
0
Parent(s):
Initial commit
Browse files- .gitattributes +1 -0
- .gitignore +7 -0
- README.md +215 -0
- assets/banner.png +3 -0
- requirements.txt +19 -0
- results/baseline_results.png +3 -0
- results/final_performance_comparison_all_agents.png +3 -0
- results/ppo_portfolio_alocation.png +3 -0
- results/sac_portfolio_alocation.png +3 -0
- results/stress_test_comparison_2018.png +3 -0
- results/td3_portfolio_alocation.png +3 -0
- scripts/check_env.py +32 -0
- scripts/environment.py +174 -0
- scripts/evaluate.py +142 -0
- scripts/evaluate_baselines.py +134 -0
- scripts/fetch_data.py +75 -0
- scripts/fetch_market_data.py +78 -0
- scripts/stress_test.py +142 -0
- scripts/train.py +77 -0
- scripts/visualize_strategy.py +123 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
venv/
|
| 4 |
+
.venv/
|
| 5 |
+
.vscode/
|
| 6 |
+
.idea/
|
| 7 |
+
|
README.md
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+

|
| 2 |
+
[](https://www.python.org/)[](https://pytorch.org/)[](LICENSE)
|
| 3 |
+
|
| 4 |
+
# π€ Portfolio Optimization with Deep Reinforcement Learning
|
| 5 |
+
|
| 6 |
+
This project explores the use of Deep Reinforcement Learning to train autonomous agents for financial portfolio management. The goal was not just to create a single profitable agent, but to conduct a comparative study of different RL algorithms (PPO, SAC, TD3) to understand the emergent trading strategies and their robustness across various market conditions.
|
| 7 |
+
|
| 8 |
+
**The ultimate finding? A TD3-based agent learned a superior, risk-managed static asset allocation that consistently outperformed both active trading strategies and aggressive growth models, especially during market downturns.**
|
| 9 |
+
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
## π Table of Contents
|
| 13 |
+
|
| 14 |
+
1. [π The Data & Asset Selection](#-the-data--asset-selection)
|
| 15 |
+
2. [π― Benchmarking Against Baselines](#-benchmarking-against-baselines)
|
| 16 |
+
3. [π Key Findings & The Champion Agent](#-key-findings--the-champion-agent)
|
| 17 |
+
4. [π§ Comparative Analysis of Agent Strategies](#-comparative-analysis-of-agent-strategies)
|
| 18 |
+
* [π₯ TD3: The Prudent Risk-Manager](#-td3-the-prudent-risk-manager)
|
| 19 |
+
* [π SAC: The Aggressive Growth Engine](#-sac-the-aggressive-growth-engine)
|
| 20 |
+
* [π PPO: The Active (but Inconsistent) Trader](#-ppo-the-active-but-inconsistent-trader)
|
| 21 |
+
5. [πͺοΈ Stress Testing: The Ultimate Test of Robustness](#οΈ-stress-testing-the-ultimate-test-of-robustness)
|
| 22 |
+
6. [π¬ The Research Journey: Why Simplicity Won](#-the-research-journey-why-simplicity-won)
|
| 23 |
+
7. [β
Conclusion](#-conclusion)
|
| 24 |
+
8. [π Project Structure](#-project-structure)
|
| 25 |
+
9. [π How to Run](#-how-to-run)
|
| 26 |
+
* [Setup](#setup)
|
| 27 |
+
* [Data Fetching](#data-fetching)
|
| 28 |
+
* [Training](#training)
|
| 29 |
+
* [Evaluation & Visualization](#evaluation--visualization)
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
## π The Data & Asset Selection
|
| 34 |
+
|
| 35 |
+
The foundation of any financial machine learning project is the data. This project uses daily closing price data sourced from **Yahoo Finance** via the `yfinance` library. The primary training period was **2015-2020**, with out-of-sample testing conducted on **2021-2023** and other periods for stress testing.
|
| 36 |
+
|
| 37 |
+
The selection of assets was crucial for creating a realistic decision-making environment for the agent. The portfolio consists of five assets, chosen to represent different classes and risk profiles:
|
| 38 |
+
|
| 39 |
+
* **Growth Equities (AAPL, MSFT):** Represent the high-growth, high-volatility technology sector.
|
| 40 |
+
* **Market Index (SPY):** An ETF tracking the S&P 500, representing the broader US stock market.
|
| 41 |
+
* **Safe Haven (TLT):** An ETF for 20+ Year US Treasury Bonds, which often acts as a "risk-off" asset during stock market downturns.
|
| 42 |
+
* **Alternative Asset (BTC-USD):** Represents a non-traditional, extremely volatile asset class with high potential returns.
|
| 43 |
+
|
| 44 |
+
This diverse mix forces the agent to learn not just about individual assets, but also about their correlations and how to balance risk across different economic regimes.
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
## π― Benchmarking Against Baselines
|
| 49 |
+
|
| 50 |
+
To prove that a reinforcement learning agent is truly "intelligent," its performance must be measured against simple, standard strategies. An agent is only successful if it can provide value beyond a naive approach.
|
| 51 |
+
|
| 52 |
+
Our primary benchmark was the **Buy and Hold** strategy, where an equal amount of capital is invested in each asset at the beginning of the period and never touched again. The goal for any trained RL agent was to achieve superior performance, especially on a **risk-adjusted basis** (e.g., higher Sharpe Ratio, lower Max Drawdown), compared to this baseline.
|
| 53 |
+
|
| 54 |
+
The chart below shows the performance of a simple Buy and Hold strategy during the 2021-2023 test period, setting a clear target for our agents to beat.
|
| 55 |
+
|
| 56 |
+

|
| 57 |
+
|
| 58 |
+
---
|
| 59 |
+
|
| 60 |
+
## π Key Findings & The Champion Agent
|
| 61 |
+
|
| 62 |
+
After extensive training, evaluation, and stress-testing, the **TD3 agent emerged as the clear winner** on a risk-adjusted basis. While other agents achieved higher raw returns, their strategies proved to be brittle and dangerously volatile during market crises. The TD3 agent's strategy was the most robust and reliable.
|
| 63 |
+
|
| 64 |
+
#### Final Performance Comparison (2021-2023)
|
| 65 |
+
|
| 66 |
+
This table summarizes the performance of the top-performing static agents against the baseline.
|
| 67 |
+
|
| 68 |
+
| Metric | **TD3 Agent** | SAC Agent | Buy & Hold |
|
| 69 |
+
| :--- | :--- | :--- | :--- |
|
| 70 |
+
| **Total Return** | 47.24% | **50.89%** | 34.91% |
|
| 71 |
+
| **CAGR** | 13.76% | **14.70%** | 10.50% |
|
| 72 |
+
| **Sharpe Ratio** | **0.62** | 0.51 | 0.45 |
|
| 73 |
+
| **Max Drawdown** | **-28.41%** | -44.61% | -40.81% |
|
| 74 |
+
|
| 75 |
+
The TD3 agent delivered strong returns while significantly reducing the maximum drawdown, proving its superior capital preservation strategy.
|
| 76 |
+
|
| 77 |
+

|
| 78 |
+
|
| 79 |
+
---
|
| 80 |
+
|
| 81 |
+
## π§ Comparative Analysis of Agent Strategies
|
| 82 |
+
|
| 83 |
+
A fascinating outcome of this project was observing three different RL algorithms independently discover three distinct and recognizable investment philosophies.
|
| 84 |
+
|
| 85 |
+
### π₯ TD3: The Prudent Risk-Manager
|
| 86 |
+
|
| 87 |
+
The TD3 agent concluded that the most effective strategy was not to trade frequently, but to find one **superior, risk-managed static asset allocation** and hold it.
|
| 88 |
+
|
| 89 |
+
* **Strategy:** "Smarter Buy and Hold".
|
| 90 |
+
* **Behavior:** The agent's allocation is completely static, indicating it focused on the initial strategic decision and ignored market noise to minimize transaction costs.
|
| 91 |
+
* **Result:** This approach led to the best risk-adjusted returns, proving that a robust initial setup is more valuable than reactive trading.
|
| 92 |
+
|
| 93 |
+

|
| 94 |
+
|
| 95 |
+
### π SAC: The Aggressive Growth Engine
|
| 96 |
+
|
| 97 |
+
The SAC agent also learned a static allocation strategy, but its portfolio was geared for **maximum growth**, accepting higher risk for higher potential returns.
|
| 98 |
+
|
| 99 |
+
* **Strategy:** High-risk, high-return static allocation.
|
| 100 |
+
* **Behavior:** Like TD3, it made one initial allocation and held firm. However, this allocation was far more aggressive.
|
| 101 |
+
* **Result:** It achieved the highest total return in some periods but suffered catastrophic drawdowns in stress tests, making its strategy unreliable and brittle.
|
| 102 |
+
|
| 103 |
+

|
| 104 |
+
|
| 105 |
+
### π PPO: The Active (but Inconsistent) Trader
|
| 106 |
+
|
| 107 |
+
Unlike the other two, the PPO agent learned an **active, dynamic trading strategy**, constantly adjusting its portfolio based on market conditions.
|
| 108 |
+
|
| 109 |
+
* **Strategy:** Tactical asset allocation.
|
| 110 |
+
* **Behavior:** The allocation chart clearly shows the agent rebalancing its portfolio over time, for example, by increasing its bond (TLT) holdings during the 2022 downturn.
|
| 111 |
+
* **Result:** While impressive that it learned this behavior, its performance was inconsistent. It succeeded in some periods (2018) but failed in others (2025), highlighting the immense difficulty of successful market timing.
|
| 112 |
+
|
| 113 |
+

|
| 114 |
+
|
| 115 |
+
---
|
| 116 |
+
|
| 117 |
+
## πͺοΈ Stress Testing: The Ultimate Test of Robustness
|
| 118 |
+
|
| 119 |
+
A model is only as good as its performance during a crisis. We subjected the agents to multiple out-of-sample stress tests, with the 2018 period (featuring a crypto winter and a stock market flash crash) being the most revealing.
|
| 120 |
+
|
| 121 |
+

|
| 122 |
+
|
| 123 |
+
* **TD3's Triumph:** The orange line shows the TD3 agent successfully navigating the downturn, preserving capital far better than the baseline.
|
| 124 |
+
* **SAC's Failure:** The green line shows the SAC agent's aggressive strategy failing catastrophically, resulting in a massive drawdown.
|
| 125 |
+
|
| 126 |
+
This test definitively proved that the **TD3 agent's risk-managed approach was truly robust**, while the SAC agent's strategy was fragile.
|
| 127 |
+
|
| 128 |
+
---
|
| 129 |
+
|
| 130 |
+
## π¬ The Research Journey: Why Simplicity Won
|
| 131 |
+
|
| 132 |
+
This project was also an exercise in scientific methodology. We initially hypothesized that more complex models and features would yield better results.
|
| 133 |
+
|
| 134 |
+
* **Hypothesis 1: More features are better.** We tested adding technical indicators (RSI, MACD) to the observation space. **Result:** Performance degraded. The indicators acted as noise, confusing the agents.
|
| 135 |
+
* **Hypothesis 2: Models with memory are better.** We tested an LSTM-based agent (`RecurrentPPO`). **Result:** Performance degraded. The added complexity led to overfitting on the training data.
|
| 136 |
+
* **Hypothesis 3: Using Regularization is better.** We tested both L1 and L2 regularization. **Results:** Performance degraded.
|
| 137 |
+
* **Hypothesis 4: Increasing the window from 30 days is better.** We tested increasing the window to 60 days. **Results:** Performance degraded. increasing the context window is not always good and it could be seen as more noise for the model.
|
| 138 |
+
|
| 139 |
+
The conclusion was clear: for this problem, a simple and elegant model (a standard MLP fed with just normalized price data) was the most effective.
|
| 140 |
+
|
| 141 |
+
---
|
| 142 |
+
|
| 143 |
+
## β
Conclusion
|
| 144 |
+
|
| 145 |
+
This project successfully demonstrates that Deep Reinforcement Learning can be a powerful tool for discovering sophisticated investment strategies. The key insight is that the most robust and successful agent did not learn to be a hyperactive trader, but rather a prudent strategic allocator, emphasizing the timeless investment principle that effective risk management is the true key to long-term success.
|
| 146 |
+
|
| 147 |
+
---
|
| 148 |
+
|
| 149 |
+
## π Project Structure
|
| 150 |
+
|
| 151 |
+
The codebase is organized into modular, reusable scripts.
|
| 152 |
+
|
| 153 |
+
```bash
|
| 154 |
+
βββ assets/
|
| 155 |
+
βββ checkpoints/ # Holds all saved model .zip files
|
| 156 |
+
βββ results/ # Holds all output plots and metrics
|
| 157 |
+
βββ scripts/
|
| 158 |
+
β βββ environment.py # The custom Gymnasium environment for the simulation
|
| 159 |
+
β βββ fetch_market_data.py# A flexible script to download data for any period
|
| 160 |
+
β βββ train.py # The main training script with model selection
|
| 161 |
+
β βββ evaluate.py # The main evaluation script for generating metrics
|
| 162 |
+
β βββ stress_test.py # Runs a full comparison of all agents on a given dataset
|
| 163 |
+
β βββ visualize_strategy.py # Plots the asset allocation of a single trained agent
|
| 164 |
+
βββ README.md # This file
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
---
|
| 168 |
+
|
| 169 |
+
## π How to Run
|
| 170 |
+
|
| 171 |
+
### Setup
|
| 172 |
+
|
| 173 |
+
1. Clone the repository.
|
| 174 |
+
2. Create and activate a Python virtual environment.
|
| 175 |
+
3. Install the required packages:
|
| 176 |
+
|
| 177 |
+
```bash
|
| 178 |
+
pip install -r requirements.txt
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
### Data Fetching
|
| 182 |
+
|
| 183 |
+
Use the flexible `fetch_market_data.py` script to get any data you need.
|
| 184 |
+
|
| 185 |
+
```bash
|
| 186 |
+
# Fetch the default training data (2015-2021)
|
| 187 |
+
python fetch_market_data.py --start 2015-01-01 --end 2020-12-31 --filename data/train.csv
|
| 188 |
+
|
| 189 |
+
# Fetch data for a stress test (e.g., 2022)
|
| 190 |
+
python fetch_market_data.py --start 2022-01-01 --end 2022-12-31 --filename data/test_2022.csv
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
### Training
|
| 194 |
+
|
| 195 |
+
Use the `train.py` script to train any of the three main agents.
|
| 196 |
+
|
| 197 |
+
```bash
|
| 198 |
+
# Train the champion TD3 agent (default)
|
| 199 |
+
python src/train.py --agent td3
|
| 200 |
+
|
| 201 |
+
# Train a SAC agent for more timesteps
|
| 202 |
+
python src/train.py --agent sac --timesteps 100000
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
### Evaluation & Visualization
|
| 206 |
+
|
| 207 |
+
Use the dedicated scripts to analyze the results.
|
| 208 |
+
|
| 209 |
+
```bash
|
| 210 |
+
# Run a full stress test on the 2018 data
|
| 211 |
+
python stress_test.py --datafile data/stress_test_2018.csv
|
| 212 |
+
|
| 213 |
+
# Visualize the TD3 agent's strategy
|
| 214 |
+
python visualize_strategy.py --agent td3 --checkpoint td3_portfolio_model.zip
|
| 215 |
+
```
|
assets/banner.png
ADDED
|
Git LFS Details
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core RL and Simulation
|
| 2 |
+
stable-baselines3==2.7.0
|
| 3 |
+
sb3_contrib==2.7.0
|
| 4 |
+
gymnasium==1.2.1
|
| 5 |
+
|
| 6 |
+
# Data Handling and Numerics
|
| 7 |
+
pandas==2.3.3
|
| 8 |
+
numpy==2.2.6
|
| 9 |
+
scikit-learn==1.6.1
|
| 10 |
+
|
| 11 |
+
# Data Fetching
|
| 12 |
+
yfinance==0.2.66
|
| 13 |
+
|
| 14 |
+
# Financial Indicators
|
| 15 |
+
pandas-ta==0.4.71b0
|
| 16 |
+
|
| 17 |
+
# Plotting and Visualization
|
| 18 |
+
matplotlib==3.10.0
|
| 19 |
+
seaborn==0.13.2
|
results/baseline_results.png
ADDED
|
Git LFS Details
|
results/final_performance_comparison_all_agents.png
ADDED
|
Git LFS Details
|
results/ppo_portfolio_alocation.png
ADDED
|
Git LFS Details
|
results/sac_portfolio_alocation.png
ADDED
|
Git LFS Details
|
results/stress_test_comparison_2018.png
ADDED
|
Git LFS Details
|
results/td3_portfolio_alocation.png
ADDED
|
Git LFS Details
|
scripts/check_env.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from stable_baselines3.common.env_checker import check_env
|
| 3 |
+
from environment import PortfolioEnv
|
| 4 |
+
|
| 5 |
+
def main():
|
| 6 |
+
"""
|
| 7 |
+
Main function to create and check the custom portfolio environment.
|
| 8 |
+
"""
|
| 9 |
+
print("--- Loading Data and Creating Environment ---")
|
| 10 |
+
try:
|
| 11 |
+
# Load your training data
|
| 12 |
+
df = pd.read_csv('data/train.csv', index_col='Date', parse_dates=True)
|
| 13 |
+
# Create an instance of your environment
|
| 14 |
+
env = PortfolioEnv(df)
|
| 15 |
+
print("Environment created successfully.")
|
| 16 |
+
except FileNotFoundError:
|
| 17 |
+
print("β Error: 'data/train.csv' not found. Make sure you've run the data fetching script.")
|
| 18 |
+
return
|
| 19 |
+
|
| 20 |
+
print("\n--- Checking Environment Compatibility ---")
|
| 21 |
+
try:
|
| 22 |
+
# The check_env function will raise an error if the environment is not compatible.
|
| 23 |
+
check_env(env)
|
| 24 |
+
print("β
Environment check passed!")
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print("β Environment check failed:")
|
| 27 |
+
# It's helpful to print the full traceback for debugging complex errors.
|
| 28 |
+
import traceback
|
| 29 |
+
traceback.print_exc()
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
main()
|
scripts/environment.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from gymnasium import spaces
|
| 5 |
+
|
| 6 |
+
class PortfolioEnv(gym.Env):
|
| 7 |
+
"""
|
| 8 |
+
A custom reinforcement learning environment for portfolio management.
|
| 9 |
+
|
| 10 |
+
This environment simulates the daily trading of multiple financial assets. The agent's
|
| 11 |
+
goal is to learn a policy for allocating capital to maximize risk-adjusted returns.
|
| 12 |
+
"""
|
| 13 |
+
metadata = {'render_modes': ['human']}
|
| 14 |
+
|
| 15 |
+
def __init__(self, df, window_size=30, initial_balance=10000, transaction_cost_pct=0.001):
|
| 16 |
+
"""
|
| 17 |
+
Initializes the portfolio management environment.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
df (pd.DataFrame): A DataFrame containing the daily closing prices of the assets.
|
| 21 |
+
The index should be dates and columns should be asset tickers.
|
| 22 |
+
window_size (int): The number of past days of price data to include in the observation.
|
| 23 |
+
initial_balance (float): The starting capital for the portfolio.
|
| 24 |
+
transaction_cost_pct (float): The percentage cost for each trade (e.g., 0.001 for 0.1%).
|
| 25 |
+
"""
|
| 26 |
+
super(PortfolioEnv, self).__init__()
|
| 27 |
+
|
| 28 |
+
# --- Basic Environment Parameters ---
|
| 29 |
+
self.df = df
|
| 30 |
+
self.window_size = window_size
|
| 31 |
+
self.initial_balance = initial_balance
|
| 32 |
+
self.transaction_cost_pct = transaction_cost_pct
|
| 33 |
+
self.n_assets = len(df.columns)
|
| 34 |
+
|
| 35 |
+
# --- Action Space ---
|
| 36 |
+
# The agent outputs a vector of continuous values, one for each asset plus one for cash.
|
| 37 |
+
# These raw outputs are then converted to portfolio weights via a softmax function.
|
| 38 |
+
# The space is defined from -1 to 1 for better compatibility with standard RL algorithms.
|
| 39 |
+
# Shape: (number of assets + 1 for cash)
|
| 40 |
+
self.action_space = spaces.Box(
|
| 41 |
+
low=-1, high=1, shape=(self.n_assets + 1,), dtype=np.float32
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# --- Observation Space ---
|
| 45 |
+
# The agent observes a window of past price data, flattened into a 1D vector.
|
| 46 |
+
# Shape: (window_size * number of assets)
|
| 47 |
+
self.observation_space = spaces.Box(
|
| 48 |
+
low=-np.inf, high=np.inf,
|
| 49 |
+
shape=(self.window_size * self.n_assets,),
|
| 50 |
+
dtype=np.float32
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# --- Internal State Variables ---
|
| 54 |
+
# These variables track the state of the simulation over time.
|
| 55 |
+
self._current_step = 0
|
| 56 |
+
self._portfolio_value = 0.0
|
| 57 |
+
# Weights for each asset + cash, e.g., [w_aapl, w_msft, ..., w_cash]
|
| 58 |
+
self._weights = np.zeros(self.n_assets + 1)
|
| 59 |
+
|
| 60 |
+
def reset(self, seed=None):
|
| 61 |
+
"""
|
| 62 |
+
Resets the environment to its initial state for a new episode.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
tuple: A tuple containing the initial observation and auxiliary info.
|
| 66 |
+
"""
|
| 67 |
+
super().reset(seed=seed)
|
| 68 |
+
|
| 69 |
+
# Start the simulation at the first point where a full window of data is available.
|
| 70 |
+
self._current_step = self.window_size
|
| 71 |
+
self._portfolio_value = self.initial_balance
|
| 72 |
+
|
| 73 |
+
# Initialize weights to be 100% in cash.
|
| 74 |
+
self._weights = np.zeros(self.n_assets + 1)
|
| 75 |
+
self._weights[-1] = 1.0 # Last element represents cash
|
| 76 |
+
|
| 77 |
+
observation = self._get_obs()
|
| 78 |
+
info = self._get_info()
|
| 79 |
+
|
| 80 |
+
return observation, info
|
| 81 |
+
|
| 82 |
+
def step(self, action):
|
| 83 |
+
"""
|
| 84 |
+
Executes one time step within the environment based on the agent's action.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
action (np.ndarray): The raw output from the agent's policy network.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
tuple: A tuple containing the next observation, reward, terminated flag,
|
| 91 |
+
truncated flag, and auxiliary info.
|
| 92 |
+
"""
|
| 93 |
+
# 1. Store the portfolio value before taking the action.
|
| 94 |
+
current_portfolio_value = self._portfolio_value
|
| 95 |
+
|
| 96 |
+
# 2. Convert the raw action into portfolio weights using the softmax function.
|
| 97 |
+
# This ensures the weights are positive and sum to 1.
|
| 98 |
+
target_weights = np.exp(action) / np.sum(np.exp(action))
|
| 99 |
+
|
| 100 |
+
# 3. Calculate the cost of rebalancing the portfolio.
|
| 101 |
+
# The cost is based on the total value of assets bought or sold.
|
| 102 |
+
trades = (target_weights[:-1] - self._weights[:-1]) * current_portfolio_value
|
| 103 |
+
transaction_costs = np.sum(np.abs(trades)) * self.transaction_cost_pct
|
| 104 |
+
|
| 105 |
+
# 4. Update the internal state: apply costs, set new weights, and advance time.
|
| 106 |
+
self._balance = current_portfolio_value - transaction_costs
|
| 107 |
+
self._weights = target_weights
|
| 108 |
+
self._current_step += 1
|
| 109 |
+
|
| 110 |
+
# 5. Calculate the new portfolio value based on the market's price movement.
|
| 111 |
+
current_prices = self.df.iloc[self._current_step - 1].values
|
| 112 |
+
next_prices = self.df.iloc[self._current_step].values
|
| 113 |
+
price_ratio = next_prices / current_prices # How much each asset's price changed.
|
| 114 |
+
|
| 115 |
+
# The new value of our asset holdings.
|
| 116 |
+
asset_values_after_price_change = (self._weights[:-1] * self._balance) * price_ratio
|
| 117 |
+
|
| 118 |
+
# The new total portfolio value is the sum of the updated asset values plus the cash holding.
|
| 119 |
+
new_portfolio_value = np.sum(asset_values_after_price_change) + (self._weights[-1] * self._balance)
|
| 120 |
+
self._portfolio_value = new_portfolio_value
|
| 121 |
+
|
| 122 |
+
# 6. Calculate the reward for the agent.
|
| 123 |
+
# The reward is the log return of the portfolio value, which encourages geometric growth.
|
| 124 |
+
reward = np.log(new_portfolio_value / current_portfolio_value)
|
| 125 |
+
|
| 126 |
+
# 7. Check for termination conditions.
|
| 127 |
+
# The episode ends if the agent goes broke or runs out of data.
|
| 128 |
+
terminated = bool(self._portfolio_value <= self.initial_balance * 0.5)
|
| 129 |
+
truncated = self._current_step >= len(self.df) - 1
|
| 130 |
+
|
| 131 |
+
observation = self._get_obs()
|
| 132 |
+
info = self._get_info()
|
| 133 |
+
|
| 134 |
+
return observation, reward, terminated, truncated, info
|
| 135 |
+
|
| 136 |
+
def _get_obs(self):
|
| 137 |
+
"""
|
| 138 |
+
Constructs the observation for the agent at the current time step.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
np.ndarray: A flattened 1D array of the normalized price history.
|
| 142 |
+
"""
|
| 143 |
+
# Get the window of historical price data.
|
| 144 |
+
price_window = self.df.iloc[self._current_step - self.window_size : self._current_step].values
|
| 145 |
+
|
| 146 |
+
# Normalize the window by dividing by the first price. This helps the agent
|
| 147 |
+
# focus on relative price changes rather than absolute values.
|
| 148 |
+
normalized_window = price_window / price_window[0]
|
| 149 |
+
|
| 150 |
+
return normalized_window.flatten().astype(np.float32)
|
| 151 |
+
|
| 152 |
+
def _get_info(self):
|
| 153 |
+
"""
|
| 154 |
+
Returns a dictionary of auxiliary information about the current state.
|
| 155 |
+
"""
|
| 156 |
+
return {
|
| 157 |
+
'step': self._current_step,
|
| 158 |
+
'portfolio_value': self._portfolio_value,
|
| 159 |
+
'weights': self._weights,
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
def render(self, mode='human'):
|
| 163 |
+
"""
|
| 164 |
+
Renders the environment's state (optional).
|
| 165 |
+
"""
|
| 166 |
+
if mode == 'human':
|
| 167 |
+
info = self._get_info()
|
| 168 |
+
print(f"Step: {info['step']}, Portfolio Value: {info['portfolio_value']:.2f}")
|
| 169 |
+
|
| 170 |
+
def close(self):
|
| 171 |
+
"""
|
| 172 |
+
Cleans up the environment (optional).
|
| 173 |
+
"""
|
| 174 |
+
pass
|
scripts/evaluate.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from stable_baselines3 import SAC ,PPO , TD3
|
| 5 |
+
from evaluate_baselines import buy_and_hold
|
| 6 |
+
from environment import PortfolioEnv
|
| 7 |
+
from matplotlib.ticker import FuncFormatter
|
| 8 |
+
|
| 9 |
+
# --- Helper Function to Run the RL Agent ---
|
| 10 |
+
|
| 11 |
+
def evaluate_agent(env, model):
|
| 12 |
+
"""
|
| 13 |
+
Runs the trained agent on the environment and returns portfolio values.
|
| 14 |
+
"""
|
| 15 |
+
obs, info = env.reset()
|
| 16 |
+
terminated, truncated = False, False
|
| 17 |
+
|
| 18 |
+
portfolio_values = [env.initial_balance]
|
| 19 |
+
|
| 20 |
+
while not (terminated or truncated):
|
| 21 |
+
action, _states = model.predict(obs, deterministic=True)
|
| 22 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
| 23 |
+
portfolio_values.append(info['portfolio_value'])
|
| 24 |
+
|
| 25 |
+
return pd.Series(portfolio_values, index=env.df.index[:len(portfolio_values)])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def calculate_metrics(portfolio_values, freq=252, rf=0.0):
|
| 29 |
+
"""
|
| 30 |
+
Calculates key performance metrics from a series of portfolio values.
|
| 31 |
+
freq: number of trading periods in a year (252 for daily, 52 for weekly).
|
| 32 |
+
rf: risk-free rate (default = 0 for simplicity).
|
| 33 |
+
"""
|
| 34 |
+
returns = portfolio_values.pct_change().dropna()
|
| 35 |
+
|
| 36 |
+
# Total Return
|
| 37 |
+
total_return = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) - 1
|
| 38 |
+
|
| 39 |
+
# CAGR
|
| 40 |
+
num_years = (len(portfolio_values) / freq)
|
| 41 |
+
cagr = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) ** (1/num_years) - 1
|
| 42 |
+
|
| 43 |
+
# Sharpe Ratio
|
| 44 |
+
sharpe_ratio = np.sqrt(freq) * (returns.mean() - rf) / returns.std()
|
| 45 |
+
|
| 46 |
+
# Sortino Ratio (downside risk only)
|
| 47 |
+
downside_returns = returns[returns < 0]
|
| 48 |
+
downside_std = downside_returns.std()
|
| 49 |
+
sortino_ratio = np.sqrt(freq) * (returns.mean() - rf) / downside_std if downside_std > 0 else np.nan
|
| 50 |
+
|
| 51 |
+
# Volatility (annualized std)
|
| 52 |
+
volatility = returns.std() * np.sqrt(freq)
|
| 53 |
+
|
| 54 |
+
# Max Drawdown
|
| 55 |
+
rolling_max = portfolio_values.cummax()
|
| 56 |
+
drawdown = portfolio_values / rolling_max - 1.0
|
| 57 |
+
max_drawdown = drawdown.min()
|
| 58 |
+
|
| 59 |
+
# Calmar Ratio
|
| 60 |
+
calmar_ratio = cagr / abs(max_drawdown / 100) if max_drawdown != 0 else np.nan
|
| 61 |
+
|
| 62 |
+
return {
|
| 63 |
+
"Total Return": f"{total_return:.2%}",
|
| 64 |
+
"CAGR": f"{cagr:.2%}",
|
| 65 |
+
"Sharpe Ratio": f"{sharpe_ratio:.2f}",
|
| 66 |
+
"Sortino Ratio": f"{sortino_ratio:.2f}",
|
| 67 |
+
"Volatility": f"{volatility:.2%}",
|
| 68 |
+
"Max Drawdown": f"{max_drawdown:.2%}",
|
| 69 |
+
"Calmar Ratio": f"{calmar_ratio:.2f}"
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def main(test_data_path='data/test.csv'):
|
| 74 |
+
"""
|
| 75 |
+
Loads, evaluates, and plots the performance of PPO, SAC, and TD3 agents
|
| 76 |
+
against a Buy and Hold baseline.
|
| 77 |
+
"""
|
| 78 |
+
# --- Define Model Paths and Agent Types ---
|
| 79 |
+
models_to_evaluate = {
|
| 80 |
+
"PPO Agent": (PPO, 'checkpoints/ppo_portfolio_model'),
|
| 81 |
+
"SAC Agent": (SAC, 'checkpoints/sac_portfolio_model'),
|
| 82 |
+
"TD3 Agent": (TD3, 'checkpoints/td3_portfolio_model')
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
# Load test data
|
| 86 |
+
test_df = pd.read_csv(test_data_path, index_col='Date', parse_dates=True)
|
| 87 |
+
|
| 88 |
+
# Dictionary to store results
|
| 89 |
+
portfolio_values = {}
|
| 90 |
+
metrics = {}
|
| 91 |
+
|
| 92 |
+
# --- Run Evaluations for each RL Agent---
|
| 93 |
+
for name, (agent_type, model_path) in models_to_evaluate.items():
|
| 94 |
+
print(f"--- Evaluating {name} ---")
|
| 95 |
+
model = agent_type.load(model_path)
|
| 96 |
+
env = PortfolioEnv(test_df)
|
| 97 |
+
portfolio_values[name] = evaluate_agent(env, model)
|
| 98 |
+
metrics[name] = calculate_metrics(portfolio_values[name])
|
| 99 |
+
|
| 100 |
+
# --- Evaluate Buy and Hold Baseline ---
|
| 101 |
+
print("\n--- Evaluating Buy and Hold Baseline ---")
|
| 102 |
+
bnh_values = buy_and_hold(test_df)
|
| 103 |
+
portfolio_values["Buy and Hold"] = bnh_values
|
| 104 |
+
metrics["Buy and Hold"] = calculate_metrics(bnh_values)
|
| 105 |
+
|
| 106 |
+
# --- Combine and Print Metrics ---
|
| 107 |
+
print("\n--- Performance Metrics ---")
|
| 108 |
+
metrics_df = pd.DataFrame(metrics)
|
| 109 |
+
print(metrics_df)
|
| 110 |
+
|
| 111 |
+
# --- Plotting All Strategies ---
|
| 112 |
+
plt.style.use('seaborn-v0_8-darkgrid')
|
| 113 |
+
fig, ax = plt.subplots(figsize=(14, 8))
|
| 114 |
+
|
| 115 |
+
# Define colors for clarity
|
| 116 |
+
colors = {
|
| 117 |
+
"PPO Agent": "red",
|
| 118 |
+
"SAC Agent": "green",
|
| 119 |
+
"TD3 Agent": "orange",
|
| 120 |
+
"Buy and Hold": "blue"
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
for name, values in portfolio_values.items():
|
| 124 |
+
ax.plot(values.index, values, label=name, color=colors[name], linewidth=2)
|
| 125 |
+
|
| 126 |
+
ax.set_title('Agent Performance Comparison', fontsize=16)
|
| 127 |
+
ax.set_xlabel('Date', fontsize=12)
|
| 128 |
+
ax.set_ylabel('Portfolio Value ($)', fontsize=12)
|
| 129 |
+
ax.legend(fontsize=12)
|
| 130 |
+
|
| 131 |
+
formatter = FuncFormatter(lambda x, p: f'${x:,.0f}')
|
| 132 |
+
ax.yaxis.set_major_formatter(formatter)
|
| 133 |
+
|
| 134 |
+
plt.tight_layout()
|
| 135 |
+
plt.savefig('results/final_performance_comparison_all_agents.png')
|
| 136 |
+
plt.show()
|
| 137 |
+
|
| 138 |
+
# Example of how to run this main function
|
| 139 |
+
if __name__ == '__main__':
|
| 140 |
+
# You can specify a different test file here if needed
|
| 141 |
+
# e.g., main(test_data_path='data/stress_test_2018.csv')
|
| 142 |
+
main()
|
scripts/evaluate_baselines.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# evaluate_baselines.py
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
|
| 7 |
+
def buy_and_hold(df, initial_balance=10000):
|
| 8 |
+
"""
|
| 9 |
+
Simulates the Buy and Hold strategy.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
df (pd.DataFrame): DataFrame with daily asset prices.
|
| 13 |
+
initial_balance (int): The starting capital.
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
pd.Series: A Series containing the portfolio value for each day.
|
| 17 |
+
"""
|
| 18 |
+
print("--- Simulating Buy and Hold ---")
|
| 19 |
+
n_assets = len(df.columns)
|
| 20 |
+
|
| 21 |
+
# Invest an equal amount in each asset at the beginning
|
| 22 |
+
initial_investment_per_asset = initial_balance / n_assets
|
| 23 |
+
|
| 24 |
+
# Get the initial prices
|
| 25 |
+
initial_prices = df.iloc[0]
|
| 26 |
+
|
| 27 |
+
# Calculate the number of shares bought for each asset
|
| 28 |
+
shares = initial_investment_per_asset / initial_prices
|
| 29 |
+
|
| 30 |
+
# Calculate the portfolio value for each day
|
| 31 |
+
portfolio_values = df.dot(shares)
|
| 32 |
+
|
| 33 |
+
print(f"Initial Investment: ${initial_balance:.2f}")
|
| 34 |
+
print(f"Final Portfolio Value: ${portfolio_values.iloc[-1]:.2f}")
|
| 35 |
+
|
| 36 |
+
return portfolio_values
|
| 37 |
+
|
| 38 |
+
def equally_weighted_rebalanced(df, initial_balance=10000, rebalance_freq='M', transaction_cost_pct=0.001):
|
| 39 |
+
"""
|
| 40 |
+
Simulates an Equally Weighted Portfolio with periodic rebalancing.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
df (pd.DataFrame): DataFrame with daily asset prices.
|
| 44 |
+
initial_balance (int): The starting capital.
|
| 45 |
+
rebalance_freq (str): The rebalancing frequency ('M' for monthly, 'Q' for quarterly).
|
| 46 |
+
transaction_cost_pct (float): The transaction cost as a percentage.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
pd.Series: A Series containing the portfolio value for each day.
|
| 50 |
+
"""
|
| 51 |
+
print(f"\n--- Simulating Equally Weighted Portfolio (Rebalanced {rebalance_freq}) ---")
|
| 52 |
+
n_assets = len(df.columns)
|
| 53 |
+
|
| 54 |
+
# Set the initial weights to be equal
|
| 55 |
+
weights = np.full(n_assets, 1/n_assets)
|
| 56 |
+
|
| 57 |
+
portfolio_value = initial_balance
|
| 58 |
+
portfolio_values = pd.Series(index=df.index)
|
| 59 |
+
|
| 60 |
+
last_rebalance_date = None
|
| 61 |
+
|
| 62 |
+
for date, prices in df.iterrows():
|
| 63 |
+
# Store the portfolio value for the day before any changes
|
| 64 |
+
portfolio_values[date] = portfolio_value
|
| 65 |
+
|
| 66 |
+
# Determine if it's a rebalancing day
|
| 67 |
+
# Rebalance on the first day of the new period (month, quarter)
|
| 68 |
+
if last_rebalance_date is None or (date.month != last_rebalance_date.month and rebalance_freq == 'M'):
|
| 69 |
+
|
| 70 |
+
# Calculate the value of trades to rebalance
|
| 71 |
+
target_asset_values = portfolio_value * (1/n_assets)
|
| 72 |
+
current_asset_values = weights * portfolio_value
|
| 73 |
+
trades = target_asset_values - current_asset_values
|
| 74 |
+
|
| 75 |
+
# Apply transaction costs
|
| 76 |
+
transaction_costs = np.sum(np.abs(trades)) * transaction_cost_pct
|
| 77 |
+
portfolio_value -= transaction_costs
|
| 78 |
+
|
| 79 |
+
# Reset weights to be equal
|
| 80 |
+
weights = np.full(n_assets, 1/n_assets)
|
| 81 |
+
last_rebalance_date = date
|
| 82 |
+
|
| 83 |
+
# Calculate portfolio value for the *next* day before the market opens
|
| 84 |
+
# Get price changes from today to the next trading day
|
| 85 |
+
today_prices = df.loc[date]
|
| 86 |
+
next_day_index = df.index.get_loc(date) + 1
|
| 87 |
+
if next_day_index < len(df):
|
| 88 |
+
next_day_prices = df.iloc[next_day_index]
|
| 89 |
+
price_change_ratio = next_day_prices / today_prices
|
| 90 |
+
|
| 91 |
+
# Update portfolio value based on price changes
|
| 92 |
+
portfolio_value = np.sum( (weights * portfolio_value) * price_change_ratio )
|
| 93 |
+
|
| 94 |
+
# Update weights due to market drift
|
| 95 |
+
new_asset_values = (weights * portfolio_value) * price_change_ratio
|
| 96 |
+
weights = new_asset_values / np.sum(new_asset_values)
|
| 97 |
+
|
| 98 |
+
print(f"Initial Investment: ${initial_balance:.2f}")
|
| 99 |
+
print(f"Final Portfolio Value: ${portfolio_values.iloc[-1]:.2f}")
|
| 100 |
+
|
| 101 |
+
return portfolio_values.dropna()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def main():
|
| 105 |
+
# Load the test data
|
| 106 |
+
test_df = pd.read_csv('data/test.csv', index_col='Date', parse_dates=True)
|
| 107 |
+
|
| 108 |
+
# --- Run Baseline Strategies ---
|
| 109 |
+
bnh_values = buy_and_hold(test_df)
|
| 110 |
+
ewp_values = equally_weighted_rebalanced(test_df)
|
| 111 |
+
|
| 112 |
+
# --- Plot the results ---
|
| 113 |
+
plt.style.use('seaborn-v0_8-darkgrid')
|
| 114 |
+
fig, ax = plt.subplots(figsize=(14, 8))
|
| 115 |
+
|
| 116 |
+
ax.plot(bnh_values.index, bnh_values, label='Buy and Hold', color='blue', linewidth=2)
|
| 117 |
+
ax.plot(ewp_values.index, ewp_values, label='Equally Weighted (Rebalanced Monthly)', color='green', linewidth=2)
|
| 118 |
+
|
| 119 |
+
ax.set_title('Baseline Strategy Performance (2021-2023)', fontsize=16)
|
| 120 |
+
ax.set_xlabel('Date', fontsize=12)
|
| 121 |
+
ax.set_ylabel('Portfolio Value ($)', fontsize=12)
|
| 122 |
+
ax.legend(fontsize=12)
|
| 123 |
+
|
| 124 |
+
# Format the y-axis to show currency
|
| 125 |
+
from matplotlib.ticker import FuncFormatter
|
| 126 |
+
formatter = FuncFormatter(lambda x, p: f'${x:,.0f}')
|
| 127 |
+
ax.yaxis.set_major_formatter(formatter)
|
| 128 |
+
|
| 129 |
+
plt.tight_layout()
|
| 130 |
+
plt.savefig('baseline_performance.png')
|
| 131 |
+
plt.show()
|
| 132 |
+
|
| 133 |
+
if __name__ == '__main__':
|
| 134 |
+
main()
|
scripts/fetch_data.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yfinance as yf
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# --- Configuration ---
|
| 6 |
+
# Asset tickers
|
| 7 |
+
TICKERS = ["AAPL", "MSFT", "SPY", "TLT", "BTC-USD"]
|
| 8 |
+
# Time periods for training and testing
|
| 9 |
+
TRAIN_START_DATE = "2015-01-01"
|
| 10 |
+
TRAIN_END_DATE = "2020-12-31"
|
| 11 |
+
TEST_START_DATE = "2021-01-01"
|
| 12 |
+
TEST_END_DATE = "2023-12-31"
|
| 13 |
+
|
| 14 |
+
# Directory to save the data
|
| 15 |
+
DATA_DIR = "data"
|
| 16 |
+
TRAIN_DATA_PATH = os.path.join(DATA_DIR, "train.csv")
|
| 17 |
+
TEST_DATA_PATH = os.path.join(DATA_DIR, "test.csv")
|
| 18 |
+
|
| 19 |
+
# --- Data Fetching and Processing ---
|
| 20 |
+
|
| 21 |
+
def fetch_and_prepare_data(start_date, end_date, tickers):
|
| 22 |
+
"""
|
| 23 |
+
Fetches historical data for the given tickers and processes it.
|
| 24 |
+
Returns a DataFrame with 'Close' prices for each ticker.
|
| 25 |
+
"""
|
| 26 |
+
print(f"Fetching data from {start_date} to {end_date} for {tickers}...")
|
| 27 |
+
|
| 28 |
+
data = yf.download(tickers, start=start_date, end=end_date)
|
| 29 |
+
|
| 30 |
+
# CHANGE: Add .copy() to explicitly create a new DataFrame and avoid warnings.
|
| 31 |
+
close_data = data['Close'].copy()
|
| 32 |
+
|
| 33 |
+
print("\nData Head:")
|
| 34 |
+
print(close_data.head())
|
| 35 |
+
|
| 36 |
+
print("\nMissing values before cleaning:")
|
| 37 |
+
print(close_data.isnull().sum())
|
| 38 |
+
|
| 39 |
+
# Now, all inplace operations are safely performed on our own copy.
|
| 40 |
+
close_data.ffill(inplace=True)
|
| 41 |
+
close_data.bfill(inplace=True)
|
| 42 |
+
|
| 43 |
+
print("\nMissing values after cleaning:")
|
| 44 |
+
print(close_data.isnull().sum())
|
| 45 |
+
|
| 46 |
+
for col in close_data.columns:
|
| 47 |
+
close_data[col] = pd.to_numeric(close_data[col], errors='coerce')
|
| 48 |
+
|
| 49 |
+
close_data.dropna(inplace=True)
|
| 50 |
+
|
| 51 |
+
return close_data
|
| 52 |
+
|
| 53 |
+
def main():
|
| 54 |
+
"""Main function to run the data fetching process."""
|
| 55 |
+
# Create data directory if it doesn't exist
|
| 56 |
+
if not os.path.exists(DATA_DIR):
|
| 57 |
+
os.makedirs(DATA_DIR)
|
| 58 |
+
print(f"Created directory: {DATA_DIR}")
|
| 59 |
+
|
| 60 |
+
# Fetch, process, and save training data
|
| 61 |
+
print("--- Preparing Training Data ---")
|
| 62 |
+
train_data = fetch_and_prepare_data(TRAIN_START_DATE, TRAIN_END_DATE, TICKERS)
|
| 63 |
+
train_data.to_csv(TRAIN_DATA_PATH)
|
| 64 |
+
print(f"Training data saved to {TRAIN_DATA_PATH}")
|
| 65 |
+
|
| 66 |
+
print("\n" + "="*50 + "\n")
|
| 67 |
+
|
| 68 |
+
# Fetch, process, and save testing data
|
| 69 |
+
print("--- Preparing Testing Data ---")
|
| 70 |
+
test_data = fetch_and_prepare_data(TEST_START_DATE, TEST_END_DATE, TICKERS)
|
| 71 |
+
test_data.to_csv(TEST_DATA_PATH)
|
| 72 |
+
print(f"Testing data saved to {TEST_DATA_PATH}")
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
main()
|
scripts/fetch_market_data.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import yfinance as yf
|
| 5 |
+
from datetime import date
|
| 6 |
+
|
| 7 |
+
def fetch_data(start_date, end_date, output_filename):
|
| 8 |
+
"""
|
| 9 |
+
Fetches, cleans, and saves historical market data for a given date range.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
start_date (str): The start date for the data in 'YYYY-MM-DD' format.
|
| 13 |
+
end_date (str): The end date for the data in 'YYYY-MM-DD' format.
|
| 14 |
+
output_filename (str): The path and name of the file to save the data.
|
| 15 |
+
"""
|
| 16 |
+
print(f"--- Fetching data from {start_date} to {end_date} ---")
|
| 17 |
+
|
| 18 |
+
# Define the base list of tickers
|
| 19 |
+
tickers = ["AAPL", "MSFT", "SPY", "TLT", "BTC-USD"]
|
| 20 |
+
|
| 21 |
+
# Smartly remove Bitcoin if the period is before its existence (e.g., before 2013)
|
| 22 |
+
if pd.to_datetime(start_date).year < 2013:
|
| 23 |
+
print("Note: Bitcoin (BTC-USD) did not exist for the requested period and will be excluded.")
|
| 24 |
+
tickers.remove("BTC-USD")
|
| 25 |
+
|
| 26 |
+
# Download data from Yahoo Finance
|
| 27 |
+
data = yf.download(tickers, start=start_date, end=end_date)
|
| 28 |
+
close_data = data['Close'].copy()
|
| 29 |
+
|
| 30 |
+
# Data Cleaning
|
| 31 |
+
print(f"\nMissing values before cleaning:\n{close_data.isnull().sum()}")
|
| 32 |
+
close_data.ffill(inplace=True)
|
| 33 |
+
close_data.bfill(inplace=True)
|
| 34 |
+
|
| 35 |
+
# Drop any columns that are still all NaN (like BTC in the 2008 data)
|
| 36 |
+
close_data.dropna(axis=1, how='all', inplace=True)
|
| 37 |
+
|
| 38 |
+
print(f"\nMissing values after cleaning:\n{close_data.isnull().sum()}")
|
| 39 |
+
|
| 40 |
+
# Ensure data directory exists
|
| 41 |
+
output_dir = os.path.dirname(output_filename)
|
| 42 |
+
if output_dir and not os.path.exists(output_dir):
|
| 43 |
+
os.makedirs(output_dir)
|
| 44 |
+
|
| 45 |
+
# Save to CSV
|
| 46 |
+
close_data.to_csv(output_filename)
|
| 47 |
+
print(f"\nβ
Data successfully saved to {output_filename}")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if __name__ == "__main__":
|
| 51 |
+
# Set up command-line argument parsing
|
| 52 |
+
parser = argparse.ArgumentParser(description="Fetch historical market data for specified periods.")
|
| 53 |
+
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--start",
|
| 56 |
+
type=str,
|
| 57 |
+
default="2018-01-01",
|
| 58 |
+
help="Start date in YYYY-MM-DD format. Default is for the 2018 stress test."
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--end",
|
| 62 |
+
type=str,
|
| 63 |
+
default="2019-12-31",
|
| 64 |
+
help="End date in YYYY-MM-DD format. Default is for the 2018 stress test."
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--filename",
|
| 68 |
+
type=str,
|
| 69 |
+
default="data/stress_test_2018.csv",
|
| 70 |
+
help="Output file name (e.g., 'data/my_data.csv')."
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
args = parser.parse_args()
|
| 74 |
+
|
| 75 |
+
# Use 'today' as the end date if specified
|
| 76 |
+
end_date = date.today().strftime('%Y-%m-%d') if args.end.lower() == 'today' else args.end
|
| 77 |
+
|
| 78 |
+
fetch_data(start_date=args.start, end_date=end_date, output_filename=args.filename)
|
scripts/stress_test.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from matplotlib.ticker import FuncFormatter
|
| 7 |
+
|
| 8 |
+
# Import all agent classes and the environment
|
| 9 |
+
from stable_baselines3 import PPO, SAC, TD3
|
| 10 |
+
from src.environment import PortfolioEnv
|
| 11 |
+
|
| 12 |
+
# --- Helper Functions ---
|
| 13 |
+
def evaluate_agent(env, model):
|
| 14 |
+
"""Runs a trained agent on a given environment."""
|
| 15 |
+
obs, info = env.reset()
|
| 16 |
+
terminated, truncated = False, False
|
| 17 |
+
portfolio_values = [env.initial_balance]
|
| 18 |
+
while not (terminated or truncated):
|
| 19 |
+
action, _states = model.predict(obs, deterministic=True)
|
| 20 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
| 21 |
+
portfolio_values.append(info['portfolio_value'])
|
| 22 |
+
return pd.Series(portfolio_values, index=env.df.index[:len(portfolio_values)])
|
| 23 |
+
|
| 24 |
+
def buy_and_hold(df, initial_balance=10000):
|
| 25 |
+
"""Simulates the Buy and Hold strategy."""
|
| 26 |
+
n_assets = len(df.columns)
|
| 27 |
+
initial_investment_per_asset = initial_balance / n_assets
|
| 28 |
+
initial_prices = df.iloc[0]
|
| 29 |
+
shares = initial_investment_per_asset / initial_prices
|
| 30 |
+
portfolio_values = df.dot(shares)
|
| 31 |
+
return portfolio_values
|
| 32 |
+
|
| 33 |
+
def calculate_metrics(portfolio_values):
|
| 34 |
+
"""Calculates performance metrics from a portfolio value series."""
|
| 35 |
+
total_return = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) - 1
|
| 36 |
+
num_years = (portfolio_values.index[-1] - portfolio_values.index[0]).days / 365.25
|
| 37 |
+
cagr = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) ** (1/num_years) - 1 if num_years > 0 else 0
|
| 38 |
+
daily_returns = portfolio_values.pct_change().dropna()
|
| 39 |
+
sharpe_ratio = np.sqrt(252) * (daily_returns.mean() / daily_returns.std()) if daily_returns.std() != 0 else 0
|
| 40 |
+
rolling_max = portfolio_values.cummax()
|
| 41 |
+
daily_drawdown = portfolio_values / rolling_max - 1.0
|
| 42 |
+
max_drawdown = daily_drawdown.min()
|
| 43 |
+
return {
|
| 44 |
+
"Total Return": f"{total_return:.2%}", "CAGR": f"{cagr:.2%}",
|
| 45 |
+
"Sharpe Ratio": f"{sharpe_ratio:.2f}", "Max Drawdown": f"{max_drawdown:.2%}"
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
# --- Main Stress Test Function ---
|
| 49 |
+
def run_stress_test(datafile_path, ppo_path, sac_path, td3_path, output_path):
|
| 50 |
+
"""
|
| 51 |
+
Loads data and models, runs evaluations, and plots the comparison.
|
| 52 |
+
"""
|
| 53 |
+
print(f"--- Running Stress Test on {datafile_path} ---")
|
| 54 |
+
|
| 55 |
+
# 1. Load Data
|
| 56 |
+
try:
|
| 57 |
+
test_df = pd.read_csv(datafile_path, index_col='Date', parse_dates=True)
|
| 58 |
+
except FileNotFoundError:
|
| 59 |
+
print(f"β Error: Data file not found at {datafile_path}")
|
| 60 |
+
return
|
| 61 |
+
|
| 62 |
+
# Check for asset mismatch (e.g., 4 assets in 2008 data vs 5-asset models)
|
| 63 |
+
# The standard models were trained on 5 assets (e.g., shape = 30 * 5 = 150)
|
| 64 |
+
expected_assets = 5
|
| 65 |
+
if test_df.shape[1] != expected_assets:
|
| 66 |
+
print(f"β οΈ Warning: Models were trained on {expected_assets} assets, but this dataset has {test_df.shape[1]}.")
|
| 67 |
+
print("Skipping agent evaluation for this dataset.")
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
# 2. Define Models to Evaluate
|
| 71 |
+
models_to_evaluate = {
|
| 72 |
+
"PPO Agent": (PPO, ppo_path),
|
| 73 |
+
"SAC Agent": (SAC, sac_path),
|
| 74 |
+
"TD3 Agent": (TD3, td3_path)
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
portfolio_values = {}
|
| 78 |
+
metrics = {}
|
| 79 |
+
|
| 80 |
+
# 3. Run Evaluations
|
| 81 |
+
for name, (agent_type, model_path) in models_to_evaluate.items():
|
| 82 |
+
if os.path.exists(model_path):
|
| 83 |
+
print(f"--- Evaluating {name} ---")
|
| 84 |
+
model = agent_type.load(model_path)
|
| 85 |
+
env = PortfolioEnv(test_df)
|
| 86 |
+
portfolio_values[name] = evaluate_agent(env, model)
|
| 87 |
+
metrics[name] = calculate_metrics(portfolio_values[name])
|
| 88 |
+
else:
|
| 89 |
+
print(f"β οΈ Warning: Model file not found at {model_path}. Skipping.")
|
| 90 |
+
|
| 91 |
+
# Evaluate Buy and Hold Baseline
|
| 92 |
+
print("\n--- Evaluating Buy and Hold Baseline ---")
|
| 93 |
+
bnh_values = buy_and_hold(test_df)
|
| 94 |
+
portfolio_values["Buy and Hold"] = bnh_values
|
| 95 |
+
metrics["Buy and Hold"] = calculate_metrics(bnh_values)
|
| 96 |
+
|
| 97 |
+
# 4. Display Results
|
| 98 |
+
print("\n--- Stress Test Performance Metrics ---")
|
| 99 |
+
metrics_df = pd.DataFrame(metrics)
|
| 100 |
+
print(metrics_df)
|
| 101 |
+
|
| 102 |
+
# 5. Plotting
|
| 103 |
+
plt.style.use('seaborn-v0_8-darkgrid')
|
| 104 |
+
fig, ax = plt.subplots(figsize=(14, 8))
|
| 105 |
+
|
| 106 |
+
colors = {"PPO Agent": "red", "SAC Agent": "green", "TD3 Agent": "orange", "Buy and Hold": "blue"}
|
| 107 |
+
for name, values in portfolio_values.items():
|
| 108 |
+
ax.plot(values.index, values, label=name, color=colors.get(name, 'black'), linewidth=2)
|
| 109 |
+
|
| 110 |
+
plot_title = f"Agent Stress Test: {os.path.basename(datafile_path).replace('.csv', '')}"
|
| 111 |
+
ax.set_title(plot_title, fontsize=16)
|
| 112 |
+
ax.set_xlabel('Date', fontsize=12)
|
| 113 |
+
ax.set_ylabel('Portfolio Value ($)', fontsize=12)
|
| 114 |
+
ax.legend(fontsize=12)
|
| 115 |
+
|
| 116 |
+
formatter = FuncFormatter(lambda x, p: f'${x:,.0f}')
|
| 117 |
+
ax.yaxis.set_major_formatter(formatter)
|
| 118 |
+
|
| 119 |
+
plt.tight_layout()
|
| 120 |
+
plt.savefig(output_path)
|
| 121 |
+
print(f"\nβ
Plot saved to {output_path}")
|
| 122 |
+
plt.show()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if __name__ == '__main__':
|
| 126 |
+
parser = argparse.ArgumentParser(description="Run a stress test on trained RL portfolio agents.")
|
| 127 |
+
|
| 128 |
+
parser.add_argument("--datafile", type=str, default="data/stress_test_2018.csv", help="Path to the market data CSV file for the test.")
|
| 129 |
+
parser.add_argument("--ppopath", type=str, default="checkpoints/ppo_portfolio_model.zip", help="Path to the trained PPO model.")
|
| 130 |
+
parser.add_argument("--sacpath", type=str, default="checkpoints/sac_portfolio_model.zip", help="Path to the trained SAC model.")
|
| 131 |
+
parser.add_argument("--td3path", type=str, default="checkpoints/td3_portfolio_model.zip", help="Path to the trained TD3 model.")
|
| 132 |
+
parser.add_argument("--output", type=str, default="results/stress_test_comparison.png", help="Path to save the output plot.")
|
| 133 |
+
|
| 134 |
+
args = parser.parse_args()
|
| 135 |
+
|
| 136 |
+
run_stress_test(
|
| 137 |
+
datafile_path=args.datafile,
|
| 138 |
+
ppo_path=args.ppopath,
|
| 139 |
+
sac_path=args.sacpath,
|
| 140 |
+
td3_path=args.td3path,
|
| 141 |
+
output_path=args.output
|
| 142 |
+
)
|
scripts/train.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from stable_baselines3 import PPO, SAC, TD3
|
| 4 |
+
from environment import PortfolioEnv
|
| 5 |
+
|
| 6 |
+
def train_agent(agent_name="td3", timesteps=100000):
|
| 7 |
+
"""
|
| 8 |
+
Main function to train a specified RL agent.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
agent_name (str): The RL algorithm to use ('ppo', 'sac', or 'td3').
|
| 12 |
+
timesteps (int): The total number of timesteps for training.
|
| 13 |
+
"""
|
| 14 |
+
# 1. Map agent names to their corresponding classes
|
| 15 |
+
AGENT_CLASSES = {
|
| 16 |
+
"ppo": PPO,
|
| 17 |
+
"sac": SAC,
|
| 18 |
+
"td3": TD3
|
| 19 |
+
}
|
| 20 |
+
agent_class = AGENT_CLASSES.get(agent_name.lower())
|
| 21 |
+
if agent_class is None:
|
| 22 |
+
raise ValueError(f"Unknown agent: {agent_name}. Choose from {list(AGENT_CLASSES.keys())}")
|
| 23 |
+
|
| 24 |
+
model_name = agent_name.lower()
|
| 25 |
+
|
| 26 |
+
# 2. Load data and create the environment
|
| 27 |
+
print("--- Loading Data and Creating Environment ---")
|
| 28 |
+
try:
|
| 29 |
+
df = pd.read_csv('data/train.csv', index_col='Date', parse_dates=True)
|
| 30 |
+
env = PortfolioEnv(df)
|
| 31 |
+
print("Environment created successfully.")
|
| 32 |
+
except FileNotFoundError:
|
| 33 |
+
print("β Error: 'data/train.csv' not found. Make sure to run a data fetching script first.")
|
| 34 |
+
return
|
| 35 |
+
|
| 36 |
+
# 3. Create the RL Agent
|
| 37 |
+
print(f"--- Creating {agent_name.upper()} Agent ---")
|
| 38 |
+
model = agent_class(
|
| 39 |
+
"MlpPolicy",
|
| 40 |
+
env,
|
| 41 |
+
verbose=1,
|
| 42 |
+
tensorboard_log="./tensorboard_logs/"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# 4. Train the Agent
|
| 46 |
+
print(f"--- Starting Agent Training for {timesteps} timesteps ---")
|
| 47 |
+
model.learn(total_timesteps=timesteps)
|
| 48 |
+
print("--- Agent Training Complete ---")
|
| 49 |
+
|
| 50 |
+
# 5. Save the Trained Model
|
| 51 |
+
save_path = f"checkpoints/{model_name}_portfolio_model"
|
| 52 |
+
model.save(save_path)
|
| 53 |
+
print(f"β
Model saved to checkpoints/{save_path}.zip")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
# 6. Set up command-line argument parsing
|
| 58 |
+
parser = argparse.ArgumentParser(description="Train a Reinforcement Learning agent for portfolio management.")
|
| 59 |
+
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--agent",
|
| 62 |
+
type=str,
|
| 63 |
+
default="td3",
|
| 64 |
+
choices=["ppo", "sac", "td3"],
|
| 65 |
+
help="The RL algorithm to use for training (default: td3)."
|
| 66 |
+
)
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--timesteps",
|
| 69 |
+
type=int,
|
| 70 |
+
default=100000,
|
| 71 |
+
help="The total number of timesteps for training (default: 100000)."
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
args = parser.parse_args()
|
| 75 |
+
|
| 76 |
+
# Call the main training function with the parsed arguments
|
| 77 |
+
train_agent(agent_name=args.agent, timesteps=args.timesteps)
|
scripts/visualize_strategy.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from matplotlib.ticker import FuncFormatter
|
| 7 |
+
from stable_baselines3 import PPO, SAC, TD3
|
| 8 |
+
from environment import PortfolioEnv
|
| 9 |
+
|
| 10 |
+
def visualize_strategy(agent_name, checkpoint_path, datafile_path, output_path):
|
| 11 |
+
"""
|
| 12 |
+
Loads a trained agent, runs a simulation, and plots its portfolio allocation strategy.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
agent_name (str): The type of agent to load ('ppo', 'sac', 'td3').
|
| 16 |
+
checkpoint_path (str): The path to the saved model checkpoint file (.zip).
|
| 17 |
+
datafile_path (str): The path to the CSV market data for the simulation.
|
| 18 |
+
output_path (str): The path to save the output plot image.
|
| 19 |
+
"""
|
| 20 |
+
print(f"--- Visualizing strategy for {agent_name.upper()} agent ---")
|
| 21 |
+
|
| 22 |
+
# 1. Define a mapping from agent names to their classes
|
| 23 |
+
AGENT_CLASSES = {
|
| 24 |
+
"ppo": PPO,
|
| 25 |
+
"sac": SAC,
|
| 26 |
+
"td3": TD3
|
| 27 |
+
}
|
| 28 |
+
agent_class = AGENT_CLASSES[agent_name.lower()]
|
| 29 |
+
|
| 30 |
+
# 2. Load Data and Model
|
| 31 |
+
try:
|
| 32 |
+
test_df = pd.read_csv(datafile_path, index_col='Date', parse_dates=True)
|
| 33 |
+
model = agent_class.load(checkpoint_path)
|
| 34 |
+
except FileNotFoundError as e:
|
| 35 |
+
print(f"β Error: Could not find a required file. {e}")
|
| 36 |
+
return
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"β An error occurred: {e}")
|
| 39 |
+
return
|
| 40 |
+
|
| 41 |
+
# 3. Create Environment and Run Simulation
|
| 42 |
+
env = PortfolioEnv(test_df)
|
| 43 |
+
obs, info = env.reset()
|
| 44 |
+
terminated, truncated = False, False
|
| 45 |
+
|
| 46 |
+
weights_history = [info['weights']]
|
| 47 |
+
while not (terminated or truncated):
|
| 48 |
+
action, _states = model.predict(obs, deterministic=True)
|
| 49 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
| 50 |
+
weights_history.append(info['weights'])
|
| 51 |
+
print("β
Simulation complete.")
|
| 52 |
+
|
| 53 |
+
# 4. Prepare Data for Plotting
|
| 54 |
+
weights_df = pd.DataFrame(weights_history)
|
| 55 |
+
asset_names = test_df.columns.tolist() + ['Cash']
|
| 56 |
+
weights_df.columns = asset_names
|
| 57 |
+
weights_df.index = test_df.index[:len(weights_df)]
|
| 58 |
+
|
| 59 |
+
# 5. Plotting the Stacked Area Chart
|
| 60 |
+
print("π Generating plot...")
|
| 61 |
+
plt.style.use('seaborn-v0_8-darkgrid')
|
| 62 |
+
fig, ax = plt.subplots(figsize=(15, 8))
|
| 63 |
+
|
| 64 |
+
ax.stackplot(weights_df.index, weights_df.T, labels=weights_df.columns, alpha=0.8)
|
| 65 |
+
|
| 66 |
+
ax.set_title(f'Agent Portfolio Allocation Over Time ({agent_name.upper()})', fontsize=16)
|
| 67 |
+
ax.set_xlabel('Date', fontsize=12)
|
| 68 |
+
ax.set_ylabel('Portfolio Allocation (%)', fontsize=12)
|
| 69 |
+
ax.legend(loc='upper left', fontsize=10)
|
| 70 |
+
|
| 71 |
+
formatter = FuncFormatter(lambda y, p: f'{y:.0%}')
|
| 72 |
+
ax.yaxis.set_major_formatter(formatter)
|
| 73 |
+
|
| 74 |
+
plt.tight_layout()
|
| 75 |
+
|
| 76 |
+
# Ensure output directory exists
|
| 77 |
+
output_dir = os.path.dirname(output_path)
|
| 78 |
+
if output_dir and not os.path.exists(output_dir):
|
| 79 |
+
os.makedirs(output_dir)
|
| 80 |
+
|
| 81 |
+
plt.savefig(output_path)
|
| 82 |
+
print(f"β
Plot saved to {output_path}")
|
| 83 |
+
plt.show()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
# Set up command-line argument parsing
|
| 88 |
+
parser = argparse.ArgumentParser(description="Visualize a trained RL agent's portfolio allocation strategy.")
|
| 89 |
+
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--agent",
|
| 92 |
+
type=str,
|
| 93 |
+
required=True,
|
| 94 |
+
choices=["ppo", "sac", "td3"],
|
| 95 |
+
help="The RL algorithm of the trained agent."
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--checkpoint",
|
| 99 |
+
type=str,
|
| 100 |
+
required=True,
|
| 101 |
+
help="Path to the saved model checkpoint .zip file (e.g., 'td3_portfolio_model.zip')."
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--datafile",
|
| 105 |
+
type=str,
|
| 106 |
+
default="data/test.csv",
|
| 107 |
+
help="Path to the market data CSV file to run the simulation on."
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--output",
|
| 111 |
+
type=str,
|
| 112 |
+
default="results/agent_allocation.png",
|
| 113 |
+
help="Path to save the output plot image."
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
args = parser.parse_args()
|
| 117 |
+
|
| 118 |
+
visualize_strategy(
|
| 119 |
+
agent_name=args.agent,
|
| 120 |
+
checkpoint_path=args.checkpoint,
|
| 121 |
+
datafile_path=args.datafile,
|
| 122 |
+
output_path=args.output
|
| 123 |
+
)
|