| | from dataclasses import dataclass |
| | from typing import Dict, Optional, Tuple, Union |
| | import math |
| |
|
| | import mlx.core as mx |
| | import mlx.nn as nn |
| |
|
| | from .base import BaseModelArgs |
| |
|
| |
|
| | @dataclass |
| | class ModelArgs(BaseModelArgs): |
| | model_type: str |
| | add_bias_linear: bool = False |
| | add_qkv_bias: bool = True |
| | apply_query_key_layer_scaling: bool = True |
| | apply_residual_connection_post_layernorm: bool = False |
| | attention_dropout: float = 0.0 |
| | attention_softmax_in_fp32: bool = True |
| | bias_dropout_fusion: bool = True |
| | ffn_hidden_size: int = 13696 |
| | fp32_residual_connection: bool = False |
| | hidden_dropout: float = 0.0 |
| | hidden_size: int = 4096 |
| | kv_channels: int = 128 |
| | layernorm_epsilon: float = 1.5625e-07 |
| | multi_query_attention: bool = True |
| | multi_query_group_num: int = 2 |
| | num_attention_heads: int = 32 |
| | num_hidden_layers: int = 40 |
| | num_layers: int = 40 |
| | rope_ratio: int = 500 |
| | original_rope: bool = True |
| | padded_vocab_size: int = 151552 |
| | post_layer_norm: bool = True |
| | rmsnorm: bool = True |
| | seq_length: int = 131072 |
| | use_cache: bool = True |
| | torch_dtype: str = "bfloat16" |
| | tie_word_embeddings: bool = False |
| |
|
| | def __post_init__(self): |
| | pass |
| |
|
| | class RotaryEmbedding(nn.Module): |
| | def __init__(self, dim, rope_ratio=1, original_impl=False, dtype=None): |
| | super().__init__() |
| | |
| | |
| | |
| | self.inv_freq_type = dtype |
| | self.dim = dim |
| | self.original_impl = original_impl |
| | self.rope_ratio = rope_ratio |
| |
|
| | def forward_impl( |
| | self, seq_len: int, n_elem: int, dtype: mx.Dtype, base: int = 10000 |
| | ): |
| | """Enhanced Transformer with Rotary Position Embedding. |
| | Derived from:https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ |
| | transformers/rope/__init__.py. MIT License: |
| | https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. |
| | """ |
| | |
| | base = base * self.rope_ratio |
| | theta = 1.0 / (base ** (mx.arange(0, n_elem, 2, dtype=mx.float16) / n_elem)) |
| |
|
| | |
| | seq_idx = mx.arange(seq_len, dtype=mx.float16) |
| |
|
| | |
| | idx_theta = mx.outer(seq_idx, theta).astype(mx.float16) |
| |
|
| | cache = mx.stack([mx.cos(idx_theta), mx.sin(idx_theta)], axis=-1) |
| |
|
| | |
| | if dtype in (mx.float16, mx.bfloat16, mx.int8): |
| | cache = cache.astype(mx.bfloat16) if dtype == mx.bfloat16 else cache.astype(mx.float16) |
| | return cache |
| | |
| | def __call__(self, max_seq_len, offset=0): |
| | return self.forward_impl( |
| | max_seq_len, self.dim, dtype=self.inv_freq_type, |
| | ) |
| |
|
| | def apply_rotary_pos_emb(x: mx.array, rope_cache: mx.array) -> mx.array: |
| | |
| | b, np, sq, hn = x.shape[0], x.shape[1], x.shape[2], x.shape[3] |
| | rot_dim = rope_cache.shape[-2] * 2 |
| | x, x_pass = x[..., :rot_dim], x[..., rot_dim:] |
| | |
| | rope_cache = rope_cache[:, :sq] |
| | xshaped = x.reshape(b, np, sq, rot_dim // 2, 2) |
| | rope_cache = rope_cache.reshape(-1, 1, sq, xshaped.shape[3], 2) |
| | x_out2 = mx.stack( |
| | [ |
| | xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], |
| | xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], |
| | ], |
| | -1, |
| | ) |
| | x_out2 = x_out2.flatten(3) |
| | return mx.concatenate((x_out2, x_pass), axis=-1) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | class CoreAttention(nn.Module): |
| | def __init__(self, args: ModelArgs, layer_number): |
| | super().__init__() |
| |
|
| | self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling |
| | self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 |
| | if self.apply_query_key_layer_scaling: |
| | self.attention_softmax_in_fp32 = True |
| | self.layer_number = max(1, layer_number) |
| |
|
| | projection_size = args.kv_channels * args.num_attention_heads |
| |
|
| | |
| | self.hidden_size_per_partition = projection_size |
| | self.hidden_size_per_attention_head = projection_size // args.num_attention_heads |
| | self.num_attention_heads_per_partition = args.num_attention_heads |
| |
|
| | coeff = None |
| | self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) |
| | if self.apply_query_key_layer_scaling: |
| | coeff = self.layer_number |
| | self.norm_factor *= coeff |
| | self.coeff = coeff |
| |
|
| | self.attention_dropout = nn.Dropout(args.attention_dropout) |
| |
|
| | def __call__(self, query_layer, key_layer, value_layer, attention_mask): |
| | |
| | scale_factor = query_layer.shape[-1] ** -0.5 |
| | |
| | |
| | if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: |
| | attention_mask = nn.MultiHeadAttention.create_additive_causal_mask(query_layer.shape[2]).astype(query_layer.dtype) |
| | context_layer = mx.fast.scaled_dot_product_attention(query_layer, key_layer, value_layer, scale=scale_factor,mask=attention_mask) |
| | else: |
| | if attention_mask is not None: |
| | attention_mask = ~attention_mask |
| | context_layer = mx.fast.scaled_dot_product_attention(query_layer, key_layer, value_layer, scale=scale_factor, mask=attention_mask) |
| | context_layer = context_layer.transpose((0,2,1,3)) |
| | new_context_layer_shape = context_layer.shape[:-2] + (self.hidden_size_per_partition,) |
| | context_layer = context_layer.reshape(*new_context_layer_shape) |
| |
|
| | return context_layer |
| |
|
| | class SelfAttention(nn.Module): |
| | def __init__(self, args: ModelArgs, layer_number): |
| | super(SelfAttention, self).__init__() |
| | self.layer_number = max(1, layer_number) |
| |
|
| | self.projection_size = args.kv_channels * args.num_attention_heads |
| |
|
| | |
| | self.hidden_size_per_attention_head = self.projection_size // args.num_attention_heads |
| | self.num_attention_heads_per_partition = args.num_attention_heads |
| | self.multi_query_attention = args.multi_query_attention |
| | self.qkv_hidden_size = 3 * self.projection_size |
| | if self.multi_query_attention: |
| | self.num_multi_query_groups_per_partition = args.multi_query_group_num |
| | self.qkv_hidden_size = ( |
| | self.projection_size + 2 * self.hidden_size_per_attention_head * args.multi_query_group_num |
| | ) |
| | self.query_key_value = nn.Linear(args.hidden_size, self.qkv_hidden_size, |
| | bias=args.add_bias_linear or args.add_qkv_bias) |
| |
|
| | self.core_attention = CoreAttention(args, self.layer_number) |
| |
|
| | |
| | self.dense = nn.Linear(self.projection_size, args.hidden_size, bias=args.add_bias_linear) |
| |
|
| | def __call__(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True): |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | mixed_x_layer = self.query_key_value(hidden_states) |
| |
|
| | if self.multi_query_attention: |
| | q_k_v_len = [ |
| | self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, |
| | self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, |
| | self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, |
| | ] |
| | mixs = mixed_x_layer.split([ |
| | q_k_v_len[0], |
| | q_k_v_len[0]+q_k_v_len[1], |
| | q_k_v_len[0]+q_k_v_len[1]+q_k_v_len[2], |
| | ], |
| | axis=-1, |
| | ) |
| |
|
| | query_layer, key_layer, value_layer = mixs[0], mixs[1], mixs[2] |
| | query_layer = query_layer.reshape( |
| | query_layer.shape[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) |
| | ) |
| | key_layer = key_layer.reshape( key_layer.shape[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)) |
| | value_layer = value_layer.reshape( |
| | value_layer.shape[:-1] |
| | + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) |
| | ) |
| | else: |
| | new_tensor_shape = mixed_x_layer.shape[:-1] + \ |
| | (self.num_attention_heads_per_partition, |
| | 3 * self.hidden_size_per_attention_head) |
| | mixed_x_layer = mixed_x_layer.reshape(*new_tensor_shape) |
| |
|
| | |
| | (query_layer, key_layer, value_layer) = mx.split_along_last_dim(mixed_x_layer, 3) |
| |
|
| | |
| | query_layer, key_layer, value_layer = [k.transpose((0,2,1,3)) for k in [query_layer, key_layer, value_layer]] |
| |
|
| | |
| | if rotary_pos_emb is not None: |
| | query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) |
| | key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) |
| |
|
| |
|
| | |
| | if use_cache: |
| | key_layer, value_layer = kv_cache.update_and_fetch(key_layer, value_layer) |
| | else: |
| | kv_cache = None |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) |
| |
|
| | |
| | |
| | |
| |
|
| | output = self.dense(context_layer) |
| |
|
| | return output |
| |
|
| | class MLP(nn.Module): |
| | def __init__(self, args: ModelArgs): |
| | super().__init__() |
| |
|
| | self.add_bias = args.add_bias_linear |
| |
|
| | |
| | self.dense_h_to_4h = nn.Linear( |
| | args.hidden_size, |
| | args.ffn_hidden_size * 2, |
| | bias=self.add_bias, |
| | ) |
| |
|
| | def swiglu(x): |
| | x = mx.split(x, 2, axis=-1) |
| | return nn.silu(x[0]) * x[1] |
| |
|
| | self.activation_func = swiglu |
| |
|
| | |
| | self.dense_4h_to_h = nn.Linear( |
| | args.ffn_hidden_size, |
| | args.hidden_size, |
| | bias=self.add_bias, |
| | ) |
| |
|
| | def __call__(self, hidden_states): |
| | |
| | intermediate_parallel = self.dense_h_to_4h(hidden_states) |
| | intermediate_parallel = self.activation_func(intermediate_parallel) |
| | |
| | output = self.dense_4h_to_h(intermediate_parallel) |
| | return output |
| | |
| |
|
| | class GLMBlock(nn.Module): |
| | def __init__(self, args: ModelArgs, layer_number): |
| | super(GLMBlock, self).__init__() |
| | self.layer_number = layer_number |
| |
|
| | self.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm |
| |
|
| | self.fp32_residual_connection = args.fp32_residual_connection |
| |
|
| | LayerNormFunc = nn.RMSNorm if args.rmsnorm else nn.LayerNorm |
| | |
| | self.input_layernorm = LayerNormFunc(args.hidden_size, eps=args.layernorm_epsilon) |
| |
|
| | |
| | self.self_attention = SelfAttention(args, layer_number) |
| | self.hidden_dropout = args.hidden_dropout |
| |
|
| | self.dropout = nn.Dropout(self.hidden_dropout) |
| |
|
| | |
| | self.post_attention_layernorm = LayerNormFunc(args.hidden_size, eps=args.layernorm_epsilon) |
| |
|
| | |
| | self.mlp = MLP(args) |
| |
|
| | def __call__( |
| | self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, |
| | ): |
| | |
| |
|
| | |
| | layernorm_output = self.input_layernorm(hidden_states) |
| | |
| | attention_output = self.self_attention( |
| | layernorm_output, |
| | attention_mask, |
| | rotary_pos_emb, |
| | kv_cache=kv_cache, |
| | use_cache=use_cache |
| | ) |
| |
|
| | |
| | if self.apply_residual_connection_post_layernorm: |
| | residual = layernorm_output |
| | else: |
| | residual = hidden_states |
| |
|
| | layernorm_input = self.dropout(attention_output) |
| | layernorm_input = residual + layernorm_input |
| |
|
| | |
| | layernorm_output = self.post_attention_layernorm(layernorm_input) |
| |
|
| | |
| | mlp_output = self.mlp(layernorm_output) |
| |
|
| | |
| | if self.apply_residual_connection_post_layernorm: |
| | residual = layernorm_output |
| | else: |
| | residual = layernorm_input |
| |
|
| | output = self.dropout(mlp_output) |
| | output = residual + output |
| |
|
| | return output |
| |
|
| | class GLMTransformer(nn.Module): |
| | def __init__(self, args: ModelArgs): |
| | super().__init__() |
| |
|
| | self.fp32_residual_connection = args.fp32_residual_connection |
| | self.post_layer_norm = args.post_layer_norm |
| |
|
| | |
| | self.num_layers = args.num_layers |
| |
|
| | |
| | def build_layer(layer_number): |
| | return GLMBlock(args, layer_number) |
| |
|
| | self.layers = [build_layer(i + 1) for i in range(self.num_layers)] |
| |
|
| | if self.post_layer_norm: |
| | LayerNormFunc = nn.RMSNorm if args.rmsnorm else nn.LayerNorm |
| | |
| | self.final_layernorm = LayerNormFunc(args.hidden_size, eps=args.layernorm_epsilon) |
| |
|
| | self.gradient_checkpointing = False |
| |
|
| | def _get_layer(self, layer_number): |
| | return self.layers[layer_number] |
| |
|
| | def __call__( |
| | self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, |
| | use_cache: Optional[bool] = True, |
| | ): |
| | if not kv_caches: |
| | kv_caches = [None for _ in range(self.num_layers)] |
| |
|
| | for index in range(self.num_layers): |
| | layer = self._get_layer(index) |
| | layer_ret = layer( |
| | hidden_states, |
| | attention_mask, |
| | rotary_pos_emb, |
| | kv_cache=kv_caches[index], |
| | use_cache=use_cache |
| | ) |
| | hidden_states = layer_ret |
| |
|
| | |
| | if self.post_layer_norm: |
| | hidden_states = self.final_layernorm(hidden_states) |
| |
|
| | return hidden_states |
| |
|
| | class Embedding(nn.Module): |
| | def __init__(self, args: ModelArgs): |
| | super().__init__() |
| |
|
| | self.hidden_size = args.hidden_size |
| | |
| | self.word_embeddings = nn.Embedding( |
| | args.padded_vocab_size, |
| | self.hidden_size, |
| | ) |
| | self.fp32_residual_connection = args.fp32_residual_connection |
| |
|
| | def __call__(self, input_ids): |
| | |
| | words_embeddings = self.word_embeddings(input_ids) |
| | embeddings = words_embeddings |
| | |
| | if self.fp32_residual_connection: |
| | embeddings = embeddings.float() |
| | return embeddings |
| |
|
| |
|
| | class ChatGLMModel(nn.Module): |
| | def __init__(self, args: ModelArgs): |
| | super().__init__() |
| |
|
| | self.embedding = Embedding(args) |
| | self.num_layers = args.num_layers |
| | self.multi_query_group_num = args.multi_query_group_num |
| |
|
| | self.kv_channels = args.kv_channels |
| | self.use_cache = args.use_cache |
| | self.use_return_dict = False |
| | self.output_hidden_states = False |
| |
|
| | |
| | self.seq_length = args.seq_length |
| | rotary_dim = ( |
| | args.hidden_size // args.num_attention_heads if args.kv_channels is None else args.kv_channels |
| | ) |
| |
|
| | self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=args.rope_ratio, original_impl=args.original_rope,dtype=args.torch_dtype) |
| | self.encoder = GLMTransformer(args) |
| | self.output_layer = nn.Linear(args.hidden_size, args.padded_vocab_size, bias=False) |
| |
|
| | self.new_position_id = None |
| | self.is_first_forward = True |
| |
|
| | def get_input_embeddings(self): |
| | return self.embedding.word_embeddings |
| |
|
| | def set_input_embeddings(self, value): |
| | self.embedding.word_embeddings = value |
| |
|
| | def get_masks(self, input_ids, past_key_values, padding_mask=None): |
| | batch_size, seq_length = input_ids.shape |
| | full_attention_mask = mx.ones((batch_size, seq_length, seq_length), dtype=input_ids.dtype) |
| | full_attention_mask = mx.tril(full_attention_mask) |
| | past_length = 0 |
| | if past_key_values and past_key_values[0].keys is not None: |
| | past_length = past_key_values[0].offset |
| | if past_length: |
| | full_attention_mask = mx.concatenate((mx.ones((batch_size, seq_length, past_length), dtype=input_ids.dtype), |
| | full_attention_mask), axis=-1) |
| | if padding_mask is not None: |
| | full_attention_mask = full_attention_mask * mx.expand_dims(padding_mask,1) |
| | if not past_length and padding_mask is not None: |
| | full_attention_mask -= mx.expand_dims(padding_mask,-1) - 1 |
| | full_attention_mask = (full_attention_mask < 0.5) |
| | full_attention_mask = mx.expand_dims(full_attention_mask,1) |
| | return full_attention_mask |
| |
|
| | def get_position_ids(self, input_ids): |
| | batch_size, seq_length = input_ids.shape |
| | position_ids = mx.arange(seq_length, dtype=mx.int32) |
| | position_ids = mx.broadcast_to(position_ids, (batch_size, seq_length)) |
| | return position_ids |
| |
|
| | def __call__( |
| | self, |
| | input_ids, |
| | position_ids: Optional[mx.array] = None, |
| | attention_mask: Optional[mx.array] = None, |
| | full_attention_mask: Optional[mx.array] = None, |
| | past_key_values: Optional[Tuple[Tuple[mx.array, mx.array], ...]] = None, |
| | inputs_embeds: Optional[mx.array] = None, |
| | use_cache: Optional[bool] = None, |
| | ): |
| | |
| | |
| | if self.new_position_id is None: |
| | position_ids = self.get_position_ids(input_ids) |
| | else: |
| | position_ids = self.new_position_id |
| | |
| | new_position_id = position_ids[..., -1:] |
| | |
| | new_position_id += 1 |
| | |
| | new_position_id = mx.concatenate( |
| | [position_ids, new_position_id], axis=-1 |
| | ) |
| | |
| | self.new_position_id = new_position_id |
| |
|
| | if past_key_values and past_key_values[0].offset > 0: |
| | position_ids = position_ids[..., -1:] |
| | input_ids = input_ids[:, -1:] |
| |
|
| | |
| | batch_size, seq_length = input_ids.shape |
| |
|
| | if inputs_embeds is None: |
| | inputs_embeds = self.embedding(input_ids) |
| | |
| | |
| | rotary_pos_emb = self.rotary_pos_emb(self.seq_length) |
| | if position_ids is not None: |
| | rotary_pos_emb = rotary_pos_emb[position_ids] |
| | else: |
| | rotary_pos_emb = rotary_pos_emb[None, :seq_length] |
| | |
| |
|
| | |
| | hidden_states = self.encoder( |
| | inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, |
| | kv_caches=past_key_values, use_cache=use_cache |
| | ) |
| |
|
| | return hidden_states |
| | |
| | |
| | class Model(nn.Module): |
| | def __init__(self, args: ModelArgs): |
| | super().__init__() |
| | self.args = args |
| | self.model_type = args.model_type |
| | self.transformer = ChatGLMModel(args) |
| |
|
| | def __call__( |
| | self, |
| | inputs: mx.array, |
| | cache=None, |
| | ): |
| | out = self.transformer(inputs, None, None, None, cache, None, True) |
| | if self.args.tie_word_embeddings: |
| | out = self.model.embedding.as_linear(out) |
| | else: |
| | out = self.model.output_layer(out) |
| | return out |
| |
|
| | def sanitize(self, weights): |
| | |
| | return { |
| | k: v for k, v in weights.items() if "transformer.rotary_pos_emb.inv_freq" not in k |
| | } |
| | |
| |
|
| | @property |
| | def layers(self): |
| | return self.model.encoder.layers |
| |
|
| | @property |
| | def head_dim(self): |
| | return self.args.hidden_size // self.args.num_attention_heads |
| |
|
| | @property |
| | def n_kv_heads(self): |
| | return self.args.multi_query_group_num |
| | |
| | @property |
| | def model(self): |
| | return self.transformer |
| |
|
| |
|
| |
|