Spaces:
Runtime error
Runtime error
File size: 4,442 Bytes
6083286 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
"""
Train a PPO agent with MLP policy on the DroneWindEnv environment.
This script uses stable-baselines3 PPO with a 2-layer MLP (64, 64) to train
an agent to survive and navigate in the 2D drone environment with wind.
The trained model is saved to models/mlp_baseline.zip and TensorBoard logs
are written to logs/ppo_mlp/.
"""
import os
import sys
import argparse
from typing import Optional
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env.drone_env import DroneWindEnv
def make_env(seed: Optional[int] = None) -> gym.Env:
"""
Create and wrap a DroneWindEnv instance with Monitor.
Args:
seed: Optional random seed for the environment
Returns:
Wrapped Gymnasium environment
"""
env = DroneWindEnv()
env = Monitor(env)
if seed is not None:
env.reset(seed=seed)
return env
def make_vec_env(num_envs: int = 4) -> DummyVecEnv:
"""
Create a vectorized environment with multiple parallel instances.
Args:
num_envs: Number of parallel environments
Returns:
Vectorized environment
"""
def make_vec_env_fn(seed: Optional[int] = None):
def _init():
return make_env(seed)
return _init
vec_env = DummyVecEnv([make_vec_env_fn(seed=i) for i in range(num_envs)])
return vec_env
def main():
"""Main training function."""
parser = argparse.ArgumentParser(description="Train PPO agent on DroneWindEnv")
parser.add_argument(
"--timesteps",
type=int,
default=100_000,
help="Total number of training timesteps (default: 100000)"
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Random seed (default: 0)"
)
parser.add_argument(
"--logdir",
type=str,
default="logs/ppo_mlp",
help="Directory for TensorBoard logs (default: logs/ppo_mlp)"
)
parser.add_argument(
"--model-path",
type=str,
default="models/mlp_baseline.zip",
help="Path to save the trained model (default: models/mlp_baseline.zip)"
)
parser.add_argument(
"--num-envs",
type=int,
default=4,
help="Number of parallel environments (default: 4)"
)
args = parser.parse_args()
# Create directories if they don't exist
os.makedirs(os.path.dirname(args.model_path), exist_ok=True)
os.makedirs(args.logdir, exist_ok=True)
print("=" * 60)
print("Training PPO Agent on DroneWindEnv")
print("=" * 60)
print(f"Total timesteps: {args.timesteps:,}")
print(f"Number of parallel environments: {args.num_envs}")
print(f"Model will be saved to: {args.model_path}")
print(f"TensorBoard logs: {args.logdir}")
print("=" * 60)
# Create vectorized environment
print("Creating vectorized environment...")
vec_env = make_vec_env(num_envs=args.num_envs)
# Configure policy (2-layer MLP with 64 hidden units each)
policy_kwargs = dict(net_arch=[64, 64])
# Create PPO agent
print("Initializing PPO agent...")
model = PPO(
policy="MlpPolicy",
env=vec_env,
policy_kwargs=policy_kwargs,
n_steps=1024,
batch_size=64,
gamma=0.99,
learning_rate=3e-4,
gae_lambda=0.95,
clip_range=0.2,
ent_coef=0.0,
verbose=1,
tensorboard_log=args.logdir,
seed=args.seed,
)
# Train the agent
print("\nStarting training...")
model.learn(
total_timesteps=args.timesteps,
progress_bar=True
)
# Save the model
print(f"\nSaving model to {args.model_path}...")
model.save(args.model_path)
print("\n" + "=" * 60)
print("Training completed successfully!")
print(f"Model saved to: {args.model_path}")
print(f"TensorBoard logs available at: {args.logdir}")
print("=" * 60)
print("\nTo view training progress, run:")
print(f" tensorboard --logdir {args.logdir}")
print("\nTo evaluate the model, run:")
print(f" python eval/eval_mlp_baseline.py --model-path {args.model_path}")
if __name__ == "__main__":
main()
|