| import math |
| import json |
| import re |
| from copy import deepcopy |
| from pathlib import Path |
| from dataclasses import dataclass |
| from typing import Callable |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from torch.utils.checkpoint import checkpoint |
|
|
| import xformers.ops as xops |
|
|
| from huggingface_hub import PyTorchModelHubMixin |
|
|
| from open_lm.attention import get_attn_func, xformers_attn, torch_attn |
| from open_lm.norms import get_norm_class |
| from open_lm.positional_embedding.head_rotary import HeadRotaryWithCast |
| from open_lm.positional_embedding.rotary import RotaryWithCast |
| from open_lm.positional_embedding.llama_rotary import LLaMARotaryWithCast |
| from open_lm.positional_embedding.none import identity_with_cast |
|
|
| |
| try: |
| from megablocks.layers.moe import MoE |
| from megablocks.layers.arguments import Arguments as MoEArgs |
| except ImportError: |
| MoE = None |
| MoEArgs = None |
|
|
| try: |
| from mamba_ssm import MambaLMHeadModel |
| except ImportError: |
| MambaLMHeadModel = None |
|
|
| |
| _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] |
| _MODEL_CONFIGS = {} |
|
|
|
|
| def _natural_key(string_): |
| return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] |
|
|
|
|
| def _rescan_model_configs(model_config_paths=None): |
| global _MODEL_CONFIGS |
|
|
| config_iter = None |
| if model_config_paths is not None: |
| config_iter = [ |
| Path(model_config_paths), |
| ] |
| else: |
| config_iter = _MODEL_CONFIG_PATHS |
|
|
| config_ext = (".json",) |
| config_files = [] |
| for config_path in config_iter: |
| if config_path.is_file() and config_path.suffix in config_ext: |
| config_files.append(Path(config_path)) |
| elif config_path.is_dir(): |
| for ext in config_ext: |
| config_files.extend(config_path.glob(f"*{ext}")) |
|
|
| for cf in config_files: |
| with open(cf, "r") as f: |
| model_cfg = json.load(f) |
| _MODEL_CONFIGS[cf.stem] = model_cfg |
|
|
| _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} |
|
|
|
|
| _rescan_model_configs() |
|
|
|
|
| |
| @dataclass |
| class Params: |
| dim: int = 512 |
| n_layers: int = 8 |
| n_heads: int = 8 |
| vocab_size: int = -1 |
| norm_eps: float = 1e-5 |
| seq_len: int = 2048 |
| post_embed_norm: bool = False |
| weight_tying: bool = False |
| norm_type: nn.Module = nn.LayerNorm |
| attn_func: Callable = xformers_attn if torch.cuda.is_available() else torch_attn |
| apply_qk_norm: bool = False |
| moe_loss_weight: float = 0.1 |
| moe_capacity_factor: float = 1.25 |
| moe_expert_model_parallelism: bool = False |
| moe_weight_parallelism: bool = False |
| moe_num_experts: int = 8 |
| moe_top_k: int = 2 |
| moe_freq: int = 0 |
| positional_embedding_type: str = "rotary" |
| ffn_type: str = "swiglu" |
|
|
|
|
| def get_pos_embed(args: Params): |
| head_dim = args.dim // args.n_heads |
| if args.positional_embedding_type == "rotary": |
| return RotaryWithCast(head_dim, args.seq_len) |
| elif args.positional_embedding_type == "llama_rotary": |
| return LLaMARotaryWithCast(head_dim, args.n_heads, args.seq_len) |
| elif args.positional_embedding_type == "head_rotary": |
| return HeadRotaryWithCast(head_dim, args.seq_len) |
| elif args.positional_embedding_type == "none": |
| return identity_with_cast |
| else: |
| raise RuntimeError(f"Unknown positional embedding type {args.positional_embedding_type}") |
|
|
|
|
| class CustomAttn(nn.Module): |
| def __init__(self, layer_id, args: Params): |
| super().__init__() |
| self.n_heads = args.n_heads |
| self.head_dim = args.dim // args.n_heads |
| self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False) |
| self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) |
| self.pos_embed = get_pos_embed(args) |
| self.attn_fn = args.attn_func |
| self.apply_qk_norm = args.apply_qk_norm |
|
|
| |
| self.q_norm = ( |
| args.norm_type( |
| args.n_heads * self.head_dim, |
| eps=args.norm_eps, |
| ) |
| if self.apply_qk_norm |
| else nn.Identity() |
| ) |
| self.k_norm = ( |
| args.norm_type( |
| args.n_heads * self.head_dim, |
| eps=args.norm_eps, |
| ) |
| if self.apply_qk_norm |
| else nn.Identity() |
| ) |
|
|
| self.layer_id = layer_id |
| self.dim = args.dim |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| |
| std = 1.0 / math.sqrt(self.dim) |
| torch.nn.init.trunc_normal_(self.in_proj.weight, std=std, a=-3 * std, b=3 * std) |
| |
| std = std / math.sqrt(2 * (self.layer_id + 1)) |
| torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std) |
|
|
| def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False, attention_mask=None): |
| batchsize, q_len, _ = x.shape |
| queries, keys, vals = self.in_proj(x).chunk(3, dim=-1) |
|
|
| queries = self.q_norm(queries) |
| keys = self.k_norm(keys) |
|
|
| queries = queries.view(batchsize, q_len, self.n_heads, self.head_dim) |
| keys = keys.view(batchsize, q_len, self.n_heads, self.head_dim) |
| vals = vals.view(batchsize, q_len, self.n_heads, self.head_dim) |
|
|
| past_length = 0 if past_key_value is None else past_key_value[0].shape[1] |
| queries, keys, vals = self.pos_embed(queries, keys, vals, offset=past_length) |
|
|
| if past_key_value is not None and use_cache: |
| keys = torch.cat([past_key_value[0], keys], dim=1) |
| vals = torch.cat([past_key_value[1], vals], dim=1) |
|
|
| if use_cache: |
| past_key_value = [keys, vals] |
|
|
| output = self.attn_fn( |
| queries, |
| keys, |
| vals, |
| is_causal=is_causal, |
| attention_mask=attention_mask, |
| ) |
|
|
| output = output.view(batchsize, q_len, -1) |
|
|
| return self.out_proj(output), past_key_value |
|
|
|
|
| class GemmaMLP(nn.Module): |
| """Google's Gemma model MLP (aka GeGLU). |
| |
| Modified from https://github.com/google/gemma_pytorch/blob/01062c9ef4cf89ac0c985b25a734164ede017d0b/gemma/model.py#L182-L201 |
| """ |
|
|
| def __init__(self, dim: int, hidden_dim: int, layer_id: int): |
| super().__init__() |
| self.dim = dim |
| self.hidden_dim = hidden_dim |
| self.gate_proj = nn.Linear(dim, hidden_dim) |
| self.up_proj = nn.Linear(dim, hidden_dim) |
| self.down_proj = nn.Linear(hidden_dim, dim) |
| self._layer_id = layer_id |
|
|
| def forward(self, x): |
| gate = self.gate_proj(x) |
| gate = F.gelu(gate) |
| up = self.up_proj(x) |
| fuse = gate * up |
| outputs = self.down_proj(fuse) |
| return outputs |
|
|
| def reset_parameters(self): |
| std = 1.0 / math.sqrt(self.dim) |
| torch.nn.init.trunc_normal_(self.gate_proj.weight, std=std, a=-3 * std, b=3 * std) |
| torch.nn.init.trunc_normal_(self.up_proj.weight, std=std, a=-3 * std, b=3 * std) |
|
|
| std = 1.0 / math.sqrt(self.hidden_dim) |
| std = std / math.sqrt(2 * (self._layer_id + 1)) |
| torch.nn.init.trunc_normal_(self.down_proj.weight, std=std, a=-3 * std, b=3 * std) |
|
|
|
|
| |
| |
| class SwiGLUTorch(nn.Module): |
| def __init__(self, in_dim, hidden_dim, out_dim, bias=True): |
| super().__init__() |
| self.w12 = nn.Linear(in_dim, 2 * hidden_dim, bias=bias) |
| self.w3 = nn.Linear(hidden_dim, out_dim, bias=bias) |
|
|
| def forward(self, x): |
| gate, x = self.w12(x).chunk(2, dim=-1) |
| x = F.silu(gate) * x |
| return self.w3(x) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, layer_id, args: Params): |
| super().__init__() |
| self.n_heads = args.n_heads |
| self.dim = args.dim |
|
|
| self.head_dim = args.dim // args.n_heads |
| self.attention = CustomAttn(layer_id, args) |
| self._ffn_type = args.ffn_type |
| if args.ffn_type == "swiglu": |
| |
| self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256) |
| self.feed_forward = xops.SwiGLU(args.dim, self.hidden_dim, args.dim, bias=False) |
| elif args.ffn_type == "swiglu_torch": |
| |
| self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256) |
| self.feed_forward = SwiGLUTorch(args.dim, self.hidden_dim, args.dim, bias=False) |
| elif args.ffn_type == "gelu": |
| |
| self.hidden_dim = args.dim * 4 |
| self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False) |
| self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False) |
| self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2) |
| elif args.ffn_type == "gemma_geglu": |
| |
| self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256) |
| self.feed_forward = GemmaMLP(args.dim, self.hidden_dim, layer_id) |
| elif args.ffn_type == "moe": |
| moe_args = MoEArgs( |
| hidden_size=args.dim, |
| ffn_hidden_size=args.dim * 4, |
| moe_num_experts=args.moe_num_experts, |
| moe_weight_parallelism=args.moe_weight_parallelism, |
| moe_expert_model_parallelism=args.moe_expert_model_parallelism, |
| moe_top_k=args.moe_top_k, |
| moe_capacity_factor=args.moe_capacity_factor, |
| moe_loss_weight=args.moe_loss_weight, |
| device=torch.cuda.current_device(), |
| bf16=False, |
| fp16=False, |
| ) |
| self.feed_forward = MoE(moe_args) |
|
|
| self.layer_id = layer_id |
| self.attention_norm = args.norm_type( |
| args.dim, |
| eps=args.norm_eps, |
| ) |
| self.ffn_norm = args.norm_type( |
| args.dim, |
| eps=args.norm_eps, |
| ) |
| self.attention.seq_len = args.seq_len |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| if self._ffn_type == "swiglu" or self._ffn_type == "swiglu_torch": |
| |
| std = 1.0 / math.sqrt(self.dim) |
| torch.nn.init.trunc_normal_(self.feed_forward.w12.weight, std=std, a=-3 * std, b=3 * std) |
| |
| std = 1.0 / math.sqrt(self.hidden_dim) |
| std = std / math.sqrt(2 * (self.layer_id + 1)) |
| torch.nn.init.trunc_normal_(self.feed_forward.w3.weight, std=std, a=-3 * std, b=3 * std) |
| elif self._ffn_type == "gelu": |
| std = 1.0 / math.sqrt(self.dim) |
| torch.nn.init.trunc_normal_(self._ff_w1.weight, std=std, a=-3 * std, b=3 * std) |
|
|
| std = 1.0 / math.sqrt(self.hidden_dim) |
| std = std / math.sqrt(2 * (self.layer_id + 1)) |
| torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std) |
|
|
| def forward(self, x, past_key_value=None, use_cache=False, attention_mask=None): |
| h, past_key_value = self.attention( |
| self.attention_norm(x), |
| is_causal=True, |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| attention_mask=attention_mask, |
| ) |
| h = x + h |
| if self._ffn_type == "moe": |
| ffn_out, _ = self.feed_forward(self.ffn_norm(h)) |
| else: |
| ffn_out = self.feed_forward(self.ffn_norm(h)) |
| out = h + ffn_out |
| return out, past_key_value |
|
|
|
|
| class Transformer(nn.Module, PyTorchModelHubMixin): |
| def __init__(self, params): |
| super().__init__() |
| |
| self.params = params |
| self.dim = params.dim |
| self.vocab_size = params.vocab_size |
| self.n_layers = params.n_layers |
| self.moe_num_experts = params.moe_num_experts |
| self.seq_len = params.seq_len |
| self.post_embed_norm = ( |
| params.norm_type( |
| params.dim, |
| eps=params.norm_eps, |
| ) |
| if params.post_embed_norm |
| else nn.Identity() |
| ) |
| self.weight_tying = params.weight_tying |
|
|
| self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) |
|
|
| self.layers = torch.nn.ModuleList() |
| ffn_type_ = params.ffn_type |
| for layer_id in range(params.n_layers): |
| if params.moe_freq > 0 and layer_id % params.moe_freq == 0: |
| params.ffn_type = "moe" |
| else: |
| params.ffn_type = ffn_type_ |
| self.layers.append(Block(layer_id, params)) |
|
|
| |
| self.norm = params.norm_type( |
| params.dim, |
| eps=params.norm_eps, |
| ) |
| self.output = nn.Linear(params.dim, params.vocab_size, bias=False) |
| if self.weight_tying: |
| self.tok_embeddings.weight = self.output.weight |
| self.grad_checkpointing = False |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| |
| |
| |
| std = 1.0 / math.sqrt(self.params.dim) |
| torch.nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std) |
| torch.nn.init.trunc_normal_(self.tok_embeddings.weight, std=std, a=-3 * std, b=3 * std) |
|
|
| @torch.jit.ignore |
| def set_grad_checkpointing(self, enable=True): |
| self.grad_checkpointing = enable |
|
|
| def forward(self, input_ids=None, inputs_embeds=None, past_key_values=None, use_cache=False, attention_mask=None): |
| """ |
| Args: |
| input |
| past_key_values |
| use_cache (bool) |
| attention_mask (torch.Tensor): Shape (batch_size, sequence_len), indicates tokens that should not be |
| attended to. attention_mask[s, i] = False indicates that token i should not be attended to by any other |
| token for sequence s. |
| """ |
| if input_ids is not None: |
| x = self.tok_embeddings(input_ids) |
| elif inputs_embeds is not None: |
| x = inputs_embeds |
| else: |
| raise ValueError("Either input_ids or inputs_embeds must be provided.") |
|
|
| x = self.post_embed_norm(x) |
|
|
| if past_key_values is None: |
| past_key_values = [None] * self.n_layers |
| elif isinstance(past_key_values, tuple): |
| past_key_values = list(past_key_values) |
| for i, layer in enumerate(self.layers): |
| if self.grad_checkpointing: |
| x, past_key_values[i] = checkpoint(layer, x, past_key_values[i], use_cache, attention_mask) |
| else: |
| x, past_key_values[i] = layer(x, past_key_values[i], use_cache=use_cache, attention_mask=attention_mask) |
| if past_key_values[0] is None: |
| past_key_values = None |
| x = self.norm(x) |
| output = self.output(x) |
| |
| return output.float(), x, past_key_values |
|
|
| def get_input_embeddings(self): |
| return self.tok_embeddings |
|
|
| def get_output_embeddings(self): |
| return self.output |
|
|
|
|
| def create_params(args): |
| cfg = None |
|
|
| if args.model.endswith(".json"): |
| _rescan_model_configs(model_config_paths=args.model) |
| args.model = Path(args.model).stem |
| |
| if args.model in _MODEL_CONFIGS: |
| cfg = deepcopy(_MODEL_CONFIGS[args.model]) |
| else: |
| raise ValueError("Pass a pre-defined open_lm model name or a json config") |
|
|
| |
| |
| |
| |
| |
|
|
| if "mamba" in args.model: |
| return { |
| "d_model": cfg["d_model"], |
| "n_layer": cfg["n_layer"], |
| "vocab_size": cfg["vocab_size"], |
| "seq_len": cfg["seq_len"], |
| } |
| else: |
| return Params( |
| dim=cfg["hidden_dim"], |
| n_layers=cfg["n_layers"], |
| n_heads=cfg["n_heads"], |
| seq_len=cfg["seq_len"], |
| vocab_size=cfg["vocab_size"], |
| post_embed_norm=cfg["post_embed_norm"], |
| weight_tying=cfg["weight_tying"], |
| norm_type=get_norm_class(cfg.get("model_norm", args.model_norm)), |
| attn_func=get_attn_func( |
| args.attn_name, args.attn_activation, args.attn_seq_scalar, args.attn_seq_scalar_alpha |
| ), |
| apply_qk_norm=cfg.get("qk_norm", args.qk_norm), |
| positional_embedding_type=cfg.get("positional_embedding_type", args.positional_embedding_type), |
| ffn_type=cfg.get("ffn_type", args.ffn_type), |
| moe_num_experts=cfg.get("moe_num_experts", args.moe_num_experts), |
| moe_loss_weight=cfg.get("moe_loss_weight", args.moe_loss_weight), |
| moe_expert_model_parallelism=cfg.get("moe_expert_model_parallelism", args.moe_expert_model_parallelism), |
| moe_weight_parallelism=cfg.get("moe_weight_parallelism", args.moe_weight_parallelism), |
| moe_capacity_factor=cfg.get("moe_capacity_factor", args.moe_capacity_factor), |
| moe_freq=cfg.get("moe_freq", args.moe_freq), |
| moe_top_k=cfg.get("moe_top_k", args.moe_top_k), |
| ) |
|
|
|
|
| class Mamba(nn.Module): |
| |
| |
| def __init__(self, params): |
| if MambaLMHeadModel is None: |
| raise ImportError( |
| "MambaLMHeadModel is not available. Please install the 'mamba_ssm' package by running 'pip install mamba-ssm'." |
| ) |
|
|
| super().__init__() |
| self.seq_len = params.pop("seq_len") |
| self.vocab_size = params["vocab_size"] |
|
|
| self.model = MambaLMHeadModel(**params) |
|
|
| def reset_parameters(self): |
| return |
|
|
| def forward(self, x): |
| out = self.model(x).logits |
| return out, None, None |
|
|
|
|
| def create_model(args): |
| if "mamba" in args.model: |
| model = Mamba(create_params(args)) |
| return model |
| else: |
| model = Transformer(create_params(args)) |
| return model |
|
|