team_22 / train /train_mlp_ppo.py
Antigravity Agent
Deploy Neuro-Flyt 3D Training
6083286
"""
Train a PPO agent with MLP policy on the DroneWindEnv environment.
This script uses stable-baselines3 PPO with a 2-layer MLP (64, 64) to train
an agent to survive and navigate in the 2D drone environment with wind.
The trained model is saved to models/mlp_baseline.zip and TensorBoard logs
are written to logs/ppo_mlp/.
"""
import os
import sys
import argparse
from typing import Optional
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env.drone_env import DroneWindEnv
def make_env(seed: Optional[int] = None) -> gym.Env:
"""
Create and wrap a DroneWindEnv instance with Monitor.
Args:
seed: Optional random seed for the environment
Returns:
Wrapped Gymnasium environment
"""
env = DroneWindEnv()
env = Monitor(env)
if seed is not None:
env.reset(seed=seed)
return env
def make_vec_env(num_envs: int = 4) -> DummyVecEnv:
"""
Create a vectorized environment with multiple parallel instances.
Args:
num_envs: Number of parallel environments
Returns:
Vectorized environment
"""
def make_vec_env_fn(seed: Optional[int] = None):
def _init():
return make_env(seed)
return _init
vec_env = DummyVecEnv([make_vec_env_fn(seed=i) for i in range(num_envs)])
return vec_env
def main():
"""Main training function."""
parser = argparse.ArgumentParser(description="Train PPO agent on DroneWindEnv")
parser.add_argument(
"--timesteps",
type=int,
default=100_000,
help="Total number of training timesteps (default: 100000)"
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Random seed (default: 0)"
)
parser.add_argument(
"--logdir",
type=str,
default="logs/ppo_mlp",
help="Directory for TensorBoard logs (default: logs/ppo_mlp)"
)
parser.add_argument(
"--model-path",
type=str,
default="models/mlp_baseline.zip",
help="Path to save the trained model (default: models/mlp_baseline.zip)"
)
parser.add_argument(
"--num-envs",
type=int,
default=4,
help="Number of parallel environments (default: 4)"
)
args = parser.parse_args()
# Create directories if they don't exist
os.makedirs(os.path.dirname(args.model_path), exist_ok=True)
os.makedirs(args.logdir, exist_ok=True)
print("=" * 60)
print("Training PPO Agent on DroneWindEnv")
print("=" * 60)
print(f"Total timesteps: {args.timesteps:,}")
print(f"Number of parallel environments: {args.num_envs}")
print(f"Model will be saved to: {args.model_path}")
print(f"TensorBoard logs: {args.logdir}")
print("=" * 60)
# Create vectorized environment
print("Creating vectorized environment...")
vec_env = make_vec_env(num_envs=args.num_envs)
# Configure policy (2-layer MLP with 64 hidden units each)
policy_kwargs = dict(net_arch=[64, 64])
# Create PPO agent
print("Initializing PPO agent...")
model = PPO(
policy="MlpPolicy",
env=vec_env,
policy_kwargs=policy_kwargs,
n_steps=1024,
batch_size=64,
gamma=0.99,
learning_rate=3e-4,
gae_lambda=0.95,
clip_range=0.2,
ent_coef=0.0,
verbose=1,
tensorboard_log=args.logdir,
seed=args.seed,
)
# Train the agent
print("\nStarting training...")
model.learn(
total_timesteps=args.timesteps,
progress_bar=True
)
# Save the model
print(f"\nSaving model to {args.model_path}...")
model.save(args.model_path)
print("\n" + "=" * 60)
print("Training completed successfully!")
print(f"Model saved to: {args.model_path}")
print(f"TensorBoard logs available at: {args.logdir}")
print("=" * 60)
print("\nTo view training progress, run:")
print(f" tensorboard --logdir {args.logdir}")
print("\nTo evaluate the model, run:")
print(f" python eval/eval_mlp_baseline.py --model-path {args.model_path}")
if __name__ == "__main__":
main()