DanielKiani commited on
Commit
349ad65
·
1 Parent(s): 1b637c6

Version 1.0 release

Browse files
README.md CHANGED
@@ -1,11 +1,21 @@
1
  ![Banner](assets/banner.png)
2
  [![Python](https://img.shields.io/badge/Python-3.12.11-blue?logo=python)](https://www.python.org/)[![PyTorch](https://img.shields.io/badge/PyTorch-2.8-EE4C2C?logo=pytorch)](https://pytorch.org/)![Made with ML](https://img.shields.io/badge/Made%20with-ML-blueviolet?logo=openai)[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
3
 
4
- # 🤖 Portfolio Optimization with Deep Reinforcement Learning
5
 
6
- This project explores the use of Deep Reinforcement Learning to train autonomous agents for financial portfolio management. The goal was not just to create a single profitable agent, but to conduct a comparative study of different RL algorithms (PPO, SAC, TD3) to understand the emergent trading strategies and their robustness across various market conditions.
7
 
8
- **The ultimate finding? A TD3-based agent learned a superior, risk-managed static asset allocation that consistently outperformed both active trading strategies and aggressive growth models, especially during market downturns.**
 
 
 
 
 
 
 
 
 
 
9
 
10
  ---
11
 
@@ -13,16 +23,12 @@ This project explores the use of Deep Reinforcement Learning to train autonomous
13
 
14
  1. [📊 The Data & Asset Selection](#-the-data--asset-selection)
15
  2. [🎯 Benchmarking Against Baselines](#-benchmarking-against-baselines)
16
- 3. [🏆 Key Findings & The Champion Agent](#-key-findings--the-champion-agent)
17
- 4. [🧠 Comparative Analysis of Agent Strategies](#-comparative-analysis-of-agent-strategies)
18
- * [🥇 TD3: The Prudent Risk-Manager](#-td3-the-prudent-risk-manager)
19
- * [🚀 SAC: The Aggressive Growth Engine](#-sac-the-aggressive-growth-engine)
20
- * [📈 PPO: The Active (but Inconsistent) Trader](#-ppo-the-active-but-inconsistent-trader)
21
- 5. [🌪️ Stress Testing: The Ultimate Test of Robustness](#️-stress-testing-the-ultimate-test-of-robustness)
22
- 6. [🔬 The Research Journey: Why Simplicity Won](#-the-research-journey-why-simplicity-won)
23
- 7. [✅ Conclusion](#-conclusion)
24
- 8. [📂 Project Structure](#-project-structure)
25
- 9. [🚀 How to Run](#-how-to-run)
26
  * [Setup](#setup)
27
  * [Data Fetching](#data-fetching)
28
  * [Training](#training)
@@ -32,16 +38,23 @@ This project explores the use of Deep Reinforcement Learning to train autonomous
32
 
33
  ## 📊 The Data & Asset Selection
34
 
35
- The foundation of any financial machine learning project is the data. This project uses daily closing price data sourced from **Yahoo Finance** via the `yfinance` library. The primary training period was **2015-2020**, with out-of-sample testing conducted on **2021-2023** and other periods for stress testing.
 
 
 
 
 
 
 
 
36
 
37
- The selection of assets was crucial for creating a realistic decision-making environment for the agent. The portfolio consists of five assets, chosen to represent different classes and risk profiles:
38
 
39
  * **Growth Equities (AAPL, MSFT):** Represent the high-growth, high-volatility technology sector.
40
  * **Market Index (SPY):** An ETF tracking the S&P 500, representing the broader US stock market.
41
  * **Safe Haven (TLT):** An ETF for 20+ Year US Treasury Bonds, which often acts as a "risk-off" asset during stock market downturns.
42
  * **Alternative Asset (BTC-USD):** Represents a non-traditional, extremely volatile asset class with high potential returns.
43
 
44
- This diverse mix forces the agent to learn not just about individual assets, but also about their correlations and how to balance risk across different economic regimes.
45
 
46
  ---
47
 
@@ -57,73 +70,83 @@ The chart below shows the performance of a simple Buy and Hold strategy during t
57
 
58
  ---
59
 
60
- ## 🏆 Key Findings & The Champion Agent
61
 
62
- After extensive training, evaluation, and stress-testing, the **TD3 agent emerged as the clear winner** on a risk-adjusted basis. While other agents achieved higher raw returns, their strategies proved to be brittle and dangerously volatile during market crises. The TD3 agent's strategy was the most robust and reliable.
63
 
64
- #### Final Performance Comparison (2021-2023)
65
 
66
- This table summarizes the performance of the top-performing static agents against the baseline.
67
 
68
- | Metric | **TD3 Agent** | SAC Agent | Buy & Hold |
69
- | :--- | :--- | :--- | :--- |
70
- | **Total Return** | 47.24% | **50.89%** | 34.91% |
71
- | **CAGR** | 13.76% | **14.70%** | 10.50% |
72
- | **Sharpe Ratio** | **0.62** | 0.51 | 0.45 |
73
- | **Max Drawdown** | **-28.41%** | -44.61% | -40.81% |
74
 
75
- The TD3 agent delivered strong returns while significantly reducing the maximum drawdown, proving its superior capital preservation strategy.
 
 
 
 
 
 
76
 
77
  ![Main Performance Chart](results/final_performance_comparison_all_agents.png)
 
78
 
79
- ---
80
 
81
- ## 🧠 Comparative Analysis of Agent Strategies
82
 
83
- A fascinating outcome of this project was observing three different RL algorithms independently discover three distinct and recognizable investment philosophies.
 
 
84
 
85
- ### 🥇 TD3: The Prudent Risk-Manager
86
 
87
- The TD3 agent concluded that the most effective strategy was not to trade frequently, but to find one **superior, risk-managed static asset allocation** and hold it.
88
 
89
- * **Strategy:** "Smarter Buy and Hold".
90
- * **Behavior:** The agent's allocation is completely static, indicating it focused on the initial strategic decision and ignored market noise to minimize transaction costs.
91
- * **Result:** This approach led to the best risk-adjusted returns, proving that a robust initial setup is more valuable than reactive trading.
92
 
93
- ![TD3 Allocation Chart](results/td3_portfolio_alocation.png)
94
 
95
- ### 🚀 SAC: The Aggressive Growth Engine
96
 
97
- The SAC agent also learned a static allocation strategy, but its portfolio was geared for **maximum growth**, accepting higher risk for higher potential returns.
98
 
99
- * **Strategy:** High-risk, high-return static allocation.
100
- * **Behavior:** Like TD3, it made one initial allocation and held firm. However, this allocation was far more aggressive.
101
- * **Result:** It achieved the highest total return in some periods but suffered catastrophic drawdowns in stress tests, making its strategy unreliable and brittle.
102
 
103
- ![SAC Performance Chart](results/sac_portfolio_alocation.png)
104
 
105
- ### 📈 PPO: The Active (but Inconsistent) Trader
106
 
107
- Unlike the other two, the PPO agent learned an **active, dynamic trading strategy**, constantly adjusting its portfolio based on market conditions.
108
 
109
- * **Strategy:** Tactical asset allocation.
110
- * **Behavior:** The allocation chart clearly shows the agent rebalancing its portfolio over time, for example, by increasing its bond (TLT) holdings during the 2022 downturn.
111
- * **Result:** While impressive that it learned this behavior, its performance was inconsistent. It succeeded in some periods (2018) but failed in others (2025), highlighting the immense difficulty of successful market timing.
112
 
113
- ![PPO Allocation Chart](results/ppo_portfolio_alocation.png)
114
 
115
- ---
 
 
116
 
117
- ## 🌪️ Stress Testing: The Ultimate Test of Robustness
118
 
119
- A model is only as good as its performance during a crisis. We subjected the agents to multiple out-of-sample stress tests, with the 2018 period (featuring a crypto winter and a stock market flash crash) being the most revealing.
120
 
121
- ![2018 Stress Test Chart](results/stress_test_comparison_2018.png)
122
 
123
- * **TD3's Triumph:** The orange line shows the TD3 agent successfully navigating the downturn, preserving capital far better than the baseline.
124
- * **SAC's Failure:** The green line shows the SAC agent's aggressive strategy failing catastrophically, resulting in a massive drawdown.
125
 
126
- This test definitively proved that the **TD3 agent's risk-managed approach was truly robust**, while the SAC agent's strategy was fragile.
 
 
 
 
 
 
 
 
 
 
127
 
128
  ---
129
 
@@ -135,8 +158,39 @@ This project was also an exercise in scientific methodology. We initially hypoth
135
  * **Hypothesis 2: Models with memory are better.** We tested an LSTM-based agent (`RecurrentPPO`). **Result:** Performance degraded. The added complexity led to overfitting on the training data.
136
  * **Hypothesis 3: Using Regularization is better.** We tested both L1 and L2 regularization. **Results:** Performance degraded.
137
  * **Hypothesis 4: Increasing the window from 30 days is better.** We tested increasing the window to 60 days. **Results:** Performance degraded. increasing the context window is not always good and it could be seen as more noise for the model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- The conclusion was clear: for this problem, a simple and elegant model (a standard MLP fed with just normalized price data) was the most effective.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  ---
142
 
@@ -148,68 +202,94 @@ This project successfully demonstrates that Deep Reinforcement Learning can be a
148
 
149
  ## 📂 Project Structure
150
 
151
- The codebase is organized into modular, reusable scripts.
152
-
153
  ```bash
154
- ├── assets/
155
- ├── checkpoints/ # Holds all saved model .zip files
156
- ├── results/ # Holds all output plots and metrics
157
- ├── scripts/
158
- ├── environment.py # The custom Gymnasium environment for the simulation
159
- ├── fetch_market_data.py# A flexible script to download data for any period
160
- ├── train.py # The main training script with model selection
161
- ├── evaluate.py # The main evaluation script for generating metrics
162
- ├── stress_test.py # Runs a full comparison of all agents on a given dataset
163
- └── visualize_strategy.py # Plots the asset allocation of a single trained agent
164
- └── README.md # This file
 
 
 
 
 
 
 
 
 
165
  ```
166
 
167
- ---
168
-
169
  ## 🚀 How to Run
170
 
171
  ### Setup
172
 
173
- 1. Clone the repository.
174
- 2. Create and activate a Python virtual environment.
175
- 3. Install the required packages:
176
 
177
- ```bash
178
- pip install -r requirements.txt
179
- ```
180
 
181
- ### Data Fetching
 
182
 
183
- Use the flexible `fetch_market_data.py` script to get any data you need.
184
 
185
  ```bash
186
- # Fetch the default training data (2015-2021)
187
- python fetch_market_data.py --start 2015-01-01 --end 2020-12-31 --filename data/train.csv
 
 
 
 
188
 
189
- # Fetch data for a stress test (e.g., 2022)
190
- python fetch_market_data.py --start 2022-01-01 --end 2022-12-31 --filename data/test_2022.csv
 
 
 
 
191
  ```
192
 
193
  ### Training
194
 
195
- Use the `train.py` script to train any of the three main agents.
196
 
197
- ```bash
198
- # Train the champion TD3 agent (default)
199
- python src/train.py --agent td3
200
 
201
  # Train a SAC agent for more timesteps
202
- python src/train.py --agent sac --timesteps 100000
203
  ```
204
 
 
 
205
  ### Evaluation & Visualization
206
 
207
- Use the dedicated scripts to analyze the results.
208
 
209
- ```bash
210
- # Run a full stress test on the 2018 data
211
- python stress_test.py --datafile data/stress_test_2018.csv
 
 
 
 
212
 
213
- # Visualize the TD3 agent's strategy
214
- python visualize_strategy.py --agent td3 --checkpoint td3_portfolio_model.zip
 
 
215
  ```
 
 
 
 
 
 
 
 
 
1
  ![Banner](assets/banner.png)
2
  [![Python](https://img.shields.io/badge/Python-3.12.11-blue?logo=python)](https://www.python.org/)[![PyTorch](https://img.shields.io/badge/PyTorch-2.8-EE4C2C?logo=pytorch)](https://pytorch.org/)![Made with ML](https://img.shields.io/badge/Made%20with-ML-blueviolet?logo=openai)[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
3
 
4
+ # 🤖 Portfolio Optimization with Deep Reinforcement Learning (v1.0)
5
 
6
+ This project explores the use of Deep Reinforcement Learning (DRL) to train autonomous agents for financial portfolio management. The goal is to create agents that can dynamically allocate capital across a diverse set of assets to maximize returns while managing risk.
7
 
8
+ This is **Version 1.0** of the project, which moves beyond initial exploration to a more robust and comparative study. Building on the foundation of v0.1, this version introduces:
9
+
10
+ * **Comparative Analysis:** We train and evaluate three state-of-the-art DRL algorithms: **Proximal Policy Optimization (PPO)**, **Soft Actor-Critic (SAC)**, and **Twin Delayed DDPG (TD3)**. This allows us to understand the different emergent strategies and trade-offs of each approach.
11
+ * **Robust Benchmarking:** Agents' performance is rigorously compared against a standard **Buy and Hold** baseline, using a comprehensive set of financial metrics including Total Return, CAGR, Sharpe Ratio, Sortino Ratio, and Max Drawdown.
12
+ * **Modular Codebase:** The project has been refactored into a clean, modular structure with separate scripts for data fetching, training, evaluation, and visualization, making it easier to understand, extend, and reproduce results.
13
+ * **In-Depth Analysis:** We delve into *why* certain agents perform better, visualizing their asset allocation strategies over time to uncover their "investment philosophy."
14
+
15
+ * **Deep RL & LLM Portfolio Manager (Web App):** A key feature of v1.0 is the interactive web application built with **Gradio**. This dashboard bridges the gap between complex backend models and user-friendly analysis, allowing for live tracking, forward-looking strategy generation, and historical backtesting.
16
+ The dashboard integrates **Large Language Models (LLMs)**, specifically Qwen, to act as an AI Risk Analyst, providing textual justification and risk assessments for the RL agent's proposed strategies.
17
+
18
+ *You try the webapp here ->* [Gradio webapp](https://huggingface.co/spaces/DanielKiani/Portfolio-Optimization-with-Deep-Reinforcement-Learning)
19
 
20
  ---
21
 
 
23
 
24
  1. [📊 The Data & Asset Selection](#-the-data--asset-selection)
25
  2. [🎯 Benchmarking Against Baselines](#-benchmarking-against-baselines)
26
+ 3. [🏆 Key Findings & The New Champion](#-key-findings--the-new-champion)
27
+ 4. [🔬 The Research Journey: Why Simplicity Won](#-the-research-journey-why-simplicity-won)
28
+ 5. [🖥️ Deep RL & LLM Portfolio Manager (Web App)](#️-deep-rl--llm-portfolio-manager-web-app)
29
+ 6. [ Conclusion](#-conclusion)
30
+ 7. [📂 Project Structure](#-project-structure)
31
+ 8. [🚀 How to Run](#-how-to-run)
 
 
 
 
32
  * [Setup](#setup)
33
  * [Data Fetching](#data-fetching)
34
  * [Training](#training)
 
38
 
39
  ## 📊 The Data & Asset Selection
40
 
41
+ The foundation of any financial machine learning project is the data. The primary source for daily closing price data of the portfolio assets is **Yahoo Finance**, accessed via the `yfinance` library.
42
+
43
+ To provide the agents with broader economic context beyond just price history, the observation space is augmented with key macroeconomic indicators sourced from **FRED (Federal Reserve Economic Data)**. These indicators include data points such as the CBOE Volatility Index (VIX), various Treasury bill yields, and inflation expectations. This allows the agents to learn strategies that adapt to different market regimes, such as high volatility or rising interest rate environments.
44
+
45
+ **Environment & Realistic Constraints:**
46
+ To ensure realistic simulation results, the trading environment incorporates transaction costs.
47
+ * **Transaction Cost:** A fee of **0.001%** is applied to the notional value of every trade (both buys and sells). This forces the agents to learn strategies that generate returns net of fees, discouraging excessive, unprofitable trading.
48
+
49
+ The portfolio itself consists of five assets, chosen to represent different asset classes and risk profiles, creating a challenging decision-making environment:
50
 
 
51
 
52
  * **Growth Equities (AAPL, MSFT):** Represent the high-growth, high-volatility technology sector.
53
  * **Market Index (SPY):** An ETF tracking the S&P 500, representing the broader US stock market.
54
  * **Safe Haven (TLT):** An ETF for 20+ Year US Treasury Bonds, which often acts as a "risk-off" asset during stock market downturns.
55
  * **Alternative Asset (BTC-USD):** Represents a non-traditional, extremely volatile asset class with high potential returns.
56
 
57
+ This diverse mix forces the agent to learn not just about individual asset price movements, but also about their correlations and how to balance risk across different economic conditions.
58
 
59
  ---
60
 
 
70
 
71
  ---
72
 
73
+ ## 🏆 Key Findings & The New Champion
74
 
75
+ Our latest evaluation on out-of-sample data from **2021-2023** has yielded surprising and significant results, challenging our initial assumptions and highlighting the impact of neural network architecture on agent performance.
76
 
77
+ The **TD3 agent powered by a Transformer architecture** has emerged as the undisputed champion in terms of risk-adjusted returns and capital preservation, while the **SAC agent** demonstrated the highest absolute growth potential.
78
 
79
+ #### Final Performance Comparison (2021-2023)
80
 
81
+ This table summarizes the performance of our key agents against the Buy & Hold baseline.
 
 
 
 
 
82
 
83
+ | Metric | **TD3 (Transformer)** | SAC (MLP) | Buy & Hold | PPO (MLP) | TD3 (MLP) |
84
+ | :--- | :--- | :--- | :--- | :--- | :--- |
85
+ | **Total Return** | 25.34% | **39.23%** | 32.76% | 22.85% | 22.07% |
86
+ | **CAGR** | 8.20% | **12.25%** | 9.96% | 7.45% | 7.21% |
87
+ | **Sharpe Ratio** | **0.61** | 0.56 | 0.59 | 0.41 | 0.42 |
88
+ | **Volatility** | **14.77%** | 27.47% | 19.06% | 25.90% | 23.00% |
89
+ | **Max Drawdown** | **-20.01%** | -29.08% | -28.82% | -44.26% | -40.50% |
90
 
91
  ![Main Performance Chart](results/final_performance_comparison_all_agents.png)
92
+ ***note*: bitcoin was excluded from the performance comparison**
93
 
94
+ ### 🥇 TD3 (Transformer): The Master of Risk Management
95
 
96
+ The most notable finding is the superior performance of the TD3 agent when equipped with a **Transformer-based policy network**. This agent achieved the best risk-adjusted metrics across the board.
97
 
98
+ * **Lowest Volatility (14.77%):** It provided a significantly smoother ride than even the passive Buy & Hold baseline.
99
+ * **Best Capital Preservation:** Its maximum drawdown of **-20.01%** was drastically lower than other agents and the baseline, proving its ability to protect capital during severe market downturns like the 2022 bear market.
100
+ * **Conclusion:** The Transformer's attention mechanism likely allowed the agent to better identify and react to long-term market shifts and regime changes, leading to a highly robust and defensive strategy.
101
 
102
+ ### 🚀 SAC (MLP): The Aggressive Growth Engine
103
 
104
+ The **Soft Actor-Critic (SAC)** agent confirmed its role as the high-growth strategist.
105
 
106
+ * **Highest Returns:** It achieved the highest Total Return (**39.23%**) and CAGR (**12.25%**), outperforming the Buy & Hold baseline by a significant margin.
107
+ * **Higher Risk:** This performance came at the cost of the highest volatility (**27.47%**), making it a strategy suited for aggressive investors willing to tolerate larger price swings for maximum gain.
 
108
 
109
+ ### 📉 The Failure of Standard Architectures
110
 
111
+ Interestingly, the standard Multi-Layer Perceptron (MLP) versions of PPO and TD3 failed to beat the simple Buy & Hold baseline. They suffered the lowest returns and the deepest drawdowns. This stark contrast with the success of the Transformer model highlights that for complex financial time-series, **network architecture is just as critical, if not more so, than the choice of RL algorithm itself.**
112
 
113
+ ---
114
 
115
+ ## 🧠 Comparative Analysis of Agent Strategies
 
 
116
 
117
+ A fascinating outcome of this project was observing how different combinations of RL algorithms and network architectures led to distinct investment philosophies. We can visualize this by looking at how each agent allocated its portfolio over time.
118
 
119
+ ### TD3 (Transformer): The Dynamic Hedger
120
 
121
+ The Transformer-based TD3 agent did not learn a static allocation. Instead, it developed a sophisticated, **dynamic hedging strategy**. By leveraging the Transformer's attention mechanism to process the 30-day lookback window, the agent could identify market trends and adapt its portfolio accordingly.
122
 
123
+ ![TD3 Transformer Allocation Chart](results/td3_transformer_allocation.png)
 
 
124
 
125
+ As shown in the chart, the agent maintains a core position in equities (AAPL, MSFT, SPY) but actively manages its exposure. During the volatile bear market of 2022, the agent significantly increased its allocation to the safe-haven asset **TLT (US Treasury Bonds)**, effectively "smoothing out" its equity curve and avoiding the deep losses suffered by the baseline. This ability to dynamically shift into defensive assets is the key to its superior risk-adjusted performance.
126
 
127
+ ### SAC (MLP): The High-Conviction Aggressor
128
+
129
+ The SAC agent learned a strategy that is nearly the polar opposite of the Transformer. It converged to a **high-risk, high-return static allocation strategy**. Its portfolio is heavily weighted towards high-growth assets, likely with a substantial allocation to Bitcoin (BTC-USD) and tech stocks, with very little exposure to defensive assets like bonds or cash.
130
 
131
+ ![SAC Allocation Chart](results/sac_allocation.png)
132
 
133
+ The allocation chart reveals a strategy with minimal changes over time, indicating a "set-and-forget" approach. While this high-conviction bet paid off with the highest total return, it also exposed the portfolio to significant volatility.
134
 
135
+ ### PPO (MLP): The Failed Active Trader
136
 
137
+ Unlike the other MLP-based agents which converged to static allocations, the PPO agent attempted a **dynamic, active trading strategy**.
 
138
 
139
+ ![PPO Allocation Chart](results/ppo_allocation.png)
140
+
141
+ As seen in the chart, the agent frequently rebalances its portfolio, shifting weights between equities, bonds, and cash. However, the performance metrics indicate that this activity was detrimental. With poor returns and the deepest maximum drawdown (-44.26%) among all agents, the PPO agent's attempts at market timing were unsuccessful, churning the portfolio without generating alpha or managing risk.
142
+
143
+ ### TD3 (MLP): The Failed Static Allocator
144
+
145
+ The standard MLP version of the TD3 agent also converged to a static allocation, similar to the SAC agent, but chose a clearly suboptimal portfolio.
146
+
147
+ ![TD3 MLP Allocation Chart](results/td3_allocation.png)
148
+
149
+ The chart shows a relatively fixed allocation that failed to perform well. Unlike the SAC agent, it did not capture high-growth opportunities, and unlike the Transformer agent, it lacked the dynamic capability to manage risk. This resulted in near-bottom performance across all metrics.
150
 
151
  ---
152
 
 
158
  * **Hypothesis 2: Models with memory are better.** We tested an LSTM-based agent (`RecurrentPPO`). **Result:** Performance degraded. The added complexity led to overfitting on the training data.
159
  * **Hypothesis 3: Using Regularization is better.** We tested both L1 and L2 regularization. **Results:** Performance degraded.
160
  * **Hypothesis 4: Increasing the window from 30 days is better.** We tested increasing the window to 60 days. **Results:** Performance degraded. increasing the context window is not always good and it could be seen as more noise for the model.
161
+ * **Hypothesis 5: A Transformer-based architecture is superior.** We replaced the standard Multi-Layer Perceptron (MLP) policy network with a more powerful Transformer model, hypothesizing its attention mechanism would better capture complex temporal relationships. **Result**: Performance degraded. Similar to the LSTM experiment, the Transformer model was too complex for the amount of data available. It suffered from significant overfitting, performing well on training data but failing to generalize to unseen market scenarios.
162
+
163
+ The conclusion was clear: a simple MLP (Multi-Layer Perceptron) policy network, fed with just normalized price data and a concise 30-day window, was the most effective and robust architecture.
164
+
165
+ ---
166
+
167
+ ## 🖥️ Deep RL & LLM Portfolio Manager (Web App)
168
+
169
+ A key feature of v1.0 is the interactive web application built with **Gradio**. This dashboard bridges the gap between complex backend models and user-friendly analysis, allowing for live tracking, forward-looking strategy generation, and historical backtesting.
170
+
171
+ The dashboard integrates **Large Language Models (LLMs)**, specifically Qwen, to act as an AI Risk Analyst, providing textual justification and risk assessments for the RL agent's proposed strategies.
172
+
173
+ ### Key Features:
174
+
175
+ #### 1. Live Dashboard & Net Worth Tracking
176
+
177
+ Track the current portfolio holdings, recent transactions, and the overall net worth evolution in real-time.
178
 
179
+ ![Live Dashboard](results/tab1.png)
180
+
181
+ #### 2. AI-Powered Strategy Forecast & Risk Analysis
182
+
183
+ Generate tomorrow's optimal portfolio allocation using the trained RL agents. The integrated LLM analyzes the proposed allocation, current market volatility (VIX), and asset concentration to provide a comprehensive **Risk Analyst Report** with a confidence score and justifications.
184
+
185
+ It also includes **Explainable AI (XAI)** feature importance plots to show which market factors most influenced the agent's decision.
186
+
187
+ ![AI Forecast and Risk Analysis](assets/tab2.png)
188
+
189
+ #### 3. Historical Simulation & Backtesting
190
+
191
+ Run dynamic backtests of the trained RL agents against baselines over any historical period. This tool is essential for validating performance across different market cycles.
192
+
193
+ ![Historical Simulation](assets/tab2.png)
194
 
195
  ---
196
 
 
202
 
203
  ## 📂 Project Structure
204
 
 
 
205
  ```bash
206
+ ├── assets/             # Images for the README
207
+ ├── checkpoints/        # Stores trained model weights (.zip files)
208
+ ├── data/               # Stores fetched CSV data files
209
+ ├── results/            # Stores generated plots and metrics logs
210
+ ├── scripts/            # Contains all the executable scripts
211
+   ├── app.py             # The Gradio web application
212
+   ├── check_env.py        # Simple script to verify the custom environment
213
+   ├── custom_policy.py    # Custom policy network definitions
214
+   ├── environment.py      # The custom Gymnasium environment class
215
+   ├── evaluate_baselines.py # Calculates performance of baseline strategies
216
+ │   ├── evaluate.py         # Main script to evaluate a trained agent
217
+ │   ├── fetch_market_data.py # Script to download historical data from YFinance
218
+ │   ├── llm_analysis_rag.py # Script for LLM-based analysis and RAG
219
+ │   ├── predict_tomorrow.py # Script to generate predictions for the next day
220
+ │   ├── stress_test.py      # Compares all agents on a specific dataset
221
+ │   ├── train.py            # Main script to train an RL agent
222
+ │   ├── tune_sac.py         # Script for hyperparameter tuning of the SAC agent
223
+ │   └── visualize_strategy.py # Plots the asset allocation of a trained agent
224
+ ├── requirements.txt    # List of Python dependencies
225
+ └── README.md           # This file
226
  ```
227
 
 
 
228
  ## 🚀 How to Run
229
 
230
  ### Setup
231
 
232
+ 1. Clone the repository:
 
 
233
 
234
+ ```Bash
 
 
235
 
236
+ git clone https://github.com/DanielKiani/Portfolio-Optimization-with-Deep-Reinforcement-Learning
237
+ ```
238
 
239
+ 2. Install the required packages:
240
 
241
  ```bash
242
+ pip install -r requirements.txt
243
+ ```
244
+
245
+ ### Data Fetching
246
+
247
+ Before training or evaluation, you need to download the historical market data. Use the `fetch_market_data.py` script.
248
 
249
+ ```Bash
250
+ # Fetch training data (e.g., 2015-2020)
251
+ python scripts/fetch_market_data.py --start 2015-01-01 --end 2020-12-31 --filename data/train_data.csv
252
+
253
+ # Fetch evaluation data (e.g., 2021-2023)
254
+ python scripts/fetch_market_data.py --start 2021-01-01 --end 2023-12-31 --filename data/eval_data.csv
255
  ```
256
 
257
  ### Training
258
 
259
+ Use the `train.py` script to train an agent. You can specify the algorithm (ppo, sac, or td3) and the number of training timesteps.
260
 
261
+ ```Bash
262
+ # Train a TD3 agent (default timesteps: 20000)
263
+ python scripts/train.py --agent td3 --datafile data/train_data.csv
264
 
265
  # Train a SAC agent for more timesteps
266
+ python scripts/train.py --agent sac --datafile data/train_data.csv --timesteps 50000
267
  ```
268
 
269
+ The trained model will be saved in the `checkpoints/` directory (e.g., `sac_portfolio_model.zip`).
270
+
271
  ### Evaluation & Visualization
272
 
273
+ Once you have trained models and evaluation data, you can use the other scripts to analyze performance.
274
 
275
+ * **Compare all agents** (`stress_test.py`): This script loads all available models in `checkpoints/` and compares them against the baseline on a given dataset.
276
+
277
+ ```Bash
278
+ python scripts/stress_test.py --datafile data/eval_data.csv
279
+ ```
280
+
281
+ This will generate `results/agent_performance_comparison`.png and print a metrics table.
282
 
283
+ * **Evaluate a single agent** (`evaluate.py`): This script calculates detailed metrics for a specific agent and plots its portfolio value.
284
+
285
+ ```Bash
286
+ python scripts/evaluate.py --agent td3 --checkpoint checkpoints/td3_portfolio_model.zip --datafile data/eval_data.csv
287
  ```
288
+
289
+ * **Visualize an agent's strategy** (`visualize_strategy.py`): This script creates a stacked area chart showing how the agent's asset allocation changed over time.
290
+
291
+ ```Bash
292
+ python scripts/visualize_strategy.py --agent ppo --checkpoint checkpoints/ppo_portfolio_model.zip --datafile data/eval_data.csv
293
+ ```
294
+
295
+ This will save the plot to `results/ppo_portfolio_allocation.png`.
requirements.txt CHANGED
@@ -1,19 +1,28 @@
1
- # Core RL and Simulation
2
- stable-baselines3==2.7.0
3
- sb3_contrib==2.7.0
4
- gymnasium==1.2.1
5
-
6
- # Data Handling and Numerics
7
- pandas==2.3.3
8
  numpy==2.2.6
9
- scikit-learn==1.6.1
 
10
 
11
- # Data Fetching
 
 
 
 
12
  yfinance==0.2.66
 
 
 
 
 
 
 
13
 
14
- # Financial Indicators
15
- pandas-ta==0.4.71b0
16
 
17
- # Plotting and Visualization
18
- matplotlib==3.10.0
19
- seaborn==0.13.2
 
 
 
 
1
+ # Core Data Science & Mathematics
 
 
 
 
 
 
2
  numpy==2.2.6
3
+ pandas==2.3.3
4
+ scipy==1.16.3
5
 
6
+ # Visualization
7
+ matplotlib==3.10.0
8
+ plotly==5.24.1
9
+
10
+ # Financial Data
11
  yfinance==0.2.66
12
+ pandas-datareader==0.10.0
13
+
14
+ # Reinforcement Learning
15
+ gymnasium==1.2.2
16
+ shimmy==2.0.0
17
+ stable-baselines3==2.7.0
18
+ sb3-contrib==2.7.0
19
 
20
+ # Deep Learning Framework
21
+ torch==2.9.0
22
 
23
+ # Utilities & Other
24
+ gradio==5.50.0
25
+ python-dotenv==1.2.1
26
+ tabulate==0.9.0
27
+ quantstats==0.0.62
28
+ pandas-ta==0.4.71b0
results/final_performance_comparison_all_agents.png CHANGED

Git LFS Details

  • SHA256: 75b085d93cd947906c2f6f5fdf4f1fbc0b53cdb56d2ad3a77011b6cae8c787ab
  • Pointer size: 131 Bytes
  • Size of remote file: 225 kB

Git LFS Details

  • SHA256: 0fd8cba927f0f2bed5373fb5fb44bf20e2ed4d22219196e48763fdd1f5f6787d
  • Pointer size: 131 Bytes
  • Size of remote file: 298 kB
results/{td3_portfolio_alocation.png → ppo_allocation.png} RENAMED
File without changes
results/sac_allocation.png ADDED

Git LFS Details

  • SHA256: 6975e520974d024567499f1605269abccdc2f56f04dd1ab97b32df85fc78f27e
  • Pointer size: 131 Bytes
  • Size of remote file: 335 kB
results/{ppo_portfolio_alocation.png → td3_allocation.png} RENAMED
File without changes
results/{sac_portfolio_alocation.png → td3_transformer_allocation.png} RENAMED
File without changes
scripts/app.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # scripts/app.py
2
+
3
+ import gradio as gr
4
+ import pandas as pd
5
+ import numpy as np
6
+ import plotly.graph_objects as go
7
+ import plotly.express as px
8
+ from datetime import datetime, timedelta
9
+ import os
10
+ import sys
11
+ import json
12
+ import torch
13
+ from fetch_market_data import fetch_market_data, ASSETS, FRED_IDS
14
+ from llm_analysis_rag import analyze_agent_decision, analyze_historical_segment
15
+ from stable_baselines3 import SAC
16
+ from environment import PortfolioEnv
17
+ from scripts.evaluate_baselines import buy_and_hold, equally_weighted_rebalanced
18
+
19
+ # --- Configuration ---
20
+ MODEL_PATH = os.path.join(project_root, "checkpoints", "sac_portfolio_model.zip")
21
+ WINDOW_SIZE = 30
22
+ MACRO_COLS = list(FRED_IDS.values())
23
+ DASHBOARD_DATA_PATH = os.path.join(project_root, "data", "historical_dashboard_data.csv")
24
+
25
+ # *** UPDATE THESE DATES TO MATCH YOUR ACTUAL TRAINING PERIOD ***
26
+ TRAIN_START_DATE = "2015-01-01"
27
+ TRAIN_END_DATE = "2023-01-01"
28
+
29
+ # Global variable for dashboard data needed for Tabs 3 & 4
30
+ DASHBOARD_DATA_DF = None
31
+
32
+ # Define Time Period mappings for the dropdown
33
+ TIME_PERIODS = {
34
+ "6 Months": 180,
35
+ "1 Year": 365,
36
+ "2 Years": 730,
37
+ "5 Years": 1825,
38
+ "Max Available": 9999 # Sentinel value for max
39
+ }
40
+
41
+ # =========================================
42
+ # Initialization Functions
43
+ # =========================================
44
+
45
+ def initialize_dashboard_data():
46
+ """Fetches and loads historical data at startup for Tabs 3 & 4."""
47
+ global DASHBOARD_DATA_DF
48
+ print("--- Initializing Historical Data for Analyst/Simulation Tabs ---")
49
+
50
+ # Fetching last 6 years to support longer analysis periods and simulation
51
+ end_date = (datetime.now() - timedelta(days=1)).strftime('%Y-%m-%d')
52
+ start_date = (datetime.now() - timedelta(days=365*6)).strftime('%Y-%m-%d')
53
+
54
+ print(f"Fetching historical data from {start_date} to {end_date}...")
55
+ # This might take a minute on first run
56
+ fetch_market_data(start_date, end_date, DASHBOARD_DATA_PATH)
57
+
58
+ if os.path.exists(DASHBOARD_DATA_PATH):
59
+ DASHBOARD_DATA_DF = pd.read_csv(DASHBOARD_DATA_PATH, index_col=0, parse_dates=True)
60
+ # Basic cleaning
61
+ DASHBOARD_DATA_DF.dropna(how='all', inplace=True)
62
+ # Calculate equal weight return for dashboard metrics
63
+ asset_cols = [c for c in ASSETS if c in DASHBOARD_DATA_DF.columns]
64
+ if asset_cols:
65
+ DASHBOARD_DATA_DF['Daily_Ret_Eq'] = DASHBOARD_DATA_DF[asset_cols].pct_change().mean(axis=1)
66
+ print(f"Data loaded successfully. Shape: {DASHBOARD_DATA_DF.shape}")
67
+ print(f"Data range: {DASHBOARD_DATA_DF.index.min().date()} to {DASHBOARD_DATA_DF.index.max().date()}")
68
+ else:
69
+ print("❌ Failed to initialize historical data.")
70
+
71
+ # Initialize data at startup
72
+ try:
73
+ initialize_dashboard_data()
74
+ except Exception as e:
75
+ print(f"Warning: Data initialization failed. Error: {e}")
76
+
77
+
78
+ # =========================================
79
+ # Professional Metrics & Evaluation Functions
80
+ # =========================================
81
+
82
+ def evaluate_agent_pro(env, model):
83
+ """
84
+ Runs the trained agent on the environment and returns portfolio values.
85
+ """
86
+ obs, info = env.reset()
87
+ terminated, truncated = False, False
88
+ portfolio_values = [env.initial_amount]
89
+
90
+ while not (terminated or truncated):
91
+ action, _states = model.predict(obs, deterministic=True)
92
+ obs, reward, terminated, truncated, info = env.step(action)
93
+ portfolio_values.append(info['portfolio_value'])
94
+
95
+ # Align index with the actual steps taken
96
+ valid_dates = env.df.index[env.window_size-1:]
97
+ return pd.Series(portfolio_values, index=valid_dates[:len(portfolio_values)])
98
+
99
+ def calculate_metrics_pro(portfolio_values, freq=252, rf=0.0):
100
+ """
101
+ Calculates key professional performance metrics from a series of portfolio values.
102
+ """
103
+ if len(portfolio_values) < 2:
104
+ return {k: "N/A" for k in ["Total Return", "CAGR", "Sharpe Ratio", "Sortino Ratio", "Volatility", "Max Drawdown", "Calmar Ratio"]}
105
+
106
+ returns = portfolio_values.pct_change().dropna()
107
+ if returns.empty:
108
+ return {k: "0.00%" if "%" in k else "0.00" for k in ["Total Return", "CAGR", "Sharpe Ratio", "Sortino Ratio", "Volatility", "Max Drawdown", "Calmar Ratio"]}
109
+
110
+ total_return = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) - 1
111
+ num_years = (len(portfolio_values) - 1) / freq
112
+ cagr = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) ** (1/num_years) - 1 if num_years > 0 else 0.0
113
+
114
+ sharpe_ratio = np.sqrt(freq) * (returns.mean() - rf) / returns.std() if returns.std() > 0 else np.nan
115
+
116
+ downside_returns = returns[returns < 0]
117
+ downside_std = downside_returns.std()
118
+ sortino_ratio = np.sqrt(freq) * (returns.mean() - rf) / downside_std if downside_std > 0 else np.nan
119
+
120
+ volatility = returns.std() * np.sqrt(freq)
121
+
122
+ rolling_max = portfolio_values.cummax()
123
+ drawdown = portfolio_values / rolling_max - 1.0
124
+ max_drawdown = drawdown.min()
125
+
126
+ calmar_ratio = cagr / abs(max_drawdown) if max_drawdown != 0 and cagr != 0 else np.nan
127
+
128
+ return {
129
+ "Total Return": total_return,
130
+ "CAGR": cagr,
131
+ "Sharpe Ratio": sharpe_ratio,
132
+ "Sortino Ratio": sortino_ratio,
133
+ "Volatility": volatility,
134
+ "Max Drawdown": max_drawdown,
135
+ "Calmar Ratio": calmar_ratio
136
+ }
137
+
138
+ # =========================================
139
+ # XAI: Feature Importance Function
140
+ # =========================================
141
+ def calculate_feature_importance(model, obs):
142
+ """
143
+ Calculates feature importance using Integrated Gradients on the RL agent's policy network.
144
+ """
145
+ # Convert observation to torch tensor and enable gradient tracking
146
+ obs_tensor = torch.as_tensor(obs, dtype=torch.float32, device=model.device)
147
+ obs_tensor.requires_grad_()
148
+
149
+ # Get the policy network (actor)
150
+ actor = model.policy.actor
151
+
152
+ # Define a baseline (e.g., a zero observation)
153
+ baseline = torch.zeros_like(obs_tensor)
154
+
155
+ # Number of steps for integral approximation
156
+ steps = 50
157
+
158
+ # Generate scaled inputs along the path from baseline to input
159
+ scaled_inputs = [baseline + (float(i) / steps) * (obs_tensor - baseline) for i in range(steps + 1)]
160
+
161
+ grads = []
162
+ for scaled_input in scaled_inputs:
163
+ # Forward pass to get action distribution parameters (mean)
164
+ action_mean = actor(scaled_input)
165
+
166
+ # We need a scalar output to calculate gradients against.
167
+ # Here we sum, representing overall sensitivity of the action vector.
168
+ target_output = action_mean.sum()
169
+
170
+ # Calculate gradients of the target output with respect to the input features
171
+ grad = torch.autograd.grad(outputs=target_output, inputs=scaled_input)[0]
172
+ grads.append(grad)
173
+
174
+ # Average the gradients using the trapezoidal rule approximation
175
+ avg_grads = (grads[:-1] + grads[1:]) / 2.0
176
+ avg_grads = torch.stack(avg_grads).mean(dim=0)
177
+
178
+ # Calculate Integrated Gradients: (input - baseline) * average_gradients
179
+ integrated_grads = (obs_tensor - baseline) * avg_grads
180
+
181
+ # Detach, move to cpu, and convert to numpy array
182
+ importance_scores = integrated_grads.detach().cpu().numpy().flatten()
183
+
184
+ # Feature Names mapping
185
+ num_assets = len(ASSETS)
186
+ num_macro = len(MACRO_COLS)
187
+
188
+ # Create feature names based on the observation structure
189
+ feature_names = []
190
+ for i in range(WINDOW_SIZE):
191
+ for asset in ASSETS:
192
+ feature_names.append(f"{asset}_t-{WINDOW_SIZE-1-i}")
193
+ for i in range(WINDOW_SIZE):
194
+ for macro in MACRO_COLS:
195
+ feature_names.append(f"{macro}_t-{WINDOW_SIZE-1-i}")
196
+
197
+ # Combine into a dictionary and sort by absolute importance
198
+ feature_importance_dict = dict(zip(feature_names, importance_scores))
199
+
200
+ # Aggregate importance by feature type (sum of absolute values across time steps)
201
+ aggregated_importance = {}
202
+ for base_feature in ASSETS + MACRO_COLS:
203
+ total_imp = sum(abs(val) for key, val in feature_importance_dict.items() if key.startswith(base_feature))
204
+ aggregated_importance[base_feature] = total_imp
205
+
206
+ # Sort and take top N for display
207
+ top_features = dict(sorted(aggregated_importance.items(), key=lambda item: item[1], reverse=True)[:8])
208
+
209
+ # Create a Plotly bar chart
210
+ fig = px.bar(
211
+ x=list(top_features.values()),
212
+ y=list(top_features.keys()),
213
+ orientation='h',
214
+ title="Top Influential Features (XAI)",
215
+ labels={'x': 'Relative Importance Score', 'y': 'Feature'},
216
+ color=list(top_features.values()),
217
+ color_continuous_scale=px.colors.sequential.Viridis
218
+ )
219
+ fig.update_layout(
220
+ template="plotly_dark",
221
+ paper_bgcolor='rgba(0,0,0,0)',
222
+ plot_bgcolor='rgba(0,0,0,0)',
223
+ yaxis={'categoryorder':'total ascending'},
224
+ coloraxis_showscale=False,
225
+ margin=dict(l=10, r=10, t=40, b=10),
226
+ height=300 # Keep it compact
227
+ )
228
+
229
+ return fig
230
+
231
+ # =========================================
232
+ # Tab 4 Logic: Historical Simulation (UPDATED)
233
+ # =========================================
234
+
235
+ def run_historical_simulation(start_date_str, end_date_str):
236
+ """
237
+ Runs the RL agent on historical data and compares to baselines using professional metrics.
238
+ """
239
+ if DASHBOARD_DATA_DF is None:
240
+ return go.Figure(), "Data not initialized. Please restart app.", gr.update(visible=False)
241
+
242
+ status_msg = "Preparing simulation..."
243
+ yield go.Figure(), status_msg, gr.update(visible=False)
244
+
245
+ try:
246
+ # 1. Validate and Slice Data
247
+ try:
248
+ start_date = pd.to_datetime(start_date_str)
249
+ end_date = pd.to_datetime(end_date_str)
250
+ except ValueError:
251
+ yield go.Figure(), "Error: Invalid date format. Use YYYY-MM-DD.", gr.update(visible=False)
252
+ return
253
+
254
+ if start_date < DASHBOARD_DATA_DF.index.min() or end_date > DASHBOARD_DATA_DF.index.max():
255
+ avail_start = DASHBOARD_DATA_DF.index.min().date()
256
+ avail_end = DASHBOARD_DATA_DF.index.max().date()
257
+ yield go.Figure(), f"Error: Selected dates outside available range ({avail_start} to {avail_end}).", gr.update(visible=False)
258
+ return
259
+
260
+ df_slice = DASHBOARD_DATA_DF.loc[start_date:end_date].copy()
261
+ asset_cols_only = [c for c in ASSETS if c in df_slice.columns]
262
+
263
+ if len(df_slice) < WINDOW_SIZE + 10:
264
+ yield go.Figure(), "Error: Time period too short for simulation.", gr.update(visible=False)
265
+ return
266
+
267
+ # 2. Setup Environment and Agent
268
+ status_msg = "Running RL Agent simulation..."
269
+ yield go.Figure(), status_msg, gr.update(visible=False)
270
+
271
+ env = PortfolioEnv(df_slice, WINDOW_SIZE, initial_amount=10000)
272
+
273
+ if not os.path.exists(MODEL_PATH):
274
+ raise FileNotFoundError(f"Model not found: {MODEL_PATH}")
275
+ model = SAC.load(MODEL_PATH)
276
+
277
+ # 3. Run Simulation Loop & Get Values using Pro Function
278
+ rl_portfolio_series = evaluate_agent_pro(env, model)
279
+
280
+ # 4. Calculate Baselines using Pro Functions
281
+ status_msg = "Calculating baselines and metrics..."
282
+ yield go.Figure(), status_msg, gr.update(visible=False)
283
+
284
+ # Pass only asset columns to baseline functions
285
+ bnh_portfolio_series = buy_and_hold(df_slice[asset_cols_only], initial_amount=10000)
286
+ # Realign B&H index to match RL agent's start date
287
+ bnh_portfolio_series = bnh_portfolio_series.loc[rl_portfolio_series.index[0]:]
288
+ # Normalize B&H starting value to match RL agent's start
289
+ bnh_portfolio_series = bnh_portfolio_series / bnh_portfolio_series.iloc[0] * 10000
290
+
291
+ eq_portfolio_series = equally_weighted_rebalanced(df_slice[asset_cols_only], initial_amount=10000)
292
+ eq_portfolio_series = eq_portfolio_series.loc[rl_portfolio_series.index[0]:]
293
+ eq_portfolio_series = eq_portfolio_series / eq_portfolio_series.iloc[0] * 10000
294
+
295
+ # 5. Generate Plot
296
+ fig = go.Figure()
297
+ fig.add_trace(go.Scatter(x=rl_portfolio_series.index, y=rl_portfolio_series, mode='lines', name='RL Agent (SAC)', line=dict(color='#10b981', width=3)))
298
+ fig.add_trace(go.Scatter(x=bnh_portfolio_series.index, y=bnh_portfolio_series, mode='lines', name='Buy & Hold (SPY)', line=dict(color='#6b7280', dash='dash')))
299
+ fig.add_trace(go.Scatter(x=eq_portfolio_series.index, y=eq_portfolio_series, mode='lines', name='Equal Weighted', line=dict(color='#a855f7', dash='dot')))
300
+
301
+ fig.update_layout(
302
+ title="Simulation: Strategy Performance Comparison ($10k Start)",
303
+ xaxis_title="Date",
304
+ yaxis_title="Portfolio Value ($)",
305
+ template="plotly_dark",
306
+ paper_bgcolor='rgba(0,0,0,0)',
307
+ plot_bgcolor='rgba(0,0,0,0)',
308
+ hovermode="x unified",
309
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
310
+ )
311
+
312
+ # 6. Calculate Professional Metrics Table
313
+ rl_m = calculate_metrics_pro(rl_portfolio_series)
314
+ bnh_m = calculate_metrics_pro(bnh_portfolio_series)
315
+ eq_m = calculate_metrics_pro(eq_portfolio_series)
316
+
317
+ # Helper to format based on metric type
318
+ def fmt(val, is_pct=True):
319
+ if pd.isna(val): return "N/A"
320
+ return f"{val:.2%}" if is_pct else f"{val:.2f}"
321
+
322
+ metrics_data = {
323
+ "Metric": ["Total Return", "CAGR", "Sharpe Ratio", "Sortino Ratio", "Volatility (Ann.)", "Max Drawdown", "Calmar Ratio"],
324
+ "RL Agent (SAC)": [fmt(rl_m["Total Return"]), fmt(rl_m["CAGR"]), fmt(rl_m["Sharpe Ratio"], False), fmt(rl_m["Sortino Ratio"], False), fmt(rl_m["Volatility"]), fmt(rl_m["Max Drawdown"]), fmt(rl_m["Calmar Ratio"], False)],
325
+ "Buy & Hold (SPY)": [fmt(bnh_m["Total Return"]), fmt(bnh_m["CAGR"]), fmt(bnh_m["Sharpe Ratio"], False), fmt(bnh_m["Sortino Ratio"], False), fmt(bnh_m["Volatility"]), fmt(bnh_m["Max Drawdown"]), fmt(bnh_m["Calmar Ratio"], False)],
326
+ "Equal Weighted": [fmt(eq_m["Total Return"]), fmt(eq_m["CAGR"]), fmt(eq_m["Sharpe Ratio"], False), fmt(eq_m["Sortino Ratio"], False), fmt(eq_m["Volatility"]), fmt(eq_m["Max Drawdown"]), fmt(eq_m["Calmar Ratio"], False)],
327
+ }
328
+ metrics_df = pd.DataFrame(metrics_data)
329
+
330
+ # Format the dataframe as a markdown table for cleaner display
331
+ metrics_md = metrics_df.to_markdown(index=False)
332
+ final_metrics_display = f"### 📊 Professional Performance Metrics\n\n{metrics_md}"
333
+
334
+ yield fig, "Simulation Complete.", final_metrics_display
335
+
336
+ except Exception as e:
337
+ import traceback
338
+ traceback.print_exc()
339
+ yield go.Figure(), f"Error during simulation: {str(e)}", gr.update(visible=False)
340
+
341
+
342
+ # =========================================
343
+ # Tab 3 Logic: Historical Data Analyst
344
+ # =========================================
345
+
346
+ def run_historical_analysis(selected_assets, period_name):
347
+ """Backend for Tab 3."""
348
+ if DASHBOARD_DATA_DF is None or not selected_assets:
349
+ return go.Figure(), "Please wait for data initialization or select assets."
350
+
351
+ status_html = """<div style="color: #9ca3af;">🔄 Processing data and running AI analysis...</div>"""
352
+ yield go.Figure(), status_html
353
+
354
+ try:
355
+ # 1. Filter Data by Time Period
356
+ days = TIME_PERIODS.get(period_name, 365)
357
+ cutoff_date = datetime.now() - timedelta(days=days)
358
+ valid_assets = [a for a in selected_assets if a in DASHBOARD_DATA_DF.columns]
359
+ if not valid_assets:
360
+ yield go.Figure(), "Error: Selected assets not found in available data."
361
+ return
362
+ df_filtered = DASHBOARD_DATA_DF.loc[cutoff_date:, valid_assets].copy()
363
+ if df_filtered.empty:
364
+ yield go.Figure(), f"No data found for the selected period: {period_name}"
365
+ return
366
+
367
+ # 2. Generate Normalized Price Plot
368
+ df_normalized = df_filtered / df_filtered.iloc[0] * 100
369
+ fig = px.line(df_normalized, x=df_normalized.index, y=df_normalized.columns,
370
+ title=f"Performance Comparison: {period_name} (Base=100)",
371
+ color_discrete_sequence=px.colors.qualitative.Bold)
372
+ fig.update_layout(template="plotly_dark", paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)',
373
+ yaxis_title="Normalized Price", xaxis_title="Date", legend_title_text="", hovermode="x unified")
374
+
375
+ # 3. Run AI Analysis
376
+ analysis_text = analyze_historical_segment(df_filtered, valid_assets, period_name)
377
+ formatted_analysis = f"### 🤖 AI Analyst Report: {period_name}\n\n{analysis_text}"
378
+ yield fig, formatted_analysis
379
+
380
+ except Exception as e:
381
+ import traceback
382
+ traceback.print_exc()
383
+ yield go.Figure(), f"### Error during analysis\n\n{str(e)}"
384
+
385
+
386
+ # =========================================
387
+ # Tab 2 Logic: Forecast & Analysis (XAI)
388
+ # =========================================
389
+
390
+ def get_latest_data_window(window_size=30):
391
+ """Fetches latest data needed for prediction."""
392
+ print("Fetching prediction data...")
393
+ lookback_days = window_size + 150
394
+ end_date = datetime.now().strftime('%Y-%m-%d')
395
+ start_date = (datetime.now() - timedelta(days=lookback_days)).strftime('%Y-%m-%d')
396
+ temp_filename = os.path.join(project_root, "data", "temp_gradio_prediction_data.csv")
397
+ fetch_market_data(start_date, end_date, temp_filename)
398
+ if not os.path.exists(temp_filename): raise Exception("Failed to fetch market data file.")
399
+ df = pd.read_csv(temp_filename, index_col=0, parse_dates=True)
400
+ df.dropna(inplace=True)
401
+ if len(df) < window_size: raise Exception(f"Not enough clean data fetched for prediction.")
402
+ return df.iloc[-window_size:].copy()
403
+
404
+ def prepare_observation(data_window):
405
+ price_data = data_window[ASSETS].values
406
+ macro_data = data_window[MACRO_COLS].values
407
+ norm_prices = price_data / (price_data[0] + 1e-8)
408
+ norm_macro = macro_data / (macro_data[0] + 1e-8)
409
+ obs = np.concatenate([norm_prices, norm_macro], axis=1)
410
+ # Return both flattened obs for model and raw obs for XAI
411
+ return obs.flatten().astype(np.float32), obs.astype(np.float32), data_window
412
+
413
+ def predict_and_analyze():
414
+ """Main function for Forecast Tab."""
415
+ status_msg = "Starting process..."
416
+ loading_html = """<div style="color: #9ca3af;">🔄 Fetching data & running prediction...</div>"""
417
+ # Update to yield an empty plot for the XAI chart initially
418
+ yield status_msg, None, go.Figure(), loading_html
419
+
420
+ try:
421
+ data_window = get_latest_data_window(WINDOW_SIZE)
422
+ # Get flattened obs for prediction and raw obs for XAI
423
+ flat_obs, raw_obs, df_window_for_analyst = prepare_observation(data_window)
424
+
425
+ if not os.path.exists(MODEL_PATH): raise FileNotFoundError(f"Model not found: {MODEL_PATH}")
426
+ model = SAC.load(MODEL_PATH)
427
+
428
+ # --- XAI: Calculate Feature Importance ---
429
+ status_msg = "Calculating feature importance..."
430
+ yield status_msg, None, go.Figure(), loading_html
431
+ xai_plot = calculate_feature_importance(model, raw_obs)
432
+
433
+ # --- Prediction ---
434
+ action, _ = model.predict(flat_obs, deterministic=True)
435
+ exp_action = np.exp(np.asarray(action).flatten())
436
+ weights = exp_action / np.sum(exp_action)
437
+ allocations_dict = {asset: weights[i] for i, asset in enumerate(ASSETS)}
438
+ allocations_dict['Cash'] = weights[-1]
439
+ alloc_df = pd.DataFrame(list(allocations_dict.items()), columns=['Asset', 'Proposed Allocation'])
440
+ alloc_df['Proposed Allocation'] = alloc_df['Proposed Allocation'].apply(lambda x: f"{x:.2%}")
441
+
442
+ status_msg = "Prediction done. Running AI Risk Analysis..."
443
+ analysing_html = """<div style="color: #9ca3af;">🤖 Running Qwen-2.5-3B Risk Analysis...</div>"""
444
+ # Yield XAI plot along with other outputs
445
+ yield status_msg, alloc_df, xai_plot, analysing_html
446
+
447
+ allocations_for_llm = {k: float(v) for k, v in allocations_dict.items()}
448
+ analysis_result = analyze_agent_decision(df_window_for_analyst, allocations_for_llm)
449
+ status_msg = "Analysis complete!"
450
+
451
+ if isinstance(analysis_result, dict):
452
+ strat = analysis_result.get('strategy_summary', 'N/A')
453
+ risk = analysis_result.get('risk_level', 'N/A').upper()
454
+ just = analysis_result.get('justification', 'N/A')
455
+ conf = analysis_result.get('confidence_score', 'N/A')
456
+ if 'HIGH' in risk:
457
+ risk_css = "color: #ef4444; font-weight: bold;"
458
+ status_bg = "#7f1d1d"
459
+ status_border = "#ef4444"
460
+ status_icon = "⛔"
461
+ status_text = "TRADE BLOCKED: High Risk Detected"
462
+ else:
463
+ risk_css = "color: #10b981; font-weight: bold;"
464
+ status_bg = "#064e3b"
465
+ status_border = "#10b981"
466
+ status_icon = "🚀"
467
+ status_text = "TRADE APPROVED"
468
+
469
+ report_html = f"""
470
+ <div style="background-color: #1f2937; padding: 20px; border-radius: 12px 12px 0 0; border: 1px solid #374151; border-bottom: none;">
471
+ <h3 style="margin-top: 0; color: #e5e7eb;">🤖 AI Risk Analyst Report</h3>
472
+ <div style="margin-bottom: 15px;"><strong style="color: #9ca3af;">Strategy:</strong><br><span style="color: #d1d5db;">{strat}</span></div>
473
+ <div style="margin-bottom: 15px;"><strong style="color: #9ca3af;">Risk Level:</strong><span style="margin-left: 8px; {risk_css}">{risk}</span></div>
474
+ <div style="margin-bottom: 15px;"><strong style="color: #9ca3af;">Justification:</strong><br><span style="color: #d1d5db;">{just}</span></div>
475
+ <div><strong style="color: #9ca3af;">Confidence:</strong> <span style="color: #d1d5db;">{conf}/10</span></div>
476
+ </div>
477
+ <div style="background-color: {status_bg}; color: white; padding: 15px; border-radius: 0 0 12px 12px; border: 2px solid {status_border}; text-align: center; font-size: 1.2em; font-weight: bold; display: flex; align-items: center; justify-content: center;">
478
+ <span style="margin-right: 10px; font-size: 1.4em;">{status_icon}</span>{status_text}
479
+ </div>"""
480
+ else:
481
+ report_html = f"""<div style="padding: 20px; background-color: #7f1d1d; color: #fca5a5; border-radius: 12px;"><h3>❌ Analysis Failed to Parse</h3><p>{str(analysis_result)}</p></div>"""
482
+ # Final yield with all outputs including XAI plot
483
+ yield status_msg, alloc_df, xai_plot, report_html
484
+ except Exception as e:
485
+ import traceback
486
+ traceback.print_exc()
487
+ status_msg = f"Error: {str(e)}"
488
+ error_html = f"""<div style="padding: 20px; background-color: #7f1d1d; color: #fca5a5; border-radius: 12px;"><h3>❌ Process Error</h3><p>{str(e)}</p></div>"""
489
+ # Final yield in case of error
490
+ yield status_msg, None, go.Figure(), error_html
491
+
492
+
493
+ # =========================================
494
+ # Tab 1 Logic: Live Dashboard (DUMMY DATA)
495
+ # =========================================
496
+ def get_dashboard_metrics():
497
+ return "$135,400", "+3.07%"
498
+
499
+ def get_portfolio_history_plot():
500
+ dates = pd.date_range(start="2023-01-01", periods=100)
501
+ np.random.seed(42)
502
+ rl_returns = np.random.normal(0.001, 0.01, 100)
503
+ bnh_returns = np.random.normal(0.0005, 0.012, 100)
504
+ rl_value = 10000 * np.cumprod(1 + rl_returns)
505
+ bnh_value = 10000 * np.cumprod(1 + bnh_returns)
506
+ fig = go.Figure()
507
+ fig.add_trace(go.Scatter(x=dates, y=rl_value, mode='lines', name='RL Agent (Live)', line=dict(color='#10b981', width=3)))
508
+ fig.add_trace(go.Scatter(x=dates, y=bnh_value, mode='lines', name='Benchmark', line=dict(color='#6b7280', dash='dash')))
509
+ fig.update_layout(title="Portfolio Net Worth (Live Tracking)", xaxis_title="Date", yaxis_title="Net Worth ($)", template="plotly_dark", paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)', legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1))
510
+ return fig
511
+
512
+ def get_current_allocation_plot():
513
+ labels = ASSETS + ['Cash']
514
+ values = [0.25, 0.10, 0.30, 0.15, 0.05, 0.15]
515
+ fig = px.pie(values=values, names=labels, title="Current Holdings Breakdown", color_discrete_sequence=px.colors.qualitative.Bold)
516
+ fig.update_traces(textposition='inside', textinfo='percent+label', hole=.4)
517
+ fig.update_layout(template="plotly_dark", paper_bgcolor='rgba(0,0,0,0)', legend=dict(orientation="h", yanchor="bottom", y=-0.1))
518
+ return fig
519
+
520
+ def get_recent_transactions():
521
+ data = [["2025-11-24", "Rebalance", "MULTIPLE", "N/A"], ["2025-11-24", "SELL", "SPY", "$4,500"], ["2025-11-24", "BUY", "TLT", "$4,200"], ["2025-11-21", "BUY", "BTC-USD", "$1,000"]]
522
+ return pd.DataFrame(data, columns=["Date", "Type", "Asset", "Approx. Value"])
523
+
524
+
525
+ # =========================================
526
+ # Gradio Interface
527
+ # =========================================
528
+
529
+ custom_css = """
530
+ .metric-box { background-color: #1f2937; padding: 20px; border-radius: 12px; border: 1px solid #374151; text-align: center; }
531
+ .metric-label { font-size: 1.1em; color: #9ca3af; margin-bottom: 5px; }
532
+ .metric-value { font-size: 2.2em; font-weight: 700; color: #e5e7eb; }
533
+ .disclaimer-box { background-color: #374151; padding: 15px; border-radius: 8px; border-left: 4px solid #f59e0b; color: #d1d5db; font-size: 0.9em; margin-bottom: 20px; }
534
+ """
535
+
536
+ theme = gr.themes.Soft(primary_hue="emerald", secondary_hue="slate", neutral_hue="zinc").set(
537
+ body_background_fill="#111827", block_background_fill="#1f2937", block_border_width="1px", block_border_color="#374151"
538
+ )
539
+
540
+ with gr.Blocks(theme=theme, css=custom_css, title="Deep RL Portfolio Manager") as demo:
541
+ gr.HTML("""<script>function forceDark(){document.body.classList.add('dark');} forceDark(); setTimeout(forceDark, 500);</script>""")
542
+
543
+ gr.Markdown("# 🧠 Deep RL & LLM Portfolio Manager")
544
+
545
+ with gr.Tabs():
546
+ # ================= TAB 1: DASHBOARD (RESTORED) =================
547
+ with gr.TabItem("📊 Live Dashboard"):
548
+ # Metrics Row
549
+ with gr.Row():
550
+ # MOVED THIS LINE INSIDE THE TAB
551
+ nw_val, dc_val = get_dashboard_metrics()
552
+ with gr.Column(elem_classes=["metric-box"]):
553
+ gr.HTML(f"<div class='metric-label'>Current Net Worth</div><div class='metric-value'>{nw_val}</div>")
554
+ with gr.Column(elem_classes=["metric-box"]):
555
+ gr.HTML(f"<div class='metric-label'>24h Change</div><div class='metric-value' style='color: #10b981;'>{daily_change}</div>")
556
+
557
+ # Main Chart row
558
+ with gr.Row():
559
+ with gr.Column(scale=3):
560
+ history_chart = gr.Plot(value=get_portfolio_history_plot(), label="Net Worth History")
561
+
562
+ # Bottom Row: Allocations and Transactions
563
+ with gr.Row():
564
+ with gr.Column(scale=1):
565
+ allocation_chart = gr.Plot(value=get_current_allocation_plot(), label="Current Allocation")
566
+ with gr.Column(scale=2):
567
+ gr.Markdown("### Recent Transactions")
568
+ transactions_table = gr.Dataframe(value=get_recent_transactions(), interactive=False, wrap=True)
569
+
570
+ # ================= TAB 2: FORECAST (UPDATED with XAI) =================
571
+ with gr.TabItem("🔮 Forecast & AI Analysis"):
572
+ gr.Markdown("### Generate Tomorrow's Portfolio Strategy")
573
+ run_btn = gr.Button("🚀 Run Overnight Analysis", variant="primary", size="lg")
574
+ status_output = gr.Textbox(label="System Status", placeholder="Ready...", interactive=False, lines=1)
575
+ gr.Markdown("---")
576
+
577
+ with gr.Row():
578
+ # Left Column: Allocations & XAI Plot
579
+ with gr.Column(scale=2):
580
+ gr.Markdown("### 📈 Suggested Position")
581
+ allocation_output = gr.Dataframe(headers=["Asset", "Allocation"], datatype=["str", "str"], interactive=False)
582
+
583
+ # NEW: XAI Feature Importance Plot
584
+ gr.Markdown("### 🧠 Why did the agent choose this?")
585
+ xai_output_plot = gr.Plot(label="Top Influential Factors (XAI)", show_label=False)
586
+
587
+ # Right Column: AI Analysis Report
588
+ with gr.Column(scale=3):
589
+ analysis_report_html = gr.HTML(label="AI Risk Analysis Report")
590
+
591
+ # Updated click event with new XAI output
592
+ run_btn.click(
593
+ fn=predict_and_analyze,
594
+ inputs=None,
595
+ outputs=[status_output, allocation_output, xai_output_plot, analysis_report_html]
596
+ )
597
+
598
+ # ================= TAB 3: HISTORICAL DATA ANALYST =================
599
+ with gr.TabItem("📅 Historical Data Analyst"):
600
+ gr.Markdown("### Analyze Past Market Performance with AI")
601
+
602
+ with gr.Row():
603
+ with gr.Column(scale=1):
604
+ all_tickers_hist = ASSETS + list(FRED_IDS.values())
605
+ if DASHBOARD_DATA_DF is not None:
606
+ available_tickers_hist = [t for t in all_tickers_hist if t in DASHBOARD_DATA_DF.columns]
607
+ else:
608
+ available_tickers_hist = []
609
+ default_tickers_hist = available_tickers_hist[:3] if available_tickers_hist else []
610
+
611
+ asset_selector = gr.Dropdown(choices=available_tickers_hist, value=default_tickers_hist, multiselect=True, label="1. Select Assets")
612
+ period_selector = gr.Dropdown(choices=list(TIME_PERIODS.keys()), value="1 Year", label="2. Select Period")
613
+ analyze_btn = gr.Button("🔎 Run Analysis", variant="primary")
614
+
615
+ with gr.Column(scale=3):
616
+ historical_plot = gr.Plot(label="Performance Plot")
617
+
618
+ gr.Markdown("---")
619
+ historical_analysis_md = gr.Markdown("### 🤖 AI Analyst Report\n\n*Click 'Run Analysis' to generate.*")
620
+
621
+ analyze_btn.click(
622
+ fn=run_historical_analysis,
623
+ inputs=[asset_selector, period_selector],
624
+ outputs=[historical_plot, historical_analysis_md]
625
+ )
626
+
627
+ # ================= TAB 4: HISTORICAL SIMULATION (UPDATED with Pro Metrics) =================
628
+ with gr.TabItem("🔙 Historical Simulation"):
629
+ gr.Markdown("### Backtest the RL Agent against Baselines")
630
+
631
+ # Disclaimer Box
632
+ gr.HTML(f"""
633
+ <div class='disclaimer-box'>
634
+ <strong>⚠️ IMPORTANT DISCLAIMER:</strong> The RL model was trained on data from approximately
635
+ <strong>{TRAIN_START_DATE} to {TRAIN_END_DATE}</strong>. Running simulations outside or overlapping significantly
636
+ with this period may not accurately reflect real-world performance (lookahead bias or out-of-distribution data).
637
+ Use for educational purposes only.
638
+ </div>
639
+ """)
640
+
641
+ with gr.Row():
642
+ with gr.Column(scale=1):
643
+ start_date_input = gr.Textbox(label="Start Date (YYYY-MM-DD)", value=(datetime.now() - timedelta(days=365)).strftime('%Y-%m-%d'))
644
+ end_date_input = gr.Textbox(label="End Date (YYYY-MM-DD)", value=(datetime.now() - timedelta(days=1)).strftime('%Y-%m-%d'))
645
+ sim_btn = gr.Button("▶️ Run Simulation", variant="primary")
646
+ sim_status = gr.Textbox(label="Status", interactive=False, lines=1)
647
+
648
+ with gr.Column(scale=3):
649
+ sim_plot = gr.Plot(label="Simulation Performance")
650
+
651
+ gr.Markdown("---")
652
+ # Updated to Markdown component for better table formatting
653
+ sim_metrics_md = gr.Markdown("### 📊 Professional Performance Metrics\n\n*Run simulation to see metrics.*")
654
+
655
+ sim_btn.click(
656
+ fn=run_historical_simulation,
657
+ inputs=[start_date_input, end_date_input],
658
+ outputs=[sim_plot, sim_status, sim_metrics_md]
659
+ )
660
+
661
+ if __name__ == "__main__":
662
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860, debug=True, share=True)
scripts/custom_policy.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # custom_policy.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from gymnasium import spaces
6
+ from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
7
+
8
+ class TransformerFeatureExtractor(BaseFeaturesExtractor):
9
+ """
10
+ A custom feature extractor that uses a Transformer Encoder.
11
+
12
+ It takes a flattened observation (window_size * n_features_per_step) and processes
13
+ it as a sequence.
14
+ """
15
+ def __init__(
16
+ self,
17
+ observation_space: spaces.Box,
18
+ features_dim: int = 256, # The final output dimension
19
+ n_features_per_step: int = 8, # <--- CRITICAL CHANGE: Matches 5 assets + 3 macro
20
+ window_size: int = 30,
21
+ d_model: int = 64, # Transformer's internal embedding dimension
22
+ n_head: int = 4, # Number of attention heads
23
+ n_layers: int = 2, # Number of transformer encoder layers
24
+ dropout: float = 0.1
25
+ ):
26
+
27
+ super().__init__(observation_space, features_dim)
28
+
29
+ self.window_size = window_size
30
+ self.n_features_per_step = n_features_per_step
31
+
32
+ # Input shape check
33
+ expected_flat_dim = window_size * n_features_per_step
34
+ if observation_space.shape[0] != expected_flat_dim:
35
+ raise ValueError(
36
+ f"Observation space flat dimension {observation_space.shape[0]} "
37
+ f"does not match expected {expected_flat_dim} "
38
+ f"(window_size={window_size}, n_features_per_step={n_features_per_step})."
39
+ )
40
+
41
+ # 1. Input Projection:
42
+ self.input_projection = nn.Linear(n_features_per_step, d_model)
43
+
44
+ # 2. Positional Encoding:
45
+ self.positional_encoding = nn.Parameter(torch.randn(1, window_size, d_model))
46
+
47
+ # 3. Transformer Encoder:
48
+ encoder_layer = nn.TransformerEncoderLayer(
49
+ d_model=d_model,
50
+ nhead=n_head,
51
+ dropout=dropout,
52
+ batch_first=True
53
+ )
54
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
55
+
56
+ # 4. Output Layers:
57
+ self.flatten = nn.Flatten()
58
+ self.linear_out = nn.Linear(window_size * d_model, features_dim)
59
+ self.relu = nn.ReLU()
60
+
61
+ def forward(self, observations: torch.Tensor) -> torch.Tensor:
62
+ # Input shape: (batch_size, window_size * n_features_per_step)
63
+
64
+ # 1. Reshape to (batch_size, window_size, n_features_per_step)
65
+ x = observations.reshape(-1, self.window_size, self.n_features_per_step)
66
+
67
+ # 2. Project input features to d_model
68
+ x = self.input_projection(x)
69
+
70
+ # 3. Add positional encoding
71
+ x = x + self.positional_encoding
72
+
73
+ # 4. Pass through Transformer
74
+ x = self.transformer_encoder(x)
75
+
76
+ # 5. Flatten and project to final output
77
+ x = self.flatten(x)
78
+ x = self.relu(self.linear_out(x))
79
+
80
+ return x
scripts/environment.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import gymnasium as gym
2
  import numpy as np
3
  import pandas as pd
@@ -5,126 +7,89 @@ from gymnasium import spaces
5
 
6
  class PortfolioEnv(gym.Env):
7
  """
8
- A custom reinforcement learning environment for portfolio management.
9
-
10
- This environment simulates the daily trading of multiple financial assets. The agent's
11
- goal is to learn a policy for allocating capital to maximize risk-adjusted returns.
12
  """
13
  metadata = {'render_modes': ['human']}
14
 
15
  def __init__(self, df, window_size=30, initial_balance=10000, transaction_cost_pct=0.001):
16
- """
17
- Initializes the portfolio management environment.
18
-
19
- Args:
20
- df (pd.DataFrame): A DataFrame containing the daily closing prices of the assets.
21
- The index should be dates and columns should be asset tickers.
22
- window_size (int): The number of past days of price data to include in the observation.
23
- initial_balance (float): The starting capital for the portfolio.
24
- transaction_cost_pct (float): The percentage cost for each trade (e.g., 0.001 for 0.1%).
25
- """
26
  super(PortfolioEnv, self).__init__()
27
 
28
- # --- Basic Environment Parameters ---
29
  self.df = df
30
  self.window_size = window_size
31
  self.initial_balance = initial_balance
32
  self.transaction_cost_pct = transaction_cost_pct
33
- self.n_assets = len(df.columns)
 
 
 
 
 
 
 
 
 
34
 
35
  # --- Action Space ---
36
- # The agent outputs a vector of continuous values, one for each asset plus one for cash.
37
- # These raw outputs are then converted to portfolio weights via a softmax function.
38
- # The space is defined from -1 to 1 for better compatibility with standard RL algorithms.
39
- # Shape: (number of assets + 1 for cash)
40
  self.action_space = spaces.Box(
41
  low=-1, high=1, shape=(self.n_assets + 1,), dtype=np.float32
42
  )
43
 
44
  # --- Observation Space ---
45
- # The agent observes a window of past price data, flattened into a 1D vector.
46
- # Shape: (window_size * number of assets)
47
  self.observation_space = spaces.Box(
48
  low=-np.inf, high=np.inf,
49
- shape=(self.window_size * self.n_assets,),
50
  dtype=np.float32
51
  )
52
 
53
- # --- Internal State Variables ---
54
- # These variables track the state of the simulation over time.
55
  self._current_step = 0
56
- self._portfolio_value = 0.0
57
- # Weights for each asset + cash, e.g., [w_aapl, w_msft, ..., w_cash]
58
  self._weights = np.zeros(self.n_assets + 1)
59
 
 
 
 
 
60
  def reset(self, seed=None):
61
- """
62
- Resets the environment to its initial state for a new episode.
63
-
64
- Returns:
65
- tuple: A tuple containing the initial observation and auxiliary info.
66
- """
67
  super().reset(seed=seed)
68
-
69
- # Start the simulation at the first point where a full window of data is available.
70
  self._current_step = self.window_size
71
  self._portfolio_value = self.initial_balance
72
-
73
- # Initialize weights to be 100% in cash.
74
  self._weights = np.zeros(self.n_assets + 1)
75
- self._weights[-1] = 1.0 # Last element represents cash
76
 
77
  observation = self._get_obs()
78
  info = self._get_info()
79
-
80
  return observation, info
81
 
82
  def step(self, action):
83
- """
84
- Executes one time step within the environment based on the agent's action.
85
-
86
- Args:
87
- action (np.ndarray): The raw output from the agent's policy network.
88
-
89
- Returns:
90
- tuple: A tuple containing the next observation, reward, terminated flag,
91
- truncated flag, and auxiliary info.
92
- """
93
- # 1. Store the portfolio value before taking the action.
94
  current_portfolio_value = self._portfolio_value
95
 
96
- # 2. Convert the raw action into portfolio weights using the softmax function.
97
- # This ensures the weights are positive and sum to 1.
98
- target_weights = np.exp(action) / np.sum(np.exp(action))
99
 
100
- # 3. Calculate the cost of rebalancing the portfolio.
101
- # The cost is based on the total value of assets bought or sold.
102
- trades = (target_weights[:-1] - self._weights[:-1]) * current_portfolio_value
103
  transaction_costs = np.sum(np.abs(trades)) * self.transaction_cost_pct
104
 
105
- # 4. Update the internal state: apply costs, set new weights, and advance time.
106
  self._balance = current_portfolio_value - transaction_costs
107
  self._weights = target_weights
 
108
  self._current_step += 1
109
 
110
- # 5. Calculate the new portfolio value based on the market's price movement.
111
- current_prices = self.df.iloc[self._current_step - 1].values
112
- next_prices = self.df.iloc[self._current_step].values
113
- price_ratio = next_prices / current_prices # How much each asset's price changed.
114
-
115
- # The new value of our asset holdings.
116
  asset_values_after_price_change = (self._weights[:-1] * self._balance) * price_ratio
117
-
118
- # The new total portfolio value is the sum of the updated asset values plus the cash holding.
119
  new_portfolio_value = np.sum(asset_values_after_price_change) + (self._weights[-1] * self._balance)
120
  self._portfolio_value = new_portfolio_value
121
 
122
- # 6. Calculate the reward for the agent.
123
- # The reward is the log return of the portfolio value, which encourages geometric growth.
124
- reward = np.log(new_portfolio_value / current_portfolio_value)
125
 
126
- # 7. Check for termination conditions.
127
- # The episode ends if the agent goes broke or runs out of data.
128
  terminated = bool(self._portfolio_value <= self.initial_balance * 0.5)
129
  truncated = self._current_step >= len(self.df) - 1
130
 
@@ -135,24 +100,25 @@ class PortfolioEnv(gym.Env):
135
 
136
  def _get_obs(self):
137
  """
138
- Constructs the observation for the agent at the current time step.
139
-
140
- Returns:
141
- np.ndarray: A flattened 1D array of the normalized price history.
142
  """
143
- # Get the window of historical price data.
144
- price_window = self.df.iloc[self._current_step - self.window_size : self._current_step].values
145
-
146
- # Normalize the window by dividing by the first price. This helps the agent
147
- # focus on relative price changes rather than absolute values.
148
- normalized_window = price_window / price_window[0]
149
-
150
- return normalized_window.flatten().astype(np.float32)
 
 
 
 
 
 
151
 
152
  def _get_info(self):
153
- """
154
- Returns a dictionary of auxiliary information about the current state.
155
- """
156
  return {
157
  'step': self._current_step,
158
  'portfolio_value': self._portfolio_value,
@@ -160,15 +126,7 @@ class PortfolioEnv(gym.Env):
160
  }
161
 
162
  def render(self, mode='human'):
163
- """
164
- Renders the environment's state (optional).
165
- """
166
- if mode == 'human':
167
- info = self._get_info()
168
- print(f"Step: {info['step']}, Portfolio Value: {info['portfolio_value']:.2f}")
169
 
170
  def close(self):
171
- """
172
- Cleans up the environment (optional).
173
- """
174
  pass
 
1
+ # src/environment.py (This is the CORRECT version for 8 features)
2
+
3
  import gymnasium as gym
4
  import numpy as np
5
  import pandas as pd
 
7
 
8
  class PortfolioEnv(gym.Env):
9
  """
10
+ A custom environment for portfolio management that includes macroeconomic data.
 
 
 
11
  """
12
  metadata = {'render_modes': ['human']}
13
 
14
  def __init__(self, df, window_size=30, initial_balance=10000, transaction_cost_pct=0.001):
 
 
 
 
 
 
 
 
 
 
15
  super(PortfolioEnv, self).__init__()
16
 
17
+ # --- Data Handling ---
18
  self.df = df
19
  self.window_size = window_size
20
  self.initial_balance = initial_balance
21
  self.transaction_cost_pct = transaction_cost_pct
22
+
23
+ # --- IMPORTANT: Define asset and macro columns ---
24
+ self.asset_columns = ['AAPL', 'BTC-USD', 'MSFT', 'SPY', 'TLT']
25
+ self.macro_columns = ['Federal Funds Rate', 'CPI', 'VIX']
26
+
27
+ self.n_assets = len(self.asset_columns)
28
+ self.n_macro_features = len(self.macro_columns)
29
+
30
+ # --- This is the attribute that was missing ---
31
+ self.n_features_per_step = self.n_assets + self.n_macro_features # Should be 8
32
 
33
  # --- Action Space ---
 
 
 
 
34
  self.action_space = spaces.Box(
35
  low=-1, high=1, shape=(self.n_assets + 1,), dtype=np.float32
36
  )
37
 
38
  # --- Observation Space ---
39
+ # Shape: (window_size * total_features) = (30 * 8) = 240
 
40
  self.observation_space = spaces.Box(
41
  low=-np.inf, high=np.inf,
42
+ shape=(self.window_size * self.n_features_per_step,),
43
  dtype=np.float32
44
  )
45
 
46
+ # --- Internal State ---
 
47
  self._current_step = 0
48
+ self._portfolio_value = 0
 
49
  self._weights = np.zeros(self.n_assets + 1)
50
 
51
+ # Separate dataframes for prices and macro for easier handling
52
+ self.price_df = self.df[self.asset_columns]
53
+ self.macro_df = self.df[self.macro_columns]
54
+
55
  def reset(self, seed=None):
 
 
 
 
 
 
56
  super().reset(seed=seed)
 
 
57
  self._current_step = self.window_size
58
  self._portfolio_value = self.initial_balance
59
+
 
60
  self._weights = np.zeros(self.n_assets + 1)
61
+ self._weights[-1] = 1.0 # 100% in cash
62
 
63
  observation = self._get_obs()
64
  info = self._get_info()
 
65
  return observation, info
66
 
67
  def step(self, action):
 
 
 
 
 
 
 
 
 
 
 
68
  current_portfolio_value = self._portfolio_value
69
 
70
+ target_weights = np.exp(action) / np.sum(np.exp(action)) # Softmax
 
 
71
 
72
+ current_asset_values = self._weights[:-1] * current_portfolio_value
73
+ target_asset_values = target_weights[:-1] * current_portfolio_value
74
+ trades = target_asset_values - current_asset_values
75
  transaction_costs = np.sum(np.abs(trades)) * self.transaction_cost_pct
76
 
 
77
  self._balance = current_portfolio_value - transaction_costs
78
  self._weights = target_weights
79
+
80
  self._current_step += 1
81
 
82
+ current_prices = self.price_df.iloc[self._current_step - 1].values
83
+ next_prices = self.price_df.iloc[self._current_step].values
84
+
85
+ price_ratio = next_prices / (current_prices + 1e-8) # Add epsilon for safety
86
+
 
87
  asset_values_after_price_change = (self._weights[:-1] * self._balance) * price_ratio
 
 
88
  new_portfolio_value = np.sum(asset_values_after_price_change) + (self._weights[-1] * self._balance)
89
  self._portfolio_value = new_portfolio_value
90
 
91
+ reward = np.log(new_portfolio_value / (current_portfolio_value + 1e-8)) # Add epsilon
 
 
92
 
 
 
93
  terminated = bool(self._portfolio_value <= self.initial_balance * 0.5)
94
  truncated = self._current_step >= len(self.df) - 1
95
 
 
100
 
101
  def _get_obs(self):
102
  """
103
+ Gets the observation for the current time step.
104
+ This includes a window of prices AND a window of macro data.
 
 
105
  """
106
+ price_window = self.price_df.iloc[self._current_step - self.window_size : self._current_step].values
107
+ macro_window = self.macro_df.iloc[self._current_step - self.window_size : self._current_step].values
108
+
109
+ # Normalize the price window (relative changes)
110
+ normalized_price_window = price_window / (price_window[0] + 1e-8)
111
+
112
+ # Normalize the macro window
113
+ normalized_macro_window = macro_window / (macro_window[0] + 1e-8)
114
+
115
+ # Combine the normalized windows
116
+ observation_window = np.concatenate([normalized_price_window, normalized_macro_window], axis=1)
117
+
118
+ # Flatten into a 1D vector
119
+ return observation_window.flatten().astype(np.float32)
120
 
121
  def _get_info(self):
 
 
 
122
  return {
123
  'step': self._current_step,
124
  'portfolio_value': self._portfolio_value,
 
126
  }
127
 
128
  def render(self, mode='human'):
129
+ pass
 
 
 
 
 
130
 
131
  def close(self):
 
 
 
132
  pass
scripts/evaluate.py CHANGED
@@ -1,12 +1,16 @@
 
 
1
  import pandas as pd
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
- from stable_baselines3 import SAC ,PPO , TD3
5
- from evaluate_baselines import buy_and_hold
6
- from environment import PortfolioEnv
7
  from matplotlib.ticker import FuncFormatter
 
 
 
8
 
9
- # --- Helper Function to Run the RL Agent ---
10
 
11
  def evaluate_agent(env, model):
12
  """
@@ -14,7 +18,6 @@ def evaluate_agent(env, model):
14
  """
15
  obs, info = env.reset()
16
  terminated, truncated = False, False
17
-
18
  portfolio_values = [env.initial_balance]
19
 
20
  while not (terminated or truncated):
@@ -22,121 +25,129 @@ def evaluate_agent(env, model):
22
  obs, reward, terminated, truncated, info = env.step(action)
23
  portfolio_values.append(info['portfolio_value'])
24
 
25
- return pd.Series(portfolio_values, index=env.df.index[:len(portfolio_values)])
 
 
 
26
 
27
 
28
  def calculate_metrics(portfolio_values, freq=252, rf=0.0):
29
  """
30
  Calculates key performance metrics from a series of portfolio values.
31
- freq: number of trading periods in a year (252 for daily, 52 for weekly).
32
- rf: risk-free rate (default = 0 for simplicity).
33
  """
 
 
 
34
  returns = portfolio_values.pct_change().dropna()
 
 
35
 
36
- # Total Return
37
  total_return = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) - 1
 
 
38
 
39
- # CAGR
40
- num_years = (len(portfolio_values) / freq)
41
- cagr = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) ** (1/num_years) - 1
42
-
43
- # Sharpe Ratio
44
- sharpe_ratio = np.sqrt(freq) * (returns.mean() - rf) / returns.std()
45
 
46
- # Sortino Ratio (downside risk only)
47
  downside_returns = returns[returns < 0]
48
  downside_std = downside_returns.std()
49
  sortino_ratio = np.sqrt(freq) * (returns.mean() - rf) / downside_std if downside_std > 0 else np.nan
50
 
51
- # Volatility (annualized std)
52
  volatility = returns.std() * np.sqrt(freq)
53
 
54
- # Max Drawdown
55
  rolling_max = portfolio_values.cummax()
56
  drawdown = portfolio_values / rolling_max - 1.0
57
  max_drawdown = drawdown.min()
58
 
59
- # Calmar Ratio
60
- calmar_ratio = cagr / abs(max_drawdown / 100) if max_drawdown != 0 else np.nan
61
 
62
  return {
63
- "Total Return": f"{total_return:.2%}",
64
- "CAGR": f"{cagr:.2%}",
65
- "Sharpe Ratio": f"{sharpe_ratio:.2f}",
66
- "Sortino Ratio": f"{sortino_ratio:.2f}",
67
- "Volatility": f"{volatility:.2%}",
68
- "Max Drawdown": f"{max_drawdown:.2%}",
69
  "Calmar Ratio": f"{calmar_ratio:.2f}"
70
  }
71
 
72
 
73
- def main(test_data_path='data/test.csv'):
74
  """
75
- Loads, evaluates, and plots the performance of PPO, SAC, and TD3 agents
76
- against a Buy and Hold baseline.
77
  """
78
- # --- Define Model Paths and Agent Types ---
79
  models_to_evaluate = {
80
- "PPO Agent": (PPO, 'checkpoints/ppo_portfolio_model'),
81
- "SAC Agent": (SAC, 'checkpoints/sac_portfolio_model'),
82
- "TD3 Agent": (TD3, 'checkpoints/td3_portfolio_model')
 
83
  }
84
 
85
- # Load test data
86
- test_df = pd.read_csv(test_data_path, index_col='Date', parse_dates=True)
 
 
 
87
 
88
- # Dictionary to store results
89
  portfolio_values = {}
90
  metrics = {}
91
 
92
  # --- Run Evaluations for each RL Agent---
93
  for name, (agent_type, model_path) in models_to_evaluate.items():
94
  print(f"--- Evaluating {name} ---")
 
 
 
 
95
  model = agent_type.load(model_path)
96
- env = PortfolioEnv(test_df)
97
  portfolio_values[name] = evaluate_agent(env, model)
98
  metrics[name] = calculate_metrics(portfolio_values[name])
99
 
100
  # --- Evaluate Buy and Hold Baseline ---
101
  print("\n--- Evaluating Buy and Hold Baseline ---")
102
- bnh_values = buy_and_hold(test_df)
 
 
 
103
  portfolio_values["Buy and Hold"] = bnh_values
104
  metrics["Buy and Hold"] = calculate_metrics(bnh_values)
105
-
 
 
 
106
  # --- Combine and Print Metrics ---
107
  print("\n--- Performance Metrics ---")
108
  metrics_df = pd.DataFrame(metrics)
109
- print(metrics_df)
110
 
111
  # --- Plotting All Strategies ---
112
  plt.style.use('seaborn-v0_8-darkgrid')
113
  fig, ax = plt.subplots(figsize=(14, 8))
114
 
115
- # Define colors for clarity
116
  colors = {
117
- "PPO Agent": "red",
118
- "SAC Agent": "green",
119
- "TD3 Agent": "orange",
120
- "Buy and Hold": "blue"
 
 
121
  }
122
 
123
  for name, values in portfolio_values.items():
124
- ax.plot(values.index, values, label=name, color=colors[name], linewidth=2)
 
125
 
126
  ax.set_title('Agent Performance Comparison', fontsize=16)
127
  ax.set_xlabel('Date', fontsize=12)
128
  ax.set_ylabel('Portfolio Value ($)', fontsize=12)
129
  ax.legend(fontsize=12)
130
-
131
  formatter = FuncFormatter(lambda x, p: f'${x:,.0f}')
132
  ax.yaxis.set_major_formatter(formatter)
133
 
134
  plt.tight_layout()
135
- plt.savefig('results/final_performance_comparison_all_agents.png')
 
 
136
  plt.show()
137
 
138
- # Example of how to run this main function
139
  if __name__ == '__main__':
140
- # You can specify a different test file here if needed
141
- # e.g., main(test_data_path='data/stress_test_2018.csv')
142
  main()
 
1
+ # scripts/compare_performance.py
2
+
3
  import pandas as pd
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
+ import os
7
+ from stable_baselines3 import TD3, PPO, SAC
8
+ from gymnasium import spaces
9
  from matplotlib.ticker import FuncFormatter
10
+ from environment import PortfolioEnv
11
+ from evaluate_baselines import buy_and_hold, equally_weighted_rebalanced
12
+ from custom_policy import TransformerFeatureExtractor
13
 
 
14
 
15
  def evaluate_agent(env, model):
16
  """
 
18
  """
19
  obs, info = env.reset()
20
  terminated, truncated = False, False
 
21
  portfolio_values = [env.initial_balance]
22
 
23
  while not (terminated or truncated):
 
25
  obs, reward, terminated, truncated, info = env.step(action)
26
  portfolio_values.append(info['portfolio_value'])
27
 
28
+ # Align index with the actual steps taken
29
+ # The first obs is at window_size, so index should start one step before
30
+ valid_dates = env.df.index[env.window_size-1:]
31
+ return pd.Series(portfolio_values, index=valid_dates[:len(portfolio_values)])
32
 
33
 
34
  def calculate_metrics(portfolio_values, freq=252, rf=0.0):
35
  """
36
  Calculates key performance metrics from a series of portfolio values.
 
 
37
  """
38
+ if len(portfolio_values) < 2:
39
+ return { "Total Return": "N/A", "CAGR": "N/A", "Sharpe Ratio": "N/A", "Max Drawdown": "N/A" }
40
+
41
  returns = portfolio_values.pct_change().dropna()
42
+ if returns.empty:
43
+ return { "Total Return": "0.00%", "CAGR": "0.00%", "Sharpe Ratio": "0.00", "Max Drawdown": "0.00%" }
44
 
 
45
  total_return = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) - 1
46
+ num_years = (len(portfolio_values) - 1) / freq
47
+ cagr = (portfolio_values.iloc[-1] / portfolio_values.iloc[0]) ** (1/num_years) - 1 if num_years > 0 else 0.0
48
 
49
+ sharpe_ratio = np.sqrt(freq) * (returns.mean() - rf) / returns.std() if returns.std() > 0 else np.nan
 
 
 
 
 
50
 
 
51
  downside_returns = returns[returns < 0]
52
  downside_std = downside_returns.std()
53
  sortino_ratio = np.sqrt(freq) * (returns.mean() - rf) / downside_std if downside_std > 0 else np.nan
54
 
 
55
  volatility = returns.std() * np.sqrt(freq)
56
 
 
57
  rolling_max = portfolio_values.cummax()
58
  drawdown = portfolio_values / rolling_max - 1.0
59
  max_drawdown = drawdown.min()
60
 
61
+ calmar_ratio = cagr / abs(max_drawdown) if max_drawdown != 0 and cagr != 0 else np.nan
 
62
 
63
  return {
64
+ "Total Return": f"{total_return:.2%}", "CAGR": f"{cagr:.2%}",
65
+ "Sharpe Ratio": f"{sharpe_ratio:.2f}", "Sortino Ratio": f"{sortino_ratio:.2f}",
66
+ "Volatility": f"{volatility:.2%}", "Max Drawdown": f"{max_drawdown:.2%}",
 
 
 
67
  "Calmar Ratio": f"{calmar_ratio:.2f}"
68
  }
69
 
70
 
71
+ def main(test_data_path='data/eval.csv'):
72
  """
73
+ Loads, evaluates, and plots all agent performances against baselines.
 
74
  """
75
+ # Define Model Paths and Agent Types
76
  models_to_evaluate = {
77
+ "SAC Agent Default (MLP)": (SAC, 'checkpoints/sac_portfolio_model.zip'),
78
+ "PPO Agent (MLP)": (PPO, 'checkpoints/ppo_portfolio_model.zip'),
79
+ "TD3 Agent (MLP)": (TD3, 'checkpoints/td3_portfolio_model.zip'),
80
+ "TD3 Agent (Transformer)": (TD3, 'checkpoints/td3_transformer_model.zip')
81
  }
82
 
83
+ # Load test data (this contains ALL columns - assets + macro)
84
+ full_eval_df = pd.read_csv(test_data_path, index_col='Date', parse_dates=True)
85
+
86
+ # Define your actual tradable asset columns
87
+ asset_columns = ['AAPL', 'BTC-USD', 'MSFT', 'SPY', 'TLT']
88
 
 
89
  portfolio_values = {}
90
  metrics = {}
91
 
92
  # --- Run Evaluations for each RL Agent---
93
  for name, (agent_type, model_path) in models_to_evaluate.items():
94
  print(f"--- Evaluating {name} ---")
95
+ if not os.path.exists(model_path):
96
+ print(f"⚠️ Warning: Model file not found at {model_path}. Skipping.")
97
+ continue
98
+
99
  model = agent_type.load(model_path)
100
+ env = PortfolioEnv(full_eval_df) # Pass the full DataFrame to the RL env
101
  portfolio_values[name] = evaluate_agent(env, model)
102
  metrics[name] = calculate_metrics(portfolio_values[name])
103
 
104
  # --- Evaluate Buy and Hold Baseline ---
105
  print("\n--- Evaluating Buy and Hold Baseline ---")
106
+
107
+ bnh_values = buy_and_hold(full_eval_df[asset_columns])
108
+ ewp_values = equally_weighted_rebalanced(full_eval_df[asset_columns])
109
+
110
  portfolio_values["Buy and Hold"] = bnh_values
111
  metrics["Buy and Hold"] = calculate_metrics(bnh_values)
112
+
113
+ portfolio_values["Equally Weighted"] = ewp_values
114
+ metrics["Equally Weighted"] = calculate_metrics(ewp_values)
115
+
116
  # --- Combine and Print Metrics ---
117
  print("\n--- Performance Metrics ---")
118
  metrics_df = pd.DataFrame(metrics)
119
+ print(metrics_df.to_markdown(numalign="left", stralign="left"))
120
 
121
  # --- Plotting All Strategies ---
122
  plt.style.use('seaborn-v0_8-darkgrid')
123
  fig, ax = plt.subplots(figsize=(14, 8))
124
 
 
125
  colors = {
126
+ "PPO Agent (MLP)": "red",
127
+ "SAC Agent Default (MLP)": "green",
128
+ "TD3 Agent (MLP)": "orange",
129
+ "TD3 Agent (Transformer)": "cyan",
130
+ "Buy and Hold": "blue",
131
+ "Equally Weighted": "purple"
132
  }
133
 
134
  for name, values in portfolio_values.items():
135
+ if name in portfolio_values: # Check if it was successfully evaluated
136
+ ax.plot(values.index, values, label=name, color=colors.get(name, 'gray'), linewidth=2)
137
 
138
  ax.set_title('Agent Performance Comparison', fontsize=16)
139
  ax.set_xlabel('Date', fontsize=12)
140
  ax.set_ylabel('Portfolio Value ($)', fontsize=12)
141
  ax.legend(fontsize=12)
142
+
143
  formatter = FuncFormatter(lambda x, p: f'${x:,.0f}')
144
  ax.yaxis.set_major_formatter(formatter)
145
 
146
  plt.tight_layout()
147
+ results_dir = 'results'
148
+ os.makedirs(results_dir, exist_ok=True)
149
+ plt.savefig(os.path.join(results_dir, 'final_performance_comparison_all_agents.png'))
150
  plt.show()
151
 
 
152
  if __name__ == '__main__':
 
 
153
  main()
scripts/evaluate_baselines.py CHANGED
@@ -1,46 +1,46 @@
1
- # evaluate_baselines.py
2
-
3
  import pandas as pd
4
  import numpy as np
5
  import matplotlib.pyplot as plt
 
6
 
7
- def buy_and_hold(df, initial_balance=10000):
8
  """
9
  Simulates the Buy and Hold strategy.
10
 
11
  Args:
12
- df (pd.DataFrame): DataFrame with daily asset prices.
13
  initial_balance (int): The starting capital.
14
 
15
  Returns:
16
  pd.Series: A Series containing the portfolio value for each day.
17
  """
18
  print("--- Simulating Buy and Hold ---")
19
- n_assets = len(df.columns)
20
 
21
  # Invest an equal amount in each asset at the beginning
22
  initial_investment_per_asset = initial_balance / n_assets
23
 
24
  # Get the initial prices
25
- initial_prices = df.iloc[0]
26
 
27
  # Calculate the number of shares bought for each asset
28
- shares = initial_investment_per_asset / initial_prices
 
29
 
30
  # Calculate the portfolio value for each day
31
- portfolio_values = df.dot(shares)
32
 
33
  print(f"Initial Investment: ${initial_balance:.2f}")
34
- print(f"Final Portfolio Value: ${portfolio_values.iloc[-1]:.2f}")
35
 
36
  return portfolio_values
37
 
38
- def equally_weighted_rebalanced(df, initial_balance=10000, rebalance_freq='M', transaction_cost_pct=0.001):
39
  """
40
  Simulates an Equally Weighted Portfolio with periodic rebalancing.
41
 
42
  Args:
43
- df (pd.DataFrame): DataFrame with daily asset prices.
44
  initial_balance (int): The starting capital.
45
  rebalance_freq (str): The rebalancing frequency ('M' for monthly, 'Q' for quarterly).
46
  transaction_cost_pct (float): The transaction cost as a percentage.
@@ -49,24 +49,30 @@ def equally_weighted_rebalanced(df, initial_balance=10000, rebalance_freq='M', t
49
  pd.Series: A Series containing the portfolio value for each day.
50
  """
51
  print(f"\n--- Simulating Equally Weighted Portfolio (Rebalanced {rebalance_freq}) ---")
52
- n_assets = len(df.columns)
53
 
54
  # Set the initial weights to be equal
55
  weights = np.full(n_assets, 1/n_assets)
56
 
57
  portfolio_value = initial_balance
58
- portfolio_values = pd.Series(index=df.index)
59
 
60
  last_rebalance_date = None
61
 
62
- for date, prices in df.iterrows():
63
  # Store the portfolio value for the day before any changes
64
  portfolio_values[date] = portfolio_value
65
 
66
  # Determine if it's a rebalancing day
67
- # Rebalance on the first day of the new period (month, quarter)
68
- if last_rebalance_date is None or (date.month != last_rebalance_date.month and rebalance_freq == 'M'):
69
-
 
 
 
 
 
 
70
  # Calculate the value of trades to rebalance
71
  target_asset_values = portfolio_value * (1/n_assets)
72
  current_asset_values = weights * portfolio_value
@@ -82,32 +88,43 @@ def equally_weighted_rebalanced(df, initial_balance=10000, rebalance_freq='M', t
82
 
83
  # Calculate portfolio value for the *next* day before the market opens
84
  # Get price changes from today to the next trading day
85
- today_prices = df.loc[date]
86
- next_day_index = df.index.get_loc(date) + 1
87
- if next_day_index < len(df):
88
- next_day_prices = df.iloc[next_day_index]
89
- price_change_ratio = next_day_prices / today_prices
 
 
90
 
91
  # Update portfolio value based on price changes
92
  portfolio_value = np.sum( (weights * portfolio_value) * price_change_ratio )
93
 
94
  # Update weights due to market drift
95
  new_asset_values = (weights * portfolio_value) * price_change_ratio
96
- weights = new_asset_values / np.sum(new_asset_values)
 
 
 
 
 
97
 
98
  print(f"Initial Investment: ${initial_balance:.2f}")
99
- print(f"Final Portfolio Value: ${portfolio_values.iloc[-1]:.2f}")
100
 
101
  return portfolio_values.dropna()
102
 
103
 
104
  def main():
105
- # Load the test data
106
- test_df = pd.read_csv('data/test.csv', index_col='Date', parse_dates=True)
 
 
 
 
107
 
108
  # --- Run Baseline Strategies ---
109
- bnh_values = buy_and_hold(test_df)
110
- ewp_values = equally_weighted_rebalanced(test_df)
111
 
112
  # --- Plot the results ---
113
  plt.style.use('seaborn-v0_8-darkgrid')
@@ -127,7 +144,11 @@ def main():
127
  ax.yaxis.set_major_formatter(formatter)
128
 
129
  plt.tight_layout()
130
- plt.savefig('baseline_performance.png')
 
 
 
 
131
  plt.show()
132
 
133
  if __name__ == '__main__':
 
 
 
1
  import pandas as pd
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
+ import os # Import os for directory creation
5
 
6
+ def buy_and_hold(df_assets, initial_balance=10000): # Renamed df to df_assets for clarity
7
  """
8
  Simulates the Buy and Hold strategy.
9
 
10
  Args:
11
+ df_assets (pd.DataFrame): DataFrame with daily tradable asset prices ONLY.
12
  initial_balance (int): The starting capital.
13
 
14
  Returns:
15
  pd.Series: A Series containing the portfolio value for each day.
16
  """
17
  print("--- Simulating Buy and Hold ---")
18
+ n_assets = len(df_assets.columns)
19
 
20
  # Invest an equal amount in each asset at the beginning
21
  initial_investment_per_asset = initial_balance / n_assets
22
 
23
  # Get the initial prices
24
+ initial_prices = df_assets.iloc[0]
25
 
26
  # Calculate the number of shares bought for each asset
27
+ # Handle potential division by zero if an asset price is 0 (though unlikely with real data)
28
+ shares = initial_investment_per_asset / (initial_prices + 1e-8)
29
 
30
  # Calculate the portfolio value for each day
31
+ portfolio_values = df_assets.dot(shares)
32
 
33
  print(f"Initial Investment: ${initial_balance:.2f}")
34
+ print(f"Final Portfolio Value (Buy and Hold): ${portfolio_values.iloc[-1]:.2f}")
35
 
36
  return portfolio_values
37
 
38
+ def equally_weighted_rebalanced(df_assets, initial_balance=10000, rebalance_freq='M', transaction_cost_pct=0.001): # Renamed df to df_assets
39
  """
40
  Simulates an Equally Weighted Portfolio with periodic rebalancing.
41
 
42
  Args:
43
+ df_assets (pd.DataFrame): DataFrame with daily tradable asset prices ONLY.
44
  initial_balance (int): The starting capital.
45
  rebalance_freq (str): The rebalancing frequency ('M' for monthly, 'Q' for quarterly).
46
  transaction_cost_pct (float): The transaction cost as a percentage.
 
49
  pd.Series: A Series containing the portfolio value for each day.
50
  """
51
  print(f"\n--- Simulating Equally Weighted Portfolio (Rebalanced {rebalance_freq}) ---")
52
+ n_assets = len(df_assets.columns)
53
 
54
  # Set the initial weights to be equal
55
  weights = np.full(n_assets, 1/n_assets)
56
 
57
  portfolio_value = initial_balance
58
+ portfolio_values = pd.Series(index=df_assets.index, dtype=float) # Explicitly set dtype
59
 
60
  last_rebalance_date = None
61
 
62
+ for i, (date, prices) in enumerate(df_assets.iterrows()):
63
  # Store the portfolio value for the day before any changes
64
  portfolio_values[date] = portfolio_value
65
 
66
  # Determine if it's a rebalancing day
67
+ # Rebalance on the first day of the new period (month, quarter) or if it's the very first day
68
+ rebalance_this_day = False
69
+ if i == 0: # Rebalance on the very first day
70
+ rebalance_this_day = True
71
+ elif rebalance_freq == 'M' and date.month != df_assets.index[i-1].month:
72
+ rebalance_this_day = True
73
+ # Add 'Q' for quarterly if needed, similar logic
74
+
75
+ if rebalance_this_day:
76
  # Calculate the value of trades to rebalance
77
  target_asset_values = portfolio_value * (1/n_assets)
78
  current_asset_values = weights * portfolio_value
 
88
 
89
  # Calculate portfolio value for the *next* day before the market opens
90
  # Get price changes from today to the next trading day
91
+ today_prices = prices # Already have prices for the current date
92
+ next_day_index = df_assets.index.get_loc(date) + 1
93
+ if next_day_index < len(df_assets):
94
+ next_day_prices = df_assets.iloc[next_day_index]
95
+
96
+ # Avoid division by zero
97
+ price_change_ratio = next_day_prices / (today_prices + 1e-8)
98
 
99
  # Update portfolio value based on price changes
100
  portfolio_value = np.sum( (weights * portfolio_value) * price_change_ratio )
101
 
102
  # Update weights due to market drift
103
  new_asset_values = (weights * portfolio_value) * price_change_ratio
104
+ # Avoid division by zero for total portfolio value
105
+ if np.sum(new_asset_values) > 1e-8: # Check if total value is effectively non-zero
106
+ weights = new_asset_values / np.sum(new_asset_values)
107
+ else:
108
+ weights = np.full(n_assets, 1/n_assets) # Default to equal or handle as error
109
+
110
 
111
  print(f"Initial Investment: ${initial_balance:.2f}")
112
+ print(f"Final Portfolio Value (Equally Weighted): ${portfolio_values.iloc[-1]:.2f}")
113
 
114
  return portfolio_values.dropna()
115
 
116
 
117
  def main():
118
+ # Load the evaluation data (which contains both assets and macro data)
119
+ full_eval_df = pd.read_csv('data/eval.csv', index_col='Date', parse_dates=True)
120
+
121
+ # --- IMPORTANT: Filter ONLY asset columns for baselines ---
122
+ asset_columns = ['AAPL', 'BTC-USD', 'MSFT', 'SPY', 'TLT'] # Define your actual tradable assets
123
+ test_df_assets_only = full_eval_df[asset_columns]
124
 
125
  # --- Run Baseline Strategies ---
126
+ bnh_values = buy_and_hold1(test_df_assets_only)
127
+ ewp_values = equally_weighted_rebalanced(test_df_assets_only)
128
 
129
  # --- Plot the results ---
130
  plt.style.use('seaborn-v0_8-darkgrid')
 
144
  ax.yaxis.set_major_formatter(formatter)
145
 
146
  plt.tight_layout()
147
+
148
+ # Ensure results directory exists for saving plot
149
+ results_dir = 'results'
150
+ os.makedirs(results_dir, exist_ok=True)
151
+ plt.savefig(os.path.join(results_dir, 'baseline_performance.png'))
152
  plt.show()
153
 
154
  if __name__ == '__main__':
scripts/fetch_data.py DELETED
@@ -1,75 +0,0 @@
1
- import yfinance as yf
2
- import pandas as pd
3
- import os
4
-
5
- # --- Configuration ---
6
- # Asset tickers
7
- TICKERS = ["AAPL", "MSFT", "SPY", "TLT", "BTC-USD"]
8
- # Time periods for training and testing
9
- TRAIN_START_DATE = "2015-01-01"
10
- TRAIN_END_DATE = "2020-12-31"
11
- TEST_START_DATE = "2021-01-01"
12
- TEST_END_DATE = "2023-12-31"
13
-
14
- # Directory to save the data
15
- DATA_DIR = "data"
16
- TRAIN_DATA_PATH = os.path.join(DATA_DIR, "train.csv")
17
- TEST_DATA_PATH = os.path.join(DATA_DIR, "test.csv")
18
-
19
- # --- Data Fetching and Processing ---
20
-
21
- def fetch_and_prepare_data(start_date, end_date, tickers):
22
- """
23
- Fetches historical data for the given tickers and processes it.
24
- Returns a DataFrame with 'Close' prices for each ticker.
25
- """
26
- print(f"Fetching data from {start_date} to {end_date} for {tickers}...")
27
-
28
- data = yf.download(tickers, start=start_date, end=end_date)
29
-
30
- # CHANGE: Add .copy() to explicitly create a new DataFrame and avoid warnings.
31
- close_data = data['Close'].copy()
32
-
33
- print("\nData Head:")
34
- print(close_data.head())
35
-
36
- print("\nMissing values before cleaning:")
37
- print(close_data.isnull().sum())
38
-
39
- # Now, all inplace operations are safely performed on our own copy.
40
- close_data.ffill(inplace=True)
41
- close_data.bfill(inplace=True)
42
-
43
- print("\nMissing values after cleaning:")
44
- print(close_data.isnull().sum())
45
-
46
- for col in close_data.columns:
47
- close_data[col] = pd.to_numeric(close_data[col], errors='coerce')
48
-
49
- close_data.dropna(inplace=True)
50
-
51
- return close_data
52
-
53
- def main():
54
- """Main function to run the data fetching process."""
55
- # Create data directory if it doesn't exist
56
- if not os.path.exists(DATA_DIR):
57
- os.makedirs(DATA_DIR)
58
- print(f"Created directory: {DATA_DIR}")
59
-
60
- # Fetch, process, and save training data
61
- print("--- Preparing Training Data ---")
62
- train_data = fetch_and_prepare_data(TRAIN_START_DATE, TRAIN_END_DATE, TICKERS)
63
- train_data.to_csv(TRAIN_DATA_PATH)
64
- print(f"Training data saved to {TRAIN_DATA_PATH}")
65
-
66
- print("\n" + "="*50 + "\n")
67
-
68
- # Fetch, process, and save testing data
69
- print("--- Preparing Testing Data ---")
70
- test_data = fetch_and_prepare_data(TEST_START_DATE, TEST_END_DATE, TICKERS)
71
- test_data.to_csv(TEST_DATA_PATH)
72
- print(f"Testing data saved to {TEST_DATA_PATH}")
73
-
74
- if __name__ == "__main__":
75
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/fetch_market_data.py CHANGED
@@ -1,78 +1,104 @@
 
 
 
 
1
  import argparse
2
  import os
3
- import pandas as pd
4
- import yfinance as yf
5
- from datetime import date
6
 
7
- def fetch_data(start_date, end_date, output_filename):
8
- """
9
- Fetches, cleans, and saves historical market data for a given date range.
 
 
 
 
 
 
 
 
10
 
11
- Args:
12
- start_date (str): The start date for the data in 'YYYY-MM-DD' format.
13
- end_date (str): The end date for the data in 'YYYY-MM-DD' format.
14
- output_filename (str): The path and name of the file to save the data.
15
  """
16
- print(f"--- Fetching data from {start_date} to {end_date} ---")
 
 
 
 
17
 
18
- # Define the base list of tickers
19
- tickers = ["AAPL", "MSFT", "SPY", "TLT", "BTC-USD"]
20
-
21
- # Smartly remove Bitcoin if the period is before its existence (e.g., before 2013)
22
- if pd.to_datetime(start_date).year < 2013:
23
- print("Note: Bitcoin (BTC-USD) did not exist for the requested period and will be excluded.")
24
- tickers.remove("BTC-USD")
25
-
26
- # Download data from Yahoo Finance
27
- data = yf.download(tickers, start=start_date, end=end_date)
28
- close_data = data['Close'].copy()
29
-
30
- # Data Cleaning
31
- print(f"\nMissing values before cleaning:\n{close_data.isnull().sum()}")
32
- close_data.ffill(inplace=True)
33
- close_data.bfill(inplace=True)
34
-
35
- # Drop any columns that are still all NaN (like BTC in the 2008 data)
36
- close_data.dropna(axis=1, how='all', inplace=True)
37
-
38
- print(f"\nMissing values after cleaning:\n{close_data.isnull().sum()}")
39
 
40
- # Ensure data directory exists
41
- output_dir = os.path.dirname(output_filename)
42
- if output_dir and not os.path.exists(output_dir):
43
- os.makedirs(output_dir)
44
 
45
- # Save to CSV
46
- close_data.to_csv(output_filename)
47
- print(f"\n✅ Data successfully saved to {output_filename}")
48
 
 
 
 
 
 
 
 
49
 
50
- if __name__ == "__main__":
51
- # Set up command-line argument parsing
52
- parser = argparse.ArgumentParser(description="Fetch historical market data for specified periods.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- parser.add_argument(
55
- "--start",
56
- type=str,
57
- default="2018-01-01",
58
- help="Start date in YYYY-MM-DD format. Default is for the 2018 stress test."
59
- )
60
- parser.add_argument(
61
- "--end",
62
- type=str,
63
- default="2019-12-31",
64
- help="End date in YYYY-MM-DD format. Default is for the 2018 stress test."
65
- )
66
- parser.add_argument(
67
- "--filename",
68
- type=str,
69
- default="data/stress_test_2018.csv",
70
- help="Output file name (e.g., 'data/my_data.csv')."
71
- )
72
 
73
- args = parser.parse_args()
74
 
75
- # Use 'today' as the end date if specified
76
- end_date = date.today().strftime('%Y-%m-%d') if args.end.lower() == 'today' else args.end
 
 
 
 
 
77
 
78
- fetch_data(start_date=args.start, end_date=end_date, output_filename=args.filename)
 
1
+ # scripts/fetch_market_data.py
2
+
3
+ import yfinance as yf_lib
4
+ import pandas as pd
5
  import argparse
6
  import os
7
+ from datetime import datetime, timedelta
8
+ from pandas_datareader import data as pdr
 
9
 
10
+ # --- MOVE THESE OUTSIDE THE FUNCTION ---
11
+ # Define your assets (Global variable, importable)
12
+ ASSETS = ['AAPL', 'MSFT', 'SPY', 'TLT', 'BTC-USD']
13
+
14
+ # Define FRED IDs for macroeconomic data (Global variable, importable)
15
+ FRED_IDS = {
16
+ 'DFF': 'Federal Funds Rate', # Daily Federal Funds Rate
17
+ 'CPIAUCSL': 'CPI', # Consumer Price Index (All Urban Consumers, Seasonally Adjusted, Monthly)
18
+ 'VIXCLS': 'VIX' # CBOE Volatility Index (VIX) from FRED
19
+ }
20
+ # ---------------------------------------
21
 
22
+ def fetch_market_data(start_date, end_date, filename):
 
 
 
23
  """
24
+ Fetches market data, macroeconomic indicators (including VIX from FRED),
25
+ for specified assets and time period, then saves it to a CSV file.
26
+ """
27
+ # No need to re-define assets and fred_ids here.
28
+ # The function will use the global ASSETS and FRED_IDS defined above.
29
 
30
+ print(f"--- Fetching market data for {ASSETS} from {start_date} to {end_date} ---")
31
+
32
+ # 1. Fetch Asset Prices (Daily) using yf_lib.download()
33
+ try:
34
+ # Use the global ASSETS variable
35
+ df_prices = yf_lib.download(ASSETS, start=start_date, end=end_date)['Close']
36
+ df_prices.dropna(inplace=True)
37
+ print(f"✅ Fetched {len(ASSETS)} asset prices.")
38
+ except Exception as e:
39
+ print(f"❌ Error fetching asset prices: {e}")
40
+ return None # Return None on failure
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # 2. Fetch Macro Data (VIX, Federal Funds Rate, CPI) from FRED using pandas_datareader
43
+ print("--- Fetching macroeconomic data from FRED ---")
 
 
44
 
45
+ try:
46
+ # FRED data can be tricky with exact date ranges, fetching a bit more to ensure coverage
47
+ fred_start_date = (datetime.strptime(start_date, '%Y-%m-%d') - timedelta(days=365)).strftime('%Y-%m-%d')
48
 
49
+ # Use the global FRED_IDS variable
50
+ df_fred = pdr.DataReader(list(FRED_IDS.keys()), 'fred', start=fred_start_date, end=end_date)
51
+ df_fred.rename(columns=FRED_IDS, inplace=True)
52
+ print("✅ Fetched Federal Funds Rate, CPI, and VIX data from FRED.")
53
+ except Exception as e:
54
+ print(f"❌ Error fetching FRED data: {e}. Check FRED API access or ticker validity.")
55
+ df_fred = pd.DataFrame() # Create empty dataframe if fetch fails
56
 
57
+ # Combine all dataframes
58
+ df_combined = df_prices.copy()
59
+
60
+ # Merge FRED data (now includes VIX)
61
+ if not df_fred.empty:
62
+ df_combined = df_combined.merge(df_fred, left_index=True, right_index=True, how='left')
63
+
64
+ # Handle missing macro data: forward-fill and then back-fill for initial NaNs
65
+ # This loop now covers all FRED columns
66
+ # Use the global FRED_IDS variable
67
+ for col_name in FRED_IDS.values():
68
+ if col_name in df_combined.columns:
69
+ df_combined[col_name] = df_combined[col_name].ffill().bfill()
70
+ # Drop rows if they still have NaN for macro data after fill
71
+ df_combined.dropna(subset=[col_name], inplace=True) # Added dropna for robustness
72
+
73
+ # Ensure all data is within the requested date range after merging/filling
74
+ df_combined = df_combined.loc[start_date:end_date]
75
+ df_combined.dropna(inplace=True) # Final dropna for any remaining NaNs
76
+
77
+ if df_combined.empty:
78
+ print("❌ Final combined dataframe is empty after merging and cleaning. Check date ranges and data availability.")
79
+ return None # Return None on failure
80
+
81
+ # Save to CSV if a filename is provided
82
+ if filename:
83
+ output_dir = os.path.dirname(filename)
84
+ if output_dir and not os.path.dirname(filename) == "":
85
+ os.makedirs(output_dir, exist_ok=True)
86
+
87
+ df_combined.to_csv(filename, index=True)
88
+ print(f"\n✅ Data saved successfully to {filename}")
89
+
90
+ print(f"Final data shape: {df_combined.shape}")
91
+ print("Columns:", df_combined.columns.tolist())
92
 
93
+ return df_combined # Return the DataFrame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
 
95
 
96
+ if __name__ == '__main__':
97
+ parser = argparse.ArgumentParser(description="Fetch market and macroeconomic data.")
98
+ parser.add_argument("--start", type=str, default="2015-01-01", help="Start date (YYYY-MM-DD).")
99
+ parser.add_argument("--end", type=str, default="2020-12-31", help="End date (YYYY-MM-DD).")
100
+ parser.add_argument("--filename", type=str, default="data/train.csv", help="Output CSV filename.")
101
+
102
+ args = parser.parse_args()
103
 
104
+ fetch_market_data(args.start, args.end, args.filename)
scripts/llm_analysis_rag.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # scripts/llm_analysis_rag.py
2
+ import gc
3
+ import time
4
+ import os
5
+ import shutil
6
+ import torch
7
+ import pandas as pd
8
+ import numpy as np
9
+ import json
10
+ import re
11
+ from datetime import datetime, timedelta
12
+
13
+ # LangChain components
14
+ from langchain_community.vectorstores import Chroma
15
+ from langchain_community.embeddings import HuggingFaceEmbeddings
16
+ from langchain_huggingface import HuggingFacePipeline
17
+ from langchain_classic.chains import RetrievalQA
18
+ from langchain_classic.prompts import PromptTemplate
19
+ from langchain_classic.docstore.document import Document
20
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
21
+
22
+ # --- Configuration ---
23
+ # HF_EMBEDDING_MODEL is no longer used in this reduced scope
24
+ HF_GENERATION_MODEL = "Qwen/Qwen2.5-3B-Instruct"
25
+
26
+ # Global variables
27
+ llm_pipeline_hf_instance = None
28
+
29
+ # --- Helper: Robust JSON Extractor ---
30
+ def extract_clean_json(response_text):
31
+ """Robust JSON extractor handling Python booleans and Markdown."""
32
+ json_match = re.search(r'```json\s*(.*?)\s*```', response_text, re.DOTALL)
33
+ if json_match:
34
+ text_to_parse = json_match.group(1)
35
+ else:
36
+ start_idx = response_text.find('{')
37
+ end_idx = response_text.rfind('}')
38
+ if start_idx != -1 and end_idx != -1:
39
+ text_to_parse = response_text[start_idx:end_idx+1]
40
+ else:
41
+ # print(f"❌ PARSE ERROR: No JSON found: {response_text[:100]}...")
42
+ return None
43
+
44
+ text_to_parse = text_to_parse.replace(": True", ": true").replace(": False", ": false")
45
+
46
+ try:
47
+ return json.loads(text_to_parse)
48
+ except json.JSONDecodeError:
49
+ return None
50
+
51
+ # --- Shared LLM Setup (Singleton Pattern) ---
52
+ def setup_llm_pipeline():
53
+ global llm_pipeline_hf_instance
54
+ if llm_pipeline_hf_instance is None:
55
+ print(f"--- Loading Model: {HF_GENERATION_MODEL} ---")
56
+ tokenizer = AutoTokenizer.from_pretrained(HF_GENERATION_MODEL, trust_remote_code=True)
57
+ # 4-bit quantization config for efficient loading
58
+ bnb_config = BitsAndBytesConfig(
59
+ load_in_4bit=True, bnb_4bit_use_double_quant=True,
60
+ bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
61
+ )
62
+ model = AutoModelForCausalLM.from_pretrained(
63
+ HF_GENERATION_MODEL, trust_remote_code=True,
64
+ quantization_config=bnb_config, device_map="auto"
65
+ )
66
+ # Create the HF pipeline
67
+ pipe = pipeline(
68
+ "text-generation", model=model, tokenizer=tokenizer,
69
+ max_new_tokens=1024, # Increased token limit for detailed historical analysis
70
+ do_sample=False, temperature=0.1, # Low temp for factual responses
71
+ return_full_text=False
72
+ )
73
+ llm_pipeline_hf_instance = HuggingFacePipeline(pipeline=pipe)
74
+ return llm_pipeline_hf_instance
75
+
76
+ # =========================================
77
+ # NEW FUNCTION: Structured Historical Analysis
78
+ # =========================================
79
+ def analyze_historical_segment(df_segment, selected_assets, period_name):
80
+ """
81
+ Analyzes a specific segment of historical data directly without RAG.
82
+ Takes a DataFrame slice, calculates summary stats, and prompts the LLM.
83
+ """
84
+ llm = setup_llm_pipeline()
85
+ print(f"--- Running Historical Analysis for {period_name} ---")
86
+
87
+ # 1. Create quantitative summary of the data segment for the prompt
88
+ if df_segment.empty:
89
+ return "No data available for this period to analyze."
90
+
91
+ start_date = df_segment.index.min().date()
92
+ end_date = df_segment.index.max().date()
93
+
94
+ start_vals = df_segment.iloc[0]
95
+ end_vals = df_segment.iloc[-1]
96
+ # Calculate percentage change over the period, handling potential zeros
97
+ pct_changes = ((end_vals - start_vals) / (start_vals.replace(0, np.nan)) * 100).fillna(0)
98
+
99
+ # Build the context string
100
+ data_summary = f"Analysis Period: {period_name} ({start_date} to {end_date})\n\n"
101
+ data_summary += "Performance Summary over Period:\n"
102
+ for asset in selected_assets:
103
+ if asset in df_segment.columns:
104
+ change = pct_changes[asset]
105
+ direction = "gained" if change > 0 else "lost"
106
+ data_summary += f"- {asset}: {direction} {abs(change):.2f}%\n"
107
+
108
+ # Add volatility context (standard deviation of daily returns)
109
+ data_summary += "\nVolatility Context (Daily Return Std Dev):\n"
110
+ daily_rets = df_segment.pct_change()
111
+ std_devs = daily_rets.std() * 100
112
+ for asset in selected_assets:
113
+ if asset in std_devs.index:
114
+ data_summary += f"- {asset}: {std_devs[asset]:.2f}%\n"
115
+
116
+ # 2. Create the Prompt
117
+ # We use Qwen's chat template format (<|im_start|>...)
118
+ prompt_template = """<|im_start|>system
119
+ You are a senior financial analyst. Your job is to analyze historical market data trends for selected assets over a specific time period.
120
+ Provide a concise, professional, and insightful summary of the performance, key trends, and comparative movements based *only* on the provided data summary.
121
+ Highlight significant gains, losses, or differences in volatility between the assets.
122
+
123
+ ### DATA CONTEXT:
124
+ {data_summary}
125
+ <|im_end|>
126
+ <|im_start|>user
127
+ Generate the historical analysis report.
128
+ <|im_end|>
129
+ <|im_start|>assistant
130
+ """
131
+ pt = PromptTemplate(template=prompt_template, input_variables=["data_summary"])
132
+ formatted_prompt = pt.format(data_summary=data_summary)
133
+
134
+ # 3. Invoke LLM
135
+ response = llm.invoke(formatted_prompt)
136
+ return response.strip()
137
+
138
+
139
+ # --- Decision Analysis (Kept for Forecast Tab) ---
140
+ def analyze_agent_decision(current_market_data_window, proposed_allocations):
141
+ """
142
+ HYBRID ANALYZER: Python does the math, LLM does the talking.
143
+ """
144
+ llm = setup_llm_pipeline()
145
+
146
+ # --- 1. PREPARE DATA ---
147
+ # (Logic remains the same as before...)
148
+ vix_level = current_market_data_window['VIX'].iloc[-1] if 'VIX' in current_market_data_window else 0
149
+
150
+ # Identify largest position
151
+ risky_assets = {k:v for k,v in proposed_allocations.items() if k not in ['Cash', 'TLT']}
152
+ if risky_assets:
153
+ max_asset = max(risky_assets, key=risky_assets.get)
154
+ max_val = risky_assets[max_asset] * 100
155
+ else:
156
+ max_asset = "None"
157
+ max_val = 0.0
158
+
159
+ safe_haven_pct = (proposed_allocations.get('Cash', 0) + proposed_allocations.get('TLT', 0)) * 100
160
+
161
+ # --- 2. PYTHON LOGIC CORE ---
162
+ trigger_safe_haven = safe_haven_pct > 80.0
163
+ trigger_crash_rule = vix_level > 20.0 and safe_haven_pct < 30.0
164
+ trigger_concentration = vix_level > 15.0 and max_val > 40.0
165
+
166
+ # Determine Verdict Programmatically
167
+ calculated_risk = "MODERATE"
168
+ reason_code = "Standard market conditions."
169
+
170
+ if trigger_safe_haven:
171
+ calculated_risk = "LOW"
172
+ reason_code = f"Safe Haven Exception triggered (Safe Assets: {safe_haven_pct:.1f}% > 80%)."
173
+ elif trigger_crash_rule:
174
+ calculated_risk = "HIGH"
175
+ reason_code = f"Crash Protocol triggered (VIX {vix_level:.1f} > 20 and Safe Haven < 30%)."
176
+ elif trigger_concentration:
177
+ calculated_risk = "HIGH"
178
+ reason_code = f"Concentration Rule triggered (VIX {vix_level:.1f} > 15 and {max_asset} > 40%)."
179
+
180
+ # --- 3. THE "NARRATOR" PROMPT ---
181
+ prompt_template = """<|im_start|>system
182
+ You are a Senior Risk Analyst.
183
+ The Quantitative Engine has already processed the data and determined the Risk Level.
184
+ Your job is to summarize the strategy and explain the risk verdict to the user.
185
+
186
+ ### QUANTITATIVE ENGINE OUTPUT:
187
+ - **Determined Risk Level:** {calculated_risk}
188
+ - **Primary Logic Trigger:** {reason_code}
189
+
190
+ ### DATA CONTEXT:
191
+ - VIX: {vix:.2f}
192
+ - Largest Position: {max_asset} ({max_val:.1f}%)
193
+ - Safe Haven Allocation: {safe_pct:.1f}%
194
+
195
+ ### INSTRUCTIONS:
196
+ 1. **Strategy Summary:** Describe the allocation style (e.g., "Aggressive Tech", "Defensive Cash").
197
+ 2. **Justification:** Explain the Risk Level using the "Primary Logic Trigger" provided above. Do not invent new math.
198
+
199
+ Return ONLY raw JSON:
200
+ {{
201
+ "strategy_summary": "string",
202
+ "risk_level": "{calculated_risk}",
203
+ "justification": "string",
204
+ "confidence_score": 10
205
+ }}
206
+ <|im_end|>
207
+ <|im_start|>user
208
+ Generate the report.
209
+ <|im_end|>
210
+ <|im_start|>assistant
211
+ """
212
+ pt = PromptTemplate(template=prompt_template, input_variables=["calculated_risk", "reason_code", "vix", "max_asset", "max_val", "safe_pct"])
213
+
214
+ formatted = pt.format(
215
+ calculated_risk=calculated_risk,
216
+ reason_code=reason_code,
217
+ vix=vix_level,
218
+ max_asset=max_asset,
219
+ max_val=max_val,
220
+ safe_pct=safe_haven_pct
221
+ )
222
+
223
+ res = llm.invoke(formatted)
224
+ return extract_clean_json(res)
225
+
226
+ # --- MAIN (for testing) ---
227
+ if __name__ == '__main__':
228
+ print("Running test...")
229
+ # Generate Dummy Data
230
+ dates = pd.date_range(start="2023-01-01", periods=180, freq='D')
231
+ df_dummy = pd.DataFrame({
232
+ 'SPY': np.linspace(400, 450, 180) + np.random.normal(0, 5, 180),
233
+ 'BTC-USD': np.linspace(30000, 40000, 180) + np.random.normal(0, 1000, 180),
234
+ 'VIX': np.linspace(20, 15, 180)
235
+ }, index=dates)
236
+
237
+ # Test the new historical analysis function
238
+ selected = ['SPY', 'BTC-USD']
239
+ period = "6 Months"
240
+ print(f"\nTesting analysis for {selected} over {period}...")
241
+ analysis = analyze_historical_segment(df_dummy, selected, period)
242
+ print("\n--- LLM Analysis Output ---")
243
+ print(analysis)
scripts/predict_tomorrow.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # scripts/predict_tomorrow.py
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+ import os
6
+ import sys
7
+ from datetime import datetime, timedelta
8
+ from stable_baselines3 import SAC
9
+
10
+ # --- Imports ---
11
+ try:
12
+ # Ensure we can find local scripts
13
+ sys.path.append(os.getcwd())
14
+ except:
15
+ pass
16
+
17
+ from fetch_market_data import fetch_market_data, ASSETS, FRED_IDS
18
+ from llm_analysis_rag import analyze_agent_decision
19
+
20
+ # --- Configuration ---
21
+ MODEL_PATH = "checkpoints/sac_portfolio_model.zip"
22
+ WINDOW_SIZE = 30
23
+ MACRO_COLS = list(FRED_IDS.values()) # ['Federal Funds Rate', 'CPI', 'VIX']
24
+
25
+ def get_latest_data_window(window_size=30):
26
+ """
27
+ Fetches live data and returns the last 'window_size' rows.
28
+ """
29
+ print("--- 🔄 Fetching Real-Time Data for Prediction ---")
30
+
31
+ # Fetch a buffer to ensure we have enough data after cleaning
32
+ lookback_days = window_size + 100
33
+ end_date = datetime.now().strftime('%Y-%m-%d')
34
+ start_date = (datetime.now() - timedelta(days=lookback_days)).strftime('%Y-%m-%d')
35
+
36
+ # We don't strictly need to save to a file for prediction, so filename=None
37
+ df = fetch_market_data(start_date, end_date, filename=None)
38
+
39
+ if df is None or len(df) < window_size:
40
+ print(f"❌ Not enough data fetched. Got {len(df) if df is not None else 0} rows, needed {window_size}.")
41
+ return None
42
+
43
+ # Return exactly the last N rows (Observation Window)
44
+ return df.iloc[-window_size:].copy()
45
+
46
+ def prepare_observation(data_window):
47
+ """
48
+ Normalizes data: Window / First_Row_of_Window
49
+ """
50
+ # Extract specific columns to guarantee order
51
+ price_data = data_window[ASSETS].values
52
+ macro_data = data_window[MACRO_COLS].values
53
+
54
+ # Normalize
55
+ norm_prices = price_data / (price_data[0] + 1e-8)
56
+ norm_macro = macro_data / (macro_data[0] + 1e-8)
57
+
58
+ # Concatenate and flatten for MLP input
59
+ obs = np.concatenate([norm_prices, norm_macro], axis=1)
60
+ return obs.flatten().astype(np.float32)
61
+
62
+ def get_allocations(action):
63
+ """Applies Softmax to convert raw action to weights"""
64
+ action = np.asarray(action).flatten()
65
+ exp_action = np.exp(action)
66
+ return exp_action / np.sum(exp_action)
67
+
68
+ def main():
69
+ print(f"🚀 Prediction Job: {datetime.now().strftime('%Y-%m-%d')}")
70
+
71
+ # 1. Get Data
72
+ data_window = get_latest_data_window(WINDOW_SIZE)
73
+ if data_window is None: return
74
+
75
+ # 2. Prepare Obs
76
+ obs = prepare_observation(data_window)
77
+
78
+ # 3. Load MLP Model
79
+ if not os.path.exists(MODEL_PATH):
80
+ print(f"❌ Model not found at {MODEL_PATH}")
81
+ return
82
+
83
+ print(f"Loading MLP SAC model...")
84
+ model = SAC.load(MODEL_PATH)
85
+
86
+ # 4. Predict
87
+ action, _ = model.predict(obs, deterministic=True)
88
+ weights = get_allocations(action)
89
+
90
+ # 5. Format Allocations (THE FIX IS HERE)
91
+ allocations = {}
92
+ for i, asset in enumerate(ASSETS):
93
+ allocations[asset] = float(weights[i]) # Explicit float() cast
94
+ allocations['Cash'] = float(weights[-1]) # Explicit float() cast
95
+
96
+ # 6. Output Results
97
+ print("\n" + "="*40)
98
+ print(f"🤖 SAC MLP MODEL RECOMMENDATION")
99
+ print("="*40)
100
+ for asset, weight in allocations.items():
101
+ print(f"{asset:<10} : {weight:6.2%}")
102
+ print("="*40)
103
+
104
+ # 7. AI Risk Analyst
105
+ print("\n🧠 Running AI Risk Analysis...")
106
+
107
+ # Now this will work because all numbers are standard floats
108
+ analysis = analyze_agent_decision(data_window, allocations)
109
+
110
+ if isinstance(analysis, dict):
111
+ print(f"\nStrategy: {analysis.get('strategy_summary')}")
112
+ print(f"Risk Level: {analysis.get('risk_level')}")
113
+ print(f"Justification: {analysis.get('justification')}")
114
+
115
+ if analysis.get('risk_level') == 'High':
116
+ print("\n⛔ BLOCKING TRADE: High Risk detected by AI Guardrail.")
117
+ else:
118
+ print("\n✅ TRADE APPROVED.")
119
+ else:
120
+ print(analysis)
121
+
122
+ if __name__ == "__main__":
123
+ main()
scripts/tune_sac.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # scripts/tune_sac.py
2
+
3
+ import os
4
+ import sys
5
+ import pandas as pd
6
+ import numpy as np
7
+ import optuna
8
+ from stable_baselines3 import SAC
9
+ from stable_baselines3.common.vec_env import DummyVecEnv # Use DummyVecEnv
10
+ from stable_baselines3.common.callbacks import EvalCallback
11
+ from stable_baselines3.common.logger import configure
12
+
13
+ from environment import PortfolioEnv
14
+
15
+ # ==============================================================================
16
+ # 1. Configuration & Data Loading
17
+ # ==============================================================================
18
+
19
+ TRAIN_DATA_PATH = 'data/train.csv'
20
+ EVAL_DATA_PATH = 'data/eval.csv'
21
+ OPTUNA_LOG_DIR = 'optuna_logs'
22
+ CHECKPOINT_DIR = 'checkpoints/optuna_sac_trials'
23
+
24
+ # Create directories if they don't exist
25
+ os.makedirs(OPTUNA_LOG_DIR, exist_ok=True)
26
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
27
+
28
+ # Load data once
29
+ df_full_train = pd.read_csv(TRAIN_DATA_PATH, index_col='Date', parse_dates=True)
30
+ df_eval = pd.read_csv(EVAL_DATA_PATH, index_col='Date', parse_dates=True)
31
+
32
+ # Split df_full_train for tuning
33
+ train_split_point = int(len(df_full_train) * 0.8)
34
+ df_train_tune = df_full_train.iloc[:train_split_point]
35
+ df_validation_tune = df_full_train.iloc[train_split_point:]
36
+
37
+ print(f"Total training data points: {len(df_full_train)}")
38
+ print(f"Optuna training data points: {len(df_train_tune)}")
39
+ print(f"Optuna validation data points: {len(df_validation_tune)}")
40
+
41
+
42
+ # ==============================================================================
43
+ # 2. Environment Creation Helper
44
+ # ==============================================================================
45
+
46
+ def make_env(df, window_size=30, initial_balance=10000, transaction_cost_pct=0.001):
47
+ """
48
+ Helper function to create a PortfolioEnv instance.
49
+ """
50
+ def _init():
51
+ env = PortfolioEnv(
52
+ df=df,
53
+ initial_balance=initial_balance,
54
+ window_size=window_size,
55
+ transaction_cost_pct=transaction_cost_pct
56
+ )
57
+ return env
58
+ return _init
59
+
60
+ # ==============================================================================
61
+ # 3. Optuna Objective Function
62
+ # ==============================================================================
63
+
64
+ def objective(trial: optuna.Trial) -> float:
65
+ """
66
+ Objective function for Optuna to optimize hyperparameters for SAC.
67
+ """
68
+ # Hyperparameter search space
69
+ learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-3, log=True)
70
+ gamma = trial.suggest_float('gamma', 0.9, 0.999)
71
+ tau = trial.suggest_float('tau', 0.005, 0.02)
72
+ buffer_size = trial.suggest_int('buffer_size', 50000, 1000000, log=True)
73
+ batch_size = trial.suggest_categorical('batch_size', [64, 128, 256, 512])
74
+ ent_coef = trial.suggest_float('ent_coef', 0.001, 0.1, log=True) # Use log scale for ent_coef
75
+
76
+ # Network architecture
77
+ n_layers = trial.suggest_int('n_layers', 1, 3)
78
+ net_arch = []
79
+ for i in range(n_layers):
80
+ layer_size = trial.suggest_categorical(f'layer_size_{i}', [64, 128, 256])
81
+ net_arch.append(layer_size)
82
+
83
+ policy_kwargs = dict(net_arch=net_arch) # SAC uses shared network or separate [pi, qf]
84
+
85
+ # Create environments for this trial
86
+ train_env = DummyVecEnv([make_env(df_train_tune)])
87
+ eval_env = DummyVecEnv([make_env(df_validation_tune)])
88
+
89
+ # Set up logger for the trial
90
+ trial_log_path = os.path.join(OPTUNA_LOG_DIR, f"trial_{trial.number}")
91
+ new_logger = configure(trial_log_path, ["stdout", "csv", "tensorboard"])
92
+
93
+ # Create SAC model
94
+ model = SAC(
95
+ "MlpPolicy",
96
+ train_env,
97
+ learning_rate=learning_rate,
98
+ gamma=gamma,
99
+ tau=tau,
100
+ buffer_size=buffer_size,
101
+ batch_size=batch_size,
102
+ ent_coef=ent_coef, # Pass the sampled value
103
+ policy_kwargs=policy_kwargs,
104
+ verbose=0,
105
+ seed=42, # Use a fixed seed for reproducibility within a trial
106
+ tensorboard_log=OPTUNA_LOG_DIR
107
+ )
108
+ model.set_logger(new_logger)
109
+
110
+ # Callback for evaluation
111
+ eval_callback = EvalCallback(
112
+ eval_env,
113
+ best_model_save_path=os.path.join(CHECKPOINT_DIR, f"best_sac_trial_{trial.number}"),
114
+ log_path=trial_log_path,
115
+ eval_freq=5000,
116
+ deterministic=True,
117
+ render=False,
118
+ n_eval_episodes=1
119
+ )
120
+
121
+ try:
122
+ # Train for a set number of steps per trial
123
+ total_timesteps_per_trial = 50000
124
+ model.learn(total_timesteps=total_timesteps_per_trial, callback=eval_callback, progress_bar=False)
125
+
126
+ # Load the best model found during this trial's training
127
+ best_model_path = os.path.join(CHECKPOINT_DIR, f"best_sac_trial_{trial.number}", "best_model.zip")
128
+ if os.path.exists(best_model_path):
129
+ model = SAC.load(best_model_path, env=eval_env)
130
+ else:
131
+ print(f"Warning: No best model saved for trial {trial.number}, using last model.")
132
+
133
+ # --- Final evaluation on the validation set ---
134
+ obs = eval_env.reset()
135
+ portfolio_values = [eval_env.envs[0].initial_balance]
136
+ done = False
137
+ while not done:
138
+ action, _ = model.predict(obs, deterministic=True)
139
+ obs, reward, done, info = eval_env.step(action)
140
+ portfolio_values.append(info[0]['portfolio_value'])
141
+
142
+ final_portfolio_value = portfolio_values[-1]
143
+ initial_portfolio_value = portfolio_values[0]
144
+ total_return = (final_portfolio_value / initial_portfolio_value) - 1
145
+
146
+ print(f"Trial {trial.number} finished. Total Return on validation: {total_return:.4f}")
147
+
148
+ except Exception as e:
149
+ print(f"Trial {trial.number} failed due to: {e}")
150
+ return float('nan') # Optuna handles NaN as a failure
151
+
152
+ finally:
153
+ train_env.close()
154
+ eval_env.close()
155
+
156
+ return total_return # Optuna aims to maximize this metric
157
+
158
+
159
+ # ==============================================================================
160
+ # 4. Run Optuna Study
161
+ # ==============================================================================
162
+
163
+ if __name__ == '__main__':
164
+ study = optuna.create_study(
165
+ direction='maximize',
166
+ sampler=optuna.samplers.TPESampler(seed=42)
167
+ )
168
+
169
+ n_trials_to_run = 50
170
+ study.optimize(objective, n_trials=n_trials_to_run, n_jobs=1) # n_jobs=1 is safer for Colab
171
+
172
+ print("\n--- Optimization finished. ---")
173
+ print("Best trial:")
174
+ trial = study.best_trial
175
+
176
+ print(f" Value: {trial.value:.4f}")
177
+ print(" Params: ")
178
+ for key, value in trial.params.items():
179
+ print(f" {key}: {value}")
180
+
181
+ # Save the best parameters to a file
182
+ best_params = trial.params
183
+ with open('checkpoints/best_sac_params.txt', 'w') as f:
184
+ f.write(str(best_params))
185
+ print(f"\n✅ Best parameters saved to checkpoints/best_sac_params.txt")
186
+
187
+ # Plotting results
188
+ try:
189
+ import plotly
190
+ from optuna.visualization import plot_optimization_history, plot_param_importances
191
+
192
+ fig1 = plot_optimization_history(study)
193
+ fig1.show()
194
+
195
+ fig2 = plot_param_importances(study)
196
+ fig2.show()
197
+ except ImportError:
198
+ print("\nInstall plotly and kaleido to visualize Optuna results: !pip install plotly kaleido")
scripts/visualize_strategy.py DELETED
@@ -1,123 +0,0 @@
1
- import argparse
2
- import os
3
- import pandas as pd
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- from matplotlib.ticker import FuncFormatter
7
- from stable_baselines3 import PPO, SAC, TD3
8
- from environment import PortfolioEnv
9
-
10
- def visualize_strategy(agent_name, checkpoint_path, datafile_path, output_path):
11
- """
12
- Loads a trained agent, runs a simulation, and plots its portfolio allocation strategy.
13
-
14
- Args:
15
- agent_name (str): The type of agent to load ('ppo', 'sac', 'td3').
16
- checkpoint_path (str): The path to the saved model checkpoint file (.zip).
17
- datafile_path (str): The path to the CSV market data for the simulation.
18
- output_path (str): The path to save the output plot image.
19
- """
20
- print(f"--- Visualizing strategy for {agent_name.upper()} agent ---")
21
-
22
- # 1. Define a mapping from agent names to their classes
23
- AGENT_CLASSES = {
24
- "ppo": PPO,
25
- "sac": SAC,
26
- "td3": TD3
27
- }
28
- agent_class = AGENT_CLASSES[agent_name.lower()]
29
-
30
- # 2. Load Data and Model
31
- try:
32
- test_df = pd.read_csv(datafile_path, index_col='Date', parse_dates=True)
33
- model = agent_class.load(checkpoint_path)
34
- except FileNotFoundError as e:
35
- print(f"❌ Error: Could not find a required file. {e}")
36
- return
37
- except Exception as e:
38
- print(f"❌ An error occurred: {e}")
39
- return
40
-
41
- # 3. Create Environment and Run Simulation
42
- env = PortfolioEnv(test_df)
43
- obs, info = env.reset()
44
- terminated, truncated = False, False
45
-
46
- weights_history = [info['weights']]
47
- while not (terminated or truncated):
48
- action, _states = model.predict(obs, deterministic=True)
49
- obs, reward, terminated, truncated, info = env.step(action)
50
- weights_history.append(info['weights'])
51
- print("✅ Simulation complete.")
52
-
53
- # 4. Prepare Data for Plotting
54
- weights_df = pd.DataFrame(weights_history)
55
- asset_names = test_df.columns.tolist() + ['Cash']
56
- weights_df.columns = asset_names
57
- weights_df.index = test_df.index[:len(weights_df)]
58
-
59
- # 5. Plotting the Stacked Area Chart
60
- print("📊 Generating plot...")
61
- plt.style.use('seaborn-v0_8-darkgrid')
62
- fig, ax = plt.subplots(figsize=(15, 8))
63
-
64
- ax.stackplot(weights_df.index, weights_df.T, labels=weights_df.columns, alpha=0.8)
65
-
66
- ax.set_title(f'Agent Portfolio Allocation Over Time ({agent_name.upper()})', fontsize=16)
67
- ax.set_xlabel('Date', fontsize=12)
68
- ax.set_ylabel('Portfolio Allocation (%)', fontsize=12)
69
- ax.legend(loc='upper left', fontsize=10)
70
-
71
- formatter = FuncFormatter(lambda y, p: f'{y:.0%}')
72
- ax.yaxis.set_major_formatter(formatter)
73
-
74
- plt.tight_layout()
75
-
76
- # Ensure output directory exists
77
- output_dir = os.path.dirname(output_path)
78
- if output_dir and not os.path.exists(output_dir):
79
- os.makedirs(output_dir)
80
-
81
- plt.savefig(output_path)
82
- print(f"✅ Plot saved to {output_path}")
83
- plt.show()
84
-
85
-
86
- if __name__ == "__main__":
87
- # Set up command-line argument parsing
88
- parser = argparse.ArgumentParser(description="Visualize a trained RL agent's portfolio allocation strategy.")
89
-
90
- parser.add_argument(
91
- "--agent",
92
- type=str,
93
- required=True,
94
- choices=["ppo", "sac", "td3"],
95
- help="The RL algorithm of the trained agent."
96
- )
97
- parser.add_argument(
98
- "--checkpoint",
99
- type=str,
100
- required=True,
101
- help="Path to the saved model checkpoint .zip file (e.g., 'td3_portfolio_model.zip')."
102
- )
103
- parser.add_argument(
104
- "--datafile",
105
- type=str,
106
- default="data/test.csv",
107
- help="Path to the market data CSV file to run the simulation on."
108
- )
109
- parser.add_argument(
110
- "--output",
111
- type=str,
112
- default="results/agent_allocation.png",
113
- help="Path to save the output plot image."
114
- )
115
-
116
- args = parser.parse_args()
117
-
118
- visualize_strategy(
119
- agent_name=args.agent,
120
- checkpoint_path=args.checkpoint,
121
- datafile_path=args.datafile,
122
- output_path=args.output
123
- )