DanielKiani commited on
Commit
7d2e753
Β·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ venv/
4
+ .venv/
5
+ .vscode/
6
+ .idea/
7
+
README.md ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![Banner](assets/banner.png)
2
+ [![Python](https://img.shields.io/badge/Python-3.12.11-blue?logo=python)](https://www.python.org/)[![PyTorch](https://img.shields.io/badge/PyTorch-2.8-EE4C2C?logo=pytorch)](https://pytorch.org/)![Made with ML](https://img.shields.io/badge/Made%20with-ML-blueviolet?logo=openai)[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
3
+
4
+ # πŸ€– Portfolio Optimization with Deep Reinforcement Learning
5
+
6
+ This project explores the use of Deep Reinforcement Learning to train autonomous agents for financial portfolio management. The goal was not just to create a single profitable agent, but to conduct a comparative study of different RL algorithms (PPO, SAC, TD3) to understand the emergent trading strategies and their robustness across various market conditions.
7
+
8
+ **The ultimate finding? A TD3-based agent learned a superior, risk-managed static asset allocation that consistently outperformed both active trading strategies and aggressive growth models, especially during market downturns.**
9
+
10
+ ---
11
+
12
+ ## πŸ“œ Table of Contents
13
+
14
+ 1. [πŸ“Š The Data & Asset Selection](#-the-data--asset-selection)
15
+ 2. [🎯 Benchmarking Against Baselines](#-benchmarking-against-baselines)
16
+ 3. [πŸ† Key Findings & The Champion Agent](#-key-findings--the-champion-agent)
17
+ 4. [🧠 Comparative Analysis of Agent Strategies](#-comparative-analysis-of-agent-strategies)
18
+ * [πŸ₯‡ TD3: The Prudent Risk-Manager](#-td3-the-prudent-risk-manager)
19
+ * [πŸš€ SAC: The Aggressive Growth Engine](#-sac-the-aggressive-growth-engine)
20
+ * [πŸ“ˆ PPO: The Active (but Inconsistent) Trader](#-ppo-the-active-but-inconsistent-trader)
21
+ 5. [πŸŒͺ️ Stress Testing: The Ultimate Test of Robustness](#️-stress-testing-the-ultimate-test-of-robustness)
22
+ 6. [πŸ”¬ The Research Journey: Why Simplicity Won](#-the-research-journey-why-simplicity-won)
23
+ 7. [βœ… Conclusion](#-conclusion)
24
+ 8. [πŸ“‚ Project Structure](#-project-structure)
25
+ 9. [πŸš€ How to Run](#-how-to-run)
26
+ * [Setup](#setup)
27
+ * [Data Fetching](#data-fetching)
28
+ * [Training](#training)
29
+ * [Evaluation & Visualization](#evaluation--visualization)
30
+
31
+ ---
32
+
33
+ ## πŸ“Š The Data & Asset Selection
34
+
35
+ The foundation of any financial machine learning project is the data. This project uses daily closing price data sourced from **Yahoo Finance** via the `yfinance` library. The primary training period was **2015-2020**, with out-of-sample testing conducted on **2021-2023** and other periods for stress testing.
36
+
37
+ The selection of assets was crucial for creating a realistic decision-making environment for the agent. The portfolio consists of five assets, chosen to represent different classes and risk profiles:
38
+
39
+ * **Growth Equities (AAPL, MSFT):** Represent the high-growth, high-volatility technology sector.
40
+ * **Market Index (SPY):** An ETF tracking the S&P 500, representing the broader US stock market.
41
+ * **Safe Haven (TLT):** An ETF for 20+ Year US Treasury Bonds, which often acts as a "risk-off" asset during stock market downturns.
42
+ * **Alternative Asset (BTC-USD):** Represents a non-traditional, extremely volatile asset class with high potential returns.
43
+
44
+ This diverse mix forces the agent to learn not just about individual assets, but also about their correlations and how to balance risk across different economic regimes.
45
+
46
+ ---
47
+
48
+ ## 🎯 Benchmarking Against Baselines
49
+
50
+ To prove that a reinforcement learning agent is truly "intelligent," its performance must be measured against simple, standard strategies. An agent is only successful if it can provide value beyond a naive approach.
51
+
52
+ Our primary benchmark was the **Buy and Hold** strategy, where an equal amount of capital is invested in each asset at the beginning of the period and never touched again. The goal for any trained RL agent was to achieve superior performance, especially on a **risk-adjusted basis** (e.g., higher Sharpe Ratio, lower Max Drawdown), compared to this baseline.
53
+
54
+ The chart below shows the performance of a simple Buy and Hold strategy during the 2021-2023 test period, setting a clear target for our agents to beat.
55
+
56
+ ![Baseline Performance Chart](results/baseline_results.png)
57
+
58
+ ---
59
+
60
+ ## πŸ† Key Findings & The Champion Agent
61
+
62
+ After extensive training, evaluation, and stress-testing, the **TD3 agent emerged as the clear winner** on a risk-adjusted basis. While other agents achieved higher raw returns, their strategies proved to be brittle and dangerously volatile during market crises. The TD3 agent's strategy was the most robust and reliable.
63
+
64
+ #### Final Performance Comparison (2021-2023)
65
+
66
+ This table summarizes the performance of the top-performing static agents against the baseline.
67
+
68
+ | Metric | **TD3 Agent** | SAC Agent | Buy & Hold |
69
+ | :--- | :--- | :--- | :--- |
70
+ | **Total Return** | 47.24% | **50.89%** | 34.91% |
71
+ | **CAGR** | 13.76% | **14.70%** | 10.50% |
72
+ | **Sharpe Ratio** | **0.62** | 0.51 | 0.45 |
73
+ | **Max Drawdown** | **-28.41%** | -44.61% | -40.81% |
74
+
75
+ The TD3 agent delivered strong returns while significantly reducing the maximum drawdown, proving its superior capital preservation strategy.
76
+
77
+ ![Main Performance Chart](results/final_performance_comparison_all_agents.png)
78
+
79
+ ---
80
+
81
+ ## 🧠 Comparative Analysis of Agent Strategies
82
+
83
+ A fascinating outcome of this project was observing three different RL algorithms independently discover three distinct and recognizable investment philosophies.
84
+
85
+ ### πŸ₯‡ TD3: The Prudent Risk-Manager
86
+
87
+ The TD3 agent concluded that the most effective strategy was not to trade frequently, but to find one **superior, risk-managed static asset allocation** and hold it.
88
+
89
+ * **Strategy:** "Smarter Buy and Hold".
90
+ * **Behavior:** The agent's allocation is completely static, indicating it focused on the initial strategic decision and ignored market noise to minimize transaction costs.
91
+ * **Result:** This approach led to the best risk-adjusted returns, proving that a robust initial setup is more valuable than reactive trading.
92
+
93
+ ![TD3 Allocation Chart](results/td3_portfolio_alocation.png)
94
+
95
+ ### πŸš€ SAC: The Aggressive Growth Engine
96
+
97
+ The SAC agent also learned a static allocation strategy, but its portfolio was geared for **maximum growth**, accepting higher risk for higher potential returns.
98
+
99
+ * **Strategy:** High-risk, high-return static allocation.
100
+ * **Behavior:** Like TD3, it made one initial allocation and held firm. However, this allocation was far more aggressive.
101
+ * **Result:** It achieved the highest total return in some periods but suffered catastrophic drawdowns in stress tests, making its strategy unreliable and brittle.
102
+
103
+ ![SAC Performance Chart](results/sac_portfolio_alocation.png)
104
+
105
+ ### πŸ“ˆ PPO: The Active (but Inconsistent) Trader
106
+
107
+ Unlike the other two, the PPO agent learned an **active, dynamic trading strategy**, constantly adjusting its portfolio based on market conditions.
108
+
109
+ * **Strategy:** Tactical asset allocation.
110
+ * **Behavior:** The allocation chart clearly shows the agent rebalancing its portfolio over time, for example, by increasing its bond (TLT) holdings during the 2022 downturn.
111
+ * **Result:** While impressive that it learned this behavior, its performance was inconsistent. It succeeded in some periods (2018) but failed in others (2025), highlighting the immense difficulty of successful market timing.
112
+
113
+ ![PPO Allocation Chart](results/ppo_portfolio_alocation.png)
114
+
115
+ ---
116
+
117
+ ## πŸŒͺ️ Stress Testing: The Ultimate Test of Robustness
118
+
119
+ A model is only as good as its performance during a crisis. We subjected the agents to multiple out-of-sample stress tests, with the 2018 period (featuring a crypto winter and a stock market flash crash) being the most revealing.
120
+
121
+ ![2018 Stress Test Chart](results/stress_test_comparison_2018.png)
122
+
123
+ * **TD3's Triumph:** The orange line shows the TD3 agent successfully navigating the downturn, preserving capital far better than the baseline.
124
+ * **SAC's Failure:** The green line shows the SAC agent's aggressive strategy failing catastrophically, resulting in a massive drawdown.
125
+
126
+ This test definitively proved that the **TD3 agent's risk-managed approach was truly robust**, while the SAC agent's strategy was fragile.
127
+
128
+ ---
129
+
130
+ ## πŸ”¬ The Research Journey: Why Simplicity Won
131
+
132
+ This project was also an exercise in scientific methodology. We initially hypothesized that more complex models and features would yield better results.
133
+
134
+ * **Hypothesis 1: More features are better.** We tested adding technical indicators (RSI, MACD) to the observation space. **Result:** Performance degraded. The indicators acted as noise, confusing the agents.
135
+ * **Hypothesis 2: Models with memory are better.** We tested an LSTM-based agent (`RecurrentPPO`). **Result:** Performance degraded. The added complexity led to overfitting on the training data.
136
+ * **Hypothesis 3: Using Regularization is better.** We tested both L1 and L2 regularization. **Results:** Performance degraded.
137
+ * **Hypothesis 4: Increasing the window from 30 days is better.** We tested increasing the window to 60 days. **Results:** Performance degraded. increasing the context window is not always good and it could be seen as more noise for the model.
138
+
139
+ The conclusion was clear: for this problem, a simple and elegant model (a standard MLP fed with just normalized price data) was the most effective.
140
+
141
+ ---
142
+
143
+ ## βœ… Conclusion
144
+
145
+ This project successfully demonstrates that Deep Reinforcement Learning can be a powerful tool for discovering sophisticated investment strategies. The key insight is that the most robust and successful agent did not learn to be a hyperactive trader, but rather a prudent strategic allocator, emphasizing the timeless investment principle that effective risk management is the true key to long-term success.
146
+
147
+ ---
148
+
149
+ ## πŸ“‚ Project Structure
150
+
151
+ The codebase is organized into modular, reusable scripts.
152
+
153
+ ```bash
154
+ β”œβ”€β”€ assets/
155
+ β”œβ”€β”€ checkpoints/ # Holds all saved model .zip files
156
+ β”œβ”€β”€ results/ # Holds all output plots and metrics
157
+ β”œβ”€β”€ scripts/
158
+ β”‚ β”œβ”€β”€ environment.py # The custom Gymnasium environment for the simulation
159
+ β”‚ β”œβ”€β”€ fetch_market_data.py# A flexible script to download data for any period
160
+ β”‚ β”œβ”€β”€ train.py # The main training script with model selection
161
+ β”‚ β”œβ”€β”€ evaluate.py # The main evaluation script for generating metrics
162
+ β”‚ β”œβ”€β”€ stress_test.py # Runs a full comparison of all agents on a given dataset
163
+ β”‚ └── visualize_strategy.py # Plots the asset allocation of a single trained agent
164
+ └── README.md # This file
165
+ ```
166
+
167
+ ---
168
+
169
+ ## πŸš€ How to Run
170
+
171
+ ### Setup
172
+
173
+ 1. Clone the repository.
174
+ 2. Create and activate a Python virtual environment.
175
+ 3. Install the required packages:
176
+
177
+ ```bash
178
+ pip install -r requirements.txt
179
+ ```
180
+
181
+ ### Data Fetching
182
+
183
+ Use the flexible `fetch_market_data.py` script to get any data you need.
184
+
185
+ ```bash
186
+ # Fetch the default training data (2015-2021)
187
+ python fetch_market_data.py --start 2015-01-01 --end 2020-12-31 --filename data/train.csv
188
+
189
+ # Fetch data for a stress test (e.g., 2022)
190
+ python fetch_market_data.py --start 2022-01-01 --end 2022-12-31 --filename data/test_2022.csv
191
+ ```
192
+
193
+ ### Training
194
+
195
+ Use the `train.py` script to train any of the three main agents.
196
+
197
+ ```bash
198
+ # Train the champion TD3 agent (default)
199
+ python src/train.py --agent td3
200
+
201
+ # Train a SAC agent for more timesteps
202
+ python src/train.py --agent sac --timesteps 100000
203
+ ```
204
+
205
+ ### Evaluation & Visualization
206
+
207
+ Use the dedicated scripts to analyze the results.
208
+
209
+ ```bash
210
+ # Run a full stress test on the 2018 data
211
+ python stress_test.py --datafile data/stress_test_2018.csv
212
+
213
+ # Visualize the TD3 agent's strategy
214
+ python visualize_strategy.py --agent td3 --checkpoint td3_portfolio_model.zip
215
+ ```
assets/banner.png ADDED

Git LFS Details

  • SHA256: 6f3455d5f88a8eb82affe16263753b2ee5cfaa6c6adf2e55bf0b650e8f4701ab
  • Pointer size: 132 Bytes
  • Size of remote file: 1.67 MB
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core RL and Simulation
2
+ stable-baselines3==2.7.0
3
+ sb3_contrib==2.7.0
4
+ gymnasium==1.2.1
5
+
6
+ # Data Handling and Numerics
7
+ pandas==2.3.3
8
+ numpy==2.2.6
9
+ scikit-learn==1.6.1
10
+
11
+ # Data Fetching
12
+ yfinance==0.2.66
13
+
14
+ # Financial Indicators
15
+ pandas-ta==0.4.71b0
16
+
17
+ # Plotting and Visualization
18
+ matplotlib==3.10.0
19
+ seaborn==0.13.2
results/baseline_results.png ADDED

Git LFS Details

  • SHA256: 8ade77274352ad37706e9bb7076b225bc784ddbbd51ada1bb4b6b983a0eb9cf2
  • Pointer size: 131 Bytes
  • Size of remote file: 161 kB
results/final_performance_comparison_all_agents.png ADDED

Git LFS Details

  • SHA256: 75b085d93cd947906c2f6f5fdf4f1fbc0b53cdb56d2ad3a77011b6cae8c787ab
  • Pointer size: 131 Bytes
  • Size of remote file: 225 kB
results/ppo_portfolio_alocation.png ADDED

Git LFS Details

  • SHA256: 3e280bd69972c63cd3ff6cb4c3e0ee80d7e8858c7afc0f251164052e3be6323a
  • Pointer size: 130 Bytes
  • Size of remote file: 86.2 kB
results/sac_portfolio_alocation.png ADDED

Git LFS Details

  • SHA256: 8462e8dffc0a9dfe6562f89c529ea763e45779fb25c6e3bf97ac4aa39d669454
  • Pointer size: 130 Bytes
  • Size of remote file: 39.2 kB
results/stress_test_comparison_2018.png ADDED

Git LFS Details

  • SHA256: 0240c233639499468ded8d767f7ba1f9b0dea256807db8dbc36bcf1136ceb731
  • Pointer size: 131 Bytes
  • Size of remote file: 197 kB
results/td3_portfolio_alocation.png ADDED

Git LFS Details

  • SHA256: 142858e04e50d46a9eeff9a60c3fdf699c5ff337de88164b17865d1fa900fdfd
  • Pointer size: 130 Bytes
  • Size of remote file: 38.1 kB
scripts/check_env.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from stable_baselines3.common.env_checker import check_env
3
+ from environment import PortfolioEnv
4
+
5
+ def main():
6
+ """
7
+ Main function to create and check the custom portfolio environment.
8
+ """
9
+ print("--- Loading Data and Creating Environment ---")
10
+ try:
11
+ # Load your training data
12
+ df = pd.read_csv('data/train.csv', index_col='Date', parse_dates=True)
13
+ # Create an instance of your environment
14
+ env = PortfolioEnv(df)
15
+ print("Environment created successfully.")
16
+ except FileNotFoundError:
17
+ print("❌ Error: 'data/train.csv' not found. Make sure you've run the data fetching script.")
18
+ return
19
+
20
+ print("\n--- Checking Environment Compatibility ---")
21
+ try:
22
+ # The check_env function will raise an error if the environment is not compatible.
23
+ check_env(env)
24
+ print("βœ… Environment check passed!")
25
+ except Exception as e:
26
+ print("❌ Environment check failed:")
27
+ # It's helpful to print the full traceback for debugging complex errors.
28
+ import traceback
29
+ traceback.print_exc()
30
+
31
+ if __name__ == "__main__":
32
+ main()
scripts/environment.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ import numpy as np
3
+ import pandas as pd
4
+ from gymnasium import spaces
5
+
6
+ class PortfolioEnv(gym.Env):
7
+ """
8
+ A custom reinforcement learning environment for portfolio management.
9
+
10
+ This environment simulates the daily trading of multiple financial assets. The agent's
11
+ goal is to learn a policy for allocating capital to maximize risk-adjusted returns.
12
+ """
13
+ metadata = {'render_modes': ['human']}
14
+
15
+ def __init__(self, df, window_size=30, initial_balance=10000, transaction_cost_pct=0.001):
16
+ """
17
+ Initializes the portfolio management environment.
18
+
19
+ Args:
20
+ df (pd.DataFrame): A DataFrame containing the daily closing prices of the assets.
21
+ The index should be dates and columns should be asset tickers.
22
+ window_size (int): The number of past days of price data to include in the observation.
23
+ initial_balance (float): The starting capital for the portfolio.
24
+ transaction_cost_pct (float): The percentage cost for each trade (e.g., 0.001 for 0.1%).
25
+ """
26
+ super(PortfolioEnv, self).__init__()
27
+
28
+ # --- Basic Environment Parameters ---
29
+ self.df = df
30
+ self.window_size = window_size
31
+ self.initial_balance = initial_balance
32
+ self.transaction_cost_pct = transaction_cost_pct
33
+ self.n_assets = len(df.columns)
34
+
35
+ # --- Action Space ---
36
+ # The agent outputs a vector of continuous values, one for each asset plus one for cash.
37
+ # These raw outputs are then converted to portfolio weights via a softmax function.
38
+ # The space is defined from -1 to 1 for better compatibility with standard RL algorithms.
39
+ # Shape: (number of assets + 1 for cash)
40
+ self.action_space = spaces.Box(
41
+ low=-1, high=1, shape=(self.n_assets + 1,), dtype=np.float32
42
+ )
43
+
44
+ # --- Observation Space ---
45
+ # The agent observes a window of past price data, flattened into a 1D vector.
46
+ # Shape: (window_size * number of assets)
47
+ self.observation_space = spaces.Box(
48
+ low=-np.inf, high=np.inf,
49
+ shape=(self.window_size * self.n_assets,),
50
+ dtype=np.float32
51
+ )
52
+
53
+ # --- Internal State Variables ---
54
+ # These variables track the state of the simulation over time.
55
+ self._current_step = 0
56
+ self._portfolio_value = 0.0
57
+ # Weights for each asset + cash, e.g., [w_aapl, w_msft, ..., w_cash]
58
+ self._weights = np.zeros(self.n_assets + 1)
59
+
60
+ def reset(self, seed=None):
61
+ """
62
+ Resets the environment to its initial state for a new episode.
63
+
64
+ Returns:
65
+ tuple: A tuple containing the initial observation and auxiliary info.
66
+ """
67
+ super().reset(seed=seed)
68
+
69
+ # Start the simulation at the first point where a full window of data is available.
70
+ self._current_step = self.window_size
71
+ self._portfolio_value = self.initial_balance
72
+
73
+ # Initialize weights to be 100% in cash.
74
+ self._weights = np.zeros(self.n_assets + 1)
75
+ self._weights[-1] = 1.0 # Last element represents cash
76
+
77
+ observation = self._get_obs()
78
+ info = self._get_info()
79
+
80
+ return observation, info
81
+
82
+ def step(self, action):
83
+ """
84
+ Executes one time step within the environment based on the agent's action.
85
+
86
+ Args:
87
+ action (np.ndarray): The raw output from the agent's policy network.
88
+
89
+ Returns:
90
+ tuple: A tuple containing the next observation, reward, terminated flag,
91
+ truncated flag, and auxiliary info.
92
+ """
93
+ # 1. Store the portfolio value before taking the action.
94
+ current_portfolio_value = self._portfolio_value
95
+
96
+ # 2. Convert the raw action into portfolio weights using the softmax function.
97
+ # This ensures the weights are positive and sum to 1.
98
+ target_weights = np.exp(action) / np.sum(np.exp(action))
99
+
100
+ # 3. Calculate the cost of rebalancing the portfolio.
101
+ # The cost is based on the total value of assets bought or sold.
102
+ trades = (target_weights[:-1] - self._weights[:-1]) * current_portfolio_value
103
+ transaction_costs = np.sum(np.abs(trades)) * self.transaction_cost_pct
104
+
105
+ # 4. Update the internal state: apply costs, set new weights, and advance time.
106
+ self._balance = current_portfolio_value - transaction_costs
107
+ self._weights = target_weights
108
+ self._current_step += 1
109
+
110
+ # 5. Calculate the new portfolio value based on the market's price movement.
111
+ current_prices = self.df.iloc[self._current_step - 1].values
112
+ next_prices = self.df.iloc[self._current_step].values
113
+ price_ratio = next_prices / current_prices # How much each asset's price changed.
114
+
115
+ # The new value of our asset holdings.
116
+ asset_values_after_price_change = (self._weights[:-1] * self._balance) * price_ratio
117
+
118
+ # The new total portfolio value is the sum of the updated asset values plus the cash holding.
119
+ new_portfolio_value = np.sum(asset_values_after_price_change) + (self._weights[-1] * self._balance)
120
+ self._portfolio_value = new_portfolio_value
121
+
122
+ # 6. Calculate the reward for the agent.
123
+ # The reward is the log return of the portfolio value, which encourages geometric growth.
124
+ reward = np.log(new_portfolio_value / current_portfolio_value)
125
+
126
+ # 7. Check for termination conditions.
127
+ # The episode ends if the agent goes broke or runs out of data.
128
+ terminated = bool(self._portfolio_value <= self.initial_balance * 0.5)
129
+ truncated = self._current_step >= len(self.df) - 1
130
+
131
+ observation = self._get_obs()
132
+ info = self._get_info()
133
+
134
+ return observation, reward, terminated, truncated, info
135
+
136
+ def _get_obs(self):
137
+ """
138
+ Constructs the observation for the agent at the current time step.
139
+
140
+ Returns:
141
+ np.ndarray: A flattened 1D array of the normalized price history.
142
+ """
143
+ # Get the window of historical price data.
144
+ price_window = self.df.iloc[self._current_step - self.window_size : self._current_step].values
145
+
146
+ # Normalize the window by dividing by the first price. This helps the agent
147
+ # focus on relative price changes rather than absolute values.
148
+ normalized_window = price_window / price_window[0]
149
+
150
+ return normalized_window.flatten().astype(np.float32)
151
+
152
+ def _get_info(self):
153
+ """
154
+ Returns a dictionary of auxiliary information about the current state.
155
+ """
156
+ return {
157
+ 'step': self._current_step,
158
+ 'portfolio_value': self._portfolio_value,
159
+ 'weights': self._weights,
160
+ }
161
+
162
+ def render(self, mode='human'):
163
+ """
164
+ Renders the environment's state (optional).
165
+ """
166
+ if mode == 'human':
167
+ info = self._get_info()
168
+ print(f"Step: {info['step']}, Portfolio Value: {info['portfolio_value']:.2f}")
169
+
170
+ def close(self):
171
+ """
172
+ Cleans up the environment (optional).
173
+ """
174
+ pass
scripts/evaluate.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from stable_baselines3 import SAC ,PPO , TD3
5
+ from evaluate_baselines import buy_and_hold
6
+ from environment import PortfolioEnv
7
+ from matplotlib.ticker import FuncFormatter
8
+
9
+ # --- Helper Function to Run the RL Agent ---
10
+
11
+ def evaluate_agent(env, model):
12
+ """
13
+ Runs the trained agent on the environment and returns portfolio values.
14
+ """
15
+ obs, info = env.reset()
16
+ terminated, truncated = False, False
17
+
18
+ portfolio_values = [env.initial_balance]
19
+
20
+ while not (terminated or truncated):
21
+ action, _states = model.predict(obs, deterministic=True)
22
+ obs, reward, terminated, truncated, info = env.step(action)
23
+ portfolio_values.append(info['portfolio_value'])
24
+
25
+ return pd.Series(portfolio_values, index=env.df.index[:len(portfolio_values)])
26
+
27
+
28
+ def calculate_metrics(portfolio_values, freq=252, rf=0.0):
29
+ """
30
+ Calculates key performance metrics from a series of portfolio values.
31
+ freq: number of trading periods in a year (252 for daily, 52 for weekly).
32
+ rf: risk-free rate (default = 0 for simplicity).
33
+ """
34
+ returns = portfolio_values.pct_change().dropna()
35
+
36
+ # Total Return
37
+ total_return = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) - 1
38
+
39
+ # CAGR
40
+ num_years = (len(portfolio_values) / freq)
41
+ cagr = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) ** (1/num_years) - 1
42
+
43
+ # Sharpe Ratio
44
+ sharpe_ratio = np.sqrt(freq) * (returns.mean() - rf) / returns.std()
45
+
46
+ # Sortino Ratio (downside risk only)
47
+ downside_returns = returns[returns < 0]
48
+ downside_std = downside_returns.std()
49
+ sortino_ratio = np.sqrt(freq) * (returns.mean() - rf) / downside_std if downside_std > 0 else np.nan
50
+
51
+ # Volatility (annualized std)
52
+ volatility = returns.std() * np.sqrt(freq)
53
+
54
+ # Max Drawdown
55
+ rolling_max = portfolio_values.cummax()
56
+ drawdown = portfolio_values / rolling_max - 1.0
57
+ max_drawdown = drawdown.min()
58
+
59
+ # Calmar Ratio
60
+ calmar_ratio = cagr / abs(max_drawdown / 100) if max_drawdown != 0 else np.nan
61
+
62
+ return {
63
+ "Total Return": f"{total_return:.2%}",
64
+ "CAGR": f"{cagr:.2%}",
65
+ "Sharpe Ratio": f"{sharpe_ratio:.2f}",
66
+ "Sortino Ratio": f"{sortino_ratio:.2f}",
67
+ "Volatility": f"{volatility:.2%}",
68
+ "Max Drawdown": f"{max_drawdown:.2%}",
69
+ "Calmar Ratio": f"{calmar_ratio:.2f}"
70
+ }
71
+
72
+
73
+ def main(test_data_path='data/test.csv'):
74
+ """
75
+ Loads, evaluates, and plots the performance of PPO, SAC, and TD3 agents
76
+ against a Buy and Hold baseline.
77
+ """
78
+ # --- Define Model Paths and Agent Types ---
79
+ models_to_evaluate = {
80
+ "PPO Agent": (PPO, 'checkpoints/ppo_portfolio_model'),
81
+ "SAC Agent": (SAC, 'checkpoints/sac_portfolio_model'),
82
+ "TD3 Agent": (TD3, 'checkpoints/td3_portfolio_model')
83
+ }
84
+
85
+ # Load test data
86
+ test_df = pd.read_csv(test_data_path, index_col='Date', parse_dates=True)
87
+
88
+ # Dictionary to store results
89
+ portfolio_values = {}
90
+ metrics = {}
91
+
92
+ # --- Run Evaluations for each RL Agent---
93
+ for name, (agent_type, model_path) in models_to_evaluate.items():
94
+ print(f"--- Evaluating {name} ---")
95
+ model = agent_type.load(model_path)
96
+ env = PortfolioEnv(test_df)
97
+ portfolio_values[name] = evaluate_agent(env, model)
98
+ metrics[name] = calculate_metrics(portfolio_values[name])
99
+
100
+ # --- Evaluate Buy and Hold Baseline ---
101
+ print("\n--- Evaluating Buy and Hold Baseline ---")
102
+ bnh_values = buy_and_hold(test_df)
103
+ portfolio_values["Buy and Hold"] = bnh_values
104
+ metrics["Buy and Hold"] = calculate_metrics(bnh_values)
105
+
106
+ # --- Combine and Print Metrics ---
107
+ print("\n--- Performance Metrics ---")
108
+ metrics_df = pd.DataFrame(metrics)
109
+ print(metrics_df)
110
+
111
+ # --- Plotting All Strategies ---
112
+ plt.style.use('seaborn-v0_8-darkgrid')
113
+ fig, ax = plt.subplots(figsize=(14, 8))
114
+
115
+ # Define colors for clarity
116
+ colors = {
117
+ "PPO Agent": "red",
118
+ "SAC Agent": "green",
119
+ "TD3 Agent": "orange",
120
+ "Buy and Hold": "blue"
121
+ }
122
+
123
+ for name, values in portfolio_values.items():
124
+ ax.plot(values.index, values, label=name, color=colors[name], linewidth=2)
125
+
126
+ ax.set_title('Agent Performance Comparison', fontsize=16)
127
+ ax.set_xlabel('Date', fontsize=12)
128
+ ax.set_ylabel('Portfolio Value ($)', fontsize=12)
129
+ ax.legend(fontsize=12)
130
+
131
+ formatter = FuncFormatter(lambda x, p: f'${x:,.0f}')
132
+ ax.yaxis.set_major_formatter(formatter)
133
+
134
+ plt.tight_layout()
135
+ plt.savefig('results/final_performance_comparison_all_agents.png')
136
+ plt.show()
137
+
138
+ # Example of how to run this main function
139
+ if __name__ == '__main__':
140
+ # You can specify a different test file here if needed
141
+ # e.g., main(test_data_path='data/stress_test_2018.csv')
142
+ main()
scripts/evaluate_baselines.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # evaluate_baselines.py
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+
7
+ def buy_and_hold(df, initial_balance=10000):
8
+ """
9
+ Simulates the Buy and Hold strategy.
10
+
11
+ Args:
12
+ df (pd.DataFrame): DataFrame with daily asset prices.
13
+ initial_balance (int): The starting capital.
14
+
15
+ Returns:
16
+ pd.Series: A Series containing the portfolio value for each day.
17
+ """
18
+ print("--- Simulating Buy and Hold ---")
19
+ n_assets = len(df.columns)
20
+
21
+ # Invest an equal amount in each asset at the beginning
22
+ initial_investment_per_asset = initial_balance / n_assets
23
+
24
+ # Get the initial prices
25
+ initial_prices = df.iloc[0]
26
+
27
+ # Calculate the number of shares bought for each asset
28
+ shares = initial_investment_per_asset / initial_prices
29
+
30
+ # Calculate the portfolio value for each day
31
+ portfolio_values = df.dot(shares)
32
+
33
+ print(f"Initial Investment: ${initial_balance:.2f}")
34
+ print(f"Final Portfolio Value: ${portfolio_values.iloc[-1]:.2f}")
35
+
36
+ return portfolio_values
37
+
38
+ def equally_weighted_rebalanced(df, initial_balance=10000, rebalance_freq='M', transaction_cost_pct=0.001):
39
+ """
40
+ Simulates an Equally Weighted Portfolio with periodic rebalancing.
41
+
42
+ Args:
43
+ df (pd.DataFrame): DataFrame with daily asset prices.
44
+ initial_balance (int): The starting capital.
45
+ rebalance_freq (str): The rebalancing frequency ('M' for monthly, 'Q' for quarterly).
46
+ transaction_cost_pct (float): The transaction cost as a percentage.
47
+
48
+ Returns:
49
+ pd.Series: A Series containing the portfolio value for each day.
50
+ """
51
+ print(f"\n--- Simulating Equally Weighted Portfolio (Rebalanced {rebalance_freq}) ---")
52
+ n_assets = len(df.columns)
53
+
54
+ # Set the initial weights to be equal
55
+ weights = np.full(n_assets, 1/n_assets)
56
+
57
+ portfolio_value = initial_balance
58
+ portfolio_values = pd.Series(index=df.index)
59
+
60
+ last_rebalance_date = None
61
+
62
+ for date, prices in df.iterrows():
63
+ # Store the portfolio value for the day before any changes
64
+ portfolio_values[date] = portfolio_value
65
+
66
+ # Determine if it's a rebalancing day
67
+ # Rebalance on the first day of the new period (month, quarter)
68
+ if last_rebalance_date is None or (date.month != last_rebalance_date.month and rebalance_freq == 'M'):
69
+
70
+ # Calculate the value of trades to rebalance
71
+ target_asset_values = portfolio_value * (1/n_assets)
72
+ current_asset_values = weights * portfolio_value
73
+ trades = target_asset_values - current_asset_values
74
+
75
+ # Apply transaction costs
76
+ transaction_costs = np.sum(np.abs(trades)) * transaction_cost_pct
77
+ portfolio_value -= transaction_costs
78
+
79
+ # Reset weights to be equal
80
+ weights = np.full(n_assets, 1/n_assets)
81
+ last_rebalance_date = date
82
+
83
+ # Calculate portfolio value for the *next* day before the market opens
84
+ # Get price changes from today to the next trading day
85
+ today_prices = df.loc[date]
86
+ next_day_index = df.index.get_loc(date) + 1
87
+ if next_day_index < len(df):
88
+ next_day_prices = df.iloc[next_day_index]
89
+ price_change_ratio = next_day_prices / today_prices
90
+
91
+ # Update portfolio value based on price changes
92
+ portfolio_value = np.sum( (weights * portfolio_value) * price_change_ratio )
93
+
94
+ # Update weights due to market drift
95
+ new_asset_values = (weights * portfolio_value) * price_change_ratio
96
+ weights = new_asset_values / np.sum(new_asset_values)
97
+
98
+ print(f"Initial Investment: ${initial_balance:.2f}")
99
+ print(f"Final Portfolio Value: ${portfolio_values.iloc[-1]:.2f}")
100
+
101
+ return portfolio_values.dropna()
102
+
103
+
104
+ def main():
105
+ # Load the test data
106
+ test_df = pd.read_csv('data/test.csv', index_col='Date', parse_dates=True)
107
+
108
+ # --- Run Baseline Strategies ---
109
+ bnh_values = buy_and_hold(test_df)
110
+ ewp_values = equally_weighted_rebalanced(test_df)
111
+
112
+ # --- Plot the results ---
113
+ plt.style.use('seaborn-v0_8-darkgrid')
114
+ fig, ax = plt.subplots(figsize=(14, 8))
115
+
116
+ ax.plot(bnh_values.index, bnh_values, label='Buy and Hold', color='blue', linewidth=2)
117
+ ax.plot(ewp_values.index, ewp_values, label='Equally Weighted (Rebalanced Monthly)', color='green', linewidth=2)
118
+
119
+ ax.set_title('Baseline Strategy Performance (2021-2023)', fontsize=16)
120
+ ax.set_xlabel('Date', fontsize=12)
121
+ ax.set_ylabel('Portfolio Value ($)', fontsize=12)
122
+ ax.legend(fontsize=12)
123
+
124
+ # Format the y-axis to show currency
125
+ from matplotlib.ticker import FuncFormatter
126
+ formatter = FuncFormatter(lambda x, p: f'${x:,.0f}')
127
+ ax.yaxis.set_major_formatter(formatter)
128
+
129
+ plt.tight_layout()
130
+ plt.savefig('baseline_performance.png')
131
+ plt.show()
132
+
133
+ if __name__ == '__main__':
134
+ main()
scripts/fetch_data.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yfinance as yf
2
+ import pandas as pd
3
+ import os
4
+
5
+ # --- Configuration ---
6
+ # Asset tickers
7
+ TICKERS = ["AAPL", "MSFT", "SPY", "TLT", "BTC-USD"]
8
+ # Time periods for training and testing
9
+ TRAIN_START_DATE = "2015-01-01"
10
+ TRAIN_END_DATE = "2020-12-31"
11
+ TEST_START_DATE = "2021-01-01"
12
+ TEST_END_DATE = "2023-12-31"
13
+
14
+ # Directory to save the data
15
+ DATA_DIR = "data"
16
+ TRAIN_DATA_PATH = os.path.join(DATA_DIR, "train.csv")
17
+ TEST_DATA_PATH = os.path.join(DATA_DIR, "test.csv")
18
+
19
+ # --- Data Fetching and Processing ---
20
+
21
+ def fetch_and_prepare_data(start_date, end_date, tickers):
22
+ """
23
+ Fetches historical data for the given tickers and processes it.
24
+ Returns a DataFrame with 'Close' prices for each ticker.
25
+ """
26
+ print(f"Fetching data from {start_date} to {end_date} for {tickers}...")
27
+
28
+ data = yf.download(tickers, start=start_date, end=end_date)
29
+
30
+ # CHANGE: Add .copy() to explicitly create a new DataFrame and avoid warnings.
31
+ close_data = data['Close'].copy()
32
+
33
+ print("\nData Head:")
34
+ print(close_data.head())
35
+
36
+ print("\nMissing values before cleaning:")
37
+ print(close_data.isnull().sum())
38
+
39
+ # Now, all inplace operations are safely performed on our own copy.
40
+ close_data.ffill(inplace=True)
41
+ close_data.bfill(inplace=True)
42
+
43
+ print("\nMissing values after cleaning:")
44
+ print(close_data.isnull().sum())
45
+
46
+ for col in close_data.columns:
47
+ close_data[col] = pd.to_numeric(close_data[col], errors='coerce')
48
+
49
+ close_data.dropna(inplace=True)
50
+
51
+ return close_data
52
+
53
+ def main():
54
+ """Main function to run the data fetching process."""
55
+ # Create data directory if it doesn't exist
56
+ if not os.path.exists(DATA_DIR):
57
+ os.makedirs(DATA_DIR)
58
+ print(f"Created directory: {DATA_DIR}")
59
+
60
+ # Fetch, process, and save training data
61
+ print("--- Preparing Training Data ---")
62
+ train_data = fetch_and_prepare_data(TRAIN_START_DATE, TRAIN_END_DATE, TICKERS)
63
+ train_data.to_csv(TRAIN_DATA_PATH)
64
+ print(f"Training data saved to {TRAIN_DATA_PATH}")
65
+
66
+ print("\n" + "="*50 + "\n")
67
+
68
+ # Fetch, process, and save testing data
69
+ print("--- Preparing Testing Data ---")
70
+ test_data = fetch_and_prepare_data(TEST_START_DATE, TEST_END_DATE, TICKERS)
71
+ test_data.to_csv(TEST_DATA_PATH)
72
+ print(f"Testing data saved to {TEST_DATA_PATH}")
73
+
74
+ if __name__ == "__main__":
75
+ main()
scripts/fetch_market_data.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pandas as pd
4
+ import yfinance as yf
5
+ from datetime import date
6
+
7
+ def fetch_data(start_date, end_date, output_filename):
8
+ """
9
+ Fetches, cleans, and saves historical market data for a given date range.
10
+
11
+ Args:
12
+ start_date (str): The start date for the data in 'YYYY-MM-DD' format.
13
+ end_date (str): The end date for the data in 'YYYY-MM-DD' format.
14
+ output_filename (str): The path and name of the file to save the data.
15
+ """
16
+ print(f"--- Fetching data from {start_date} to {end_date} ---")
17
+
18
+ # Define the base list of tickers
19
+ tickers = ["AAPL", "MSFT", "SPY", "TLT", "BTC-USD"]
20
+
21
+ # Smartly remove Bitcoin if the period is before its existence (e.g., before 2013)
22
+ if pd.to_datetime(start_date).year < 2013:
23
+ print("Note: Bitcoin (BTC-USD) did not exist for the requested period and will be excluded.")
24
+ tickers.remove("BTC-USD")
25
+
26
+ # Download data from Yahoo Finance
27
+ data = yf.download(tickers, start=start_date, end=end_date)
28
+ close_data = data['Close'].copy()
29
+
30
+ # Data Cleaning
31
+ print(f"\nMissing values before cleaning:\n{close_data.isnull().sum()}")
32
+ close_data.ffill(inplace=True)
33
+ close_data.bfill(inplace=True)
34
+
35
+ # Drop any columns that are still all NaN (like BTC in the 2008 data)
36
+ close_data.dropna(axis=1, how='all', inplace=True)
37
+
38
+ print(f"\nMissing values after cleaning:\n{close_data.isnull().sum()}")
39
+
40
+ # Ensure data directory exists
41
+ output_dir = os.path.dirname(output_filename)
42
+ if output_dir and not os.path.exists(output_dir):
43
+ os.makedirs(output_dir)
44
+
45
+ # Save to CSV
46
+ close_data.to_csv(output_filename)
47
+ print(f"\nβœ… Data successfully saved to {output_filename}")
48
+
49
+
50
+ if __name__ == "__main__":
51
+ # Set up command-line argument parsing
52
+ parser = argparse.ArgumentParser(description="Fetch historical market data for specified periods.")
53
+
54
+ parser.add_argument(
55
+ "--start",
56
+ type=str,
57
+ default="2018-01-01",
58
+ help="Start date in YYYY-MM-DD format. Default is for the 2018 stress test."
59
+ )
60
+ parser.add_argument(
61
+ "--end",
62
+ type=str,
63
+ default="2019-12-31",
64
+ help="End date in YYYY-MM-DD format. Default is for the 2018 stress test."
65
+ )
66
+ parser.add_argument(
67
+ "--filename",
68
+ type=str,
69
+ default="data/stress_test_2018.csv",
70
+ help="Output file name (e.g., 'data/my_data.csv')."
71
+ )
72
+
73
+ args = parser.parse_args()
74
+
75
+ # Use 'today' as the end date if specified
76
+ end_date = date.today().strftime('%Y-%m-%d') if args.end.lower() == 'today' else args.end
77
+
78
+ fetch_data(start_date=args.start, end_date=end_date, output_filename=args.filename)
scripts/stress_test.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pandas as pd
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from matplotlib.ticker import FuncFormatter
7
+
8
+ # Import all agent classes and the environment
9
+ from stable_baselines3 import PPO, SAC, TD3
10
+ from src.environment import PortfolioEnv
11
+
12
+ # --- Helper Functions ---
13
+ def evaluate_agent(env, model):
14
+ """Runs a trained agent on a given environment."""
15
+ obs, info = env.reset()
16
+ terminated, truncated = False, False
17
+ portfolio_values = [env.initial_balance]
18
+ while not (terminated or truncated):
19
+ action, _states = model.predict(obs, deterministic=True)
20
+ obs, reward, terminated, truncated, info = env.step(action)
21
+ portfolio_values.append(info['portfolio_value'])
22
+ return pd.Series(portfolio_values, index=env.df.index[:len(portfolio_values)])
23
+
24
+ def buy_and_hold(df, initial_balance=10000):
25
+ """Simulates the Buy and Hold strategy."""
26
+ n_assets = len(df.columns)
27
+ initial_investment_per_asset = initial_balance / n_assets
28
+ initial_prices = df.iloc[0]
29
+ shares = initial_investment_per_asset / initial_prices
30
+ portfolio_values = df.dot(shares)
31
+ return portfolio_values
32
+
33
+ def calculate_metrics(portfolio_values):
34
+ """Calculates performance metrics from a portfolio value series."""
35
+ total_return = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) - 1
36
+ num_years = (portfolio_values.index[-1] - portfolio_values.index[0]).days / 365.25
37
+ cagr = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) ** (1/num_years) - 1 if num_years > 0 else 0
38
+ daily_returns = portfolio_values.pct_change().dropna()
39
+ sharpe_ratio = np.sqrt(252) * (daily_returns.mean() / daily_returns.std()) if daily_returns.std() != 0 else 0
40
+ rolling_max = portfolio_values.cummax()
41
+ daily_drawdown = portfolio_values / rolling_max - 1.0
42
+ max_drawdown = daily_drawdown.min()
43
+ return {
44
+ "Total Return": f"{total_return:.2%}", "CAGR": f"{cagr:.2%}",
45
+ "Sharpe Ratio": f"{sharpe_ratio:.2f}", "Max Drawdown": f"{max_drawdown:.2%}"
46
+ }
47
+
48
+ # --- Main Stress Test Function ---
49
+ def run_stress_test(datafile_path, ppo_path, sac_path, td3_path, output_path):
50
+ """
51
+ Loads data and models, runs evaluations, and plots the comparison.
52
+ """
53
+ print(f"--- Running Stress Test on {datafile_path} ---")
54
+
55
+ # 1. Load Data
56
+ try:
57
+ test_df = pd.read_csv(datafile_path, index_col='Date', parse_dates=True)
58
+ except FileNotFoundError:
59
+ print(f"❌ Error: Data file not found at {datafile_path}")
60
+ return
61
+
62
+ # Check for asset mismatch (e.g., 4 assets in 2008 data vs 5-asset models)
63
+ # The standard models were trained on 5 assets (e.g., shape = 30 * 5 = 150)
64
+ expected_assets = 5
65
+ if test_df.shape[1] != expected_assets:
66
+ print(f"⚠️ Warning: Models were trained on {expected_assets} assets, but this dataset has {test_df.shape[1]}.")
67
+ print("Skipping agent evaluation for this dataset.")
68
+ return
69
+
70
+ # 2. Define Models to Evaluate
71
+ models_to_evaluate = {
72
+ "PPO Agent": (PPO, ppo_path),
73
+ "SAC Agent": (SAC, sac_path),
74
+ "TD3 Agent": (TD3, td3_path)
75
+ }
76
+
77
+ portfolio_values = {}
78
+ metrics = {}
79
+
80
+ # 3. Run Evaluations
81
+ for name, (agent_type, model_path) in models_to_evaluate.items():
82
+ if os.path.exists(model_path):
83
+ print(f"--- Evaluating {name} ---")
84
+ model = agent_type.load(model_path)
85
+ env = PortfolioEnv(test_df)
86
+ portfolio_values[name] = evaluate_agent(env, model)
87
+ metrics[name] = calculate_metrics(portfolio_values[name])
88
+ else:
89
+ print(f"⚠️ Warning: Model file not found at {model_path}. Skipping.")
90
+
91
+ # Evaluate Buy and Hold Baseline
92
+ print("\n--- Evaluating Buy and Hold Baseline ---")
93
+ bnh_values = buy_and_hold(test_df)
94
+ portfolio_values["Buy and Hold"] = bnh_values
95
+ metrics["Buy and Hold"] = calculate_metrics(bnh_values)
96
+
97
+ # 4. Display Results
98
+ print("\n--- Stress Test Performance Metrics ---")
99
+ metrics_df = pd.DataFrame(metrics)
100
+ print(metrics_df)
101
+
102
+ # 5. Plotting
103
+ plt.style.use('seaborn-v0_8-darkgrid')
104
+ fig, ax = plt.subplots(figsize=(14, 8))
105
+
106
+ colors = {"PPO Agent": "red", "SAC Agent": "green", "TD3 Agent": "orange", "Buy and Hold": "blue"}
107
+ for name, values in portfolio_values.items():
108
+ ax.plot(values.index, values, label=name, color=colors.get(name, 'black'), linewidth=2)
109
+
110
+ plot_title = f"Agent Stress Test: {os.path.basename(datafile_path).replace('.csv', '')}"
111
+ ax.set_title(plot_title, fontsize=16)
112
+ ax.set_xlabel('Date', fontsize=12)
113
+ ax.set_ylabel('Portfolio Value ($)', fontsize=12)
114
+ ax.legend(fontsize=12)
115
+
116
+ formatter = FuncFormatter(lambda x, p: f'${x:,.0f}')
117
+ ax.yaxis.set_major_formatter(formatter)
118
+
119
+ plt.tight_layout()
120
+ plt.savefig(output_path)
121
+ print(f"\nβœ… Plot saved to {output_path}")
122
+ plt.show()
123
+
124
+
125
+ if __name__ == '__main__':
126
+ parser = argparse.ArgumentParser(description="Run a stress test on trained RL portfolio agents.")
127
+
128
+ parser.add_argument("--datafile", type=str, default="data/stress_test_2018.csv", help="Path to the market data CSV file for the test.")
129
+ parser.add_argument("--ppopath", type=str, default="checkpoints/ppo_portfolio_model.zip", help="Path to the trained PPO model.")
130
+ parser.add_argument("--sacpath", type=str, default="checkpoints/sac_portfolio_model.zip", help="Path to the trained SAC model.")
131
+ parser.add_argument("--td3path", type=str, default="checkpoints/td3_portfolio_model.zip", help="Path to the trained TD3 model.")
132
+ parser.add_argument("--output", type=str, default="results/stress_test_comparison.png", help="Path to save the output plot.")
133
+
134
+ args = parser.parse_args()
135
+
136
+ run_stress_test(
137
+ datafile_path=args.datafile,
138
+ ppo_path=args.ppopath,
139
+ sac_path=args.sacpath,
140
+ td3_path=args.td3path,
141
+ output_path=args.output
142
+ )
scripts/train.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pandas as pd
3
+ from stable_baselines3 import PPO, SAC, TD3
4
+ from environment import PortfolioEnv
5
+
6
+ def train_agent(agent_name="td3", timesteps=100000):
7
+ """
8
+ Main function to train a specified RL agent.
9
+
10
+ Args:
11
+ agent_name (str): The RL algorithm to use ('ppo', 'sac', or 'td3').
12
+ timesteps (int): The total number of timesteps for training.
13
+ """
14
+ # 1. Map agent names to their corresponding classes
15
+ AGENT_CLASSES = {
16
+ "ppo": PPO,
17
+ "sac": SAC,
18
+ "td3": TD3
19
+ }
20
+ agent_class = AGENT_CLASSES.get(agent_name.lower())
21
+ if agent_class is None:
22
+ raise ValueError(f"Unknown agent: {agent_name}. Choose from {list(AGENT_CLASSES.keys())}")
23
+
24
+ model_name = agent_name.lower()
25
+
26
+ # 2. Load data and create the environment
27
+ print("--- Loading Data and Creating Environment ---")
28
+ try:
29
+ df = pd.read_csv('data/train.csv', index_col='Date', parse_dates=True)
30
+ env = PortfolioEnv(df)
31
+ print("Environment created successfully.")
32
+ except FileNotFoundError:
33
+ print("❌ Error: 'data/train.csv' not found. Make sure to run a data fetching script first.")
34
+ return
35
+
36
+ # 3. Create the RL Agent
37
+ print(f"--- Creating {agent_name.upper()} Agent ---")
38
+ model = agent_class(
39
+ "MlpPolicy",
40
+ env,
41
+ verbose=1,
42
+ tensorboard_log="./tensorboard_logs/"
43
+ )
44
+
45
+ # 4. Train the Agent
46
+ print(f"--- Starting Agent Training for {timesteps} timesteps ---")
47
+ model.learn(total_timesteps=timesteps)
48
+ print("--- Agent Training Complete ---")
49
+
50
+ # 5. Save the Trained Model
51
+ save_path = f"checkpoints/{model_name}_portfolio_model"
52
+ model.save(save_path)
53
+ print(f"βœ… Model saved to checkpoints/{save_path}.zip")
54
+
55
+
56
+ if __name__ == "__main__":
57
+ # 6. Set up command-line argument parsing
58
+ parser = argparse.ArgumentParser(description="Train a Reinforcement Learning agent for portfolio management.")
59
+
60
+ parser.add_argument(
61
+ "--agent",
62
+ type=str,
63
+ default="td3",
64
+ choices=["ppo", "sac", "td3"],
65
+ help="The RL algorithm to use for training (default: td3)."
66
+ )
67
+ parser.add_argument(
68
+ "--timesteps",
69
+ type=int,
70
+ default=100000,
71
+ help="The total number of timesteps for training (default: 100000)."
72
+ )
73
+
74
+ args = parser.parse_args()
75
+
76
+ # Call the main training function with the parsed arguments
77
+ train_agent(agent_name=args.agent, timesteps=args.timesteps)
scripts/visualize_strategy.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pandas as pd
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from matplotlib.ticker import FuncFormatter
7
+ from stable_baselines3 import PPO, SAC, TD3
8
+ from environment import PortfolioEnv
9
+
10
+ def visualize_strategy(agent_name, checkpoint_path, datafile_path, output_path):
11
+ """
12
+ Loads a trained agent, runs a simulation, and plots its portfolio allocation strategy.
13
+
14
+ Args:
15
+ agent_name (str): The type of agent to load ('ppo', 'sac', 'td3').
16
+ checkpoint_path (str): The path to the saved model checkpoint file (.zip).
17
+ datafile_path (str): The path to the CSV market data for the simulation.
18
+ output_path (str): The path to save the output plot image.
19
+ """
20
+ print(f"--- Visualizing strategy for {agent_name.upper()} agent ---")
21
+
22
+ # 1. Define a mapping from agent names to their classes
23
+ AGENT_CLASSES = {
24
+ "ppo": PPO,
25
+ "sac": SAC,
26
+ "td3": TD3
27
+ }
28
+ agent_class = AGENT_CLASSES[agent_name.lower()]
29
+
30
+ # 2. Load Data and Model
31
+ try:
32
+ test_df = pd.read_csv(datafile_path, index_col='Date', parse_dates=True)
33
+ model = agent_class.load(checkpoint_path)
34
+ except FileNotFoundError as e:
35
+ print(f"❌ Error: Could not find a required file. {e}")
36
+ return
37
+ except Exception as e:
38
+ print(f"❌ An error occurred: {e}")
39
+ return
40
+
41
+ # 3. Create Environment and Run Simulation
42
+ env = PortfolioEnv(test_df)
43
+ obs, info = env.reset()
44
+ terminated, truncated = False, False
45
+
46
+ weights_history = [info['weights']]
47
+ while not (terminated or truncated):
48
+ action, _states = model.predict(obs, deterministic=True)
49
+ obs, reward, terminated, truncated, info = env.step(action)
50
+ weights_history.append(info['weights'])
51
+ print("βœ… Simulation complete.")
52
+
53
+ # 4. Prepare Data for Plotting
54
+ weights_df = pd.DataFrame(weights_history)
55
+ asset_names = test_df.columns.tolist() + ['Cash']
56
+ weights_df.columns = asset_names
57
+ weights_df.index = test_df.index[:len(weights_df)]
58
+
59
+ # 5. Plotting the Stacked Area Chart
60
+ print("πŸ“Š Generating plot...")
61
+ plt.style.use('seaborn-v0_8-darkgrid')
62
+ fig, ax = plt.subplots(figsize=(15, 8))
63
+
64
+ ax.stackplot(weights_df.index, weights_df.T, labels=weights_df.columns, alpha=0.8)
65
+
66
+ ax.set_title(f'Agent Portfolio Allocation Over Time ({agent_name.upper()})', fontsize=16)
67
+ ax.set_xlabel('Date', fontsize=12)
68
+ ax.set_ylabel('Portfolio Allocation (%)', fontsize=12)
69
+ ax.legend(loc='upper left', fontsize=10)
70
+
71
+ formatter = FuncFormatter(lambda y, p: f'{y:.0%}')
72
+ ax.yaxis.set_major_formatter(formatter)
73
+
74
+ plt.tight_layout()
75
+
76
+ # Ensure output directory exists
77
+ output_dir = os.path.dirname(output_path)
78
+ if output_dir and not os.path.exists(output_dir):
79
+ os.makedirs(output_dir)
80
+
81
+ plt.savefig(output_path)
82
+ print(f"βœ… Plot saved to {output_path}")
83
+ plt.show()
84
+
85
+
86
+ if __name__ == "__main__":
87
+ # Set up command-line argument parsing
88
+ parser = argparse.ArgumentParser(description="Visualize a trained RL agent's portfolio allocation strategy.")
89
+
90
+ parser.add_argument(
91
+ "--agent",
92
+ type=str,
93
+ required=True,
94
+ choices=["ppo", "sac", "td3"],
95
+ help="The RL algorithm of the trained agent."
96
+ )
97
+ parser.add_argument(
98
+ "--checkpoint",
99
+ type=str,
100
+ required=True,
101
+ help="Path to the saved model checkpoint .zip file (e.g., 'td3_portfolio_model.zip')."
102
+ )
103
+ parser.add_argument(
104
+ "--datafile",
105
+ type=str,
106
+ default="data/test.csv",
107
+ help="Path to the market data CSV file to run the simulation on."
108
+ )
109
+ parser.add_argument(
110
+ "--output",
111
+ type=str,
112
+ default="results/agent_allocation.png",
113
+ help="Path to save the output plot image."
114
+ )
115
+
116
+ args = parser.parse_args()
117
+
118
+ visualize_strategy(
119
+ agent_name=args.agent,
120
+ checkpoint_path=args.checkpoint,
121
+ datafile_path=args.datafile,
122
+ output_path=args.output
123
+ )