Transformer-RPB / README.md
Abner0803's picture
Update README.md
113dff4 verified

To use these checkpoints, you need to use the following model structure for Transformer

Import used packages

import math

import torch
from torch import nn

PositionalEncoding

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None:
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(
            1
        )  # (max_len, 1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)  # (max_len, d_model // 2)
        truncated_div_term = div_term[: d_model // 2]
        pe[:, 1::2] = torch.cos(position * truncated_div_term)  #
        pe = pe.unsqueeze(0).transpose(0, 1)  # (max_len, 1, d_model)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[: x.size(0), :, :]
        return self.dropout(x)

RPBClass

class RelativePositionBiasV2(nn.Module):
    def __init__(self, n_heads, num_buckets=32, max_distance=128, bidirectional=True):
        super().__init__()
        assert num_buckets % 2 == 0, "num_buckets should be even for bidirectional"
        self.n_heads = n_heads
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.bidirectional = bidirectional
        self.emb = nn.Embedding(num_buckets, n_heads)

    def _relative_position_bucket(self, relative_position):
        """
        relative_position: [Tq, Tk] = k - q
        returns bucket ids in [0, num_buckets-1]
        """
        num_buckets = self.num_buckets
        max_distance = self.max_distance

        ret = torch.zeros_like(relative_position, dtype=torch.long)
        n = -relative_position  # want smaller buckets for n > 0 (keys before queries)

        if self.bidirectional:
            half = num_buckets // 2
            ret += (n < 0).long() * half
            n = n.abs()
            num_buckets = half  # remaining buckets for non-negative distances
        else:
            n = torch.clamp(n, min=0)

        # Now n >= 0
        max_exact = num_buckets // 2
        is_small = n < max_exact
        # Avoid log(0) and division by zero; also ensure max_distance > max_exact
        denom = max(1.0, math.log(max(max_distance, max_exact + 1) / max(1, max_exact)))
        val_if_large = (
            max_exact
            + (
                (torch.log(n.float() / max(1, max_exact) + 1e-6) / denom)
                * (num_buckets - max_exact)
            ).long()
        )
        val_if_large = torch.clamp(val_if_large, max=num_buckets - 1)

        ret += torch.where(is_small, n.long(), val_if_large)
        # Final clamp for absolute safety when bidirectional half-split was applied
        return torch.clamp(ret, min=0, max=self.num_buckets - 1)

    def forward(self, Tq, Tk, device=None):
        device = device or torch.device("cpu")
        qpos = torch.arange(Tq, device=device)[:, None]
        kpos = torch.arange(Tk, device=device)[None, :]
        buckets = self._relative_position_bucket(kpos - qpos)  # [Tq, Tk]
        bias = self.emb(buckets)  # [Tq, Tk, H]
        return bias.permute(2, 0, 1)  # [H, Tq, Tk]

Transformer Base Class

class BaseTransformerComp(nn.Module):
    """Base class for transformer-based intra-stock components."""

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        num_layers: int,
        num_heads: int,
        dropout: float = 0.1,
        mask_type: str = "none",
    ) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout_rate = dropout
        self.mask_type = mask_type

    def _reshape_input(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]:
        """
        Reshape input from [batch, seq_len, n_stocks, n_feats] to [seq_len, batch*n_stocks, n_feats].
        Returns reshaped tensor and original batch/n_stocks sizes for later reconstruction.
        """
        batch, seq_len, n_stocks, n_feats = x.shape

        if batch == 0 or seq_len == 0 or n_stocks == 0:
            raise ValueError(
                f"Invalid input dimensions: batch={batch}, seq_len={seq_len}, "
                f"n_stocks={n_stocks}, n_feats={n_feats}"
            )

        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.reshape(batch * n_stocks, seq_len, n_feats)  # [b * s, t, f]
        x = x.permute(1, 0, 2).contiguous()  # [t, b * s, f]

        return x, batch, n_stocks

    def _reshape_output(
        self, x: torch.Tensor, batch: int, n_stocks: int
    ) -> torch.Tensor:
        """Reshape output from [seq_len, batch*n_stocks, hidden_dim] to [batch, n_stocks, hidden_dim]."""
        output = x[-1]  # Take last time step: [b * s, hidden_dim]
        output = output.reshape(batch, n_stocks, -1)  # [b, s, hidden_dim]
        return output

    def _generate_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
        """Generate causal attention mask."""
        mask = torch.triu(
            torch.ones(seq_len, seq_len, device=device) * float("-inf"), diagonal=1
        )
        return mask

Transformer Encoder Layer with RPB

class TransformerEncoderLayerWithRPB(nn.Module):
    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int,
        dropout: float,
        rbp,
    ):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.rbp = rbp

        # QKV projections
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)

        # FFN layers
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        # Normalization and dropout
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = F.relu

    def forward(
        self,
        src: torch.Tensor,
        src_mask: Optional[torch.Tensor] = None,
        src_key_padding_mask: Optional[torch.Tensor] = None,
        is_causal: bool = False,
    ) -> torch.Tensor:
        seq_len, batch_size, d_model = src.shape
        head_dim = d_model // self.nhead
        qkv = self.qkv_proj(src)
        q, k, v = qkv.chunk(3, dim=-1)
        q = q.reshape(seq_len, batch_size, self.nhead, head_dim).permute(1, 2, 0, 3)
        k = k.reshape(seq_len, batch_size, self.nhead, head_dim).permute(1, 2, 0, 3)
        v = v.reshape(seq_len, batch_size, self.nhead, head_dim).permute(1, 2, 0, 3)
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)

        # Add RBP after QK^T
        rbp_bias = self.rbp(
            seq_len, seq_len, device=src.device
        )  # [nhead, seq_len, seq_len]
        attn_weights = attn_weights + rbp_bias.unsqueeze(
            0
        )  # [batch, nhead, seq_len, seq_len]

        if src_mask is not None:
            attn_weights = attn_weights + src_mask.unsqueeze(0).unsqueeze(0)

        if src_key_padding_mask is not None:
            attn_weights = attn_weights.masked_fill(
                src_key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
            )

        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout1(attn_weights)
        attn_output = torch.matmul(attn_weights, v)  # [batch, nhead, seq_len, head_dim]
        attn_output = attn_output.permute(2, 0, 1, 3).reshape(
            seq_len, batch_size, d_model
        )
        attn_output = self.out_proj(attn_output)
        src2 = src + self.dropout1(attn_output)
        src2 = self.norm1(src2)
        ffn_output = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src3 = src2 + self.dropout2(ffn_output)
        src3 = self.norm2(src3)

        return src3

RPB Components

class TransformerRPBComp(BaseTransformerComp):
    """TransformerComp with Relative Bias Pooling."""

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        num_layers: int,
        num_heads: int,
        dropout: float = 0.1,
        mask_type: str = "none",
    ) -> None:
        super().__init__(input_dim, hidden_dim, num_layers, num_heads, dropout)
        self.feature_layer = nn.Linear(input_dim, hidden_dim)
        self.pe = PositionalEncoding(hidden_dim, dropout)
        self.encoder_norm = nn.LayerNorm(hidden_dim)
        self.mask_type = mask_type
        self.rbp = RelativePositionBiasV2(n_heads=num_heads)
        self.encoder_layers = nn.ModuleList(
            [
                TransformerEncoderLayerWithRPB(
                    d_model=hidden_dim,
                    nhead=num_heads,
                    dim_feedforward=hidden_dim * 4,
                    dropout=dropout,
                    rbp=self.rbp,
                )
                for _ in range(num_layers)
            ]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """x.shape [batch, seq_len, n_stocks, n_feats]"""
        x, batch, n_stocks = self._reshape_input(x)
        seq_len = x.shape[0]

        x = self.encoder_norm(self.pe(self.feature_layer(x)))  # [t, b * s, d_model]

        if self.mask_type == "causal":
            mask = self._generate_causal_mask(seq_len, x.device).permute(1, 0)
        else:
            mask = None

        for layer in self.encoder_layers:
            x = layer(x, src_mask=mask)

        return self._reshape_output(x, batch, n_stocks)

Transformer Module

class Transformer(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int = 1,
        hidden_dim: int = 256,
        num_layers: int = 2,
        num_heads: int = 4,
        dropout: float = 0.1,
        tfm_type: str = "base",
        mask_type: str = "none",
    ) -> None:
        """
        tfm_type: "base", "rope", "rpb"
        mask_type: "none", "alibi", "causal"
        """
        super().__init__()
        self.tfm_type = tfm_type
        self.mask_type = mask_type

        tfm_type_mapper = {
            "base": TransformerComp,
            "alibi": TransformerComp,
            "rope": TransformerRoPEComp,
            "rpb": TransformerRPBComp,
        }
        self.transformer_encoder = tfm_type_mapper[self.tfm_type](
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            dropout=dropout,
            mask_type=mask_type,
        )
        self.fc_out = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim, bias=True),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim, bias=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        tfm_out = self.transformer_encoder(x)  # [b, s, d_model]
        final_out = self.fc_out(tfm_out).squeeze(-1)  # [b, s]

        return final_out

Model Configuration

input_dim: 8,
output_dim: 1,
hidden_dim: 64,
num_layers: 2,
num_heads: 4,
dropout: 0.0,
tfm_type: "rpb",
mask_type: "causal",