RL_Project20 / a2c_main.py
ShuvamGanguli's picture
Upload 4 files
3d433ba verified
import argparse
import gymnasium as gym
import sys
import matplotlib.pyplot as plt
import ale_py
from a2c_helpers import *
from gymnasium.spaces import Box
import cv2
import logging
import numpy as np
import pandas as pd
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
# Preprocess environment
def preprocess(obs):
# Convert to grayscale
obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
# Resize
obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
return np.expand_dims(obs, axis=0).astype(np.float32) / 255.0
def df_ops(lst_df, seeds):
for df in lst_df:
seed_data = df[seeds]
df['Avg'] = seed_data.mean(axis=1)
df['High'] = seed_data.max(axis=1)
df['Low'] = seed_data.min(axis=1)
return lst_df
# Main loop
def main() -> int:
# Initialize variables
batches = 1000
steps = 5
clip_interval = 2
seeds = [10, 20, 30, 40, 50]
ep_per_batch = 5
# batches = 5
# steps = 5
# clip_interval = 2
# seeds = [10, 20]
# ep_per_batch = 2
# Arguments
"""
usage examples:
python3 a2c_main.py --method vanilla
python3 a2c_main.py --method grad_clip
python3 a2c_main.py --method rbs
"""
parser = argparse.ArgumentParser(description='A2C Training')
parser.add_argument('--method', type=str, choices=['vanilla', 'reward_clip', 'rbs', 'grad_clip',
'obs_norm', 'adv_norm', 'return_norm', 'reward_norm'],
default='vanilla', help='A2C update method')
parser.add_argument('--env', type=str, default='ALE/Pacman-v5',
help='Gym environment name (e.g., ALE/Pacman-v5, ALE/SpaceInvaders-v5)')
parser.add_argument('--render', action='store_true', help='Enable rendering')
parser.add_argument('--clip_window', type=int, default=clip_interval,
help='Number of batches to collect rewards for clipping range update')
args = parser.parse_args()
# Set up environment
if args.render:
env = gym.make(args.env, render_mode="human")
else:
env = gym.make(args.env)
logger.info(f"Observation space: {env.observation_space}")
logger.info(f"Action space: {env.action_space}")
logger.info(f'Method: {args.method}')
# Initialize CNN with a dummy observation to get correct input shape
obs, _ = env.reset()
dummy_obs_space = Box(low=0.0, high=1.0, shape=preprocess(obs).shape)
# Initialize A2C agent
agent = Agent(obs_space=dummy_obs_space, action_space=env.action_space,
hidden=64, lr=0.00001, gamma=0.997,
entropy_coef=0.01, value_coef=0.5, seed=70,lam=0.95)
# === Return-Based Scaling stats (for RBS method) ===
r_mean, r_var = 0.0, 1e-8
g2_mean = 1.0
agent.r_var = r_var
agent.g2_mean = g2_mean
all_reward_histories = pd.DataFrame(columns=[i for i in seeds], index=[i for i in range(1, batches + 1)])
all_loss_histories = pd.DataFrame(columns=[i for i in seeds], index=[i for i in range(1, batches + 1)])
all_policy_loss = pd.DataFrame(columns=[i for i in seeds])
all_value_loss = pd.DataFrame(columns=[i for i in seeds])
step = 0
# Main update loop
try:
for seed in seeds:
obs, info = env.reset(seed=seed)
state = preprocess(obs)
loss_history = []
reward_history = []
episode = 0
total_return = 0
""" Update loop: Gradient, Reward Normalization """
if args.method == 'reward_clip':
alpha = np.random.uniform(1, 2)
logger.info(f"α sampled = {alpha:.3f} seed = {seed}")
clip_low, clip_high = None, None
ep_reward_history = []
obs, info = env.reset()
state = preprocess(obs)
for update in range(1, batches + 1):
batch_episode_returns = [] # used for μ, σ
for _ in range(ep_per_batch):
ep_rewards = []
done = False
while not done:
action, value = agent.choose_action(state)
next_obs, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
next_state = preprocess(next_obs)
ep_rewards.append(reward)
agent.remember(state, action, reward, done, value, next_state)
state = next_state
if done:
ep_return = sum(ep_rewards)
if clip_low is not None:
clipped_return = np.clip(ep_return, clip_low, clip_high)
else:
clipped_return = ep_return
ep_reward_history.append(clipped_return)
batch_episode_returns.append(clipped_return)
episode += 1
total_return += clipped_return
logger.info(f"Episode {episode} return: {clipped_return:.2f}")
obs, info = env.reset()
state = preprocess(obs)
# === Compute clipping bounds using Code 1 logic ===
mu = np.mean(batch_episode_returns)
sigma = np.std(batch_episode_returns) + 1e-8 if np.std(batch_episode_returns)!=0 else 1
clip_low = mu - alpha * sigma
clip_high = mu + alpha * sigma
logger.info(
f"[UPDATE {update}] New Reward Clip Range: "
f"[{clip_low:.4f}, {clip_high:.4f}]"
)
# === A2C UPDATE ===
avg_loss = agent.vanilla_a2c_update()
loss_history.append(avg_loss)
avg_ret = np.mean(batch_episode_returns)
reward_history.append(avg_ret)
logger.info(
f"Update {update}: batch_mean={avg_ret:.4f}, "
f"batch_std={np.std(batch_episode_returns):.4f}, "
f"episodes={episode}, avg_loss={avg_loss:.4f}"
)
""" Update loop: Other Normalization Methods """
else:
for update in range(1, batches + 1):
batch_episode_rewards = []
ep_per_batch = 5
for _ in range(ep_per_batch):
ep_rewards = []
done = False
while not done:
action, value = agent.choose_action(state)
next_obs, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
next_state = preprocess(next_obs)
ep_rewards.append(reward)
agent.remember(state, action, reward, done, value, next_state)
state = next_state
if done:
ep_return = sum(ep_rewards)
episode += 1
total_return += ep_return
batch_episode_rewards.append(ep_return)
logger.info(f"Episode {episode} return: {ep_return:.2f}")
obs, info = env.reset()
state = preprocess(obs)
# Choose normalization method
if args.method == 'vanilla':
avg_loss = agent.vanilla_a2c_update()
elif args.method == 'grad_clip':
avg_loss = agent.update_gradient_clipping()
elif args.method == 'obs_norm':
avg_loss = agent.update_obs_norm()
elif args.method == 'adv_norm':
avg_loss = agent.update_adv_norm()
elif args.method == 'reward_norm':
avg_loss = agent.update_reward_norm()
else: # rbs
avg_loss = agent.update_rbs()
loss_history.append(avg_loss)
avg_ret = (total_return / episode) if episode else 0
reward_history.append(avg_ret)
logger.info(
f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={avg_loss:.4f}")
all_reward_histories[seed] = reward_history
all_loss_histories[seed] = loss_history
all_value_loss[seed] = agent.value_loss_history[step:step+batches]
all_policy_loss[seed] = agent.policy_loss_history[step:step+batches]
step += batches
[all_reward_histories, all_loss_histories, all_value_loss, all_policy_loss] = df_ops([all_reward_histories,
all_loss_histories,
all_value_loss,
all_policy_loss],
seeds)
# [all_reward_histories, all_loss_histories] = df_ops([all_reward_histories, all_loss_histories], seeds)
# all_policy_loss.to_csv(args.method + '_policy_loss.csv')
# all_value_loss.to_csv(args.method + '_value_loss.csv')
# all_reward_histories.to_csv(args.method + '_a2c_reward_history.csv')
# all_loss_histories.to_csv(args.method + '_a2c_loss_history.csv')
fig = plt.figure(figsize=(15, 10))
# --- Subplot 1: Average PPO Loss ---
ax2 = plt.subplot(221)
# Plot the shaded High-Low Range
ax2.fill_between(
all_loss_histories.index,
all_loss_histories['Low'],
all_loss_histories['High'],
color='#A8DADC', # Light blue for aesthetic shading
alpha=0.5,
label="High-Low Range"
)
# Plot the Average Line
ax2.plot(all_loss_histories['Avg'], label="Avg Loss", color='#1D3557', linewidth=2)
ax2.set_ylabel("Average PPO Loss")
ax2.set_xlabel("PPO Update")
ax2.legend()
# --- Subplot 2: Reward ---
ax3 = plt.subplot(222)
# Plot the shaded High-Low Range
ax3.fill_between(
all_reward_histories.index,
all_reward_histories['Low'],
all_reward_histories['High'],
color='#FEDCC8', # Light orange/peach
alpha=0.5,
label="High-Low Range"
)
# Plot the Average Line
ax3.plot(all_reward_histories['Avg'], label="Avg Reward", color='#E63946', linewidth=2)
ax3.set_ylabel("Average Reward")
ax3.set_xlabel("PPO Update")
ax3.legend()
# --- Subplot 3: Policy Loss ---
ax4 = plt.subplot(223)
# Plot the shaded High-Low Range
ax4.fill_between(
all_policy_loss.index,
all_policy_loss['Low'],
all_policy_loss['High'],
color='#B0E0A0', # Light green
alpha=0.5,
label="High-Low Range"
)
# Plot the Average Line
ax4.plot(all_policy_loss['Avg'], label="Policy Loss", color='#38B000', linewidth=2)
ax4.set_ylabel("Average Policy Loss")
ax4.set_xlabel("PPO Update")
ax4.legend()
# --- Subplot 4: Value Loss ---
ax5 = plt.subplot(224)
# Plot the shaded High-Low Range
ax5.fill_between(
all_value_loss.index,
all_value_loss['Low'],
all_value_loss['High'],
color='#D7BDE2', # Light purple
alpha=0.5,
label="High-Low Range"
)
# Plot the Average Line
ax5.plot(all_value_loss['Avg'], label="Value Loss", color='#8E44AD', linewidth=2)
ax5.set_ylabel("Average Value Loss")
ax5.set_xlabel("PPO Update")
ax5.legend()
# --- Figure Settings ---
fig.suptitle(f"PPO Training Stability - {args.method}", fontsize=16, fontweight='bold')
# fig.tight_layout() # Adjust layout to make room for suptitle
plt.show()
except Exception as e:
logger.error(f"Error: {e}", exc_info=True)
return 1
finally:
avg = total_return / episode if episode else 0
logger.info(f"\nEpisodes: {episode}, Avg return: {avg:.3f}")
env.close()
return 0
if __name__ == "__main__":
raise SystemExit(main())