msj19 commited on
Commit
779d3b2
·
verified ·
1 Parent(s): 454c53f

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla3/layers/__pycache__/__init__.cpython-310.pyc +0 -0
  2. fla3/layers/__pycache__/__init__.cpython-312.pyc +0 -0
  3. fla3/layers/__pycache__/abc.cpython-310.pyc +0 -0
  4. fla3/layers/__pycache__/attn.cpython-310.pyc +0 -0
  5. fla3/layers/__pycache__/attn.cpython-312.pyc +0 -0
  6. fla3/layers/__pycache__/based.cpython-310.pyc +0 -0
  7. fla3/layers/__pycache__/bitattn.cpython-310.pyc +0 -0
  8. fla3/layers/__pycache__/delta_net.cpython-310.pyc +0 -0
  9. fla3/layers/__pycache__/emdeltanet.cpython-310.pyc +0 -0
  10. fla3/layers/__pycache__/emdeltanet.cpython-312.pyc +0 -0
  11. fla3/layers/__pycache__/forgetting_attn.cpython-310.pyc +0 -0
  12. fla3/layers/__pycache__/gated_deltanet.cpython-310.pyc +0 -0
  13. fla3/layers/__pycache__/gated_deltanet.cpython-312.pyc +0 -0
  14. fla3/layers/__pycache__/lightnet.cpython-310.pyc +0 -0
  15. fla3/layers/utils.py +197 -0
  16. fla3/models/gated_deltanet/__pycache__/__init__.cpython-312.pyc +0 -0
  17. fla3/models/gsa/__pycache__/__init__.cpython-310.pyc +0 -0
  18. fla3/models/hgrn/__pycache__/modeling_hgrn.cpython-310.pyc +0 -0
  19. fla3/models/hgrn/configuration_hgrn.py +81 -0
  20. fla3/models/hgrn2/__pycache__/__init__.cpython-310.pyc +0 -0
  21. fla3/models/hgrn2/__pycache__/configuration_hgrn2.cpython-310.pyc +0 -0
  22. fla3/models/hgrn2/__pycache__/modeling_hgrn2.cpython-310.pyc +0 -0
  23. fla3/models/hgrn2/modeling_hgrn2.py +421 -0
  24. fla3/models/lightnet/__pycache__/__init__.cpython-310.pyc +0 -0
  25. fla3/models/lightnet/__pycache__/configuration_lightnet.cpython-310.pyc +0 -0
  26. fla3/models/lightnet/__pycache__/modeling_lightnet.cpython-310.pyc +0 -0
  27. fla3/models/lightnet/configuration_lightnet.py +83 -0
  28. fla3/models/lightnet/modeling_lightnet.py +410 -0
  29. fla3/models/linear_attn/__init__.py +12 -0
  30. fla3/models/linear_attn/__pycache__/__init__.cpython-310.pyc +0 -0
  31. fla3/models/linear_attn/__pycache__/configuration_linear_attn.cpython-310.pyc +0 -0
  32. fla3/models/linear_attn/__pycache__/modeling_linear_attn.cpython-310.pyc +0 -0
  33. fla3/models/linear_attn/configuration_linear_attn.py +91 -0
  34. fla3/models/linear_attn/modeling_linear_attn.py +406 -0
  35. fla3/models/mamba/__init__.py +13 -0
  36. fla3/models/mamba/__pycache__/__init__.cpython-310.pyc +0 -0
  37. fla3/models/mamba/__pycache__/configuration_mamba.cpython-310.pyc +0 -0
  38. fla3/models/mamba/__pycache__/modeling_mamba.cpython-310.pyc +0 -0
  39. fla3/models/mamba/configuration_mamba.py +166 -0
  40. fla3/models/mamba/modeling_mamba.py +565 -0
  41. fla3/models/mamba2/__init__.py +13 -0
  42. fla3/models/mamba2/__pycache__/__init__.cpython-310.pyc +0 -0
  43. fla3/models/mamba2/__pycache__/configuration_mamba2.cpython-310.pyc +0 -0
  44. fla3/models/mamba2/__pycache__/modeling_mamba2.cpython-310.pyc +0 -0
  45. fla3/models/mamba2/configuration_mamba2.py +167 -0
  46. fla3/models/mamba2/modeling_mamba2.py +562 -0
  47. fla3/models/nsa/__init__.py +15 -0
  48. fla3/models/nsa/__pycache__/__init__.cpython-310.pyc +0 -0
  49. fla3/models/nsa/__pycache__/configuration_nsa.cpython-310.pyc +0 -0
  50. fla3/models/nsa/__pycache__/modeling_nsa.cpython-310.pyc +0 -0
fla3/layers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (204 Bytes). View file
 
fla3/layers/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (210 Bytes). View file
 
fla3/layers/__pycache__/abc.cpython-310.pyc ADDED
Binary file (5.55 kB). View file
 
fla3/layers/__pycache__/attn.cpython-310.pyc ADDED
Binary file (4.47 kB). View file
 
fla3/layers/__pycache__/attn.cpython-312.pyc ADDED
Binary file (7.7 kB). View file
 
fla3/layers/__pycache__/based.cpython-310.pyc ADDED
Binary file (3.32 kB). View file
 
fla3/layers/__pycache__/bitattn.cpython-310.pyc ADDED
Binary file (4.37 kB). View file
 
fla3/layers/__pycache__/delta_net.cpython-310.pyc ADDED
Binary file (8.32 kB). View file
 
fla3/layers/__pycache__/emdeltanet.cpython-310.pyc ADDED
Binary file (9.49 kB). View file
 
fla3/layers/__pycache__/emdeltanet.cpython-312.pyc ADDED
Binary file (18.3 kB). View file
 
fla3/layers/__pycache__/forgetting_attn.cpython-310.pyc ADDED
Binary file (3.87 kB). View file
 
fla3/layers/__pycache__/gated_deltanet.cpython-310.pyc ADDED
Binary file (8.62 kB). View file
 
fla3/layers/__pycache__/gated_deltanet.cpython-312.pyc ADDED
Binary file (13.6 kB). View file
 
fla3/layers/__pycache__/lightnet.cpython-310.pyc ADDED
Binary file (5.32 kB). View file
 
fla3/layers/utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # Code is adapted from flash-attn.bert_padding.py
5
+
6
+ from typing import Tuple
7
+
8
+ import torch
9
+ from einops import rearrange, repeat
10
+
11
+ from ..ops.utils.index import prepare_cu_seqlens_from_mask, prepare_lens_from_mask
12
+ from ..utils import tensor_cache
13
+
14
+
15
+ class IndexFirstAxis(torch.autograd.Function):
16
+
17
+ @staticmethod
18
+ def forward(ctx, x, indices):
19
+ ctx.save_for_backward(indices)
20
+ assert x.ndim >= 2
21
+ ctx.first_axis_dim, other_shape = x.shape[0], x.shape[1:]
22
+ second_dim = other_shape.numel()
23
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
24
+ # return x[indices]
25
+ return torch.gather(
26
+ rearrange(x, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
27
+ ).reshape(-1, *other_shape)
28
+
29
+ @staticmethod
30
+ def backward(ctx, do):
31
+ (indices,) = ctx.saved_tensors
32
+ assert do.ndim >= 2
33
+ other_shape = do.shape[1:]
34
+ do = rearrange(do, "b ... -> b (...)")
35
+ dx = torch.zeros(
36
+ [ctx.first_axis_dim, do.shape[1]],
37
+ device=do.device,
38
+ dtype=do.dtype,
39
+ )
40
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
41
+ # dx[indices] = do
42
+ dx.scatter_(0, repeat(indices, "z -> z d", d=do.shape[1]), do)
43
+ return dx.reshape(ctx.first_axis_dim, *other_shape), None
44
+
45
+
46
+ index_first_axis = IndexFirstAxis.apply
47
+
48
+
49
+ class IndexPutFirstAxis(torch.autograd.Function):
50
+
51
+ @staticmethod
52
+ def forward(ctx, x, indices, first_axis_dim):
53
+ ctx.save_for_backward(indices)
54
+ assert indices.ndim == 1
55
+ assert x.ndim >= 2
56
+ y = torch.zeros(first_axis_dim, *x.shape[1:], device=x.device, dtype=x.dtype)
57
+ # TODO [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
58
+ y[indices] = x
59
+ # y.scatter_(0, repeat(indices, 'z -> z d', d=x.shape[1]), x)
60
+ return y
61
+
62
+ @staticmethod
63
+ def backward(ctx, do):
64
+ (indices,) = ctx.saved_tensors
65
+ # TODO [2022-03-04] For some reason torch.gather is a bit faster than indexing.
66
+ dx = do[indices]
67
+ # dx = torch.gather(do, 0, repeat(indices, 'z -> z d', d=do.shape[1]))
68
+ return dx, None, None
69
+
70
+
71
+ index_put_first_axis = IndexPutFirstAxis.apply
72
+
73
+
74
+ @tensor_cache
75
+ def get_unpad_data(
76
+ attention_mask: torch.Tensor
77
+ ) -> tuple[torch.Tensor, torch.Tensor, int]:
78
+ """
79
+ Retrieves indexing data required to repad unpadded (ragged) tensors.
80
+
81
+ Args:
82
+ attention_mask (`torch.Tensor`):
83
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
84
+
85
+ Return:
86
+ indices (`torch.Tensor`):
87
+ The indices of non-masked tokens from the flattened input sequence.
88
+ cu_seqlens (`torch.Tensor`):
89
+ The cumulative sequence lengths, used to index into ragged (unpadded) tensors.
90
+ `cu_seqlens` shape is [batch_size + 1].
91
+ max_seqlen_in_batch (`int`):
92
+ Maximum sequence length in batch.
93
+ """
94
+ lens = prepare_lens_from_mask(attention_mask)
95
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
96
+ max_seqlen_in_batch = lens.max().item()
97
+ cu_seqlens = prepare_cu_seqlens_from_mask(attention_mask)
98
+ return indices, cu_seqlens, max_seqlen_in_batch
99
+
100
+
101
+ def unpad_input(
102
+ q: torch.Tensor,
103
+ states: Tuple[torch.Tensor],
104
+ attention_mask: torch.Tensor,
105
+ q_len: int,
106
+ keepdim: bool = False,
107
+ ):
108
+ """
109
+ Unpads query, key, and values tensors, using a single dimension for all tokens
110
+ even though they belong to different batches.
111
+
112
+
113
+ Arguments:
114
+ q (`torch.Tensor`):
115
+ Query state with padding. Shape: [batch_size, q_len, ...].
116
+ states (`Tuple[torch.Tensor]`):
117
+ Attention state with padding. Shape: [batch_size, seq_len, ...].
118
+ attention_mask (`torch.Tensor`):
119
+ Boolean or int tensor of shape [batch_size, sequence_length], 1 means valid and 0 means not valid.
120
+ q_len (`int`):
121
+ Target length.
122
+ keepdim (`bool`):
123
+ Whether to keep the batch dimension. Default: `False`.
124
+
125
+ Return:
126
+ q (`torch.Tensor`):
127
+ Query state without padding.
128
+ Shape: [1, total_target_length, ...] if `keepdim=True` else [total_target_length, ...].
129
+ states (`Tuple[torch.Tensor]`):
130
+ Attention state without padding.
131
+ Shape: [1, total_source_length, ...] if `keepdim=True` else [total_source_length, ...].
132
+ indices_q (`torch.Tensor`):
133
+ The indices of non-masked tokens from the flattened input target sequence.
134
+ (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
135
+ The cumulative sequence lengths for the target (query) and source (key, value),
136
+ used to index into ragged (unpadded) tensors.
137
+ `cu_seqlens` shape is [batch_size + 1].
138
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
139
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence
140
+ i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
141
+ """
142
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask)
143
+ batch_size, seq_len, *_ = states[0].shape
144
+
145
+ state = tuple(
146
+ index_first_axis(rearrange(s, "b s ... -> (b s) ..."), indices_k)
147
+ for s in states
148
+ )
149
+
150
+ if q_len == seq_len:
151
+ q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices_k)
152
+ cu_seqlens_q = cu_seqlens_k
153
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
154
+ indices_q = indices_k
155
+ elif q_len == 1:
156
+ max_seqlen_in_batch_q = 1
157
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
158
+ indices_q = cu_seqlens_q[:-1]
159
+ q = q.squeeze(1)
160
+ else:
161
+ raise NotImplementedError("We only support either q_len == k_len (prefilling) or q_len == 1 (decoding)")
162
+
163
+ if keepdim:
164
+ q = q.unsqueeze(0)
165
+ state = tuple(s.unsqueeze(0) for s in state)
166
+
167
+ return (
168
+ q,
169
+ state,
170
+ indices_q,
171
+ (cu_seqlens_q, cu_seqlens_k),
172
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
173
+ )
174
+
175
+
176
+ def pad_input(
177
+ hidden_states: torch.Tensor,
178
+ indices: torch.LongTensor,
179
+ batch_size: int,
180
+ seq_len: int,
181
+ ) -> torch.Tensor:
182
+ """
183
+ Args:
184
+ hidden_states ([total_tokens, ...]):
185
+ where total_tokens denotes the number of tokens in selected in attention_mask.
186
+ indices ([total_tokens]):
187
+ the indices that represent the non-masked tokens of the original padded input sequence.
188
+ batch_size (int):
189
+ batch_size size for the padded sequence.
190
+ seq_len (int):
191
+ maximum sequence length for the padded sequence.
192
+
193
+ Return:
194
+ hidden_states of shape [batch_size, seq_len, ...]
195
+ """
196
+ output = index_put_first_axis(hidden_states, indices, batch_size * seq_len)
197
+ return rearrange(output, "(b s) ... -> b s ...", b=batch_size)
fla3/models/gated_deltanet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (737 Bytes). View file
 
fla3/models/gsa/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (516 Bytes). View file
 
fla3/models/hgrn/__pycache__/modeling_hgrn.cpython-310.pyc ADDED
Binary file (11.9 kB). View file
 
fla3/models/hgrn/configuration_hgrn.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class HGRNConfig(PretrainedConfig):
9
+
10
+ model_type = 'hgrn'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "fused_recurrent",
16
+ hidden_size: int = 2048,
17
+ num_hidden_layers: int = 24,
18
+ expand_ratio: Optional[int] = 1,
19
+ use_short_conv: bool = False,
20
+ conv_size: int = 4,
21
+ use_lower_bound: bool = True,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ elementwise_affine: Optional[bool] = True,
27
+ norm_eps: float = 1e-6,
28
+ attn: Optional[Dict] = None,
29
+ use_cache: bool = True,
30
+ pad_token_id: int = None,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ tie_word_embeddings: bool = False,
34
+ initializer_range: float = 0.02,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ **kwargs
40
+ ):
41
+ self.attn_mode = attn_mode
42
+ self.hidden_size = hidden_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.expand_ratio = expand_ratio
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.use_lower_bound = use_lower_bound
48
+ self.max_position_embeddings = max_position_embeddings
49
+ self.hidden_ratio = hidden_ratio
50
+ self.intermediate_size = intermediate_size
51
+ self.elementwise_affine = elementwise_affine
52
+ self.attn = attn
53
+ self.norm_eps = norm_eps
54
+ self.hidden_act = hidden_act
55
+ self.use_cache = use_cache
56
+ self.initializer_range = initializer_range
57
+
58
+ self.fuse_norm = fuse_norm
59
+ self.fuse_swiglu = fuse_swiglu
60
+ self.fuse_cross_entropy = fuse_cross_entropy
61
+ self.vocab_size = vocab_size
62
+
63
+ if attn is not None:
64
+ if not isinstance(attn, Dict):
65
+ raise ValueError("attn must be a dictionary")
66
+ if 'layers' not in attn:
67
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
68
+ if 'num_heads' not in attn:
69
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
70
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
71
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
72
+ attn['window_size'] = attn.get('window_size', None)
73
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
74
+
75
+ super().__init__(
76
+ pad_token_id=pad_token_id,
77
+ bos_token_id=bos_token_id,
78
+ eos_token_id=eos_token_id,
79
+ tie_word_embeddings=tie_word_embeddings,
80
+ **kwargs,
81
+ )
fla3/models/hgrn2/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (532 Bytes). View file
 
fla3/models/hgrn2/__pycache__/configuration_hgrn2.cpython-310.pyc ADDED
Binary file (2.64 kB). View file
 
fla3/models/hgrn2/__pycache__/modeling_hgrn2.cpython-310.pyc ADDED
Binary file (11.9 kB). View file
 
fla3/models/hgrn2/modeling_hgrn2.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.hgrn2 import HGRN2Attention
20
+ from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as HGRN2MLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class HGRN2Block(nn.Module):
33
+ def __init__(self, config: HGRN2Config, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = HGRN2Attention(
53
+ mode=config.attn_mode,
54
+ hidden_size=config.hidden_size,
55
+ num_heads=config.num_heads,
56
+ expand_ratio=config.expand_ratio,
57
+ use_short_conv=config.use_short_conv,
58
+ conv_size=config.conv_size,
59
+ elementwise_affine=config.elementwise_affine,
60
+ norm_eps=config.norm_eps,
61
+ layer_idx=layer_idx
62
+ )
63
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
64
+ self.mlp = HGRN2MLP(
65
+ hidden_size=config.hidden_size,
66
+ hidden_ratio=config.hidden_ratio,
67
+ intermediate_size=config.intermediate_size,
68
+ hidden_act=config.hidden_act,
69
+ fuse_swiglu=config.fuse_swiglu
70
+ )
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.Tensor] = None,
76
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
77
+ use_cache: Optional[bool] = False,
78
+ output_attentions: Optional[bool] = False,
79
+ lower_bound: Optional[torch.Tensor] = False,
80
+ **kwargs: Unpack[Dict]
81
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
82
+ residual = hidden_states
83
+ hidden_states = self.attn_norm(hidden_states)
84
+ hidden_states, attentions, past_key_values = self.attn(
85
+ hidden_states=hidden_states,
86
+ attention_mask=attention_mask,
87
+ past_key_values=past_key_values,
88
+ use_cache=use_cache,
89
+ output_attentions=output_attentions,
90
+ lower_bound=lower_bound,
91
+ **kwargs
92
+ )
93
+ if self.config.fuse_norm:
94
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
95
+ else:
96
+ hidden_states = residual + hidden_states
97
+ residual = hidden_states
98
+ hidden_states = self.mlp_norm(hidden_states)
99
+ hidden_states = self.mlp(hidden_states, **kwargs)
100
+ hidden_states = residual + hidden_states
101
+
102
+ outputs = (hidden_states, attentions, past_key_values)
103
+
104
+ return outputs
105
+
106
+
107
+ class HGRN2PreTrainedModel(PreTrainedModel):
108
+
109
+ config_class = HGRN2Config
110
+ base_model_prefix = 'model'
111
+ supports_gradient_checkpointing = True
112
+ _no_split_modules = ['HGRN2Block']
113
+ _supports_cache_class = True
114
+
115
+ def __init__(self, *inputs, **kwargs):
116
+ super().__init__(*inputs, **kwargs)
117
+
118
+ def _init_weights(
119
+ self,
120
+ module: nn.Module,
121
+ prenorm_residual_strategy: Optional[str] = 'rescale',
122
+ num_residuals_per_layer: int = 2,
123
+ ):
124
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
125
+ # Slightly different from the TF version which uses truncated_normal for initialization
126
+ # cf https://github.com/pytorch/pytorch/pull/5617
127
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
128
+ if module.bias is not None:
129
+ nn.init.zeros_(module.bias)
130
+ elif isinstance(module, nn.Embedding):
131
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
132
+ elif hasattr(module, 'reset_parameters'):
133
+ module.reset_parameters()
134
+
135
+ if prenorm_residual_strategy is not None:
136
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
137
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
138
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
139
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
140
+ #
141
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
142
+ p = None
143
+ if hasattr(module, 'o_proj'):
144
+ p = module.o_proj.weight
145
+ elif hasattr(module, 'down_proj'):
146
+ p = module.down_proj.weight
147
+ if p is not None:
148
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
149
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
150
+ # We need to reinit p since this code could be called multiple times
151
+ # Having just p *= scale would repeatedly scale it down
152
+ if prenorm_residual_strategy == 'rescale':
153
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
154
+ with torch.no_grad():
155
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
156
+ elif prenorm_residual_strategy == 'zero':
157
+ nn.init.zeros_(p)
158
+ else:
159
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
160
+
161
+
162
+ class HGRN2Model(HGRN2PreTrainedModel):
163
+
164
+ def __init__(self, config: HGRN2Config):
165
+ super().__init__(config)
166
+ self.padding_idx = config.pad_token_id
167
+ self.vocab_size = config.vocab_size
168
+
169
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
170
+ if config.use_lower_bound:
171
+ self.lower_bounds = nn.Parameter(torch.zeros(config.num_hidden_layers, config.hidden_size))
172
+ self.layers = nn.ModuleList([HGRN2Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
173
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
174
+
175
+ self.gradient_checkpointing = False
176
+
177
+ self.post_init()
178
+
179
+ def get_input_embeddings(self):
180
+ return self.embeddings
181
+
182
+ def set_input_embeddings(self, value):
183
+ self.embeddings = value
184
+
185
+ def forward(
186
+ self,
187
+ input_ids: Optional[torch.LongTensor] = None,
188
+ attention_mask: Optional[torch.Tensor] = None, # noqa
189
+ inputs_embeds: Optional[torch.FloatTensor] = None,
190
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
191
+ use_cache: Optional[bool] = None,
192
+ output_attentions: Optional[bool] = None,
193
+ output_hidden_states: Optional[bool] = None,
194
+ return_dict: Optional[bool] = None,
195
+ **kwargs: Unpack[Dict]
196
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
197
+ if output_attentions:
198
+ warnings.warn("`HGRN2Model` does not `output_attentions` now, setting it to `False`.")
199
+ output_attentions = False
200
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
201
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
202
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
203
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
204
+
205
+ # retrieve input_ids and inputs_embeds
206
+ if input_ids is not None and inputs_embeds is not None:
207
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
208
+ if input_ids is None and inputs_embeds is None:
209
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
210
+
211
+ if inputs_embeds is None:
212
+ inputs_embeds = self.embeddings(input_ids)
213
+ hidden_states = inputs_embeds
214
+
215
+ if use_cache and not isinstance(past_key_values, Cache):
216
+ past_key_values = Cache.from_legacy_cache(past_key_values)
217
+
218
+ if self.gradient_checkpointing and self.training and use_cache:
219
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
220
+ use_cache = False
221
+
222
+ all_hidden_states = () if output_hidden_states else None
223
+ all_attns = () if output_attentions else None
224
+
225
+ if self.config.use_lower_bound:
226
+ lower_bounds = self.lower_bounds.softmax(0)
227
+ lower_bounds = lower_bounds.cumsum(0) - lower_bounds[0]
228
+ for i, layer in enumerate(self.layers):
229
+ if output_hidden_states:
230
+ all_hidden_states += (hidden_states,)
231
+
232
+ lower_bound = lower_bounds[i] if self.config.use_lower_bound else None
233
+ if self.gradient_checkpointing and self.training:
234
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
235
+ layer.__call__,
236
+ hidden_states,
237
+ attention_mask,
238
+ past_key_values,
239
+ use_cache,
240
+ output_attentions,
241
+ lower_bound,
242
+ **kwargs
243
+ )
244
+ else:
245
+ hidden_states, attentions, past_key_values = layer(
246
+ hidden_states,
247
+ attention_mask=attention_mask,
248
+ past_key_values=past_key_values,
249
+ use_cache=use_cache,
250
+ output_attentions=output_attentions,
251
+ lower_bound=lower_bound,
252
+ **kwargs
253
+ )
254
+
255
+ if output_attentions:
256
+ all_attns += (attentions,)
257
+
258
+ hidden_states = self.norm(hidden_states)
259
+
260
+ # add hidden states from the last decoder layer
261
+ if output_hidden_states:
262
+ all_hidden_states += (hidden_states,)
263
+
264
+ if not return_dict:
265
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
266
+ return BaseModelOutputWithPast(
267
+ last_hidden_state=hidden_states,
268
+ past_key_values=past_key_values,
269
+ hidden_states=all_hidden_states,
270
+ attentions=all_attns
271
+ )
272
+
273
+
274
+ class HGRN2ForCausalLM(HGRN2PreTrainedModel, GenerationMixin):
275
+
276
+ _tied_weights_keys = ["lm_head.weight"]
277
+
278
+ def __init__(self, config):
279
+ super().__init__(config)
280
+ self.model = HGRN2Model(config)
281
+ self.vocab_size = config.vocab_size
282
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
283
+ self.criterion = None
284
+
285
+ # Initialize weights and apply final processing
286
+ self.post_init()
287
+
288
+ def get_input_embeddings(self):
289
+ return self.model.embeddings
290
+
291
+ def set_input_embeddings(self, value):
292
+ self.model.embeddings = value
293
+
294
+ def get_output_embeddings(self):
295
+ return self.lm_head
296
+
297
+ def set_output_embeddings(self, new_embeddings):
298
+ self.lm_head = new_embeddings
299
+
300
+ def set_decoder(self, decoder):
301
+ self.model = decoder
302
+
303
+ def get_decoder(self):
304
+ return self.model
305
+
306
+ def generate(self, *args, **kwargs):
307
+ try:
308
+ return super().generate(*args, **kwargs)
309
+ except AttributeError as exception:
310
+ if 'past_key_values' in str(exception):
311
+ raise AttributeError(
312
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
313
+ f"which is not supported for {self.__class__.__name__}. "
314
+ f"Try another generation strategy instead. "
315
+ f"For the available generation strategies, check this doc: "
316
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
317
+ )
318
+ else:
319
+ raise exception
320
+
321
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
322
+ def prepare_inputs_for_generation(
323
+ self,
324
+ input_ids: torch.LongTensor = None,
325
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
326
+ attention_mask: Optional[torch.Tensor] = None,
327
+ inputs_embeds: Optional[torch.Tensor] = None,
328
+ use_cache: bool = True,
329
+ logits_to_keep: Optional[int] = None,
330
+ **kwargs: Unpack[Dict]
331
+ ):
332
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
333
+ if past_key_values is not None and len(past_key_values) > 0:
334
+ input_ids = input_ids[:, -1:]
335
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
336
+ if inputs_embeds is not None and len(past_key_values) == 0:
337
+ model_inputs = {'inputs_embeds': inputs_embeds}
338
+ else:
339
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
340
+ # recompiles graphs as the stride of the inputs is a guard.
341
+ # Ref: https://github.com/huggingface/transformers/pull/29114
342
+ # TODO: use `next_tokens` directly instead.
343
+ model_inputs = {'input_ids': input_ids.contiguous()}
344
+
345
+ if logits_to_keep is not None:
346
+ model_inputs['logits_to_keep'] = logits_to_keep
347
+
348
+ model_inputs.update({
349
+ 'past_key_values': past_key_values,
350
+ 'use_cache': use_cache,
351
+ 'attention_mask': attention_mask,
352
+ })
353
+ return model_inputs
354
+
355
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
356
+ def forward(
357
+ self,
358
+ input_ids: torch.LongTensor = None,
359
+ attention_mask: Optional[torch.Tensor] = None,
360
+ inputs_embeds: Optional[torch.Tensor] = None,
361
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
362
+ labels: Optional[torch.LongTensor] = None,
363
+ use_cache: Optional[bool] = None,
364
+ output_attentions: Optional[bool] = None,
365
+ output_hidden_states: Optional[bool] = None,
366
+ return_dict: Optional[bool] = None,
367
+ logits_to_keep: Optional[int] = 0,
368
+ **kwargs: Unpack[Dict]
369
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
370
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
371
+ output_hidden_states = (
372
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
373
+ )
374
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
375
+
376
+ outputs = self.model(
377
+ input_ids=input_ids,
378
+ attention_mask=attention_mask,
379
+ inputs_embeds=inputs_embeds,
380
+ past_key_values=past_key_values,
381
+ use_cache=use_cache,
382
+ output_attentions=output_attentions,
383
+ output_hidden_states=output_hidden_states,
384
+ return_dict=return_dict,
385
+ **kwargs
386
+ )
387
+
388
+ hidden_states = outputs[0]
389
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
390
+
391
+ loss, logits = None, None
392
+ if not fuse_linear_and_cross_entropy or labels is None:
393
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
394
+ if labels is not None:
395
+ if getattr(self, 'criterion', None) is None:
396
+ if fuse_linear_and_cross_entropy:
397
+ criterion = FusedLinearCrossEntropyLoss()
398
+ elif self.config.fuse_cross_entropy:
399
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
400
+ else:
401
+ criterion = nn.CrossEntropyLoss()
402
+ else:
403
+ criterion = self.criterion
404
+ labels = labels.to(hidden_states.device)
405
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
406
+ if fuse_linear_and_cross_entropy:
407
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
408
+ else:
409
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
410
+
411
+ if not return_dict:
412
+ output = (logits,) + outputs[1:]
413
+ return (loss,) + output if loss is not None else output
414
+
415
+ return CausalLMOutputWithPast(
416
+ loss=loss,
417
+ logits=logits,
418
+ past_key_values=outputs.past_key_values,
419
+ hidden_states=outputs.hidden_states,
420
+ attentions=outputs.attentions,
421
+ )
fla3/models/lightnet/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (556 Bytes). View file
 
fla3/models/lightnet/__pycache__/configuration_lightnet.cpython-310.pyc ADDED
Binary file (2.5 kB). View file
 
fla3/models/lightnet/__pycache__/modeling_lightnet.cpython-310.pyc ADDED
Binary file (11.7 kB). View file
 
fla3/models/lightnet/configuration_lightnet.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class LightNetConfig(PretrainedConfig):
9
+
10
+ model_type = 'lightnet'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ attn_mode: str = "chunk",
18
+ num_heads: Optional[int] = None,
19
+ expand_ratio: Optional[int] = 128,
20
+ use_short_conv: bool = False,
21
+ conv_size: int = 4,
22
+ hidden_ratio: Optional[int] = 4,
23
+ intermediate_size: Optional[int] = None,
24
+ hidden_act: str = "swish",
25
+ max_position_embeddings: int = 2048,
26
+ gate_low_rank_dim: int = 128,
27
+ elementwise_affine: Optional[bool] = True,
28
+ norm_eps: float = 1e-6,
29
+ attn: Optional[Dict] = None,
30
+ use_cache: bool = True,
31
+ pad_token_id: int = None,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ tie_word_embeddings: bool = False,
35
+ initializer_range: float = 0.02,
36
+ fuse_norm: bool = True,
37
+ fuse_swiglu: bool = True,
38
+ fuse_cross_entropy: bool = True,
39
+ vocab_size: int = 32000,
40
+ **kwargs
41
+ ):
42
+ self.hidden_size = hidden_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.attn_mode = attn_mode
45
+ self.num_heads = num_heads
46
+ self.expand_ratio = expand_ratio
47
+ self.use_short_conv = use_short_conv
48
+ self.conv_size = conv_size
49
+ self.max_position_embeddings = max_position_embeddings
50
+ self.gate_low_rank_dim = gate_low_rank_dim
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+ self.elementwise_affine = elementwise_affine
55
+ self.norm_eps = norm_eps
56
+ self.attn = attn
57
+ self.use_cache = use_cache
58
+ self.initializer_range = initializer_range
59
+
60
+ self.fuse_norm = fuse_norm
61
+ self.fuse_swiglu = fuse_swiglu
62
+ self.fuse_cross_entropy = fuse_cross_entropy
63
+ self.vocab_size = vocab_size
64
+
65
+ if attn is not None:
66
+ if not isinstance(attn, Dict):
67
+ raise ValueError("attn must be a dictionary")
68
+ if 'layers' not in attn:
69
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
70
+ if 'num_heads' not in attn:
71
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
72
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
73
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
74
+ attn['window_size'] = attn.get('window_size', None)
75
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
76
+
77
+ super().__init__(
78
+ pad_token_id=pad_token_id,
79
+ bos_token_id=bos_token_id,
80
+ eos_token_id=eos_token_id,
81
+ tie_word_embeddings=tie_word_embeddings,
82
+ **kwargs,
83
+ )
fla3/models/lightnet/modeling_lightnet.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.lightnet import LightNetAttention
20
+ from fla.models.lightnet.configuration_lightnet import LightNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as LightNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class LightNetBlock(nn.Module):
33
+ def __init__(self, config: LightNetConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ max_position_embeddings=config.max_position_embeddings,
48
+ layer_idx=layer_idx
49
+ )
50
+ else:
51
+ self.attn = LightNetAttention(
52
+ mode=config.attn_mode,
53
+ hidden_size=config.hidden_size,
54
+ num_heads=config.num_heads,
55
+ expand_ratio=config.expand_ratio,
56
+ use_short_conv=config.use_short_conv,
57
+ conv_size=config.conv_size,
58
+ gate_low_rank_dim=config.gate_low_rank_dim,
59
+ elementwise_affine=config.elementwise_affine,
60
+ norm_eps=config.norm_eps,
61
+ layer_idx=layer_idx
62
+ )
63
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
64
+ self.mlp = LightNetMLP(
65
+ hidden_size=config.hidden_size,
66
+ hidden_ratio=config.hidden_ratio,
67
+ intermediate_size=config.intermediate_size,
68
+ hidden_act=config.hidden_act,
69
+ fuse_swiglu=config.fuse_swiglu
70
+ )
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.Tensor] = None,
76
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
77
+ use_cache: Optional[bool] = False,
78
+ output_attentions: Optional[bool] = False,
79
+ **kwargs: Unpack[Dict]
80
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
81
+ residual = hidden_states
82
+ hidden_states = self.attn_norm(hidden_states)
83
+ hidden_states, attentions, past_key_values = self.attn(
84
+ hidden_states=hidden_states,
85
+ attention_mask=attention_mask,
86
+ past_key_values=past_key_values,
87
+ use_cache=use_cache,
88
+ output_attentions=output_attentions,
89
+ **kwargs
90
+ )
91
+ if self.config.fuse_norm:
92
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
93
+ else:
94
+ hidden_states = residual + hidden_states
95
+ residual = hidden_states
96
+ hidden_states = self.mlp_norm(hidden_states)
97
+ hidden_states = self.mlp(hidden_states, **kwargs)
98
+ hidden_states = residual + hidden_states
99
+
100
+ outputs = (hidden_states, attentions, past_key_values)
101
+
102
+ return outputs
103
+
104
+
105
+ class LightNetPreTrainedModel(PreTrainedModel):
106
+
107
+ config_class = LightNetConfig
108
+ supports_gradient_checkpointing = True
109
+ _no_split_modules = ['LightNetBlock']
110
+ _supports_cache_class = True
111
+
112
+ def __init__(self, *inputs, **kwargs):
113
+ super().__init__(*inputs, **kwargs)
114
+
115
+ def _init_weights(
116
+ self,
117
+ module: nn.Module,
118
+ prenorm_residual_strategy: Optional[str] = 'rescale',
119
+ num_residuals_per_layer: int = 2,
120
+ ):
121
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
122
+ # Slightly different from the TF version which uses truncated_normal for initialization
123
+ # cf https://github.com/pytorch/pytorch/pull/5617
124
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
125
+ if module.bias is not None:
126
+ nn.init.zeros_(module.bias)
127
+ elif isinstance(module, nn.Embedding):
128
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
129
+ elif hasattr(module, 'reset_parameters'):
130
+ module.reset_parameters()
131
+
132
+ if prenorm_residual_strategy is not None:
133
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
134
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
135
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
136
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
137
+ #
138
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
139
+ p = None
140
+ if hasattr(module, 'o_proj'):
141
+ p = module.o_proj.weight
142
+ elif hasattr(module, 'down_proj'):
143
+ p = module.down_proj.weight
144
+ if p is not None:
145
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
146
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
147
+ # We need to reinit p since this code could be called multiple times
148
+ # Having just p *= scale would repeatedly scale it down
149
+ if prenorm_residual_strategy == 'rescale':
150
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
151
+ with torch.no_grad():
152
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
153
+ elif prenorm_residual_strategy == 'zero':
154
+ nn.init.zeros_(p)
155
+ else:
156
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
157
+
158
+
159
+ class LightNetModel(LightNetPreTrainedModel):
160
+
161
+ def __init__(self, config: LightNetConfig):
162
+ super().__init__(config)
163
+ self.padding_idx = config.pad_token_id
164
+ self.vocab_size = config.vocab_size
165
+
166
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
167
+ self.layers = nn.ModuleList([LightNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
168
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
169
+
170
+ self.gradient_checkpointing = False
171
+
172
+ self.post_init()
173
+
174
+ def get_input_embeddings(self):
175
+ return self.embeddings
176
+
177
+ def set_input_embeddings(self, value):
178
+ self.embeddings = value
179
+
180
+ def forward(
181
+ self,
182
+ input_ids: Optional[torch.LongTensor] = None,
183
+ attention_mask: Optional[torch.Tensor] = None, # noqa
184
+ inputs_embeds: Optional[torch.FloatTensor] = None,
185
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
186
+ use_cache: Optional[bool] = None,
187
+ output_attentions: Optional[bool] = None,
188
+ output_hidden_states: Optional[bool] = None,
189
+ return_dict: Optional[bool] = None,
190
+ **kwargs: Unpack[Dict]
191
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
192
+ if output_attentions:
193
+ warnings.warn("`LightNetModel` does not `output_attentions` now, setting it to `False`.")
194
+ output_attentions = False
195
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
196
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
197
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
198
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
199
+
200
+ # retrieve input_ids and inputs_embeds
201
+ if input_ids is not None and inputs_embeds is not None:
202
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
203
+ if input_ids is None and inputs_embeds is None:
204
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
205
+
206
+ if inputs_embeds is None:
207
+ inputs_embeds = self.embeddings(input_ids)
208
+ hidden_states = inputs_embeds
209
+
210
+ if use_cache and not isinstance(past_key_values, Cache):
211
+ past_key_values = Cache.from_legacy_cache(past_key_values)
212
+
213
+ if self.gradient_checkpointing and self.training and use_cache:
214
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
215
+ use_cache = False
216
+
217
+ all_hidden_states = () if output_hidden_states else None
218
+ all_attns = () if output_attentions else None
219
+
220
+ for i, layer in enumerate(self.layers):
221
+ if output_hidden_states:
222
+ all_hidden_states += (hidden_states,)
223
+
224
+ if self.gradient_checkpointing and self.training:
225
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
226
+ layer.__call__,
227
+ hidden_states,
228
+ attention_mask,
229
+ past_key_values,
230
+ use_cache,
231
+ output_attentions,
232
+ **kwargs
233
+ )
234
+ else:
235
+ hidden_states, attentions, past_key_values = layer(
236
+ hidden_states,
237
+ attention_mask=attention_mask,
238
+ past_key_values=past_key_values,
239
+ use_cache=use_cache,
240
+ output_attentions=output_attentions,
241
+ **kwargs
242
+ )
243
+
244
+ if output_attentions:
245
+ all_attns += (attentions,)
246
+
247
+ hidden_states = self.norm(hidden_states)
248
+
249
+ # add hidden states from the last decoder layer
250
+ if output_hidden_states:
251
+ all_hidden_states += (hidden_states,)
252
+
253
+ if not return_dict:
254
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
255
+ return BaseModelOutputWithPast(
256
+ last_hidden_state=hidden_states,
257
+ past_key_values=past_key_values,
258
+ hidden_states=all_hidden_states,
259
+ attentions=all_attns
260
+ )
261
+
262
+
263
+ class LightNetForCausalLM(LightNetPreTrainedModel, GenerationMixin):
264
+
265
+ _tied_weights_keys = ["lm_head.weight"]
266
+
267
+ def __init__(self, config):
268
+ super().__init__(config)
269
+ self.model = LightNetModel(config)
270
+ self.vocab_size = config.vocab_size
271
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
272
+ self.criterion = None
273
+
274
+ # Initialize weights and apply final processing
275
+ self.post_init()
276
+
277
+ def get_input_embeddings(self):
278
+ return self.model.embeddings
279
+
280
+ def set_input_embeddings(self, value):
281
+ self.model.embeddings = value
282
+
283
+ def get_output_embeddings(self):
284
+ return self.lm_head
285
+
286
+ def set_output_embeddings(self, new_embeddings):
287
+ self.lm_head = new_embeddings
288
+
289
+ def set_decoder(self, decoder):
290
+ self.model = decoder
291
+
292
+ def get_decoder(self):
293
+ return self.model
294
+
295
+ def generate(self, *args, **kwargs):
296
+ try:
297
+ return super().generate(*args, **kwargs)
298
+ except AttributeError as exception:
299
+ if 'past_key_values' in str(exception):
300
+ raise AttributeError(
301
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
302
+ f"which is not supported for {self.__class__.__name__}. "
303
+ f"Try another generation strategy instead. "
304
+ f"For the available generation strategies, check this doc: "
305
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
306
+ )
307
+ else:
308
+ raise exception
309
+
310
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
311
+ def prepare_inputs_for_generation(
312
+ self,
313
+ input_ids: torch.LongTensor = None,
314
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
315
+ attention_mask: Optional[torch.Tensor] = None,
316
+ inputs_embeds: Optional[torch.Tensor] = None,
317
+ use_cache: bool = True,
318
+ logits_to_keep: Optional[int] = None,
319
+ **kwargs: Unpack[Dict]
320
+ ):
321
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
322
+ if past_key_values is not None and len(past_key_values) > 0:
323
+ input_ids = input_ids[:, -1:]
324
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
325
+ if inputs_embeds is not None and len(past_key_values) == 0:
326
+ model_inputs = {'inputs_embeds': inputs_embeds}
327
+ else:
328
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
329
+ # recompiles graphs as the stride of the inputs is a guard.
330
+ # Ref: https://github.com/huggingface/transformers/pull/29114
331
+ # TODO: use `next_tokens` directly instead.
332
+ model_inputs = {'input_ids': input_ids.contiguous()}
333
+
334
+ if logits_to_keep is not None:
335
+ model_inputs['logits_to_keep'] = logits_to_keep
336
+
337
+ model_inputs.update({
338
+ 'past_key_values': past_key_values,
339
+ 'use_cache': use_cache,
340
+ 'attention_mask': attention_mask,
341
+ })
342
+ return model_inputs
343
+
344
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
345
+ def forward(
346
+ self,
347
+ input_ids: torch.LongTensor = None,
348
+ attention_mask: Optional[torch.Tensor] = None,
349
+ inputs_embeds: Optional[torch.Tensor] = None,
350
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
351
+ labels: Optional[torch.LongTensor] = None,
352
+ use_cache: Optional[bool] = None,
353
+ output_attentions: Optional[bool] = None,
354
+ output_hidden_states: Optional[bool] = None,
355
+ return_dict: Optional[bool] = None,
356
+ logits_to_keep: Optional[int] = 0,
357
+ **kwargs: Unpack[Dict]
358
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
359
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
360
+ output_hidden_states = (
361
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
362
+ )
363
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
364
+
365
+ outputs = self.model(
366
+ input_ids=input_ids,
367
+ attention_mask=attention_mask,
368
+ inputs_embeds=inputs_embeds,
369
+ past_key_values=past_key_values,
370
+ use_cache=use_cache,
371
+ output_attentions=output_attentions,
372
+ output_hidden_states=output_hidden_states,
373
+ return_dict=return_dict,
374
+ **kwargs
375
+ )
376
+
377
+ hidden_states = outputs[0]
378
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
379
+
380
+ loss, logits = None, None
381
+ if not fuse_linear_and_cross_entropy or labels is None:
382
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
383
+ if labels is not None:
384
+ if getattr(self, 'criterion', None) is None:
385
+ if fuse_linear_and_cross_entropy:
386
+ criterion = FusedLinearCrossEntropyLoss()
387
+ elif self.config.fuse_cross_entropy:
388
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
389
+ else:
390
+ criterion = nn.CrossEntropyLoss()
391
+ else:
392
+ criterion = self.criterion
393
+ labels = labels.to(hidden_states.device)
394
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
395
+ if fuse_linear_and_cross_entropy:
396
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
397
+ else:
398
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
399
+
400
+ if not return_dict:
401
+ output = (logits,) + outputs[1:]
402
+ return (loss,) + output if loss is not None else output
403
+
404
+ return CausalLMOutputWithPast(
405
+ loss=loss,
406
+ logits=logits,
407
+ past_key_values=outputs.past_key_values,
408
+ hidden_states=outputs.hidden_states,
409
+ attentions=outputs.attentions,
410
+ )
fla3/models/linear_attn/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.linear_attn.configuration_linear_attn import LinearAttentionConfig
6
+ from fla.models.linear_attn.modeling_linear_attn import LinearAttentionForCausalLM, LinearAttentionModel
7
+
8
+ AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig)
9
+ AutoModel.register(LinearAttentionConfig, LinearAttentionModel)
10
+ AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM)
11
+
12
+ __all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel']
fla3/models/linear_attn/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (592 Bytes). View file
 
fla3/models/linear_attn/__pycache__/configuration_linear_attn.cpython-310.pyc ADDED
Binary file (2.7 kB). View file
 
fla3/models/linear_attn/__pycache__/modeling_linear_attn.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
fla3/models/linear_attn/configuration_linear_attn.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class LinearAttentionConfig(PretrainedConfig):
9
+
10
+ model_type = 'linear_attn'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "fused_chunk",
16
+ hidden_size: int = 2048,
17
+ expand_k: int = 1,
18
+ expand_v: int = 1,
19
+ hidden_ratio: Optional[int] = 4,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ num_kv_heads: Optional[int] = None,
24
+ feature_map: str = "elementwise_product",
25
+ tie_feature_map_qk: bool = False,
26
+ norm_q: bool = False,
27
+ norm_k: bool = False,
28
+ norm_feature_map: bool = False,
29
+ hidden_act: str = "swish",
30
+ max_position_embeddings: int = 2048,
31
+ elementwise_affine: Optional[bool] = True,
32
+ norm_eps: float = 1e-6,
33
+ attn: Optional[Dict] = None,
34
+ use_cache: bool = True,
35
+ pad_token_id: int = None,
36
+ bos_token_id: int = 1,
37
+ eos_token_id: int = 2,
38
+ tie_word_embeddings: bool = False,
39
+ initializer_range: float = 0.02,
40
+ fuse_norm: bool = True,
41
+ fuse_swiglu: bool = True,
42
+ fuse_cross_entropy: bool = True,
43
+ vocab_size: int = 32000,
44
+ **kwargs
45
+ ):
46
+ self.attn_mode = attn_mode
47
+ self.hidden_size = hidden_size
48
+ self.expand_k = expand_k
49
+ self.expand_v = expand_v
50
+ self.hidden_ratio = hidden_ratio
51
+ self.intermediate_size = intermediate_size
52
+ self.num_hidden_layers = num_hidden_layers
53
+ self.num_heads = num_heads
54
+ self.num_kv_heads = num_kv_heads
55
+ self.feature_map = feature_map
56
+ self.tie_feature_map_qk = tie_feature_map_qk
57
+ self.norm_q = norm_q
58
+ self.norm_k = norm_k
59
+ self.norm_feature_map = norm_feature_map
60
+ self.hidden_act = hidden_act
61
+ self.max_position_embeddings = max_position_embeddings
62
+ self.elementwise_affine = elementwise_affine
63
+ self.norm_eps = norm_eps
64
+ self.attn = attn
65
+ self.use_cache = use_cache
66
+ self.initializer_range = initializer_range
67
+
68
+ self.fuse_norm = fuse_norm
69
+ self.fuse_swiglu = fuse_swiglu
70
+ self.fuse_cross_entropy = fuse_cross_entropy
71
+ self.vocab_size = vocab_size
72
+
73
+ if attn is not None:
74
+ if not isinstance(attn, Dict):
75
+ raise ValueError("attn must be a dictionary")
76
+ if 'layers' not in attn:
77
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
78
+ if 'num_heads' not in attn:
79
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
80
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
81
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
82
+ attn['window_size'] = attn.get('window_size', None)
83
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
84
+
85
+ super().__init__(
86
+ pad_token_id=pad_token_id,
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ tie_word_embeddings=tie_word_embeddings,
90
+ **kwargs,
91
+ )
fla3/models/linear_attn/modeling_linear_attn.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.linear_attn import LinearAttention
20
+ from fla.models.linear_attn.configuration_linear_attn import LinearAttentionConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as LinearAttentionMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class LinearAttentionBlock(nn.Module):
30
+ def __init__(self, config: LinearAttentionConfig, layer_idx: int):
31
+ super().__init__()
32
+
33
+ self.config = config
34
+ self.layer_idx = layer_idx
35
+
36
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
37
+ if config.attn is not None and layer_idx in config.attn['layers']:
38
+ self.attn = Attention(
39
+ hidden_size=config.hidden_size,
40
+ num_heads=config.attn['num_heads'],
41
+ num_kv_heads=config.attn['num_kv_heads'],
42
+ qkv_bias=config.attn['qkv_bias'],
43
+ window_size=config.attn['window_size'],
44
+ rope_theta=config.attn['rope_theta'],
45
+ max_position_embeddings=config.max_position_embeddings,
46
+ layer_idx=layer_idx
47
+ )
48
+ else:
49
+ self.attn = LinearAttention(
50
+ mode=config.attn_mode,
51
+ hidden_size=config.hidden_size,
52
+ expand_k=config.expand_k,
53
+ expand_v=config.expand_v,
54
+ num_heads=config.num_heads,
55
+ num_kv_heads=config.num_kv_heads,
56
+ feature_map=config.feature_map,
57
+ tie_feature_map_qk=config.tie_feature_map_qk,
58
+ norm_q=config.norm_q,
59
+ norm_k=config.norm_k,
60
+ do_feature_map_norm=config.norm_feature_map,
61
+ elementwise_affine=config.elementwise_affine,
62
+ norm_eps=config.norm_eps,
63
+ layer_idx=layer_idx
64
+ )
65
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
66
+ self.mlp = LinearAttentionMLP(
67
+ hidden_size=config.hidden_size,
68
+ hidden_ratio=config.hidden_ratio,
69
+ intermediate_size=config.intermediate_size,
70
+ hidden_act=config.hidden_act,
71
+ fuse_swiglu=config.fuse_swiglu
72
+ )
73
+
74
+ def forward(
75
+ self,
76
+ hidden_states: torch.Tensor,
77
+ attention_mask: Optional[torch.Tensor] = None,
78
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
79
+ use_cache: Optional[bool] = False,
80
+ output_attentions: Optional[bool] = False,
81
+ **kwargs,
82
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
83
+ residual = hidden_states
84
+ # currently not supported
85
+ attentions, past_key_values = None, None
86
+ hidden_states = self.attn_norm(hidden_states)
87
+ hidden_states = self.attn(hidden_states=hidden_states, **kwargs)
88
+ if self.config.fuse_norm:
89
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
90
+ else:
91
+ hidden_states = residual + hidden_states
92
+ residual = hidden_states
93
+ hidden_states = self.mlp_norm(hidden_states)
94
+ hidden_states = self.mlp(hidden_states, **kwargs)
95
+ hidden_states = residual + hidden_states
96
+
97
+ outputs = (hidden_states, attentions, past_key_values)
98
+
99
+ return outputs
100
+
101
+
102
+ class LinearAttentionPreTrainedModel(PreTrainedModel):
103
+
104
+ config_class = LinearAttentionConfig
105
+ base_model_prefix = 'model'
106
+ supports_gradient_checkpointing = True
107
+ _no_split_modules = ['LinearAttentionBlock']
108
+ _supports_cache_class = True
109
+
110
+ def __init__(self, *inputs, **kwargs):
111
+ super().__init__(*inputs, **kwargs)
112
+
113
+ def _init_weights(
114
+ self,
115
+ module: nn.Module,
116
+ prenorm_residual_strategy: Optional[str] = 'rescale',
117
+ num_residuals_per_layer: int = 2,
118
+ ):
119
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
120
+ # Slightly different from the TF version which uses truncated_normal for initialization
121
+ # cf https://github.com/pytorch/pytorch/pull/5617
122
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
123
+ if module.bias is not None:
124
+ nn.init.zeros_(module.bias)
125
+ elif isinstance(module, nn.Embedding):
126
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
127
+ elif hasattr(module, 'reset_parameters'):
128
+ module.reset_parameters()
129
+
130
+ if prenorm_residual_strategy is not None:
131
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
132
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
133
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
134
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
135
+ #
136
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
137
+ p = None
138
+ if hasattr(module, 'o_proj'):
139
+ p = module.o_proj.weight
140
+ elif hasattr(module, 'down_proj'):
141
+ p = module.down_proj.weight
142
+ if p is not None:
143
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
144
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
145
+ # We need to reinit p since this code could be called multiple times
146
+ # Having just p *= scale would repeatedly scale it down
147
+ if prenorm_residual_strategy == 'rescale':
148
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
149
+ with torch.no_grad():
150
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
151
+ elif prenorm_residual_strategy == 'zero':
152
+ nn.init.zeros_(p)
153
+ else:
154
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
155
+
156
+
157
+ class LinearAttentionModel(LinearAttentionPreTrainedModel):
158
+
159
+ def __init__(self, config: LinearAttentionConfig):
160
+ super().__init__(config)
161
+ self.padding_idx = config.pad_token_id
162
+ self.vocab_size = config.vocab_size
163
+
164
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
165
+ self.layers = nn.ModuleList([LinearAttentionBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
166
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
167
+
168
+ self.gradient_checkpointing = False
169
+
170
+ self.post_init()
171
+
172
+ def get_input_embeddings(self):
173
+ return self.embeddings
174
+
175
+ def set_input_embeddings(self, value):
176
+ self.embeddings = value
177
+
178
+ def forward(
179
+ self,
180
+ input_ids: Optional[torch.LongTensor] = None,
181
+ attention_mask: Optional[torch.Tensor] = None, # noqa
182
+ inputs_embeds: Optional[torch.FloatTensor] = None,
183
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
184
+ use_cache: Optional[bool] = None,
185
+ output_attentions: Optional[bool] = None,
186
+ output_hidden_states: Optional[bool] = None,
187
+ return_dict: Optional[bool] = None
188
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
189
+ if output_attentions:
190
+ warnings.warn(
191
+ "`LinearAttentionModel` does not support output attention weights now, "
192
+ "so `output_attentions` is set to `False`."
193
+ )
194
+ output_attentions = False
195
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
196
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
197
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
198
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
199
+
200
+ # retrieve input_ids and inputs_embeds
201
+ if input_ids is not None and inputs_embeds is not None:
202
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
203
+ if input_ids is None and inputs_embeds is None:
204
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
205
+
206
+ if inputs_embeds is None:
207
+ inputs_embeds = self.embeddings(input_ids)
208
+ hidden_states = inputs_embeds
209
+
210
+ if use_cache and not isinstance(past_key_values, Cache):
211
+ past_key_values = Cache.from_legacy_cache(past_key_values)
212
+
213
+ if self.gradient_checkpointing and self.training and use_cache:
214
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
215
+ use_cache = False
216
+
217
+ all_hidden_states = () if output_hidden_states else None
218
+ all_attns = () if output_attentions else None
219
+
220
+ for i, layer in enumerate(self.layers):
221
+ if output_hidden_states:
222
+ all_hidden_states += (hidden_states,)
223
+
224
+ if self.gradient_checkpointing and self.training:
225
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
226
+ layer.__call__,
227
+ hidden_states,
228
+ attention_mask,
229
+ past_key_values,
230
+ use_cache,
231
+ output_attentions,
232
+ )
233
+ else:
234
+ hidden_states, attentions, past_key_values = layer(
235
+ hidden_states,
236
+ attention_mask=attention_mask,
237
+ past_key_values=past_key_values,
238
+ use_cache=use_cache,
239
+ output_attentions=output_attentions
240
+ )
241
+
242
+ if output_attentions:
243
+ all_attns += (attentions,)
244
+
245
+ hidden_states = self.norm(hidden_states)
246
+
247
+ # add hidden states from the last decoder layer
248
+ if output_hidden_states:
249
+ all_hidden_states += (hidden_states,)
250
+
251
+ if not return_dict:
252
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
253
+ return BaseModelOutputWithPast(
254
+ last_hidden_state=hidden_states,
255
+ past_key_values=past_key_values,
256
+ hidden_states=all_hidden_states,
257
+ attentions=all_attns
258
+ )
259
+
260
+
261
+ class LinearAttentionForCausalLM(LinearAttentionPreTrainedModel, GenerationMixin):
262
+
263
+ _tied_weights_keys = ["lm_head.weight"]
264
+
265
+ def __init__(self, config):
266
+ super().__init__(config)
267
+ self.model = LinearAttentionModel(config)
268
+ self.vocab_size = config.vocab_size
269
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
270
+ self.criterion = None
271
+
272
+ # Initialize weights and apply final processing
273
+ self.post_init()
274
+
275
+ def get_input_embeddings(self):
276
+ return self.model.embeddings
277
+
278
+ def set_input_embeddings(self, value):
279
+ self.model.embeddings = value
280
+
281
+ def get_output_embeddings(self):
282
+ return self.lm_head
283
+
284
+ def set_output_embeddings(self, new_embeddings):
285
+ self.lm_head = new_embeddings
286
+
287
+ def set_decoder(self, decoder):
288
+ self.model = decoder
289
+
290
+ def get_decoder(self):
291
+ return self.model
292
+
293
+ def generate(self, *args, **kwargs):
294
+ try:
295
+ return super().generate(*args, **kwargs)
296
+ except AttributeError as exception:
297
+ if 'past_key_values' in str(exception):
298
+ raise AttributeError(
299
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
300
+ f"which is not supported for {self.__class__.__name__}. "
301
+ f"Try another generation strategy instead. "
302
+ f"For the available generation strategies, check this doc: "
303
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
304
+ )
305
+ else:
306
+ raise exception
307
+
308
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
309
+ def prepare_inputs_for_generation(
310
+ self,
311
+ input_ids: torch.LongTensor = None,
312
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
313
+ attention_mask: Optional[torch.Tensor] = None,
314
+ inputs_embeds: Optional[torch.Tensor] = None,
315
+ use_cache: bool = True,
316
+ logits_to_keep: Optional[int] = None,
317
+ **kwargs
318
+ ):
319
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
320
+ if past_key_values is not None and len(past_key_values) > 0:
321
+ input_ids = input_ids[:, -1:]
322
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
323
+ if inputs_embeds is not None and len(past_key_values) == 0:
324
+ model_inputs = {'inputs_embeds': inputs_embeds}
325
+ else:
326
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
327
+ # recompiles graphs as the stride of the inputs is a guard.
328
+ # Ref: https://github.com/huggingface/transformers/pull/29114
329
+ # TODO: use `next_tokens` directly instead.
330
+ model_inputs = {'input_ids': input_ids.contiguous()}
331
+
332
+ if logits_to_keep is not None:
333
+ model_inputs['logits_to_keep'] = logits_to_keep
334
+
335
+ model_inputs.update({
336
+ 'past_key_values': past_key_values,
337
+ 'use_cache': use_cache,
338
+ 'attention_mask': attention_mask,
339
+ })
340
+ return model_inputs
341
+
342
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
343
+ def forward(
344
+ self,
345
+ input_ids: torch.LongTensor = None,
346
+ attention_mask: Optional[torch.Tensor] = None,
347
+ inputs_embeds: Optional[torch.Tensor] = None,
348
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
349
+ labels: Optional[torch.LongTensor] = None,
350
+ use_cache: Optional[bool] = None,
351
+ output_attentions: Optional[bool] = None,
352
+ output_hidden_states: Optional[bool] = None,
353
+ return_dict: Optional[bool] = None,
354
+ logits_to_keep: Optional[int] = 0
355
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
356
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
357
+ output_hidden_states = (
358
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
359
+ )
360
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
361
+
362
+ outputs = self.model(
363
+ input_ids=input_ids,
364
+ attention_mask=attention_mask,
365
+ inputs_embeds=inputs_embeds,
366
+ past_key_values=past_key_values,
367
+ use_cache=use_cache,
368
+ output_attentions=output_attentions,
369
+ output_hidden_states=output_hidden_states,
370
+ return_dict=return_dict
371
+ )
372
+
373
+ hidden_states = outputs[0]
374
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
375
+
376
+ loss, logits = None, None
377
+ if not fuse_linear_and_cross_entropy or labels is None:
378
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
379
+ if labels is not None:
380
+ if getattr(self, 'criterion', None) is None:
381
+ if fuse_linear_and_cross_entropy:
382
+ criterion = FusedLinearCrossEntropyLoss()
383
+ elif self.config.fuse_cross_entropy:
384
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
385
+ else:
386
+ criterion = nn.CrossEntropyLoss()
387
+ else:
388
+ criterion = self.criterion
389
+ labels = labels.to(hidden_states.device)
390
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
391
+ if fuse_linear_and_cross_entropy:
392
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
393
+ else:
394
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
395
+
396
+ if not return_dict:
397
+ output = (logits,) + outputs[1:]
398
+ return (loss,) + output if loss is not None else output
399
+
400
+ return CausalLMOutputWithPast(
401
+ loss=loss,
402
+ logits=logits,
403
+ past_key_values=outputs.past_key_values,
404
+ hidden_states=outputs.hidden_states,
405
+ attentions=outputs.attentions,
406
+ )
fla3/models/mamba/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.mamba.configuration_mamba import MambaConfig
6
+ from fla.models.mamba.modeling_mamba import MambaBlock, MambaForCausalLM, MambaModel
7
+
8
+ AutoConfig.register(MambaConfig.model_type, MambaConfig, True)
9
+ AutoModel.register(MambaConfig, MambaModel, True)
10
+ AutoModelForCausalLM.register(MambaConfig, MambaForCausalLM, True)
11
+
12
+
13
+ __all__ = ['MambaConfig', 'MambaForCausalLM', 'MambaModel', 'MambaBlock']
fla3/models/mamba/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (565 Bytes). View file
 
fla3/models/mamba/__pycache__/configuration_mamba.cpython-310.pyc ADDED
Binary file (6.37 kB). View file
 
fla3/models/mamba/__pycache__/modeling_mamba.cpython-310.pyc ADDED
Binary file (18.2 kB). View file
 
fla3/models/mamba/configuration_mamba.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """MAMBA configuration"""
16
+
17
+ import math
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+
21
+
22
+ class MambaConfig(PretrainedConfig):
23
+ """
24
+ This is the configuration class to store the configuration of a [`MambaModel`]. It is used to instantiate a MAMBA
25
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
26
+ defaults will yield a similar configuration to that of the MAMBA
27
+ [state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) architecture.
28
+
29
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
30
+ documentation from [`PretrainedConfig`] for more information.
31
+
32
+
33
+ Args:
34
+ vocab_size (`int`, *optional*):
35
+ Vocabulary size of the Mamba model.
36
+ hidden_size (`int`, *optional*):
37
+ Dimensionality of the embeddings and hidden states. Default: 2048.
38
+ state_size (`int`, *optional*):
39
+ Shape of the state space latents. Default: 16.
40
+ num_hidden_layers (`int`, *optional*):
41
+ Number of hidden layers in the model. Default: 48.
42
+ norm_eps (`float`, *optional*):
43
+ The epsilon to use in the layer normalization layers. Default: 1e-5.
44
+ pad_token_id (`int`, *optional*):
45
+ Padding token id. Default: 0.
46
+ bos_token_id (`int`, *optional*):
47
+ The id of the beginning of sentence token in the vocabulary. Default: 0.
48
+ eos_token_id (`int`, *optional*):
49
+ The id of the end of sentence token in the vocabulary. Default: 0.
50
+ expand (`int`, *optional*):
51
+ Expanding factor used to determine the intermediate size. Default: 2.
52
+ conv_kernel (`int`, *optional*):
53
+ Size of the convolution kernel. Default: 4.
54
+ use_bias (`bool`, *optional*):
55
+ Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block. Default: `False`.
56
+ use_conv_bias (`bool`, *optional*):
57
+ Whether or not to use bias in the convolution layer of the mixer block. Default: `True`.
58
+ hidden_act (`str`, *optional*):
59
+ The non-linear activation function (function or string) in the decoder. Default: `"silu"`.
60
+ initializer_range (`float`, *optional*):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices. Default: 0.02.
62
+ residual_in_fp32 (`bool`, *optional*):
63
+ Whether or not residuals should be in `float32`.
64
+ If set to `False` residuals will keep the same `dtype` as the rest of the model. Default: `True`.
65
+ time_step_rank (`Union[int,str]`, *optional*):
66
+ Rank of the the discretization projection matrix.
67
+ `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`. Default: `"auto"`.
68
+ time_step_scale (`float`, *optional*):
69
+ Scale used used to scale `dt_proj.bias`. Default: 1.0.
70
+ time_step_min (`float`, *optional*):
71
+ Minimum `time_step` used to bound `dt_proj.bias`. Default: 0.001.
72
+ time_step_max (`float`, *optional*):
73
+ Maximum `time_step` used to bound `dt_proj.bias`. Default: 0.1.
74
+ time_step_init_scheme (`float`, *optional*):
75
+ Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`. Default: `"random"`.
76
+ time_step_floor (`float`, *optional*):
77
+ Minimum clamping value of the `dt_proj.bias` layer initialization. Default: 0.0001.
78
+ window_size (`int`, *optional*):
79
+ The window size used for sliding window attention. Default: 2048.
80
+ rescale_prenorm_residual (`bool`, *optional*):
81
+ Whether or not to rescale `out_proj` weights when initializing. Default: `False`.
82
+ use_cache (`bool`, *optional*):
83
+ Whether or not the cache should be used. Default: `True`.
84
+
85
+
86
+ Example:
87
+
88
+ ```python
89
+ >>> from transformers import MambaConfig, MambaModel
90
+
91
+ >>> # Initializing a Mamba configuration
92
+ >>> configuration = MambaConfig()
93
+
94
+ >>> # Initializing a model (with random weights) from the configuration
95
+ >>> model = MambaModel(configuration)
96
+
97
+ >>> # Accessing the model configuration
98
+ >>> configuration = model.config
99
+ ```"""
100
+
101
+ model_type = "mamba"
102
+
103
+ def __init__(
104
+ self,
105
+ vocab_size: int = 32000,
106
+ hidden_size: int = 2048,
107
+ state_size: int = 16,
108
+ num_hidden_layers: int = 48,
109
+ norm_eps=1e-5,
110
+ pad_token_id: int = 0,
111
+ bos_token_id: int = 1,
112
+ eos_token_id: int = 2,
113
+ expand: int = 2,
114
+ conv_kernel: int = 4,
115
+ use_bias: bool = False,
116
+ use_conv_bias: bool = True,
117
+ hidden_act: str = "silu",
118
+ initializer_range: str = 0.02,
119
+ residual_in_fp32: bool = False,
120
+ time_step_rank: str = "auto",
121
+ time_step_scale: float = 1.0,
122
+ time_step_min: float = 0.001,
123
+ time_step_max: float = 0.1,
124
+ time_step_init_scheme: str = "random",
125
+ time_step_floor: float = 1e-4,
126
+ rescale_prenorm_residual: bool = False,
127
+ use_cache: bool = True,
128
+ fuse_norm: bool = True,
129
+ fuse_cross_entropy: bool = True,
130
+ tie_word_embeddings: bool = False,
131
+ **kwargs,
132
+ ):
133
+ self.vocab_size = vocab_size
134
+ self.hidden_size = hidden_size
135
+ self.state_size = state_size
136
+ self.num_hidden_layers = num_hidden_layers
137
+ self.norm_eps = norm_eps
138
+ self.conv_kernel = conv_kernel
139
+ self.expand = expand
140
+ self.intermediate_size = int(expand * self.hidden_size)
141
+ self.bos_token_id = bos_token_id
142
+ self.eos_token_id = eos_token_id
143
+ self.pad_token_id = pad_token_id
144
+ self.use_bias = use_bias
145
+ self.use_conv_bias = use_conv_bias
146
+ self.hidden_act = hidden_act
147
+ self.initializer_range = initializer_range
148
+ self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
149
+ self.time_step_scale = time_step_scale
150
+ self.time_step_min = time_step_min
151
+ self.time_step_max = time_step_max
152
+ self.time_step_init_scheme = time_step_init_scheme
153
+ self.time_step_floor = time_step_floor
154
+ self.rescale_prenorm_residual = rescale_prenorm_residual
155
+ self.residual_in_fp32 = residual_in_fp32
156
+ self.use_cache = use_cache
157
+ self.fuse_norm = fuse_norm
158
+ self.fuse_cross_entropy = fuse_cross_entropy
159
+
160
+ super().__init__(
161
+ bos_token_id=bos_token_id,
162
+ eos_token_id=eos_token_id,
163
+ pad_token_id=pad_token_id,
164
+ tie_word_embeddings=tie_word_embeddings,
165
+ **kwargs
166
+ )
fla3/models/mamba/modeling_mamba.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.generation import GenerationMixin
24
+ from transformers.modeling_utils import PreTrainedModel
25
+ from transformers.utils import ModelOutput, logging
26
+ from transformers.utils.deprecation import deprecate_kwarg
27
+
28
+ from fla.layers.mamba import Mamba
29
+ from fla.models.mamba.configuration_mamba import MambaConfig
30
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class MambaCache:
36
+ """
37
+ Cache for mamba model which does not have attention mechanism and key value states.
38
+
39
+ Arguments:
40
+ config (`PretrainedConfig):
41
+ The configuration file defining the shape-related attributes required to initialize the static cache.
42
+ batch_size (`int`):
43
+ The batch size with which the model will be used. Note that a new instance must be instantiated if a
44
+ smaller batch size is used.
45
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
46
+ The default `dtype` to use when initializing the layer.
47
+ device (`torch.device` or `str`, *optional*):
48
+ The device on which the cache should be initialized. Should be the same as the layer.
49
+
50
+ Attributes:
51
+ dtype: (`torch.dtype`):
52
+ The default `dtype` used to initializing the cache.
53
+ intermediate_size: (`int`):
54
+ Model's intermediate_size taken from config.
55
+ ssm_state_size: (`int`):
56
+ Model's state_size taken from config.
57
+ conv_kernel_size: (`int`):
58
+ Model's convolution kernel size taken from config
59
+ conv_states: (`torch.Tensor`):
60
+ A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states.
61
+ ssm_states: (`torch.Tensor`):
62
+ A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states
63
+
64
+ Example:
65
+
66
+ ```python
67
+ >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache
68
+
69
+ >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
70
+ >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
71
+
72
+ >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
73
+
74
+ >>> # Prepare a cache class and pass it to model's forward
75
+ >>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype)
76
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
77
+ >>> outputs.past_key_values
78
+ MambaCache()
79
+ ```
80
+ """
81
+
82
+ # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
83
+ def __init__(
84
+ self,
85
+ config: PretrainedConfig,
86
+ batch_size: int = None,
87
+ dtype: torch.dtype = torch.float16,
88
+ device: Optional[Union[torch.device, str]] = None,
89
+ max_batch_size: Optional[int] = None,
90
+ ):
91
+ if max_batch_size is not None:
92
+ logger.warning_once(
93
+ f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
94
+ "v4.46. Use the more precisely named 'batch_size' argument instead."
95
+ )
96
+ self.dtype = dtype
97
+ self.batch_size = batch_size or max_batch_size
98
+ self.intermediate_size = config.intermediate_size
99
+ self.ssm_state_size = config.state_size
100
+ self.conv_kernel_size = config.conv_kernel
101
+
102
+ self.conv_states: torch.Tensor = torch.zeros(
103
+ config.num_hidden_layers,
104
+ self.batch_size,
105
+ self.intermediate_size,
106
+ self.conv_kernel_size,
107
+ device=device,
108
+ dtype=dtype,
109
+ )
110
+ self.ssm_states: torch.Tensor = torch.zeros(
111
+ config.num_hidden_layers,
112
+ self.batch_size,
113
+ self.intermediate_size,
114
+ self.ssm_state_size,
115
+ device=device,
116
+ dtype=dtype,
117
+ )
118
+
119
+ torch._dynamo.mark_static_address(self.conv_states)
120
+ torch._dynamo.mark_static_address(self.ssm_states)
121
+
122
+ def update_conv_state(
123
+ self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
124
+ ) -> torch.Tensor:
125
+ conv_state = self.conv_states[layer_idx]
126
+ cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
127
+
128
+ conv_state = conv_state.roll(shifts=-1, dims=-1)
129
+ conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
130
+ self.conv_states[layer_idx].zero_()
131
+ self.conv_states[layer_idx] += conv_state
132
+ return self.conv_states[layer_idx]
133
+
134
+ def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
135
+ self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
136
+ return self.ssm_states[layer_idx]
137
+
138
+ def reset(self):
139
+ self.conv_states.zero_()
140
+ self.ssm_states.zero_()
141
+
142
+
143
+ class MambaBlock(nn.Module):
144
+ def __init__(self, config, layer_idx):
145
+ super().__init__()
146
+ self.config = config
147
+ self.layer_idx = layer_idx
148
+ self.residual_in_fp32 = config.residual_in_fp32
149
+ self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
150
+ self.mixer = Mamba(
151
+ hidden_size=config.hidden_size,
152
+ state_size=config.state_size,
153
+ conv_kernel=config.conv_kernel,
154
+ intermediate_size=config.intermediate_size,
155
+ time_step_rank=config.time_step_rank,
156
+ use_bias=config.use_bias,
157
+ layer_idx=layer_idx
158
+ )
159
+
160
+ def forward(
161
+ self,
162
+ hidden_states,
163
+ cache_params: Optional[MambaCache] = None,
164
+ cache_position: Optional[torch.LongTensor] = None,
165
+ attention_mask: Optional[torch.LongTensor] = None,
166
+ ):
167
+ residual = hidden_states
168
+ hidden_states = self.norm(hidden_states)
169
+ if self.residual_in_fp32:
170
+ residual = residual.to(torch.float32)
171
+
172
+ hidden_states = self.mixer(
173
+ hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
174
+ )
175
+ hidden_states = residual + hidden_states
176
+ if self.residual_in_fp32:
177
+ hidden_states = hidden_states.to(dtype=self.norm.weight.dtype)
178
+ return hidden_states
179
+
180
+
181
+ class MambaPreTrainedModel(PreTrainedModel):
182
+ """
183
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
184
+ models.
185
+ """
186
+
187
+ config_class = MambaConfig
188
+ base_model_prefix = 'backbone'
189
+ _no_split_modules = ['Mamba', 'MambaBlock']
190
+ supports_gradient_checkpointing = True
191
+ _is_stateful = True
192
+
193
+ def _init_weights(self, module):
194
+ """Initialize the weights."""
195
+ if isinstance(module, nn.Linear):
196
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
197
+ if module.bias is not None:
198
+ if not getattr(module.bias, "_no_reinit", False):
199
+ nn.init.zeros_(module.bias)
200
+ elif isinstance(module, Mamba):
201
+ module.A_log._no_weight_decay = True
202
+ module.D._no_weight_decay = True
203
+
204
+ dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
205
+ if self.config.time_step_init_scheme == "constant":
206
+ nn.init.constant_(module.dt_proj.weight, dt_init_std)
207
+ elif self.config.time_step_init_scheme == "random":
208
+ nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
209
+
210
+ dt = torch.exp(
211
+ torch.rand(self.config.intermediate_size)
212
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
213
+ + math.log(self.config.time_step_min)
214
+ ).clamp(min=self.config.time_step_floor)
215
+ # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
216
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
217
+ with torch.no_grad():
218
+ module.dt_proj.bias.data = nn.Parameter(inv_dt.to(module.dt_proj.bias.device))
219
+ module.dt_proj.bias._no_reinit = True
220
+ elif isinstance(module, nn.Embedding):
221
+ nn.init.normal_(module.weight, std=self.config.initializer_range)
222
+ elif hasattr(module, 'reset_parameters'):
223
+ module.reset_parameters()
224
+
225
+ if self.config.rescale_prenorm_residual:
226
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
227
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
228
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
229
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
230
+ #
231
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
232
+ for name, p in module.named_parameters():
233
+ if name in ["out_proj.weight"]:
234
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
235
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
236
+ # We need to reinit p since this code could be called multiple times
237
+ # Having just p *= scale would repeatedly scale it down
238
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
239
+ with torch.no_grad():
240
+ p /= math.sqrt(self.config.num_hidden_layers)
241
+
242
+
243
+ @dataclass
244
+ class MambaOutput(ModelOutput):
245
+ """
246
+ Class for the MAMBA model outputs.
247
+
248
+ Args:
249
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
250
+ Sequence of hidden-states at the output of the last layer of the model.
251
+ cache_params (`MambaCache`):
252
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
253
+ avoid providing the old `input_ids`.
254
+
255
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
256
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
257
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
258
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
259
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
260
+
261
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
262
+ """
263
+
264
+ last_hidden_state: Optional[torch.FloatTensor] = None
265
+ cache_params: Optional[MambaCache] = None
266
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
267
+
268
+
269
+ @dataclass
270
+ class MambaCausalLMOutput(ModelOutput):
271
+ """
272
+ Base class for causal language model (or autoregressive) outputs.
273
+
274
+ Args:
275
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
276
+ Language modeling loss (for next-token prediction).
277
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
278
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
279
+ cache_params (`MambaCache`):
280
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
281
+ avoid providing the old `input_ids`.
282
+
283
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
284
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
285
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
286
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
287
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
288
+
289
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
290
+ """
291
+
292
+ loss: Optional[torch.FloatTensor] = None
293
+ logits: Optional[torch.FloatTensor] = None
294
+ cache_params: Optional[MambaCache] = None
295
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
296
+
297
+
298
+ class MambaModel(MambaPreTrainedModel):
299
+ def __init__(self, config):
300
+ super().__init__(config)
301
+
302
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
303
+ self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
304
+
305
+ self.gradient_checkpointing = False
306
+ self.norm_f = RMSNorm(config.hidden_size, eps=config.norm_eps)
307
+ # Initialize weights and apply final processing
308
+ self._register_load_state_dict_pre_hook(self.load_hook)
309
+ self.post_init()
310
+
311
+ def load_hook(self, state_dict, prefix, *args):
312
+ for k in state_dict:
313
+ if "embedding." in k:
314
+ state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
315
+ break
316
+
317
+ def get_input_embeddings(self):
318
+ return self.embeddings
319
+
320
+ def set_input_embeddings(self, new_embeddings):
321
+ self.embeddings = new_embeddings
322
+
323
+ def forward(
324
+ self,
325
+ input_ids: Optional[torch.LongTensor] = None,
326
+ inputs_embeds: Optional[torch.LongTensor] = None,
327
+ cache_params: Optional[MambaCache] = None,
328
+ use_cache: Optional[bool] = None,
329
+ output_hidden_states: Optional[bool] = None,
330
+ return_dict: Optional[bool] = None,
331
+ cache_position: Optional[torch.LongTensor] = None,
332
+ attention_mask: Optional[torch.LongTensor] = None,
333
+ ) -> Union[Tuple, MambaOutput]:
334
+ output_hidden_states = (
335
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
336
+ )
337
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
338
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
339
+
340
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
341
+ raise ValueError(
342
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
343
+ )
344
+
345
+ if inputs_embeds is None:
346
+ inputs_embeds = self.embeddings(input_ids)
347
+
348
+ if self.gradient_checkpointing and self.training and use_cache:
349
+ use_cache = False
350
+
351
+ if use_cache:
352
+ if cache_params is None:
353
+ cache_params = MambaCache(
354
+ self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
355
+ )
356
+ cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
357
+ elif cache_position is None:
358
+ # cases when we do manual forward instead of using `model.generate` which will initiate
359
+ # `cache_position` and makes sure it is not None, throw error here instead of doing some
360
+ # hack to conjecture the current cache position
361
+ raise ValueError(
362
+ "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
363
+ "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
364
+ "be initialized for you automatically"
365
+ )
366
+ else:
367
+ cache_params = None
368
+
369
+ hidden_states = inputs_embeds
370
+ all_hidden_states = () if output_hidden_states else None
371
+ for mixer_block in self.layers:
372
+ if self.gradient_checkpointing and self.training:
373
+ hidden_states = self._gradient_checkpointing_func(
374
+ mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
375
+ )
376
+ else:
377
+ hidden_states = mixer_block(
378
+ hidden_states,
379
+ cache_params=cache_params,
380
+ cache_position=cache_position,
381
+ attention_mask=attention_mask,
382
+ )
383
+
384
+ if output_hidden_states:
385
+ all_hidden_states = all_hidden_states + (hidden_states,)
386
+
387
+ hidden_states = self.norm_f(hidden_states)
388
+
389
+ if output_hidden_states:
390
+ all_hidden_states = all_hidden_states + (hidden_states,)
391
+
392
+ if not return_dict:
393
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
394
+
395
+ return MambaOutput(
396
+ last_hidden_state=hidden_states,
397
+ cache_params=cache_params if use_cache else None,
398
+ hidden_states=all_hidden_states,
399
+ )
400
+
401
+
402
+ class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin):
403
+
404
+ _tied_weights_keys = ["lm_head.weight"]
405
+
406
+ def __init__(self, config):
407
+ super().__init__(config)
408
+ self.backbone = MambaModel(config)
409
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
410
+ self.criterion = None
411
+
412
+ # Initialize weights and apply final processing
413
+ self.post_init()
414
+
415
+ def get_output_embeddings(self):
416
+ return self.lm_head
417
+
418
+ def set_output_embeddings(self, new_embeddings):
419
+ self.lm_head = new_embeddings
420
+
421
+ def get_input_embeddings(self):
422
+ return self.backbone.get_input_embeddings()
423
+
424
+ def set_input_embeddings(self, new_embeddings):
425
+ return self.backbone.set_input_embeddings(new_embeddings)
426
+
427
+ def _update_model_kwargs_for_generation(
428
+ self, outputs: ModelOutput,
429
+ model_kwargs: Dict[str, Any],
430
+ num_new_tokens: int = 1,
431
+ **kwargs
432
+ ) -> Dict[str, Any]:
433
+ model_kwargs["cache_params"] = outputs.get("cache_params", None)
434
+ if (
435
+ model_kwargs.get("use_cache", True)
436
+ and "cache_position" in model_kwargs
437
+ and model_kwargs["cache_position"] is not None
438
+ ):
439
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
440
+
441
+ if "attention_mask" in model_kwargs:
442
+ attention_mask = model_kwargs["attention_mask"]
443
+ model_kwargs["attention_mask"] = torch.cat(
444
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
445
+ )
446
+
447
+ return model_kwargs
448
+
449
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
450
+ def prepare_inputs_for_generation(
451
+ self,
452
+ input_ids,
453
+ inputs_embeds=None,
454
+ use_cache=None,
455
+ cache_params: Optional[MambaCache] = None,
456
+ cache_position: Optional[torch.LongTensor] = None,
457
+ attention_mask: Optional[torch.LongTensor] = None,
458
+ logits_to_keep: Optional[int] = None,
459
+ **kwargs,
460
+ ):
461
+ if use_cache:
462
+ # `cache_position` should have been initialized in `generate`
463
+ if cache_position is None:
464
+ raise ValueError(
465
+ "`cache_position` should not be None as it should have been initialized in "
466
+ "`model.generate`, you are responsible for passing in a valid `cache_position` if "
467
+ "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
468
+ )
469
+ if cache_position[0] > 0:
470
+ input_ids = input_ids[:, -1].unsqueeze(-1)
471
+
472
+ if attention_mask is not None:
473
+ attention_mask = None
474
+
475
+ else:
476
+ # we initialize the `cache_position` to full size of `conv_states` at prefill stage
477
+ # considering padding will be applied when input length is shorter, and truncation
478
+ # will be applied when it is longer, so it will be equivalent to always have it match
479
+ # the length of `cache_params.conv_states`, which is `config.conv_kernel`
480
+ cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
481
+
482
+ if inputs_embeds is not None and cache_params is None:
483
+ model_inputs = {"inputs_embeds": inputs_embeds}
484
+ else:
485
+ model_inputs = {"input_ids": input_ids.contiguous()}
486
+
487
+ if logits_to_keep is not None:
488
+ model_inputs['logits_to_keep'] = logits_to_keep
489
+
490
+ model_inputs.update({
491
+ 'cache_params': cache_params,
492
+ 'use_cache': use_cache,
493
+ 'cache_position': cache_position,
494
+ 'attention_mask': attention_mask,
495
+ 'logits_to_keep': logits_to_keep,
496
+ })
497
+ return model_inputs
498
+
499
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
500
+ def forward(
501
+ self,
502
+ input_ids: Optional[torch.LongTensor] = None,
503
+ attention_mask: Optional[torch.LongTensor] = None,
504
+ inputs_embeds: Optional[torch.FloatTensor] = None,
505
+ cache_params: Optional[MambaCache] = None,
506
+ labels: Optional[torch.LongTensor] = None,
507
+ output_hidden_states: Optional[bool] = None,
508
+ return_dict: Optional[bool] = None,
509
+ use_cache: Optional[bool] = None,
510
+ cache_position: Optional[torch.Tensor] = None,
511
+ logits_to_keep: Optional[int] = 0,
512
+ **kwargs, # for now we need this for generation
513
+ ) -> Union[Tuple, MambaCausalLMOutput]:
514
+ r"""
515
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
516
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
517
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
518
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
519
+ """
520
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
521
+
522
+ mamba_outputs = self.backbone(
523
+ input_ids,
524
+ cache_params=cache_params,
525
+ inputs_embeds=inputs_embeds,
526
+ output_hidden_states=output_hidden_states,
527
+ return_dict=return_dict,
528
+ use_cache=use_cache,
529
+ cache_position=cache_position,
530
+ attention_mask=attention_mask,
531
+ )
532
+ hidden_states = mamba_outputs[0]
533
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
534
+
535
+ loss, logits = None, None
536
+ if not fuse_linear_and_cross_entropy or labels is None:
537
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
538
+ if labels is not None:
539
+ if getattr(self, 'criterion', None) is None:
540
+ if fuse_linear_and_cross_entropy:
541
+ criterion = FusedLinearCrossEntropyLoss()
542
+ elif self.config.fuse_cross_entropy:
543
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
544
+ else:
545
+ criterion = nn.CrossEntropyLoss()
546
+ else:
547
+ criterion = self.criterion
548
+ # Enable model parallelism
549
+ labels = labels.to(hidden_states.device)
550
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
551
+ if fuse_linear_and_cross_entropy:
552
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
553
+ else:
554
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
555
+
556
+ if not return_dict:
557
+ output = (logits,) + mamba_outputs[1:]
558
+ return (loss,) + output if loss is not None else output
559
+
560
+ return MambaCausalLMOutput(
561
+ loss=loss,
562
+ logits=logits,
563
+ cache_params=mamba_outputs.cache_params,
564
+ hidden_states=mamba_outputs.hidden_states,
565
+ )
fla3/models/mamba2/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.mamba2.configuration_mamba2 import Mamba2Config
6
+ from fla.models.mamba2.modeling_mamba2 import Mamba2ForCausalLM, Mamba2Model
7
+
8
+ AutoConfig.register(Mamba2Config.model_type, Mamba2Config, True)
9
+ AutoModel.register(Mamba2Config, Mamba2Model, True)
10
+ AutoModelForCausalLM.register(Mamba2Config, Mamba2ForCausalLM, True)
11
+
12
+
13
+ __all__ = ['Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model']
fla3/models/mamba2/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (547 Bytes). View file
 
fla3/models/mamba2/__pycache__/configuration_mamba2.cpython-310.pyc ADDED
Binary file (6.51 kB). View file
 
fla3/models/mamba2/__pycache__/modeling_mamba2.cpython-310.pyc ADDED
Binary file (17 kB). View file
 
fla3/models/mamba2/configuration_mamba2.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """MAMBA2 configuration"""
15
+
16
+ import math
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+
20
+
21
+ class Mamba2Config(PretrainedConfig):
22
+ """
23
+ This is the configuration class to store the configuration of a [`Mamba2Model`]. It is used to instantiate a MAMBA2
24
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
25
+ defaults will yield a similar configuration to that of the MAMBA2
26
+ [state-spaces/mamba2-2.8b](https://huggingface.co/state-spaces/mamba2-2.8b) architecture.
27
+
28
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
29
+ documentation from [`PretrainedConfig`] for more information.
30
+
31
+
32
+ Args:
33
+ head_dim (`int`, *optional*, defaults to 64):
34
+ Dimension of each head.
35
+ vocab_size (`int`, *optional*, defaults to 32768):
36
+ Vocabulary size of the MAMBA2 model. Defines the number of different tokens that can be represented by the
37
+ `inputs_ids` passed when calling [`Mamba2Model`].
38
+ hidden_size (`int`, *optional*, defaults to 2048):
39
+ Dimensionality of the embeddings and hidden states.
40
+ state_size (`int`, *optional*, defaults to 128): shape of the state space latents.
41
+ num_hidden_layers (`int`, *optional*, defaults to 48):
42
+ Number of hidden layers in the model.
43
+ norm_eps (`float`, *optional*, defaults to 1e-05):
44
+ The epsilon to use in the layer normalization layers.
45
+ pad_token_id (`int`, *optional*, defaults to 0):
46
+ Padding token id.
47
+ bos_token_id (`int`, *optional*, defaults to 1):
48
+ The id of the beginning of sentence token in the vocabulary.
49
+ eos_token_id (`int`, *optional*, defaults to 2):
50
+ The id of the end of sentence token in the vocabulary.
51
+ expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
52
+ conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
53
+ n_groups (`int`, *optional*, defaults to 1):
54
+ Number of groups for the evolution matrices of mamba 2.
55
+ use_bias (`bool`, *optional*, defaults to `False`):
56
+ Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
57
+ use_conv_bias (`bool`, *optional*, defaults to `True`):
58
+ Whether or not to use bias in the convolution layer of the mixer block.
59
+ hidden_act (`str`, *optional*, defaults to `"silu"`):
60
+ The non-linear activation function (function or string) in the decoder.
61
+ initializer_range (`float`, *optional*, defaults to 0.02):
62
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63
+ residual_in_fp32 (`bool`, *optional*, defaults to `True`):
64
+ Whether or not residuals should be in `float32`.
65
+ If set to `False` residuals will keep the same `dtype` as the rest of the model
66
+ time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
67
+ Rank of the discretization projection matrix.
68
+ `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
69
+ time_step_min (`float`, *optional*, defaults to 0.001):
70
+ Minimum `time_step` used to bound `dt_proj.bias`.
71
+ time_step_max (`float`, *optional*, defaults to 0.1):
72
+ Maximum `time_step` used to bound `dt_proj.bias`.
73
+ time_step_floor (`float`, *optional*, defaults to 0.0001):
74
+ Minimum clamping value of the `dt_proj.bias` layer initialization.
75
+ time_step_limit (`tuple`, *optional*, defaults to `(0.0, inf)`):
76
+ Accepted range of time step values.
77
+ rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):
78
+ Whether or not to rescale `out_proj` weights when initializing.
79
+ use_cache (`bool`, *optional*, defaults to `True`):
80
+ Whether or not the cache should be used.
81
+ rms_norm (`bool`, *optional*, defaults to `True`):
82
+ Whether to use RMS norm or not.
83
+ chunk_size (`int`, *optional*, defaults to 256):
84
+ Size of the chunks that will comprise the sequence.
85
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
86
+ Whether to tie word embeddings or not.
87
+ """
88
+
89
+ model_type = "mamba2"
90
+
91
+ def __init__(
92
+ self,
93
+ head_dim: int = 64,
94
+ vocab_size: int = 32000,
95
+ hidden_size: int = 2048,
96
+ state_size: int = 128,
97
+ num_hidden_layers: int = 48,
98
+ norm_eps: float = 1e-5,
99
+ pad_token_id: int = 0,
100
+ bos_token_id: int = 1,
101
+ eos_token_id: int = 2,
102
+ expand: int = 2,
103
+ conv_kernel: int = 4,
104
+ n_groups: int = 1,
105
+ use_bias: bool = False,
106
+ use_conv_bias: bool = True,
107
+ hidden_act: str = "silu",
108
+ initializer_range: float = 0.02,
109
+ residual_in_fp32: bool = True,
110
+ time_step_rank: str = "auto",
111
+ time_step_min: float = 0.001,
112
+ time_step_max: float = 0.1,
113
+ time_step_floor: float = 1e-4,
114
+ time_step_limit=(0.0, float("inf")),
115
+ rescale_prenorm_residual: bool = True,
116
+ use_cache: bool = True,
117
+ rms_norm: bool = True,
118
+ chunk_size: int = 256,
119
+ fuse_norm: bool = True,
120
+ fuse_cross_entropy: bool = True,
121
+ tie_word_embeddings: bool = False,
122
+ **kwargs,
123
+ ):
124
+ self.vocab_size = vocab_size
125
+ self.hidden_size = hidden_size
126
+ self.state_size = state_size
127
+ self.num_hidden_layers = num_hidden_layers
128
+ self.norm_eps = norm_eps
129
+ self.conv_kernel = conv_kernel
130
+ self.expand = expand
131
+
132
+ self.bos_token_id = bos_token_id
133
+ self.eos_token_id = eos_token_id
134
+ self.pad_token_id = pad_token_id
135
+ self.use_bias = use_bias
136
+ self.use_conv_bias = use_conv_bias
137
+ self.hidden_act = hidden_act
138
+ self.initializer_range = initializer_range
139
+ self.time_step_rank = (
140
+ math.ceil(self.hidden_size / 16)
141
+ if time_step_rank == "auto"
142
+ else time_step_rank
143
+ )
144
+ self.time_step_min = time_step_min
145
+ self.time_step_max = time_step_max
146
+ self.time_step_floor = time_step_floor
147
+ self.rescale_prenorm_residual = rescale_prenorm_residual
148
+ self.residual_in_fp32 = residual_in_fp32
149
+ self.use_cache = use_cache
150
+ self.n_groups = n_groups
151
+ self.head_dim = head_dim
152
+ self.num_heads = int(self.expand * self.hidden_size / self.head_dim)
153
+ self.rms_norm = rms_norm
154
+ self.state_size = state_size
155
+ self.chunk_size = chunk_size
156
+ self.time_step_limit = time_step_limit
157
+ self.fuse_norm = fuse_norm
158
+ self.fuse_cross_entropy = fuse_cross_entropy
159
+ self.tie_word_embeddings = tie_word_embeddings
160
+
161
+ super().__init__(
162
+ bos_token_id=bos_token_id,
163
+ eos_token_id=eos_token_id,
164
+ pad_token_id=pad_token_id,
165
+ tie_word_embeddings=tie_word_embeddings,
166
+ **kwargs,
167
+ )
fla3/models/mamba2/modeling_mamba2.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+ from transformers.generation import GenerationMixin
23
+ from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.utils import ModelOutput, logging
25
+ from transformers.utils.deprecation import deprecate_kwarg
26
+
27
+ from fla.layers.mamba2 import Mamba2
28
+ from fla.models.mamba2.configuration_mamba2 import Mamba2Config
29
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ class Mamba2Cache:
35
+ """
36
+ Arguments:
37
+ config: Mamba2Config
38
+ batch_size: int
39
+ dtype: torch.dtype
40
+ device: torch.device
41
+
42
+ Attributes:
43
+ dtype: (`torch.dtype`):
44
+ The default `dtype` used to initializing the cache.
45
+ conv_kernel_size: (`int`):
46
+ Model's convolution kernel size taken from config.
47
+ n_groups: (`int`):
48
+ Model's number of groups taken from the config - similar to tensor parallel in Transformer.
49
+ state_size: (`int`):
50
+ Model's SSM state size taken from config.
51
+ num_heads: (`int`):
52
+ The number of heads used in the linear attention / SSM.
53
+ head_dim: (`int`):
54
+ The respective dimension of the heads used in the linear attention / SSM.
55
+ intermediate_size: (`int`):
56
+ Model's intermediate_size based on (expand * hidden_dim) from config.
57
+ conv_states: (`torch.Tensor`):
58
+ A tensor of shape `[num_layers, batch_size, conv_kernel_size, intermediate_size + 2 * n_groups * state_size]`
59
+ that holds convolutional states.
60
+ ssm_states: (`torch.Tensor`):
61
+ A tensor of shape `[num_layers, batch_size, num_heads, head_dim, state_size]` that holds ssm states.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ config: Mamba2Config,
67
+ batch_size: int,
68
+ dtype: torch.dtype = torch.float16,
69
+ device: Optional[str] = None,
70
+ ):
71
+ self.dtype = dtype
72
+ self.conv_kernel_size = config.conv_kernel
73
+ self.n_groups = config.n_groups
74
+ self.state_size = config.state_size
75
+ self.num_heads = config.num_heads
76
+ self.head_dim = config.head_dim
77
+ self.intermediate_size = int(config.expand * config.hidden_size)
78
+
79
+ self.conv_states = torch.zeros(
80
+ config.num_hidden_layers,
81
+ batch_size,
82
+ self.intermediate_size + 2 * self.n_groups * self.state_size,
83
+ self.conv_kernel_size,
84
+ device=device,
85
+ dtype=dtype,
86
+ )
87
+ self.ssm_states = torch.zeros(
88
+ config.num_hidden_layers,
89
+ batch_size,
90
+ self.num_heads,
91
+ self.head_dim,
92
+ self.state_size,
93
+ device=device,
94
+ dtype=dtype,
95
+ )
96
+
97
+ def update_conv_state(
98
+ self,
99
+ layer_idx: int,
100
+ new_conv_state: torch.Tensor,
101
+ cache_init: bool = False
102
+ ) -> torch.Tensor:
103
+ if cache_init:
104
+ self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device)
105
+ else:
106
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
107
+ self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device)
108
+ return self.conv_states[layer_idx]
109
+
110
+ def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
111
+ self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
112
+ return self.ssm_states[layer_idx]
113
+
114
+ def reset(self):
115
+ self.conv_states.zero_()
116
+ self.ssm_states.zero_()
117
+
118
+
119
+ class Mamba2Block(nn.Module):
120
+ def __init__(self, config, layer_idx):
121
+ super().__init__()
122
+ self.config = config
123
+ self.layer_idx = layer_idx
124
+ self.residual_in_fp32 = config.residual_in_fp32
125
+ self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
126
+ self.mixer = Mamba2(
127
+ num_heads=config.num_heads,
128
+ head_dim=config.head_dim,
129
+ hidden_size=config.hidden_size,
130
+ state_size=config.state_size,
131
+ expand=config.expand,
132
+ n_groups=config.n_groups,
133
+ conv_kernel=config.conv_kernel,
134
+ use_conv_bias=config.use_conv_bias,
135
+ hidden_act=config.hidden_act,
136
+ rms_norm=config.rms_norm,
137
+ chunk_size=config.chunk_size,
138
+ time_step_rank=config.time_step_rank,
139
+ time_step_limit=config.time_step_limit,
140
+ time_step_min=config.time_step_min,
141
+ time_step_max=config.time_step_max,
142
+ use_bias=config.use_bias,
143
+ norm_eps=config.norm_eps,
144
+ layer_idx=layer_idx,
145
+ )
146
+
147
+ def forward(
148
+ self,
149
+ hidden_states,
150
+ cache_params: Optional[Mamba2Cache] = None,
151
+ cache_position: Optional[torch.LongTensor] = None,
152
+ attention_mask: Optional[torch.Tensor] = None,
153
+ ):
154
+ residual = hidden_states
155
+ hidden_states = self.norm(hidden_states)
156
+ if self.residual_in_fp32:
157
+ residual = residual.to(torch.float32)
158
+
159
+ hidden_states = self.mixer(
160
+ hidden_states,
161
+ cache_params=cache_params,
162
+ cache_position=cache_position,
163
+ attention_mask=attention_mask,
164
+ )
165
+ hidden_states = residual + hidden_states
166
+ if self.residual_in_fp32:
167
+ hidden_states = hidden_states.to(dtype=self.norm.weight.dtype)
168
+ return hidden_states
169
+
170
+
171
+ class Mamba2PreTrainedModel(PreTrainedModel, GenerationMixin):
172
+ """
173
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
174
+ models.
175
+ """
176
+
177
+ config_class = Mamba2Config
178
+ base_model_prefix = "backbone"
179
+ _no_split_modules = ["Mamba2Block"]
180
+ supports_gradient_checkpointing = True
181
+ _is_stateful = True
182
+
183
+ def _init_weights(
184
+ self,
185
+ module: nn.Module,
186
+ num_residuals_per_layer: int = 1,
187
+ ):
188
+ """Initialize the weights."""
189
+ if isinstance(module, Mamba2):
190
+
191
+ # --- A_log ---
192
+ A = torch.arange(1, module.num_heads + 1)
193
+ with torch.no_grad():
194
+ if not isinstance(module.A_log, torch.distributed.tensor.DTensor):
195
+ module.A_log.copy_(torch.log(A))
196
+ else:
197
+ logger.warning_once("`A_log` is a DTensor, skipping initialization")
198
+ module.A_log._no_weight_decay = True
199
+
200
+ # --- D ---
201
+ nn.init.ones_(module.D)
202
+ module.D._no_weight_decay = True
203
+
204
+ # --- dt_bias ---
205
+ dt = torch.exp(
206
+ torch.rand(self.config.num_heads)
207
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
208
+ + math.log(self.config.time_step_min)
209
+ ).clamp(min=self.config.time_step_floor)
210
+
211
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
212
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
213
+ with torch.no_grad():
214
+ if not isinstance(module.dt_bias, torch.distributed.tensor.DTensor):
215
+ module.dt_bias.copy_(inv_dt)
216
+ else:
217
+ logger.warning_once("`dt_bias` is a DTensor, skipping initialization")
218
+ module.dt_bias._no_reinit = True
219
+
220
+ elif isinstance(module, (nn.Linear, nn.Conv1d)):
221
+ # Slightly different from the TF version which uses truncated_normal for initialization
222
+ # cf https://github.com/pytorch/pytorch/pull/5617
223
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
224
+ if module.bias is not None:
225
+ nn.init.zeros_(module.bias)
226
+ # guard against deprecated behavior
227
+ if hasattr(module.bias, "_no_reinit"):
228
+ raise ValueError("This is not supposed to happen")
229
+ elif isinstance(module, nn.Embedding):
230
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
231
+ elif hasattr(module, 'reset_parameters'):
232
+ module.reset_parameters()
233
+
234
+ if self.config.rescale_prenorm_residual:
235
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
236
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
237
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
238
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
239
+ #
240
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
241
+ p = None
242
+ if hasattr(module, 'o_proj'):
243
+ # p = module.o_proj.weight
244
+ # guard against deprecated behavior
245
+ raise ValueError("This is not supposed to happen")
246
+ elif hasattr(module, 'out_proj'):
247
+ p = module.out_proj.weight
248
+ elif hasattr(module, 'down_proj'):
249
+ p = module.down_proj.weight
250
+ if p is not None:
251
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
252
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
253
+ # We need to reinit p since this code could be called multiple times
254
+ # Having just p *= scale would repeatedly scale it down
255
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
256
+ with torch.no_grad():
257
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
258
+
259
+
260
+ @dataclass
261
+ # Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2
262
+ class Mamba2Output(ModelOutput):
263
+ """
264
+ Class for the MAMBA2 model outputs.
265
+
266
+ Args:
267
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
268
+ Sequence of hidden-states at the output of the last layer of the model.
269
+ cache_params (`Mamba2Cache`):
270
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
271
+ avoid providing the old `input_ids`.
272
+
273
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
274
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
275
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
276
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
277
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
278
+
279
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
280
+ """
281
+
282
+ last_hidden_state: Optional[torch.FloatTensor] = None
283
+ cache_params: Optional[Mamba2Cache] = None
284
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
285
+
286
+
287
+ @dataclass
288
+ # Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->Mamba2
289
+ class Mamba2CausalLMOutput(ModelOutput):
290
+ """
291
+ Base class for causal language model (or autoregressive) outputs.
292
+
293
+ Args:
294
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
295
+ Language modeling loss (for next-token prediction).
296
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
297
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
298
+ cache_params (`Mamba2Cache`):
299
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
300
+ avoid providing the old `input_ids`.
301
+
302
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
303
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*,
304
+ returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
305
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
306
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
307
+
308
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
309
+ """
310
+
311
+ loss: Optional[torch.FloatTensor] = None
312
+ logits: Optional[torch.FloatTensor] = None
313
+ cache_params: Optional[Mamba2Cache] = None
314
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
315
+
316
+
317
+ class Mamba2Model(Mamba2PreTrainedModel):
318
+ def __init__(self, config):
319
+ super().__init__(config)
320
+
321
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
322
+ self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
323
+
324
+ self.gradient_checkpointing = False
325
+ self.norm_f = RMSNorm(config.hidden_size, eps=config.norm_eps)
326
+ # Initialize weights and apply final processing
327
+ self._register_load_state_dict_pre_hook(self.load_hook)
328
+ self.post_init()
329
+
330
+ def load_hook(self, state_dict, prefix, *args):
331
+ for k in state_dict:
332
+ if "embedding." in k:
333
+ state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
334
+ break
335
+
336
+ def get_input_embeddings(self):
337
+ return self.embeddings
338
+
339
+ def set_input_embeddings(self, new_embeddings):
340
+ self.embeddings = new_embeddings
341
+
342
+ def forward(
343
+ self,
344
+ input_ids: Optional[torch.LongTensor] = None,
345
+ inputs_embeds: Optional[torch.LongTensor] = None,
346
+ cache_params: Optional[Mamba2Cache] = None,
347
+ use_cache: Optional[bool] = None,
348
+ output_hidden_states: Optional[bool] = None,
349
+ return_dict: Optional[bool] = None,
350
+ cache_position: Optional[torch.LongTensor] = None,
351
+ attention_mask: Optional[torch.Tensor] = None,
352
+ **kwargs,
353
+ ) -> Union[Tuple, Mamba2Output]:
354
+ output_hidden_states = (
355
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
356
+ )
357
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
358
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
359
+
360
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
361
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
362
+
363
+ if inputs_embeds is None:
364
+ inputs_embeds = self.embeddings(input_ids)
365
+
366
+ if self.gradient_checkpointing and self.training and use_cache:
367
+ use_cache = False
368
+
369
+ if use_cache:
370
+ if cache_params is None:
371
+ cache_params = Mamba2Cache(
372
+ self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
373
+ )
374
+ cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
375
+ elif cache_position is None:
376
+ # cases when we do manual forward instead of using `model.generate` which will initiate
377
+ # `cache_position` and makes sure it is not None, throw error here instead of doing some
378
+ # hack to conjecture the current cache position
379
+ raise ValueError(
380
+ "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
381
+ "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
382
+ "be initialized for you automatically"
383
+ )
384
+ else:
385
+ cache_params = None
386
+
387
+ hidden_states = inputs_embeds
388
+ all_hidden_states = () if output_hidden_states else None
389
+ for mixer_block in self.layers:
390
+ if self.gradient_checkpointing and self.training:
391
+ hidden_states = self._gradient_checkpointing_func(
392
+ mixer_block.__call__,
393
+ hidden_states,
394
+ cache_params,
395
+ cache_position,
396
+ attention_mask,
397
+ )
398
+ else:
399
+ hidden_states = mixer_block(
400
+ hidden_states,
401
+ cache_params=cache_params,
402
+ cache_position=cache_position,
403
+ attention_mask=attention_mask,
404
+ )
405
+
406
+ if output_hidden_states:
407
+ all_hidden_states = all_hidden_states + (hidden_states,)
408
+
409
+ hidden_states = self.norm_f(hidden_states)
410
+
411
+ if output_hidden_states:
412
+ all_hidden_states = all_hidden_states + (hidden_states,)
413
+
414
+ if not return_dict:
415
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
416
+
417
+ return Mamba2Output(
418
+ last_hidden_state=hidden_states,
419
+ cache_params=cache_params if use_cache else None,
420
+ hidden_states=all_hidden_states,
421
+ )
422
+
423
+
424
+ class Mamba2ForCausalLM(Mamba2PreTrainedModel):
425
+ _tied_weights_keys = []
426
+
427
+ def __init__(self, config):
428
+ super().__init__(config)
429
+ self.backbone = Mamba2Model(config)
430
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
431
+ self.criterion = None
432
+
433
+ # Initialize weights and apply final processing
434
+ self.post_init()
435
+
436
+ def get_output_embeddings(self):
437
+ return self.lm_head
438
+
439
+ def set_output_embeddings(self, new_embeddings):
440
+ self.lm_head = new_embeddings
441
+
442
+ def get_input_embeddings(self):
443
+ return self.backbone.get_input_embeddings()
444
+
445
+ def set_input_embeddings(self, new_embeddings):
446
+ return self.backbone.set_input_embeddings(new_embeddings)
447
+
448
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
449
+ def prepare_inputs_for_generation(
450
+ self,
451
+ input_ids,
452
+ inputs_embeds=None,
453
+ use_cache=None,
454
+ cache_params: Optional[Mamba2Cache] = None,
455
+ cache_position: Optional[torch.LongTensor] = None,
456
+ attention_mask: Optional[torch.Tensor] = None,
457
+ logits_to_keep: Optional[int] = None,
458
+ **kwargs,
459
+ ):
460
+ if use_cache:
461
+ # `cache_position` should have been initialized in `generate`
462
+ if cache_position is None:
463
+ raise ValueError(
464
+ "`cache_position` should not be None as it should have been initialized in "
465
+ "`model.generate`, you are responsible for passing in a valid `cache_position` if "
466
+ "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
467
+ )
468
+ if cache_position[0] > 0:
469
+ input_ids = input_ids[:, -1][..., None]
470
+
471
+ if attention_mask is not None:
472
+ attention_mask = None
473
+ else:
474
+ # we initialize the `cache_position` to full size of `conv_states` at prefill stage
475
+ # considering padding will be applied when input length is shorter, and truncation
476
+ # will be applied when it is longer, so it will be equivalent to always have it match
477
+ # the length of `cache_params.conv_states`, which is `config.conv_kernel`
478
+ cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
479
+
480
+ if inputs_embeds is not None and cache_params is None:
481
+ model_inputs = {"inputs_embeds": inputs_embeds}
482
+ else:
483
+ model_inputs = {"input_ids": input_ids}
484
+
485
+ if logits_to_keep is not None:
486
+ model_inputs['logits_to_keep'] = logits_to_keep
487
+
488
+ model_inputs.update({
489
+ 'attention_mask': attention_mask,
490
+ 'cache_params': cache_params,
491
+ 'use_cache': use_cache,
492
+ 'cache_position': cache_position,
493
+ 'logits_to_keep': logits_to_keep
494
+ })
495
+ return model_inputs
496
+
497
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
498
+ def forward(
499
+ self,
500
+ input_ids: Optional[torch.LongTensor] = None,
501
+ inputs_embeds: Optional[torch.FloatTensor] = None,
502
+ cache_params: Optional[Mamba2Cache] = None,
503
+ labels: Optional[torch.LongTensor] = None,
504
+ output_hidden_states: Optional[bool] = None,
505
+ return_dict: Optional[bool] = None,
506
+ use_cache: Optional[bool] = None,
507
+ cache_position: Optional[torch.Tensor] = None,
508
+ attention_mask: Optional[torch.Tensor] = None,
509
+ logits_to_keep: Optional[int] = 0,
510
+ **kwargs, # for now we need this for generation
511
+ ) -> Union[Tuple, Mamba2CausalLMOutput]:
512
+ r"""
513
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
514
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
515
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
516
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
517
+ """
518
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
519
+
520
+ outputs = self.backbone(
521
+ input_ids,
522
+ cache_params=cache_params,
523
+ inputs_embeds=inputs_embeds,
524
+ output_hidden_states=output_hidden_states,
525
+ return_dict=return_dict,
526
+ use_cache=use_cache,
527
+ cache_position=cache_position,
528
+ attention_mask=attention_mask,
529
+ )
530
+ hidden_states = outputs[0]
531
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
532
+
533
+ loss, logits = None, None
534
+ if not fuse_linear_and_cross_entropy or labels is None:
535
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
536
+ if labels is not None:
537
+ if getattr(self, 'criterion', None) is None:
538
+ if fuse_linear_and_cross_entropy:
539
+ criterion = FusedLinearCrossEntropyLoss()
540
+ elif self.config.fuse_cross_entropy:
541
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
542
+ else:
543
+ criterion = nn.CrossEntropyLoss()
544
+ else:
545
+ criterion = self.criterion
546
+ labels = labels.to(hidden_states.device)
547
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
548
+ if fuse_linear_and_cross_entropy:
549
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
550
+ else:
551
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
552
+
553
+ if not return_dict:
554
+ output = (logits,) + outputs[1:]
555
+ return (loss,) + output if loss is not None else output
556
+
557
+ return Mamba2CausalLMOutput(
558
+ loss=loss,
559
+ logits=logits,
560
+ cache_params=outputs.cache_params,
561
+ hidden_states=outputs.hidden_states,
562
+ )
fla3/models/nsa/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.nsa.configuration_nsa import NSAConfig
6
+ from fla.models.nsa.modeling_nsa import NSAForCausalLM, NSAModel
7
+
8
+ AutoConfig.register(NSAConfig.model_type, NSAConfig)
9
+ AutoModel.register(NSAConfig, NSAModel)
10
+ AutoModelForCausalLM.register(NSAConfig, NSAForCausalLM)
11
+
12
+
13
+ __all__ = [
14
+ 'NSAConfig', 'NSAModel', 'NSAForCausalLM',
15
+ ]
fla3/models/nsa/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (516 Bytes). View file
 
fla3/models/nsa/__pycache__/configuration_nsa.cpython-310.pyc ADDED
Binary file (2.1 kB). View file
 
fla3/models/nsa/__pycache__/modeling_nsa.cpython-310.pyc ADDED
Binary file (11.3 kB). View file