| |
|
|
| import torch |
| import torch.nn as nn |
| from gymnasium import spaces |
| from stable_baselines3.common.torch_layers import BaseFeaturesExtractor |
|
|
| class TransformerFeatureExtractor(BaseFeaturesExtractor): |
| """ |
| A custom feature extractor that uses a Transformer Encoder. |
| |
| It takes a flattened observation (window_size * n_features_per_step) and processes |
| it as a sequence. |
| """ |
| def __init__( |
| self, |
| observation_space: spaces.Box, |
| features_dim: int = 256, |
| n_features_per_step: int = 8, |
| window_size: int = 30, |
| d_model: int = 64, |
| n_head: int = 4, |
| n_layers: int = 2, |
| dropout: float = 0.1 |
| ): |
|
|
| super().__init__(observation_space, features_dim) |
|
|
| self.window_size = window_size |
| self.n_features_per_step = n_features_per_step |
|
|
| |
| expected_flat_dim = window_size * n_features_per_step |
| if observation_space.shape[0] != expected_flat_dim: |
| raise ValueError( |
| f"Observation space flat dimension {observation_space.shape[0]} " |
| f"does not match expected {expected_flat_dim} " |
| f"(window_size={window_size}, n_features_per_step={n_features_per_step})." |
| ) |
|
|
| |
| self.input_projection = nn.Linear(n_features_per_step, d_model) |
|
|
| |
| self.positional_encoding = nn.Parameter(torch.randn(1, window_size, d_model)) |
|
|
| |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, |
| nhead=n_head, |
| dropout=dropout, |
| batch_first=True |
| ) |
| self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) |
|
|
| |
| self.flatten = nn.Flatten() |
| self.linear_out = nn.Linear(window_size * d_model, features_dim) |
| self.relu = nn.ReLU() |
|
|
| def forward(self, observations: torch.Tensor) -> torch.Tensor: |
| |
|
|
| |
| x = observations.reshape(-1, self.window_size, self.n_features_per_step) |
|
|
| |
| x = self.input_projection(x) |
|
|
| |
| x = x + self.positional_encoding |
|
|
| |
| x = self.transformer_encoder(x) |
|
|
| |
| x = self.flatten(x) |
| x = self.relu(self.linear_out(x)) |
|
|
| return x |