| ## To use these checkpoints, you need to use the following model structure for Transformer | |
| ### Import used packages | |
| ```python | |
| import math | |
| import torch | |
| from torch import nn | |
| ``` | |
| ### PositionalEncoding | |
| ```python | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None: | |
| super().__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze( | |
| 1 | |
| ) # (max_len, 1) | |
| div_term = torch.exp( | |
| torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) | |
| ) | |
| pe[:, 0::2] = torch.sin(position * div_term) # (max_len, d_model // 2) | |
| truncated_div_term = div_term[: d_model // 2] | |
| pe[:, 1::2] = torch.cos(position * truncated_div_term) # | |
| pe = pe.unsqueeze(0).transpose(0, 1) # (max_len, 1, d_model) | |
| self.register_buffer("pe", pe) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.pe[: x.size(0), :, :] | |
| return self.dropout(x) | |
| ``` | |
| ### RPBClass | |
| ```python | |
| class RelativePositionBiasV2(nn.Module): | |
| def __init__(self, n_heads, num_buckets=32, max_distance=128, bidirectional=True): | |
| super().__init__() | |
| assert num_buckets % 2 == 0, "num_buckets should be even for bidirectional" | |
| self.n_heads = n_heads | |
| self.num_buckets = num_buckets | |
| self.max_distance = max_distance | |
| self.bidirectional = bidirectional | |
| self.emb = nn.Embedding(num_buckets, n_heads) | |
| def _relative_position_bucket(self, relative_position): | |
| """ | |
| relative_position: [Tq, Tk] = k - q | |
| returns bucket ids in [0, num_buckets-1] | |
| """ | |
| num_buckets = self.num_buckets | |
| max_distance = self.max_distance | |
| ret = torch.zeros_like(relative_position, dtype=torch.long) | |
| n = -relative_position # want smaller buckets for n > 0 (keys before queries) | |
| if self.bidirectional: | |
| half = num_buckets // 2 | |
| ret += (n < 0).long() * half | |
| n = n.abs() | |
| num_buckets = half # remaining buckets for non-negative distances | |
| else: | |
| n = torch.clamp(n, min=0) | |
| # Now n >= 0 | |
| max_exact = num_buckets // 2 | |
| is_small = n < max_exact | |
| # Avoid log(0) and division by zero; also ensure max_distance > max_exact | |
| denom = max(1.0, math.log(max(max_distance, max_exact + 1) / max(1, max_exact))) | |
| val_if_large = ( | |
| max_exact | |
| + ( | |
| (torch.log(n.float() / max(1, max_exact) + 1e-6) / denom) | |
| * (num_buckets - max_exact) | |
| ).long() | |
| ) | |
| val_if_large = torch.clamp(val_if_large, max=num_buckets - 1) | |
| ret += torch.where(is_small, n.long(), val_if_large) | |
| # Final clamp for absolute safety when bidirectional half-split was applied | |
| return torch.clamp(ret, min=0, max=self.num_buckets - 1) | |
| def forward(self, Tq, Tk, device=None): | |
| device = device or torch.device("cpu") | |
| qpos = torch.arange(Tq, device=device)[:, None] | |
| kpos = torch.arange(Tk, device=device)[None, :] | |
| buckets = self._relative_position_bucket(kpos - qpos) # [Tq, Tk] | |
| bias = self.emb(buckets) # [Tq, Tk, H] | |
| return bias.permute(2, 0, 1) # [H, Tq, Tk] | |
| ``` | |
| ### Transformer Base Class | |
| ```python | |
| class BaseTransformerComp(nn.Module): | |
| """Base class for transformer-based intra-stock components.""" | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| hidden_dim: int, | |
| num_layers: int, | |
| num_heads: int, | |
| dropout: float = 0.1, | |
| mask_type: str = "none", | |
| ) -> None: | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.hidden_dim = hidden_dim | |
| self.num_layers = num_layers | |
| self.num_heads = num_heads | |
| self.dropout_rate = dropout | |
| self.mask_type = mask_type | |
| def _reshape_input(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]: | |
| """ | |
| Reshape input from [batch, seq_len, n_stocks, n_feats] to [seq_len, batch*n_stocks, n_feats]. | |
| Returns reshaped tensor and original batch/n_stocks sizes for later reconstruction. | |
| """ | |
| batch, seq_len, n_stocks, n_feats = x.shape | |
| if batch == 0 or seq_len == 0 or n_stocks == 0: | |
| raise ValueError( | |
| f"Invalid input dimensions: batch={batch}, seq_len={seq_len}, " | |
| f"n_stocks={n_stocks}, n_feats={n_feats}" | |
| ) | |
| x = x.permute(0, 2, 1, 3).contiguous() | |
| x = x.reshape(batch * n_stocks, seq_len, n_feats) # [b * s, t, f] | |
| x = x.permute(1, 0, 2).contiguous() # [t, b * s, f] | |
| return x, batch, n_stocks | |
| def _reshape_output( | |
| self, x: torch.Tensor, batch: int, n_stocks: int | |
| ) -> torch.Tensor: | |
| """Reshape output from [seq_len, batch*n_stocks, hidden_dim] to [batch, n_stocks, hidden_dim].""" | |
| output = x[-1] # Take last time step: [b * s, hidden_dim] | |
| output = output.reshape(batch, n_stocks, -1) # [b, s, hidden_dim] | |
| return output | |
| def _generate_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: | |
| """Generate causal attention mask.""" | |
| mask = torch.triu( | |
| torch.ones(seq_len, seq_len, device=device) * float("-inf"), diagonal=1 | |
| ) | |
| return mask | |
| ``` | |
| ### Transformer Encoder Layer with RPB | |
| ```python | |
| class TransformerEncoderLayerWithRPB(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| nhead: int, | |
| dim_feedforward: int, | |
| dropout: float, | |
| rbp, | |
| ): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.nhead = nhead | |
| self.rbp = rbp | |
| # QKV projections | |
| self.qkv_proj = nn.Linear(d_model, 3 * d_model) | |
| self.out_proj = nn.Linear(d_model, d_model) | |
| # FFN layers | |
| self.linear1 = nn.Linear(d_model, dim_feedforward) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model) | |
| # Normalization and dropout | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.activation = F.relu | |
| def forward( | |
| self, | |
| src: torch.Tensor, | |
| src_mask: Optional[torch.Tensor] = None, | |
| src_key_padding_mask: Optional[torch.Tensor] = None, | |
| is_causal: bool = False, | |
| ) -> torch.Tensor: | |
| seq_len, batch_size, d_model = src.shape | |
| head_dim = d_model // self.nhead | |
| qkv = self.qkv_proj(src) | |
| q, k, v = qkv.chunk(3, dim=-1) | |
| q = q.reshape(seq_len, batch_size, self.nhead, head_dim).permute(1, 2, 0, 3) | |
| k = k.reshape(seq_len, batch_size, self.nhead, head_dim).permute(1, 2, 0, 3) | |
| v = v.reshape(seq_len, batch_size, self.nhead, head_dim).permute(1, 2, 0, 3) | |
| attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim) | |
| # Add RBP after QK^T | |
| rbp_bias = self.rbp( | |
| seq_len, seq_len, device=src.device | |
| ) # [nhead, seq_len, seq_len] | |
| attn_weights = attn_weights + rbp_bias.unsqueeze( | |
| 0 | |
| ) # [batch, nhead, seq_len, seq_len] | |
| if src_mask is not None: | |
| attn_weights = attn_weights + src_mask.unsqueeze(0).unsqueeze(0) | |
| if src_key_padding_mask is not None: | |
| attn_weights = attn_weights.masked_fill( | |
| src_key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf") | |
| ) | |
| attn_weights = F.softmax(attn_weights, dim=-1) | |
| attn_weights = self.dropout1(attn_weights) | |
| attn_output = torch.matmul(attn_weights, v) # [batch, nhead, seq_len, head_dim] | |
| attn_output = attn_output.permute(2, 0, 1, 3).reshape( | |
| seq_len, batch_size, d_model | |
| ) | |
| attn_output = self.out_proj(attn_output) | |
| src2 = src + self.dropout1(attn_output) | |
| src2 = self.norm1(src2) | |
| ffn_output = self.linear2(self.dropout(self.activation(self.linear1(src2)))) | |
| src3 = src2 + self.dropout2(ffn_output) | |
| src3 = self.norm2(src3) | |
| return src3 | |
| ``` | |
| ### RPB Components | |
| ```python | |
| class TransformerRPBComp(BaseTransformerComp): | |
| """TransformerComp with Relative Bias Pooling.""" | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| hidden_dim: int, | |
| num_layers: int, | |
| num_heads: int, | |
| dropout: float = 0.1, | |
| mask_type: str = "none", | |
| ) -> None: | |
| super().__init__(input_dim, hidden_dim, num_layers, num_heads, dropout) | |
| self.feature_layer = nn.Linear(input_dim, hidden_dim) | |
| self.pe = PositionalEncoding(hidden_dim, dropout) | |
| self.encoder_norm = nn.LayerNorm(hidden_dim) | |
| self.mask_type = mask_type | |
| self.rbp = RelativePositionBiasV2(n_heads=num_heads) | |
| self.encoder_layers = nn.ModuleList( | |
| [ | |
| TransformerEncoderLayerWithRPB( | |
| d_model=hidden_dim, | |
| nhead=num_heads, | |
| dim_feedforward=hidden_dim * 4, | |
| dropout=dropout, | |
| rbp=self.rbp, | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """x.shape [batch, seq_len, n_stocks, n_feats]""" | |
| x, batch, n_stocks = self._reshape_input(x) | |
| seq_len = x.shape[0] | |
| x = self.encoder_norm(self.pe(self.feature_layer(x))) # [t, b * s, d_model] | |
| if self.mask_type == "causal": | |
| mask = self._generate_causal_mask(seq_len, x.device).permute(1, 0) | |
| else: | |
| mask = None | |
| for layer in self.encoder_layers: | |
| x = layer(x, src_mask=mask) | |
| return self._reshape_output(x, batch, n_stocks) | |
| ``` | |
| ### Transformer Module | |
| ```python | |
| class Transformer(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| output_dim: int = 1, | |
| hidden_dim: int = 256, | |
| num_layers: int = 2, | |
| num_heads: int = 4, | |
| dropout: float = 0.1, | |
| tfm_type: str = "base", | |
| mask_type: str = "none", | |
| ) -> None: | |
| """ | |
| tfm_type: "base", "rope", "rpb" | |
| mask_type: "none", "alibi", "causal" | |
| """ | |
| super().__init__() | |
| self.tfm_type = tfm_type | |
| self.mask_type = mask_type | |
| tfm_type_mapper = { | |
| "base": TransformerComp, | |
| "alibi": TransformerComp, | |
| "rope": TransformerRoPEComp, | |
| "rpb": TransformerRPBComp, | |
| } | |
| self.transformer_encoder = tfm_type_mapper[self.tfm_type]( | |
| input_dim=input_dim, | |
| hidden_dim=hidden_dim, | |
| num_layers=num_layers, | |
| num_heads=num_heads, | |
| dropout=dropout, | |
| mask_type=mask_type, | |
| ) | |
| self.fc_out = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim, bias=True), | |
| nn.GELU(), | |
| nn.Linear(hidden_dim, output_dim, bias=True), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| tfm_out = self.transformer_encoder(x) # [b, s, d_model] | |
| final_out = self.fc_out(tfm_out).squeeze(-1) # [b, s] | |
| return final_out | |
| ``` | |
| ### Model Configuration | |
| ```yaml | |
| input_dim: 8, | |
| output_dim: 1, | |
| hidden_dim: 64, | |
| num_layers: 2, | |
| num_heads: 4, | |
| dropout: 0.0, | |
| tfm_type: "rpb", | |
| mask_type: "causal", | |
| ``` |