| """Full definition of a LLaMA Language Model, all of it in this single file. |
| |
| Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT. |
| """ |
| |
| import math |
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
| from typing_extensions import Self |
|
|
|
|
| @dataclass |
| class LLaMAConfig: |
| block_size: int = 4096 |
| vocab_size: int = 32000 |
| n_layer: int = 32 |
| n_head: int = 32 |
| n_embd: int = 4096 |
|
|
| @classmethod |
| def from_name(cls, name: str) -> Self: |
| return cls(**llama_configs[name]) |
|
|
|
|
| llama_configs = { |
| "7B": dict(n_layer=32, n_head=32, n_embd=4096), |
| "13B": dict(n_layer=40, n_head=40, n_embd=5120), |
| "30B": dict(n_layer=60, n_head=52, n_embd=6656), |
| "65B": dict(n_layer=80, n_head=64, n_embd=8192), |
| } |
|
|
|
|
| class LLaMA(nn.Module): |
| def __init__(self, config: LLaMAConfig) -> None: |
| super().__init__() |
| assert config.vocab_size is not None |
| assert config.block_size is not None |
| self.config = config |
|
|
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| self.transformer = nn.ModuleDict( |
| dict( |
| wte=nn.Embedding(config.vocab_size, config.n_embd), |
| h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
| ln_f=RMSNorm(config.n_embd), |
| ) |
| ) |
| |
| |
| |
| |
| |
| self.llama_proj = nn.Linear(512, config.n_embd) |
| |
| |
| |
| |
| |
| self.motion_proj = nn.Linear(config.n_embd, 512) |
|
|
| def _init_weights(self, module: nn.Module) -> None: |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) |
|
|
| def forward(self, idx: torch.Tensor) -> torch.Tensor: |
| |
| _, t = idx.size() |
| assert ( |
| t <= self.config.block_size |
| ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
| |
| x = self.transformer.wte(idx) |
|
|
| for block in self.transformer.h: |
| x = block(x) |
| x = self.transformer.ln_f(x) |
|
|
| logits = self.lm_head(x) |
|
|
| return logits |
|
|
| @classmethod |
| def from_name(cls, name: str) -> Self: |
| return cls(LLaMAConfig.from_name(name)) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, config: LLaMAConfig) -> None: |
| super().__init__() |
| self.rms_1 = RMSNorm(config.n_embd) |
| self.attn = CausalSelfAttention(config) |
| self.rms_2 = RMSNorm(config.n_embd) |
| self.mlp = MLP(config) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x + self.attn(self.rms_1(x)) |
| x = x + self.mlp(self.rms_2(x)) |
| return x |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, config: LLaMAConfig) -> None: |
| super().__init__() |
| assert config.n_embd % config.n_head == 0 |
|
|
| |
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) |
| |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) |
|
|
| self.n_head = config.n_head |
| self.n_embd = config.n_embd |
| self.block_size = config.block_size |
| self.rope_cache = None |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, T, C = x.size() |
|
|
| |
| q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
|
|
| head_size = C // self.n_head |
| k = k.view(B, T, self.n_head, head_size).transpose(1, 2) |
| q = q.view(B, T, self.n_head, head_size).transpose(1, 2) |
| v = v.view(B, T, self.n_head, head_size).transpose(1, 2) |
|
|
| if self.rope_cache is None: |
| |
| self.rope_cache = build_rope_cache( |
| seq_len=self.block_size, |
| n_elem=self.n_embd // self.n_head, |
| dtype=x.dtype, |
| device=x.device, |
| ) |
|
|
| q = apply_rope(q, self.rope_cache) |
| k = apply_rope(k, self.rope_cache) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True) |
|
|
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
|
|
| |
| y = self.c_proj(y) |
|
|
| return y |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, config: LLaMAConfig) -> None: |
| super().__init__() |
| hidden_dim = 4 * config.n_embd |
| n_hidden = int(2 * hidden_dim / 3) |
| N = 256 |
| |
| n_hidden = ((n_hidden - 1) // N) * N + N |
|
|
| self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False) |
| self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False) |
| self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = F.silu(self.c_fc1(x)) * self.c_fc2(x) |
| x = self.c_proj(x) |
| return x |
|
|
|
|
| class RMSNorm(nn.Module): |
| """Root Mean Square Layer Normalization. |
| |
| Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: |
| https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. |
| """ |
|
|
| def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: |
| super().__init__() |
| self.scale = nn.Parameter(torch.ones(size)) |
| self.eps = eps |
| self.dim = dim |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| |
| |
| |
| norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) |
| x_normed = x * torch.rsqrt(norm_x + self.eps) |
| return self.scale * x_normed |
|
|
|
|
| def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000) -> torch.Tensor: |
| """Enhanced Transformer with Rotary Position Embedding. |
| |
| Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ |
| transformers/rope/__init__.py. MIT License: |
| https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. |
| """ |
| |
| theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) |
|
|
| |
| seq_idx = torch.arange(seq_len, dtype=dtype, device=device) |
|
|
| |
| idx_theta = torch.outer(seq_idx, theta) |
|
|
| |
| |
| dtypes_requiring_casting = [torch.float16, torch.bfloat16, torch.int8] |
| working_dtype = ( |
| torch.float32 if dtype in dtypes_requiring_casting else dtype |
| ) |
| complex_dtype = ( |
| torch.complex32 if dtype in dtypes_requiring_casting else torch.complex64 |
| ) |
| cache = torch.polar( |
| torch.ones_like(idx_theta).to(working_dtype), idx_theta.to(working_dtype) |
| ).to(complex_dtype) |
| return cache |
|
|
|
|
| def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: |
| x = x.transpose(1, 2) |
|
|
| |
| T = x.size(1) |
| rope_cache = rope_cache[:T] |
|
|
| |
| xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) |
| rope_cache = rope_cache.view(1, xc.size(1), 1, xc.size(3)) |
| x_out = torch.view_as_real(xc * rope_cache).flatten(3) |
| return x_out.transpose(1, 2).type_as(x) |
|
|