diff --git a/fla2/layers/__pycache__/attn.cpython-39.pyc b/fla2/layers/__pycache__/attn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bffe718052689cc20e2019d3e04442b7ed7b6dda Binary files /dev/null and b/fla2/layers/__pycache__/attn.cpython-39.pyc differ diff --git a/fla2/layers/__pycache__/based.cpython-312.pyc b/fla2/layers/__pycache__/based.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88af3d80df7610aab44450b510c36e34868e8cbd Binary files /dev/null and b/fla2/layers/__pycache__/based.cpython-312.pyc differ diff --git a/fla2/layers/__pycache__/based.cpython-39.pyc b/fla2/layers/__pycache__/based.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..935db78e8986e9cf8c627fd3812eec1e33292ace Binary files /dev/null and b/fla2/layers/__pycache__/based.cpython-39.pyc differ diff --git a/fla2/layers/__pycache__/delta_net.cpython-38.pyc b/fla2/layers/__pycache__/delta_net.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b431f3a64fdaef5017cdbec890768653256109e Binary files /dev/null and b/fla2/layers/__pycache__/delta_net.cpython-38.pyc differ diff --git a/fla2/layers/__pycache__/delta_net.cpython-39.pyc b/fla2/layers/__pycache__/delta_net.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21c9855b003c6f0e1c7f2ec686e1168517c77609 Binary files /dev/null and b/fla2/layers/__pycache__/delta_net.cpython-39.pyc differ diff --git a/fla2/layers/__pycache__/emdeltanet.cpython-310.pyc b/fla2/layers/__pycache__/emdeltanet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5479d4ba0731a3d84f5f4345a3022564256b22d8 Binary files /dev/null and b/fla2/layers/__pycache__/emdeltanet.cpython-310.pyc differ diff --git a/fla2/layers/__pycache__/emla.cpython-310.pyc b/fla2/layers/__pycache__/emla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca0483151da2c872fb86c63ec8a265e39bb738c3 Binary files /dev/null and b/fla2/layers/__pycache__/emla.cpython-310.pyc differ diff --git a/fla2/layers/__pycache__/emla.cpython-38.pyc b/fla2/layers/__pycache__/emla.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d01ae333f880cc4cc0f739cd7805d9a202f9d889 Binary files /dev/null and b/fla2/layers/__pycache__/emla.cpython-38.pyc differ diff --git a/fla2/layers/__pycache__/gla.cpython-38.pyc b/fla2/layers/__pycache__/gla.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dbb87d049450e01d3f55a438d3a4e9434cd78f3 Binary files /dev/null and b/fla2/layers/__pycache__/gla.cpython-38.pyc differ diff --git a/fla2/layers/__pycache__/gsa.cpython-312.pyc b/fla2/layers/__pycache__/gsa.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f96cb9aa242e5775f12d62e692f085efdfb24045 Binary files /dev/null and b/fla2/layers/__pycache__/gsa.cpython-312.pyc differ diff --git a/fla2/layers/__pycache__/gsa.cpython-38.pyc b/fla2/layers/__pycache__/gsa.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4e31f7d52290d6ef0da032f2a72b2ed2e6ea1b7 Binary files /dev/null and b/fla2/layers/__pycache__/gsa.cpython-38.pyc differ diff --git a/fla2/layers/__pycache__/hgrn.cpython-39.pyc b/fla2/layers/__pycache__/hgrn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f8d970efef2f01acfc224f0cac98e33c8afae0b Binary files /dev/null and b/fla2/layers/__pycache__/hgrn.cpython-39.pyc differ diff --git a/fla2/layers/__pycache__/hgrn2.cpython-39.pyc b/fla2/layers/__pycache__/hgrn2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a22fa32e7681ab166d89a27542869036847ffe0 Binary files /dev/null and b/fla2/layers/__pycache__/hgrn2.cpython-39.pyc differ diff --git a/fla2/layers/__pycache__/mask_deltanet.cpython-312.pyc b/fla2/layers/__pycache__/mask_deltanet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e67fab89de7287ec55dac3c23b2501a1d7b6b357 Binary files /dev/null and b/fla2/layers/__pycache__/mask_deltanet.cpython-312.pyc differ diff --git a/fla2/layers/__pycache__/mask_gdn.cpython-312.pyc b/fla2/layers/__pycache__/mask_gdn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e34a15c75f0e7418be0b7afb59fff893e8a7dbf7 Binary files /dev/null and b/fla2/layers/__pycache__/mask_gdn.cpython-312.pyc differ diff --git a/fla2/layers/__pycache__/multiscale_retention.cpython-312.pyc b/fla2/layers/__pycache__/multiscale_retention.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab4a7405f3418480f4620f5fee668dddd636b34e Binary files /dev/null and b/fla2/layers/__pycache__/multiscale_retention.cpython-312.pyc differ diff --git a/fla2/layers/__pycache__/rebased.cpython-312.pyc b/fla2/layers/__pycache__/rebased.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c2ea84121af018a2fac1e08fd06ab50080b07bb Binary files /dev/null and b/fla2/layers/__pycache__/rebased.cpython-312.pyc differ diff --git a/fla2/layers/__pycache__/rwkv6.cpython-312.pyc b/fla2/layers/__pycache__/rwkv6.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..941b0019825f77e057c17db1a3c5e41de204f70e Binary files /dev/null and b/fla2/layers/__pycache__/rwkv6.cpython-312.pyc differ diff --git a/fla2/layers/__pycache__/rwkv6.cpython-38.pyc b/fla2/layers/__pycache__/rwkv6.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0077a4cb32ecb8820cbad8b098f33d4392db83d5 Binary files /dev/null and b/fla2/layers/__pycache__/rwkv6.cpython-38.pyc differ diff --git a/fla2/layers/__pycache__/rwkv6.cpython-39.pyc b/fla2/layers/__pycache__/rwkv6.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24cc7abc7339e7a1f0da88200bad87c79f9d00fd Binary files /dev/null and b/fla2/layers/__pycache__/rwkv6.cpython-39.pyc differ diff --git a/fla2/models/abc/__pycache__/configuration_abc.cpython-312.pyc b/fla2/models/abc/__pycache__/configuration_abc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15374dc5b1af6b4130c1246b2a30d5da130da8fe Binary files /dev/null and b/fla2/models/abc/__pycache__/configuration_abc.cpython-312.pyc differ diff --git a/fla2/models/abc/__pycache__/configuration_abc.cpython-38.pyc b/fla2/models/abc/__pycache__/configuration_abc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13e0e0028910ce7dfd94c41ef45d3f82817d534f Binary files /dev/null and b/fla2/models/abc/__pycache__/configuration_abc.cpython-38.pyc differ diff --git a/fla2/models/abc/__pycache__/modeling_abc.cpython-39.pyc b/fla2/models/abc/__pycache__/modeling_abc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..923ac4106dd47e22a820dcdaa43d48e5ef6a4b24 Binary files /dev/null and b/fla2/models/abc/__pycache__/modeling_abc.cpython-39.pyc differ diff --git a/fla2/models/delta_net/__pycache__/__init__.cpython-38.pyc b/fla2/models/delta_net/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d64cccd2d1311e54f4ffffcdc524449f159dc00 Binary files /dev/null and b/fla2/models/delta_net/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/models/delta_net/__pycache__/configuration_delta_net.cpython-39.pyc b/fla2/models/delta_net/__pycache__/configuration_delta_net.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ada0ccef1c1a01403210465a1d521572e28ba124 Binary files /dev/null and b/fla2/models/delta_net/__pycache__/configuration_delta_net.cpython-39.pyc differ diff --git a/fla2/models/delta_net/__pycache__/modeling_delta_net.cpython-312.pyc b/fla2/models/delta_net/__pycache__/modeling_delta_net.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbba132c4471c09f8017b7e693ec77bfaf2da7ab Binary files /dev/null and b/fla2/models/delta_net/__pycache__/modeling_delta_net.cpython-312.pyc differ diff --git a/fla2/models/delta_net/__pycache__/modeling_delta_net.cpython-38.pyc b/fla2/models/delta_net/__pycache__/modeling_delta_net.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d700709e42c54cea53385652a24855e145d0bf9e Binary files /dev/null and b/fla2/models/delta_net/__pycache__/modeling_delta_net.cpython-38.pyc differ diff --git a/fla2/models/delta_net/__pycache__/modeling_delta_net.cpython-39.pyc b/fla2/models/delta_net/__pycache__/modeling_delta_net.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11e4b6092f9ebcf1903c83273cb215cfab6ba7b7 Binary files /dev/null and b/fla2/models/delta_net/__pycache__/modeling_delta_net.cpython-39.pyc differ diff --git a/fla2/models/emdeltanet/__init__.py b/fla2/models/emdeltanet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9230484e43ba29d230276e0407b9d97cc9aaa93a --- /dev/null +++ b/fla2/models/emdeltanet/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from .configuration_emdeltanet import emdeltanetConfig +from .modeling_emdeltanet import emdeltanetForCausalLM, emdeltanetModel + +AutoConfig.register(emdeltanetConfig.model_type, emdeltanetConfig) +AutoModel.register(emdeltanetConfig, emdeltanetModel) +AutoModelForCausalLM.register(emdeltanetConfig, emdeltanetForCausalLM) + +__all__ = ['emdeltanetConfig', 'emdeltanetForCausalLM', 'emdeltanetModel'] diff --git a/fla2/models/emdeltanet/__pycache__/__init__.cpython-312.pyc b/fla2/models/emdeltanet/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b5f20ffa8e738b08d2b009705c82b252f1abe18 Binary files /dev/null and b/fla2/models/emdeltanet/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/models/emdeltanet/__pycache__/configuration_emdeltanet.cpython-310.pyc b/fla2/models/emdeltanet/__pycache__/configuration_emdeltanet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53e484b37cea42c743a3291b90823b3bcf722e90 Binary files /dev/null and b/fla2/models/emdeltanet/__pycache__/configuration_emdeltanet.cpython-310.pyc differ diff --git a/fla2/models/emdeltanet/__pycache__/configuration_emdeltanet.cpython-312.pyc b/fla2/models/emdeltanet/__pycache__/configuration_emdeltanet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..936c59211678c79d020b9af30bc180bc1c407e02 Binary files /dev/null and b/fla2/models/emdeltanet/__pycache__/configuration_emdeltanet.cpython-312.pyc differ diff --git a/fla2/models/emdeltanet/__pycache__/configuration_emgla.cpython-310.pyc b/fla2/models/emdeltanet/__pycache__/configuration_emgla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9c664fd9d4914d4f96f85972368580a45bea817 Binary files /dev/null and b/fla2/models/emdeltanet/__pycache__/configuration_emgla.cpython-310.pyc differ diff --git a/fla2/models/emdeltanet/__pycache__/modeling_emdeltanet.cpython-310.pyc b/fla2/models/emdeltanet/__pycache__/modeling_emdeltanet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4239c45100ed9a6e42879bf733538c9f87e6527 Binary files /dev/null and b/fla2/models/emdeltanet/__pycache__/modeling_emdeltanet.cpython-310.pyc differ diff --git a/fla2/models/emdeltanet/__pycache__/modeling_emdeltanet.cpython-312.pyc b/fla2/models/emdeltanet/__pycache__/modeling_emdeltanet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b86c5c70db19112ea71f2a008b7d4a72942ccd28 Binary files /dev/null and b/fla2/models/emdeltanet/__pycache__/modeling_emdeltanet.cpython-312.pyc differ diff --git a/fla2/models/emdeltanet/__pycache__/modeling_emgla.cpython-310.pyc b/fla2/models/emdeltanet/__pycache__/modeling_emgla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b15c2e90086bd98e7bb2c3a4dbc409f774ac8649 Binary files /dev/null and b/fla2/models/emdeltanet/__pycache__/modeling_emgla.cpython-310.pyc differ diff --git a/fla2/models/emdeltanet/modeling_emdeltanet.py b/fla2/models/emdeltanet/modeling_emdeltanet.py new file mode 100644 index 0000000000000000000000000000000000000000..75a780d5af987c2f52fdd2f8e0b1628f97430671 --- /dev/null +++ b/fla2/models/emdeltanet/modeling_emdeltanet.py @@ -0,0 +1,535 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union,Iterable +from einops import rearrange +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from ...layers.attn import Attention +from ...layers.emdeltanet import emdeltanet +from ...models.emdeltanet.configuration_emdeltanet import emdeltanetConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as emdeltanetMLP +from ...modules import RMSNorm +from ...modules import RotaryEmbedding +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +class emdeltanetBlock(nn.Module): + def __init__(self, config: emdeltanetConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + print(config) + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = emdeltanet( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_short_conv=config.use_short_conv, + use_output_norm=config.use_output_norm, + conv_size=config.conv_size, + norm_eps=config.norm_eps, + ratio = config.ratio, + topk = config.topk, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = emdeltanetMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values,router_logits = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values,router_logits) + + return outputs + + +class emdeltanetPreTrainedModel(PreTrainedModel): + + config_class = emdeltanetConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['emdeltanetBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class emdeltanetModel(emdeltanetPreTrainedModel): + + def __init__(self, config: emdeltanetConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([emdeltanetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`emdeltanetModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + all_router_logits = () + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values, router_logits = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values, router_logits = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + all_router_logits += (router_logits,) + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + # return BaseModelOutputWithPast( + # last_hidden_state=hidden_states, + # past_key_values=past_key_values, + # hidden_states=all_hidden_states, + # attentions=all_attns + # ) + return emdeltanetOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns, + router_logits=all_router_logits + ) + + +from dataclasses import dataclass +@dataclass +class emdeltanetOutputWithPast(BaseModelOutputWithPast): + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class emdeltanetCausalLMOutputWithPast(CausalLMOutputWithPast): + aux_loss: Optional[torch.FloatTensor] = None + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple], + num_memories: torch.Tensor = None, + top_k=2, + use_layer_wise_balance=False, +) -> torch.FloatTensor: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_memories]. + num_memories (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or ( + isinstance(gate_logits, Iterable) and len(gate_logits) == 0 + ): + return 0 + + # ✨ Here is the fix for balance loss in Mixtral. + # We should calculate the balance loss in a layer-wise manner otherwise it may lead to degenerated solutions. + if use_layer_wise_balance: + if not isinstance(gate_logits, Iterable): + gate_logits = (gate_logits,) + else: + if isinstance(gate_logits, Iterable): + gate_logits = (torch.cat(gate_logits, dim=0),) + else: + gate_logits = (gate_logits,) + + all_balance_losses = [] + + for logits in gate_logits: + if logits.dim() == 4: + logits = rearrange(logits,'b h l r-> (b h) l r') + routing_weights, selected_experts = torch.topk(logits, top_k, dim=-1) + routing_weights = routing_weights.softmax(dim=-1).to(logits.dtype) + routing_weights_full = torch.zeros_like(logits).scatter(-1, selected_experts, routing_weights) + + # cast the expert indices to int64, otherwise one-hot encoding will fail + if selected_experts.dtype != torch.int64: + selected_experts = selected_experts.to(torch.int64) + + if len(selected_experts.shape) == 2: + selected_experts = selected_experts.unsqueeze(2) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_memories) + + # For a given token, determine if it was routed to a given expert. + expert_mask = torch.max(expert_mask, axis=-2).values + + # cast to float32 otherwise mean will fail + expert_mask = expert_mask.to(torch.float32) + tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) + + router_prob_per_group_and_expert = torch.mean(routing_weights_full, axis=-2) + + # ✨ balance loss for this layer + balance_loss = torch.mean( + tokens_per_group_and_expert * router_prob_per_group_and_expert + ) * (num_memories**2) + all_balance_losses.append(balance_loss.reshape(1)) + + all_balance_losses = torch.cat(all_balance_losses).mean() # ✨ + + return all_balance_losses + + +class emdeltanetForCausalLM(emdeltanetPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = emdeltanetModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + valid_router_logits = tuple( + logits + for logits in (outputs.router_logits if return_dict else outputs[-1]) + if logits is not None + ) + aux_loss = load_balancing_loss_func( + valid_router_logits, + self.config.ratio, + self.config.topk, + use_layer_wise_balance=True, + ) + aux_loss *= self.config.aux_loss_scale + # print('aux_loss:',aux_loss) + if aux_loss: + loss += aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + # return CausalLMOutputWithPast( + # loss=loss, + # logits=logits, + # past_key_values=outputs.past_key_values, + # hidden_states=outputs.hidden_states, + # attentions=outputs.attentions, + # ) + + return emdeltanetCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + aux_loss=aux_loss + ) \ No newline at end of file diff --git a/fla2/models/emgla-noaux/__init__.py b/fla2/models/emgla-noaux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..343b9f41073201acffd5c6a2dbd7c40a1ecb2fdd --- /dev/null +++ b/fla2/models/emgla-noaux/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from .configuration_emgla import emglaConfig +from .modeling_emgla import emglaForCausalLM, emglaModel + +AutoConfig.register(emglaConfig.model_type, emglaConfig) +AutoModel.register(emglaConfig, emglaModel) +AutoModelForCausalLM.register(emglaConfig, emglaForCausalLM) + +__all__ = ['emglaConfig', 'emglaForCausalLM', 'emglaModel'] diff --git a/fla2/models/emgla-noaux/__pycache__/__init__.cpython-310.pyc b/fla2/models/emgla-noaux/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd5e1f5a1495aa395ff4609f8744310bccc69de2 Binary files /dev/null and b/fla2/models/emgla-noaux/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla2/models/emgla-noaux/__pycache__/configuration_emgla.cpython-310.pyc b/fla2/models/emgla-noaux/__pycache__/configuration_emgla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..419751705abc5d01d7cedd1dd53dfae59e520adc Binary files /dev/null and b/fla2/models/emgla-noaux/__pycache__/configuration_emgla.cpython-310.pyc differ diff --git a/fla2/models/emgla-noaux/__pycache__/modeling_emgla.cpython-310.pyc b/fla2/models/emgla-noaux/__pycache__/modeling_emgla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12332ab4cd9d3c64deb8d98658d97c11f3334515 Binary files /dev/null and b/fla2/models/emgla-noaux/__pycache__/modeling_emgla.cpython-310.pyc differ diff --git a/fla2/models/emgla-noaux/configuration_emgla.py b/fla2/models/emgla-noaux/configuration_emgla.py new file mode 100644 index 0000000000000000000000000000000000000000..36432b379cd47b26bc9688f299c01d420f979956 --- /dev/null +++ b/fla2/models/emgla-noaux/configuration_emgla.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + + +class emglaConfig(PretrainedConfig): + + model_type = 'emgla' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + use_gate: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + use_beta: bool = True, + use_output_norm: bool = True, + num_heads: int = 16, + qk_norm: str = 'l2', + qk_activation: str = 'silu', + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + num_hidden_layers: int = 24, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + ratio : int = 2, + top_k : int = 1, + **kwargs + ): + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_beta = use_beta + self.use_output_norm = use_output_norm + self.num_heads = num_heads + self.qk_norm = qk_norm + self.qk_activation = qk_activation + self.max_position_embeddings = max_position_embeddings + self.topk = top_k + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + self.ratio = ratio + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla2/models/emgla-noaux/modeling_emgla.py b/fla2/models/emgla-noaux/modeling_emgla.py new file mode 100644 index 0000000000000000000000000000000000000000..c164999d7fabe3e0881a65809a645589e4b6a002 --- /dev/null +++ b/fla2/models/emgla-noaux/modeling_emgla.py @@ -0,0 +1,414 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from ...layers.attn import Attention +from ...layers.emgla import emgla +from ...models.emgla.configuration_emgla import emglaConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as emglaMLP +from ...modules import RMSNorm +from ...modules import RotaryEmbedding +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +class emglaBlock(nn.Module): + def __init__(self, config: emglaConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = emgla( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_short_conv=config.use_short_conv, + use_output_norm=config.use_output_norm, + conv_size=config.conv_size, + norm_eps=config.norm_eps, + ratio = config.ratio, + top_k = config.topk, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = emglaMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class emglaPreTrainedModel(PreTrainedModel): + + config_class = emglaConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['emglaBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class emglaModel(emglaPreTrainedModel): + + def __init__(self, config: emglaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([emglaBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`emglaModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class emglaForCausalLM(emglaPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = emglaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla2/models/emgla/__init__.py b/fla2/models/emgla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..343b9f41073201acffd5c6a2dbd7c40a1ecb2fdd --- /dev/null +++ b/fla2/models/emgla/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from .configuration_emgla import emglaConfig +from .modeling_emgla import emglaForCausalLM, emglaModel + +AutoConfig.register(emglaConfig.model_type, emglaConfig) +AutoModel.register(emglaConfig, emglaModel) +AutoModelForCausalLM.register(emglaConfig, emglaForCausalLM) + +__all__ = ['emglaConfig', 'emglaForCausalLM', 'emglaModel'] diff --git a/fla2/models/emgla/__pycache__/__init__.cpython-310.pyc b/fla2/models/emgla/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..163decd8df58835371ef7e86920df660e42b175d Binary files /dev/null and b/fla2/models/emgla/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla2/models/emgla/__pycache__/__init__.cpython-312.pyc b/fla2/models/emgla/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e366d8d10511ba4a82925da57bb22234a45ad38 Binary files /dev/null and b/fla2/models/emgla/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/models/emgla/__pycache__/configuration_emgla.cpython-310.pyc b/fla2/models/emgla/__pycache__/configuration_emgla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34bb7510b4b4a44595f3fde649777e59fdf26ae7 Binary files /dev/null and b/fla2/models/emgla/__pycache__/configuration_emgla.cpython-310.pyc differ diff --git a/fla2/models/emgla/__pycache__/configuration_emgla.cpython-312.pyc b/fla2/models/emgla/__pycache__/configuration_emgla.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfbb49c59e23119f9e8de67f3bc0ad9300178916 Binary files /dev/null and b/fla2/models/emgla/__pycache__/configuration_emgla.cpython-312.pyc differ diff --git a/fla2/models/emgla/__pycache__/modeling_emgla.cpython-310.pyc b/fla2/models/emgla/__pycache__/modeling_emgla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da13b7aa96e5f6bb4faf0321bfafafe5948be728 Binary files /dev/null and b/fla2/models/emgla/__pycache__/modeling_emgla.cpython-310.pyc differ diff --git a/fla2/models/emgla/__pycache__/modeling_emgla.cpython-312.pyc b/fla2/models/emgla/__pycache__/modeling_emgla.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3949b4811eaa67fd0f351c12c0fb55031769f665 Binary files /dev/null and b/fla2/models/emgla/__pycache__/modeling_emgla.cpython-312.pyc differ diff --git a/fla2/models/emgla/configuration_emgla.py b/fla2/models/emgla/configuration_emgla.py new file mode 100644 index 0000000000000000000000000000000000000000..3f7740105f6f8b76adedba3328baa805c2d102fa --- /dev/null +++ b/fla2/models/emgla/configuration_emgla.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + + +class emglaConfig(PretrainedConfig): + + model_type = 'emgla' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + use_gate: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + use_beta: bool = True, + use_output_norm: bool = True, + num_heads: int = 16, + qk_norm: str = 'l2', + qk_activation: str = 'silu', + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + num_hidden_layers: int = 24, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + aux_loss_scale:float = 0.01, + ratio : int = 2, + topk : int = 1, + **kwargs + ): + self.attn_mode = attn_mode + self.aux_loss_scale = aux_loss_scale + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_beta = use_beta + self.use_output_norm = use_output_norm + self.num_heads = num_heads + self.qk_norm = qk_norm + self.qk_activation = qk_activation + self.max_position_embeddings = max_position_embeddings + + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + self.ratio = ratio + self.topk = topk + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla2/models/emgla/modeling_emgla.py b/fla2/models/emgla/modeling_emgla.py new file mode 100644 index 0000000000000000000000000000000000000000..e394df093ccae97fbeae4a41c037bd37e88bad8e --- /dev/null +++ b/fla2/models/emgla/modeling_emgla.py @@ -0,0 +1,534 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union,Iterable +from einops import rearrange +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from ...layers.attn import Attention +from ...layers.emgla import emgla +from ...models.emgla.configuration_emgla import emglaConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as emglaMLP +from ...modules import RMSNorm +from ...modules import RotaryEmbedding +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +class emglaBlock(nn.Module): + def __init__(self, config: emglaConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + print(config) + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = emgla( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_short_conv=config.use_short_conv, + use_output_norm=config.use_output_norm, + conv_size=config.conv_size, + norm_eps=config.norm_eps, + ratio = config.ratio, + topk = config.topk, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = emglaMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values,router_logits = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values,router_logits) + + return outputs + + +class emglaPreTrainedModel(PreTrainedModel): + + config_class = emglaConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['emglaBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class emglaModel(emglaPreTrainedModel): + + def __init__(self, config: emglaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([emglaBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`emglaModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + all_router_logits = () + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values, router_logits = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values, router_logits = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + all_router_logits += (router_logits,) + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + # return BaseModelOutputWithPast( + # last_hidden_state=hidden_states, + # past_key_values=past_key_values, + # hidden_states=all_hidden_states, + # attentions=all_attns + # ) + return emglaOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns, + router_logits=all_router_logits + ) + + +from dataclasses import dataclass +@dataclass +class emglaOutputWithPast(BaseModelOutputWithPast): + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class emglaCausalLMOutputWithPast(CausalLMOutputWithPast): + aux_loss: Optional[torch.FloatTensor] = None + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple], + num_memories: torch.Tensor = None, + top_k=2, + use_layer_wise_balance=False, +) -> torch.FloatTensor: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_memories]. + num_memories (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or ( + isinstance(gate_logits, Iterable) and len(gate_logits) == 0 + ): + return 0 + + # ✨ Here is the fix for balance loss in Mixtral. + # We should calculate the balance loss in a layer-wise manner otherwise it may lead to degenerated solutions. + if use_layer_wise_balance: + if not isinstance(gate_logits, Iterable): + gate_logits = (gate_logits,) + else: + if isinstance(gate_logits, Iterable): + gate_logits = (torch.cat(gate_logits, dim=0),) + else: + gate_logits = (gate_logits,) + + all_balance_losses = [] + + for logits in gate_logits: + if logits.dim() == 4: + logits = rearrange(logits,'b h l r-> (b h) l r') + routing_weights, selected_experts = torch.topk(logits, top_k, dim=-1) + routing_weights = routing_weights.softmax(dim=-1).to(logits.dtype) + routing_weights_full = torch.zeros_like(logits).scatter(-1, selected_experts, routing_weights) + + # cast the expert indices to int64, otherwise one-hot encoding will fail + if selected_experts.dtype != torch.int64: + selected_experts = selected_experts.to(torch.int64) + + if len(selected_experts.shape) == 2: + selected_experts = selected_experts.unsqueeze(2) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_memories) + + # For a given token, determine if it was routed to a given expert. + expert_mask = torch.max(expert_mask, axis=-2).values + + # cast to float32 otherwise mean will fail + expert_mask = expert_mask.to(torch.float32) + tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) + + router_prob_per_group_and_expert = torch.mean(routing_weights_full, axis=-2) + + # ✨ balance loss for this layer + balance_loss = torch.mean( + tokens_per_group_and_expert * router_prob_per_group_and_expert + ) * (num_memories**2) + all_balance_losses.append(balance_loss.reshape(1)) + + all_balance_losses = torch.cat(all_balance_losses).mean() # ✨ + + return all_balance_losses + + +class emglaForCausalLM(emglaPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = emglaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + valid_router_logits = tuple( + logits + for logits in (outputs.router_logits if return_dict else outputs[-1]) + if logits is not None + ) + aux_loss = load_balancing_loss_func( + valid_router_logits, + self.config.ratio, + self.config.topk, + use_layer_wise_balance=True, + ) + aux_loss *= self.config.aux_loss_scale + # print('aux_loss:',aux_loss) + loss += aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + # return CausalLMOutputWithPast( + # loss=loss, + # logits=logits, + # past_key_values=outputs.past_key_values, + # hidden_states=outputs.hidden_states, + # attentions=outputs.attentions, + # ) + + return emglaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + aux_loss=aux_loss + ) \ No newline at end of file diff --git a/fla2/models/emla-noaux/__init__.py b/fla2/models/emla-noaux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d70f0a70375f4e940daf9ad8cca5ae089918bfda --- /dev/null +++ b/fla2/models/emla-noaux/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from .configuration_emla import emlaConfig +from .modeling_emla import emlaForCausalLM, emlaModel + +AutoConfig.register(emlaConfig.model_type, emlaConfig) +AutoModel.register(emlaConfig, emlaModel) +AutoModelForCausalLM.register(emlaConfig, emlaForCausalLM) + +__all__ = ['emlaConfig', 'emlaForCausalLM', 'emlaModel'] diff --git a/fla2/models/emla-noaux/__pycache__/__init__.cpython-310.pyc b/fla2/models/emla-noaux/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cec329e7970f5f20429024276f7d82b4f44e25d Binary files /dev/null and b/fla2/models/emla-noaux/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla2/models/emla-noaux/__pycache__/__init__.cpython-38.pyc b/fla2/models/emla-noaux/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d67f6d09c84e0cc0933f3d05e42663ca58b53c6f Binary files /dev/null and b/fla2/models/emla-noaux/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/models/emla-noaux/__pycache__/configuration_emla.cpython-310.pyc b/fla2/models/emla-noaux/__pycache__/configuration_emla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1af87c67c30e1a5079bd965debb038c267e8f79b Binary files /dev/null and b/fla2/models/emla-noaux/__pycache__/configuration_emla.cpython-310.pyc differ diff --git a/fla2/models/emla-noaux/__pycache__/configuration_emla.cpython-38.pyc b/fla2/models/emla-noaux/__pycache__/configuration_emla.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4587c645ec03ed84b1b717e4815862d12987dc89 Binary files /dev/null and b/fla2/models/emla-noaux/__pycache__/configuration_emla.cpython-38.pyc differ diff --git a/fla2/models/emla-noaux/__pycache__/modeling_emla.cpython-310.pyc b/fla2/models/emla-noaux/__pycache__/modeling_emla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0af4fa0799e4a035db4e463bb8a036af71438ab Binary files /dev/null and b/fla2/models/emla-noaux/__pycache__/modeling_emla.cpython-310.pyc differ diff --git a/fla2/models/emla-noaux/__pycache__/modeling_emla.cpython-38.pyc b/fla2/models/emla-noaux/__pycache__/modeling_emla.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ea88399d0df732f34e6029913e82bd24846e7fc Binary files /dev/null and b/fla2/models/emla-noaux/__pycache__/modeling_emla.cpython-38.pyc differ diff --git a/fla2/models/emla-noaux/configuration_emla.py b/fla2/models/emla-noaux/configuration_emla.py new file mode 100644 index 0000000000000000000000000000000000000000..bb58bf401bcbbc3dd95f072e9358080ba6b54fab --- /dev/null +++ b/fla2/models/emla-noaux/configuration_emla.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + + +class emlaConfig(PretrainedConfig): + + model_type = 'emla' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + use_gate: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + use_beta: bool = True, + use_output_norm: bool = True, + num_heads: int = 16, + qk_norm: str = 'l2', + qk_activation: str = 'silu', + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + num_hidden_layers: int = 24, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + ratio : int = 2, + **kwargs + ): + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_beta = use_beta + self.use_output_norm = use_output_norm + self.num_heads = num_heads + self.qk_norm = qk_norm + self.qk_activation = qk_activation + self.max_position_embeddings = max_position_embeddings + + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + self.ratio = ratio + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla2/models/emla-noaux/modeling_emla.py b/fla2/models/emla-noaux/modeling_emla.py new file mode 100644 index 0000000000000000000000000000000000000000..84443211f22bd130e6d9926d1098f035d9a15bd9 --- /dev/null +++ b/fla2/models/emla-noaux/modeling_emla.py @@ -0,0 +1,413 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from ...layers.attn import Attention +from ...layers.emla import emla +from ...models.emla.configuration_emla import emlaConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as emlaMLP +from ...modules import RMSNorm +from ...modules import RotaryEmbedding +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +class emlaBlock(nn.Module): + def __init__(self, config: emlaConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = emla( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_short_conv=config.use_short_conv, + use_output_norm=config.use_output_norm, + conv_size=config.conv_size, + norm_eps=config.norm_eps, + ratio = config.ratio, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = emlaMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class emlaPreTrainedModel(PreTrainedModel): + + config_class = emlaConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['emlaBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class emlaModel(emlaPreTrainedModel): + + def __init__(self, config: emlaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([emlaBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`emlaModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class emlaForCausalLM(emlaPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = emlaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla2/models/emla/__init__.py b/fla2/models/emla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d70f0a70375f4e940daf9ad8cca5ae089918bfda --- /dev/null +++ b/fla2/models/emla/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from .configuration_emla import emlaConfig +from .modeling_emla import emlaForCausalLM, emlaModel + +AutoConfig.register(emlaConfig.model_type, emlaConfig) +AutoModel.register(emlaConfig, emlaModel) +AutoModelForCausalLM.register(emlaConfig, emlaForCausalLM) + +__all__ = ['emlaConfig', 'emlaForCausalLM', 'emlaModel'] diff --git a/fla2/models/emla/__pycache__/__init__.cpython-310.pyc b/fla2/models/emla/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd1b499904a0a5d83b0d37e6864bb15124898d25 Binary files /dev/null and b/fla2/models/emla/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla2/models/emla/__pycache__/__init__.cpython-312.pyc b/fla2/models/emla/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2b626807ee7452dfd0659ee59b32a85ca7aa7cf Binary files /dev/null and b/fla2/models/emla/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/models/emla/__pycache__/__init__.cpython-38.pyc b/fla2/models/emla/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f2d0d22cbd50d90fac263803737b3ac106ddc75 Binary files /dev/null and b/fla2/models/emla/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/models/emla/__pycache__/configuration_emgla.cpython-310.pyc b/fla2/models/emla/__pycache__/configuration_emgla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c16d0a7d44d25704f2b515c0eb87bd0ba149efbf Binary files /dev/null and b/fla2/models/emla/__pycache__/configuration_emgla.cpython-310.pyc differ diff --git a/fla2/models/emla/__pycache__/configuration_emla.cpython-310.pyc b/fla2/models/emla/__pycache__/configuration_emla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90870ad2c95a175021dc4f0d861b1d18ac183e87 Binary files /dev/null and b/fla2/models/emla/__pycache__/configuration_emla.cpython-310.pyc differ diff --git a/fla2/models/emla/__pycache__/configuration_emla.cpython-312.pyc b/fla2/models/emla/__pycache__/configuration_emla.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4451f6df21866ca2d78a83fd4e52de93fb3b21dd Binary files /dev/null and b/fla2/models/emla/__pycache__/configuration_emla.cpython-312.pyc differ diff --git a/fla2/models/emla/__pycache__/configuration_emla.cpython-38.pyc b/fla2/models/emla/__pycache__/configuration_emla.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43299fc4200fa1dadcd29f75f084bf99f357f294 Binary files /dev/null and b/fla2/models/emla/__pycache__/configuration_emla.cpython-38.pyc differ diff --git a/fla2/models/emla/__pycache__/modeling_emgla.cpython-310.pyc b/fla2/models/emla/__pycache__/modeling_emgla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b15c2e90086bd98e7bb2c3a4dbc409f774ac8649 Binary files /dev/null and b/fla2/models/emla/__pycache__/modeling_emgla.cpython-310.pyc differ diff --git a/fla2/models/emla/__pycache__/modeling_emla.cpython-310.pyc b/fla2/models/emla/__pycache__/modeling_emla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90b3e77f198da70d734fa75b29055d3325520af7 Binary files /dev/null and b/fla2/models/emla/__pycache__/modeling_emla.cpython-310.pyc differ diff --git a/fla2/models/emla/__pycache__/modeling_emla.cpython-312.pyc b/fla2/models/emla/__pycache__/modeling_emla.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d50cd7fce1c53032c688595ee1fab80acd38cd1 Binary files /dev/null and b/fla2/models/emla/__pycache__/modeling_emla.cpython-312.pyc differ diff --git a/fla2/models/emla/__pycache__/modeling_emla.cpython-38.pyc b/fla2/models/emla/__pycache__/modeling_emla.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fbb55b95acef8263e567cea19e0895a46470d30 Binary files /dev/null and b/fla2/models/emla/__pycache__/modeling_emla.cpython-38.pyc differ diff --git a/fla2/models/emla/configuration_emla.py b/fla2/models/emla/configuration_emla.py new file mode 100644 index 0000000000000000000000000000000000000000..f4fc04a359a427a261b3823d7f2bf7d971bcd257 --- /dev/null +++ b/fla2/models/emla/configuration_emla.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + + +class emlaConfig(PretrainedConfig): + + model_type = 'emla' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + use_gate: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + use_beta: bool = True, + use_output_norm: bool = True, + num_heads: int = 16, + qk_norm: str = 'l2', + qk_activation: str = 'silu', + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + num_hidden_layers: int = 24, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + aux_loss_scale:float = 0.01, + ratio : int = 2, + topk : int = 1, + **kwargs + ): + self.attn_mode = attn_mode + self.aux_loss_scale = aux_loss_scale + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_beta = use_beta + self.use_output_norm = use_output_norm + self.num_heads = num_heads + self.qk_norm = qk_norm + self.qk_activation = qk_activation + self.max_position_embeddings = max_position_embeddings + self.topk = topk + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + self.ratio = ratio + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla2/models/emla/modeling_emla.py b/fla2/models/emla/modeling_emla.py new file mode 100644 index 0000000000000000000000000000000000000000..0e9bd9c322bed819351053f94955b113ca3542c0 --- /dev/null +++ b/fla2/models/emla/modeling_emla.py @@ -0,0 +1,533 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union,Iterable +from einops import rearrange +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from ...layers.attn import Attention +from ...layers.emla import emla +from ...models.emla.configuration_emla import emlaConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as emlaMLP +from ...modules import RMSNorm +from ...modules import RotaryEmbedding +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +class emlaBlock(nn.Module): + def __init__(self, config: emlaConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + print(config) + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = emla( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_short_conv=config.use_short_conv, + use_output_norm=config.use_output_norm, + conv_size=config.conv_size, + norm_eps=config.norm_eps, + ratio = config.ratio, + topk = config.topk, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = emlaMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values,router_logits = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values,router_logits) + + return outputs + + +class emlaPreTrainedModel(PreTrainedModel): + + config_class = emlaConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['emlaBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class emlaModel(emlaPreTrainedModel): + + def __init__(self, config: emlaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([emlaBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`emlaModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + all_router_logits = () + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values, router_logits = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values, router_logits = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + all_router_logits += (router_logits,) + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + # return BaseModelOutputWithPast( + # last_hidden_state=hidden_states, + # past_key_values=past_key_values, + # hidden_states=all_hidden_states, + # attentions=all_attns + # ) + return emlaOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns, + router_logits=all_router_logits + ) + + +from dataclasses import dataclass +@dataclass +class emlaOutputWithPast(BaseModelOutputWithPast): + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class emlaCausalLMOutputWithPast(CausalLMOutputWithPast): + aux_loss: Optional[torch.FloatTensor] = None + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple], + num_memories: torch.Tensor = None, + top_k=2, + use_layer_wise_balance=False, +) -> torch.FloatTensor: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_memories]. + num_memories (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or ( + isinstance(gate_logits, Iterable) and len(gate_logits) == 0 + ): + return 0 + + # ✨ Here is the fix for balance loss in Mixtral. + # We should calculate the balance loss in a layer-wise manner otherwise it may lead to degenerated solutions. + if use_layer_wise_balance: + if not isinstance(gate_logits, Iterable): + gate_logits = (gate_logits,) + else: + if isinstance(gate_logits, Iterable): + gate_logits = (torch.cat(gate_logits, dim=0),) + else: + gate_logits = (gate_logits,) + + all_balance_losses = [] + + for logits in gate_logits: + logits = rearrange(logits,'b h l r-> (b h) l r') + routing_weights, selected_experts = torch.topk(logits, top_k, dim=-1) + routing_weights = routing_weights.softmax(dim=-1).to(logits.dtype) + routing_weights_full = torch.zeros_like(logits).scatter(-1, selected_experts, routing_weights) + + # cast the expert indices to int64, otherwise one-hot encoding will fail + if selected_experts.dtype != torch.int64: + selected_experts = selected_experts.to(torch.int64) + + if len(selected_experts.shape) == 2: + selected_experts = selected_experts.unsqueeze(2) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_memories) + + # For a given token, determine if it was routed to a given expert. + expert_mask = torch.max(expert_mask, axis=-2).values + + # cast to float32 otherwise mean will fail + expert_mask = expert_mask.to(torch.float32) + tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) + + router_prob_per_group_and_expert = torch.mean(routing_weights_full, axis=-2) + + # ✨ balance loss for this layer + balance_loss = torch.mean( + tokens_per_group_and_expert * router_prob_per_group_and_expert + ) * (num_memories**2) + all_balance_losses.append(balance_loss.reshape(1)) + + all_balance_losses = torch.cat(all_balance_losses).mean() # ✨ + + return all_balance_losses + + +class emlaForCausalLM(emlaPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = emlaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + valid_router_logits = tuple( + logits + for logits in (outputs.router_logits if return_dict else outputs[-1]) + if logits is not None + ) + aux_loss = load_balancing_loss_func( + valid_router_logits, + self.config.ratio, + self.config.topk, + use_layer_wise_balance=True, + ) + aux_loss *= self.config.aux_loss_scale + # print('aux_loss:',aux_loss) + loss += aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + # return CausalLMOutputWithPast( + # loss=loss, + # logits=logits, + # past_key_values=outputs.past_key_values, + # hidden_states=outputs.hidden_states, + # attentions=outputs.attentions, + # ) + + return emlaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + aux_loss=aux_loss + ) \ No newline at end of file diff --git a/fla2/models/gla/__init__.py b/fla2/models/gla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..edccb515af8f04144308bfcbb72be8e91e714cd7 --- /dev/null +++ b/fla2/models/gla/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.gla.configuration_gla import GLAConfig +from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel + +AutoConfig.register(GLAConfig.model_type, GLAConfig) +AutoModel.register(GLAConfig, GLAModel) +AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM) + + +__all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel'] diff --git a/fla2/models/gla/__pycache__/__init__.cpython-312.pyc b/fla2/models/gla/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f8d6291a732fb30c1c670dc46d8951d1136bc8d Binary files /dev/null and b/fla2/models/gla/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/models/gla/__pycache__/__init__.cpython-38.pyc b/fla2/models/gla/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bce8b1042a948af8d1061778e830766a9523f7f Binary files /dev/null and b/fla2/models/gla/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/models/gla/__pycache__/__init__.cpython-39.pyc b/fla2/models/gla/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e58e76d4041c2ecbce2d997dda780e528296bf92 Binary files /dev/null and b/fla2/models/gla/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/models/gla/__pycache__/configuration_gla.cpython-312.pyc b/fla2/models/gla/__pycache__/configuration_gla.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e289c68086d949fd5311252ab7518aa6244a088 Binary files /dev/null and b/fla2/models/gla/__pycache__/configuration_gla.cpython-312.pyc differ diff --git a/fla2/models/gla/__pycache__/configuration_gla.cpython-38.pyc b/fla2/models/gla/__pycache__/configuration_gla.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a91d13096b615b81c6769fadb73d5b6ca4407de Binary files /dev/null and b/fla2/models/gla/__pycache__/configuration_gla.cpython-38.pyc differ diff --git a/fla2/models/gla/__pycache__/configuration_gla.cpython-39.pyc b/fla2/models/gla/__pycache__/configuration_gla.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5da62a2a857df5c99ec1421df0d40c40c774625f Binary files /dev/null and b/fla2/models/gla/__pycache__/configuration_gla.cpython-39.pyc differ diff --git a/fla2/models/gla/__pycache__/modeling_gla.cpython-312.pyc b/fla2/models/gla/__pycache__/modeling_gla.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3458bf4141580f3ad40a77fbb224415aa0f5ef2d Binary files /dev/null and b/fla2/models/gla/__pycache__/modeling_gla.cpython-312.pyc differ diff --git a/fla2/models/gla/__pycache__/modeling_gla.cpython-38.pyc b/fla2/models/gla/__pycache__/modeling_gla.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b440ac64fcf0bb1bbed06ffb20862d19c39466a Binary files /dev/null and b/fla2/models/gla/__pycache__/modeling_gla.cpython-38.pyc differ diff --git a/fla2/models/gla/__pycache__/modeling_gla.cpython-39.pyc b/fla2/models/gla/__pycache__/modeling_gla.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5f3497e0d9b3602780ba5116e7e0af89b43b73f Binary files /dev/null and b/fla2/models/gla/__pycache__/modeling_gla.cpython-39.pyc differ diff --git a/fla2/models/gla/configuration_gla.py b/fla2/models/gla/configuration_gla.py new file mode 100644 index 0000000000000000000000000000000000000000..98c33138ab0f3c186ef41afdc3a394e25a6bdbad --- /dev/null +++ b/fla2/models/gla/configuration_gla.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class GLAConfig(PretrainedConfig): + + model_type = 'gla' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2048, + expand_k: int = 0.5, + expand_v: int = 1, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + feature_map: Optional[str] = None, + attn_mode: str = "chunk", + use_short_conv: bool = False, + conv_size: int = 4, + use_output_gate: bool = True, + clamp_min: Optional[float] = None, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_gk: bool = True, + use_gv: bool = False, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + **kwargs + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.feature_map = feature_map + self.attn_mode = attn_mode + self.clamp_min = clamp_min + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_gk = use_gk + self.use_gv = use_gv + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_output_gate = use_output_gate + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla2/models/gla/modeling_gla.py b/fla2/models/gla/modeling_gla.py new file mode 100644 index 0000000000000000000000000000000000000000..bbef0afa0152fead1fc2d06786242af8ee270420 --- /dev/null +++ b/fla2/models/gla/modeling_gla.py @@ -0,0 +1,402 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.layers.gla import GatedLinearAttention +from fla.models.gla.configuration_gla import GLAConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, RMSNorm +from fla.modules.activations import swiglu_linear + +logger = logging.get_logger(__name__) + +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as GLAMLP + +# class GLAMLP(nn.Module): + +# def __init__( +# self, +# hidden_size: int, +# hidden_ratio: Optional[int] = None, +# intermediate_size: Optional[int] = None, +# hidden_act: str = 'swish' +# ) -> GLAMLP: +# super().__init__() + +# self.hidden_size = hidden_size +# # the final number of params is `hidden_ratio * hidden_size^2` +# # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` +# if hidden_ratio is None: +# hidden_ratio = 4 +# if intermediate_size is None: +# intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) +# intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) +# self.hidden_ratio = hidden_ratio +# self.intermediate_size = intermediate_size + +# self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) +# self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) +# self.act_fn = ACT2FN[hidden_act] + +# def forward(self, x): +# y = self.gate_proj(x) +# gate, y = y.chunk(2, -1) +# return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class GLABlock(nn.Module): + def __init__(self, config: GLAConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.attn = GatedLinearAttention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + feature_map=config.feature_map, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + use_output_gate=config.use_output_gate, + gate_fn=config.hidden_act, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + clamp_min=config.clamp_min, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = GLAMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class GLAPreTrainedModel(PreTrainedModel): + + config_class = GLAConfig + supports_gradient_checkpointing = True + _no_split_modules = ['GLABlock'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class GLAModel(GLAPreTrainedModel): + + def __init__(self, config: GLAConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([GLABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`GLAModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache: + if past_key_values is None: + past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers] + if not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class GLAForCausalLM(GLAPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = GLAModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + if not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1) + input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla2/models/gsa/__init__.py b/fla2/models/gsa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a134f758e0bea0eb844a2db73957936078f889b6 --- /dev/null +++ b/fla2/models/gsa/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.gsa.configuration_gsa import GSAConfig +from fla.models.gsa.modeling_gsa import GSAForCausalLM, GSAModel + +AutoConfig.register(GSAConfig.model_type, GSAConfig) +AutoModel.register(GSAConfig, GSAModel) +AutoModelForCausalLM.register(GSAConfig, GSAForCausalLM) + + +__all__ = ['GSAConfig', 'GSAForCausalLM', 'GSAModel'] diff --git a/fla2/models/gsa/__pycache__/__init__.cpython-312.pyc b/fla2/models/gsa/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecf0f258ae5e740b2c3d17269a8e360d16362f58 Binary files /dev/null and b/fla2/models/gsa/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/models/gsa/__pycache__/__init__.cpython-38.pyc b/fla2/models/gsa/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56d74ae032fee95b094c29cc052b0d5d5443f09e Binary files /dev/null and b/fla2/models/gsa/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/models/gsa/__pycache__/__init__.cpython-39.pyc b/fla2/models/gsa/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b05068bbaaa3d1b98e261679ee8e55cdabe726f Binary files /dev/null and b/fla2/models/gsa/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc b/fla2/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f254321c6e42a5c3d3fac891b666d1f7aa284723 Binary files /dev/null and b/fla2/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc differ diff --git a/fla2/models/gsa/__pycache__/configuration_gsa.cpython-38.pyc b/fla2/models/gsa/__pycache__/configuration_gsa.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8b6994f2233d85a57a8b3e7bc338901ab9ea017 Binary files /dev/null and b/fla2/models/gsa/__pycache__/configuration_gsa.cpython-38.pyc differ diff --git a/fla2/models/gsa/__pycache__/configuration_gsa.cpython-39.pyc b/fla2/models/gsa/__pycache__/configuration_gsa.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a0303581190d1898e5f8d72d2fc0001bb49d8d8 Binary files /dev/null and b/fla2/models/gsa/__pycache__/configuration_gsa.cpython-39.pyc differ diff --git a/fla2/models/gsa/__pycache__/modeling_gsa.cpython-312.pyc b/fla2/models/gsa/__pycache__/modeling_gsa.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a224be5210b1997106f3e7166dc4aecbbf996d2 Binary files /dev/null and b/fla2/models/gsa/__pycache__/modeling_gsa.cpython-312.pyc differ diff --git a/fla2/models/gsa/__pycache__/modeling_gsa.cpython-38.pyc b/fla2/models/gsa/__pycache__/modeling_gsa.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32097ae38d783dc609c6a0f1f68ff884921414de Binary files /dev/null and b/fla2/models/gsa/__pycache__/modeling_gsa.cpython-38.pyc differ diff --git a/fla2/models/gsa/__pycache__/modeling_gsa.cpython-39.pyc b/fla2/models/gsa/__pycache__/modeling_gsa.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6039ba335b1bd462f5d832a62cd106b82ad3e522 Binary files /dev/null and b/fla2/models/gsa/__pycache__/modeling_gsa.cpython-39.pyc differ diff --git a/fla2/models/gsa/configuration_gsa.py b/fla2/models/gsa/configuration_gsa.py new file mode 100644 index 0000000000000000000000000000000000000000..df1ab641d584480401bd1c76a6c62ceca353a06c --- /dev/null +++ b/fla2/models/gsa/configuration_gsa.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class GSAConfig(PretrainedConfig): + + model_type = 'gsa' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2048, + gate_low_rank_dim: Optional[int] = None, + gate_logit_normalizer: Optional[int] = 16, + clamp_min: Optional[float] = None, + clamp_max: Optional[float] = None, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 8, + num_kv_heads: Optional[int] = None, + num_slots: Optional[int] = 64, + use_short_conv: bool = False, + conv_size: int = 4, + exapnd_k: float = 1, + exapnd_v: float = 1, + feature_map: str = 'swish', + use_rope: bool = False, + use_output_gate: bool = False, + use_norm: bool = True, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_first: bool = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + initializer_range: float = 0.02, + tie_word_embeddings: bool = False, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + **kwargs + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.gate_low_rank_dim = gate_low_rank_dim + self.gate_logit_normalizer = gate_logit_normalizer + self.clamp_min = clamp_min + self.clamp_max = clamp_max + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.num_slots = num_slots + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.expand_k = exapnd_k + self.expand_v = exapnd_v + self.feature_map = feature_map + self.use_rope = use_rope + self.use_output_gate = use_output_gate + self.use_norm = use_norm + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_first = norm_first + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + self.fuse_norm = fuse_norm + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla2/models/gsa/modeling_gsa.py b/fla2/models/gsa/modeling_gsa.py new file mode 100644 index 0000000000000000000000000000000000000000..267b25492260a169a92c8a7611edde45a4e2b1c1 --- /dev/null +++ b/fla2/models/gsa/modeling_gsa.py @@ -0,0 +1,423 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.layers.gsa import GatedSlotAttention +from fla.models.gsa.configuration_gsa import GSAConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, RMSNorm, RMSNormLinear +from fla.modules.activations import swiglu_linear + +logger = logging.get_logger(__name__) + + +class GSAMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish', + norm_first: bool = True, + norm_eps: float = 1e-5 + ) -> GSAMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.norm_first = norm_first + + if norm_first: + self.norm = RMSNormLinear(hidden_size=hidden_size, eps=norm_eps) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + if self.norm_first: + gate, y = self.norm(x, self.gate_proj.weight, self.gate_proj.bias).chunk(2, -1) + else: + gate, y = self.gate_proj(x).chunk(2, -1) + return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class GSABlock(nn.Module): + def __init__(self, config: GSAConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if not config.norm_first: + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.attn = GatedSlotAttention( + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + num_slots=config.num_slots, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + feature_map=config.feature_map, + use_rope=config.use_rope, + use_output_gate=config.use_output_gate, + use_norm=config.use_norm, + gate_fn=config.hidden_act, + gate_low_rank_dim=config.gate_low_rank_dim, + gate_logit_normalizer=config.gate_logit_normalizer, + elementwise_affine=config.elementwise_affine, + norm_first=config.norm_first, + norm_eps=config.norm_eps, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + if not config.norm_first: + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = GSAMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + norm_first=config.norm_first, + norm_eps=config.norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + if hasattr(self, 'attn_norm'): + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + if hasattr(self, 'mlp_norm'): + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class GSAPreTrainedModel(PreTrainedModel): + + config_class = GSAConfig + supports_gradient_checkpointing = True + _no_split_modules = ['GSABlock'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class GSAModel(GSAPreTrainedModel): + + def __init__(self, config: GSAConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([GSABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`GSAModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache: + if past_key_values is None: + past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers] + if not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class GSAForCausalLM(GSAPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + + super().__init__(config) + self.model = GSAModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + if not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1) + input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla2/models/hgrn/__init__.py b/fla2/models/hgrn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3b29a3dd82da6d64bac6cc887e24295a03de5b23 --- /dev/null +++ b/fla2/models/hgrn/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.hgrn.configuration_hgrn import HGRNConfig +from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel + +AutoConfig.register(HGRNConfig.model_type, HGRNConfig) +AutoModel.register(HGRNConfig, HGRNModel) +AutoModelForCausalLM.register(HGRNConfig, HGRNForCausalLM) + + +__all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel'] diff --git a/fla2/models/hgrn/__pycache__/__init__.cpython-312.pyc b/fla2/models/hgrn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e5e64d33d10ad4f3752bcf513029d0752a176b1 Binary files /dev/null and b/fla2/models/hgrn/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/models/hgrn/__pycache__/__init__.cpython-38.pyc b/fla2/models/hgrn/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bf66a2ad7164b7e2549a6c5b3b7b61bf06c1d93 Binary files /dev/null and b/fla2/models/hgrn/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/models/hgrn/__pycache__/__init__.cpython-39.pyc b/fla2/models/hgrn/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f0c0283769e9d69f0be4f96eda7567a5916310a Binary files /dev/null and b/fla2/models/hgrn/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/models/hgrn/__pycache__/configuration_hgrn.cpython-312.pyc b/fla2/models/hgrn/__pycache__/configuration_hgrn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54b3e017a60927e9b362c773187744f7ccb29645 Binary files /dev/null and b/fla2/models/hgrn/__pycache__/configuration_hgrn.cpython-312.pyc differ diff --git a/fla2/models/hgrn/__pycache__/configuration_hgrn.cpython-38.pyc b/fla2/models/hgrn/__pycache__/configuration_hgrn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbfe7d3e022d61e443c4cc8585131944d5123581 Binary files /dev/null and b/fla2/models/hgrn/__pycache__/configuration_hgrn.cpython-38.pyc differ diff --git a/fla2/models/hgrn/__pycache__/configuration_hgrn.cpython-39.pyc b/fla2/models/hgrn/__pycache__/configuration_hgrn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..148d53686f00f29acf9f26e27d378808112a55d8 Binary files /dev/null and b/fla2/models/hgrn/__pycache__/configuration_hgrn.cpython-39.pyc differ diff --git a/fla2/models/hgrn/__pycache__/modeling_hgrn.cpython-312.pyc b/fla2/models/hgrn/__pycache__/modeling_hgrn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7148de4c62abe61a99455706f1c91fdb951b04ff Binary files /dev/null and b/fla2/models/hgrn/__pycache__/modeling_hgrn.cpython-312.pyc differ diff --git a/fla2/models/hgrn/configuration_hgrn.py b/fla2/models/hgrn/configuration_hgrn.py new file mode 100644 index 0000000000000000000000000000000000000000..b8cd12aa2acf3b43a013f6a73115b2fa6dc668c2 --- /dev/null +++ b/fla2/models/hgrn/configuration_hgrn.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class HGRNConfig(PretrainedConfig): + + model_type = 'hgrn' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + vocab_size: int = 32000, + hidden_size: int = 2048, + num_hidden_layers: int = 24, + num_heads: Optional[int] = 1, + expand_ratio: Optional[int] = 1, + use_short_conv: bool = False, + conv_size: int = 4, + use_lower_bound: bool = True, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + **kwargs + ): + self.attn_mode = attn_mode + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.expand_ratio = expand_ratio + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_lower_bound = use_lower_bound + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla2/models/hgrn/modeling_hgrn.py b/fla2/models/hgrn/modeling_hgrn.py new file mode 100644 index 0000000000000000000000000000000000000000..bd39b4f640b10c0dcdf4d2f12fa090ebe59535e1 --- /dev/null +++ b/fla2/models/hgrn/modeling_hgrn.py @@ -0,0 +1,403 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.layers.hgrn import HGRNAttention +from fla.models.hgrn.configuration_hgrn import HGRNConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, RMSNorm +from fla.modules.activations import swiglu_linear + +logger = logging.get_logger(__name__) + + +class HGRNMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish' + ) -> HGRNMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + y = self.gate_proj(x) + gate, y = y.chunk(2, -1) + return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class HGRNBlock(nn.Module): + def __init__(self, config: HGRNConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.attn = HGRNAttention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + num_heads=config.num_heads, + expand_ratio=config.expand_ratio, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = HGRNMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + lower_bound: Optional[torch.Tensor] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + lower_bound=lower_bound + ) + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class HGRNPreTrainedModel(PreTrainedModel): + + config_class = HGRNConfig + supports_gradient_checkpointing = True + _no_split_modules = ['HGRNBlock'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class HGRNModel(HGRNPreTrainedModel): + + def __init__(self, config: HGRNConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + if config.use_lower_bound: + self.lower_bounds = nn.Parameter(torch.zeros(config.num_hidden_layers, config.hidden_size)) + self.layers = nn.ModuleList([HGRNBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`HGRNModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache: + if past_key_values is None: + past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers] + if not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + + if self.config.use_lower_bound: + lower_bounds = self.lower_bounds.softmax(0) + lower_bounds = lower_bounds.cumsum(0) - lower_bounds[0] + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + lower_bound = lower_bounds[i] if self.config.use_lower_bound else None + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + lower_bound + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + lower_bound=lower_bound + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class HGRNForCausalLM(HGRNPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = HGRNModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + if not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1) + input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla2/models/hgrn2/__init__.py b/fla2/models/hgrn2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..306b8082220a57091f2e99cd689c011690db0439 --- /dev/null +++ b/fla2/models/hgrn2/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config +from fla.models.hgrn2.modeling_hgrn2 import HGRN2ForCausalLM, HGRN2Model + +AutoConfig.register(HGRN2Config.model_type, HGRN2Config) +AutoModel.register(HGRN2Config, HGRN2Model) +AutoModelForCausalLM.register(HGRN2Config, HGRN2ForCausalLM) + + +__all__ = ['HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model'] diff --git a/fla2/models/hgrn2/__pycache__/__init__.cpython-312.pyc b/fla2/models/hgrn2/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..737de2b5f8c01ae5219e5189bc353e854439d1d0 Binary files /dev/null and b/fla2/models/hgrn2/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/models/hgrn2/__pycache__/__init__.cpython-38.pyc b/fla2/models/hgrn2/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa108df9388b0585dfb9ec0c45d4545b571160f5 Binary files /dev/null and b/fla2/models/hgrn2/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/models/hgrn2/__pycache__/configuration_hgrn2.cpython-312.pyc b/fla2/models/hgrn2/__pycache__/configuration_hgrn2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d471ddda7bd1aecb494fdb533980eb9693b1abc4 Binary files /dev/null and b/fla2/models/hgrn2/__pycache__/configuration_hgrn2.cpython-312.pyc differ diff --git a/fla2/models/hgrn2/__pycache__/configuration_hgrn2.cpython-39.pyc b/fla2/models/hgrn2/__pycache__/configuration_hgrn2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..071883d402f464ab1b2f1805c76c02f9cb3c12c8 Binary files /dev/null and b/fla2/models/hgrn2/__pycache__/configuration_hgrn2.cpython-39.pyc differ diff --git a/fla2/models/hgrn2/__pycache__/modeling_hgrn2.cpython-38.pyc b/fla2/models/hgrn2/__pycache__/modeling_hgrn2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c86c9696cd1b6204402df5e66dac04366aa1093d Binary files /dev/null and b/fla2/models/hgrn2/__pycache__/modeling_hgrn2.cpython-38.pyc differ diff --git a/fla2/models/hgrn2/configuration_hgrn2.py b/fla2/models/hgrn2/configuration_hgrn2.py new file mode 100644 index 0000000000000000000000000000000000000000..6c88abaa2ecdbfb2e9519b75fdb453aa0db2bf86 --- /dev/null +++ b/fla2/models/hgrn2/configuration_hgrn2.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class HGRN2Config(PretrainedConfig): + + model_type = 'hgrn2' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2048, + num_hidden_layers: int = 24, + attn_mode: str = "chunk", + num_heads: Optional[int] = None, + expand_ratio: Optional[int] = 128, + use_short_conv: bool = False, + conv_size: int = 4, + use_lower_bound: bool = True, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + **kwargs + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.attn_mode = attn_mode + self.num_heads = num_heads + self.expand_ratio = expand_ratio + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_lower_bound = use_lower_bound + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla2/models/hgrn2/modeling_hgrn2.py b/fla2/models/hgrn2/modeling_hgrn2.py new file mode 100644 index 0000000000000000000000000000000000000000..d0392e9e41875496f267fee4fbd2eef62088e381 --- /dev/null +++ b/fla2/models/hgrn2/modeling_hgrn2.py @@ -0,0 +1,403 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.layers.hgrn2 import HGRN2Attention +from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, RMSNorm +from fla.modules.activations import swiglu_linear + +logger = logging.get_logger(__name__) + + +class HGRN2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish' + ) -> HGRN2MLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + y = self.gate_proj(x) + gate, y = y.chunk(2, -1) + return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class HGRN2Block(nn.Module): + def __init__(self, config: HGRN2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.attn = HGRN2Attention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + num_heads=config.num_heads, + expand_ratio=config.expand_ratio, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = HGRN2MLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + lower_bound: Optional[torch.Tensor] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + lower_bound=lower_bound + ) + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class HGRN2PreTrainedModel(PreTrainedModel): + + config_class = HGRN2Config + supports_gradient_checkpointing = True + _no_split_modules = ['HGRN2Block'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class HGRN2Model(HGRN2PreTrainedModel): + + def __init__(self, config: HGRN2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + if config.use_lower_bound: + self.lower_bounds = nn.Parameter(torch.zeros(config.num_hidden_layers, config.hidden_size)) + self.layers = nn.ModuleList([HGRN2Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`HGRN2Model` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache: + if past_key_values is None: + past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers] + if not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + + if self.config.use_lower_bound: + lower_bounds = self.lower_bounds.softmax(0) + lower_bounds = lower_bounds.cumsum(0) - lower_bounds[0] + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + lower_bound = lower_bounds[i] if self.config.use_lower_bound else None + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + lower_bound + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + lower_bound=lower_bound + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class HGRN2ForCausalLM(HGRN2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = HGRN2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + if not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1) + input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla2/models/linear_attn/__init__.py b/fla2/models/linear_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..72d5d022de95afe9dc6cf76d3c2026a6a7f9e7a0 --- /dev/null +++ b/fla2/models/linear_attn/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.linear_attn.configuration_linear_attn import \ + LinearAttentionConfig +from fla.models.linear_attn.modeling_linear_attn import ( + LinearAttentionForCausalLM, LinearAttentionModel) + +AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig) +AutoModel.register(LinearAttentionConfig, LinearAttentionModel) +AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM) + +__all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel'] diff --git a/fla2/models/linear_attn/__pycache__/__init__.cpython-312.pyc b/fla2/models/linear_attn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d46dfda5fc8ea84730f47977fe6c6eb169c9306 Binary files /dev/null and b/fla2/models/linear_attn/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/models/linear_attn/__pycache__/__init__.cpython-38.pyc b/fla2/models/linear_attn/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb8b6486e6068b6378ee38002d8ddd495c30007b Binary files /dev/null and b/fla2/models/linear_attn/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/models/linear_attn/__pycache__/__init__.cpython-39.pyc b/fla2/models/linear_attn/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb9d48004e5a2708a13cdc68ddee5bc96034a4e0 Binary files /dev/null and b/fla2/models/linear_attn/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/models/linear_attn/configuration_linear_attn.py b/fla2/models/linear_attn/configuration_linear_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..ed4bae518434b978a725e1f2437b11751cf3d644 --- /dev/null +++ b/fla2/models/linear_attn/configuration_linear_attn.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class LinearAttentionConfig(PretrainedConfig): + + model_type = 'linear_attn' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + attn_mode: str = "fused_chunk", + feature_map: str = "elementwise_product", + tie_feature_map_qk: bool = False, + norm_q: bool = False, + norm_k: bool = False, + norm_feature_map: bool = False, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + **kwargs + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.attn_mode = attn_mode + self.feature_map = feature_map + self.tie_feature_map_qk = tie_feature_map_qk + self.norm_q = norm_q + self.norm_k = norm_k + self.norm_feature_map = norm_feature_map + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla2/models/linear_attn/modeling_linear_attn.py b/fla2/models/linear_attn/modeling_linear_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..9977c32456d802540349b2ba8c94f86981598ecf --- /dev/null +++ b/fla2/models/linear_attn/modeling_linear_attn.py @@ -0,0 +1,425 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.layers.linear_attn import LinearAttention +from fla.models.linear_attn.configuration_linear_attn import \ + LinearAttentionConfig +from fla.modules import FusedCrossEntropyLoss, RMSNorm +from fla.modules.activations import swiglu_linear + +logger = logging.get_logger(__name__) + + +class LinearAttentionMLP(nn.Module): + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish' + ) -> LinearAttentionMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + y = self.gate_proj(x) + gate, y = y.chunk(2, -1) + return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class LinearAttentionBlock(nn.Module): + def __init__(self, config: LinearAttentionConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.attn = LinearAttention( + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + mode=config.attn_mode, + feature_map=config.feature_map, + tie_feature_map_qk=config.tie_feature_map_qk, + norm_q=config.norm_q, + norm_k=config.norm_k, + do_feature_map_norm=config.norm_feature_map, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = LinearAttentionMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + # currently not supported + attn_weights, present_key_value = None, None + + hidden_states = self.attn_norm(hidden_states) + hidden_states = self.attn(hidden_states) + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class LinearAttentionPreTrainedModel(PreTrainedModel): + config_class = LinearAttentionConfig + supports_gradient_checkpointing = True + _no_split_modules = ['LinearAttentionBlock'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class LinearAttentionModel(LinearAttentionPreTrainedModel): + + def __init__(self, config: LinearAttentionConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LinearAttentionBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn( + "`LinearAttentionModel` does not support output attention weights now, " + "so `output_attentions` is set to `False`." + ) + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + _, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + _, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + # embed positions + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class LinearAttentionForCausalLM(LinearAttentionPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LinearAttentionModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exc: + # Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'" + if 'past_key_values' in str(exc): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exc + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + state: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs + ): + # only last token for inputs_ids if the state is passed along. + if state is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and state is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + model_inputs["state"] = state + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla2/models/mamba/__init__.py b/fla2/models/mamba/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0eff2ea26f3a11bcf2333002509686eca2289aa --- /dev/null +++ b/fla2/models/mamba/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.mamba.configuration_mamba import MambaConfig +from fla.models.mamba.modeling_mamba import (MambaBlock, MambaForCausalLM, + MambaModel) + +AutoConfig.register(MambaConfig.model_type, MambaConfig, True) +AutoModel.register(MambaConfig, MambaModel, True) +AutoModelForCausalLM.register(MambaConfig, MambaForCausalLM, True) + + +__all__ = ['MambaConfig', 'MambaForCausalLM', 'MambaModel', 'MambaBlock'] diff --git a/fla2/models/mamba/__pycache__/__init__.cpython-39.pyc b/fla2/models/mamba/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f0abab48e715df82caeaf80a9d5832d55faeda1 Binary files /dev/null and b/fla2/models/mamba/__pycache__/__init__.cpython-39.pyc differ