Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import tqdm | |
| class Head(nn.Module): | |
| """One head of self-attention.""" | |
| def __init__(self, n_embd, head_size, block_size, dropout): | |
| super().__init__() | |
| self.key = nn.Linear(n_embd, head_size, bias=False) | |
| self.query = nn.Linear(n_embd, head_size, bias=False) | |
| self.value = nn.Linear(n_embd, head_size, bias=False) | |
| self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| B, T, C = x.shape | |
| k = self.key(x) | |
| q = self.query(x) | |
| wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5 | |
| wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) | |
| wei = F.softmax(wei, dim=-1) | |
| wei = self.dropout(wei) | |
| v = self.value(x) | |
| out = wei @ v | |
| return out | |
| class MultiHeadAttention(nn.Module): | |
| """Multiple heads of self-attention in parallel.""" | |
| def __init__(self, n_embd, n_head, block_size, dropout): | |
| super().__init__() | |
| assert n_embd % n_head == 0, f"n_embd ({n_embd}) must be divisible by num_heads ({n_head})" | |
| self.n_embd = n_embd | |
| self.n_head = n_head | |
| self.head_size = n_embd // n_head | |
| self.heads = nn.ModuleList([Head(n_embd, self.head_size, block_size, dropout) for _ in range(n_head)]) | |
| self.proj = nn.Linear(n_embd, n_embd) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| out = torch.cat([h(x) for h in self.heads], dim=-1) | |
| out = self.dropout(self.proj(out)) | |
| return out | |
| class FeedForward(nn.Module): | |
| """A simple linear layer followed by a non-linearity.""" | |
| def __init__(self, n_embd, dropout): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(n_embd, 4 * n_embd), | |
| nn.ReLU(), | |
| nn.Linear(4 * n_embd, n_embd), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class Block(nn.Module): | |
| """Transformer block: communication followed by computation.""" | |
| def __init__(self, n_embd, n_head, block_size, dropout): | |
| super().__init__() | |
| self.sa = MultiHeadAttention(n_embd, n_head, block_size, dropout) | |
| self.ffwd = FeedForward(n_embd, dropout) | |
| self.ln1 = nn.LayerNorm(n_embd) | |
| self.ln2 = nn.LayerNorm(n_embd) | |
| def forward(self, x): | |
| x = x + self.sa(self.ln1(x)) | |
| x = x + self.ffwd(self.ln2(x)) | |
| return x | |
| class RoPE(nn.Module): | |
| """Rotary Positional Encoding (RoPE) layer.""" | |
| def __init__(self, embd_dim, max_freq=10): | |
| super().__init__() | |
| self.embd_dim = embd_dim | |
| self.max_freq = max_freq | |
| self.freqs = 2 ** torch.linspace(0, max_freq - 1, embd_dim // 2) * torch.pi | |
| self.inv_freqs = 1. / self.freqs | |
| def forward(self, x): | |
| x = x + torch.sin(x @ self.freqs) * self.inv_freqs | |
| x = x + torch.cos(x @ self.freqs) * self.inv_freqs | |
| return x | |
| class RMSNorm(nn.Module): | |
| """Root Mean Square Layer Normalization (RMSNorm).""" | |
| def __init__(self, embd_dim, epsilon=1e-8): | |
| super().__init__() | |
| self.embd_dim = embd_dim | |
| self.epsilon = epsilon | |
| self.gamma = nn.Parameter(torch.ones(embd_dim)) | |
| self.beta = nn.Parameter(torch.zeros(embd_dim)) | |
| def forward(self, x: torch.Tensor): | |
| mean = x.mean(-1, keepdim=True) | |
| variance = x.var(-1, keepdim=True) | |
| x = x - mean | |
| x = x / torch.sqrt(variance + self.epsilon) | |
| x = x * self.gamma + self.beta | |
| return x | |
| class LlamaFFN(nn.Module): | |
| """Feed-forward network of the LLAMA model with SwiGLU activation.""" | |
| def __init__(self, n_embd, dropout): | |
| super().__init__() | |
| self.linear1 = nn.Linear(n_embd, 4 * n_embd) | |
| self.linear2 = nn.Linear(4 * n_embd, n_embd) | |
| self.dropout = nn.Dropout(dropout) | |
| def swiglu(self, x): | |
| """Applies SwiGLU activation.""" | |
| x1, x2 = torch.chunk(x, 2, dim=-1) | |
| return x1 * F.silu(x2) | |
| def forward(self, x): | |
| x = self.linear1(x) | |
| x = self.swiglu(x) | |
| x = self.dropout(x) | |
| x = self.linear2(x) | |
| return x | |
| class AttentionHeadWithKVCacheAndRoPE(nn.Module): | |
| """One head of self-attention with key and value cache and RoPE.""" | |
| def __init__(self, n_embd, head_size, block_size, dropout): | |
| super().__init__() | |
| self.key = nn.Linear(n_embd, head_size, bias=False) | |
| self.query = nn.Linear(n_embd, head_size, bias=False) | |
| self.value = nn.Linear(n_embd, head_size, bias=False) | |
| self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) | |
| self.dropout = nn.Dropout(dropout) | |
| self.pe = RoPE(head_size) | |
| self.ln = RMSNorm(n_embd) | |
| def forward(self, x, kv_cache): | |
| B, T, C = x.shape | |
| k = self.key(x) | |
| q = self.query(x) | |
| v = self.value(x) | |
| if kv_cache is not None: | |
| k = torch.cat([kv_cache['k'], k], dim=1) | |
| v = torch.cat([kv_cache['v'], v], dim=1) | |
| wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5 | |
| wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) | |
| wei = F.softmax(wei, dim=-1) | |
| wei = self.dropout(wei) | |
| out = wei @ v | |
| if kv_cache is None: | |
| kv_cache = {'k': k, 'q': q, 'v': v} | |
| else: | |
| kv_cache['k'] = k | |
| kv_cache['q'] = q | |
| kv_cache['v'] = v | |
| return self.pe(out) + x | |
| class MultiHeadAttentionWithKVCacheAndRoPE(nn.Module): | |
| """Multiple heads of self-attention in parallel.""" | |
| def __init__(self, n_embd, n_head, block_size, dropout): | |
| super().__init__() | |
| assert n_embd % n_head == 0, f"n_embd ({n_embd}) must be divisible by num_heads ({n_head})" | |
| self.n_embd = n_embd | |
| self.n_head = n_head | |
| self.head_size = n_embd // n_head | |
| self.heads = nn.ModuleList([AttentionHeadWithKVCacheAndRoPE(n_embd, self.head_size, block_size, dropout) for _ in range(n_head)]) | |
| self.proj = nn.Linear(n_embd, n_embd) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, kv_cache): | |
| out = torch.cat([h(x, kv_cache) for h in self.heads], dim=-1) | |
| out = self.dropout(self.proj(out)) | |
| return out | |
| class LlamaBlock(nn.Module): | |
| """LLAMA block: communication followed by computation.""" | |
| def __init__(self, n_embd, n_head, block_size, dropout): | |
| super().__init__() | |
| self.ln1 = RMSNorm(n_embd) | |
| self.sa = MultiHeadAttentionWithKVCacheAndRoPE(n_embd, n_head, block_size, dropout) | |
| self.ln2 = RMSNorm(n_embd) | |
| self.ffwd = LlamaFFN(n_embd, dropout) | |
| def forward(self, x, kv_cache): | |
| x = x + self.sa(self.ln1(x), kv_cache) | |
| x = x + self.ffwd(self.ln2(x)) | |
| return x | |