Spaces:
Sleeping
Sleeping
| import argparse | |
| import gymnasium as gym | |
| import sys | |
| import matplotlib.pyplot as plt | |
| import ale_py | |
| import pandas as pd | |
| from ppo_helpers_cnn import * | |
| from gymnasium.spaces import Box | |
| import cv2 | |
| import logging | |
| import numpy as np | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') | |
| logger = logging.getLogger(__name__) | |
| # Preprocess environment | |
| def preprocess(obs): | |
| # Convert to grayscale | |
| obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY) | |
| # Resize | |
| obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA) | |
| return np.expand_dims(obs, axis=0).astype(np.float32) / 255.0 | |
| import pandas as pd | |
| import numpy as np | |
| def df_ops(lst_df, seeds): | |
| for df in lst_df: | |
| seed_data = df[seeds] | |
| df['Avg'] = seed_data.mean(axis=1) | |
| df['High'] = seed_data.max(axis=1) | |
| df['Low'] = seed_data.min(axis=1) | |
| return lst_df | |
| # Main loop | |
| def main() -> int: | |
| # Initialize variables | |
| """ | |
| batches = 5 | |
| steps = 5 | |
| clip_interval = 2 | |
| seeds = [10, 20] | |
| ep_per_batch = 2 | |
| """ | |
| batches = 1000 | |
| steps = 5 | |
| clip_interval = 2 | |
| seeds = [10, 20, 30, 40, 50] | |
| ep_per_batch = 5 | |
| # Arguments - 'vanilla', 'reward_clip', 'rbs', 'grad_clip', 'obs_norm', 'adv_norm', 'return_norm', 'reward_norm' | |
| """ | |
| 'vanilla', 'reward_clip', 'rbs', 'grad_clip', 'obs_norm', 'adv_norm', 'return_norm', 'reward_norm' | |
| python Poster/ppo_main.py --method vanilla --env ALE/Pacman-v5 | |
| usage examples: | |
| python3 ppo_main.py --method vanilla | |
| python3 ppo_main.py --method grad_clip | |
| python3 ppo_main.py --method rbs | |
| """ | |
| parser = argparse.ArgumentParser(description='PPO Training') | |
| parser.add_argument('--method', type=str, choices=['vanilla', 'reward_clip', 'rbs', 'grad_clip', | |
| 'obs_norm', 'adv_norm', 'return_norm', 'reward_norm'], | |
| default='vanilla', help='PPO update method') | |
| parser.add_argument('--env', type=str, default='ALE/Pacman-v5', | |
| help='Gym environment name (e.g., ALE/Pacman-v5, ALE/SpaceInvaders-v5, ALE/BattleZone-v5)') | |
| parser.add_argument('--render', action='store_true', help='Enable rendering') | |
| parser.add_argument('--clip_window', type=int, default=clip_interval, | |
| help='Number of batches to collect rewards for clipping range update') | |
| args = parser.parse_args() | |
| # Set up environment | |
| if args.render: | |
| env = gym.make(args.env, render_mode='human') | |
| else: | |
| env = gym.make(args.env) | |
| logger.info(f"Observation space: {env.observation_space}") | |
| logger.info(f"Action space: {env.action_space}") | |
| logger.info(f'Method: {args.method}') | |
| # Initialize CNN with a dummy observation to get correct input shape | |
| obs, _ = env.reset() | |
| dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape) | |
| # Initialize PPO agent | |
| agent = Agent(obs_space=dummy_obs_space, action_space=env.action_space, | |
| hidden=64, lr=0.00001, gamma=0.997, clip_coef=0.2, | |
| entropy_coef=0.01, value_coef=0.5, seed=70, | |
| batch_size=64, ppo_epochs=32, lam=0.95) | |
| # === Return-Based Scaling stats (for RBS method) === | |
| r_mean, r_var = 0.0, 1e-8 | |
| g2_mean = 1.0 | |
| agent.r_var = r_var | |
| agent.g2_mean = g2_mean | |
| # Initialize data structure outside the loop | |
| all_reward_histories = pd.DataFrame(columns=[i for i in seeds], index=[i for i in range(1, batches + 1)]) | |
| all_loss_histories = pd.DataFrame(columns=[i for i in seeds], index=[i for i in range(1, batches + 1)]) | |
| all_policy_loss = pd.DataFrame(columns=[i for i in seeds]) | |
| all_value_loss = pd.DataFrame(columns=[i for i in seeds]) | |
| # Main update loop | |
| try: | |
| for seed in seeds: | |
| obs, info = env.reset(seed=seed) | |
| state = preprocess(obs) | |
| loss_history = [] | |
| reward_history = [] | |
| policy_loss_history = [] | |
| value_loss_history = [] | |
| episode = 0 | |
| total_return = 0 | |
| steps = [0] | |
| """ Update loop: Gradient, Reward Normalization """ | |
| if args.method == 'reward_clip': | |
| alpha = np.random.uniform(1, 2) | |
| logger.info(f"α sampled = {alpha:.3f} seed = {seed}") | |
| clip_low, clip_high = None, None | |
| ep_reward_history = [] | |
| obs, info = env.reset() | |
| state = preprocess(obs) | |
| for update in range(1, batches + 1): | |
| batch_episode_returns = [] # used for μ, σ | |
| for _ in range(ep_per_batch): | |
| ep_rewards = [] | |
| done = False | |
| while not done: | |
| action, logp, value = agent.choose_action(state) | |
| next_obs, reward, terminated, truncated, info = env.step(action) | |
| done = terminated or truncated | |
| next_state = preprocess(next_obs) | |
| ep_rewards.append(reward) | |
| agent.remember(state, action, reward, done, logp, value, next_state) | |
| state = next_state | |
| if done: | |
| ep_return = sum(ep_rewards) | |
| if clip_low is not None: | |
| clipped_return = np.clip(ep_return, clip_low, clip_high) | |
| else: | |
| clipped_return = ep_return | |
| ep_reward_history.append(clipped_return) | |
| batch_episode_returns.append(clipped_return) | |
| episode += 1 | |
| total_return += clipped_return | |
| logger.info(f"Episode {episode} return: {clipped_return:.2f}") | |
| obs, info = env.reset() | |
| state = preprocess(obs) | |
| # === Compute clipping bounds using Code 1 logic === | |
| mu = np.mean(batch_episode_returns) | |
| sigma = np.std(batch_episode_returns) + 1e-8 if np.std(batch_episode_returns) != 0 else 1 | |
| clip_low = mu - alpha * sigma | |
| clip_high = mu + alpha * sigma | |
| logger.info( | |
| f"[UPDATE {update}] New Reward Clip Range: " | |
| f"[{clip_low:.4f}, {clip_high:.4f}]" | |
| ) | |
| # === PPO UPDATE === | |
| avg_loss = agent.vanilla_ppo_update() | |
| loss_history.append(avg_loss) | |
| avg_ret = np.mean(batch_episode_returns) | |
| reward_history.append(avg_ret) | |
| logger.info( | |
| f"Update {update}: batch_mean={avg_ret:.4f}, " | |
| f"batch_std={np.std(batch_episode_returns):.4f}, " | |
| f"episodes={episode}, avg_loss={avg_loss:.4f}" | |
| ) | |
| current_steps = len(agent.value_loss_history) | |
| steps.append(current_steps - 1 - steps[-1]) | |
| x = len(steps) - 1 | |
| value_loss_history.append( | |
| sum(agent.value_loss_history[steps[x - 1]:steps[x]]) / (steps[x - 1] - steps[x])) | |
| policy_loss_history.append(sum(agent.policy_loss_history[x - 1:x]) / (steps[x - 1] - steps[x])) | |
| """ Update loop: Other Normalization Methods """ | |
| else: | |
| for update in range(1, batches + 1): | |
| batch_episode_rewards = [] | |
| ep_per_batch = 5 | |
| for _ in range(ep_per_batch): | |
| ep_rewards = [] | |
| done = False | |
| while not done: | |
| action, logp, value = agent.choose_action(state) | |
| next_obs, reward, terminated, truncated, info = env.step(action) | |
| done = terminated or truncated | |
| next_state = preprocess(next_obs) | |
| ep_rewards.append(reward) # Add this line to collect rewards | |
| agent.remember(state, action, reward, done, logp, value, next_state) | |
| state = next_state | |
| if done: | |
| ep_return = sum(ep_rewards) | |
| episode += 1 | |
| total_return += ep_return | |
| batch_episode_rewards.append(ep_return) | |
| logger.info(f"Episode {episode} return: {ep_return:.2f}") | |
| obs, info = env.reset() | |
| state = preprocess(obs) | |
| # Choose normalization method | |
| if args.method == 'vanilla': | |
| avg_loss = agent.vanilla_ppo_update() | |
| elif args.method == 'grad_clip': | |
| avg_loss = agent.update_gradient_clipping() | |
| elif args.method == 'obs_norm': | |
| avg_loss = agent.update_obs_norm() | |
| elif args.method == 'return_norm': | |
| avg_loss = agent.update_return_norm() | |
| elif args.method == 'reward_norm': | |
| avg_loss = agent.update_reward_norm() | |
| else: # rbs | |
| avg_loss = agent.update_rbs() | |
| loss_history.append(avg_loss) | |
| avg_ret = (total_return / episode) if episode else 0 | |
| reward_history.append(avg_ret) | |
| logger.info( | |
| f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={avg_loss:.4f}") | |
| current_steps = len(agent.value_loss_history) | |
| steps.append(current_steps-1 - steps[-1]) | |
| x = len(steps)-1 | |
| value_loss_history.append(sum(agent.value_loss_history[steps[x-1]:steps[x]]) / (steps[x-1] - steps[x])) | |
| policy_loss_history.append(sum(agent.policy_loss_history[x - 1:x]) / (steps[x - 1] - steps[x])) | |
| all_reward_histories[seed] = reward_history | |
| all_loss_histories[seed] = loss_history | |
| # print(agent.value_loss_history) | |
| all_value_loss[seed] = value_loss_history[1:] | |
| # print(len(agent.value_loss_history)) | |
| # print(agent.policy_loss_history) | |
| all_policy_loss[seed] = policy_loss_history[1:] | |
| # print(len(agent.policy_loss_history)) | |
| [all_reward_histories, all_loss_histories, all_value_loss, all_policy_loss] = df_ops([all_reward_histories, | |
| all_loss_histories, | |
| all_value_loss, | |
| all_policy_loss], seeds) | |
| # [all_reward_histories, all_loss_histories] = df_ops([all_reward_histories, | |
| # all_loss_histories], seeds) | |
| all_policy_loss.to_csv(args.method + '_policy_loss.csv') | |
| all_reward_histories.to_csv(args.method + '_reward_history.csv') | |
| all_loss_histories.to_csv(args.method + '_loss_history.csv') | |
| all_value_loss.to_csv(args.method + '_value_loss.csv') | |
| fig = plt.figure(figsize=(15, 10)) | |
| # --- Subplot 1: Average PPO Loss --- | |
| ax2 = plt.subplot(221) | |
| # Plot the shaded High-Low Range | |
| ax2.fill_between( | |
| all_loss_histories.index, | |
| all_loss_histories['Low'], | |
| all_loss_histories['High'], | |
| color='#A8DADC', # Light blue for aesthetic shading | |
| alpha=0.5, | |
| label="High-Low Range" | |
| ) | |
| # Plot the Average Line | |
| ax2.plot(all_loss_histories['Avg'], label="Avg Loss", color='#1D3557', linewidth=2) | |
| ax2.set_ylabel("Average PPO Loss") | |
| ax2.set_xlabel("PPO Update") | |
| ax2.legend() | |
| # --- Subplot 2: Reward --- | |
| ax3 = plt.subplot(222) | |
| # Plot the shaded High-Low Range | |
| ax3.fill_between( | |
| all_reward_histories.index, | |
| all_reward_histories['Low'], | |
| all_reward_histories['High'], | |
| color='#FEDCC8', # Light orange/peach | |
| alpha=0.5, | |
| label="High-Low Range" | |
| ) | |
| # Plot the Average Line | |
| ax3.plot(all_reward_histories['Avg'], label="Avg Reward", color='#E63946', linewidth=2) | |
| ax3.set_ylabel("Average Reward") | |
| ax3.set_xlabel("PPO Update") | |
| ax3.legend() | |
| # --- Subplot 3: Policy Loss --- | |
| ax4 = plt.subplot(223) | |
| # Plot the shaded High-Low Range | |
| ax4.fill_between( | |
| all_policy_loss.index, | |
| all_policy_loss['Low'], | |
| all_policy_loss['High'], | |
| color='#B0E0A0', # Light green | |
| alpha=0.5, | |
| label="High-Low Range" | |
| ) | |
| # Plot the Average Line | |
| ax4.plot(all_policy_loss['Avg'], label="Policy Loss", color='#38B000', linewidth=2) | |
| ax4.set_ylabel("Average Policy Loss") | |
| ax4.set_xlabel("PPO Update") | |
| ax4.legend() | |
| # --- Subplot 4: Value Loss --- | |
| ax5 = plt.subplot(224) | |
| # Plot the shaded High-Low Range | |
| ax5.fill_between( | |
| all_value_loss.index, | |
| all_value_loss['Low'], | |
| all_value_loss['High'], | |
| color='#D7BDE2', # Light purple | |
| alpha=0.5, | |
| label="High-Low Range" | |
| ) | |
| # Plot the Average Line | |
| ax5.plot(all_value_loss['Avg'], label="Value Loss", color='#8E44AD', linewidth=2) | |
| ax5.set_ylabel("Average Value Loss") | |
| ax5.set_xlabel("PPO Update") | |
| ax5.legend() | |
| # --- Figure Settings --- | |
| fig.suptitle(f"PPO Training Stability - {args.method}", fontsize=16, fontweight='bold') | |
| # fig.tight_layout() # Adjust layout to make room for suptitle | |
| plt.show() | |
| except Exception as e: | |
| logger.error(f"Error: {e}", exc_info=True) | |
| return 1 | |
| finally: | |
| avg = total_return / episode if episode else 0 | |
| logger.info(f"\nEpisodes: {episode}, Avg return: {avg:.3f}") | |
| env.close() | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |