RL_Project20 / ppo_template_3.py
manansodha's picture
Initial Commit
662707e verified
import gymnasium as gym
import sys
import matplotlib.pyplot as plt
# import numpy as np
from ppo_helpers_v2 import *
def preprocess(obs):
return obs.astype(np.float32).ravel() / 255.0
def main() -> int:
# env = gym.make("ALE/SpaceInvaders-v5", render_mode='human')
env = gym.make("ALE/Pacman-v5")
episode = 0
total_return = 0
ep_return = 0
steps = 2000
batches = 100
print("Observation space:", env.observation_space)
print("Action space:", env.action_space)
agent = Agent(obs_space=env.observation_space, action_space=env.action_space,
hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2,
entropy_coef=0.01, value_coef=0.5, seed=70,
batch_size = 64, ppo_epochs = 4, lam = 0.95)
# === Return-Based Scaling stats ===
r_mean, r_var = 0.0, 1e-8
g2_mean = 1.0
agent.r_var = r_var
agent.g2_mean = g2_mean
try:
obs, info = env.reset(seed=42)
state = preprocess(obs)
loss_history = []
reward_history = []
for update in range(1, batches + 1):
for t in range(steps):
action, logp, value = agent.choose_action(state)
next_obs, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
next_state = preprocess(next_obs)
agent.remember(state, action, reward, done, logp, value, next_state)
ep_return += reward
state = next_state
if done:
episode += 1
total_return += ep_return
print(f"Episode {episode} return: {ep_return:.2f}")
ep_return = 0
obs, info = env.reset()
state = preprocess(obs)
avg_loss = agent.update_rbs() # collect loss from update
loss_history.append(avg_loss)
avg_ret = (total_return / episode) if episode else 0
reward_history.append(avg_ret)
print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={avg_loss:.4f}")
fig = plt.figure()
ax1 = plt.subplot(321)
ax1.plot(agent.sigma_history, label="Return σ")
ax1.set_xlabel("PPO Update")
ax1.set_ylabel("σ (Return Std)")
ax2 = plt.subplot(322)
ax2.plot(loss_history, label="Avg Loss")
ax2.set_ylabel("Average PPO Loss")
ax2.set_xlabel("PPO Update")
ax3 = plt.subplot(323)
ax3.plot(reward_history, label="Reward")
ax3.set_ylabel("Reward")
ax3.set_xlabel("PPO Update")
fig.suptitle("PPO Training Stability")
fig.tight_layout()
plt.show()
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
return 1
finally:
avg = total_return / episode if episode else 0
print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}")
env.close()
return 0
if __name__ == "__main__":
raise SystemExit(main())