Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple, Union | |
| import mp_utils | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from xformers.ops import RMSNorm, fmha, rope_padded | |
| from xformers.ops.fmha.attn_bias import ( | |
| BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias, | |
| ) | |
| class ModelArgs: | |
| dim: int = 512 | |
| n_layers: int = 8 | |
| n_heads: int = 8 | |
| n_kv_heads: Optional[int] = None | |
| vocab_size: int = -1 | |
| ffn_dim_multiplier: Optional[float] = None | |
| multiple_of: int = 256 | |
| """ | |
| Enforces that the SwiGLU hidden layer size is a multiple | |
| of large power of 2. | |
| """ | |
| norm_eps: float = 1e-5 | |
| rope_theta: float = 10000.0 | |
| """ | |
| Positional encoding parameter; increase to 1e6 to run | |
| Code Llama models with long contexts. | |
| """ | |
| LayerCache = Tuple[torch.Tensor, torch.Tensor] | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| head_dim: int, | |
| n_heads: int, | |
| n_kv_heads: int, | |
| rope_theta: float, | |
| ): | |
| super().__init__() | |
| mp_size = mp_utils.get_world_size() | |
| self.head_dim = head_dim | |
| self.rope_theta = rope_theta | |
| self.n_local_heads = n_heads // mp_size | |
| self.n_local_kv_heads = n_kv_heads // mp_size | |
| self.wqkv = nn.Linear( | |
| dim, | |
| (self.n_local_heads + 2 * self.n_local_kv_heads) * head_dim, | |
| bias=False, | |
| ) | |
| self.wo = nn.Linear( | |
| self.n_local_heads * head_dim, | |
| dim, | |
| bias=False, | |
| ) | |
| self._register_load_state_dict_pre_hook(self.load_hook) | |
| # This adapter makes sure we can load vanilla | |
| # Llama checkpoints where wq, wk, and wv are | |
| # not fused in a single parameter | |
| def load_hook( | |
| self, | |
| state_dict, | |
| prefix, | |
| local_metadata, | |
| strict, | |
| missing_keys, | |
| unexpected_keys, | |
| error_msgs, | |
| ): | |
| if prefix + "wq.weight" in state_dict: | |
| wq = state_dict.pop(prefix + "wq.weight") | |
| wk = state_dict.pop(prefix + "wk.weight") | |
| wv = state_dict.pop(prefix + "wv.weight") | |
| state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| cache: LayerCache, | |
| attn_bias: AttnBias, | |
| position_index: Optional[torch.Tensor], | |
| ) -> torch.Tensor: | |
| # x.shape is (sum(seq_lens), dim) | |
| # | |
| # Since we support heterogenous sequence | |
| # lengths, the hidden states are all | |
| # concatenated together along the usual | |
| # sequence dimension. The attention below | |
| # finds out where sequences start & end | |
| # using the provided attention bias. | |
| xqkv = self.wqkv(x) | |
| xq = xqkv[:, : (self.n_local_heads * self.head_dim)] | |
| xkv = xqkv[:, (self.n_local_heads * self.head_dim) :] | |
| xk, xv = xkv.chunk(2, 1) | |
| output_shape = xq.shape | |
| heads_per_group = self.n_local_heads // self.n_local_kv_heads | |
| xq = xq.view( | |
| 1, xq.shape[0], self.n_local_kv_heads, heads_per_group, self.head_dim | |
| ) | |
| xk = xk.view(1, xk.shape[0], self.n_local_kv_heads, 1, self.head_dim) | |
| xv = xv.view(1, xv.shape[0], self.n_local_kv_heads, 1, self.head_dim) | |
| cache_k, cache_v = cache | |
| xq = rope_padded( | |
| xq=xq, | |
| xk=xk, | |
| xv=xv, | |
| cache_k=cache_k, | |
| cache_v=cache_v, | |
| attn_bias=attn_bias, | |
| theta=self.rope_theta, | |
| ) | |
| # rope_padded() updated the caches, so we | |
| # call attention directly | |
| output = fmha.memory_efficient_attention_forward( | |
| xq, cache_k, cache_v, attn_bias | |
| ) | |
| output = output.reshape(output_shape) | |
| if position_index is not None: | |
| output = output[position_index] | |
| output = self.wo(output) | |
| mp_utils.all_reduce(output) | |
| return output | |
| class FeedForward(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| hidden_dim: int, | |
| multiple_of: int, | |
| ffn_dim_multiplier: Optional[float], | |
| ): | |
| super().__init__() | |
| mp_size = mp_utils.get_world_size() | |
| hidden_dim = int(2 * hidden_dim / 3) | |
| if ffn_dim_multiplier is not None: | |
| hidden_dim = int(ffn_dim_multiplier * hidden_dim) | |
| hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | |
| assert hidden_dim % mp_size == 0 | |
| self.w13 = nn.Linear( | |
| dim, | |
| 2 * hidden_dim // mp_size, | |
| bias=False, | |
| ) | |
| self.w2 = nn.Linear( | |
| hidden_dim // mp_size, | |
| dim, | |
| bias=False, | |
| ) | |
| self._register_load_state_dict_pre_hook(self.load_hook) | |
| # This adapter makes sure we can load vanilla | |
| # Llama checkpoints where w1 and w3 are not | |
| # fused in a single parameter | |
| def load_hook( | |
| self, | |
| state_dict, | |
| prefix, | |
| local_metadata, | |
| strict, | |
| missing_keys, | |
| unexpected_keys, | |
| error_msgs, | |
| ): | |
| if prefix + "w1.weight" in state_dict: | |
| w1 = state_dict.pop(prefix + "w1.weight") | |
| w3 = state_dict.pop(prefix + "w3.weight") | |
| state_dict[prefix + "w13.weight"] = torch.cat([w1, w3]) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x13 = self.w13(x) | |
| x1, x3 = x13.chunk(2, -1) | |
| output = self.w2(F.silu(x1) * x3) | |
| mp_utils.all_reduce(output) | |
| return output | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, args: ModelArgs, layer_index: int): | |
| super().__init__() | |
| assert args.dim % args.n_heads == 0 | |
| head_dim = args.dim // args.n_heads | |
| if args.n_kv_heads is not None: | |
| n_kv_heads = args.n_kv_heads | |
| else: | |
| n_kv_heads = args.n_heads | |
| mp_size = mp_utils.get_world_size() | |
| assert args.n_heads % n_kv_heads == 0 | |
| assert args.n_heads % mp_size == 0 | |
| assert n_kv_heads % mp_size == 0 | |
| self.is_last_layer = layer_index + 1 == args.n_layers | |
| self.attention = Attention( | |
| dim=args.dim, | |
| head_dim=head_dim, | |
| n_heads=args.n_heads, | |
| n_kv_heads=n_kv_heads, | |
| rope_theta=args.rope_theta, | |
| ) | |
| self.feed_forward = FeedForward( | |
| dim=args.dim, | |
| hidden_dim=4 * args.dim, | |
| multiple_of=args.multiple_of, | |
| ffn_dim_multiplier=args.ffn_dim_multiplier, | |
| ) | |
| self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) | |
| self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| cache: LayerCache, | |
| attn_bias: AttnBias, | |
| ) -> torch.Tensor: | |
| position_index = None | |
| if self.is_last_layer and attn_bias.q_seqinfo.max_seqlen > 1: | |
| position_index = attn_bias.q_seqinfo.seqstart[1:] - 1 | |
| h = self.attention.forward( | |
| self.attention_norm(x), | |
| cache, | |
| attn_bias, | |
| position_index=position_index, | |
| ) | |
| if position_index is not None: | |
| x = x[position_index] | |
| h = h + x | |
| out = h + self.feed_forward(self.ffn_norm(h)) | |
| return out | |
| class Transformer(nn.Module): | |
| def __init__(self, args: ModelArgs): | |
| super().__init__() | |
| mp_size = mp_utils.get_world_size() | |
| assert args.dim % mp_size == 0 | |
| assert args.vocab_size > 0 | |
| assert args.vocab_size % mp_size == 0 | |
| self.tok_embeddings = nn.Embedding( | |
| num_embeddings=args.vocab_size, | |
| embedding_dim=args.dim // mp_size, | |
| ) | |
| self.layers = nn.ModuleList() | |
| for layer_index in range(args.n_layers): | |
| self.layers.append(TransformerBlock(args, layer_index)) | |
| self.norm = RMSNorm(args.dim, eps=args.norm_eps) | |
| self.output = nn.Linear( | |
| args.dim, | |
| args.vocab_size // mp_size, | |
| bias=False, | |
| ) | |
| def forward_with_attn_bias( | |
| self, | |
| token_values: torch.Tensor, | |
| attn_bias: AttnBias, | |
| cache: list[LayerCache], | |
| ) -> torch.Tensor: | |
| h_parallel = self.tok_embeddings(token_values) | |
| h = mp_utils.all_gather(h_parallel) | |
| for i, layer in enumerate(self.layers): | |
| h = layer(h, cache[i], attn_bias) | |
| logits_parallel = self.output(self.norm(h)) | |
| logits = mp_utils.all_gather(logits_parallel) | |
| return logits.float() | |
| def forward( | |
| self, | |
| token_values: torch.Tensor, | |
| token_lengths: torch.Tensor, | |
| start_pos: torch.Tensor, | |
| cache: list[LayerCache], | |
| kv_padding: int, | |
| ) -> torch.Tensor: | |
| attn_bias = AttnBias.from_seqlens( | |
| q_seqlen=token_lengths.tolist(), | |
| kv_seqlen=(start_pos + token_lengths).tolist(), | |
| kv_padding=kv_padding, | |
| ) | |
| return self.forward_with_attn_bias(token_values, attn_bias, cache) | |
| def make_cache( | |
| args: ModelArgs, | |
| length: int, | |
| device: Optional[Union[str, torch.device]] = None, | |
| n_layers: Optional[int] = None, | |
| dtype: Optional[torch.dtype] = None, | |
| ) -> list[LayerCache]: | |
| """ | |
| Allocate a cache to be used with the Transformer module. | |
| Args: | |
| args (ModelArgs): the model configuration. | |
| length (int): per layer cache size. | |
| It is usually budgeted as ``max_batch * max_seq`` | |
| device (torch.device, optional): the device on which | |
| the cache should be allocated. | |
| n_layers (int, optional): the number of layers to | |
| allocate a cache for (defaults to the model | |
| settings). | |
| dtype (torch.dtype, optional): the dtype to use for | |
| cache entries (defaults to the default dtype). | |
| Returns: | |
| The cache object to pass to ``Tranformer.forward``. | |
| """ | |
| head_dim = args.dim // args.n_heads | |
| n_kv_heads = args.n_kv_heads | |
| if n_kv_heads is None: | |
| n_kv_heads = args.n_heads | |
| n_local_kv_heads = n_kv_heads // mp_utils.get_world_size() | |
| if n_layers is None: | |
| n_layers = args.n_layers | |
| shape = (1, length, n_local_kv_heads, 1, head_dim) | |
| heads_per_group = args.n_heads // n_kv_heads | |
| expansion = (-1, -1, -1, heads_per_group, -1) | |
| return [ | |
| ( | |
| torch.zeros(shape, device=device, dtype=dtype).expand(expansion), | |
| torch.zeros(shape, device=device, dtype=dtype).expand(expansion), | |
| ) | |
| for _ in range(n_layers) | |
| ] | |
| def cache_prefix(cache: list[LayerCache], length: int) -> list[LayerCache]: | |
| """ | |
| Take a prefix view of a larger cache. | |
| The original cache object remains of identical size and valid | |
| after the shrinked alias has been used. This function is useful | |
| when a cache was allocated for a larger batch size than what is | |
| necessary. | |
| Args: | |
| cache: the cache to take a view in. | |
| length (int): the desired length | |
| Returns: | |
| A view in the input cache object. | |
| """ | |
| if len(cache) > 0: | |
| assert cache[0][0].shape[1] >= length | |
| return [(ck[:, :length], cv[:, :length]) for ck, cv in cache] | |