import gymnasium as gym import sys import matplotlib.pyplot as plt # import numpy as np from ppo_helpers_v2 import * def preprocess(obs): return obs.astype(np.float32).ravel() / 255.0 def main() -> int: # env = gym.make("ALE/SpaceInvaders-v5", render_mode='human') env = gym.make("ALE/Pacman-v5") episode = 0 total_return = 0 ep_return = 0 steps = 2000 batches = 100 print("Observation space:", env.observation_space) print("Action space:", env.action_space) agent = Agent(obs_space=env.observation_space, action_space=env.action_space, hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2, entropy_coef=0.01, value_coef=0.5, seed=70, batch_size = 64, ppo_epochs = 4, lam = 0.95) # === Return-Based Scaling stats === r_mean, r_var = 0.0, 1e-8 g2_mean = 1.0 agent.r_var = r_var agent.g2_mean = g2_mean try: obs, info = env.reset(seed=42) state = preprocess(obs) loss_history = [] reward_history = [] for update in range(1, batches + 1): for t in range(steps): action, logp, value = agent.choose_action(state) next_obs, reward, terminated, truncated, info = env.step(action) done = terminated or truncated next_state = preprocess(next_obs) agent.remember(state, action, reward, done, logp, value, next_state) ep_return += reward state = next_state if done: episode += 1 total_return += ep_return print(f"Episode {episode} return: {ep_return:.2f}") ep_return = 0 obs, info = env.reset() state = preprocess(obs) avg_loss = agent.update_rbs() # collect loss from update loss_history.append(avg_loss) avg_ret = (total_return / episode) if episode else 0 reward_history.append(avg_ret) print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={avg_loss:.4f}") fig = plt.figure() ax1 = plt.subplot(321) ax1.plot(agent.sigma_history, label="Return σ") ax1.set_xlabel("PPO Update") ax1.set_ylabel("σ (Return Std)") ax2 = plt.subplot(322) ax2.plot(loss_history, label="Avg Loss") ax2.set_ylabel("Average PPO Loss") ax2.set_xlabel("PPO Update") ax3 = plt.subplot(323) ax3.plot(reward_history, label="Reward") ax3.set_ylabel("Reward") ax3.set_xlabel("PPO Update") fig.suptitle("PPO Training Stability") fig.tight_layout() plt.show() except Exception as e: print(f"Error: {e}", file=sys.stderr) return 1 finally: avg = total_return / episode if episode else 0 print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}") env.close() return 0 if __name__ == "__main__": raise SystemExit(main())