modell-name / src /models /transformer.py
RabidUmarell's picture
Add model checkpoint and source
8006486 verified
"""Causal Transformer for the selective copy task."""
import torch
import torch.nn as nn
class CausalSelfAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.0):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.proj = nn.Linear(d_model, d_model, bias=False)
self.attn_drop_p = dropout
def forward(self, x):
B, T, C = x.shape
q, k, v = self.qkv(x).split(C, dim=-1)
def split_heads(t):
return t.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
out = torch.nn.functional.scaled_dot_product_attention(
split_heads(q), split_heads(k), split_heads(v),
is_causal=True,
dropout_p=self.attn_drop_p if self.training else 0.0,
)
return self.proj(out.transpose(1, 2).contiguous().view(B, T, C))
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.0):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = CausalSelfAttention(d_model, n_heads, dropout)
self.ln2 = nn.LayerNorm(d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, 4 * d_model), nn.GELU(),
nn.Linear(4 * d_model, d_model), nn.Dropout(dropout),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class TransformerModel(nn.Module):
def __init__(self, d_input, d_model, d_output, n_layers=2, n_heads=4, max_len=4096, dropout=0.0, **kwargs):
super().__init__()
self.input_proj = nn.Linear(d_input, d_model)
self.pos_emb = nn.Embedding(max_len, d_model)
self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, dropout) for _ in range(n_layers)])
self.ln_f = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, d_output)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""(B, T, d_input) → (B, T, d_output)"""
B, T, _ = x.shape
if T > self.pos_emb.num_embeddings:
raise ValueError(f"Sequence length {T} exceeds max_len {self.pos_emb.num_embeddings}")
h = self.input_proj(x) + self.pos_emb(torch.arange(T, device=x.device))
for block in self.blocks:
h = block(h)
return self.head(self.ln_f(h))
@staticmethod
def extra_kwargs(model_cfg) -> dict:
return {"n_heads": model_cfg.n_heads}