## To use these checkpoints, you need to use the following model structure for Transformer ### Import used packages ```python import math import torch from torch import nn ``` ### PositionalEncoding ```python class PositionalEncoding(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None: super().__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze( 1 ) # (max_len, 1) div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(position * div_term) # (max_len, d_model // 2) truncated_div_term = div_term[: d_model // 2] pe[:, 1::2] = torch.cos(position * truncated_div_term) # pe = pe.unsqueeze(0).transpose(0, 1) # (max_len, 1, d_model) self.register_buffer("pe", pe) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.pe[: x.size(0), :, :] return self.dropout(x) ``` ### RPBClass ```python class RelativePositionBiasV2(nn.Module): def __init__(self, n_heads, num_buckets=32, max_distance=128, bidirectional=True): super().__init__() assert num_buckets % 2 == 0, "num_buckets should be even for bidirectional" self.n_heads = n_heads self.num_buckets = num_buckets self.max_distance = max_distance self.bidirectional = bidirectional self.emb = nn.Embedding(num_buckets, n_heads) def _relative_position_bucket(self, relative_position): """ relative_position: [Tq, Tk] = k - q returns bucket ids in [0, num_buckets-1] """ num_buckets = self.num_buckets max_distance = self.max_distance ret = torch.zeros_like(relative_position, dtype=torch.long) n = -relative_position # want smaller buckets for n > 0 (keys before queries) if self.bidirectional: half = num_buckets // 2 ret += (n < 0).long() * half n = n.abs() num_buckets = half # remaining buckets for non-negative distances else: n = torch.clamp(n, min=0) # Now n >= 0 max_exact = num_buckets // 2 is_small = n < max_exact # Avoid log(0) and division by zero; also ensure max_distance > max_exact denom = max(1.0, math.log(max(max_distance, max_exact + 1) / max(1, max_exact))) val_if_large = ( max_exact + ( (torch.log(n.float() / max(1, max_exact) + 1e-6) / denom) * (num_buckets - max_exact) ).long() ) val_if_large = torch.clamp(val_if_large, max=num_buckets - 1) ret += torch.where(is_small, n.long(), val_if_large) # Final clamp for absolute safety when bidirectional half-split was applied return torch.clamp(ret, min=0, max=self.num_buckets - 1) def forward(self, Tq, Tk, device=None): device = device or torch.device("cpu") qpos = torch.arange(Tq, device=device)[:, None] kpos = torch.arange(Tk, device=device)[None, :] buckets = self._relative_position_bucket(kpos - qpos) # [Tq, Tk] bias = self.emb(buckets) # [Tq, Tk, H] return bias.permute(2, 0, 1) # [H, Tq, Tk] ``` ### Transformer Base Class ```python class BaseTransformerComp(nn.Module): """Base class for transformer-based intra-stock components.""" def __init__( self, input_dim: int, hidden_dim: int, num_layers: int, num_heads: int, dropout: float = 0.1, mask_type: str = "none", ) -> None: super().__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.num_layers = num_layers self.num_heads = num_heads self.dropout_rate = dropout self.mask_type = mask_type def _reshape_input(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]: """ Reshape input from [batch, seq_len, n_stocks, n_feats] to [seq_len, batch*n_stocks, n_feats]. Returns reshaped tensor and original batch/n_stocks sizes for later reconstruction. """ batch, seq_len, n_stocks, n_feats = x.shape if batch == 0 or seq_len == 0 or n_stocks == 0: raise ValueError( f"Invalid input dimensions: batch={batch}, seq_len={seq_len}, " f"n_stocks={n_stocks}, n_feats={n_feats}" ) x = x.permute(0, 2, 1, 3).contiguous() x = x.reshape(batch * n_stocks, seq_len, n_feats) # [b * s, t, f] x = x.permute(1, 0, 2).contiguous() # [t, b * s, f] return x, batch, n_stocks def _reshape_output( self, x: torch.Tensor, batch: int, n_stocks: int ) -> torch.Tensor: """Reshape output from [seq_len, batch*n_stocks, hidden_dim] to [batch, n_stocks, hidden_dim].""" output = x[-1] # Take last time step: [b * s, hidden_dim] output = output.reshape(batch, n_stocks, -1) # [b, s, hidden_dim] return output def _generate_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: """Generate causal attention mask.""" mask = torch.triu( torch.ones(seq_len, seq_len, device=device) * float("-inf"), diagonal=1 ) return mask ``` ### Transformer Encoder Layer with RPB ```python class TransformerEncoderLayerWithRPB(nn.Module): def __init__( self, d_model: int, nhead: int, dim_feedforward: int, dropout: float, rbp, ): super().__init__() self.d_model = d_model self.nhead = nhead self.rbp = rbp # QKV projections self.qkv_proj = nn.Linear(d_model, 3 * d_model) self.out_proj = nn.Linear(d_model, d_model) # FFN layers self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) # Normalization and dropout self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = F.relu def forward( self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, is_causal: bool = False, ) -> torch.Tensor: seq_len, batch_size, d_model = src.shape head_dim = d_model // self.nhead qkv = self.qkv_proj(src) q, k, v = qkv.chunk(3, dim=-1) q = q.reshape(seq_len, batch_size, self.nhead, head_dim).permute(1, 2, 0, 3) k = k.reshape(seq_len, batch_size, self.nhead, head_dim).permute(1, 2, 0, 3) v = v.reshape(seq_len, batch_size, self.nhead, head_dim).permute(1, 2, 0, 3) attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim) # Add RBP after QK^T rbp_bias = self.rbp( seq_len, seq_len, device=src.device ) # [nhead, seq_len, seq_len] attn_weights = attn_weights + rbp_bias.unsqueeze( 0 ) # [batch, nhead, seq_len, seq_len] if src_mask is not None: attn_weights = attn_weights + src_mask.unsqueeze(0).unsqueeze(0) if src_key_padding_mask is not None: attn_weights = attn_weights.masked_fill( src_key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf") ) attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = self.dropout1(attn_weights) attn_output = torch.matmul(attn_weights, v) # [batch, nhead, seq_len, head_dim] attn_output = attn_output.permute(2, 0, 1, 3).reshape( seq_len, batch_size, d_model ) attn_output = self.out_proj(attn_output) src2 = src + self.dropout1(attn_output) src2 = self.norm1(src2) ffn_output = self.linear2(self.dropout(self.activation(self.linear1(src2)))) src3 = src2 + self.dropout2(ffn_output) src3 = self.norm2(src3) return src3 ``` ### RPB Components ```python class TransformerRPBComp(BaseTransformerComp): """TransformerComp with Relative Bias Pooling.""" def __init__( self, input_dim: int, hidden_dim: int, num_layers: int, num_heads: int, dropout: float = 0.1, mask_type: str = "none", ) -> None: super().__init__(input_dim, hidden_dim, num_layers, num_heads, dropout) self.feature_layer = nn.Linear(input_dim, hidden_dim) self.pe = PositionalEncoding(hidden_dim, dropout) self.encoder_norm = nn.LayerNorm(hidden_dim) self.mask_type = mask_type self.rbp = RelativePositionBiasV2(n_heads=num_heads) self.encoder_layers = nn.ModuleList( [ TransformerEncoderLayerWithRPB( d_model=hidden_dim, nhead=num_heads, dim_feedforward=hidden_dim * 4, dropout=dropout, rbp=self.rbp, ) for _ in range(num_layers) ] ) def forward(self, x: torch.Tensor) -> torch.Tensor: """x.shape [batch, seq_len, n_stocks, n_feats]""" x, batch, n_stocks = self._reshape_input(x) seq_len = x.shape[0] x = self.encoder_norm(self.pe(self.feature_layer(x))) # [t, b * s, d_model] if self.mask_type == "causal": mask = self._generate_causal_mask(seq_len, x.device).permute(1, 0) else: mask = None for layer in self.encoder_layers: x = layer(x, src_mask=mask) return self._reshape_output(x, batch, n_stocks) ``` ### Transformer Module ```python class Transformer(nn.Module): def __init__( self, input_dim: int, output_dim: int = 1, hidden_dim: int = 256, num_layers: int = 2, num_heads: int = 4, dropout: float = 0.1, tfm_type: str = "base", mask_type: str = "none", ) -> None: """ tfm_type: "base", "rope", "rpb" mask_type: "none", "alibi", "causal" """ super().__init__() self.tfm_type = tfm_type self.mask_type = mask_type tfm_type_mapper = { "base": TransformerComp, "alibi": TransformerComp, "rope": TransformerRoPEComp, "rpb": TransformerRPBComp, } self.transformer_encoder = tfm_type_mapper[self.tfm_type]( input_dim=input_dim, hidden_dim=hidden_dim, num_layers=num_layers, num_heads=num_heads, dropout=dropout, mask_type=mask_type, ) self.fc_out = nn.Sequential( nn.Linear(hidden_dim, hidden_dim, bias=True), nn.GELU(), nn.Linear(hidden_dim, output_dim, bias=True), ) def forward(self, x: torch.Tensor) -> torch.Tensor: tfm_out = self.transformer_encoder(x) # [b, s, d_model] final_out = self.fc_out(tfm_out).squeeze(-1) # [b, s] return final_out ``` ### Model Configuration ```yaml input_dim: 8, output_dim: 1, hidden_dim: 64, num_layers: 2, num_heads: 4, dropout: 0.0, tfm_type: "rpb", mask_type: "causal", ```