| """Full definition of a GPT NeoX Language Model, all of it in this single file. |
| |
| Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and |
| https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. |
| """ |
| import math |
| from typing import Any, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| from typing_extensions import Self |
|
|
| from lit_gpt.config import Config |
|
|
|
|
| class GPT(nn.Module): |
| def __init__(self, config: Config) -> None: |
| super().__init__() |
| assert config.padded_vocab_size is not None |
| self.config = config |
|
|
| self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) |
| self.transformer = nn.ModuleDict( |
| dict( |
| wte=nn.Embedding(config.padded_vocab_size, config.n_embd), |
| h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), |
| ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), |
| ) |
| ) |
| self.max_seq_length = self.config.block_size |
| self.mask_cache: Optional[torch.Tensor] = None |
|
|
| @property |
| def max_seq_length(self) -> int: |
| return self._max_seq_length |
|
|
| @max_seq_length.setter |
| def max_seq_length(self, value: int) -> None: |
| """ |
| When doing inference, the sequences used might be shorter than the model's context length. |
| This allows setting a smaller number to avoid allocating unused memory |
| """ |
| if value > self.config.block_size: |
| raise ValueError(f"Cannot attend to {value}, block size is only {self.config.block_size}") |
| self._max_seq_length = value |
| if not hasattr(self, "cos"): |
| |
| cos, sin = self.rope_cache() |
| self.register_buffer("cos", cos, persistent=False) |
| self.register_buffer("sin", sin, persistent=False) |
| elif value != self.cos.size(0): |
| |
| self.cos, self.sin = self.rope_cache(device=self.cos.device) |
| |
| |
|
|
| def reset_parameters(self) -> None: |
| |
| self.max_seq_length = self.config.block_size |
|
|
| def _init_weights(self, module: nn.Module) -> None: |
| """Meant to be used with `gpt.apply(gpt._init_weights)`.""" |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, maxlen: int = None) -> torch.Tensor: |
| T = idx.size(1) if maxlen is None else maxlen |
| |
| if self.max_seq_length < T: |
| raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") |
|
|
| |
| if input_pos is not None: |
| cos = self.cos.index_select(0, input_pos) |
| sin = self.sin.index_select(0, input_pos) |
| if self.mask_cache is None: |
| raise TypeError("You need to call `gpt.set_kv_cache()`") |
| mask = self.mask_cache.index_select(2, input_pos) |
| else: |
| cos = self.cos[:T] |
| sin = self.sin[:T] |
| mask = None |
|
|
| if type(idx) is tuple: |
| stack_before_tokens_x, motion_tokens, before_len = idx |
| |
| |
| |
| x = self.transformer.wte(stack_before_tokens_x.cuda()) |
| |
| for i in range(len(x)): |
| x[i][before_len[i]: before_len[i] + len(motion_tokens[i])] = motion_tokens[i].cuda() |
| else: |
| x = self.transformer.wte(idx) |
| for block in self.transformer.h: |
| x = block(x, cos, sin, mask, input_pos) |
| x = self.transformer.ln_f(x) |
| return self.lm_head(x) |
|
|
| @classmethod |
| def from_name(cls, name: str, **kwargs: Any) -> Self: |
| return cls(Config.from_name(name, **kwargs)) |
|
|
| def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
| return build_rope_cache( |
| seq_len=self.max_seq_length, |
| n_elem=self.config.rope_n_elem, |
| device=device, |
| condense_ratio=self.config.rope_condense_ratio, |
| base=self.config.rope_base, |
| ) |
|
|
| def set_kv_cache( |
| self, |
| batch_size: int, |
| rope_cache_length: Optional[int] = None, |
| device: Optional[torch.device] = None, |
| dtype: Optional[torch.dtype] = None, |
| ) -> None: |
| if rope_cache_length is None: |
| rope_cache_length = self.cos.size(-1) |
| max_seq_length = self.max_seq_length |
|
|
| |
| for block in self.transformer.h: |
| block.attn.kv_cache = block.attn.build_kv_cache( |
| batch_size, max_seq_length, rope_cache_length, device, dtype |
| ) |
|
|
| if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length: |
| |
| |
| |
| ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool) |
| self.mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0) |
|
|
| def clear_kv_cache(self) -> None: |
| self.mask_cache = None |
| for block in self.transformer.h: |
| block.attn.kv_cache = None |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, config: Config) -> None: |
| super().__init__() |
| self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) |
| self.attn = CausalSelfAttention(config) |
| self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps) |
| self.mlp = config.mlp_class(config) |
|
|
| self.config = config |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| input_pos: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| n_1 = self.norm_1(x) |
| h = self.attn(n_1, cos, sin, mask, input_pos) |
| if self.config.parallel_residual: |
| n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) |
| x = self.mlp(n_2) + h + x |
| else: |
| if self.config.shared_attention_norm: |
| raise NotImplementedError( |
| "No checkpoint amongst the ones we support uses this configuration" |
| " (non-parallel residual and shared attention norm)." |
| ) |
| x = h + x |
| x = self.mlp(self.norm_2(x)) + x |
| return x |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, config: Config) -> None: |
| super().__init__() |
| shape = (config.n_head + 2 * config.n_query_groups) * config.head_size |
| |
| self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) |
| |
| self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
| |
| self.kv_cache: Optional[KVCache] = None |
|
|
| self.config = config |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| input_pos: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| B, T, C = x.size() |
|
|
| qkv = self.attn(x) |
|
|
| |
| q_per_kv = self.config.n_head // self.config.n_query_groups |
| total_qkv = q_per_kv + 2 |
| qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) |
| qkv = qkv.permute(0, 2, 3, 1, 4) |
|
|
| |
| q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) |
|
|
| |
| |
| |
| if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): |
| k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) |
| v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) |
|
|
| q = q.reshape(B, -1, T, self.config.head_size) |
| k = k.reshape(B, -1, T, self.config.head_size) |
| v = v.reshape(B, -1, T, self.config.head_size) |
|
|
| q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) |
| k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) |
| q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) |
| k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) |
|
|
| if input_pos is not None: |
| if not isinstance(self.kv_cache, KVCache): |
| raise TypeError("You need to call `gpt.set_kv_cache()`") |
| k, v = self.kv_cache(input_pos, k, v) |
|
|
| y = self.scaled_dot_product_attention(q, k, v, mask) |
|
|
| y = y.reshape(B, T, C) |
|
|
| |
| return self.proj(y) |
|
|
| def scaled_dot_product_attention( |
| self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| scale = 1.0 / math.sqrt(self.config.head_size) |
| y = torch.nn.functional.scaled_dot_product_attention( |
| q, k, v, attn_mask=mask, dropout_p=0.0, |
| |
| is_causal=mask is None |
| ) |
| return y.transpose(1, 2) |
|
|
| def build_kv_cache( |
| self, |
| batch_size: int, |
| max_seq_length: int, |
| rope_cache_length: Optional[int] = None, |
| device: Optional[torch.device] = None, |
| dtype: Optional[torch.dtype] = None, |
| ) -> "KVCache": |
| heads = 1 if self.config.n_query_groups == 1 else self.config.n_head |
| v_shape = (batch_size, heads, max_seq_length, self.config.head_size) |
| if rope_cache_length is None: |
| if self.config.rotary_percentage != 1.0: |
| raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value") |
| k_shape = v_shape |
| else: |
| k_shape = ( |
| batch_size, |
| heads, |
| max_seq_length, |
| rope_cache_length + self.config.head_size - self.config.rope_n_elem, |
| ) |
| return KVCache(k_shape, v_shape, device=device, dtype=dtype) |
|
|
|
|
| class GptNeoxMLP(nn.Module): |
| def __init__(self, config: Config) -> None: |
| super().__init__() |
| self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) |
| self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) |
|
|
| self.config = config |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.fc(x) |
| x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate) |
| return self.proj(x) |
|
|
|
|
| class LLaMAMLP(nn.Module): |
| def __init__(self, config: Config) -> None: |
| super().__init__() |
| self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) |
| self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) |
| self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x_fc_1 = self.fc_1(x) |
| x_fc_2 = self.fc_2(x) |
| x = torch.nn.functional.silu(x_fc_1) * x_fc_2 |
| return self.proj(x) |
|
|
|
|
| def build_rope_cache( |
| seq_len: int, n_elem: int, device: Optional[torch.device] = None, base: int = 10000, condense_ratio: int = 1 |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Enhanced Transformer with Rotary Position Embedding. |
| |
| Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ |
| transformers/rope/__init__.py. MIT License: |
| https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. |
| """ |
| |
| theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) |
|
|
| |
| seq_idx = torch.arange(seq_len, device=device) / condense_ratio |
|
|
| |
| idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) |
|
|
| return torch.cos(idx_theta), torch.sin(idx_theta) |
|
|
|
|
| def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| head_size = x.size(-1) |
| x1 = x[..., : head_size // 2] |
| x2 = x[..., head_size // 2 :] |
| rotated = torch.cat((-x2, x1), dim=-1) |
| roped = (x * cos) + (rotated * sin) |
| return roped.type_as(x) |
|
|
|
|
| class KVCache(nn.Module): |
| def __init__( |
| self, |
| k_shape: Tuple[int, int, int, int], |
| v_shape: Tuple[int, int, int, int], |
| device: Optional[torch.device] = None, |
| dtype: Optional[torch.dtype] = None, |
| ) -> None: |
| super().__init__() |
| self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False) |
| self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False) |
|
|
| def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| |
| self.k = self.k.to(k.dtype) |
| self.v = self.v.to(v.dtype) |
| |
| k = self.k.index_copy_(2, input_pos, k) |
| v = self.v.index_copy_(2, input_pos, v) |
| return k, v |
|
|