Spaces:
Sleeping
Sleeping
| import gymnasium as gym | |
| import sys | |
| import matplotlib.pyplot as plt | |
| # import numpy as np | |
| from ppo_helpers_v2 import * | |
| def preprocess(obs): | |
| return obs.astype(np.float32).ravel() / 255.0 | |
| def main() -> int: | |
| # env = gym.make("ALE/SpaceInvaders-v5", render_mode='human') | |
| env = gym.make("ALE/Pacman-v5") | |
| episode = 0 | |
| total_return = 0 | |
| ep_return = 0 | |
| steps = 2000 | |
| batches = 100 | |
| print("Observation space:", env.observation_space) | |
| print("Action space:", env.action_space) | |
| agent = Agent(obs_space=env.observation_space, action_space=env.action_space, | |
| hidden=64, lr=3e-4, gamma=0.99, clip_coef=0.2, | |
| entropy_coef=0.01, value_coef=0.5, seed=70, | |
| batch_size = 64, ppo_epochs = 4, lam = 0.95) | |
| # === Return-Based Scaling stats === | |
| r_mean, r_var = 0.0, 1e-8 | |
| g2_mean = 1.0 | |
| agent.r_var = r_var | |
| agent.g2_mean = g2_mean | |
| try: | |
| obs, info = env.reset(seed=42) | |
| state = preprocess(obs) | |
| loss_history = [] | |
| reward_history = [] | |
| for update in range(1, batches + 1): | |
| for t in range(steps): | |
| action, logp, value = agent.choose_action(state) | |
| next_obs, reward, terminated, truncated, info = env.step(action) | |
| done = terminated or truncated | |
| next_state = preprocess(next_obs) | |
| agent.remember(state, action, reward, done, logp, value, next_state) | |
| ep_return += reward | |
| state = next_state | |
| if done: | |
| episode += 1 | |
| total_return += ep_return | |
| print(f"Episode {episode} return: {ep_return:.2f}") | |
| ep_return = 0 | |
| obs, info = env.reset() | |
| state = preprocess(obs) | |
| avg_loss = agent.update_rbs() # collect loss from update | |
| loss_history.append(avg_loss) | |
| avg_ret = (total_return / episode) if episode else 0 | |
| reward_history.append(avg_ret) | |
| print(f"Update {update}: episodes={episode}, avg_return={avg_ret:.2f}, avg_loss={avg_loss:.4f}") | |
| fig = plt.figure() | |
| ax1 = plt.subplot(321) | |
| ax1.plot(agent.sigma_history, label="Return σ") | |
| ax1.set_xlabel("PPO Update") | |
| ax1.set_ylabel("σ (Return Std)") | |
| ax2 = plt.subplot(322) | |
| ax2.plot(loss_history, label="Avg Loss") | |
| ax2.set_ylabel("Average PPO Loss") | |
| ax2.set_xlabel("PPO Update") | |
| ax3 = plt.subplot(323) | |
| ax3.plot(reward_history, label="Reward") | |
| ax3.set_ylabel("Reward") | |
| ax3.set_xlabel("PPO Update") | |
| fig.suptitle("PPO Training Stability") | |
| fig.tight_layout() | |
| plt.show() | |
| except Exception as e: | |
| print(f"Error: {e}", file=sys.stderr) | |
| return 1 | |
| finally: | |
| avg = total_return / episode if episode else 0 | |
| print(f"\nEpisodes: {episode}, Avg return: {avg:.3f}") | |
| env.close() | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |