| import torch |
| import torch.nn as nn |
| import math |
|
|
|
|
| class MaskedMultiHeadedSelfAttention(nn.Module): |
| def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False): |
| super().__init__() |
| assert (d_out % num_heads == 0), "d_out must be divisible by num_heads" |
|
|
| self.d_out = d_out |
| self.num_heads = num_heads |
| self.head_dim = d_out // num_heads |
|
|
| |
| self.W_qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias) |
| self.out_proj = nn.Linear(d_out, d_out, bias=False) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x): |
| batch_size, num_tokens, d_in = x.shape |
|
|
| |
| qkv = self.W_qkv(x) |
| q, k, v = qkv.split(self.d_out, dim=-1) |
|
|
| |
| q = q.view(batch_size, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) |
| k = k.view(batch_size, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) |
| v = v.view(batch_size, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
| |
| context_vec = torch.nn.functional.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask=None, |
| dropout_p=self.dropout.p if self.training else 0.0, |
| is_causal=True |
| ) |
|
|
| |
| context_vec = context_vec.transpose(1, 2).reshape(batch_size, num_tokens, self.d_out) |
| context_vec = self.out_proj(context_vec) |
| return context_vec |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, configuration): |
| super().__init__() |
| dim = configuration["embedding_dim"] |
| self.layers = nn.Sequential( |
| nn.Linear(dim, 4 * dim), |
| nn.GELU(approximate='tanh'), |
| nn.Linear(4 * dim, dim), |
| ) |
|
|
| def forward(self, x): |
| return self.layers(x) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, configuration): |
| super().__init__() |
| self.attention = MaskedMultiHeadedSelfAttention( |
| d_in=configuration["embedding_dim"], |
| d_out=configuration["embedding_dim"], |
| context_length=configuration["context_length"], |
| num_heads=configuration["n_heads"], |
| dropout=configuration["dropout_rate"], |
| qkv_bias=configuration["qkv_bias"] |
| ) |
| self.feed_forward = FeedForward(configuration) |
|
|
| |
| self.norm1 = nn.LayerNorm(configuration["embedding_dim"]) |
| self.norm2 = nn.LayerNorm(configuration["embedding_dim"]) |
| self.drop_shortcut = nn.Dropout(configuration["dropout_rate"]) |
|
|
| def forward(self, x): |
| |
| shortcut = x |
| x = self.norm1(x) |
| x = self.attention(x) |
| x = self.drop_shortcut(x) |
| x = x + shortcut |
|
|
| |
| shortcut = x |
| x = self.norm2(x) |
| x = self.feed_forward(x) |
| x = self.drop_shortcut(x) |
| x = x + shortcut |
|
|
| return x |
|
|
|
|
| class LanguageModel(nn.Module): |
| def __init__(self, configuration): |
| super().__init__() |
| self.config = configuration |
|
|
| self.token_embedding = nn.Embedding(configuration["vocab_size"], configuration["embedding_dim"]) |
| self.pos_embedding = nn.Embedding(configuration["context_length"], configuration["embedding_dim"]) |
| self.drop_embedding = nn.Dropout(configuration["dropout_rate"]) |
|
|
| self.transformer_blocks = nn.Sequential( |
| *[TransformerBlock(configuration) for _ in range(configuration["n_layers"])] |
| ) |
|
|
| |
| self.final_norm = nn.LayerNorm(configuration["embedding_dim"]) |
| self.out_head = nn.Linear(configuration["embedding_dim"], configuration["vocab_size"], bias=False) |
|
|
| |
| self.out_head.weight = self.token_embedding.weight |
|
|
| |
| self.apply(self._init_weights) |
|
|
| |
| |
| |
| |
| residual_std = 0.02 / math.sqrt(2 * configuration["n_layers"]) |
| for pn, p in self.named_parameters(): |
| if pn.endswith('out_proj.weight') or pn.endswith('layers.2.weight'): |
| torch.nn.init.normal_(p, mean=0.0, std=residual_std) |
|
|
| def _init_weights(self, module): |
| """Initialize weights following GPT-2: N(0, 0.02) for all projections.""" |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def forward(self, in_idx): |
| batch_size, seq_len = in_idx.shape |
| tok_embeds = self.token_embedding(in_idx) |
| pos_embeds = self.pos_embedding(torch.arange(seq_len, device=in_idx.device)) |
| x = tok_embeds + pos_embeds |
| x = self.drop_embedding(x) |
| x = self.transformer_blocks(x) |
| x = self.final_norm(x) |
| logits = self.out_head(x) |
| return logits |