YongganFu commited on
Commit
5fbddec
·
verified ·
1 Parent(s): 2aae7ad

Upload FastSLMForCausalLM

Browse files
config.json CHANGED
@@ -13,6 +13,7 @@
13
  "bos_token_id": 1,
14
  "calc_logits_for_entire_prompt": false,
15
  "d_conv": 4,
 
16
  "eos_token_id": 2,
17
  "ffn_expand_ratio": 3,
18
  "global_attn_idx": [],
@@ -127,8 +128,7 @@
127
  "router_aux_loss_coef": 0.001,
128
  "sliding_window": null,
129
  "tie_word_embeddings": true,
130
- "torch_dtype": "bfloat16",
131
- "transformers_version": "4.48.2",
132
  "use_cache": false,
133
  "use_mamba_kernels": true,
134
  "v_head_dim": -1,
 
13
  "bos_token_id": 1,
14
  "calc_logits_for_entire_prompt": false,
15
  "d_conv": 4,
16
+ "dtype": "bfloat16",
17
  "eos_token_id": 2,
18
  "ffn_expand_ratio": 3,
19
  "global_attn_idx": [],
 
128
  "router_aux_loss_coef": 0.001,
129
  "sliding_window": null,
130
  "tie_word_embeddings": true,
131
+ "transformers_version": "4.56.2",
 
132
  "use_cache": false,
133
  "use_mamba_kernels": true,
134
  "v_head_dim": -1,
delta_net.py CHANGED
@@ -10,7 +10,6 @@ import torch.nn as nn
10
  from einops import rearrange
11
  from torch.nn import functional as F
12
 
13
- import fla
14
  from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
15
  from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
16
 
@@ -331,7 +330,7 @@ class Cache(transformers.cache_utils.Cache):
331
  self,
332
  seen_tokens: int = 0
333
  ) -> Cache:
334
- super().__init__()
335
 
336
  self.states: List[Dict[str, Any]] = []
337
 
 
10
  from einops import rearrange
11
  from torch.nn import functional as F
12
 
 
13
  from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
14
  from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
15
 
 
330
  self,
331
  seen_tokens: int = 0
332
  ) -> Cache:
333
+ super().__init__(layers=[0])
334
 
335
  self.states: List[Dict[str, Any]] = []
336
 
model.safetensors.index.json CHANGED
@@ -1,5 +1,6 @@
1
  {
2
  "metadata": {
 
3
  "total_size": 5500034112
4
  },
5
  "weight_map": {
 
1
  {
2
  "metadata": {
3
+ "total_parameters": 2750017056,
4
  "total_size": 5500034112
5
  },
6
  "weight_map": {
modeling_fast_slm.py CHANGED
@@ -46,6 +46,12 @@ from transformers.modeling_outputs import (
46
  SequenceClassifierOutputWithPast,
47
  )
48
  from transformers.modeling_utils import PreTrainedModel
 
 
 
 
 
 
49
  from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
50
  from transformers.utils import (
51
  add_start_docstrings,
@@ -280,7 +286,6 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
280
  def __init__(self, config, batch_size, dtype=torch.float16, device=None, layer_type=None):
281
  self.dtype = dtype
282
  # self.layers_block_type = config.layers_block_type
283
- self.has_previous_state = False
284
  intermediate_size = config.mamba_expand * config.hidden_size
285
  ssm_state_size = config.mamba_d_state
286
  conv_kernel_size = config.mamba_d_conv
@@ -804,6 +809,75 @@ class FastSLMFlashAttention2(FastSLMAttention):
804
  )
805
 
806
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
807
  class FastSLMFused_MHA(FastSLMAttention):
808
  """
809
  FastSLM flash attention module. This module inherits from `FastSLMAttention` as the weights of the module stays
@@ -938,9 +1012,6 @@ class FastSLMFused_MHA(FastSLMAttention):
938
  v_dim = query_states.shape[-2] * value_states.shape[-1]
939
  attn_output = attn_output.reshape(bsz, q_len, v_dim).contiguous()
940
 
941
- if past_key_value is not None:
942
- past_key_value.has_previous_state = True
943
-
944
  attn_output = self.o_proj(attn_output)
945
 
946
  if not output_attentions:
@@ -952,6 +1023,7 @@ class FastSLMFused_MHA(FastSLMAttention):
952
  JAMBA_ATTENTION_CLASSES = {
953
  "flash_attention_2": FastSLMFlashAttention2,
954
  "fused_mha": FastSLMFused_MHA,
 
955
  }
956
 
957
  class FastSLMMLP(nn.Module):
@@ -1633,6 +1705,8 @@ class FastSLMModel(FastSLMPreTrainedModel):
1633
  # Initialize weights and apply final processing
1634
  self.post_init()
1635
 
 
 
1636
 
1637
  def get_input_embeddings(self):
1638
  return self.embed_tokens
@@ -1684,21 +1758,13 @@ class FastSLMModel(FastSLMPreTrainedModel):
1684
  )
1685
  use_cache = False
1686
 
1687
- past_key_values_length = 0
1688
- if use_cache:
1689
- if past_key_values is not None:
1690
- past_key_values_length = past_key_values.get_usable_length(seq_length, 0)
1691
- else:
1692
- use_cache = False
1693
-
1694
  if position_ids is None:
1695
  device = input_ids.device if input_ids is not None else inputs_embeds.device
1696
- position_ids = torch.arange(
1697
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1698
  )
1699
  position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1700
  else:
1701
- if self.config.num_memory_tokens > 0 and past_key_values is not None and not past_key_values.has_previous_state:
1702
  position_ids = position_ids.view(-1, seq_length + self.config.num_memory_tokens).long()
1703
  else:
1704
  position_ids = position_ids.view(-1, seq_length).long()
@@ -1708,7 +1774,7 @@ class FastSLMModel(FastSLMPreTrainedModel):
1708
 
1709
  ori_b, ori_n = inputs_embeds.shape[0], inputs_embeds.shape[1]
1710
 
1711
- if self.config.num_memory_tokens > 0 and (past_key_values is None or not past_key_values.has_previous_state):
1712
  mem = repeat(self.memory_tokens, 'n d -> b n d', b = inputs_embeds.shape[0]) # prepend the memory to every segment of m by repeating the memory tokens
1713
  inputs_embeds, mem_packed_shape = pack((mem, inputs_embeds), 'b * d')
1714
 
@@ -1718,6 +1784,7 @@ class FastSLMModel(FastSLMPreTrainedModel):
1718
  if attention_mask is not None and attention_mask.shape[1] < inputs_embeds.shape[1]:
1719
  assert attention_mask.shape[1] + self.config.num_memory_tokens == inputs_embeds.shape[1]
1720
  attention_mask = torch.cat([torch.ones(inputs_embeds.shape[0], self.config.num_memory_tokens, device=attention_mask.device), attention_mask], dim=1)
 
1721
 
1722
  if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1723
  is_padding_right = attention_mask[:, -1].sum().item() != batch_size
@@ -1784,21 +1851,12 @@ class FastSLMModel(FastSLMPreTrainedModel):
1784
  if output_hidden_states:
1785
  all_hidden_states += (hidden_states,)
1786
 
1787
- if self.config.num_memory_tokens > 0 and (past_key_values is None or not past_key_values.has_previous_state):
1788
  mem, hidden_states = unpack(hidden_states, mem_packed_shape, 'b * d')
1789
  hidden_states = hidden_states[:, :ori_n, :]
1790
 
1791
- if past_key_values and not past_key_values.has_previous_state:
1792
- for layer_idx_ in range(len(self.layers)):
1793
- if past_key_values.get_seq_length(layer_idx_) > 0:
1794
- past_key_values.has_previous_state = True
1795
- break
1796
-
1797
- if mamba_inference_params is not None and mamba_inference_params.seqlen_offset > 0:
1798
- past_key_values.has_previous_state = True
1799
-
1800
- if fla_past_key_values is not None and len(fla_past_key_values.states) > 0:
1801
- past_key_values.has_previous_state = True
1802
 
1803
  next_cache = None
1804
  if use_cache:
@@ -2011,7 +2069,7 @@ class FastSLMForCausalLM(FastSLMPreTrainedModel):
2011
  static_logits = torch.zeros((batch_size, self.config.vocab_size), device=device)
2012
 
2013
  # Set up for graph capture
2014
- past_key_values.has_previous_state = True
2015
  if mamba_inference_params is not None:
2016
  mamba_inference_params.seqlen_offset = 1
2017
 
@@ -2055,7 +2113,7 @@ class FastSLMForCausalLM(FastSLMPreTrainedModel):
2055
  if hasattr(module, 'reset_kv_cache'):
2056
  module.reset_kv_cache()
2057
 
2058
- past_key_values.has_previous_state = False
2059
 
2060
  # Return generation state
2061
  generation_state = {
@@ -2134,7 +2192,7 @@ class FastSLMForCausalLM(FastSLMPreTrainedModel):
2134
  if hasattr(module, 'reset_kv_cache'):
2135
  module.reset_kv_cache()
2136
 
2137
- past_key_values.has_previous_state = False
2138
 
2139
  # Prefill phase - process input sequence
2140
  position_ids = torch.arange(
@@ -2215,6 +2273,48 @@ class FastSLMForCausalLM(FastSLMPreTrainedModel):
2215
  return generated_ids
2216
 
2217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2218
  def sample_token(logits, temperature=1.0, top_k=0, top_p=0.9):
2219
  """
2220
  Sample a token from logits with temperature, top-k, and top-p filtering.
 
46
  SequenceClassifierOutputWithPast,
47
  )
48
  from transformers.modeling_utils import PreTrainedModel
49
+
50
+ try:
51
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
52
+ except ImportError:
53
+ pass
54
+
55
  from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
56
  from transformers.utils import (
57
  add_start_docstrings,
 
286
  def __init__(self, config, batch_size, dtype=torch.float16, device=None, layer_type=None):
287
  self.dtype = dtype
288
  # self.layers_block_type = config.layers_block_type
 
289
  intermediate_size = config.mamba_expand * config.hidden_size
290
  ssm_state_size = config.mamba_d_state
291
  conv_kernel_size = config.mamba_d_conv
 
809
  )
810
 
811
 
812
+
813
+ class FastSLMSDPAAttention(nn.Module):
814
+
815
+ def __init__(self, config, layer_idx: int, reuse_kv=False):
816
+ super().__init__()
817
+ self.config = config
818
+ self.layer_idx = layer_idx
819
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
820
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
821
+ self.scaling = self.head_dim**-0.5
822
+ self.attention_dropout = config.attention_dropout
823
+ self.is_causal = True
824
+
825
+ self.q_proj = nn.Linear(
826
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=False
827
+ )
828
+ self.k_proj = nn.Linear(
829
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
830
+ )
831
+ self.v_proj = nn.Linear(
832
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
833
+ )
834
+ self.o_proj = nn.Linear(
835
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=False
836
+ )
837
+
838
+ self.sliding_window = self.config.sliding_window if self.layer_idx not in self.config.global_attn_idx else None
839
+
840
+ def forward(
841
+ self,
842
+ hidden_states: torch.Tensor,
843
+ # position_embeddings: tuple[torch.Tensor, torch.Tensor],
844
+ attention_mask: Optional[torch.Tensor],
845
+ past_key_value: Optional[Cache] = None,
846
+ **kwargs,
847
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
848
+ input_shape = hidden_states.shape[:-1]
849
+ hidden_shape = (*input_shape, -1, self.head_dim)
850
+
851
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
852
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
853
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
854
+
855
+ # cos, sin = position_embeddings
856
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
857
+
858
+ if past_key_value is not None:
859
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) # , cache_kwargs)
860
+
861
+ attention_interface = ALL_ATTENTION_FUNCTIONS['flash_attention_2']
862
+
863
+ attn_output, attn_weights = attention_interface(
864
+ self,
865
+ query_states,
866
+ key_states,
867
+ value_states,
868
+ attention_mask,
869
+ dropout=0.0 if not self.training else self.attention_dropout,
870
+ scaling=self.scaling,
871
+ sliding_window=self.sliding_window, # diff with Llama
872
+ **kwargs,
873
+ )
874
+
875
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
876
+ attn_output = self.o_proj(attn_output)
877
+
878
+ return attn_output, attn_weights, past_key_value, (key_states, value_states)
879
+
880
+
881
  class FastSLMFused_MHA(FastSLMAttention):
882
  """
883
  FastSLM flash attention module. This module inherits from `FastSLMAttention` as the weights of the module stays
 
1012
  v_dim = query_states.shape[-2] * value_states.shape[-1]
1013
  attn_output = attn_output.reshape(bsz, q_len, v_dim).contiguous()
1014
 
 
 
 
1015
  attn_output = self.o_proj(attn_output)
1016
 
1017
  if not output_attentions:
 
1023
  JAMBA_ATTENTION_CLASSES = {
1024
  "flash_attention_2": FastSLMFlashAttention2,
1025
  "fused_mha": FastSLMFused_MHA,
1026
+ "sdpa": FastSLMSDPAAttention,
1027
  }
1028
 
1029
  class FastSLMMLP(nn.Module):
 
1705
  # Initialize weights and apply final processing
1706
  self.post_init()
1707
 
1708
+ self.has_previous_state = False
1709
+
1710
 
1711
  def get_input_embeddings(self):
1712
  return self.embed_tokens
 
1758
  )
1759
  use_cache = False
1760
 
 
 
 
 
 
 
 
1761
  if position_ids is None:
1762
  device = input_ids.device if input_ids is not None else inputs_embeds.device
1763
+ position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device
 
1764
  )
1765
  position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1766
  else:
1767
+ if self.config.num_memory_tokens > 0 and past_key_values is not None and not self.has_previous_state:
1768
  position_ids = position_ids.view(-1, seq_length + self.config.num_memory_tokens).long()
1769
  else:
1770
  position_ids = position_ids.view(-1, seq_length).long()
 
1774
 
1775
  ori_b, ori_n = inputs_embeds.shape[0], inputs_embeds.shape[1]
1776
 
1777
+ if self.config.num_memory_tokens > 0 and (past_key_values is None or not self.has_previous_state):
1778
  mem = repeat(self.memory_tokens, 'n d -> b n d', b = inputs_embeds.shape[0]) # prepend the memory to every segment of m by repeating the memory tokens
1779
  inputs_embeds, mem_packed_shape = pack((mem, inputs_embeds), 'b * d')
1780
 
 
1784
  if attention_mask is not None and attention_mask.shape[1] < inputs_embeds.shape[1]:
1785
  assert attention_mask.shape[1] + self.config.num_memory_tokens == inputs_embeds.shape[1]
1786
  attention_mask = torch.cat([torch.ones(inputs_embeds.shape[0], self.config.num_memory_tokens, device=attention_mask.device), attention_mask], dim=1)
1787
+
1788
 
1789
  if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1790
  is_padding_right = attention_mask[:, -1].sum().item() != batch_size
 
1851
  if output_hidden_states:
1852
  all_hidden_states += (hidden_states,)
1853
 
1854
+ if self.config.num_memory_tokens > 0 and (past_key_values is None or not self.has_previous_state):
1855
  mem, hidden_states = unpack(hidden_states, mem_packed_shape, 'b * d')
1856
  hidden_states = hidden_states[:, :ori_n, :]
1857
 
1858
+ if past_key_values is not None and not self.has_previous_state:
1859
+ self.has_previous_state = True
 
 
 
 
 
 
 
 
 
1860
 
1861
  next_cache = None
1862
  if use_cache:
 
2069
  static_logits = torch.zeros((batch_size, self.config.vocab_size), device=device)
2070
 
2071
  # Set up for graph capture
2072
+ self.model.has_previous_state = True
2073
  if mamba_inference_params is not None:
2074
  mamba_inference_params.seqlen_offset = 1
2075
 
 
2113
  if hasattr(module, 'reset_kv_cache'):
2114
  module.reset_kv_cache()
2115
 
2116
+ self.model.has_previous_state = False
2117
 
2118
  # Return generation state
2119
  generation_state = {
 
2192
  if hasattr(module, 'reset_kv_cache'):
2193
  module.reset_kv_cache()
2194
 
2195
+ self.model.has_previous_state = False
2196
 
2197
  # Prefill phase - process input sequence
2198
  position_ids = torch.arange(
 
2273
  return generated_ids
2274
 
2275
 
2276
+ def prepare_inputs_for_generation(
2277
+ self,
2278
+ input_ids,
2279
+ past_key_values=None,
2280
+ attention_mask=None,
2281
+ inputs_embeds=None,
2282
+ output_router_logits=False,
2283
+ **kwargs,
2284
+ ):
2285
+ if self.config.num_memory_tokens > 0:
2286
+ attention_mask = torch.cat([torch.ones(input_ids.shape[0], self.config.num_memory_tokens, device=attention_mask.device), attention_mask], dim=1)
2287
+
2288
+ past_key_values = None # Disable cache for now
2289
+
2290
+ position_ids = kwargs.get("position_ids", None)
2291
+ if attention_mask is not None and position_ids is None:
2292
+ # create position_ids on the fly for batch generation
2293
+ position_ids = attention_mask.long().cumsum(-1) - 1
2294
+ position_ids.masked_fill_(attention_mask == 0, 1)
2295
+ position_ids = position_ids[:, -input_ids.shape[1]:]
2296
+
2297
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
2298
+ if inputs_embeds is not None:
2299
+ if input_ids.shape[1] == 0:
2300
+ model_inputs = {"inputs_embeds": inputs_embeds}
2301
+ else:
2302
+ inputs_embeds_new = self.model.embed_tokens(input_ids)
2303
+ model_inputs = {"inputs_embeds": torch.cat([inputs_embeds, inputs_embeds_new], dim=1)}
2304
+ else:
2305
+ model_inputs = {"input_ids": input_ids}
2306
+
2307
+ model_inputs.update(
2308
+ {
2309
+ "position_ids": position_ids,
2310
+ "past_key_values": past_key_values,
2311
+ "use_cache": kwargs.get("use_cache"),
2312
+ "attention_mask": attention_mask,
2313
+ }
2314
+ )
2315
+ return model_inputs
2316
+
2317
+
2318
  def sample_token(logits, temperature=1.0, top_k=0, top_p=0.9):
2319
  """
2320
  Sample a token from logits with temperature, top-k, and top-p filtering.