# %% # Import required packages import gymnasium as gym from huggingface_sb3 import package_to_hub from stable_baselines3 import PPO from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder # %% # Test random environment env_id = "LunarLander-v3" env = gym.make(env_id) observation, info = env.reset() for _ in range(20): action = env.action_space.sample() print("Action taken:", action) observation, reward, terminated, truncated, info = env.step(action) if terminated or truncated: print("Environment is reset") observation, info = env.reset() env.close() # %% # Check observation and action spaces env.reset() print("_____OBSERVATION SPACE_____ \n") print("Observation Space Shape", env.observation_space.shape) print("Sample observation", env.observation_space.sample()) # Get a random observation print("\n _____ACTION SPACE_____ \n") print("Action Space Shape", env.action_space.n) print("Action Space Sample", env.action_space.sample()) # Take a random action # %% # Check SB3 model device model = PPO("MlpPolicy", env, device="auto") print(model.device) # %% # Train PPO agent model = PPO( policy="MlpPolicy", env=env, n_steps=1024, batch_size=64, n_epochs=4, gamma=0.999, gae_lambda=0.98, ent_coef=0.01, verbose=1, ) model.learn(total_timesteps=500_000) # %% # Train agent for 1M timesteps model.learn(total_timesteps=1_000_000) model.save("ppo-lunar-lander") # %% # Evaluate the agent model = PPO.load("ppo-lunar-lander", env=env) eval_env = Monitor(gym.make(env_id)) mean_reward, std_reward = evaluate_policy( model, eval_env, n_eval_episodes=100, deterministic=True ) print(f"mean_reward={mean_reward:.2f} +/- {std_reward}") # %% # Publish the trained agent eval_env = DummyVecEnv( [lambda: Monitor(gym.make(env_id, render_mode="rgb_array"))] ) eval_env = VecVideoRecorder( eval_env, "videos/", record_video_trigger=lambda x: x == 0, video_length=1000, name_prefix="ppo-lunar-lander-demo", ) package_to_hub( model=model, model_name="ppo-lunar-lander-v2", model_architecture="PPO", env_id=env_id, eval_env=eval_env, repo_id="pabloramesc/ppo-lunar-lander-v2", commit_message="Upload PPO agent for LunarLander-v2", ) # %%