Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from torch import optim | |
| from torch import functional as F | |
| from einops import rearrange | |
| import os | |
| import pickle | |
| #from modules.utils import * | |
| from .utils import * | |
| class Encoder(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.rnn = nn.RNN(input_size=config['z_dim'], | |
| hidden_size=config['hidden_dim'], | |
| num_layers=config['num_layer']) | |
| self.fc = nn.Linear(in_features=config['hidden_dim'], | |
| out_features=config['hidden_dim']) | |
| def forward(self, x): | |
| x_enc, _ = self.rnn(x) | |
| x_enc = self.fc(x_enc) | |
| return x_enc | |
| class Decoder(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.rnn = nn.RNN(input_size=config['hidden_dim'], | |
| hidden_size=config['hidden_dim'], | |
| num_layers=config['num_layer']) | |
| self.fc = nn.Linear(in_features=config['hidden_dim'], | |
| out_features=config['z_dim']) | |
| def forward(self, x_enc): | |
| x_dec, _ = self.rnn(x_enc) | |
| x_dec = self.fc(x_dec) | |
| return x_dec | |
| class Interpolator(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.sequence_inter = nn.Linear(in_features=(config['ts_size'] - config['total_mask_size']), | |
| out_features=config['ts_size']) | |
| self.feature_inter = nn.Linear(in_features=config['hidden_dim'], | |
| out_features=config['hidden_dim']) | |
| def forward(self, x): | |
| # x(bs, vis_size, hidden_dim) | |
| x = rearrange(x, 'b l f -> b f l') # x(bs, hidden_dim, vis_size) | |
| x = self.sequence_inter(x) # x(bs, hidden_dim, ts_size) | |
| x = rearrange(x, 'b f l -> b l f') # x(bs, ts_size, hidden_dim) | |
| x = self.feature_inter(x) # x(bs, ts_size, hidden_dim) | |
| return x | |
| class StockEmbedder(nn.Module): | |
| def __init__(self, cfg: dict = None) -> None: | |
| """ | |
| Args: | |
| cfg (dict): { | |
| 'ts_size': 24, | |
| 'mask_size': 1, | |
| 'num_masks': 3, | |
| 'hidden_dim': 12, | |
| 'embed_dim': 6, | |
| 'num_layer': 3, | |
| 'z_dim': 6, | |
| 'num_embed': 32, | |
| 'stock_features': [], | |
| 'min_val': 0, | |
| 'max_val': 1e6 | |
| } | |
| """ | |
| super().__init__() | |
| self.config = cfg | |
| self.config['total_mask_size'] = self.config['num_masks'] * self.config['mask_size'] | |
| self.encoder = Encoder(config=self.config) | |
| self.interpolator = Interpolator(config=self.config) | |
| self.decoder = Decoder(config=self.config) | |
| print('StockEmbedder initialized') | |
| def mask_it(self, | |
| x: torch.Tensor, | |
| masks: torch.Tensor): | |
| # x.shape = (bs, ts_size, z_dim) | |
| b, l, f = x.shape | |
| x_visible = x[~masks.bool(), :].reshape(b, -1, f) # (bs, vis_size, z_dim) | |
| return x_visible | |
| def forward_ae(self, x: torch.Tensor): | |
| """mae_pseudo_mask is equivalent to the Autoencoder | |
| There is no interpolator in this mode | |
| Args: | |
| x (torch.Tensor): shape: (bs, ts_size, z_dim) | |
| """ | |
| out_encoder = self.encoder(x) | |
| out_decoder = self.decoder(out_encoder) | |
| return out_encoder, out_decoder | |
| def forward_mae(self, | |
| x: torch.Tensor, | |
| masks: torch.Tensor): | |
| """No mask tokens, using Interpolation in the latent space | |
| Args: | |
| x (torch.Tensor): shape: (bs, ts_size, z_dim) | |
| masks (torch.Tensor): | |
| """ | |
| x_vis = self.mask_it(x, masks=masks) # (bs, vis_size, z_dim) | |
| out_encoder = self.encoder(x_vis) # (bs, vis_size, hidden_dim) | |
| out_interpolator = self.interpolator(out_encoder) # (bs, ts_size, hidden_dim) | |
| out_decoder = self.decoder(out_interpolator) # (bs, ts_size, z_dim) | |
| return out_encoder, out_interpolator, out_decoder | |
| def forward(self, | |
| x: torch.Tensor, | |
| masks: torch.Tensor = None, | |
| mode: str = 'ae | mae'): | |
| x = torch.tensor(x, dtype=torch.float32) | |
| if masks is not None: | |
| masks = torch.tensor(masks, dtype=torch.float32) | |
| if mode == 'ae': | |
| out_encoder, out_decoder = self.forward_ae(x) | |
| return out_encoder, out_decoder | |
| elif mode == 'mae': | |
| out_encoder, out_interpolator, out_decoder = self.forward_mae(x, masks=masks) | |
| return out_encoder, out_interpolator, out_decoder | |
| def get_embedding(self, | |
| stock_data: torch.Tensor, | |
| embedding_used: str = 'encoder | decoder'): | |
| """get stock_embedding | |
| Args: | |
| stock_data (torch.Tensor): shape = (batch_size, stock_days, stock_features); NORMALIZED | |
| """ | |
| with torch.no_grad(): | |
| out_encoder, out_decoder = self.forward(stock_data, masks=None, mode='ae') | |
| if embedding_used == 'encoder': | |
| stock_embedding = out_encoder | |
| elif embedding_used == 'decoder': | |
| stock_embedding = out_decoder | |
| return stock_embedding | |
| def save(self, model_dir: str): | |
| os.makedirs(model_dir, exist_ok=True) | |
| # Save model: | |
| torch.save(obj=self.state_dict(), f=os.path.join(model_dir, 'model.pth')) | |
| # Save config: | |
| with open(file=os.path.join(model_dir, 'config.pkl'), mode='wb') as f: | |
| pickle.dump(obj=self.config, file=f) |