# scripts/tune_sac.py import os import sys import pandas as pd import numpy as np import optuna from stable_baselines3 import SAC from stable_baselines3.common.vec_env import DummyVecEnv # Use DummyVecEnv from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.logger import configure from environment import PortfolioEnv # ============================================================================== # 1. Configuration & Data Loading # ============================================================================== TRAIN_DATA_PATH = 'data/train.csv' EVAL_DATA_PATH = 'data/eval.csv' OPTUNA_LOG_DIR = 'optuna_logs' CHECKPOINT_DIR = 'checkpoints/optuna_sac_trials' # Create directories if they don't exist os.makedirs(OPTUNA_LOG_DIR, exist_ok=True) os.makedirs(CHECKPOINT_DIR, exist_ok=True) # Load data once df_full_train = pd.read_csv(TRAIN_DATA_PATH, index_col='Date', parse_dates=True) df_eval = pd.read_csv(EVAL_DATA_PATH, index_col='Date', parse_dates=True) # Split df_full_train for tuning train_split_point = int(len(df_full_train) * 0.8) df_train_tune = df_full_train.iloc[:train_split_point] df_validation_tune = df_full_train.iloc[train_split_point:] print(f"Total training data points: {len(df_full_train)}") print(f"Optuna training data points: {len(df_train_tune)}") print(f"Optuna validation data points: {len(df_validation_tune)}") # ============================================================================== # 2. Environment Creation Helper # ============================================================================== def make_env(df, window_size=30, initial_balance=10000, transaction_cost_pct=0.001): """ Helper function to create a PortfolioEnv instance. """ def _init(): env = PortfolioEnv( df=df, initial_balance=initial_balance, window_size=window_size, transaction_cost_pct=transaction_cost_pct ) return env return _init # ============================================================================== # 3. Optuna Objective Function # ============================================================================== def objective(trial: optuna.Trial) -> float: """ Objective function for Optuna to optimize hyperparameters for SAC. """ # Hyperparameter search space learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-3, log=True) gamma = trial.suggest_float('gamma', 0.9, 0.999) tau = trial.suggest_float('tau', 0.005, 0.02) buffer_size = trial.suggest_int('buffer_size', 50000, 1000000, log=True) batch_size = trial.suggest_categorical('batch_size', [64, 128, 256, 512]) ent_coef = trial.suggest_float('ent_coef', 0.001, 0.1, log=True) # Use log scale for ent_coef # Network architecture n_layers = trial.suggest_int('n_layers', 1, 3) net_arch = [] for i in range(n_layers): layer_size = trial.suggest_categorical(f'layer_size_{i}', [64, 128, 256]) net_arch.append(layer_size) policy_kwargs = dict(net_arch=net_arch) # SAC uses shared network or separate [pi, qf] # Create environments for this trial train_env = DummyVecEnv([make_env(df_train_tune)]) eval_env = DummyVecEnv([make_env(df_validation_tune)]) # Set up logger for the trial trial_log_path = os.path.join(OPTUNA_LOG_DIR, f"trial_{trial.number}") new_logger = configure(trial_log_path, ["stdout", "csv", "tensorboard"]) # Create SAC model model = SAC( "MlpPolicy", train_env, learning_rate=learning_rate, gamma=gamma, tau=tau, buffer_size=buffer_size, batch_size=batch_size, ent_coef=ent_coef, # Pass the sampled value policy_kwargs=policy_kwargs, verbose=0, seed=42, # Use a fixed seed for reproducibility within a trial tensorboard_log=OPTUNA_LOG_DIR ) model.set_logger(new_logger) # Callback for evaluation eval_callback = EvalCallback( eval_env, best_model_save_path=os.path.join(CHECKPOINT_DIR, f"best_sac_trial_{trial.number}"), log_path=trial_log_path, eval_freq=5000, deterministic=True, render=False, n_eval_episodes=1 ) try: # Train for a set number of steps per trial total_timesteps_per_trial = 50000 model.learn(total_timesteps=total_timesteps_per_trial, callback=eval_callback, progress_bar=False) # Load the best model found during this trial's training best_model_path = os.path.join(CHECKPOINT_DIR, f"best_sac_trial_{trial.number}", "best_model.zip") if os.path.exists(best_model_path): model = SAC.load(best_model_path, env=eval_env) else: print(f"Warning: No best model saved for trial {trial.number}, using last model.") # --- Final evaluation on the validation set --- obs = eval_env.reset() portfolio_values = [eval_env.envs[0].initial_balance] done = False while not done: action, _ = model.predict(obs, deterministic=True) obs, reward, done, info = eval_env.step(action) portfolio_values.append(info[0]['portfolio_value']) final_portfolio_value = portfolio_values[-1] initial_portfolio_value = portfolio_values[0] total_return = (final_portfolio_value / initial_portfolio_value) - 1 print(f"Trial {trial.number} finished. Total Return on validation: {total_return:.4f}") except Exception as e: print(f"Trial {trial.number} failed due to: {e}") return float('nan') # Optuna handles NaN as a failure finally: train_env.close() eval_env.close() return total_return # Optuna aims to maximize this metric # ============================================================================== # 4. Run Optuna Study # ============================================================================== if __name__ == '__main__': study = optuna.create_study( direction='maximize', sampler=optuna.samplers.TPESampler(seed=42) ) n_trials_to_run = 50 study.optimize(objective, n_trials=n_trials_to_run, n_jobs=1) # n_jobs=1 is safer for Colab print("\n--- Optimization finished. ---") print("Best trial:") trial = study.best_trial print(f" Value: {trial.value:.4f}") print(" Params: ") for key, value in trial.params.items(): print(f" {key}: {value}") # Save the best parameters to a file best_params = trial.params with open('checkpoints/best_sac_params.txt', 'w') as f: f.write(str(best_params)) print(f"\n✅ Best parameters saved to checkpoints/best_sac_params.txt") # Plotting results try: import plotly from optuna.visualization import plot_optimization_history, plot_param_importances fig1 = plot_optimization_history(study) fig1.show() fig2 = plot_param_importances(study) fig2.show() except ImportError: print("\nInstall plotly and kaleido to visualize Optuna results: !pip install plotly kaleido")