Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import argparse | |
| import gymnasium as gym | |
| from stable_baselines3 import PPO | |
| from stable_baselines3.common.monitor import Monitor | |
| import multiprocessing | |
| from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor | |
| from stable_baselines3.common.callbacks import CheckpointCallback | |
| # Add project root to path | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| from env_3d import Drone3DEnv | |
| from models.liquid_policy import LiquidFeatureExtractor | |
| def make_env(rank, seed=0): | |
| def _init(): | |
| env = Drone3DEnv() | |
| env.reset(seed=seed + rank) | |
| return env | |
| return _init | |
| def main(): | |
| # Verify CUDA | |
| import torch | |
| print(f"Is CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| else: | |
| print("WARNING: CUDA is NOT available. Training will be slow.") | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--timesteps", type=int, default=8_000_000) # 8M steps as requested | |
| parser.add_argument("--seed", type=int, default=42) | |
| args = parser.parse_args() | |
| # Create Vectorized Env | |
| # 8x L40S has ~192 vCPUs. Using 64-96 is usually the sweet spot for PPO. | |
| # Too many envs can cause overhead or instability. | |
| max_cpu = 64 | |
| available_cpu = multiprocessing.cpu_count() | |
| num_cpu = min(max_cpu, available_cpu) | |
| print(f"Using {num_cpu} CPUs for parallel environments (Available: {available_cpu}).") | |
| # SubprocVecEnv for true parallelism | |
| env = SubprocVecEnv([make_env(i, args.seed) for i in range(num_cpu)]) | |
| env = VecMonitor(env) # Monitor wrapper for logging | |
| # Liquid Policy Config | |
| policy_kwargs = dict( | |
| features_extractor_class=LiquidFeatureExtractor, | |
| features_extractor_kwargs=dict(features_dim=128, hidden_size=128, dt=0.05), # Large Capacity | |
| net_arch=dict(pi=[256, 256], vf=[256, 256]) # Large Capacity | |
| ) | |
| # Initialize PPO | |
| model = PPO( | |
| "MlpPolicy", | |
| env, | |
| policy_kwargs=policy_kwargs, | |
| verbose=1, | |
| learning_rate=3e-4, | |
| n_steps=8192, # Large horizon for stability | |
| batch_size=2048, # Large batch for powerful GPU | |
| n_epochs=10, | |
| gamma=0.99, | |
| gae_lambda=0.95, | |
| clip_range=0.2, | |
| tensorboard_log="./logs/ppo_liquid_3d/", | |
| device="cuda" | |
| ) | |
| # Checkpoint Callback | |
| checkpoint_callback = CheckpointCallback( | |
| save_freq=500_000 // num_cpu, | |
| save_path="./models/checkpoints/", | |
| name_prefix="liquid_ppo_3d" | |
| ) | |
| print("Starting Training with Liquid Neural Network Policy...") | |
| print(f"Target Timesteps: {args.timesteps}") | |
| print(f"Configuration: {num_cpu} Envs, {model.n_steps} Steps, {model.batch_size} Batch Size") | |
| model.learn(total_timesteps=args.timesteps, callback=checkpoint_callback) | |
| model.save("models/liquid_ppo_3d_final") | |
| print("Model saved to models/liquid_ppo_3d_final.zip") | |
| if __name__ == "__main__": | |
| main() | |