| | import operator |
| | from functools import partial, reduce |
| | from typing import Iterable, List, Optional |
| |
|
| | import torch |
| | import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as torch_ckpt |
| | import torch.nn as nn |
| | from xformers.ops.fmha import memory_efficient_attention |
| | from xformers.ops.fmha.attn_bias import AttentionBias, BlockDiagonalCausalMask |
| |
|
| | from .args import ModelArgs |
| | from .lora import LoRALinear |
| | from .moe import MoeLayer |
| | from .rope import apply_rotary_emb, precompute_freqs_cis |
| |
|
| |
|
| | def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int): |
| | keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim) |
| | values = torch.repeat_interleave(values, repeats=repeats, dim=dim) |
| | return keys, values |
| |
|
| |
|
| | def maybe_lora_layer( |
| | args: ModelArgs, rank: Optional[int] = None |
| | ) -> partial[LoRALinear] | type[nn.Linear]: |
| | MaybeLora: partial[LoRALinear] | type[nn.Linear] |
| | if not args.lora.enable: |
| | return nn.Linear |
| |
|
| | rank = rank or args.lora.rank |
| | scaling = args.lora.scaling |
| | dropout = args.lora.dropout |
| |
|
| | MaybeLora = partial( |
| | LoRALinear, |
| | rank=rank, |
| | scaling=scaling, |
| | dropout=dropout, |
| | ) |
| |
|
| | return MaybeLora |
| |
|
| |
|
| | class Attention(nn.Module): |
| | def __init__(self, args: ModelArgs): |
| | super().__init__() |
| | self.args = args |
| |
|
| | self.n_heads: int = args.n_heads |
| | self.n_kv_heads: int = args.n_kv_heads |
| | self.head_dim: int = args.head_dim |
| |
|
| | self.repeats = self.n_heads // self.n_kv_heads |
| |
|
| | self.scale = self.args.head_dim**-0.5 |
| |
|
| | MaybeLora = maybe_lora_layer(args) |
| |
|
| | self.wq = MaybeLora(args.dim, args.n_heads * args.head_dim, bias=False) |
| | self.wk = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False) |
| | self.wv = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False) |
| |
|
| | self.wo = MaybeLora(args.n_heads * args.head_dim, args.dim, bias=False) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | freqs_cis: torch.Tensor, |
| | mask: AttentionBias, |
| | ) -> torch.Tensor: |
| | seqlen_sum, _ = x.shape |
| |
|
| | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) |
| |
|
| | xq = xq.view(seqlen_sum, self.n_heads, self.args.head_dim) |
| | xk = xk.view(seqlen_sum, self.n_kv_heads, self.args.head_dim) |
| | xv = xv.view(seqlen_sum, self.n_kv_heads, self.args.head_dim) |
| | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) |
| |
|
| | key, val = xk, xv |
| |
|
| | |
| | key, val = repeat_kv(key, val, self.repeats, dim=1) |
| |
|
| | |
| | xq, key, val = xq[None, ...], key[None, ...], val[None, ...] |
| | output = memory_efficient_attention(xq, key, val, mask) |
| |
|
| | return self.wo(output.view(seqlen_sum, -1)) |
| |
|
| |
|
| | class FeedForward(nn.Module): |
| | def __init__(self, args: ModelArgs): |
| | super().__init__() |
| |
|
| | MaybeLora = maybe_lora_layer(args) |
| | self.w1 = MaybeLora(args.dim, args.hidden_dim, bias=False) |
| | self.w2 = MaybeLora(args.hidden_dim, args.dim, bias=False) |
| | self.w3 = MaybeLora(args.dim, args.hidden_dim, bias=False) |
| |
|
| | def forward(self, x) -> torch.Tensor: |
| | return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) |
| |
|
| |
|
| | class RMSNorm(torch.nn.Module): |
| | def __init__(self, dim: int, eps: float = 1e-6): |
| | super().__init__() |
| | self.eps = eps |
| | self.weight = nn.Parameter(torch.ones(dim)) |
| |
|
| | def _norm(self, x): |
| | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| |
|
| | def forward(self, x): |
| | output = self._norm(x.float()).type_as(x) |
| | return output * self.weight |
| |
|
| |
|
| | class TransformerBlock(nn.Module): |
| | def __init__(self, args: ModelArgs): |
| | super().__init__() |
| | self.n_heads = args.n_heads |
| | self.dim = args.dim |
| | self.attention = Attention(args) |
| |
|
| | self.feed_forward: MoeLayer | FeedForward |
| | if args.moe is not None: |
| | self.feed_forward = MoeLayer( |
| | experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)], |
| | gate=nn.Linear(args.dim, args.moe.num_experts, bias=False), |
| | moe_args=args.moe, |
| | ) |
| | else: |
| | self.feed_forward = FeedForward(args=args) |
| |
|
| | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) |
| | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) |
| | self.args = args |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | freqs_cis: torch.Tensor, |
| | att_mask: AttentionBias, |
| | ) -> torch.Tensor: |
| | r = self.attention(self.attention_norm(x), freqs_cis, att_mask) |
| | h = x + r |
| |
|
| | r = self.feed_forward(self.ffn_norm(h)) |
| | out = h + r |
| |
|
| | return out |
| |
|
| |
|
| | class Transformer(nn.Module): |
| | def __init__(self, args: ModelArgs, checkpoint: bool = False): |
| | super().__init__() |
| | self.args = args |
| | self.vocab_size = args.vocab_size |
| | self.n_layers = args.n_layers |
| | assert self.vocab_size > 0 |
| | self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim) |
| | self.layers = torch.nn.ModuleList() |
| | for _ in range(args.n_layers): |
| | block: torch.nn.Module = TransformerBlock(args=args) |
| | if checkpoint: |
| | |
| | non_reentrant_wrapper = partial( |
| | torch_ckpt.checkpoint_wrapper, |
| | checkpoint_impl=torch_ckpt.CheckpointImpl.NO_REENTRANT, |
| | ) |
| | block = non_reentrant_wrapper(block) |
| |
|
| | self.layers.append(block) |
| |
|
| | self.norm = RMSNorm(args.dim, eps=args.norm_eps) |
| |
|
| | self.output = torch.nn.Linear( |
| | args.dim, |
| | args.vocab_size, |
| | bias=False, |
| | ) |
| |
|
| | |
| | self._freqs_cis = None |
| |
|
| | @property |
| | def dtype(self) -> torch.dtype: |
| | return self.tok_embeddings.weight.dtype |
| |
|
| | @property |
| | def device(self) -> torch.device: |
| | return self.tok_embeddings.weight.device |
| |
|
| | @property |
| | def freqs_cis(self): |
| | |
| | device = next(iter(self.parameters())).device |
| | if self._freqs_cis is None: |
| | self._freqs_cis = precompute_freqs_cis( |
| | self.args.head_dim, 128_000, theta=self.args.rope_theta, device=device |
| | ) |
| |
|
| | return self._freqs_cis |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | seqlens: List[int], |
| | ) -> torch.Tensor: |
| | assert sum(seqlens) == input_ids.shape[0], (sum(seqlens), input_ids.shape[0]) |
| |
|
| | h = self.tok_embeddings(input_ids) |
| | positions = positions_from_sizes(seqlens, self.freqs_cis.device) |
| | att_mask = BlockDiagonalCausalMask.from_seqlens(seqlens) |
| |
|
| | freqs_cis = self.freqs_cis[positions].to(device=h.device) |
| |
|
| | for layer in self.layers: |
| | h = layer(h, freqs_cis, att_mask) |
| |
|
| | return self.output(self.norm(h)).float() |
| |
|
| |
|
| | def positions_from_sizes(sizes: Iterable[int], device): |
| | return torch.tensor( |
| | reduce(operator.iadd, [list(range(s)) for s in sizes], []), |
| | dtype=torch.long, |
| | device=device, |
| | ) |
| |
|