"""Causal Transformer for the selective copy task.""" import torch import torch.nn as nn class CausalSelfAttention(nn.Module): def __init__(self, d_model, n_heads, dropout=0.0): super().__init__() assert d_model % n_heads == 0 self.n_heads = n_heads self.d_head = d_model // n_heads self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) self.proj = nn.Linear(d_model, d_model, bias=False) self.attn_drop_p = dropout def forward(self, x): B, T, C = x.shape q, k, v = self.qkv(x).split(C, dim=-1) def split_heads(t): return t.view(B, T, self.n_heads, self.d_head).transpose(1, 2) out = torch.nn.functional.scaled_dot_product_attention( split_heads(q), split_heads(k), split_heads(v), is_causal=True, dropout_p=self.attn_drop_p if self.training else 0.0, ) return self.proj(out.transpose(1, 2).contiguous().view(B, T, C)) class TransformerBlock(nn.Module): def __init__(self, d_model, n_heads, dropout=0.0): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.attn = CausalSelfAttention(d_model, n_heads, dropout) self.ln2 = nn.LayerNorm(d_model) self.mlp = nn.Sequential( nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Linear(4 * d_model, d_model), nn.Dropout(dropout), ) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class TransformerModel(nn.Module): def __init__(self, d_input, d_model, d_output, n_layers=2, n_heads=4, max_len=4096, dropout=0.0, **kwargs): super().__init__() self.input_proj = nn.Linear(d_input, d_model) self.pos_emb = nn.Embedding(max_len, d_model) self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, dropout) for _ in range(n_layers)]) self.ln_f = nn.LayerNorm(d_model) self.head = nn.Linear(d_model, d_output) def forward(self, x: torch.Tensor) -> torch.Tensor: """(B, T, d_input) → (B, T, d_output)""" B, T, _ = x.shape if T > self.pos_emb.num_embeddings: raise ValueError(f"Sequence length {T} exceeds max_len {self.pos_emb.num_embeddings}") h = self.input_proj(x) + self.pos_emb(torch.arange(T, device=x.device)) for block in self.blocks: h = block(h) return self.head(self.ln_f(h)) @staticmethod def extra_kwargs(model_cfg) -> dict: return {"n_heads": model_cfg.n_heads}