| | Transformer_ALiBi shares most of the modules with [Transformer-RPB](https://huggingface.co/Abner0803/Transformer-RPB) except of the below modules | |
| ## TransformerComp | |
| Add `TransformerComp` into your current script | |
| ```python | |
| class TransformerComp(BaseTransformerComp): | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| hidden_dim: int, | |
| num_layers: int, | |
| num_heads: int, | |
| dropout: float = 0.1, | |
| mask_type: str = "none", | |
| ) -> None: | |
| """ | |
| mask_type: "none", "alibi", "calibi", "causal" | |
| """ | |
| super().__init__(input_dim, hidden_dim, num_layers, num_heads, dropout) | |
| self.feature_layer = nn.Linear(input_dim, hidden_dim) | |
| self.pe = PositionalEncoding(hidden_dim, dropout) | |
| self.mask_type = mask_type | |
| if self.mask_type in ["alibi", "calibi"]: | |
| closest_power_of_2 = 2 ** int(math.log2(num_heads)) | |
| base_slopes = torch.pow( | |
| 2, | |
| -torch.arange(1, closest_power_of_2 + 1, dtype=torch.float32) | |
| * 8 | |
| / closest_power_of_2, | |
| ) | |
| if closest_power_of_2 != num_heads: | |
| extra_slopes = torch.pow( | |
| 2, | |
| -torch.arange( | |
| 1, | |
| 2 * (num_heads - closest_power_of_2) + 1, | |
| 2, | |
| dtype=torch.float32, | |
| ) | |
| * 8 | |
| / closest_power_of_2, | |
| ) | |
| base_slopes = torch.cat([base_slopes, extra_slopes]) | |
| self.register_buffer( | |
| "slopes", base_slopes.view(-1, 1, 1) | |
| ) # [n_heads, 1, 1] | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=hidden_dim, | |
| nhead=num_heads, | |
| dim_feedforward=hidden_dim * 4, | |
| dropout=dropout, | |
| activation="relu", | |
| batch_first=False, | |
| ) | |
| self.encoder_norm = nn.LayerNorm(hidden_dim) | |
| self.transformer_encoder = nn.TransformerEncoder( | |
| encoder_layer, num_layers=num_layers | |
| ) | |
| def _generate_alibi_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: | |
| """ | |
| Creates a mask that is Relative (ALiBi). | |
| Returns: [Num_Heads, Seq_Len, Seq_Len] | |
| """ | |
| context_pos = torch.arange(seq_len, device=device).unsqueeze(1) | |
| memory_pos = torch.arange(seq_len, device=device).unsqueeze(0) | |
| distance = torch.abs(context_pos - memory_pos) | |
| alibi_bias = distance * -1.0 * self.slopes | |
| return alibi_bias | |
| def _generate_causal_alibi_mask( | |
| self, seq_len: int, device: torch.device | |
| ) -> torch.Tensor: | |
| """ | |
| Creates a mask that is Relative (ALiBi) and Causal (Mask Wall) | |
| """ | |
| context_pos = torch.arange(seq_len, device=device).unsqueeze(1) | |
| memory_pos = torch.arange(seq_len, device=device).unsqueeze(0) | |
| distance = torch.abs(context_pos - memory_pos) | |
| alibi_bias = distance * -1.0 * self.slopes | |
| causal_mask = torch.triu( | |
| torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1 | |
| ) | |
| alibi_bias.masked_fill_(causal_mask, float("-inf")) | |
| return alibi_bias | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """x.shape [batch, seq_len, n_stocks, n_feats]""" | |
| x, batch, n_stocks = self._reshape_input(x) | |
| seq_len = x.shape[0] | |
| x = self.encoder_norm(self.pe(self.feature_layer(x))) # [t, b * s, d_model] | |
| if self.mask_type == "causal": | |
| mask = self._generate_causal_mask(seq_len, x.device).permute(1, 0) | |
| elif self.mask_type == "alibi": | |
| mask = self._generate_alibi_mask(seq_len, x.device).repeat( | |
| x.shape[1], 1, 1 | |
| ) # [b * s, t, t] | |
| elif self.mask_type == "calibi": | |
| mask = self._generate_causal_alibi_mask(seq_len, x.device).repeat( | |
| x.shape[1], 1, 1 | |
| ) | |
| else: | |
| mask = None | |
| x = self.transformer_encoder(x, mask=mask) | |
| return self._reshape_output(x, batch, n_stocks) | |
| ``` | |
| ## Model Config | |
| ```yaml | |
| input_dim: 8 | |
| output_dim: 1 | |
| hidden_dim: 64 | |
| num_layers: 2 | |
| num_heads: 4 | |
| dropout: 0.0 | |
| tfm_type: "base" | |
| mask_type: "alibi" | |
| ``` |