|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Transformer utilities for GSLRM. |
|
|
|
|
|
This module contains the core transformer components used by the GSLRM model, |
|
|
including self-attention, MLP layers, and transformer blocks. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
|
|
|
try: |
|
|
import xformers.ops as xops |
|
|
except ImportError as e: |
|
|
print("Please install xformers to use flashatt v2") |
|
|
raise e |
|
|
|
|
|
|
|
|
def _init_weights(module): |
|
|
""" |
|
|
Initialize weights for transformer modules. |
|
|
|
|
|
Reference: https://github.com/karpathy/nanoGPT/blob/eba36e84649f3c6d840a93092cb779a260544d08/model.py#L162-L168 |
|
|
|
|
|
Args: |
|
|
module: Neural network module to initialize |
|
|
""" |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
if module.bias is not None: |
|
|
torch.nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
""" |
|
|
Multi-layer perceptron with GELU activation. |
|
|
|
|
|
Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L49-L65 |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d, |
|
|
mlp_ratio=4, |
|
|
mlp_bias=False, |
|
|
mlp_dropout=0.0, |
|
|
mlp_dim=None, |
|
|
): |
|
|
""" |
|
|
Initialize MLP layer. |
|
|
|
|
|
Args: |
|
|
d: Input/output dimension |
|
|
mlp_ratio: Hidden dimension ratio (hidden_dim = d * mlp_ratio) |
|
|
mlp_bias: Whether to use bias in linear layers |
|
|
mlp_dropout: Dropout probability |
|
|
mlp_dim: Explicit hidden dimension (overrides mlp_ratio if provided) |
|
|
""" |
|
|
super().__init__() |
|
|
if mlp_dim is None: |
|
|
mlp_dim = d * mlp_ratio |
|
|
|
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(d, mlp_dim, bias=mlp_bias), |
|
|
nn.GELU(), |
|
|
nn.Linear(mlp_dim, d, bias=mlp_bias), |
|
|
nn.Dropout(mlp_dropout), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Forward pass through MLP. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (batch, seq_len, d) |
|
|
|
|
|
Returns: |
|
|
Output tensor of shape (batch, seq_len, d) |
|
|
""" |
|
|
return self.mlp(x) |
|
|
|
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
""" |
|
|
Multi-head self-attention with flash attention support. |
|
|
|
|
|
Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L68-L92 |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d, |
|
|
d_head, |
|
|
attn_qkv_bias=False, |
|
|
attn_dropout=0.0, |
|
|
attn_fc_bias=False, |
|
|
attn_fc_dropout=0.0, |
|
|
use_flashatt_v2=True, |
|
|
): |
|
|
""" |
|
|
Initialize self-attention layer. |
|
|
|
|
|
Args: |
|
|
d: Token dimension |
|
|
d_head: Head dimension |
|
|
attn_qkv_bias: Whether to use bias in QKV projection |
|
|
attn_dropout: Attention dropout probability |
|
|
attn_fc_bias: Whether to use bias in output projection |
|
|
attn_fc_dropout: Output projection dropout probability |
|
|
use_flashatt_v2: Whether to use flash attention v2 |
|
|
""" |
|
|
super().__init__() |
|
|
assert d % d_head == 0, f"Token dimension {d} should be divisible by head dimension {d_head}" |
|
|
|
|
|
self.d = d |
|
|
self.d_head = d_head |
|
|
self.attn_dropout = attn_dropout |
|
|
self.use_flashatt_v2 = use_flashatt_v2 |
|
|
|
|
|
|
|
|
self.to_qkv = nn.Linear(d, 3 * d, bias=attn_qkv_bias) |
|
|
|
|
|
|
|
|
self.fc = nn.Linear(d, d, bias=attn_fc_bias) |
|
|
self.attn_fc_dropout = nn.Dropout(attn_fc_dropout) |
|
|
|
|
|
def forward(self, x, subset_attention_size=None): |
|
|
""" |
|
|
Forward pass through self-attention. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (batch, seq_len, d) |
|
|
subset_attention_size: Optional size for subset attention |
|
|
|
|
|
Returns: |
|
|
Output tensor of shape (batch, seq_len, d) |
|
|
""" |
|
|
|
|
|
q, k, v = self.to_qkv(x).split(self.d, dim=2) |
|
|
|
|
|
if self.use_flashatt_v2: |
|
|
|
|
|
q, k, v = map( |
|
|
lambda t: rearrange(t, "b l (nh dh) -> b l nh dh", dh=self.d_head), |
|
|
(q, k, v), |
|
|
) |
|
|
|
|
|
if subset_attention_size is not None and subset_attention_size < q.shape[1]: |
|
|
|
|
|
x_subset = xops.memory_efficient_attention( |
|
|
q[:, :subset_attention_size, :, :].contiguous(), |
|
|
k[:, :subset_attention_size, :, :].contiguous(), |
|
|
v[:, :subset_attention_size, :, :].contiguous(), |
|
|
attn_bias=None, |
|
|
op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), |
|
|
) |
|
|
x_rest = xops.memory_efficient_attention( |
|
|
q[:, subset_attention_size:, :, :].contiguous(), |
|
|
k, |
|
|
v, |
|
|
attn_bias=None, |
|
|
op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), |
|
|
) |
|
|
x = torch.cat([x_subset, x_rest], dim=1) |
|
|
else: |
|
|
|
|
|
x = xops.memory_efficient_attention( |
|
|
q, k, v, |
|
|
attn_bias=None, |
|
|
op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), |
|
|
) |
|
|
|
|
|
x = rearrange(x, "b l nh dh -> b l (nh dh)") |
|
|
else: |
|
|
|
|
|
q, k, v = ( |
|
|
rearrange(q, "b l (nh dh) -> b nh l dh", dh=self.d_head), |
|
|
rearrange(k, "b l (nh dh) -> b nh l dh", dh=self.d_head), |
|
|
rearrange(v, "b l (nh dh) -> b nh l dh", dh=self.d_head), |
|
|
) |
|
|
|
|
|
dropout_p = self.attn_dropout if self.training else 0.0 |
|
|
|
|
|
if subset_attention_size is not None and subset_attention_size < q.shape[2]: |
|
|
|
|
|
x_subset = F.scaled_dot_product_attention( |
|
|
q[:, :, :subset_attention_size, :].contiguous(), |
|
|
k[:, :, :subset_attention_size, :].contiguous(), |
|
|
v[:, :, :subset_attention_size, :].contiguous(), |
|
|
dropout_p=dropout_p, |
|
|
) |
|
|
x_rest = F.scaled_dot_product_attention( |
|
|
q[:, :, subset_attention_size:, :].contiguous(), |
|
|
k, v, |
|
|
dropout_p=dropout_p, |
|
|
) |
|
|
x = torch.cat([x_subset, x_rest], dim=2) |
|
|
else: |
|
|
|
|
|
x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) |
|
|
|
|
|
x = rearrange(x, "b nh l dh -> b l (nh dh)") |
|
|
|
|
|
|
|
|
return self.attn_fc_dropout(self.fc(x)) |
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
""" |
|
|
Standard transformer block with pre-normalization. |
|
|
|
|
|
Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L95-L113 |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d, |
|
|
d_head, |
|
|
ln_bias=False, |
|
|
attn_qkv_bias=False, |
|
|
attn_dropout=0.0, |
|
|
attn_fc_bias=False, |
|
|
attn_fc_dropout=0.0, |
|
|
mlp_ratio=4, |
|
|
mlp_bias=False, |
|
|
mlp_dropout=0.0, |
|
|
): |
|
|
""" |
|
|
Initialize transformer block. |
|
|
|
|
|
Args: |
|
|
d: Token dimension |
|
|
d_head: Attention head dimension |
|
|
ln_bias: Whether to use bias in layer norm |
|
|
attn_qkv_bias: Whether to use bias in attention QKV projection |
|
|
attn_dropout: Attention dropout probability |
|
|
attn_fc_bias: Whether to use bias in attention output projection |
|
|
attn_fc_dropout: Attention output dropout probability |
|
|
mlp_ratio: MLP hidden dimension ratio |
|
|
mlp_bias: Whether to use bias in MLP layers |
|
|
mlp_dropout: MLP dropout probability |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.norm1 = nn.LayerNorm(d, bias=ln_bias) |
|
|
self.norm2 = nn.LayerNorm(d, bias=ln_bias) |
|
|
|
|
|
|
|
|
self.attn = SelfAttention( |
|
|
d=d, |
|
|
d_head=d_head, |
|
|
attn_qkv_bias=attn_qkv_bias, |
|
|
attn_dropout=attn_dropout, |
|
|
attn_fc_bias=attn_fc_bias, |
|
|
attn_fc_dropout=attn_fc_dropout, |
|
|
) |
|
|
|
|
|
|
|
|
self.mlp = MLP( |
|
|
d=d, |
|
|
mlp_ratio=mlp_ratio, |
|
|
mlp_bias=mlp_bias, |
|
|
mlp_dropout=mlp_dropout, |
|
|
) |
|
|
|
|
|
def forward(self, x, subset_attention_size=None): |
|
|
""" |
|
|
Forward pass through transformer block. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (batch, seq_len, d) |
|
|
subset_attention_size: Optional size for subset attention |
|
|
|
|
|
Returns: |
|
|
Output tensor of shape (batch, seq_len, d) |
|
|
""" |
|
|
|
|
|
x = x + self.attn(self.norm1(x), subset_attention_size=subset_attention_size) |
|
|
|
|
|
|
|
|
x = x + self.mlp(self.norm2(x)) |
|
|
|
|
|
return x |