cartpole-ppo-agent / train_cartpole.py
jatinror's picture
Upload train_cartpole.py
a4d9d34 verified
import gymnasium as gym
from stable_baselines3 import PPO
import os
# -----------------------------
# Create environment (CPU only)
# -----------------------------
env = gym.make("CartPole-v1")
# -----------------------------
# Initialize PPO agent
# -----------------------------
model = PPO(
policy="MlpPolicy",
env=env,
verbose=1, # prints training progress
)
# -----------------------------
# Training
# -----------------------------
total_timesteps = 100_000 # You can increase for better performance
print(f"๐Ÿš€ Training started for {total_timesteps} timesteps...")
model.learn(total_timesteps=total_timesteps)
print("โœ… Training finished!")
# -----------------------------
# Save trained model
# -----------------------------
os.makedirs("model", exist_ok=True)
model_path = "model/ppo_cartpole"
model.save(model_path)
print(f"โœ… Model saved at {model_path}.zip")
env.close()