Spaces:
Sleeping
Sleeping
File size: 7,775 Bytes
bb2fa48 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | """ Architecture of the TransformerDecoder """
import torch
import torch.nn as nn
from torch.nn import functional as F
class TransformerDecoder(nn.Module):
""" GPT-style decoder-only language model """
def __init__(self, vocab_size, hyperparam_cfg, device):
super(TransformerDecoder, self).__init__()
self.device = device
# model hyperparameters
embedding_dim = hyperparam_cfg.embedding_dim
num_layers = hyperparam_cfg.num_layers
self.context_len = hyperparam_cfg.context_len
# lookup table of tokens is used so that each token reads the logits for the next token
self.token_embedding_table = nn.Embedding(vocab_size, embedding_dim)
# pos embedding table adds information about the position of each token in the context
self.pos_embedding_table = nn.Embedding(self.context_len, embedding_dim)
# stack multiple transformer blocks to increase model capacity
self.tfblocks = nn.Sequential(*[TFBlock(hyperparam_cfg) for _ in range(num_layers)])
# final normalization and linear layer to produce logits for each token in the vocabulary
self.ln_f = nn.LayerNorm(embedding_dim)
self.lm_head = nn.Linear(embedding_dim, vocab_size)
# better weight initialization for
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx):
"""
The forward pass of the model returns the logits of shape (B,T,C)
# where: B=batch_size T=context_len C=vocab_size
"""
# idx is a (B,T) tensor of integers which are indices in the current context
B, T = idx.shape
token_embd = self.token_embedding_table(idx) # (batch_size, context_len, embedding_dim)
positions = torch.arange(T).to(self.device) # tensor([0, 1, 2, ..., T-1])
pos_embd = self.pos_embedding_table(positions) # (context_len, embedding_dim)
x = token_embd + pos_embd # (batch_size, context_len, embedding_dim)
x = self.tfblocks(x) # (batch_size, context_len, embedding_dim)
x = self.ln_f(x) # (batch_size, context_len, embedding_dim)
logits = self.lm_head(x) # (batch_size, context_len, vocab_size)
return logits
def generate(self, idx, max_new_tokens):
""" Generate new tokens from the model """
for _ in range(max_new_tokens):
# crop idx to the last context_len tokens
idx_context = idx[:, -self.context_len:]
# get the predictions
logits = self(idx_context) # (B,T,C)
# focus only on the last time step
logits = logits[:, -1, :] # (B, C)
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution to get the next token index
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx
class TFBlock(nn.Module):
""" Single transformer block: communication (attention) followed by computation (dense) """
def __init__(self, hyperparam_cfg):
super(TFBlock, self).__init__()
# model hyperparameters
embedding_dim = hyperparam_cfg.embedding_dim
num_heads = hyperparam_cfg.num_heads
context_len = hyperparam_cfg.context_len
dropout = hyperparam_cfg.dropout
# size of MultiHeadAttention matches the embedding dimension (num_heads * head_size = embedding_dim)
self.sa_heads = MultiHeadAttention(num_heads=num_heads,
head_size=embedding_dim // num_heads,
embedding_dim=embedding_dim,
context_len=context_len,
dropout=dropout)
self.feed_forward = FeedForward(embedding_dim, dropout)
self.ln1 = nn.LayerNorm(embedding_dim)
self.ln2 = nn.LayerNorm(embedding_dim)
def forward(self, x):
# both attention and feed-forward layers have residual connections
x = x + self.sa_heads(self.ln1(x))
x = x + self.feed_forward(self.ln2(x))
return x
class MultiHeadAttention(nn.Module):
""" Multiple heads of self-attention in parallel """
def __init__(self, num_heads, head_size, embedding_dim, context_len, dropout):
super(MultiHeadAttention, self).__init__()
self.heads = nn.ModuleList([AttentionHead(embedding_dim, head_size, context_len, dropout) for _ in range(num_heads)])
# projection is needed due to residual connection to bring all heads back to embedding_dim
self.projection = nn.Linear(num_heads * head_size, embedding_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = torch.cat([h(x) for h in self.heads], dim=-1) # (batch, context_len, num_heads * head_size)
out = self.dropout(self.projection(x)) # (batch, context_len, embedding_dim)
return out
class AttentionHead(nn.Module):
""" One head of self-attention """
def __init__(self, embedding_dim, head_size, context_len, dropout):
super(AttentionHead, self).__init__()
self.queries = nn.Linear(embedding_dim, head_size, bias=False)
self.keys = nn.Linear(embedding_dim, head_size, bias=False)
self.values = nn.Linear(embedding_dim, head_size, bias=False)
self.dropout = nn.Dropout(dropout)
# lower triangular matrix is used to mask out future tokens in the attention mechanism
self.register_buffer("mask", torch.tril(torch.ones(context_len, context_len)))
def forward(self, x):
B, T, C = x.shape # (batch_size, context_len, embedding_dim)
q = self.queries(x) # (batch, context_len, head_size)
k = self.keys(x) # (batch, context_len, head_size)
v = self.values(x) # (batch, context_len, head_size)
# compute attention matrix (key and query dot product)
weights = q @ k.transpose(-2, -1) # (B,T,C) @ (B,C,T) -> (B,T,T)
# scale by sqrt(head_size) to prevent large dot products (stabilizes gradients)
weights = weights * C**-0.5
# mask replaces 0 with -inf and keeps 1 as is (ones are on and below diagonal; zeros above diagonal)
weights = weights.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
# softmax along the last dimension to get probabilities per row
weights = F.softmax(weights, dim=-1)
weights = self.dropout(weights)
output = weights @ v # matrix multiplication (T,T) @ (B,T,C) -> (B,T,C) = (batch, context_len, head_size)
return output
class FeedForward(nn.Module):
""" Single feed-forward layer followed by a non-linearity """
def __init__(self, embedding_dim, dropout):
super(FeedForward, self).__init__()
# embedding_dim is multiplied by 4 to reflect the original transformer paper
self.net = nn.Sequential(
nn.Linear(embedding_dim, embedding_dim * 4),
nn.ReLU(),
nn.Linear(embedding_dim * 4, embedding_dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
|