|
|
import argparse |
|
|
import pandas as pd |
|
|
from stable_baselines3 import PPO, SAC, TD3 |
|
|
from environment import PortfolioEnv |
|
|
|
|
|
def train_agent(agent_name="td3", timesteps=100000): |
|
|
""" |
|
|
Main function to train a specified RL agent. |
|
|
|
|
|
Args: |
|
|
agent_name (str): The RL algorithm to use ('ppo', 'sac', or 'td3'). |
|
|
timesteps (int): The total number of timesteps for training. |
|
|
""" |
|
|
|
|
|
AGENT_CLASSES = { |
|
|
"ppo": PPO, |
|
|
"sac": SAC, |
|
|
"td3": TD3 |
|
|
} |
|
|
agent_class = AGENT_CLASSES.get(agent_name.lower()) |
|
|
if agent_class is None: |
|
|
raise ValueError(f"Unknown agent: {agent_name}. Choose from {list(AGENT_CLASSES.keys())}") |
|
|
|
|
|
model_name = agent_name.lower() |
|
|
|
|
|
|
|
|
print("--- Loading Data and Creating Environment ---") |
|
|
try: |
|
|
df = pd.read_csv('data/train.csv', index_col='Date', parse_dates=True) |
|
|
env = PortfolioEnv(df) |
|
|
print("Environment created successfully.") |
|
|
except FileNotFoundError: |
|
|
print("β Error: 'data/train.csv' not found. Make sure to run a data fetching script first.") |
|
|
return |
|
|
|
|
|
|
|
|
print(f"--- Creating {agent_name.upper()} Agent ---") |
|
|
model = agent_class( |
|
|
"MlpPolicy", |
|
|
env, |
|
|
verbose=1, |
|
|
tensorboard_log="./tensorboard_logs/" |
|
|
) |
|
|
|
|
|
|
|
|
print(f"--- Starting Agent Training for {timesteps} timesteps ---") |
|
|
model.learn(total_timesteps=timesteps) |
|
|
print("--- Agent Training Complete ---") |
|
|
|
|
|
|
|
|
save_path = f"checkpoints/{model_name}_portfolio_model" |
|
|
model.save(save_path) |
|
|
print(f"β
Model saved to checkpoints/{save_path}.zip") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Train a Reinforcement Learning agent for portfolio management.") |
|
|
|
|
|
parser.add_argument( |
|
|
"--agent", |
|
|
type=str, |
|
|
default="td3", |
|
|
choices=["ppo", "sac", "td3"], |
|
|
help="The RL algorithm to use for training (default: td3)." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--timesteps", |
|
|
type=int, |
|
|
default=100000, |
|
|
help="The total number of timesteps for training (default: 100000)." |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
train_agent(agent_name=args.agent, timesteps=args.timesteps) |