Abner0803 commited on
Commit
011fa3c
·
verified ·
1 Parent(s): 0c2f092

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +268 -0
README.md ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## To use these checkpoints, you need to use the following model structure for Transformer
2
+
3
+ ### Import used packages
4
+
5
+ ```python
6
+ import math
7
+
8
+ import torch
9
+ from torch import nn
10
+ ```
11
+
12
+ ### PositionalEncoding
13
+
14
+ ```python
15
+ class PositionalEncoding(nn.Module):
16
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None:
17
+ super().__init__()
18
+ self.dropout = nn.Dropout(p=dropout)
19
+
20
+ pe = torch.zeros(max_len, d_model)
21
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(
22
+ 1
23
+ ) # (max_len, 1)
24
+ div_term = torch.exp(
25
+ torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
26
+ )
27
+ pe[:, 0::2] = torch.sin(position * div_term) # (max_len, d_model // 2)
28
+ truncated_div_term = div_term[: d_model // 2]
29
+ pe[:, 1::2] = torch.cos(position * truncated_div_term) #
30
+ pe = pe.unsqueeze(0).transpose(0, 1) # (max_len, 1, d_model)
31
+ self.register_buffer("pe", pe)
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ x = x + self.pe[: x.size(0), :, :]
35
+ return self.dropout(x)
36
+ ```
37
+
38
+ ### RPBClass
39
+
40
+ ```python
41
+ class RelativePositionBiasV2(nn.Module):
42
+ def __init__(self, n_heads, num_buckets=32, max_distance=128, bidirectional=True):
43
+ super().__init__()
44
+ assert num_buckets % 2 == 0, "num_buckets should be even for bidirectional"
45
+ self.n_heads = n_heads
46
+ self.num_buckets = num_buckets
47
+ self.max_distance = max_distance
48
+ self.bidirectional = bidirectional
49
+ self.emb = nn.Embedding(num_buckets, n_heads)
50
+
51
+ def _relative_position_bucket(self, relative_position):
52
+ """
53
+ relative_position: [Tq, Tk] = k - q
54
+ returns bucket ids in [0, num_buckets-1]
55
+ """
56
+ num_buckets = self.num_buckets
57
+ max_distance = self.max_distance
58
+
59
+ ret = torch.zeros_like(relative_position, dtype=torch.long)
60
+ n = -relative_position # want smaller buckets for n > 0 (keys before queries)
61
+
62
+ if self.bidirectional:
63
+ half = num_buckets // 2
64
+ ret += (n < 0).long() * half
65
+ n = n.abs()
66
+ num_buckets = half # remaining buckets for non-negative distances
67
+ else:
68
+ n = torch.clamp(n, min=0)
69
+
70
+ # Now n >= 0
71
+ max_exact = num_buckets // 2
72
+ is_small = n < max_exact
73
+ # Avoid log(0) and division by zero; also ensure max_distance > max_exact
74
+ denom = max(1.0, math.log(max(max_distance, max_exact + 1) / max(1, max_exact)))
75
+ val_if_large = (
76
+ max_exact
77
+ + (
78
+ (torch.log(n.float() / max(1, max_exact) + 1e-6) / denom)
79
+ * (num_buckets - max_exact)
80
+ ).long()
81
+ )
82
+ val_if_large = torch.clamp(val_if_large, max=num_buckets - 1)
83
+
84
+ ret += torch.where(is_small, n.long(), val_if_large)
85
+ # Final clamp for absolute safety when bidirectional half-split was applied
86
+ return torch.clamp(ret, min=0, max=self.num_buckets - 1)
87
+
88
+ def forward(self, Tq, Tk, device=None):
89
+ device = device or torch.device("cpu")
90
+ qpos = torch.arange(Tq, device=device)[:, None]
91
+ kpos = torch.arange(Tk, device=device)[None, :]
92
+ buckets = self._relative_position_bucket(kpos - qpos) # [Tq, Tk]
93
+ bias = self.emb(buckets) # [Tq, Tk, H]
94
+ return bias.permute(2, 0, 1) # [H, Tq, Tk]
95
+ ```
96
+
97
+ ### Transformer Base Class
98
+
99
+ ```python
100
+ class BaseTransformerComp(nn.Module):
101
+ """Base class for transformer-based intra-stock components."""
102
+
103
+ def __init__(
104
+ self,
105
+ input_dim: int,
106
+ hidden_dim: int,
107
+ num_layers: int,
108
+ num_heads: int,
109
+ dropout: float = 0.1,
110
+ mask_type: str = "none",
111
+ ) -> None:
112
+ super().__init__()
113
+ self.input_dim = input_dim
114
+ self.hidden_dim = hidden_dim
115
+ self.num_layers = num_layers
116
+ self.num_heads = num_heads
117
+ self.dropout_rate = dropout
118
+ self.mask_type = mask_type
119
+
120
+ def _reshape_input(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]:
121
+ """
122
+ Reshape input from [batch, seq_len, n_stocks, n_feats] to [seq_len, batch*n_stocks, n_feats].
123
+ Returns reshaped tensor and original batch/n_stocks sizes for later reconstruction.
124
+ """
125
+ batch, seq_len, n_stocks, n_feats = x.shape
126
+
127
+ if batch == 0 or seq_len == 0 or n_stocks == 0:
128
+ raise ValueError(
129
+ f"Invalid input dimensions: batch={batch}, seq_len={seq_len}, "
130
+ f"n_stocks={n_stocks}, n_feats={n_feats}"
131
+ )
132
+
133
+ x = x.permute(0, 2, 1, 3).contiguous()
134
+ x = x.reshape(batch * n_stocks, seq_len, n_feats) # [b * s, t, f]
135
+ x = x.permute(1, 0, 2).contiguous() # [t, b * s, f]
136
+
137
+ return x, batch, n_stocks
138
+
139
+ def _reshape_output(
140
+ self, x: torch.Tensor, batch: int, n_stocks: int
141
+ ) -> torch.Tensor:
142
+ """Reshape output from [seq_len, batch*n_stocks, hidden_dim] to [batch, n_stocks, hidden_dim]."""
143
+ output = x[-1] # Take last time step: [b * s, hidden_dim]
144
+ output = output.reshape(batch, n_stocks, -1) # [b, s, hidden_dim]
145
+ return output
146
+
147
+ def _generate_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
148
+ """Generate causal attention mask."""
149
+ mask = torch.triu(
150
+ torch.ones(seq_len, seq_len, device=device) * float("-inf"), diagonal=1
151
+ )
152
+ return mask
153
+ ```
154
+
155
+ ### RPB Components
156
+
157
+ ```python
158
+ class TransformerRPBComp(BaseTransformerComp):
159
+ """TransformerComp with Relative Bias Pooling."""
160
+
161
+ def __init__(
162
+ self,
163
+ input_dim: int,
164
+ hidden_dim: int,
165
+ num_layers: int,
166
+ num_heads: int,
167
+ dropout: float = 0.1,
168
+ mask_type: str = "none",
169
+ ) -> None:
170
+ super().__init__(input_dim, hidden_dim, num_layers, num_heads, dropout)
171
+ self.feature_layer = nn.Linear(input_dim, hidden_dim)
172
+ self.pe = PositionalEncoding(hidden_dim, dropout)
173
+ self.encoder_norm = nn.LayerNorm(hidden_dim)
174
+ self.mask_type = mask_type
175
+ self.rbp = RelativePositionBiasV2(n_heads=num_heads)
176
+ self.encoder_layers = nn.ModuleList(
177
+ [
178
+ TransformerEncoderLayerWithRPB(
179
+ d_model=hidden_dim,
180
+ nhead=num_heads,
181
+ dim_feedforward=hidden_dim * 4,
182
+ dropout=dropout,
183
+ rbp=self.rbp,
184
+ )
185
+ for _ in range(num_layers)
186
+ ]
187
+ )
188
+
189
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
190
+ """x.shape [batch, seq_len, n_stocks, n_feats]"""
191
+ x, batch, n_stocks = self._reshape_input(x)
192
+ seq_len = x.shape[0]
193
+
194
+ x = self.encoder_norm(self.pe(self.feature_layer(x))) # [t, b * s, d_model]
195
+
196
+ if self.mask_type == "causal":
197
+ mask = self._generate_causal_mask(seq_len, x.device).permute(1, 0)
198
+ else:
199
+ mask = None
200
+
201
+ for layer in self.encoder_layers:
202
+ x = layer(x, src_mask=mask)
203
+
204
+ return self._reshape_output(x, batch, n_stocks)
205
+ ```
206
+
207
+ ### Transformer Module
208
+
209
+ ```python
210
+ class Transformer(nn.Module):
211
+ def __init__(
212
+ self,
213
+ input_dim: int,
214
+ output_dim: int = 1,
215
+ hidden_dim: int = 256,
216
+ num_layers: int = 2,
217
+ num_heads: int = 4,
218
+ dropout: float = 0.1,
219
+ tfm_type: str = "base",
220
+ mask_type: str = "none",
221
+ ) -> None:
222
+ """
223
+ tfm_type: "base", "rope", "rpb"
224
+ mask_type: "none", "alibi", "causal"
225
+ """
226
+ super().__init__()
227
+ self.tfm_type = tfm_type
228
+ self.mask_type = mask_type
229
+
230
+ tfm_type_mapper = {
231
+ "base": TransformerComp,
232
+ "alibi": TransformerComp,
233
+ "rope": TransformerRoPEComp,
234
+ "rpb": TransformerRPBComp,
235
+ }
236
+ self.transformer_encoder = tfm_type_mapper[self.tfm_type](
237
+ input_dim=input_dim,
238
+ hidden_dim=hidden_dim,
239
+ num_layers=num_layers,
240
+ num_heads=num_heads,
241
+ dropout=dropout,
242
+ mask_type=mask_type,
243
+ )
244
+ self.fc_out = nn.Sequential(
245
+ nn.Linear(hidden_dim, hidden_dim, bias=True),
246
+ nn.GELU(),
247
+ nn.Linear(hidden_dim, output_dim, bias=True),
248
+ )
249
+
250
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
251
+ tfm_out = self.transformer_encoder(x) # [b, s, d_model]
252
+ final_out = self.fc_out(tfm_out).squeeze(-1) # [b, s]
253
+
254
+ return final_out
255
+ ```
256
+
257
+ ### Model Configuration
258
+
259
+ ```yaml
260
+ input_dim: 8,
261
+ output_dim: 1,
262
+ hidden_dim: 64,
263
+ num_layers: 2,
264
+ num_heads: 4,
265
+ dropout: 0.0,
266
+ tfm_type: "rpb",
267
+ mask_type: "causal",
268
+ ```