| import gymnasium as gym | |
| from stable_baselines3 import PPO | |
| import os | |
| # ----------------------------- | |
| # Create environment (CPU only) | |
| # ----------------------------- | |
| env = gym.make("CartPole-v1") | |
| # ----------------------------- | |
| # Initialize PPO agent | |
| # ----------------------------- | |
| model = PPO( | |
| policy="MlpPolicy", | |
| env=env, | |
| verbose=1, # prints training progress | |
| ) | |
| # ----------------------------- | |
| # Training | |
| # ----------------------------- | |
| total_timesteps = 100_000 # You can increase for better performance | |
| print(f"๐ Training started for {total_timesteps} timesteps...") | |
| model.learn(total_timesteps=total_timesteps) | |
| print("โ Training finished!") | |
| # ----------------------------- | |
| # Save trained model | |
| # ----------------------------- | |
| os.makedirs("model", exist_ok=True) | |
| model_path = "model/ppo_cartpole" | |
| model.save(model_path) | |
| print(f"โ Model saved at {model_path}.zip") | |
| env.close() | |