Huy0502 commited on
Commit
1f1857d
·
verified ·
1 Parent(s): d73e6af

Create stock_embedder.py

Browse files
Files changed (1) hide show
  1. Models/stock_embedder.py +189 -0
Models/stock_embedder.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch import optim
4
+ from torch import functional as F
5
+ from einops import rearrange
6
+ import os
7
+ import pickle
8
+ from stock_embedder.modules.utils import *
9
+
10
+
11
+
12
+ class Encoder(nn.Module):
13
+ def __init__(self, config):
14
+ super().__init__()
15
+ self.rnn = nn.RNN(input_size=config['z_dim'],
16
+ hidden_size=config['hidden_dim'],
17
+ num_layers=config['num_layer'])
18
+ self.fc = nn.Linear(in_features=config['hidden_dim'],
19
+ out_features=config['hidden_dim'])
20
+
21
+ def forward(self, x):
22
+ x_enc, _ = self.rnn(x)
23
+ x_enc = self.fc(x_enc)
24
+ return x_enc
25
+
26
+
27
+ class Decoder(nn.Module):
28
+ def __init__(self, config):
29
+ super().__init__()
30
+ self.rnn = nn.RNN(input_size=config['hidden_dim'],
31
+ hidden_size=config['hidden_dim'],
32
+ num_layers=config['num_layer'])
33
+ self.fc = nn.Linear(in_features=config['hidden_dim'],
34
+ out_features=config['z_dim'])
35
+
36
+ def forward(self, x_enc):
37
+ x_dec, _ = self.rnn(x_enc)
38
+ x_dec = self.fc(x_dec)
39
+ return x_dec
40
+
41
+
42
+ class Interpolator(nn.Module):
43
+ def __init__(self, config):
44
+ super().__init__()
45
+ self.sequence_inter = nn.Linear(in_features=(config['ts_size'] - config['total_mask_size']),
46
+ out_features=config['ts_size'])
47
+ self.feature_inter = nn.Linear(in_features=config['hidden_dim'],
48
+ out_features=config['hidden_dim'])
49
+
50
+ def forward(self, x):
51
+
52
+ # x(bs, vis_size, hidden_dim)
53
+ x = rearrange(x, 'b l f -> b f l') # x(bs, hidden_dim, vis_size)
54
+ x = self.sequence_inter(x) # x(bs, hidden_dim, ts_size)
55
+ x = rearrange(x, 'b f l -> b l f') # x(bs, ts_size, hidden_dim)
56
+ x = self.feature_inter(x) # x(bs, ts_size, hidden_dim)
57
+ return x
58
+
59
+
60
+ class StockEmbedder(nn.Module):
61
+ def __init__(self, cfg: dict = None) -> None:
62
+
63
+ """
64
+ Args:
65
+ cfg (dict): {
66
+ 'ts_size': 24,
67
+ 'mask_size': 1,
68
+ 'num_masks': 3,
69
+ 'hidden_dim': 12,
70
+ 'embed_dim': 6,
71
+ 'num_layer': 3,
72
+ 'z_dim': 6,
73
+ 'num_embed': 32,
74
+ 'stock_features': [],
75
+ 'min_val': 0,
76
+ 'max_val': 1e6
77
+ }
78
+
79
+ """
80
+
81
+ super().__init__()
82
+
83
+ self.config = cfg
84
+
85
+ self.config['total_mask_size'] = self.config['num_masks'] * self.config['mask_size']
86
+
87
+ self.encoder = Encoder(config=self.config)
88
+ self.interpolator = Interpolator(config=self.config)
89
+ self.decoder = Decoder(config=self.config)
90
+
91
+
92
+ print('StockEmbedder initialized')
93
+
94
+
95
+ def mask_it(self,
96
+ x: torch.Tensor,
97
+ masks: torch.Tensor):
98
+
99
+ # x.shape = (bs, ts_size, z_dim)
100
+
101
+ b, l, f = x.shape
102
+ x_visible = x[~masks.bool(), :].reshape(b, -1, f) # (bs, vis_size, z_dim)
103
+
104
+ return x_visible
105
+
106
+
107
+ def forward_ae(self, x: torch.Tensor):
108
+
109
+ """mae_pseudo_mask is equivalent to the Autoencoder
110
+ There is no interpolator in this mode
111
+
112
+ Args:
113
+ x (torch.Tensor): shape: (bs, ts_size, z_dim)
114
+ """
115
+
116
+ out_encoder = self.encoder(x)
117
+ out_decoder = self.decoder(out_encoder)
118
+
119
+ return out_encoder, out_decoder
120
+
121
+
122
+ def forward_mae(self,
123
+ x: torch.Tensor,
124
+ masks: torch.Tensor):
125
+
126
+ """No mask tokens, using Interpolation in the latent space
127
+
128
+ Args:
129
+ x (torch.Tensor): shape: (bs, ts_size, z_dim)
130
+ masks (torch.Tensor):
131
+ """
132
+
133
+ x_vis = self.mask_it(x, masks=masks) # (bs, vis_size, z_dim)
134
+ out_encoder = self.encoder(x_vis) # (bs, vis_size, hidden_dim)
135
+ out_interpolator = self.interpolator(out_encoder) # (bs, ts_size, hidden_dim)
136
+ out_decoder = self.decoder(out_interpolator) # (bs, ts_size, z_dim)
137
+
138
+ return out_encoder, out_interpolator, out_decoder
139
+
140
+
141
+ def forward(self,
142
+ x: torch.Tensor,
143
+ masks: torch.Tensor = None,
144
+ mode: str = 'ae | mae'):
145
+
146
+ x = torch.tensor(x, dtype=torch.float32)
147
+ if masks is not None:
148
+ masks = torch.tensor(masks, dtype=torch.float32)
149
+
150
+ if mode == 'ae':
151
+ out_encoder, out_decoder = self.forward_ae(x)
152
+
153
+ return out_encoder, out_decoder
154
+
155
+ elif mode == 'mae':
156
+ out_encoder, out_interpolator, out_decoder = self.forward_mae(x, masks=masks)
157
+
158
+ return out_encoder, out_interpolator, out_decoder
159
+
160
+
161
+ def get_embedding(self,
162
+ stock_data: torch.Tensor,
163
+ embedding_used: str = 'encoder | decoder'):
164
+
165
+ """get stock_embedding
166
+
167
+ Args:
168
+ stock_data (torch.Tensor): shape = (batch_size, stock_days, stock_features); NORMALIZED
169
+ """
170
+
171
+ with torch.no_grad():
172
+ out_encoder, out_decoder = self.forward(stock_data, masks=None, mode='ae')
173
+
174
+ if embedding_used == 'encoder':
175
+ stock_embedding = out_encoder
176
+ elif embedding_used == 'decoder':
177
+ stock_embedding = out_decoder
178
+
179
+ return stock_embedding
180
+
181
+
182
+ def save(self, model_dir: str):
183
+ os.makedirs(model_dir, exist_ok=True)
184
+
185
+ # Save model:
186
+ torch.save(obj=self.state_dict(), f=os.path.join(model_dir, 'model.pth'))
187
+ # Save config:
188
+ with open(file=os.path.join(model_dir, 'config.pkl'), mode='wb') as f:
189
+ pickle.dump(obj=self.config, file=f)