File size: 3,058 Bytes
28dbd6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import sys
import argparse
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
import multiprocessing
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
from stable_baselines3.common.callbacks import CheckpointCallback

# Add project root to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from env_3d import Drone3DEnv
from models.liquid_policy import LiquidFeatureExtractor

def make_env(rank, seed=0):
    def _init():
        env = Drone3DEnv()
        env.reset(seed=seed + rank)
        return env
    return _init

def main():
    # Verify CUDA
    import torch
    print(f"Is CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
    else:
        print("WARNING: CUDA is NOT available. Training will be slow.")

    parser = argparse.ArgumentParser()
    parser.add_argument("--timesteps", type=int, default=8_000_000) # 8M steps as requested
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    # Create Vectorized Env
    # 8x L40S has ~192 vCPUs. Using 64-96 is usually the sweet spot for PPO.
    # Too many envs can cause overhead or instability.
    max_cpu = 64 
    available_cpu = multiprocessing.cpu_count()
    num_cpu = min(max_cpu, available_cpu)
    print(f"Using {num_cpu} CPUs for parallel environments (Available: {available_cpu}).")
    
    # SubprocVecEnv for true parallelism
    env = SubprocVecEnv([make_env(i, args.seed) for i in range(num_cpu)])
    env = VecMonitor(env) # Monitor wrapper for logging

    # Liquid Policy Config
    policy_kwargs = dict(
        features_extractor_class=LiquidFeatureExtractor,
        features_extractor_kwargs=dict(features_dim=128, hidden_size=128, dt=0.05), # Large Capacity
        net_arch=dict(pi=[256, 256], vf=[256, 256]) # Large Capacity
    )

    # Initialize PPO
    model = PPO(
        "MlpPolicy",
        env,
        policy_kwargs=policy_kwargs,
        verbose=1,
        learning_rate=3e-4,
        n_steps=8192, # Large horizon for stability
        batch_size=2048, # Large batch for powerful GPU
        n_epochs=10,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        tensorboard_log="./logs/ppo_liquid_3d/",
        device="cuda"
    )
    
    # Checkpoint Callback
    checkpoint_callback = CheckpointCallback(
        save_freq=500_000 // num_cpu,
        save_path="./models/checkpoints/",
        name_prefix="liquid_ppo_3d"
    )

    print("Starting Training with Liquid Neural Network Policy...")
    print(f"Target Timesteps: {args.timesteps}")
    print(f"Configuration: {num_cpu} Envs, {model.n_steps} Steps, {model.batch_size} Batch Size")
    model.learn(total_timesteps=args.timesteps, callback=checkpoint_callback)
    
    model.save("models/liquid_ppo_3d_final")
    print("Model saved to models/liquid_ppo_3d_final.zip")

if __name__ == "__main__":
    main()