DanielKiani's picture
Initial commit
7d2e753
import argparse
import pandas as pd
from stable_baselines3 import PPO, SAC, TD3
from environment import PortfolioEnv
def train_agent(agent_name="td3", timesteps=100000):
"""
Main function to train a specified RL agent.
Args:
agent_name (str): The RL algorithm to use ('ppo', 'sac', or 'td3').
timesteps (int): The total number of timesteps for training.
"""
# 1. Map agent names to their corresponding classes
AGENT_CLASSES = {
"ppo": PPO,
"sac": SAC,
"td3": TD3
}
agent_class = AGENT_CLASSES.get(agent_name.lower())
if agent_class is None:
raise ValueError(f"Unknown agent: {agent_name}. Choose from {list(AGENT_CLASSES.keys())}")
model_name = agent_name.lower()
# 2. Load data and create the environment
print("--- Loading Data and Creating Environment ---")
try:
df = pd.read_csv('data/train.csv', index_col='Date', parse_dates=True)
env = PortfolioEnv(df)
print("Environment created successfully.")
except FileNotFoundError:
print("❌ Error: 'data/train.csv' not found. Make sure to run a data fetching script first.")
return
# 3. Create the RL Agent
print(f"--- Creating {agent_name.upper()} Agent ---")
model = agent_class(
"MlpPolicy",
env,
verbose=1,
tensorboard_log="./tensorboard_logs/"
)
# 4. Train the Agent
print(f"--- Starting Agent Training for {timesteps} timesteps ---")
model.learn(total_timesteps=timesteps)
print("--- Agent Training Complete ---")
# 5. Save the Trained Model
save_path = f"checkpoints/{model_name}_portfolio_model"
model.save(save_path)
print(f"βœ… Model saved to checkpoints/{save_path}.zip")
if __name__ == "__main__":
# 6. Set up command-line argument parsing
parser = argparse.ArgumentParser(description="Train a Reinforcement Learning agent for portfolio management.")
parser.add_argument(
"--agent",
type=str,
default="td3",
choices=["ppo", "sac", "td3"],
help="The RL algorithm to use for training (default: td3)."
)
parser.add_argument(
"--timesteps",
type=int,
default=100000,
help="The total number of timesteps for training (default: 100000)."
)
args = parser.parse_args()
# Call the main training function with the parsed arguments
train_agent(agent_name=args.agent, timesteps=args.timesteps)