import argparse import gymnasium as gym import sys import matplotlib.pyplot as plt import ale_py from a2c_helpers import * from gymnasium.spaces import Box import cv2 import logging import numpy as np import pandas as pd # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') logger = logging.getLogger(__name__) # Preprocess environment def preprocess(obs): # Convert to grayscale obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY) # Resize obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA) return np.expand_dims(obs, axis=0).astype(np.float32) / 255.0 def df_ops(lst_df, seeds): for df in lst_df: seed_data = df[seeds] df['Avg'] = seed_data.mean(axis=1) df['High'] = seed_data.max(axis=1) df['Low'] = seed_data.min(axis=1) return lst_df # Main loop def main() -> int: # Initialize variables batches = 1000 steps = 5 clip_interval = 2 seeds = [10, 20, 30, 40, 50] ep_per_batch = 5 # batches = 5 # steps = 5 # clip_interval = 2 # seeds = [10, 20] # ep_per_batch = 2 # Arguments """ usage examples: python3 a2c_main.py --method vanilla python3 a2c_main.py --method grad_clip python3 a2c_main.py --method rbs """ parser = argparse.ArgumentParser(description='A2C Training') parser.add_argument('--method', type=str, choices=['vanilla', 'reward_clip', 'rbs', 'grad_clip', 'obs_norm', 'adv_norm', 'return_norm', 'reward_norm'], default='vanilla', help='A2C update method') parser.add_argument('--env', type=str, default='ALE/Pacman-v5', help='Gym environment name (e.g., ALE/Pacman-v5, ALE/SpaceInvaders-v5)') parser.add_argument('--render', action='store_true', help='Enable rendering') parser.add_argument('--clip_window', type=int, default=clip_interval, help='Number of batches to collect rewards for clipping range update') args = parser.parse_args() # Set up environment if args.render: env = gym.make(args.env, render_mode="human") else: env = gym.make(args.env) logger.info(f"Observation space: {env.observation_space}") logger.info(f"Action space: {env.action_space}") logger.info(f'Method: {args.method}') # Initialize CNN with a dummy observation to get correct input shape obs, _ = env.reset() dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape) # Initialize A2C agent agent = Agent(obs_space=dummy_obs_space, action_space=env.action_space, hidden=64, lr=0.00001, gamma=0.997, entropy_coef=0.01, value_coef=0.5, seed=70,lam=0.95) # === Return-Based Scaling stats (for RBS method) === r_mean, r_var = 0.0, 1e-8 g2_mean = 1.0 agent.r_var = r_var agent.g2_mean = g2_mean all_reward_histories = pd.DataFrame(columns=[i for i in seeds], index=[i for i in range(1, batches + 1)]) all_loss_histories = pd.DataFrame(columns=[i for i in seeds], index=[i for i in range(1, batches + 1)]) all_policy_loss = pd.DataFrame(columns=[i for i in seeds]) all_value_loss = pd.DataFrame(columns=[i for i in seeds]) step = 0 # Main update loop try: for seed in seeds: obs, info = env.reset(seed=seed) state = preprocess(obs) loss_history = [] reward_history = [] episode = 0 total_return = 0 """ Update loop: Gradient, Reward Normalization """ if args.method == 'reward_clip': alpha = np.random.uniform(1, 2) logger.info(f"α sampled = {alpha:.3f} seed = {seed}") clip_low, clip_high = None, None ep_reward_history = [] obs, info = env.reset() state = preprocess(obs) for update in range(1, batches + 1): batch_episode_returns = [] # used for μ, σ for _ in range(ep_per_batch): ep_rewards = [] done = False while not done: action, value = agent.choose_action(state) next_obs, reward, terminated, truncated, info = env.step(action) done = terminated or truncated next_state = preprocess(next_obs) ep_rewards.append(reward) agent.remember(state, action, reward, done, value, next_state) state = next_state if done: ep_return = sum(ep_rewards) if clip_low is not None: clipped_return = np.clip(ep_return, clip_low, clip_high) else: clipped_return = ep_return ep_reward_history.append(clipped_return) batch_episode_returns.append(clipped_return) episode += 1 total_return += clipped_return logger.info(f"Episode {episode} return: {clipped_return:.2f}") obs, info = env.reset() state = preprocess(obs) # === Compute clipping bounds using Code 1 logic === mu = np.mean(batch_episode_returns) sigma = np.std(batch_episode_returns) + 1e-8 if np.std(batch_episode_returns)!=0 else 1 clip_low = mu - alpha * sigma clip_high = mu + alpha * sigma logger.info( f"[UPDATE {update}] New Reward Clip Range: " f"[{clip_low:.4f}, {clip_high:.4f}]" ) # === A2C UPDATE === avg_loss = agent.vanilla_a2c_update() loss_history.append(avg_loss) avg_ret = np.mean(batch_episode_returns) reward_history.append(avg_ret) logger.info( f"Update {update}: batch_mean={avg_ret:.4f}, " f"batch_std={np.std(batch_episode_returns):.4f}, " f"episodes={episode}, avg_loss={avg_loss:.4f}" ) """ Update loop: Other Normalization Methods """ else: for update in range(1, batches + 1): batch_episode_rewards = [] ep_per_batch = 5 for _ in range(ep_per_batch): ep_rewards = [] done = False while not done: action, value = agent.choose_action(state) next_obs, reward, terminated, truncated, info = env.step(action) done = terminated or truncated next_state = preprocess(next_obs) ep_rewards.append(reward) agent.remember(state, action, reward, done, value, next_state) state = next_state if done: ep_return = sum(ep_rewards) episode += 1 total_return += ep_return batch_episode_rewards.append(ep_return) logger.info(f"Episode {episode} return: {ep_return:.2f}") obs, info = env.reset() state = preprocess(obs) # Choose normalization method if args.method == 'vanilla': avg_loss = agent.vanilla_a2c_update() elif args.method == 'grad_clip': avg_loss = agent.update_gradient_clipping() elif args.method == 'obs_norm': avg_loss = agent.update_obs_norm() elif args.method == 'adv_norm': avg_loss = agent.update_adv_norm() elif args.method == 'reward_norm': avg_loss = agent.update_reward_norm() else: # rbs avg_loss = agent.update_rbs() loss_history.append(avg_loss) avg_ret = (total_return / episode) if episode else 0 reward_history.append(avg_ret) logger.info( f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={avg_loss:.4f}") all_reward_histories[seed] = reward_history all_loss_histories[seed] = loss_history all_value_loss[seed] = agent.value_loss_history[step:step+batches] all_policy_loss[seed] = agent.policy_loss_history[step:step+batches] step += batches [all_reward_histories, all_loss_histories, all_value_loss, all_policy_loss] = df_ops([all_reward_histories, all_loss_histories, all_value_loss, all_policy_loss], seeds) # [all_reward_histories, all_loss_histories] = df_ops([all_reward_histories, all_loss_histories], seeds) # all_policy_loss.to_csv(args.method + '_policy_loss.csv') # all_value_loss.to_csv(args.method + '_value_loss.csv') # all_reward_histories.to_csv(args.method + '_a2c_reward_history.csv') # all_loss_histories.to_csv(args.method + '_a2c_loss_history.csv') fig = plt.figure(figsize=(15, 10)) # --- Subplot 1: Average PPO Loss --- ax2 = plt.subplot(221) # Plot the shaded High-Low Range ax2.fill_between( all_loss_histories.index, all_loss_histories['Low'], all_loss_histories['High'], color='#A8DADC', # Light blue for aesthetic shading alpha=0.5, label="High-Low Range" ) # Plot the Average Line ax2.plot(all_loss_histories['Avg'], label="Avg Loss", color='#1D3557', linewidth=2) ax2.set_ylabel("Average PPO Loss") ax2.set_xlabel("PPO Update") ax2.legend() # --- Subplot 2: Reward --- ax3 = plt.subplot(222) # Plot the shaded High-Low Range ax3.fill_between( all_reward_histories.index, all_reward_histories['Low'], all_reward_histories['High'], color='#FEDCC8', # Light orange/peach alpha=0.5, label="High-Low Range" ) # Plot the Average Line ax3.plot(all_reward_histories['Avg'], label="Avg Reward", color='#E63946', linewidth=2) ax3.set_ylabel("Average Reward") ax3.set_xlabel("PPO Update") ax3.legend() # --- Subplot 3: Policy Loss --- ax4 = plt.subplot(223) # Plot the shaded High-Low Range ax4.fill_between( all_policy_loss.index, all_policy_loss['Low'], all_policy_loss['High'], color='#B0E0A0', # Light green alpha=0.5, label="High-Low Range" ) # Plot the Average Line ax4.plot(all_policy_loss['Avg'], label="Policy Loss", color='#38B000', linewidth=2) ax4.set_ylabel("Average Policy Loss") ax4.set_xlabel("PPO Update") ax4.legend() # --- Subplot 4: Value Loss --- ax5 = plt.subplot(224) # Plot the shaded High-Low Range ax5.fill_between( all_value_loss.index, all_value_loss['Low'], all_value_loss['High'], color='#D7BDE2', # Light purple alpha=0.5, label="High-Low Range" ) # Plot the Average Line ax5.plot(all_value_loss['Avg'], label="Value Loss", color='#8E44AD', linewidth=2) ax5.set_ylabel("Average Value Loss") ax5.set_xlabel("PPO Update") ax5.legend() # --- Figure Settings --- fig.suptitle(f"PPO Training Stability - {args.method}", fontsize=16, fontweight='bold') # fig.tight_layout() # Adjust layout to make room for suptitle plt.show() except Exception as e: logger.error(f"Error: {e}", exc_info=True) return 1 finally: avg = total_return / episode if episode else 0 logger.info(f"\nEpisodes: {episode}, Avg return: {avg:.3f}") env.close() return 0 if __name__ == "__main__": raise SystemExit(main())