Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import gymnasium as gym | |
| from stable_baselines3 import PPO | |
| from stable_baselines3.common.torch_layers import BaseFeaturesExtractor | |
| from stable_baselines3.common.env_util import make_vec_env | |
| from stable_baselines3.common.vec_env import SubprocVecEnv | |
| from ncps.torch import LTC | |
| from ncps.wirings import AutoNCP | |
| from env.drone_3d import Drone3DEnv | |
| class LTCFeatureExtractor(BaseFeaturesExtractor): | |
| """ | |
| Custom Feature Extractor using Liquid Time-Constant (LTC) Cells. | |
| This allows the agent to handle irregular time-steps and stiff dynamics better than standard MLPs or LSTMs. | |
| """ | |
| def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 32): | |
| super().__init__(observation_space, features_dim) | |
| input_size = observation_space.shape[0] | |
| # self.features_dim is already set by super().__init__ | |
| # Neural Circuit Policy (NCP) wiring for structured connectivity | |
| # We use a small wiring to keep inference fast (< 10ms) | |
| # AutoNCP requires units > output_size. Let's use 48 units for 32 outputs. | |
| wiring = AutoNCP(48, output_size=features_dim) | |
| self.ltc = LTC(input_size, wiring, batch_first=True) | |
| # Hidden state for the LTC | |
| self.hx = None | |
| def forward(self, observations: torch.Tensor) -> torch.Tensor: | |
| # LTC expects (batch, time, features) | |
| # SB3 provides (batch, features), so we add a time dimension | |
| if observations.dim() == 2: | |
| observations = observations.unsqueeze(1) | |
| # Initialize hidden state if needed or if batch size changes | |
| batch_size = observations.size(0) | |
| if self.hx is None or self.hx.size(0) != batch_size: | |
| self.hx = torch.zeros(batch_size, self.ltc.state_size, device=observations.device) | |
| # Forward pass through LTC | |
| # Note: In a real recurrent setting with SB3, we'd need to manage hidden states | |
| # more carefully (e.g. using RecurrentPPO from sb3-contrib). | |
| # For this demo, we are using a simplified approach where we treat the LTC | |
| # as a stateful feature extractor that maintains state between calls within a batch. | |
| # However, standard PPO assumes stateless policies. | |
| # To make this truly "Liquid" in a standard PPO loop without sb3-contrib, | |
| # we approximate by running the LTC on the current step. | |
| # A better approach for production would be RecurrentPPO. | |
| # Given the constraints and the goal of a "demo", we will use the LTC | |
| # but reset state if we detect a new episode (which is hard here). | |
| # So we will let the LTC evolve. | |
| # Detach hidden state from previous graph to prevent "backward through graph a second time" error | |
| if self.hx is not None: | |
| self.hx = self.hx.detach() | |
| output, self.hx = self.ltc(observations, self.hx) | |
| # Remove time dimension | |
| return output.squeeze(1) | |
| def make_liquid_ppo(env, verbose=1): | |
| """ | |
| Factory function to create a PPO agent with Liquid Brain. | |
| """ | |
| # Parallel Environments for High-Performance Training | |
| # A100/A10G are data hungry. We need to run physics on many CPU cores to feed them. | |
| # We will use 1 environment to debug (DummyVecEnv) | |
| n_envs = 4 | |
| env = make_vec_env( | |
| lambda: Drone3DEnv(render_mode=None, wind_scale=10.0, wind_speed=5.0), | |
| n_envs=n_envs, | |
| vec_env_cls=SubprocVecEnv | |
| ) | |
| # Create Model with optimized hyperparameters for A100 | |
| policy_kwargs = dict( | |
| features_extractor_class=LTCFeatureExtractor, | |
| features_extractor_kwargs=dict(features_dim=32), | |
| ) | |
| model = PPO( | |
| "MlpPolicy", | |
| env, | |
| verbose=verbose, | |
| learning_rate=1e-3, | |
| n_steps=2048, | |
| batch_size=64, | |
| n_epochs=10, | |
| gamma=0.99, | |
| gae_lambda=0.95, | |
| clip_range=0.2, | |
| policy_kwargs=policy_kwargs, | |
| device='cuda' # Use GPU | |
| ) | |
| return model | |