| import torch |
| import torch.nn as nn |
| from stable_baselines3.common.torch_layers import BaseFeaturesExtractor |
| from deep_portfolio_torch import DeepPortfolioAgentNetworkTorch |
|
|
| class PortfolioFeaturesExtractorTorch(BaseFeaturesExtractor): |
| def __init__(self, observation_space, features_dim=32, |
| num_assets=4, sequence_length=60, num_features_per_asset=26, |
| asset_cnn_filters1=32, asset_cnn_filters2=64, |
| asset_lstm_units1=64, asset_lstm_units2=32, |
| final_dense_units1=128, final_dense_units2=32, |
| final_dropout=0.3, mha_num_heads=4, mha_key_dim_divisor=2, |
| output_latent_features=True, use_sentiment_analysis=False): |
| super().__init__(observation_space, features_dim) |
| self.network = DeepPortfolioAgentNetworkTorch( |
| num_assets=num_assets, |
| sequence_length=sequence_length, |
| num_features_per_asset=num_features_per_asset, |
| asset_cnn_filters1=asset_cnn_filters1, |
| asset_cnn_filters2=asset_cnn_filters2, |
| asset_lstm_units1=asset_lstm_units1, |
| asset_lstm_units2=asset_lstm_units2, |
| final_dense_units1=final_dense_units1, |
| final_dense_units2=final_dense_units2, |
| final_dropout=final_dropout, |
| mha_num_heads=mha_num_heads, |
| mha_key_dim_divisor=mha_key_dim_divisor, |
| output_latent_features=output_latent_features, |
| use_sentiment_analysis=use_sentiment_analysis |
| ) |
| self._features_dim = features_dim |
|
|
| def forward(self, observations): |
| |
| return self.network(observations) |
|
|