|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from gymnasium import spaces |
|
|
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor |
|
|
|
|
|
class TransformerFeatureExtractor(BaseFeaturesExtractor): |
|
|
""" |
|
|
A custom feature extractor that uses a Transformer Encoder. |
|
|
|
|
|
It takes a flattened observation (window_size * n_features_per_step) and processes |
|
|
it as a sequence. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
observation_space: spaces.Box, |
|
|
features_dim: int = 256, |
|
|
n_features_per_step: int = 8, |
|
|
window_size: int = 30, |
|
|
d_model: int = 64, |
|
|
n_head: int = 4, |
|
|
n_layers: int = 2, |
|
|
dropout: float = 0.1 |
|
|
): |
|
|
|
|
|
super().__init__(observation_space, features_dim) |
|
|
|
|
|
self.window_size = window_size |
|
|
self.n_features_per_step = n_features_per_step |
|
|
|
|
|
|
|
|
expected_flat_dim = window_size * n_features_per_step |
|
|
if observation_space.shape[0] != expected_flat_dim: |
|
|
raise ValueError( |
|
|
f"Observation space flat dimension {observation_space.shape[0]} " |
|
|
f"does not match expected {expected_flat_dim} " |
|
|
f"(window_size={window_size}, n_features_per_step={n_features_per_step})." |
|
|
) |
|
|
|
|
|
|
|
|
self.input_projection = nn.Linear(n_features_per_step, d_model) |
|
|
|
|
|
|
|
|
self.positional_encoding = nn.Parameter(torch.randn(1, window_size, d_model)) |
|
|
|
|
|
|
|
|
encoder_layer = nn.TransformerEncoderLayer( |
|
|
d_model=d_model, |
|
|
nhead=n_head, |
|
|
dropout=dropout, |
|
|
batch_first=True |
|
|
) |
|
|
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) |
|
|
|
|
|
|
|
|
self.flatten = nn.Flatten() |
|
|
self.linear_out = nn.Linear(window_size * d_model, features_dim) |
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
def forward(self, observations: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
x = observations.reshape(-1, self.window_size, self.n_features_per_step) |
|
|
|
|
|
|
|
|
x = self.input_projection(x) |
|
|
|
|
|
|
|
|
x = x + self.positional_encoding |
|
|
|
|
|
|
|
|
x = self.transformer_encoder(x) |
|
|
|
|
|
|
|
|
x = self.flatten(x) |
|
|
x = self.relu(self.linear_out(x)) |
|
|
|
|
|
return x |