zaydzuhri commited on
Commit
c93077d
·
verified ·
1 Parent(s): 89d2952

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/__pycache__/__init__.cpython-312.pyc +0 -0
  2. fla/layers/__pycache__/forgetting_attn.cpython-312.pyc +0 -0
  3. fla/layers/forgetting_attn.py +109 -0
  4. fla/layers/nsa.py +138 -0
  5. fla/models/__pycache__/__init__.cpython-312.pyc +0 -0
  6. fla/models/abc/__init__.py +13 -0
  7. fla/models/abc/__pycache__/__init__.cpython-312.pyc +0 -0
  8. fla/models/abc/__pycache__/configuration_abc.cpython-312.pyc +0 -0
  9. fla/models/abc/__pycache__/modeling_abc.cpython-312.pyc +0 -0
  10. fla/models/bitnet/__pycache__/__init__.cpython-312.pyc +0 -0
  11. fla/models/bitnet/__pycache__/configuration_bitnet.cpython-312.pyc +0 -0
  12. fla/models/bitnet/__pycache__/modeling_bitnet.cpython-312.pyc +0 -0
  13. fla/models/delta_net/__pycache__/__init__.cpython-312.pyc +0 -0
  14. fla/models/delta_net/__pycache__/configuration_delta_net.cpython-312.pyc +0 -0
  15. fla/models/delta_net/__pycache__/modeling_delta_net.cpython-312.pyc +0 -0
  16. fla/models/delta_net/configuration_delta_net.py +91 -0
  17. fla/models/forgetting_transformer/__pycache__/__init__.cpython-312.pyc +0 -0
  18. fla/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-312.pyc +0 -0
  19. fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-312.pyc +0 -0
  20. fla/models/gated_deltanet/__pycache__/__init__.cpython-312.pyc +0 -0
  21. fla/models/gated_deltanet/modeling_gated_deltanet.py +412 -0
  22. fla/models/gated_deltaproduct/__init__.py +14 -0
  23. fla/models/gla/__pycache__/__init__.cpython-312.pyc +0 -0
  24. fla/models/gla/__pycache__/modeling_gla.cpython-312.pyc +0 -0
  25. fla/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc +0 -0
  26. fla/models/gsa/__pycache__/modeling_gsa.cpython-312.pyc +0 -0
  27. fla/models/gsa/modeling_gsa.py +420 -0
  28. fla/models/hgrn/__init__.py +13 -0
  29. fla/models/hgrn/__pycache__/__init__.cpython-312.pyc +0 -0
  30. fla/models/hgrn/__pycache__/configuration_hgrn.cpython-312.pyc +0 -0
  31. fla/models/hgrn/configuration_hgrn.py +81 -0
  32. fla/models/hgrn/modeling_hgrn.py +420 -0
  33. fla/models/hgrn2/__pycache__/__init__.cpython-312.pyc +0 -0
  34. fla/models/hgrn2/__pycache__/configuration_hgrn2.cpython-312.pyc +0 -0
  35. fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc +0 -0
  36. fla/models/hgrn2/modeling_hgrn2.py +421 -0
  37. fla/models/lightnet/__pycache__/configuration_lightnet.cpython-312.pyc +0 -0
  38. fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc +0 -0
  39. fla/models/lightnet/configuration_lightnet.py +83 -0
  40. fla/models/lightnet/modeling_lightnet.py +410 -0
  41. fla/models/linear_attn/__init__.py +12 -0
  42. fla/models/linear_attn/__pycache__/__init__.cpython-312.pyc +0 -0
  43. fla/models/linear_attn/__pycache__/configuration_linear_attn.cpython-312.pyc +0 -0
  44. fla/models/linear_attn/__pycache__/modeling_linear_attn.cpython-312.pyc +0 -0
  45. fla/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc +0 -0
  46. fla/models/mamba2/__init__.py +13 -0
  47. fla/models/mamba2/__pycache__/__init__.cpython-312.pyc +0 -0
  48. fla/models/nsa/__init__.py +15 -0
  49. fla/models/nsa/__pycache__/configuration_nsa.cpython-312.pyc +0 -0
  50. fla/models/nsa/__pycache__/modeling_nsa.cpython-312.pyc +0 -0
fla/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.85 kB). View file
 
fla/layers/__pycache__/forgetting_attn.cpython-312.pyc ADDED
Binary file (5.3 kB). View file
 
fla/layers/forgetting_attn.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from einops import rearrange
13
+ from transformers.utils import logging
14
+
15
+ from fla.modules import GroupNorm
16
+ from fla.ops.forgetting_attn.parallel import parallel_forgetting_attn
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class ForgettingAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ hidden_size: int = 2048,
30
+ num_heads: int = 32,
31
+ num_kv_heads: Optional[int] = None,
32
+ qkv_bias: bool = False,
33
+ qk_norm: bool = False,
34
+ window_size: Optional[int] = None,
35
+ use_output_gate: bool = False,
36
+ layer_idx: int = None
37
+ ):
38
+ super().__init__()
39
+
40
+ self.hidden_size = hidden_size
41
+ self.num_heads = num_heads
42
+ if num_kv_heads is None:
43
+ self.num_kv_heads = self.num_heads
44
+ else:
45
+ self.num_kv_heads = num_kv_heads
46
+ self.num_kv_groups = num_heads // self.num_kv_heads
47
+ self.head_dim = self.hidden_size // self.num_heads
48
+ self.kv_dim = self.num_kv_heads * self.head_dim
49
+ self.qkv_bias = qkv_bias
50
+ self.qk_norm = qk_norm
51
+
52
+ self.window_size = window_size
53
+ self.use_output_gate = use_output_gate
54
+ self.layer_idx = layer_idx
55
+
56
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
57
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
58
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
59
+ self.f_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
60
+
61
+ if use_output_gate:
62
+ self.g_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
63
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
64
+
65
+ if qk_norm:
66
+ self.q_norm = GroupNorm(
67
+ num_groups=self.num_heads,
68
+ hidden_size=self.hidden_size,
69
+ is_rms_norm=True,
70
+ )
71
+ self.k_norm = GroupNorm(
72
+ num_groups=self.num_kv_heads,
73
+ hidden_size=self.kv_dim,
74
+ is_rms_norm=True,
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ hidden_states: torch.Tensor,
80
+ attention_mask: Optional[torch.LongTensor] = None,
81
+ past_key_values: Optional[Cache] = None,
82
+ output_attentions: bool = False,
83
+ use_cache: bool = False,
84
+ **kwargs,
85
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
86
+ if attention_mask is not None:
87
+ assert len(attention_mask.shape) == 2, (
88
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
89
+ "for padding purposes (0 indicating padding). "
90
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
91
+ )
92
+
93
+ cu_seqlens = kwargs.get('cu_seqlens', None)
94
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
95
+ f = F.logsigmoid(self.f_proj(hidden_states).float())
96
+ if self.qk_norm:
97
+ q, k = self.q_norm(q), self.k_norm(k)
98
+
99
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
100
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
101
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
102
+
103
+ o = parallel_forgetting_attn(q, k, v, f, cu_seqlens=cu_seqlens)
104
+ o = rearrange(o, '... h d -> ... (h d)')
105
+ if self.use_output_gate:
106
+ o = self.g_proj(hidden_states).sigmoid() * o
107
+ o = self.o_proj(o)
108
+
109
+ return o, None, past_key_values
fla/layers/nsa.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from transformers.utils import logging
12
+
13
+ from fla.modules import RotaryEmbedding
14
+ from fla.ops.nsa.parallel import parallel_nsa
15
+
16
+ if TYPE_CHECKING:
17
+ from fla.models.utils import Cache
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class NativeSparseAttention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ hidden_size: int = 2048,
27
+ num_heads: int = 64,
28
+ num_kv_heads: Optional[int] = 4,
29
+ head_dim: int = 64,
30
+ qkv_bias: bool = False,
31
+ block_size: Optional[int] = 64,
32
+ block_counts: Optional[Union[torch.LongTensor, int]] = 16,
33
+ window_size: Optional[int] = 512,
34
+ rope_theta: Optional[float] = 10000.,
35
+ max_position_embeddings: Optional[int] = None,
36
+ layer_idx: int = None
37
+ ):
38
+ super().__init__()
39
+
40
+ self.hidden_size = hidden_size
41
+ self.num_heads = num_heads
42
+ if num_kv_heads is None:
43
+ self.num_kv_heads = self.num_heads
44
+ else:
45
+ self.num_kv_heads = num_kv_heads
46
+ self.num_kv_groups = num_heads // self.num_kv_heads
47
+ self.head_dim = head_dim
48
+ self.kv_dim = self.num_kv_heads * self.head_dim
49
+ self.qkv_bias = qkv_bias
50
+
51
+ self.block_size = block_size
52
+ self.block_counts = block_counts
53
+ self.window_size = window_size
54
+ self.rope_theta = rope_theta
55
+ self.max_position_embeddings = max_position_embeddings
56
+ self.layer_idx = layer_idx
57
+
58
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.qkv_bias)
59
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
60
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
61
+ self.g_proj = nn.Linear(self.hidden_size, self.num_heads * 3, bias=False)
62
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
63
+
64
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
65
+
66
+ def forward(
67
+ self,
68
+ hidden_states: torch.Tensor,
69
+ attention_mask: Optional[torch.LongTensor] = None,
70
+ past_key_values: Optional[Cache] = None,
71
+ output_attentions: bool = False,
72
+ use_cache: bool = False,
73
+ **kwargs,
74
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
75
+ if attention_mask is not None:
76
+ assert len(attention_mask.shape) == 2, (
77
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
78
+ "for padding purposes (0 indicating padding). "
79
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
80
+ )
81
+
82
+ batch_size, seq_len, _ = hidden_states.size()
83
+
84
+ q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
85
+ k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
86
+ v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
87
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=3)
88
+ g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1)
89
+
90
+ cu_seqlens = kwargs.get('cu_seqlens', None)
91
+
92
+ seqlen_offset, max_seqlen = 0, seq_len
93
+ if past_key_values is not None:
94
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
95
+ max_seqlen = q.shape[1] + seqlen_offset
96
+
97
+ if attention_mask is not None:
98
+ # to deliminate the offsets of padding tokens
99
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
100
+ max_seqlen = q.shape[1] + max(seqlen_offset)
101
+
102
+ if self.max_position_embeddings is not None:
103
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
104
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
105
+
106
+ if past_key_values is not None:
107
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
108
+ k_cached, v_cached = past_key_values.update(
109
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
110
+ layer_idx=self.layer_idx,
111
+ offset=seq_len,
112
+ cache_kwargs=dict(window_size=self.window_size)
113
+ )['attn_state']
114
+ if cache_has_content:
115
+ k, v = k_cached, v_cached
116
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
117
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
118
+
119
+ o = parallel_nsa(
120
+ q=q,
121
+ k=k,
122
+ v=v,
123
+ g_cmp=g_cmp,
124
+ g_slc=g_slc,
125
+ g_swa=g_swa,
126
+ block_size=self.block_size,
127
+ block_counts=self.block_counts,
128
+ window_size=self.window_size,
129
+ cu_seqlens=cu_seqlens,
130
+ head_first=False
131
+ )
132
+ o = o.reshape(batch_size, seq_len, -1)
133
+ o = self.o_proj(o)
134
+
135
+ if not output_attentions:
136
+ attentions = None
137
+
138
+ return o, attentions, past_key_values
fla/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (3.07 kB). View file
 
fla/models/abc/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.abc.configuration_abc import ABCConfig
6
+ from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel
7
+
8
+ AutoConfig.register(ABCConfig.model_type, ABCConfig)
9
+ AutoModel.register(ABCConfig, ABCModel)
10
+ AutoModelForCausalLM.register(ABCConfig, ABCForCausalLM)
11
+
12
+
13
+ __all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel']
fla/models/abc/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (657 Bytes). View file
 
fla/models/abc/__pycache__/configuration_abc.cpython-312.pyc ADDED
Binary file (3.61 kB). View file
 
fla/models/abc/__pycache__/modeling_abc.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
fla/models/bitnet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (682 Bytes). View file
 
fla/models/bitnet/__pycache__/configuration_bitnet.cpython-312.pyc ADDED
Binary file (2.37 kB). View file
 
fla/models/bitnet/__pycache__/modeling_bitnet.cpython-312.pyc ADDED
Binary file (18.6 kB). View file
 
fla/models/delta_net/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (701 Bytes). View file
 
fla/models/delta_net/__pycache__/configuration_delta_net.cpython-312.pyc ADDED
Binary file (3.59 kB). View file
 
fla/models/delta_net/__pycache__/modeling_delta_net.cpython-312.pyc ADDED
Binary file (18.5 kB). View file
 
fla/models/delta_net/configuration_delta_net.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class DeltaNetConfig(PretrainedConfig):
9
+
10
+ model_type = 'delta_net'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "chunk",
16
+ hidden_size: int = 2048,
17
+ expand_k: int = 1,
18
+ expand_v: int = 1,
19
+ use_gate: bool = False,
20
+ use_short_conv: bool = True,
21
+ conv_size: int = 4,
22
+ use_beta: bool = True,
23
+ use_output_norm: bool = True,
24
+ num_heads: int = 16,
25
+ qk_norm: str = 'l2',
26
+ qk_activation: str = 'silu',
27
+ max_position_embeddings: int = 2048,
28
+ hidden_ratio: Optional[int] = 4,
29
+ intermediate_size: Optional[int] = None,
30
+ hidden_act: str = "swish",
31
+ num_hidden_layers: int = 24,
32
+ norm_eps: float = 1e-6,
33
+ attn: Optional[Dict] = None,
34
+ use_cache: bool = True,
35
+ pad_token_id: int = None,
36
+ bos_token_id: int = 1,
37
+ eos_token_id: int = 2,
38
+ tie_word_embeddings: bool = False,
39
+ initializer_range: float = 0.006,
40
+ fuse_norm: bool = True,
41
+ fuse_swiglu: bool = True,
42
+ fuse_cross_entropy: bool = True,
43
+ vocab_size: int = 32000,
44
+ **kwargs
45
+ ):
46
+ self.attn_mode = attn_mode
47
+ self.hidden_size = hidden_size
48
+ self.expand_k = expand_k
49
+ self.expand_v = expand_v
50
+ self.use_gate = use_gate
51
+ self.use_short_conv = use_short_conv
52
+ self.conv_size = conv_size
53
+ self.use_beta = use_beta
54
+ self.use_output_norm = use_output_norm
55
+ self.num_heads = num_heads
56
+ self.qk_norm = qk_norm
57
+ self.qk_activation = qk_activation
58
+ self.max_position_embeddings = max_position_embeddings
59
+
60
+ self.hidden_ratio = hidden_ratio
61
+ self.intermediate_size = intermediate_size
62
+ self.hidden_act = hidden_act
63
+ self.num_hidden_layers = num_hidden_layers
64
+ self.norm_eps = norm_eps
65
+ self.attn = attn
66
+ self.use_cache = use_cache
67
+ self.initializer_range = initializer_range
68
+ self.fuse_norm = fuse_norm
69
+ self.fuse_swiglu = fuse_swiglu
70
+ self.fuse_cross_entropy = fuse_cross_entropy
71
+ self.vocab_size = vocab_size
72
+
73
+ if attn is not None:
74
+ if not isinstance(attn, Dict):
75
+ raise ValueError("attn must be a dictionary")
76
+ if 'layers' not in attn:
77
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
78
+ if 'num_heads' not in attn:
79
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
80
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
81
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
82
+ attn['window_size'] = attn.get('window_size', None)
83
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
84
+
85
+ super().__init__(
86
+ pad_token_id=pad_token_id,
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ tie_word_embeddings=tie_word_embeddings,
90
+ **kwargs,
91
+ )
fla/models/forgetting_transformer/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (817 Bytes). View file
 
fla/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-312.pyc ADDED
Binary file (2.5 kB). View file
 
fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-312.pyc ADDED
Binary file (17.2 kB). View file
 
fla/models/gated_deltanet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (746 Bytes). View file
 
fla/models/gated_deltanet/modeling_gated_deltanet.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.gated_deltanet import GatedDeltaNet
20
+ from fla.models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as GatedDeltaNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class GatedDeltaNetBlock(nn.Module):
34
+ def __init__(self, config: GatedDeltaNetConfig, layer_idx: int):
35
+ super().__init__()
36
+
37
+ self.config = config
38
+ self.layer_idx = layer_idx
39
+
40
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
41
+ if config.attn is not None and layer_idx in config.attn['layers']:
42
+ self.attn = Attention(
43
+ hidden_size=config.hidden_size,
44
+ num_heads=config.attn['num_heads'],
45
+ num_kv_heads=config.attn['num_kv_heads'],
46
+ qkv_bias=config.attn['qkv_bias'],
47
+ window_size=config.attn['window_size'],
48
+ rope_theta=config.attn['rope_theta'],
49
+ max_position_embeddings=config.max_position_embeddings,
50
+ layer_idx=layer_idx
51
+ )
52
+ else:
53
+ self.attn = GatedDeltaNet(
54
+ mode=config.attn_mode,
55
+ hidden_size=config.hidden_size,
56
+ expand_v=config.expand_v,
57
+ head_dim=config.head_dim,
58
+ num_heads=config.num_heads,
59
+ use_gate=config.use_gate,
60
+ use_short_conv=config.use_short_conv,
61
+ conv_size=config.conv_size,
62
+ norm_eps=config.norm_eps,
63
+ layer_idx=layer_idx
64
+ )
65
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
66
+ self.mlp = GatedDeltaNetMLP(
67
+ hidden_size=config.hidden_size,
68
+ hidden_ratio=config.hidden_ratio,
69
+ intermediate_size=config.intermediate_size,
70
+ hidden_act=config.hidden_act,
71
+ fuse_swiglu=config.fuse_swiglu
72
+ )
73
+
74
+ def forward(
75
+ self,
76
+ hidden_states: torch.Tensor,
77
+ attention_mask: Optional[torch.Tensor] = None,
78
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
79
+ use_cache: Optional[bool] = False,
80
+ output_attentions: Optional[bool] = False,
81
+ **kwargs: Unpack[Dict]
82
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
83
+ residual = hidden_states
84
+ hidden_states = self.attn_norm(hidden_states)
85
+ hidden_states, attentions, past_key_values = self.attn(
86
+ hidden_states=hidden_states,
87
+ attention_mask=attention_mask,
88
+ past_key_values=past_key_values,
89
+ use_cache=use_cache,
90
+ output_attentions=output_attentions,
91
+ **kwargs
92
+ )
93
+ if self.config.fuse_norm:
94
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
95
+ else:
96
+ hidden_states = residual + hidden_states
97
+ residual = hidden_states
98
+ hidden_states = self.mlp_norm(hidden_states)
99
+ hidden_states = self.mlp(hidden_states, **kwargs)
100
+ hidden_states = residual + hidden_states
101
+
102
+ outputs = (hidden_states, attentions, past_key_values)
103
+
104
+ return outputs
105
+
106
+
107
+ class GatedDeltaNetPreTrainedModel(PreTrainedModel):
108
+
109
+ config_class = GatedDeltaNetConfig
110
+ base_model_prefix = 'model'
111
+ supports_gradient_checkpointing = True
112
+ _no_split_modules = ['GatedDeltaNetBlock']
113
+ _supports_cache_class = True
114
+
115
+ def __init__(self, *inputs, **kwargs):
116
+ super().__init__(*inputs, **kwargs)
117
+
118
+ def _init_weights(
119
+ self,
120
+ module: nn.Module,
121
+ prenorm_residual_strategy: Optional[str] = 'rescale',
122
+ num_residuals_per_layer: int = 2,
123
+ ):
124
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
125
+ # Slightly different from the TF version which uses truncated_normal for initialization
126
+ # cf https://github.com/pytorch/pytorch/pull/5617
127
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
128
+ if module.bias is not None:
129
+ nn.init.zeros_(module.bias)
130
+ elif isinstance(module, nn.Embedding):
131
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
132
+ elif hasattr(module, 'reset_parameters'):
133
+ module.reset_parameters()
134
+
135
+ if prenorm_residual_strategy is not None:
136
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
137
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
138
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
139
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
140
+ #
141
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
142
+ p = None
143
+ if hasattr(module, 'o_proj'):
144
+ p = module.o_proj.weight
145
+ elif hasattr(module, 'down_proj'):
146
+ p = module.down_proj.weight
147
+ if p is not None:
148
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
149
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
150
+ # We need to reinit p since this code could be called multiple times
151
+ # Having just p *= scale would repeatedly scale it down
152
+ if prenorm_residual_strategy == 'rescale':
153
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
154
+ with torch.no_grad():
155
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
156
+ elif prenorm_residual_strategy == 'zero':
157
+ nn.init.zeros_(p)
158
+ else:
159
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
160
+
161
+
162
+ class GatedDeltaNetModel(GatedDeltaNetPreTrainedModel):
163
+
164
+ def __init__(self, config: GatedDeltaNetConfig):
165
+ super().__init__(config)
166
+ self.padding_idx = config.pad_token_id
167
+ self.vocab_size = config.vocab_size
168
+
169
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
170
+ self.layers = nn.ModuleList([GatedDeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
171
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
172
+
173
+ self.gradient_checkpointing = False
174
+
175
+ self.post_init()
176
+
177
+ def get_input_embeddings(self):
178
+ return self.embeddings
179
+
180
+ def set_input_embeddings(self, value):
181
+ self.embeddings = value
182
+
183
+ def forward(
184
+ self,
185
+ input_ids: Optional[torch.LongTensor] = None,
186
+ attention_mask: Optional[torch.Tensor] = None, # noqa
187
+ inputs_embeds: Optional[torch.FloatTensor] = None,
188
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
189
+ use_cache: Optional[bool] = None,
190
+ output_attentions: Optional[bool] = None,
191
+ output_hidden_states: Optional[bool] = None,
192
+ return_dict: Optional[bool] = None,
193
+ **kwargs: Unpack[Dict]
194
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
195
+ if output_attentions:
196
+ warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.")
197
+ output_attentions = False
198
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
199
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
200
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
201
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
202
+
203
+ # retrieve input_ids and inputs_embeds
204
+ if input_ids is not None and inputs_embeds is not None:
205
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
206
+ if input_ids is None and inputs_embeds is None:
207
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
208
+
209
+ if inputs_embeds is None:
210
+ inputs_embeds = self.embeddings(input_ids)
211
+ hidden_states = inputs_embeds
212
+
213
+ if use_cache and not isinstance(past_key_values, Cache):
214
+ past_key_values = Cache.from_legacy_cache(past_key_values)
215
+
216
+ if self.gradient_checkpointing and self.training and use_cache:
217
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
218
+ use_cache = False
219
+
220
+ all_hidden_states = () if output_hidden_states else None
221
+ all_attns = () if output_attentions else None
222
+ for layer in self.layers:
223
+ if output_hidden_states:
224
+ all_hidden_states += (hidden_states,)
225
+
226
+ if self.gradient_checkpointing and self.training:
227
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
228
+ layer.__call__,
229
+ hidden_states,
230
+ attention_mask,
231
+ past_key_values,
232
+ use_cache,
233
+ output_attentions,
234
+ **kwargs
235
+ )
236
+ else:
237
+ hidden_states, attentions, past_key_values = layer(
238
+ hidden_states,
239
+ attention_mask=attention_mask,
240
+ past_key_values=past_key_values,
241
+ use_cache=use_cache,
242
+ output_attentions=output_attentions,
243
+ **kwargs
244
+ )
245
+
246
+ if output_attentions:
247
+ all_attns += (attentions,)
248
+
249
+ hidden_states = self.norm(hidden_states)
250
+
251
+ # add hidden states from the last decoder layer
252
+ if output_hidden_states:
253
+ all_hidden_states += (hidden_states,)
254
+
255
+ if not return_dict:
256
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
257
+ return BaseModelOutputWithPast(
258
+ last_hidden_state=hidden_states,
259
+ past_key_values=past_key_values,
260
+ hidden_states=all_hidden_states,
261
+ attentions=all_attns
262
+ )
263
+
264
+
265
+ class GatedDeltaNetForCausalLM(GatedDeltaNetPreTrainedModel, GenerationMixin):
266
+
267
+ _tied_weights_keys = ["lm_head.weight"]
268
+
269
+ def __init__(self, config):
270
+ super().__init__(config)
271
+ self.model = GatedDeltaNetModel(config)
272
+ self.vocab_size = config.vocab_size
273
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
274
+ self.criterion = None
275
+
276
+ # Initialize weights and apply final processing
277
+ self.post_init()
278
+
279
+ def get_input_embeddings(self):
280
+ return self.model.embeddings
281
+
282
+ def set_input_embeddings(self, value):
283
+ self.model.embeddings = value
284
+
285
+ def get_output_embeddings(self):
286
+ return self.lm_head
287
+
288
+ def set_output_embeddings(self, new_embeddings):
289
+ self.lm_head = new_embeddings
290
+
291
+ def set_decoder(self, decoder):
292
+ self.model = decoder
293
+
294
+ def get_decoder(self):
295
+ return self.model
296
+
297
+ def generate(self, *args, **kwargs):
298
+ try:
299
+ return super().generate(*args, **kwargs)
300
+ except AttributeError as exception:
301
+ if 'past_key_values' in str(exception):
302
+ raise AttributeError(
303
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
304
+ f"which is not supported for {self.__class__.__name__}. "
305
+ f"Try another generation strategy instead. "
306
+ f"For the available generation strategies, check this doc: "
307
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
308
+ )
309
+ else:
310
+ raise exception
311
+
312
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
313
+ def prepare_inputs_for_generation(
314
+ self,
315
+ input_ids: torch.LongTensor = None,
316
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
317
+ attention_mask: Optional[torch.Tensor] = None,
318
+ inputs_embeds: Optional[torch.Tensor] = None,
319
+ use_cache: bool = True,
320
+ logits_to_keep: Optional[int] = None,
321
+ **kwargs
322
+ ):
323
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
324
+ if past_key_values is not None and len(past_key_values) > 0:
325
+ input_ids = input_ids[:, -1:]
326
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
327
+ if inputs_embeds is not None and len(past_key_values) == 0:
328
+ model_inputs = {'inputs_embeds': inputs_embeds}
329
+ else:
330
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
331
+ # recompiles graphs as the stride of the inputs is a guard.
332
+ # Ref: https://github.com/huggingface/transformers/pull/29114
333
+ # TODO: use `next_tokens` directly instead.
334
+ model_inputs = {'input_ids': input_ids.contiguous()}
335
+
336
+ if logits_to_keep is not None:
337
+ model_inputs['logits_to_keep'] = logits_to_keep
338
+
339
+ model_inputs.update({
340
+ 'past_key_values': past_key_values,
341
+ 'use_cache': use_cache,
342
+ 'attention_mask': attention_mask,
343
+ })
344
+ return model_inputs
345
+
346
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
347
+ def forward(
348
+ self,
349
+ input_ids: torch.LongTensor = None,
350
+ attention_mask: Optional[torch.Tensor] = None,
351
+ inputs_embeds: Optional[torch.Tensor] = None,
352
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
353
+ labels: Optional[torch.LongTensor] = None,
354
+ use_cache: Optional[bool] = None,
355
+ output_attentions: Optional[bool] = None,
356
+ output_hidden_states: Optional[bool] = None,
357
+ return_dict: Optional[bool] = None,
358
+ logits_to_keep: Optional[int] = 0,
359
+ **kwargs: Unpack[Dict]
360
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
361
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
362
+ output_hidden_states = (
363
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
364
+ )
365
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
366
+
367
+ outputs = self.model(
368
+ input_ids=input_ids,
369
+ attention_mask=attention_mask,
370
+ inputs_embeds=inputs_embeds,
371
+ past_key_values=past_key_values,
372
+ use_cache=use_cache,
373
+ output_attentions=output_attentions,
374
+ output_hidden_states=output_hidden_states,
375
+ return_dict=return_dict,
376
+ **kwargs
377
+ )
378
+
379
+ hidden_states = outputs[0]
380
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
381
+
382
+ loss, logits = None, None
383
+ if not fuse_linear_and_cross_entropy or labels is None:
384
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
385
+ if labels is not None:
386
+ if getattr(self, 'criterion', None) is None:
387
+ if fuse_linear_and_cross_entropy:
388
+ criterion = FusedLinearCrossEntropyLoss()
389
+ elif self.config.fuse_cross_entropy:
390
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
391
+ else:
392
+ criterion = nn.CrossEntropyLoss()
393
+ else:
394
+ criterion = self.criterion
395
+ labels = labels.to(hidden_states.device)
396
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
397
+ if fuse_linear_and_cross_entropy:
398
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
399
+ else:
400
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
401
+
402
+ if not return_dict:
403
+ output = (logits,) + outputs[1:]
404
+ return (loss,) + output if loss is not None else output
405
+
406
+ return CausalLMOutputWithPast(
407
+ loss=loss,
408
+ logits=logits,
409
+ past_key_values=outputs.past_key_values,
410
+ hidden_states=outputs.hidden_states,
411
+ attentions=outputs.attentions,
412
+ )
fla/models/gated_deltaproduct/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
2
+
3
+ from fla.models.gated_deltaproduct.configuration_gated_deltaproduct import GatedDeltaProductConfig
4
+ from fla.models.gated_deltaproduct.modeling_gated_deltaproduct import GatedDeltaProductForCausalLM, GatedDeltaProductModel
5
+
6
+ AutoConfig.register(GatedDeltaProductConfig.model_type, GatedDeltaProductConfig)
7
+ AutoModel.register(GatedDeltaProductConfig, GatedDeltaProductModel)
8
+ AutoModelForCausalLM.register(GatedDeltaProductConfig, GatedDeltaProductForCausalLM)
9
+
10
+ __all__ = [
11
+ "GatedDeltaProductConfig",
12
+ "GatedDeltaProductForCausalLM",
13
+ "GatedDeltaProductModel",
14
+ ]
fla/models/gla/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (657 Bytes). View file
 
fla/models/gla/__pycache__/modeling_gla.cpython-312.pyc ADDED
Binary file (18.6 kB). View file
 
fla/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc ADDED
Binary file (3.84 kB). View file
 
fla/models/gsa/__pycache__/modeling_gsa.cpython-312.pyc ADDED
Binary file (18.7 kB). View file
 
fla/models/gsa/modeling_gsa.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.gsa import GatedSlotAttention
20
+ from fla.models.gsa.configuration_gsa import GSAConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as GSAMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class GSABlock(nn.Module):
33
+ def __init__(self, config: GSAConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = GatedSlotAttention(
53
+ hidden_size=config.hidden_size,
54
+ expand_k=config.expand_k,
55
+ expand_v=config.expand_v,
56
+ num_heads=config.num_heads,
57
+ num_kv_heads=config.num_kv_heads,
58
+ num_slots=config.num_slots,
59
+ use_short_conv=config.use_short_conv,
60
+ conv_size=config.conv_size,
61
+ feature_map=config.feature_map,
62
+ use_output_gate=config.use_output_gate,
63
+ use_norm=config.use_norm,
64
+ gate_fn=config.hidden_act,
65
+ gate_logit_normalizer=config.gate_logit_normalizer,
66
+ elementwise_affine=config.elementwise_affine,
67
+ norm_eps=config.norm_eps,
68
+ fuse_norm=config.fuse_norm,
69
+ layer_idx=layer_idx
70
+ )
71
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
72
+ self.mlp = GSAMLP(
73
+ hidden_size=config.hidden_size,
74
+ hidden_ratio=config.hidden_ratio,
75
+ intermediate_size=config.intermediate_size,
76
+ hidden_act=config.hidden_act,
77
+ fuse_swiglu=config.fuse_swiglu
78
+ )
79
+
80
+ def forward(
81
+ self,
82
+ hidden_states: torch.Tensor,
83
+ attention_mask: Optional[torch.Tensor] = None,
84
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
85
+ use_cache: Optional[bool] = False,
86
+ output_attentions: Optional[bool] = False,
87
+ **kwargs: Unpack[Dict]
88
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
89
+ residual = hidden_states
90
+ hidden_states = self.attn_norm(hidden_states)
91
+ hidden_states, attentions, past_key_values = self.attn(
92
+ hidden_states=hidden_states,
93
+ attention_mask=attention_mask,
94
+ past_key_values=past_key_values,
95
+ use_cache=use_cache,
96
+ output_attentions=output_attentions,
97
+ **kwargs
98
+ )
99
+ if self.config.fuse_norm:
100
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
101
+ else:
102
+ hidden_states = residual + hidden_states
103
+ residual = hidden_states
104
+ hidden_states = self.mlp_norm(hidden_states)
105
+ hidden_states = self.mlp(hidden_states, **kwargs)
106
+ hidden_states = residual + hidden_states
107
+
108
+ outputs = (hidden_states, attentions, past_key_values)
109
+
110
+ return outputs
111
+
112
+
113
+ class GSAPreTrainedModel(PreTrainedModel):
114
+
115
+ config_class = GSAConfig
116
+ base_model_prefix = 'model'
117
+ supports_gradient_checkpointing = True
118
+ _no_split_modules = ['GSABlock']
119
+ _supports_cache_class = True
120
+
121
+ def __init__(self, *inputs, **kwargs):
122
+ super().__init__(*inputs, **kwargs)
123
+
124
+ def _init_weights(
125
+ self,
126
+ module: nn.Module,
127
+ prenorm_residual_strategy: Optional[str] = 'rescale',
128
+ num_residuals_per_layer: int = 2,
129
+ ):
130
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
131
+ # Slightly different from the TF version which uses truncated_normal for initialization
132
+ # cf https://github.com/pytorch/pytorch/pull/5617
133
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
134
+ if module.bias is not None:
135
+ nn.init.zeros_(module.bias)
136
+ elif isinstance(module, nn.Embedding):
137
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
138
+ elif hasattr(module, 'reset_parameters'):
139
+ module.reset_parameters()
140
+
141
+ if prenorm_residual_strategy is not None:
142
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
143
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
144
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
145
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
146
+ #
147
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
148
+ p = None
149
+ if hasattr(module, 'o_proj'):
150
+ p = module.o_proj.weight
151
+ elif hasattr(module, 'down_proj'):
152
+ p = module.down_proj.weight
153
+ if p is not None:
154
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
155
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
156
+ # We need to reinit p since this code could be called multiple times
157
+ # Having just p *= scale would repeatedly scale it down
158
+ if prenorm_residual_strategy == 'rescale':
159
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
160
+ with torch.no_grad():
161
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
162
+ elif prenorm_residual_strategy == 'zero':
163
+ nn.init.zeros_(p)
164
+ else:
165
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
166
+
167
+
168
+ class GSAModel(GSAPreTrainedModel):
169
+
170
+ def __init__(self, config: GSAConfig):
171
+ super().__init__(config)
172
+ self.padding_idx = config.pad_token_id
173
+ self.vocab_size = config.vocab_size
174
+
175
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
176
+ self.layers = nn.ModuleList([GSABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
177
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
178
+
179
+ self.gradient_checkpointing = False
180
+
181
+ self.post_init()
182
+
183
+ def get_input_embeddings(self):
184
+ return self.embeddings
185
+
186
+ def set_input_embeddings(self, value):
187
+ self.embeddings = value
188
+
189
+ def forward(
190
+ self,
191
+ input_ids: Optional[torch.LongTensor] = None,
192
+ attention_mask: Optional[torch.Tensor] = None, # noqa
193
+ inputs_embeds: Optional[torch.FloatTensor] = None,
194
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
195
+ use_cache: Optional[bool] = None,
196
+ output_attentions: Optional[bool] = None,
197
+ output_hidden_states: Optional[bool] = None,
198
+ return_dict: Optional[bool] = None,
199
+ **kwargs: Unpack[Dict]
200
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
201
+ if output_attentions:
202
+ warnings.warn("`GSAModel` does not `output_attentions` now, setting it to `False`.")
203
+ output_attentions = False
204
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
205
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
206
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
207
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
208
+
209
+ # retrieve input_ids and inputs_embeds
210
+ if input_ids is not None and inputs_embeds is not None:
211
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
212
+ if input_ids is None and inputs_embeds is None:
213
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
214
+
215
+ if inputs_embeds is None:
216
+ inputs_embeds = self.embeddings(input_ids)
217
+ hidden_states = inputs_embeds
218
+
219
+ if use_cache and not isinstance(past_key_values, Cache):
220
+ past_key_values = Cache.from_legacy_cache(past_key_values)
221
+
222
+ if self.gradient_checkpointing and self.training and use_cache:
223
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
224
+ use_cache = False
225
+
226
+ all_hidden_states = () if output_hidden_states else None
227
+ all_attns = () if output_attentions else None
228
+ for layer in self.layers:
229
+ if output_hidden_states:
230
+ all_hidden_states += (hidden_states,)
231
+
232
+ if self.gradient_checkpointing and self.training:
233
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
234
+ layer.__call__,
235
+ hidden_states,
236
+ attention_mask,
237
+ past_key_values,
238
+ use_cache,
239
+ output_attentions,
240
+ **kwargs
241
+ )
242
+ else:
243
+ hidden_states, attentions, past_key_values = layer(
244
+ hidden_states,
245
+ attention_mask=attention_mask,
246
+ past_key_values=past_key_values,
247
+ use_cache=use_cache,
248
+ output_attentions=output_attentions,
249
+ **kwargs
250
+ )
251
+
252
+ if output_attentions:
253
+ all_attns += (attentions,)
254
+
255
+ hidden_states = self.norm(hidden_states)
256
+
257
+ # add hidden states from the last decoder layer
258
+ if output_hidden_states:
259
+ all_hidden_states += (hidden_states,)
260
+
261
+ if not return_dict:
262
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
263
+ return BaseModelOutputWithPast(
264
+ last_hidden_state=hidden_states,
265
+ past_key_values=past_key_values,
266
+ hidden_states=all_hidden_states,
267
+ attentions=all_attns
268
+ )
269
+
270
+
271
+ class GSAForCausalLM(GSAPreTrainedModel, GenerationMixin):
272
+
273
+ _tied_weights_keys = ["lm_head.weight"]
274
+
275
+ def __init__(self, config):
276
+
277
+ super().__init__(config)
278
+ self.model = GSAModel(config)
279
+ self.vocab_size = config.vocab_size
280
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
281
+ self.criterion = None
282
+
283
+ # Initialize weights and apply final processing
284
+ self.post_init()
285
+
286
+ def get_input_embeddings(self):
287
+ return self.model.embeddings
288
+
289
+ def set_input_embeddings(self, value):
290
+ self.model.embeddings = value
291
+
292
+ def get_output_embeddings(self):
293
+ return self.lm_head
294
+
295
+ def set_output_embeddings(self, new_embeddings):
296
+ self.lm_head = new_embeddings
297
+
298
+ def set_decoder(self, decoder):
299
+ self.model = decoder
300
+
301
+ def get_decoder(self):
302
+ return self.model
303
+
304
+ def generate(self, *args, **kwargs):
305
+ try:
306
+ return super().generate(*args, **kwargs)
307
+ except AttributeError as exception:
308
+ if 'past_key_values' in str(exception):
309
+ raise AttributeError(
310
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
311
+ f"which is not supported for {self.__class__.__name__}. "
312
+ f"Try another generation strategy instead. "
313
+ f"For the available generation strategies, check this doc: "
314
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
315
+ )
316
+ else:
317
+ raise exception
318
+
319
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
320
+ def prepare_inputs_for_generation(
321
+ self,
322
+ input_ids: torch.LongTensor = None,
323
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
324
+ attention_mask: Optional[torch.Tensor] = None,
325
+ inputs_embeds: Optional[torch.Tensor] = None,
326
+ use_cache: bool = True,
327
+ logits_to_keep: Optional[int] = None,
328
+ **kwargs
329
+ ):
330
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
331
+ if past_key_values is not None and len(past_key_values) > 0:
332
+ input_ids = input_ids[:, -1:]
333
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
334
+ if inputs_embeds is not None and len(past_key_values) == 0:
335
+ model_inputs = {'inputs_embeds': inputs_embeds}
336
+ else:
337
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
338
+ # recompiles graphs as the stride of the inputs is a guard.
339
+ # Ref: https://github.com/huggingface/transformers/pull/29114
340
+ # TODO: use `next_tokens` directly instead.
341
+ model_inputs = {'input_ids': input_ids.contiguous()}
342
+
343
+ if logits_to_keep is not None:
344
+ model_inputs['logits_to_keep'] = logits_to_keep
345
+
346
+ model_inputs.update({
347
+ 'past_key_values': past_key_values,
348
+ 'use_cache': use_cache,
349
+ 'attention_mask': attention_mask,
350
+ })
351
+ return model_inputs
352
+
353
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
354
+ def forward(
355
+ self,
356
+ input_ids: torch.LongTensor = None,
357
+ attention_mask: Optional[torch.Tensor] = None,
358
+ inputs_embeds: Optional[torch.Tensor] = None,
359
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
360
+ labels: Optional[torch.LongTensor] = None,
361
+ use_cache: Optional[bool] = None,
362
+ output_attentions: Optional[bool] = None,
363
+ output_hidden_states: Optional[bool] = None,
364
+ return_dict: Optional[bool] = None,
365
+ logits_to_keep: Optional[int] = 0,
366
+ **kwargs: Unpack[Dict]
367
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
368
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
369
+ output_hidden_states = (
370
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
371
+ )
372
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
373
+
374
+ outputs = self.model(
375
+ input_ids=input_ids,
376
+ attention_mask=attention_mask,
377
+ inputs_embeds=inputs_embeds,
378
+ past_key_values=past_key_values,
379
+ use_cache=use_cache,
380
+ output_attentions=output_attentions,
381
+ output_hidden_states=output_hidden_states,
382
+ return_dict=return_dict,
383
+ **kwargs
384
+ )
385
+
386
+ hidden_states = outputs[0]
387
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
388
+
389
+ loss, logits = None, None
390
+ if not fuse_linear_and_cross_entropy or labels is None:
391
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
392
+ if labels is not None:
393
+ if getattr(self, 'criterion', None) is None:
394
+ if fuse_linear_and_cross_entropy:
395
+ criterion = FusedLinearCrossEntropyLoss()
396
+ elif self.config.fuse_cross_entropy:
397
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
398
+ else:
399
+ criterion = nn.CrossEntropyLoss()
400
+ else:
401
+ criterion = self.criterion
402
+ # Enable model parallelism
403
+ labels = labels.to(hidden_states.device)
404
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
405
+ if fuse_linear_and_cross_entropy:
406
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
407
+ else:
408
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
409
+
410
+ if not return_dict:
411
+ output = (logits,) + outputs[1:]
412
+ return (loss,) + output if loss is not None else output
413
+
414
+ return CausalLMOutputWithPast(
415
+ loss=loss,
416
+ logits=logits,
417
+ past_key_values=outputs.past_key_values,
418
+ hidden_states=outputs.hidden_states,
419
+ attentions=outputs.attentions,
420
+ )
fla/models/hgrn/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.hgrn.configuration_hgrn import HGRNConfig
6
+ from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel
7
+
8
+ AutoConfig.register(HGRNConfig.model_type, HGRNConfig)
9
+ AutoModel.register(HGRNConfig, HGRNModel)
10
+ AutoModelForCausalLM.register(HGRNConfig, HGRNForCausalLM)
11
+
12
+
13
+ __all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel']
fla/models/hgrn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (665 Bytes). View file
 
fla/models/hgrn/__pycache__/configuration_hgrn.cpython-312.pyc ADDED
Binary file (3.28 kB). View file
 
fla/models/hgrn/configuration_hgrn.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class HGRNConfig(PretrainedConfig):
9
+
10
+ model_type = 'hgrn'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "fused_recurrent",
16
+ hidden_size: int = 2048,
17
+ num_hidden_layers: int = 24,
18
+ expand_ratio: Optional[int] = 1,
19
+ use_short_conv: bool = False,
20
+ conv_size: int = 4,
21
+ use_lower_bound: bool = True,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ elementwise_affine: Optional[bool] = True,
27
+ norm_eps: float = 1e-6,
28
+ attn: Optional[Dict] = None,
29
+ use_cache: bool = True,
30
+ pad_token_id: int = None,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ tie_word_embeddings: bool = False,
34
+ initializer_range: float = 0.006,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ **kwargs
40
+ ):
41
+ self.attn_mode = attn_mode
42
+ self.hidden_size = hidden_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.expand_ratio = expand_ratio
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.use_lower_bound = use_lower_bound
48
+ self.max_position_embeddings = max_position_embeddings
49
+ self.hidden_ratio = hidden_ratio
50
+ self.intermediate_size = intermediate_size
51
+ self.elementwise_affine = elementwise_affine
52
+ self.attn = attn
53
+ self.norm_eps = norm_eps
54
+ self.hidden_act = hidden_act
55
+ self.use_cache = use_cache
56
+ self.initializer_range = initializer_range
57
+
58
+ self.fuse_norm = fuse_norm
59
+ self.fuse_swiglu = fuse_swiglu
60
+ self.fuse_cross_entropy = fuse_cross_entropy
61
+ self.vocab_size = vocab_size
62
+
63
+ if attn is not None:
64
+ if not isinstance(attn, Dict):
65
+ raise ValueError("attn must be a dictionary")
66
+ if 'layers' not in attn:
67
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
68
+ if 'num_heads' not in attn:
69
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
70
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
71
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
72
+ attn['window_size'] = attn.get('window_size', None)
73
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
74
+
75
+ super().__init__(
76
+ pad_token_id=pad_token_id,
77
+ bos_token_id=bos_token_id,
78
+ eos_token_id=eos_token_id,
79
+ tie_word_embeddings=tie_word_embeddings,
80
+ **kwargs,
81
+ )
fla/models/hgrn/modeling_hgrn.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.hgrn import HGRNAttention
20
+ from fla.models.hgrn.configuration_hgrn import HGRNConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as HGRNMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class HGRNBlock(nn.Module):
33
+ def __init__(self, config: HGRNConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = HGRNAttention(
53
+ mode=config.attn_mode,
54
+ hidden_size=config.hidden_size,
55
+ expand_ratio=config.expand_ratio,
56
+ use_short_conv=config.use_short_conv,
57
+ conv_size=config.conv_size,
58
+ elementwise_affine=config.elementwise_affine,
59
+ norm_eps=config.norm_eps,
60
+ layer_idx=layer_idx
61
+ )
62
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
63
+ self.mlp = HGRNMLP(
64
+ hidden_size=config.hidden_size,
65
+ hidden_ratio=config.hidden_ratio,
66
+ intermediate_size=config.intermediate_size,
67
+ hidden_act=config.hidden_act,
68
+ fuse_swiglu=config.fuse_swiglu
69
+ )
70
+
71
+ def forward(
72
+ self,
73
+ hidden_states: torch.Tensor,
74
+ attention_mask: Optional[torch.Tensor] = None,
75
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
76
+ use_cache: Optional[bool] = False,
77
+ output_attentions: Optional[bool] = False,
78
+ lower_bound: Optional[torch.Tensor] = False,
79
+ **kwargs: Unpack[Dict]
80
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
81
+ residual = hidden_states
82
+ hidden_states = self.attn_norm(hidden_states)
83
+ hidden_states, attentions, past_key_values = self.attn(
84
+ hidden_states=hidden_states,
85
+ attention_mask=attention_mask,
86
+ past_key_values=past_key_values,
87
+ use_cache=use_cache,
88
+ output_attentions=output_attentions,
89
+ lower_bound=lower_bound,
90
+ **kwargs
91
+ )
92
+ if self.config.fuse_norm:
93
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
94
+ else:
95
+ hidden_states = residual + hidden_states
96
+ residual = hidden_states
97
+ hidden_states = self.mlp_norm(hidden_states)
98
+ hidden_states = self.mlp(hidden_states, **kwargs)
99
+ hidden_states = residual + hidden_states
100
+
101
+ outputs = (hidden_states, attentions, past_key_values)
102
+
103
+ return outputs
104
+
105
+
106
+ class HGRNPreTrainedModel(PreTrainedModel):
107
+
108
+ config_class = HGRNConfig
109
+ base_model_prefix = 'model'
110
+ supports_gradient_checkpointing = True
111
+ _no_split_modules = ['HGRNBlock']
112
+ _supports_cache_class = True
113
+
114
+ def __init__(self, *inputs, **kwargs):
115
+ super().__init__(*inputs, **kwargs)
116
+
117
+ def _init_weights(
118
+ self,
119
+ module: nn.Module,
120
+ prenorm_residual_strategy: Optional[str] = 'rescale',
121
+ num_residuals_per_layer: int = 2,
122
+ ):
123
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
124
+ # Slightly different from the TF version which uses truncated_normal for initialization
125
+ # cf https://github.com/pytorch/pytorch/pull/5617
126
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
127
+ if module.bias is not None:
128
+ nn.init.zeros_(module.bias)
129
+ elif isinstance(module, nn.Embedding):
130
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
131
+ elif hasattr(module, 'reset_parameters'):
132
+ module.reset_parameters()
133
+
134
+ if prenorm_residual_strategy is not None:
135
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
136
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
137
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
138
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
139
+ #
140
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
141
+ p = None
142
+ if hasattr(module, 'o_proj'):
143
+ p = module.o_proj.weight
144
+ elif hasattr(module, 'down_proj'):
145
+ p = module.down_proj.weight
146
+ if p is not None:
147
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
148
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
149
+ # We need to reinit p since this code could be called multiple times
150
+ # Having just p *= scale would repeatedly scale it down
151
+ if prenorm_residual_strategy == 'rescale':
152
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
153
+ with torch.no_grad():
154
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
155
+ elif prenorm_residual_strategy == 'zero':
156
+ nn.init.zeros_(p)
157
+ else:
158
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
159
+
160
+
161
+ class HGRNModel(HGRNPreTrainedModel):
162
+
163
+ def __init__(self, config: HGRNConfig):
164
+ super().__init__(config)
165
+ self.padding_idx = config.pad_token_id
166
+ self.vocab_size = config.vocab_size
167
+
168
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
169
+ if config.use_lower_bound:
170
+ self.lower_bounds = nn.Parameter(torch.zeros(config.num_hidden_layers, config.hidden_size))
171
+ self.layers = nn.ModuleList([HGRNBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
172
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
173
+
174
+ self.gradient_checkpointing = False
175
+
176
+ self.post_init()
177
+
178
+ def get_input_embeddings(self):
179
+ return self.embeddings
180
+
181
+ def set_input_embeddings(self, value):
182
+ self.embeddings = value
183
+
184
+ def forward(
185
+ self,
186
+ input_ids: Optional[torch.LongTensor] = None,
187
+ attention_mask: Optional[torch.Tensor] = None, # noqa
188
+ inputs_embeds: Optional[torch.FloatTensor] = None,
189
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
190
+ use_cache: Optional[bool] = None,
191
+ output_attentions: Optional[bool] = None,
192
+ output_hidden_states: Optional[bool] = None,
193
+ return_dict: Optional[bool] = None,
194
+ **kwargs: Unpack[Dict]
195
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
196
+ if output_attentions:
197
+ warnings.warn("`HGRNModel` does not `output_attentions` now, setting it to `False`.")
198
+ output_attentions = False
199
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
200
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
201
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
202
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
203
+
204
+ # retrieve input_ids and inputs_embeds
205
+ if input_ids is not None and inputs_embeds is not None:
206
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
207
+ if input_ids is None and inputs_embeds is None:
208
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
209
+
210
+ if inputs_embeds is None:
211
+ inputs_embeds = self.embeddings(input_ids)
212
+ hidden_states = inputs_embeds
213
+
214
+ if use_cache and not isinstance(past_key_values, Cache):
215
+ past_key_values = Cache.from_legacy_cache(past_key_values)
216
+
217
+ if self.gradient_checkpointing and self.training and use_cache:
218
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
219
+ use_cache = False
220
+
221
+ all_hidden_states = () if output_hidden_states else None
222
+ all_attns = () if output_attentions else None
223
+
224
+ if self.config.use_lower_bound:
225
+ lower_bounds = self.lower_bounds.softmax(0)
226
+ lower_bounds = lower_bounds.cumsum(0) - lower_bounds[0]
227
+ for i, layer in enumerate(self.layers):
228
+ if output_hidden_states:
229
+ all_hidden_states += (hidden_states,)
230
+
231
+ lower_bound = lower_bounds[i] if self.config.use_lower_bound else None
232
+ if self.gradient_checkpointing and self.training:
233
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
234
+ layer.__call__,
235
+ hidden_states,
236
+ attention_mask,
237
+ past_key_values,
238
+ use_cache,
239
+ output_attentions,
240
+ lower_bound,
241
+ **kwargs
242
+ )
243
+ else:
244
+ hidden_states, attentions, past_key_values = layer(
245
+ hidden_states,
246
+ attention_mask=attention_mask,
247
+ past_key_values=past_key_values,
248
+ use_cache=use_cache,
249
+ output_attentions=output_attentions,
250
+ lower_bound=lower_bound,
251
+ **kwargs
252
+ )
253
+
254
+ if output_attentions:
255
+ all_attns += (attentions,)
256
+
257
+ hidden_states = self.norm(hidden_states)
258
+
259
+ # add hidden states from the last decoder layer
260
+ if output_hidden_states:
261
+ all_hidden_states += (hidden_states,)
262
+
263
+ if not return_dict:
264
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
265
+ return BaseModelOutputWithPast(
266
+ last_hidden_state=hidden_states,
267
+ past_key_values=past_key_values,
268
+ hidden_states=all_hidden_states,
269
+ attentions=all_attns
270
+ )
271
+
272
+
273
+ class HGRNForCausalLM(HGRNPreTrainedModel, GenerationMixin):
274
+
275
+ _tied_weights_keys = ["lm_head.weight"]
276
+
277
+ def __init__(self, config):
278
+ super().__init__(config)
279
+ self.model = HGRNModel(config)
280
+ self.vocab_size = config.vocab_size
281
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
282
+ self.criterion = None
283
+
284
+ # Initialize weights and apply final processing
285
+ self.post_init()
286
+
287
+ def get_input_embeddings(self):
288
+ return self.model.embeddings
289
+
290
+ def set_input_embeddings(self, value):
291
+ self.model.embeddings = value
292
+
293
+ def get_output_embeddings(self):
294
+ return self.lm_head
295
+
296
+ def set_output_embeddings(self, new_embeddings):
297
+ self.lm_head = new_embeddings
298
+
299
+ def set_decoder(self, decoder):
300
+ self.model = decoder
301
+
302
+ def get_decoder(self):
303
+ return self.model
304
+
305
+ def generate(self, *args, **kwargs):
306
+ try:
307
+ return super().generate(*args, **kwargs)
308
+ except AttributeError as exception:
309
+ if 'past_key_values' in str(exception):
310
+ raise AttributeError(
311
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
312
+ f"which is not supported for {self.__class__.__name__}. "
313
+ f"Try another generation strategy instead. "
314
+ f"For the available generation strategies, check this doc: "
315
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
316
+ )
317
+ else:
318
+ raise exception
319
+
320
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
321
+ def prepare_inputs_for_generation(
322
+ self,
323
+ input_ids: torch.LongTensor = None,
324
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
325
+ attention_mask: Optional[torch.Tensor] = None,
326
+ inputs_embeds: Optional[torch.Tensor] = None,
327
+ use_cache: bool = True,
328
+ logits_to_keep: Optional[int] = None,
329
+ **kwargs: Unpack[Dict]
330
+ ):
331
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
332
+ if past_key_values is not None and len(past_key_values) > 0:
333
+ input_ids = input_ids[:, -1:]
334
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
335
+ if inputs_embeds is not None and len(past_key_values) == 0:
336
+ model_inputs = {'inputs_embeds': inputs_embeds}
337
+ else:
338
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
339
+ # recompiles graphs as the stride of the inputs is a guard.
340
+ # Ref: https://github.com/huggingface/transformers/pull/29114
341
+ # TODO: use `next_tokens` directly instead.
342
+ model_inputs = {'input_ids': input_ids.contiguous()}
343
+
344
+ if logits_to_keep is not None:
345
+ model_inputs['logits_to_keep'] = logits_to_keep
346
+
347
+ model_inputs.update({
348
+ 'past_key_values': past_key_values,
349
+ 'use_cache': use_cache,
350
+ 'attention_mask': attention_mask,
351
+ })
352
+ return model_inputs
353
+
354
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
355
+ def forward(
356
+ self,
357
+ input_ids: torch.LongTensor = None,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ inputs_embeds: Optional[torch.Tensor] = None,
360
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
361
+ labels: Optional[torch.LongTensor] = None,
362
+ use_cache: Optional[bool] = None,
363
+ output_attentions: Optional[bool] = None,
364
+ output_hidden_states: Optional[bool] = None,
365
+ return_dict: Optional[bool] = None,
366
+ logits_to_keep: Optional[int] = 0,
367
+ **kwargs: Unpack[Dict]
368
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
369
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
370
+ output_hidden_states = (
371
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
372
+ )
373
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
374
+
375
+ outputs = self.model(
376
+ input_ids=input_ids,
377
+ attention_mask=attention_mask,
378
+ inputs_embeds=inputs_embeds,
379
+ past_key_values=past_key_values,
380
+ use_cache=use_cache,
381
+ output_attentions=output_attentions,
382
+ output_hidden_states=output_hidden_states,
383
+ return_dict=return_dict,
384
+ **kwargs
385
+ )
386
+
387
+ hidden_states = outputs[0]
388
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
389
+
390
+ loss, logits = None, None
391
+ if not fuse_linear_and_cross_entropy or labels is None:
392
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
393
+ if labels is not None:
394
+ if getattr(self, 'criterion', None) is None:
395
+ if fuse_linear_and_cross_entropy:
396
+ criterion = FusedLinearCrossEntropyLoss()
397
+ elif self.config.fuse_cross_entropy:
398
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
399
+ else:
400
+ criterion = nn.CrossEntropyLoss()
401
+ else:
402
+ criterion = self.criterion
403
+ labels = labels.to(hidden_states.device)
404
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
405
+ if fuse_linear_and_cross_entropy:
406
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
407
+ else:
408
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
409
+
410
+ if not return_dict:
411
+ output = (logits,) + outputs[1:]
412
+ return (loss,) + output if loss is not None else output
413
+
414
+ return CausalLMOutputWithPast(
415
+ loss=loss,
416
+ logits=logits,
417
+ past_key_values=outputs.past_key_values,
418
+ hidden_states=outputs.hidden_states,
419
+ attentions=outputs.attentions,
420
+ )
fla/models/hgrn2/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (674 Bytes). View file
 
fla/models/hgrn2/__pycache__/configuration_hgrn2.cpython-312.pyc ADDED
Binary file (3.55 kB). View file
 
fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc ADDED
Binary file (18.9 kB). View file
 
fla/models/hgrn2/modeling_hgrn2.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.hgrn2 import HGRN2Attention
20
+ from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as HGRN2MLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class HGRN2Block(nn.Module):
33
+ def __init__(self, config: HGRN2Config, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = HGRN2Attention(
53
+ mode=config.attn_mode,
54
+ hidden_size=config.hidden_size,
55
+ num_heads=config.num_heads,
56
+ expand_ratio=config.expand_ratio,
57
+ use_short_conv=config.use_short_conv,
58
+ conv_size=config.conv_size,
59
+ elementwise_affine=config.elementwise_affine,
60
+ norm_eps=config.norm_eps,
61
+ layer_idx=layer_idx
62
+ )
63
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
64
+ self.mlp = HGRN2MLP(
65
+ hidden_size=config.hidden_size,
66
+ hidden_ratio=config.hidden_ratio,
67
+ intermediate_size=config.intermediate_size,
68
+ hidden_act=config.hidden_act,
69
+ fuse_swiglu=config.fuse_swiglu
70
+ )
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.Tensor] = None,
76
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
77
+ use_cache: Optional[bool] = False,
78
+ output_attentions: Optional[bool] = False,
79
+ lower_bound: Optional[torch.Tensor] = False,
80
+ **kwargs: Unpack[Dict]
81
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
82
+ residual = hidden_states
83
+ hidden_states = self.attn_norm(hidden_states)
84
+ hidden_states, attentions, past_key_values = self.attn(
85
+ hidden_states=hidden_states,
86
+ attention_mask=attention_mask,
87
+ past_key_values=past_key_values,
88
+ use_cache=use_cache,
89
+ output_attentions=output_attentions,
90
+ lower_bound=lower_bound,
91
+ **kwargs
92
+ )
93
+ if self.config.fuse_norm:
94
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
95
+ else:
96
+ hidden_states = residual + hidden_states
97
+ residual = hidden_states
98
+ hidden_states = self.mlp_norm(hidden_states)
99
+ hidden_states = self.mlp(hidden_states, **kwargs)
100
+ hidden_states = residual + hidden_states
101
+
102
+ outputs = (hidden_states, attentions, past_key_values)
103
+
104
+ return outputs
105
+
106
+
107
+ class HGRN2PreTrainedModel(PreTrainedModel):
108
+
109
+ config_class = HGRN2Config
110
+ base_model_prefix = 'model'
111
+ supports_gradient_checkpointing = True
112
+ _no_split_modules = ['HGRN2Block']
113
+ _supports_cache_class = True
114
+
115
+ def __init__(self, *inputs, **kwargs):
116
+ super().__init__(*inputs, **kwargs)
117
+
118
+ def _init_weights(
119
+ self,
120
+ module: nn.Module,
121
+ prenorm_residual_strategy: Optional[str] = 'rescale',
122
+ num_residuals_per_layer: int = 2,
123
+ ):
124
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
125
+ # Slightly different from the TF version which uses truncated_normal for initialization
126
+ # cf https://github.com/pytorch/pytorch/pull/5617
127
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
128
+ if module.bias is not None:
129
+ nn.init.zeros_(module.bias)
130
+ elif isinstance(module, nn.Embedding):
131
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
132
+ elif hasattr(module, 'reset_parameters'):
133
+ module.reset_parameters()
134
+
135
+ if prenorm_residual_strategy is not None:
136
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
137
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
138
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
139
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
140
+ #
141
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
142
+ p = None
143
+ if hasattr(module, 'o_proj'):
144
+ p = module.o_proj.weight
145
+ elif hasattr(module, 'down_proj'):
146
+ p = module.down_proj.weight
147
+ if p is not None:
148
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
149
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
150
+ # We need to reinit p since this code could be called multiple times
151
+ # Having just p *= scale would repeatedly scale it down
152
+ if prenorm_residual_strategy == 'rescale':
153
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
154
+ with torch.no_grad():
155
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
156
+ elif prenorm_residual_strategy == 'zero':
157
+ nn.init.zeros_(p)
158
+ else:
159
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
160
+
161
+
162
+ class HGRN2Model(HGRN2PreTrainedModel):
163
+
164
+ def __init__(self, config: HGRN2Config):
165
+ super().__init__(config)
166
+ self.padding_idx = config.pad_token_id
167
+ self.vocab_size = config.vocab_size
168
+
169
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
170
+ if config.use_lower_bound:
171
+ self.lower_bounds = nn.Parameter(torch.zeros(config.num_hidden_layers, config.hidden_size))
172
+ self.layers = nn.ModuleList([HGRN2Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
173
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
174
+
175
+ self.gradient_checkpointing = False
176
+
177
+ self.post_init()
178
+
179
+ def get_input_embeddings(self):
180
+ return self.embeddings
181
+
182
+ def set_input_embeddings(self, value):
183
+ self.embeddings = value
184
+
185
+ def forward(
186
+ self,
187
+ input_ids: Optional[torch.LongTensor] = None,
188
+ attention_mask: Optional[torch.Tensor] = None, # noqa
189
+ inputs_embeds: Optional[torch.FloatTensor] = None,
190
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
191
+ use_cache: Optional[bool] = None,
192
+ output_attentions: Optional[bool] = None,
193
+ output_hidden_states: Optional[bool] = None,
194
+ return_dict: Optional[bool] = None,
195
+ **kwargs: Unpack[Dict]
196
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
197
+ if output_attentions:
198
+ warnings.warn("`HGRN2Model` does not `output_attentions` now, setting it to `False`.")
199
+ output_attentions = False
200
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
201
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
202
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
203
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
204
+
205
+ # retrieve input_ids and inputs_embeds
206
+ if input_ids is not None and inputs_embeds is not None:
207
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
208
+ if input_ids is None and inputs_embeds is None:
209
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
210
+
211
+ if inputs_embeds is None:
212
+ inputs_embeds = self.embeddings(input_ids)
213
+ hidden_states = inputs_embeds
214
+
215
+ if use_cache and not isinstance(past_key_values, Cache):
216
+ past_key_values = Cache.from_legacy_cache(past_key_values)
217
+
218
+ if self.gradient_checkpointing and self.training and use_cache:
219
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
220
+ use_cache = False
221
+
222
+ all_hidden_states = () if output_hidden_states else None
223
+ all_attns = () if output_attentions else None
224
+
225
+ if self.config.use_lower_bound:
226
+ lower_bounds = self.lower_bounds.softmax(0)
227
+ lower_bounds = lower_bounds.cumsum(0) - lower_bounds[0]
228
+ for i, layer in enumerate(self.layers):
229
+ if output_hidden_states:
230
+ all_hidden_states += (hidden_states,)
231
+
232
+ lower_bound = lower_bounds[i] if self.config.use_lower_bound else None
233
+ if self.gradient_checkpointing and self.training:
234
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
235
+ layer.__call__,
236
+ hidden_states,
237
+ attention_mask,
238
+ past_key_values,
239
+ use_cache,
240
+ output_attentions,
241
+ lower_bound,
242
+ **kwargs
243
+ )
244
+ else:
245
+ hidden_states, attentions, past_key_values = layer(
246
+ hidden_states,
247
+ attention_mask=attention_mask,
248
+ past_key_values=past_key_values,
249
+ use_cache=use_cache,
250
+ output_attentions=output_attentions,
251
+ lower_bound=lower_bound,
252
+ **kwargs
253
+ )
254
+
255
+ if output_attentions:
256
+ all_attns += (attentions,)
257
+
258
+ hidden_states = self.norm(hidden_states)
259
+
260
+ # add hidden states from the last decoder layer
261
+ if output_hidden_states:
262
+ all_hidden_states += (hidden_states,)
263
+
264
+ if not return_dict:
265
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
266
+ return BaseModelOutputWithPast(
267
+ last_hidden_state=hidden_states,
268
+ past_key_values=past_key_values,
269
+ hidden_states=all_hidden_states,
270
+ attentions=all_attns
271
+ )
272
+
273
+
274
+ class HGRN2ForCausalLM(HGRN2PreTrainedModel, GenerationMixin):
275
+
276
+ _tied_weights_keys = ["lm_head.weight"]
277
+
278
+ def __init__(self, config):
279
+ super().__init__(config)
280
+ self.model = HGRN2Model(config)
281
+ self.vocab_size = config.vocab_size
282
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
283
+ self.criterion = None
284
+
285
+ # Initialize weights and apply final processing
286
+ self.post_init()
287
+
288
+ def get_input_embeddings(self):
289
+ return self.model.embeddings
290
+
291
+ def set_input_embeddings(self, value):
292
+ self.model.embeddings = value
293
+
294
+ def get_output_embeddings(self):
295
+ return self.lm_head
296
+
297
+ def set_output_embeddings(self, new_embeddings):
298
+ self.lm_head = new_embeddings
299
+
300
+ def set_decoder(self, decoder):
301
+ self.model = decoder
302
+
303
+ def get_decoder(self):
304
+ return self.model
305
+
306
+ def generate(self, *args, **kwargs):
307
+ try:
308
+ return super().generate(*args, **kwargs)
309
+ except AttributeError as exception:
310
+ if 'past_key_values' in str(exception):
311
+ raise AttributeError(
312
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
313
+ f"which is not supported for {self.__class__.__name__}. "
314
+ f"Try another generation strategy instead. "
315
+ f"For the available generation strategies, check this doc: "
316
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
317
+ )
318
+ else:
319
+ raise exception
320
+
321
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
322
+ def prepare_inputs_for_generation(
323
+ self,
324
+ input_ids: torch.LongTensor = None,
325
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
326
+ attention_mask: Optional[torch.Tensor] = None,
327
+ inputs_embeds: Optional[torch.Tensor] = None,
328
+ use_cache: bool = True,
329
+ logits_to_keep: Optional[int] = None,
330
+ **kwargs: Unpack[Dict]
331
+ ):
332
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
333
+ if past_key_values is not None and len(past_key_values) > 0:
334
+ input_ids = input_ids[:, -1:]
335
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
336
+ if inputs_embeds is not None and len(past_key_values) == 0:
337
+ model_inputs = {'inputs_embeds': inputs_embeds}
338
+ else:
339
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
340
+ # recompiles graphs as the stride of the inputs is a guard.
341
+ # Ref: https://github.com/huggingface/transformers/pull/29114
342
+ # TODO: use `next_tokens` directly instead.
343
+ model_inputs = {'input_ids': input_ids.contiguous()}
344
+
345
+ if logits_to_keep is not None:
346
+ model_inputs['logits_to_keep'] = logits_to_keep
347
+
348
+ model_inputs.update({
349
+ 'past_key_values': past_key_values,
350
+ 'use_cache': use_cache,
351
+ 'attention_mask': attention_mask,
352
+ })
353
+ return model_inputs
354
+
355
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
356
+ def forward(
357
+ self,
358
+ input_ids: torch.LongTensor = None,
359
+ attention_mask: Optional[torch.Tensor] = None,
360
+ inputs_embeds: Optional[torch.Tensor] = None,
361
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
362
+ labels: Optional[torch.LongTensor] = None,
363
+ use_cache: Optional[bool] = None,
364
+ output_attentions: Optional[bool] = None,
365
+ output_hidden_states: Optional[bool] = None,
366
+ return_dict: Optional[bool] = None,
367
+ logits_to_keep: Optional[int] = 0,
368
+ **kwargs: Unpack[Dict]
369
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
370
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
371
+ output_hidden_states = (
372
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
373
+ )
374
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
375
+
376
+ outputs = self.model(
377
+ input_ids=input_ids,
378
+ attention_mask=attention_mask,
379
+ inputs_embeds=inputs_embeds,
380
+ past_key_values=past_key_values,
381
+ use_cache=use_cache,
382
+ output_attentions=output_attentions,
383
+ output_hidden_states=output_hidden_states,
384
+ return_dict=return_dict,
385
+ **kwargs
386
+ )
387
+
388
+ hidden_states = outputs[0]
389
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
390
+
391
+ loss, logits = None, None
392
+ if not fuse_linear_and_cross_entropy or labels is None:
393
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
394
+ if labels is not None:
395
+ if getattr(self, 'criterion', None) is None:
396
+ if fuse_linear_and_cross_entropy:
397
+ criterion = FusedLinearCrossEntropyLoss()
398
+ elif self.config.fuse_cross_entropy:
399
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
400
+ else:
401
+ criterion = nn.CrossEntropyLoss()
402
+ else:
403
+ criterion = self.criterion
404
+ labels = labels.to(hidden_states.device)
405
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
406
+ if fuse_linear_and_cross_entropy:
407
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
408
+ else:
409
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
410
+
411
+ if not return_dict:
412
+ output = (logits,) + outputs[1:]
413
+ return (loss,) + output if loss is not None else output
414
+
415
+ return CausalLMOutputWithPast(
416
+ loss=loss,
417
+ logits=logits,
418
+ past_key_values=outputs.past_key_values,
419
+ hidden_states=outputs.hidden_states,
420
+ attentions=outputs.attentions,
421
+ )
fla/models/lightnet/__pycache__/configuration_lightnet.cpython-312.pyc ADDED
Binary file (3.36 kB). View file
 
fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc ADDED
Binary file (18.3 kB). View file
 
fla/models/lightnet/configuration_lightnet.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class LightNetConfig(PretrainedConfig):
9
+
10
+ model_type = 'lightnet'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ attn_mode: str = "chunk",
18
+ num_heads: Optional[int] = None,
19
+ expand_ratio: Optional[int] = 128,
20
+ use_short_conv: bool = False,
21
+ conv_size: int = 4,
22
+ hidden_ratio: Optional[int] = 4,
23
+ intermediate_size: Optional[int] = None,
24
+ hidden_act: str = "swish",
25
+ max_position_embeddings: int = 2048,
26
+ gate_low_rank_dim: int = 128,
27
+ elementwise_affine: Optional[bool] = True,
28
+ norm_eps: float = 1e-6,
29
+ attn: Optional[Dict] = None,
30
+ use_cache: bool = True,
31
+ pad_token_id: int = None,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ tie_word_embeddings: bool = False,
35
+ initializer_range: float = 0.006,
36
+ fuse_norm: bool = True,
37
+ fuse_swiglu: bool = True,
38
+ fuse_cross_entropy: bool = True,
39
+ vocab_size: int = 32000,
40
+ **kwargs
41
+ ):
42
+ self.hidden_size = hidden_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.attn_mode = attn_mode
45
+ self.num_heads = num_heads
46
+ self.expand_ratio = expand_ratio
47
+ self.use_short_conv = use_short_conv
48
+ self.conv_size = conv_size
49
+ self.max_position_embeddings = max_position_embeddings
50
+ self.gate_low_rank_dim = gate_low_rank_dim
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+ self.elementwise_affine = elementwise_affine
55
+ self.norm_eps = norm_eps
56
+ self.attn = attn
57
+ self.use_cache = use_cache
58
+ self.initializer_range = initializer_range
59
+
60
+ self.fuse_norm = fuse_norm
61
+ self.fuse_swiglu = fuse_swiglu
62
+ self.fuse_cross_entropy = fuse_cross_entropy
63
+ self.vocab_size = vocab_size
64
+
65
+ if attn is not None:
66
+ if not isinstance(attn, Dict):
67
+ raise ValueError("attn must be a dictionary")
68
+ if 'layers' not in attn:
69
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
70
+ if 'num_heads' not in attn:
71
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
72
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
73
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
74
+ attn['window_size'] = attn.get('window_size', None)
75
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
76
+
77
+ super().__init__(
78
+ pad_token_id=pad_token_id,
79
+ bos_token_id=bos_token_id,
80
+ eos_token_id=eos_token_id,
81
+ tie_word_embeddings=tie_word_embeddings,
82
+ **kwargs,
83
+ )
fla/models/lightnet/modeling_lightnet.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.lightnet import LightNetAttention
20
+ from fla.models.lightnet.configuration_lightnet import LightNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as LightNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class LightNetBlock(nn.Module):
33
+ def __init__(self, config: LightNetConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ max_position_embeddings=config.max_position_embeddings,
48
+ layer_idx=layer_idx
49
+ )
50
+ else:
51
+ self.attn = LightNetAttention(
52
+ mode=config.attn_mode,
53
+ hidden_size=config.hidden_size,
54
+ num_heads=config.num_heads,
55
+ expand_ratio=config.expand_ratio,
56
+ use_short_conv=config.use_short_conv,
57
+ conv_size=config.conv_size,
58
+ gate_low_rank_dim=config.gate_low_rank_dim,
59
+ elementwise_affine=config.elementwise_affine,
60
+ norm_eps=config.norm_eps,
61
+ layer_idx=layer_idx
62
+ )
63
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
64
+ self.mlp = LightNetMLP(
65
+ hidden_size=config.hidden_size,
66
+ hidden_ratio=config.hidden_ratio,
67
+ intermediate_size=config.intermediate_size,
68
+ hidden_act=config.hidden_act,
69
+ fuse_swiglu=config.fuse_swiglu
70
+ )
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.Tensor] = None,
76
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
77
+ use_cache: Optional[bool] = False,
78
+ output_attentions: Optional[bool] = False,
79
+ **kwargs: Unpack[Dict]
80
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
81
+ residual = hidden_states
82
+ hidden_states = self.attn_norm(hidden_states)
83
+ hidden_states, attentions, past_key_values = self.attn(
84
+ hidden_states=hidden_states,
85
+ attention_mask=attention_mask,
86
+ past_key_values=past_key_values,
87
+ use_cache=use_cache,
88
+ output_attentions=output_attentions,
89
+ **kwargs
90
+ )
91
+ if self.config.fuse_norm:
92
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
93
+ else:
94
+ hidden_states = residual + hidden_states
95
+ residual = hidden_states
96
+ hidden_states = self.mlp_norm(hidden_states)
97
+ hidden_states = self.mlp(hidden_states, **kwargs)
98
+ hidden_states = residual + hidden_states
99
+
100
+ outputs = (hidden_states, attentions, past_key_values)
101
+
102
+ return outputs
103
+
104
+
105
+ class LightNetPreTrainedModel(PreTrainedModel):
106
+
107
+ config_class = LightNetConfig
108
+ supports_gradient_checkpointing = True
109
+ _no_split_modules = ['LightNetBlock']
110
+ _supports_cache_class = True
111
+
112
+ def __init__(self, *inputs, **kwargs):
113
+ super().__init__(*inputs, **kwargs)
114
+
115
+ def _init_weights(
116
+ self,
117
+ module: nn.Module,
118
+ prenorm_residual_strategy: Optional[str] = 'rescale',
119
+ num_residuals_per_layer: int = 2,
120
+ ):
121
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
122
+ # Slightly different from the TF version which uses truncated_normal for initialization
123
+ # cf https://github.com/pytorch/pytorch/pull/5617
124
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
125
+ if module.bias is not None:
126
+ nn.init.zeros_(module.bias)
127
+ elif isinstance(module, nn.Embedding):
128
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
129
+ elif hasattr(module, 'reset_parameters'):
130
+ module.reset_parameters()
131
+
132
+ if prenorm_residual_strategy is not None:
133
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
134
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
135
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
136
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
137
+ #
138
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
139
+ p = None
140
+ if hasattr(module, 'o_proj'):
141
+ p = module.o_proj.weight
142
+ elif hasattr(module, 'down_proj'):
143
+ p = module.down_proj.weight
144
+ if p is not None:
145
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
146
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
147
+ # We need to reinit p since this code could be called multiple times
148
+ # Having just p *= scale would repeatedly scale it down
149
+ if prenorm_residual_strategy == 'rescale':
150
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
151
+ with torch.no_grad():
152
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
153
+ elif prenorm_residual_strategy == 'zero':
154
+ nn.init.zeros_(p)
155
+ else:
156
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
157
+
158
+
159
+ class LightNetModel(LightNetPreTrainedModel):
160
+
161
+ def __init__(self, config: LightNetConfig):
162
+ super().__init__(config)
163
+ self.padding_idx = config.pad_token_id
164
+ self.vocab_size = config.vocab_size
165
+
166
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
167
+ self.layers = nn.ModuleList([LightNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
168
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
169
+
170
+ self.gradient_checkpointing = False
171
+
172
+ self.post_init()
173
+
174
+ def get_input_embeddings(self):
175
+ return self.embeddings
176
+
177
+ def set_input_embeddings(self, value):
178
+ self.embeddings = value
179
+
180
+ def forward(
181
+ self,
182
+ input_ids: Optional[torch.LongTensor] = None,
183
+ attention_mask: Optional[torch.Tensor] = None, # noqa
184
+ inputs_embeds: Optional[torch.FloatTensor] = None,
185
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
186
+ use_cache: Optional[bool] = None,
187
+ output_attentions: Optional[bool] = None,
188
+ output_hidden_states: Optional[bool] = None,
189
+ return_dict: Optional[bool] = None,
190
+ **kwargs: Unpack[Dict]
191
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
192
+ if output_attentions:
193
+ warnings.warn("`LightNetModel` does not `output_attentions` now, setting it to `False`.")
194
+ output_attentions = False
195
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
196
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
197
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
198
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
199
+
200
+ # retrieve input_ids and inputs_embeds
201
+ if input_ids is not None and inputs_embeds is not None:
202
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
203
+ if input_ids is None and inputs_embeds is None:
204
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
205
+
206
+ if inputs_embeds is None:
207
+ inputs_embeds = self.embeddings(input_ids)
208
+ hidden_states = inputs_embeds
209
+
210
+ if use_cache and not isinstance(past_key_values, Cache):
211
+ past_key_values = Cache.from_legacy_cache(past_key_values)
212
+
213
+ if self.gradient_checkpointing and self.training and use_cache:
214
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
215
+ use_cache = False
216
+
217
+ all_hidden_states = () if output_hidden_states else None
218
+ all_attns = () if output_attentions else None
219
+
220
+ for i, layer in enumerate(self.layers):
221
+ if output_hidden_states:
222
+ all_hidden_states += (hidden_states,)
223
+
224
+ if self.gradient_checkpointing and self.training:
225
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
226
+ layer.__call__,
227
+ hidden_states,
228
+ attention_mask,
229
+ past_key_values,
230
+ use_cache,
231
+ output_attentions,
232
+ **kwargs
233
+ )
234
+ else:
235
+ hidden_states, attentions, past_key_values = layer(
236
+ hidden_states,
237
+ attention_mask=attention_mask,
238
+ past_key_values=past_key_values,
239
+ use_cache=use_cache,
240
+ output_attentions=output_attentions,
241
+ **kwargs
242
+ )
243
+
244
+ if output_attentions:
245
+ all_attns += (attentions,)
246
+
247
+ hidden_states = self.norm(hidden_states)
248
+
249
+ # add hidden states from the last decoder layer
250
+ if output_hidden_states:
251
+ all_hidden_states += (hidden_states,)
252
+
253
+ if not return_dict:
254
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
255
+ return BaseModelOutputWithPast(
256
+ last_hidden_state=hidden_states,
257
+ past_key_values=past_key_values,
258
+ hidden_states=all_hidden_states,
259
+ attentions=all_attns
260
+ )
261
+
262
+
263
+ class LightNetForCausalLM(LightNetPreTrainedModel, GenerationMixin):
264
+
265
+ _tied_weights_keys = ["lm_head.weight"]
266
+
267
+ def __init__(self, config):
268
+ super().__init__(config)
269
+ self.model = LightNetModel(config)
270
+ self.vocab_size = config.vocab_size
271
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
272
+ self.criterion = None
273
+
274
+ # Initialize weights and apply final processing
275
+ self.post_init()
276
+
277
+ def get_input_embeddings(self):
278
+ return self.model.embeddings
279
+
280
+ def set_input_embeddings(self, value):
281
+ self.model.embeddings = value
282
+
283
+ def get_output_embeddings(self):
284
+ return self.lm_head
285
+
286
+ def set_output_embeddings(self, new_embeddings):
287
+ self.lm_head = new_embeddings
288
+
289
+ def set_decoder(self, decoder):
290
+ self.model = decoder
291
+
292
+ def get_decoder(self):
293
+ return self.model
294
+
295
+ def generate(self, *args, **kwargs):
296
+ try:
297
+ return super().generate(*args, **kwargs)
298
+ except AttributeError as exception:
299
+ if 'past_key_values' in str(exception):
300
+ raise AttributeError(
301
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
302
+ f"which is not supported for {self.__class__.__name__}. "
303
+ f"Try another generation strategy instead. "
304
+ f"For the available generation strategies, check this doc: "
305
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
306
+ )
307
+ else:
308
+ raise exception
309
+
310
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
311
+ def prepare_inputs_for_generation(
312
+ self,
313
+ input_ids: torch.LongTensor = None,
314
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
315
+ attention_mask: Optional[torch.Tensor] = None,
316
+ inputs_embeds: Optional[torch.Tensor] = None,
317
+ use_cache: bool = True,
318
+ logits_to_keep: Optional[int] = None,
319
+ **kwargs: Unpack[Dict]
320
+ ):
321
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
322
+ if past_key_values is not None and len(past_key_values) > 0:
323
+ input_ids = input_ids[:, -1:]
324
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
325
+ if inputs_embeds is not None and len(past_key_values) == 0:
326
+ model_inputs = {'inputs_embeds': inputs_embeds}
327
+ else:
328
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
329
+ # recompiles graphs as the stride of the inputs is a guard.
330
+ # Ref: https://github.com/huggingface/transformers/pull/29114
331
+ # TODO: use `next_tokens` directly instead.
332
+ model_inputs = {'input_ids': input_ids.contiguous()}
333
+
334
+ if logits_to_keep is not None:
335
+ model_inputs['logits_to_keep'] = logits_to_keep
336
+
337
+ model_inputs.update({
338
+ 'past_key_values': past_key_values,
339
+ 'use_cache': use_cache,
340
+ 'attention_mask': attention_mask,
341
+ })
342
+ return model_inputs
343
+
344
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
345
+ def forward(
346
+ self,
347
+ input_ids: torch.LongTensor = None,
348
+ attention_mask: Optional[torch.Tensor] = None,
349
+ inputs_embeds: Optional[torch.Tensor] = None,
350
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
351
+ labels: Optional[torch.LongTensor] = None,
352
+ use_cache: Optional[bool] = None,
353
+ output_attentions: Optional[bool] = None,
354
+ output_hidden_states: Optional[bool] = None,
355
+ return_dict: Optional[bool] = None,
356
+ logits_to_keep: Optional[int] = 0,
357
+ **kwargs: Unpack[Dict]
358
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
359
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
360
+ output_hidden_states = (
361
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
362
+ )
363
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
364
+
365
+ outputs = self.model(
366
+ input_ids=input_ids,
367
+ attention_mask=attention_mask,
368
+ inputs_embeds=inputs_embeds,
369
+ past_key_values=past_key_values,
370
+ use_cache=use_cache,
371
+ output_attentions=output_attentions,
372
+ output_hidden_states=output_hidden_states,
373
+ return_dict=return_dict,
374
+ **kwargs
375
+ )
376
+
377
+ hidden_states = outputs[0]
378
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
379
+
380
+ loss, logits = None, None
381
+ if not fuse_linear_and_cross_entropy or labels is None:
382
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
383
+ if labels is not None:
384
+ if getattr(self, 'criterion', None) is None:
385
+ if fuse_linear_and_cross_entropy:
386
+ criterion = FusedLinearCrossEntropyLoss()
387
+ elif self.config.fuse_cross_entropy:
388
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
389
+ else:
390
+ criterion = nn.CrossEntropyLoss()
391
+ else:
392
+ criterion = self.criterion
393
+ labels = labels.to(hidden_states.device)
394
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
395
+ if fuse_linear_and_cross_entropy:
396
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
397
+ else:
398
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
399
+
400
+ if not return_dict:
401
+ output = (logits,) + outputs[1:]
402
+ return (loss,) + output if loss is not None else output
403
+
404
+ return CausalLMOutputWithPast(
405
+ loss=loss,
406
+ logits=logits,
407
+ past_key_values=outputs.past_key_values,
408
+ hidden_states=outputs.hidden_states,
409
+ attentions=outputs.attentions,
410
+ )
fla/models/linear_attn/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.linear_attn.configuration_linear_attn import LinearAttentionConfig
6
+ from fla.models.linear_attn.modeling_linear_attn import LinearAttentionForCausalLM, LinearAttentionModel
7
+
8
+ AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig)
9
+ AutoModel.register(LinearAttentionConfig, LinearAttentionModel)
10
+ AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM)
11
+
12
+ __all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel']
fla/models/linear_attn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (737 Bytes). View file
 
fla/models/linear_attn/__pycache__/configuration_linear_attn.cpython-312.pyc ADDED
Binary file (3.65 kB). View file
 
fla/models/linear_attn/__pycache__/modeling_linear_attn.cpython-312.pyc ADDED
Binary file (18.5 kB). View file
 
fla/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc ADDED
Binary file (7.06 kB). View file
 
fla/models/mamba2/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.mamba2.configuration_mamba2 import Mamba2Config
6
+ from fla.models.mamba2.modeling_mamba2 import Mamba2ForCausalLM, Mamba2Model
7
+
8
+ AutoConfig.register(Mamba2Config.model_type, Mamba2Config, True)
9
+ AutoModel.register(Mamba2Config, Mamba2Model, True)
10
+ AutoModelForCausalLM.register(Mamba2Config, Mamba2ForCausalLM, True)
11
+
12
+
13
+ __all__ = ['Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model']
fla/models/mamba2/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (695 Bytes). View file
 
fla/models/nsa/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.nsa.configuration_nsa import NSAConfig
6
+ from fla.models.nsa.modeling_nsa import NSAForCausalLM, NSAModel
7
+
8
+ AutoConfig.register(NSAConfig.model_type, NSAConfig)
9
+ AutoModel.register(NSAConfig, NSAModel)
10
+ AutoModelForCausalLM.register(NSAConfig, NSAForCausalLM)
11
+
12
+
13
+ __all__ = [
14
+ 'NSAConfig', 'NSAModel', 'NSAForCausalLM',
15
+ ]
fla/models/nsa/__pycache__/configuration_nsa.cpython-312.pyc ADDED
Binary file (2.64 kB). View file
 
fla/models/nsa/__pycache__/modeling_nsa.cpython-312.pyc ADDED
Binary file (17.6 kB). View file