| from dataclasses import dataclass |
| from typing import Any, Dict, Optional, Union |
|
|
| import mlx.core as mx |
| import mlx.nn as nn |
| from mlx.nn.layers.distributed import shard_linear |
|
|
| from mlx_lm.models.activations import swiglu |
| from mlx_lm.models.base import ( |
| BaseModelArgs, |
| scaled_dot_product_attention, |
| ) |
| from mlx_lm.models.rope_utils import initialize_rope |
|
|
|
|
| @dataclass |
| class ModelArgs(BaseModelArgs): |
| model_type: str |
| hidden_size: int |
| num_hidden_layers: int |
| intermediate_size: int |
| num_attention_heads: int |
| rms_norm_eps: float |
| vocab_size: int |
| num_key_value_heads: int |
| max_position_embeddings: int |
| rope_theta: float |
| head_dim: int |
| tie_word_embeddings: bool |
| rope_scaling: Optional[Dict[str, Union[float, str]]] = None |
|
|
|
|
| def _make_bidirectional_mask( |
| attention_mask: mx.array, |
| batch_size: int, |
| seq_len: int, |
| offset: int = 0, |
| ) -> mx.array: |
| if attention_mask.ndim != 2: |
| raise ValueError( |
| f"Expected 2D attention_mask with shape [batch, seq], got {attention_mask.shape}" |
| ) |
| if attention_mask.shape[0] != batch_size: |
| raise ValueError( |
| "attention_mask batch size does not match input batch size" |
| ) |
| if attention_mask.shape[1] < offset + seq_len: |
| raise ValueError( |
| "attention_mask sequence length is shorter than the required cached length" |
| ) |
|
|
| |
| q = attention_mask[:, offset : offset + seq_len].astype(mx.bool_) |
| k = attention_mask[:, : offset + seq_len].astype(mx.bool_) |
| mask = q[:, :, None] & k[:, None, :] |
| return mask[:, None, :, :] |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, args: ModelArgs): |
| super().__init__() |
|
|
| dim = args.hidden_size |
| self.n_heads = n_heads = args.num_attention_heads |
| self.n_kv_heads = n_kv_heads = args.num_key_value_heads |
| head_dim = args.head_dim |
| self.scale = head_dim**-0.5 |
|
|
| self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) |
| self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) |
| self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) |
| self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) |
|
|
| self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps) |
| self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps) |
| self.rope = initialize_rope( |
| head_dim, |
| base=args.rope_theta, |
| traditional=False, |
| scaling_config=args.rope_scaling, |
| max_position_embeddings=args.max_position_embeddings, |
| ) |
|
|
| def __call__( |
| self, |
| x: mx.array, |
| mask: Optional[mx.array] = None, |
| cache: Optional[Any] = None, |
| ) -> mx.array: |
| bsz, seq_len, _ = x.shape |
|
|
| queries = self.q_proj(x) |
| keys = self.k_proj(x) |
| values = self.v_proj(x) |
|
|
| queries = self.q_norm(queries.reshape(bsz, seq_len, self.n_heads, -1)).transpose( |
| 0, 2, 1, 3 |
| ) |
| keys = self.k_norm(keys.reshape(bsz, seq_len, self.n_kv_heads, -1)).transpose( |
| 0, 2, 1, 3 |
| ) |
| values = values.reshape(bsz, seq_len, self.n_kv_heads, -1).transpose(0, 2, 1, 3) |
|
|
| if cache is not None: |
| queries = self.rope(queries, offset=cache.offset) |
| keys = self.rope(keys, offset=cache.offset) |
| keys, values = cache.update_and_fetch(keys, values) |
| else: |
| queries = self.rope(queries) |
| keys = self.rope(keys) |
|
|
| output = scaled_dot_product_attention( |
| queries, |
| keys, |
| values, |
| cache=cache, |
| scale=self.scale, |
| mask=mask, |
| ) |
| output = output.transpose(0, 2, 1, 3).reshape(bsz, seq_len, -1) |
| return self.o_proj(output) |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, dim: int, hidden_dim: int): |
| super().__init__() |
| self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) |
| self.down_proj = nn.Linear(hidden_dim, dim, bias=False) |
| self.up_proj = nn.Linear(dim, hidden_dim, bias=False) |
|
|
| def __call__(self, x: mx.array) -> mx.array: |
| return self.down_proj(swiglu(self.gate_proj(x), self.up_proj(x))) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, args: ModelArgs): |
| super().__init__() |
| self.self_attn = Attention(args) |
| self.mlp = MLP(args.hidden_size, args.intermediate_size) |
| self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) |
| self.post_attention_layernorm = nn.RMSNorm( |
| args.hidden_size, eps=args.rms_norm_eps |
| ) |
|
|
| def __call__( |
| self, |
| x: mx.array, |
| mask: Optional[mx.array] = None, |
| cache: Optional[Any] = None, |
| ) -> mx.array: |
| residual = x + self.self_attn(self.input_layernorm(x), mask, cache) |
| return residual + self.mlp(self.post_attention_layernorm(residual)) |
|
|
|
|
| class Model(nn.Module): |
| def __init__(self, args: ModelArgs): |
| super().__init__() |
| self.args = args |
| self.model_type = args.model_type |
| self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) |
| self.layers = [TransformerBlock(args) for _ in range(args.num_hidden_layers)] |
| self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) |
|
|
| def __call__( |
| self, |
| inputs: mx.array, |
| cache=None, |
| input_embeddings: Optional[mx.array] = None, |
| attention_mask: Optional[mx.array] = None, |
| ) -> mx.array: |
| if input_embeddings is not None: |
| h = input_embeddings |
| else: |
| h = self.embed_tokens(inputs) |
|
|
| if cache is None: |
| cache = [None] * len(self.layers) |
| elif len(cache) != len(self.layers): |
| raise ValueError( |
| f"Expected cache with {len(self.layers)} layers, got {len(cache)}" |
| ) |
|
|
| if any(layer_cache is not None for layer_cache in cache): |
| raise ValueError( |
| "KV cache is not supported for this bidirectional embedding model." |
| ) |
|
|
| if attention_mask is not None: |
| mask = _make_bidirectional_mask( |
| attention_mask, |
| batch_size=h.shape[0], |
| seq_len=h.shape[1], |
| offset=0, |
| ) |
| else: |
| mask = None |
|
|
| for layer, layer_cache in zip(self.layers, cache): |
| h = layer(h, mask, layer_cache) |
|
|
| return self.norm(h) |
|
|
| def shard(self, group: Optional[mx.distributed.Group] = None): |
| group = group or mx.distributed.init() |
| n = group.size() |
| for layer in self.layers: |
| layer.self_attn.q_proj = shard_linear( |
| layer.self_attn.q_proj, "all-to-sharded", group=group |
| ) |
| layer.self_attn.k_proj = shard_linear( |
| layer.self_attn.k_proj, "all-to-sharded", group=group |
| ) |
| layer.self_attn.v_proj = shard_linear( |
| layer.self_attn.v_proj, "all-to-sharded", group=group |
| ) |
| layer.self_attn.o_proj = shard_linear( |
| layer.self_attn.o_proj, "sharded-to-all", group=group |
| ) |
| layer.self_attn.n_heads //= n |
| layer.self_attn.n_kv_heads //= n |
|
|
| layer.mlp.gate_proj = shard_linear( |
| layer.mlp.gate_proj, "all-to-sharded", group=group |
| ) |
| layer.mlp.down_proj = shard_linear( |
| layer.mlp.down_proj, "sharded-to-all", group=group |
| ) |
| layer.mlp.up_proj = shard_linear( |
| layer.mlp.up_proj, "all-to-sharded", group=group |
| ) |
|
|