msj19 commited on
Commit
c1eabf9
·
verified ·
1 Parent(s): e3221a7

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla2/layers/__pycache__/attn.cpython-38.pyc +0 -0
  2. fla2/layers/__pycache__/gla.cpython-38.pyc +0 -0
  3. fla2/layers/__pycache__/gsa.cpython-312.pyc +0 -0
  4. fla2/layers/__pycache__/gsa.cpython-38.pyc +0 -0
  5. fla2/layers/__pycache__/hgrn.cpython-38.pyc +0 -0
  6. fla2/layers/__pycache__/hgrn.cpython-39.pyc +0 -0
  7. fla2/layers/__pycache__/hgrn2.cpython-39.pyc +0 -0
  8. fla2/layers/__pycache__/linear_attn.cpython-38.pyc +0 -0
  9. fla2/layers/__pycache__/linear_attn.cpython-39.pyc +0 -0
  10. fla2/layers/__pycache__/mask_deltanet.cpython-310.pyc +0 -0
  11. fla2/layers/__pycache__/mask_deltanet.cpython-312.pyc +0 -0
  12. fla2/layers/__pycache__/mask_gdn.cpython-310.pyc +0 -0
  13. fla2/layers/__pycache__/mask_gdn.cpython-312.pyc +0 -0
  14. fla2/layers/__pycache__/multiscale_retention.cpython-39.pyc +0 -0
  15. fla2/layers/__pycache__/rebased.cpython-312.pyc +0 -0
  16. fla2/layers/__pycache__/rebased.cpython-39.pyc +0 -0
  17. fla2/layers/__pycache__/rwkv6.cpython-312.pyc +0 -0
  18. fla2/layers/__pycache__/rwkv6.cpython-38.pyc +0 -0
  19. fla2/layers/__pycache__/rwkv6.cpython-39.pyc +0 -0
  20. fla2/models/gsa/__pycache__/configuration_gsa.cpython-39.pyc +0 -0
  21. fla2/models/hgrn/__pycache__/modeling_hgrn.cpython-38.pyc +0 -0
  22. fla2/models/hgrn/__pycache__/modeling_hgrn.cpython-39.pyc +0 -0
  23. fla2/models/hgrn2/__pycache__/__init__.cpython-312.pyc +0 -0
  24. fla2/models/hgrn2/__pycache__/configuration_hgrn2.cpython-312.pyc +0 -0
  25. fla2/models/hgrn2/__pycache__/configuration_hgrn2.cpython-39.pyc +0 -0
  26. fla2/models/hgrn2/__pycache__/modeling_hgrn2.cpython-39.pyc +0 -0
  27. fla2/models/linear_attn/__pycache__/__init__.cpython-312.pyc +0 -0
  28. fla2/models/linear_attn/__pycache__/__init__.cpython-38.pyc +0 -0
  29. fla2/models/linear_attn/__pycache__/__init__.cpython-39.pyc +0 -0
  30. fla2/models/linear_attn/__pycache__/configuration_linear_attn.cpython-38.pyc +0 -0
  31. fla2/models/linear_attn/__pycache__/modeling_linear_attn.cpython-312.pyc +0 -0
  32. fla2/models/linear_attn/__pycache__/modeling_linear_attn.cpython-38.pyc +0 -0
  33. fla2/models/linear_attn/__pycache__/modeling_linear_attn.cpython-39.pyc +0 -0
  34. fla2/models/linear_attn/configuration_linear_attn.py +72 -0
  35. fla2/models/mamba/__init__.py +14 -0
  36. fla2/models/mamba/__pycache__/__init__.cpython-38.pyc +0 -0
  37. fla2/models/mamba/__pycache__/__init__.cpython-39.pyc +0 -0
  38. fla2/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc +0 -0
  39. fla2/models/mamba/__pycache__/configuration_mamba.cpython-38.pyc +0 -0
  40. fla2/models/mamba/__pycache__/configuration_mamba.cpython-39.pyc +0 -0
  41. fla2/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc +0 -0
  42. fla2/models/mamba/__pycache__/modeling_mamba.cpython-38.pyc +0 -0
  43. fla2/models/mamba/__pycache__/modeling_mamba.cpython-39.pyc +0 -0
  44. fla2/models/mamba/modeling_mamba.py +606 -0
  45. fla2/models/mamba2/__init__.py +13 -0
  46. fla2/models/mamba2/__pycache__/__init__.cpython-312.pyc +0 -0
  47. fla2/models/mamba2/__pycache__/__init__.cpython-38.pyc +0 -0
  48. fla2/models/mamba2/__pycache__/__init__.cpython-39.pyc +0 -0
  49. fla2/models/mamba2/__pycache__/configuration_mamba2.cpython-312.pyc +0 -0
  50. fla2/models/mamba2/__pycache__/configuration_mamba2.cpython-38.pyc +0 -0
fla2/layers/__pycache__/attn.cpython-38.pyc ADDED
Binary file (4.77 kB). View file
 
fla2/layers/__pycache__/gla.cpython-38.pyc ADDED
Binary file (8.53 kB). View file
 
fla2/layers/__pycache__/gsa.cpython-312.pyc ADDED
Binary file (13.1 kB). View file
 
fla2/layers/__pycache__/gsa.cpython-38.pyc ADDED
Binary file (6.64 kB). View file
 
fla2/layers/__pycache__/hgrn.cpython-38.pyc ADDED
Binary file (4.49 kB). View file
 
fla2/layers/__pycache__/hgrn.cpython-39.pyc ADDED
Binary file (4.47 kB). View file
 
fla2/layers/__pycache__/hgrn2.cpython-39.pyc ADDED
Binary file (5.01 kB). View file
 
fla2/layers/__pycache__/linear_attn.cpython-38.pyc ADDED
Binary file (4.55 kB). View file
 
fla2/layers/__pycache__/linear_attn.cpython-39.pyc ADDED
Binary file (4.53 kB). View file
 
fla2/layers/__pycache__/mask_deltanet.cpython-310.pyc ADDED
Binary file (9.16 kB). View file
 
fla2/layers/__pycache__/mask_deltanet.cpython-312.pyc ADDED
Binary file (17.2 kB). View file
 
fla2/layers/__pycache__/mask_gdn.cpython-310.pyc ADDED
Binary file (9.22 kB). View file
 
fla2/layers/__pycache__/mask_gdn.cpython-312.pyc ADDED
Binary file (17.7 kB). View file
 
fla2/layers/__pycache__/multiscale_retention.cpython-39.pyc ADDED
Binary file (8.17 kB). View file
 
fla2/layers/__pycache__/rebased.cpython-312.pyc ADDED
Binary file (8.54 kB). View file
 
fla2/layers/__pycache__/rebased.cpython-39.pyc ADDED
Binary file (4.33 kB). View file
 
fla2/layers/__pycache__/rwkv6.cpython-312.pyc ADDED
Binary file (14 kB). View file
 
fla2/layers/__pycache__/rwkv6.cpython-38.pyc ADDED
Binary file (7.85 kB). View file
 
fla2/layers/__pycache__/rwkv6.cpython-39.pyc ADDED
Binary file (7.83 kB). View file
 
fla2/models/gsa/__pycache__/configuration_gsa.cpython-39.pyc ADDED
Binary file (2.21 kB). View file
 
fla2/models/hgrn/__pycache__/modeling_hgrn.cpython-38.pyc ADDED
Binary file (11.2 kB). View file
 
fla2/models/hgrn/__pycache__/modeling_hgrn.cpython-39.pyc ADDED
Binary file (11.2 kB). View file
 
fla2/models/hgrn2/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (689 Bytes). View file
 
fla2/models/hgrn2/__pycache__/configuration_hgrn2.cpython-312.pyc ADDED
Binary file (2.39 kB). View file
 
fla2/models/hgrn2/__pycache__/configuration_hgrn2.cpython-39.pyc ADDED
Binary file (1.76 kB). View file
 
fla2/models/hgrn2/__pycache__/modeling_hgrn2.cpython-39.pyc ADDED
Binary file (11.3 kB). View file
 
fla2/models/linear_attn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (758 Bytes). View file
 
fla2/models/linear_attn/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (573 Bytes). View file
 
fla2/models/linear_attn/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (575 Bytes). View file
 
fla2/models/linear_attn/__pycache__/configuration_linear_attn.cpython-38.pyc ADDED
Binary file (1.97 kB). View file
 
fla2/models/linear_attn/__pycache__/modeling_linear_attn.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
fla2/models/linear_attn/__pycache__/modeling_linear_attn.cpython-38.pyc ADDED
Binary file (11.7 kB). View file
 
fla2/models/linear_attn/__pycache__/modeling_linear_attn.cpython-39.pyc ADDED
Binary file (11.7 kB). View file
 
fla2/models/linear_attn/configuration_linear_attn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class LinearAttentionConfig(PretrainedConfig):
9
+
10
+ model_type = 'linear_attn'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ vocab_size: int = 32000,
16
+ hidden_size: int = 2048,
17
+ expand_k: int = 1,
18
+ expand_v: int = 1,
19
+ hidden_ratio: Optional[int] = 4,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ num_kv_heads: Optional[int] = None,
24
+ attn_mode: str = "fused_chunk",
25
+ feature_map: str = "elementwise_product",
26
+ tie_feature_map_qk: bool = False,
27
+ norm_q: bool = False,
28
+ norm_k: bool = False,
29
+ norm_feature_map: bool = False,
30
+ hidden_act: str = "swish",
31
+ max_position_embeddings: int = 2048,
32
+ elementwise_affine: Optional[bool] = True,
33
+ norm_eps: float = 1e-6,
34
+ use_cache: bool = True,
35
+ pad_token_id: int = None,
36
+ bos_token_id: int = 1,
37
+ eos_token_id: int = 2,
38
+ tie_word_embeddings: bool = False,
39
+ initializer_range: float = 0.02,
40
+ fuse_cross_entropy: bool = True,
41
+ **kwargs
42
+ ):
43
+ self.vocab_size = vocab_size
44
+ self.max_position_embeddings = max_position_embeddings
45
+ self.hidden_size = hidden_size
46
+ self.expand_k = expand_k
47
+ self.expand_v = expand_v
48
+ self.hidden_ratio = hidden_ratio
49
+ self.intermediate_size = intermediate_size
50
+ self.num_hidden_layers = num_hidden_layers
51
+ self.num_heads = num_heads
52
+ self.num_kv_heads = num_kv_heads
53
+ self.attn_mode = attn_mode
54
+ self.feature_map = feature_map
55
+ self.tie_feature_map_qk = tie_feature_map_qk
56
+ self.norm_q = norm_q
57
+ self.norm_k = norm_k
58
+ self.norm_feature_map = norm_feature_map
59
+ self.hidden_act = hidden_act
60
+ self.elementwise_affine = elementwise_affine
61
+ self.norm_eps = norm_eps
62
+ self.use_cache = use_cache
63
+ self.initializer_range = initializer_range
64
+ self.fuse_cross_entropy = fuse_cross_entropy
65
+
66
+ super().__init__(
67
+ pad_token_id=pad_token_id,
68
+ bos_token_id=bos_token_id,
69
+ eos_token_id=eos_token_id,
70
+ tie_word_embeddings=tie_word_embeddings,
71
+ **kwargs,
72
+ )
fla2/models/mamba/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.mamba.configuration_mamba import MambaConfig
6
+ from fla.models.mamba.modeling_mamba import (MambaBlock, MambaForCausalLM,
7
+ MambaModel)
8
+
9
+ AutoConfig.register(MambaConfig.model_type, MambaConfig, True)
10
+ AutoModel.register(MambaConfig, MambaModel, True)
11
+ AutoModelForCausalLM.register(MambaConfig, MambaForCausalLM, True)
12
+
13
+
14
+ __all__ = ['MambaConfig', 'MambaForCausalLM', 'MambaModel', 'MambaBlock']
fla2/models/mamba/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (548 Bytes). View file
 
fla2/models/mamba/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (548 Bytes). View file
 
fla2/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc ADDED
Binary file (7.08 kB). View file
 
fla2/models/mamba/__pycache__/configuration_mamba.cpython-38.pyc ADDED
Binary file (6.21 kB). View file
 
fla2/models/mamba/__pycache__/configuration_mamba.cpython-39.pyc ADDED
Binary file (6.21 kB). View file
 
fla2/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc ADDED
Binary file (31.8 kB). View file
 
fla2/models/mamba/__pycache__/modeling_mamba.cpython-38.pyc ADDED
Binary file (18.1 kB). View file
 
fla2/models/mamba/__pycache__/modeling_mamba.cpython-39.pyc ADDED
Binary file (18.1 kB). View file
 
fla2/models/mamba/modeling_mamba.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch MAMBA model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from transformers.activations import ACT2FN
25
+ from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.utils import ModelOutput, logging
27
+
28
+ from fla.models.mamba.configuration_mamba import MambaConfig
29
+ from fla.modules import FusedCrossEntropyLoss, RMSNorm
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ try:
34
+ from mamba_ssm.ops.selective_scan_interface import (mamba_inner_fn,
35
+ selective_scan_fn)
36
+ from mamba_ssm.ops.triton.selective_state_update import \
37
+ selective_state_update
38
+ except ImportError:
39
+ selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
40
+
41
+ try:
42
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
43
+ except ImportError:
44
+ causal_conv1d_update, causal_conv1d_fn = None, None
45
+
46
+ is_fast_path_available = all(
47
+ (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
48
+ )
49
+
50
+
51
+ class MambaCache:
52
+ def __init__(self, config, batch_size, dtype=torch.float16, device=None):
53
+ self.seqlen_offset = 0
54
+ self.dtype = dtype
55
+ intermediate_size = config.intermediate_size
56
+ ssm_state_size = config.state_size
57
+ conv_kernel_size = config.conv_kernel
58
+
59
+ self.conv_states = {
60
+ i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
61
+ for i in range(config.num_hidden_layers)
62
+ }
63
+ self.ssm_states = {
64
+ i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
65
+ for i in range(config.num_hidden_layers)
66
+ }
67
+
68
+
69
+ class MambaMixer(nn.Module):
70
+ """
71
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
72
+ A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
73
+ ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
74
+ and is why Mamba is called **selective** state spaces)
75
+ """
76
+
77
+ def __init__(self, config, layer_idx):
78
+ super().__init__()
79
+ self.hidden_size = config.hidden_size
80
+ self.ssm_state_size = config.state_size
81
+ self.conv_kernel_size = config.conv_kernel
82
+ self.intermediate_size = config.intermediate_size
83
+ self.time_step_rank = config.time_step_rank
84
+ self.layer_idx = layer_idx
85
+ self.use_conv_bias = config.use_conv_bias
86
+ self.conv1d = nn.Conv1d(
87
+ in_channels=self.intermediate_size,
88
+ out_channels=self.intermediate_size,
89
+ bias=config.use_conv_bias,
90
+ kernel_size=config.conv_kernel,
91
+ groups=self.intermediate_size,
92
+ padding=config.conv_kernel - 1,
93
+ )
94
+
95
+ self.activation = config.hidden_act
96
+ self.act = ACT2FN[config.hidden_act]
97
+
98
+ # projection of the input hidden states
99
+ self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
100
+ # selective projection used to make dt, B and C input dependant
101
+ self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
102
+ # time step projection (discretization)
103
+ self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
104
+
105
+ # S4D real initialization. These are not discretized!
106
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
107
+ A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
108
+ A = A.expand(self.intermediate_size, -1).contiguous()
109
+
110
+ self.A_log = nn.Parameter(torch.log(A))
111
+ self.D = nn.Parameter(torch.ones(self.intermediate_size))
112
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
113
+ self.use_bias = config.use_bias
114
+
115
+ if not is_fast_path_available:
116
+ logger.warning_once(
117
+ "The fast path is not available because on of "
118
+ "`(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
119
+ " is None. Falling back to the naive implementation. "
120
+ "To install follow https://github.com/state-spaces/mamba/#installation and"
121
+ " https://github.com/Dao-AILab/causal-conv1d"
122
+ )
123
+
124
+ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Optional[MambaCache] = None):
125
+ # 1. Gated MLP's linear projection
126
+ projected_states = self.in_proj(hidden_states).transpose(1, 2)
127
+
128
+ if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
129
+ contextualized_states = mamba_inner_fn(
130
+ projected_states,
131
+ self.conv1d.weight,
132
+ self.conv1d.bias if self.use_conv_bias else None,
133
+ self.x_proj.weight,
134
+ self.dt_proj.weight,
135
+ self.out_proj.weight,
136
+ self.out_proj.bias.float() if self.use_bias else None,
137
+ -torch.exp(self.A_log.float()),
138
+ None, # input-dependent B
139
+ None, # input-dependent C
140
+ self.D.float(),
141
+ delta_bias=self.dt_proj.bias.float(),
142
+ delta_softplus=True,
143
+ )
144
+
145
+ else:
146
+ hidden_states, gate = projected_states.chunk(2, dim=1)
147
+
148
+ # 2. Convolution sequence transformation
149
+ conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
150
+ if cache_params is not None and cache_params.seqlen_offset > 0:
151
+ hidden_states = causal_conv1d_update(
152
+ hidden_states.squeeze(-1),
153
+ cache_params.conv_states[self.layer_idx],
154
+ conv_weights,
155
+ self.conv1d.bias,
156
+ self.activation,
157
+ )
158
+ hidden_states = hidden_states.unsqueeze(-1)
159
+ else:
160
+ if cache_params is not None:
161
+ conv_states = nn.functional.pad(
162
+ hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
163
+ )
164
+ cache_params.conv_states[self.layer_idx].copy_(conv_states)
165
+ hidden_states = causal_conv1d_fn(
166
+ hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
167
+ )
168
+
169
+ # 3. State Space Model sequence transformation
170
+ # 3.a. input varying initialization of time_step, B and C
171
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
172
+ time_step, B, C = torch.split(
173
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
174
+ )
175
+ discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
176
+
177
+ A = -torch.exp(self.A_log.float())
178
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
179
+ time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
180
+ if cache_params is not None and cache_params.seqlen_offset > 0:
181
+ scan_outputs = selective_state_update(
182
+ cache_params.ssm_states[self.layer_idx],
183
+ hidden_states[..., 0],
184
+ discrete_time_step[..., 0],
185
+ A,
186
+ B[:, 0],
187
+ C[:, 0],
188
+ self.D,
189
+ gate[..., 0],
190
+ time_proj_bias,
191
+ dt_softplus=True,
192
+ ).unsqueeze(-1)
193
+ else:
194
+ scan_outputs, ssm_state = selective_scan_fn(
195
+ hidden_states,
196
+ discrete_time_step,
197
+ A,
198
+ B.transpose(1, 2),
199
+ C.transpose(1, 2),
200
+ self.D.float(),
201
+ gate,
202
+ time_proj_bias,
203
+ delta_softplus=True,
204
+ return_last_state=True,
205
+ )
206
+ if ssm_state is not None and cache_params is not None:
207
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
208
+
209
+ # 4. Final linear projection
210
+ contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
211
+ return contextualized_states
212
+
213
+ # fmt: off
214
+ def slow_forward(self, input_states, cache_params: Optional[MambaCache] = None):
215
+ batch_size, seq_len, _ = input_states.shape
216
+ dtype = input_states.dtype
217
+ # 1. Gated MLP's linear projection
218
+ # [batch, 2 * intermediate_size, seq_len]
219
+ projected_states = self.in_proj(input_states).transpose(1, 2)
220
+ hidden_states, gate = projected_states.chunk(2, dim=1)
221
+
222
+ # 2. Convolution sequence transformation
223
+ if cache_params is not None:
224
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
225
+ if cache_params.seqlen_offset > 0:
226
+ # [batch, intermediate_size, conv_kernel_size]
227
+ conv_state = cache_params.conv_states[self.layer_idx]
228
+ conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
229
+ conv_state[:, :, -1] = hidden_states[:, :, 0]
230
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
231
+ hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
232
+ if self.use_conv_bias:
233
+ hidden_states += self.conv1d.bias
234
+ # [batch, intermediate_size, 1] : decoding
235
+ hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1)
236
+ else:
237
+ conv_state = nn.functional.pad(
238
+ hidden_states,
239
+ (self.conv_kernel_size - hidden_states.shape[-1], 0)
240
+ )
241
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
242
+ # [batch, intermediate_size, seq_len]
243
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
244
+ else:
245
+ ssm_state = torch.zeros(
246
+ (batch_size, self.intermediate_size, self.ssm_state_size),
247
+ device=hidden_states.device, dtype=dtype
248
+ )
249
+ # [batch, intermediate_size, seq_len]
250
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
251
+
252
+ # 3. State Space Model sequence transformation
253
+ # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
254
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
255
+ time_step, B, C = torch.split(
256
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
257
+ )
258
+ # [batch, seq_len, intermediate_size]
259
+ discrete_time_step = self.dt_proj(time_step)
260
+ # [batch, intermediate_size, seq_len]
261
+ discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2)
262
+
263
+ # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
264
+ # [intermediate_size, ssm_state_size]
265
+ A = -torch.exp(self.A_log.float())
266
+ # [batch, intermediate_size, seq_len, ssm_state_size]
267
+ discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None])
268
+ # [batch, intermediade_size, seq_len, ssm_state_size]
269
+ discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float()
270
+ deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
271
+
272
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
273
+ scan_outputs = []
274
+ for i in range(seq_len):
275
+ # [batch, intermediade_size, ssm_state]
276
+ ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
277
+ # [batch, intermediade_size, 1]
278
+ scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1))
279
+ scan_outputs.append(scan_output[:, :, 0])
280
+ # [batch, seq_len, intermediade_size]
281
+ scan_output = torch.stack(scan_outputs, dim=-1)
282
+ scan_output = scan_output + (hidden_states * self.D[None, :, None])
283
+ scan_output = (scan_output * self.act(gate))
284
+
285
+ if cache_params is not None:
286
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
287
+
288
+ # 4. Final linear projection
289
+ # [batch, seq_len, hidden_size]
290
+ contextualized_states = self.out_proj(scan_output.transpose(1, 2))
291
+ return contextualized_states
292
+ # fmt: on
293
+
294
+ def forward(self, hidden_states, cache_params: Optional[MambaCache] = None):
295
+ if is_fast_path_available and "cuda" in self.x_proj.weight.device.type:
296
+ return self.cuda_kernels_forward(hidden_states, cache_params)
297
+ return self.slow_forward(hidden_states, cache_params)
298
+
299
+
300
+ class MambaBlock(nn.Module):
301
+ def __init__(self, config, layer_idx):
302
+ super().__init__()
303
+ self.config = config
304
+ self.layer_idx = layer_idx
305
+ self.residual_in_fp32 = config.residual_in_fp32
306
+ self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
307
+ self.mixer = MambaMixer(config, layer_idx=layer_idx)
308
+
309
+ def forward(self, hidden_states, cache_params: Optional[MambaCache] = None):
310
+ residual = hidden_states
311
+ hidden_states = self.norm(hidden_states)
312
+ # if self.residual_in_fp32:
313
+ # residual = residual.to(torch.float32)
314
+ hidden_states = self.mixer(hidden_states, cache_params=cache_params)
315
+ hidden_states = residual + hidden_states
316
+ return hidden_states
317
+
318
+
319
+ class MambaPreTrainedModel(PreTrainedModel):
320
+ """
321
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
322
+ models.
323
+ """
324
+
325
+ config_class = MambaConfig
326
+ base_model_prefix = "backbone"
327
+ _no_split_modules = ["MambaBlock"]
328
+ supports_gradient_checkpointing = True
329
+
330
+ def _init_weights(self, module):
331
+ """Initialize the weights."""
332
+ if isinstance(module, MambaMixer):
333
+ module.A_log._no_weight_decay = True
334
+ module.D._no_weight_decay = True
335
+
336
+ dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
337
+ if self.config.time_step_init_scheme == "constant":
338
+ nn.init.constant_(module.dt_proj.weight, dt_init_std)
339
+ elif self.config.time_step_init_scheme == "random":
340
+ nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
341
+
342
+ dt = torch.exp(
343
+ torch.rand(self.config.intermediate_size)
344
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
345
+ + math.log(self.config.time_step_min)
346
+ ).clamp(min=self.config.time_step_floor)
347
+ # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
348
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
349
+ with torch.no_grad():
350
+ module.dt_proj.bias.copy_(inv_dt)
351
+ module.dt_proj.bias._no_reinit = True
352
+
353
+ if isinstance(module, nn.Linear):
354
+ if module.bias is not None:
355
+ if not getattr(module.bias, "_no_reinit", False):
356
+ nn.init.zeros_(module.bias)
357
+ elif isinstance(module, nn.Embedding):
358
+ nn.init.normal_(module.weight, std=self.config.initializer_range)
359
+
360
+ if self.config.rescale_prenorm_residual:
361
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
362
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
363
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
364
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
365
+ #
366
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
367
+ for name, p in module.named_parameters():
368
+ if name in ["out_proj.weight"]:
369
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
370
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
371
+ # We need to reinit p since this code could be called multiple times
372
+ # Having just p *= scale would repeatedly scale it down
373
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
374
+ with torch.no_grad():
375
+ p /= math.sqrt(self.config.num_layers)
376
+
377
+
378
+ @dataclass
379
+ class MambaOutput(ModelOutput):
380
+ """
381
+ Class for the MAMBA model outputs.
382
+
383
+ Args:
384
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
385
+ Sequence of hidden-states at the output of the last layer of the model.
386
+ cache_params (`MambaCache`):
387
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
388
+ avoid providing the old `input_ids`.
389
+
390
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
391
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
392
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
393
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
394
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
395
+
396
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
397
+ """
398
+
399
+ last_hidden_state: Optional[torch.FloatTensor] = None
400
+ cache_params: Optional[MambaCache] = None
401
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
402
+
403
+
404
+ @dataclass
405
+ class MambaCausalLMOutput(ModelOutput):
406
+ """
407
+ Base class for causal language model (or autoregressive) outputs.
408
+
409
+ Args:
410
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
411
+ Language modeling loss (for next-token prediction).
412
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
413
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
414
+ cache_params (`MambaCache`):
415
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
416
+ avoid providing the old `input_ids`.
417
+
418
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
419
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
420
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
421
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
422
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
423
+
424
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
425
+ """
426
+
427
+ loss: Optional[torch.FloatTensor] = None
428
+ logits: Optional[torch.FloatTensor] = None
429
+ cache_params: Optional[MambaCache] = None
430
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
431
+
432
+
433
+ class MambaModel(MambaPreTrainedModel):
434
+ def __init__(self, config):
435
+ super().__init__(config)
436
+
437
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
438
+ self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
439
+
440
+ self.gradient_checkpointing = False
441
+ self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
442
+ # Initialize weights and apply final processing
443
+ self.post_init()
444
+
445
+ def get_input_embeddings(self):
446
+ return self.embeddings
447
+
448
+ def set_input_embeddings(self, new_embeddings):
449
+ self.embeddings = new_embeddings
450
+
451
+ def forward(
452
+ self,
453
+ input_ids: Optional[torch.LongTensor] = None,
454
+ inputs_embeds: Optional[torch.LongTensor] = None,
455
+ cache_params: Optional[MambaCache] = None,
456
+ use_cache: Optional[bool] = None,
457
+ output_hidden_states: Optional[bool] = None,
458
+ return_dict: Optional[bool] = None,
459
+ **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
460
+ ) -> Union[Tuple, MambaOutput]:
461
+ output_hidden_states = (
462
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
463
+ )
464
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
465
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
466
+
467
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
468
+ raise ValueError(
469
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
470
+ )
471
+
472
+ if inputs_embeds is None:
473
+ inputs_embeds = self.embeddings(input_ids)
474
+
475
+ if self.gradient_checkpointing and self.training and use_cache:
476
+ use_cache = False
477
+
478
+ if cache_params is None and use_cache:
479
+ cache_params = MambaCache(
480
+ self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
481
+ )
482
+
483
+ hidden_states = inputs_embeds
484
+ all_hidden_states = () if output_hidden_states else None
485
+ for mixer_block in self.layers:
486
+ if self.gradient_checkpointing and self.training:
487
+ hidden_states = self._gradient_checkpointing_func(mixer_block.__call__, hidden_states, cache_params)
488
+ else:
489
+ hidden_states = mixer_block(hidden_states, cache_params=cache_params)
490
+
491
+ if output_hidden_states:
492
+ all_hidden_states = all_hidden_states + (hidden_states,)
493
+
494
+ if use_cache:
495
+ cache_params.seqlen_offset += inputs_embeds.shape[1]
496
+
497
+ hidden_states = self.norm_f(hidden_states)
498
+
499
+ if output_hidden_states:
500
+ all_hidden_states = all_hidden_states + (hidden_states,)
501
+
502
+ if not return_dict:
503
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
504
+
505
+ return MambaOutput(
506
+ last_hidden_state=hidden_states,
507
+ cache_params=cache_params if use_cache else None,
508
+ hidden_states=all_hidden_states,
509
+ )
510
+
511
+
512
+ class MambaForCausalLM(MambaPreTrainedModel):
513
+ _tied_weights_keys = ["lm_head.weight"]
514
+
515
+ def __init__(self, config):
516
+ super().__init__(config)
517
+ self.backbone = MambaModel(config)
518
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
519
+ # Initialize weights and apply final processing
520
+ self.post_init()
521
+
522
+ def get_output_embeddings(self):
523
+ return self.lm_head
524
+
525
+ def set_output_embeddings(self, new_embeddings):
526
+ self.lm_head = new_embeddings
527
+
528
+ def get_input_embeddings(self):
529
+ return self.backbone.get_input_embeddings()
530
+
531
+ def set_input_embeddings(self, new_embeddings):
532
+ return self.backbone.set_input_embeddings(new_embeddings)
533
+
534
+ def _update_model_kwargs_for_generation(
535
+ self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
536
+ ) -> Dict[str, Any]:
537
+ model_kwargs["cache_params"] = outputs.get("cache_params", None)
538
+ return model_kwargs
539
+
540
+ def prepare_inputs_for_generation(
541
+ self, input_ids, cache_params: Optional[MambaCache] = None, inputs_embeds=None, attention_mask=None, **kwargs
542
+ ):
543
+ # only last token for inputs_ids if the state is passed along.
544
+ if cache_params is not None:
545
+ input_ids = input_ids[:, -1].unsqueeze(-1)
546
+
547
+ if inputs_embeds is not None and cache_params is None:
548
+ model_inputs = {"inputs_embeds": inputs_embeds}
549
+ else:
550
+ model_inputs = {"input_ids": input_ids}
551
+
552
+ model_inputs["cache_params"] = cache_params
553
+ return model_inputs
554
+
555
+ def forward(
556
+ self,
557
+ input_ids: Optional[torch.LongTensor] = None,
558
+ attention_mask: Optional[torch.Tensor] = None, # noqa
559
+ inputs_embeds: Optional[torch.FloatTensor] = None,
560
+ cache_params: Optional[MambaCache] = None,
561
+ labels: Optional[torch.LongTensor] = None,
562
+ output_hidden_states: Optional[bool] = None,
563
+ return_dict: Optional[bool] = None,
564
+ use_cache: Optional[bool] = None,
565
+ **kwargs, # for now we need this for generation
566
+ ) -> Union[Tuple, MambaCausalLMOutput]:
567
+ r"""
568
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
569
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
570
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
571
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
572
+ """
573
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
574
+
575
+ mamba_outputs = self.backbone(
576
+ input_ids,
577
+ cache_params=cache_params,
578
+ inputs_embeds=inputs_embeds,
579
+ output_hidden_states=output_hidden_states,
580
+ return_dict=return_dict,
581
+ use_cache=use_cache,
582
+ )
583
+ hidden_states = mamba_outputs[0]
584
+ logits = self.lm_head(hidden_states)
585
+
586
+ loss = None
587
+ if labels is not None:
588
+ if self.config.fuse_cross_entropy:
589
+ loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
590
+ else:
591
+ loss_fct = nn.CrossEntropyLoss()
592
+ # Enable model parallelism
593
+ labels = labels.to(logits.device)
594
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
595
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
596
+
597
+ if not return_dict:
598
+ output = (logits,) + mamba_outputs[1:]
599
+ return (loss,) + output if loss is not None else output
600
+
601
+ return MambaCausalLMOutput(
602
+ loss=loss,
603
+ logits=logits,
604
+ cache_params=mamba_outputs.cache_params,
605
+ hidden_states=mamba_outputs.hidden_states,
606
+ )
fla2/models/mamba2/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.mamba2.configuration_mamba2 import Mamba2Config
6
+ from fla.models.mamba2.modeling_mamba2 import Mamba2ForCausalLM, Mamba2Model
7
+
8
+ AutoConfig.register(Mamba2Config.model_type, Mamba2Config, True)
9
+ AutoModel.register(Mamba2Config, Mamba2Model, True)
10
+ AutoModelForCausalLM.register(Mamba2Config, Mamba2ForCausalLM, True)
11
+
12
+
13
+ __all__ = ['Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model']
fla2/models/mamba2/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (710 Bytes). View file
 
fla2/models/mamba2/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (528 Bytes). View file
 
fla2/models/mamba2/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (530 Bytes). View file
 
fla2/models/mamba2/__pycache__/configuration_mamba2.cpython-312.pyc ADDED
Binary file (7.67 kB). View file
 
fla2/models/mamba2/__pycache__/configuration_mamba2.cpython-38.pyc ADDED
Binary file (6.66 kB). View file