File size: 2,494 Bytes
7d2e753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
import pandas as pd
from stable_baselines3 import PPO, SAC, TD3
from environment import PortfolioEnv

def train_agent(agent_name="td3", timesteps=100000):
    """
    Main function to train a specified RL agent.

    Args:
        agent_name (str): The RL algorithm to use ('ppo', 'sac', or 'td3').
        timesteps (int): The total number of timesteps for training.
    """
    # 1. Map agent names to their corresponding classes
    AGENT_CLASSES = {
        "ppo": PPO,
        "sac": SAC,
        "td3": TD3
    }
    agent_class = AGENT_CLASSES.get(agent_name.lower())
    if agent_class is None:
        raise ValueError(f"Unknown agent: {agent_name}. Choose from {list(AGENT_CLASSES.keys())}")

    model_name = agent_name.lower()

    # 2. Load data and create the environment
    print("--- Loading Data and Creating Environment ---")
    try:
        df = pd.read_csv('data/train.csv', index_col='Date', parse_dates=True)
        env = PortfolioEnv(df)
        print("Environment created successfully.")
    except FileNotFoundError:
        print("❌ Error: 'data/train.csv' not found. Make sure to run a data fetching script first.")
        return

    # 3. Create the RL Agent
    print(f"--- Creating {agent_name.upper()} Agent ---")
    model = agent_class(
        "MlpPolicy",
        env,
        verbose=1,
        tensorboard_log="./tensorboard_logs/"
    )

    # 4. Train the Agent
    print(f"--- Starting Agent Training for {timesteps} timesteps ---")
    model.learn(total_timesteps=timesteps)
    print("--- Agent Training Complete ---")

    # 5. Save the Trained Model
    save_path = f"checkpoints/{model_name}_portfolio_model"
    model.save(save_path)
    print(f"βœ… Model saved to checkpoints/{save_path}.zip")


if __name__ == "__main__":
    # 6. Set up command-line argument parsing
    parser = argparse.ArgumentParser(description="Train a Reinforcement Learning agent for portfolio management.")
    
    parser.add_argument(
        "--agent",
        type=str,
        default="td3",
        choices=["ppo", "sac", "td3"],
        help="The RL algorithm to use for training (default: td3)."
    )
    parser.add_argument(
        "--timesteps",
        type=int,
        default=100000,
        help="The total number of timesteps for training (default: 100000)."
    )
    
    args = parser.parse_args()

    # Call the main training function with the parsed arguments
    train_agent(agent_name=args.agent, timesteps=args.timesteps)