File size: 2,862 Bytes
349ad65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# custom_policy.py 

import torch
import torch.nn as nn
from gymnasium import spaces
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

class TransformerFeatureExtractor(BaseFeaturesExtractor):
    """
    A custom feature extractor that uses a Transformer Encoder.

    It takes a flattened observation (window_size * n_features_per_step) and processes
    it as a sequence.
    """
    def __init__(
        self,
        observation_space: spaces.Box,
        features_dim: int = 256,  # The final output dimension
        n_features_per_step: int = 8,  # <--- CRITICAL CHANGE: Matches 5 assets + 3 macro
        window_size: int = 30,
        d_model: int = 64,       # Transformer's internal embedding dimension
        n_head: int = 4,         # Number of attention heads
        n_layers: int = 2,       # Number of transformer encoder layers
        dropout: float = 0.1
    ):

        super().__init__(observation_space, features_dim)

        self.window_size = window_size
        self.n_features_per_step = n_features_per_step

        # Input shape check
        expected_flat_dim = window_size * n_features_per_step
        if observation_space.shape[0] != expected_flat_dim:
            raise ValueError(
                f"Observation space flat dimension {observation_space.shape[0]} "
                f"does not match expected {expected_flat_dim} "
                f"(window_size={window_size}, n_features_per_step={n_features_per_step})."
            )

        # 1. Input Projection:
        self.input_projection = nn.Linear(n_features_per_step, d_model)

        # 2. Positional Encoding:
        self.positional_encoding = nn.Parameter(torch.randn(1, window_size, d_model))

        # 3. Transformer Encoder:
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_head,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        # 4. Output Layers:
        self.flatten = nn.Flatten()
        self.linear_out = nn.Linear(window_size * d_model, features_dim)
        self.relu = nn.ReLU()

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        # Input shape: (batch_size, window_size * n_features_per_step)

        # 1. Reshape to (batch_size, window_size, n_features_per_step)
        x = observations.reshape(-1, self.window_size, self.n_features_per_step)

        # 2. Project input features to d_model
        x = self.input_projection(x)

        # 3. Add positional encoding
        x = x + self.positional_encoding

        # 4. Pass through Transformer
        x = self.transformer_encoder(x)

        # 5. Flatten and project to final output
        x = self.flatten(x)
        x = self.relu(self.linear_out(x))

        return x