DanielKiani's picture
Version 1.0 release
349ad65
# custom_policy.py
import torch
import torch.nn as nn
from gymnasium import spaces
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class TransformerFeatureExtractor(BaseFeaturesExtractor):
"""
A custom feature extractor that uses a Transformer Encoder.
It takes a flattened observation (window_size * n_features_per_step) and processes
it as a sequence.
"""
def __init__(
self,
observation_space: spaces.Box,
features_dim: int = 256, # The final output dimension
n_features_per_step: int = 8, # <--- CRITICAL CHANGE: Matches 5 assets + 3 macro
window_size: int = 30,
d_model: int = 64, # Transformer's internal embedding dimension
n_head: int = 4, # Number of attention heads
n_layers: int = 2, # Number of transformer encoder layers
dropout: float = 0.1
):
super().__init__(observation_space, features_dim)
self.window_size = window_size
self.n_features_per_step = n_features_per_step
# Input shape check
expected_flat_dim = window_size * n_features_per_step
if observation_space.shape[0] != expected_flat_dim:
raise ValueError(
f"Observation space flat dimension {observation_space.shape[0]} "
f"does not match expected {expected_flat_dim} "
f"(window_size={window_size}, n_features_per_step={n_features_per_step})."
)
# 1. Input Projection:
self.input_projection = nn.Linear(n_features_per_step, d_model)
# 2. Positional Encoding:
self.positional_encoding = nn.Parameter(torch.randn(1, window_size, d_model))
# 3. Transformer Encoder:
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=n_head,
dropout=dropout,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
# 4. Output Layers:
self.flatten = nn.Flatten()
self.linear_out = nn.Linear(window_size * d_model, features_dim)
self.relu = nn.ReLU()
def forward(self, observations: torch.Tensor) -> torch.Tensor:
# Input shape: (batch_size, window_size * n_features_per_step)
# 1. Reshape to (batch_size, window_size, n_features_per_step)
x = observations.reshape(-1, self.window_size, self.n_features_per_step)
# 2. Project input features to d_model
x = self.input_projection(x)
# 3. Add positional encoding
x = x + self.positional_encoding
# 4. Pass through Transformer
x = self.transformer_encoder(x)
# 5. Flatten and project to final output
x = self.flatten(x)
x = self.relu(self.linear_out(x))
return x