DeepFin / agents /portfolio_features_extractor_torch.py
amos-fernandes's picture
Upload 151 files
cb9259f verified
import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from deep_portfolio_torch import DeepPortfolioAgentNetworkTorch
class PortfolioFeaturesExtractorTorch(BaseFeaturesExtractor):
def __init__(self, observation_space, features_dim=32,
num_assets=4, sequence_length=60, num_features_per_asset=26,
asset_cnn_filters1=32, asset_cnn_filters2=64,
asset_lstm_units1=64, asset_lstm_units2=32,
final_dense_units1=128, final_dense_units2=32,
final_dropout=0.3, mha_num_heads=4, mha_key_dim_divisor=2,
output_latent_features=True, use_sentiment_analysis=False):
super().__init__(observation_space, features_dim)
self.network = DeepPortfolioAgentNetworkTorch(
num_assets=num_assets,
sequence_length=sequence_length,
num_features_per_asset=num_features_per_asset,
asset_cnn_filters1=asset_cnn_filters1,
asset_cnn_filters2=asset_cnn_filters2,
asset_lstm_units1=asset_lstm_units1,
asset_lstm_units2=asset_lstm_units2,
final_dense_units1=final_dense_units1,
final_dense_units2=final_dense_units2,
final_dropout=final_dropout,
mha_num_heads=mha_num_heads,
mha_key_dim_divisor=mha_key_dim_divisor,
output_latent_features=output_latent_features,
use_sentiment_analysis=use_sentiment_analysis
)
self._features_dim = features_dim
def forward(self, observations):
# observations: (batch, seq_len, num_assets * num_features_per_asset)
return self.network(observations)