Erland commited on
Commit
f63b81d
·
verified ·
1 Parent(s): e4ebaab

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. configs/delta_net_1B.json +29 -0
  2. configs/gla_340M.json +24 -0
  3. configs/gla_7B.json +25 -0
  4. configs/gsa_340M.json +29 -0
  5. configs/hgrn2_340M.json +20 -0
  6. configs/mtp_transformer_120M.json +19 -0
  7. configs/myopic_transformer_120M.json +19 -0
  8. configs/myopic_transformer_1B.json +23 -0
  9. configs/myopic_transformer_340M.json +19 -0
  10. configs/transformer_120M.json +18 -0
  11. configs/transformer_1B.json +22 -0
  12. configs/transformer_340M.json +18 -0
  13. configs/transformer_7B.json +21 -0
  14. fla/layers/__pycache__/__init__.cpython-311.pyc +0 -0
  15. fla/layers/__pycache__/abc.cpython-311.pyc +0 -0
  16. fla/layers/__pycache__/attn.cpython-311.pyc +0 -0
  17. fla/layers/__pycache__/based.cpython-311.pyc +0 -0
  18. fla/layers/__pycache__/forgetting_attn.cpython-311.pyc +0 -0
  19. fla/layers/__pycache__/gated_deltaproduct.cpython-311.pyc +0 -0
  20. fla/layers/__pycache__/gla.cpython-311.pyc +0 -0
  21. fla/layers/__pycache__/gsa.cpython-311.pyc +0 -0
  22. fla/layers/__pycache__/hgrn.cpython-311.pyc +0 -0
  23. fla/layers/__pycache__/multiscale_retention.cpython-311.pyc +0 -0
  24. fla/layers/__pycache__/nsa.cpython-311.pyc +0 -0
  25. fla/layers/__pycache__/rwkv6.cpython-311.pyc +0 -0
  26. fla/layers/attn.py +203 -0
  27. fla/layers/based.py +96 -0
  28. fla/layers/forgetting_attn.py +109 -0
  29. fla/layers/hgrn2.py +211 -0
  30. fla/layers/linear_attn.py +166 -0
  31. fla/layers/multiscale_retention.py +298 -0
  32. fla/models/abc/__pycache__/configuration_abc.cpython-311.pyc +0 -0
  33. fla/models/abc/__pycache__/modeling_abc.cpython-311.pyc +0 -0
  34. fla/models/abc/configuration_abc.py +91 -0
  35. fla/models/bitnet/modeling_bitnet.py +441 -0
  36. fla/models/delta_net/__pycache__/configuration_delta_net.cpython-311.pyc +0 -0
  37. fla/models/delta_net/modeling_delta_net.py +415 -0
  38. fla/models/forgetting_transformer/__init__.py +16 -0
  39. fla/models/forgetting_transformer/modeling_forgetting_transformer.py +408 -0
  40. fla/models/gated_deltanet/__init__.py +12 -0
  41. fla/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-311.pyc +0 -0
  42. fla/models/gated_deltaproduct/__init__.py +14 -0
  43. fla/models/gated_deltaproduct/__pycache__/__init__.cpython-311.pyc +0 -0
  44. fla/models/gla/__pycache__/configuration_gla.cpython-311.pyc +0 -0
  45. fla/models/gla/__pycache__/modeling_gla.cpython-311.pyc +0 -0
  46. fla/models/hgrn/__pycache__/__init__.cpython-311.pyc +0 -0
  47. fla/models/hgrn/configuration_hgrn.py +81 -0
  48. fla/models/hgrn2/__init__.py +13 -0
  49. fla/models/hgrn2/__pycache__/__init__.cpython-311.pyc +0 -0
  50. fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-311.pyc +0 -0
configs/delta_net_1B.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "conv_size": 4,
6
+ "eos_token_id": 2,
7
+ "expand_k": 1,
8
+ "expand_v": 1,
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "hidden_act": "swish",
12
+ "hidden_ratio": 4,
13
+ "hidden_size": 2048,
14
+ "initializer_range": 0.006,
15
+ "intermediate_size": null,
16
+ "model_type": "delta_net",
17
+ "norm_eps": 1e-06,
18
+ "num_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "pad_token_id": 2,
21
+ "qk_activation": "silu",
22
+ "qk_norm": "l2",
23
+ "tie_word_embeddings": false,
24
+ "use_beta": true,
25
+ "use_cache": true,
26
+ "use_gate": false,
27
+ "use_output_norm": true,
28
+ "use_short_conv": true
29
+ }
configs/gla_340M.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "clamp_min": null,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": null,
15
+ "model_type": "gla",
16
+ "num_heads": 4,
17
+ "num_hidden_layers": 24,
18
+ "norm_eps": 1e-06,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "vocab_size": 32000
24
+ }
configs/gla_7B.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn": null,
3
+ "attn_mode": "chunk",
4
+ "bos_token_id": 1,
5
+ "eos_token_id": 2,
6
+ "expand_k": 0.5,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "fuse_norm": true,
10
+ "hidden_act": "swish",
11
+ "hidden_ratio": 4,
12
+ "hidden_size": 4096,
13
+ "initializer_range": 0.006,
14
+ "intermediate_size": 11008,
15
+ "model_type": "gla",
16
+ "norm_eps": 1e-06,
17
+ "num_heads": 16,
18
+ "num_hidden_layers": 32,
19
+ "tie_word_embeddings": false,
20
+ "use_cache": true,
21
+ "use_gk": true,
22
+ "use_gv": false,
23
+ "use_output_gate": true,
24
+ "use_short_conv": false
25
+ }
configs/gsa_340M.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "conv_size": 4,
4
+ "eos_token_id": 2,
5
+ "expand_k": 1,
6
+ "expand_v": 1,
7
+ "elementwise_affine": false,
8
+ "feature_map": "swish",
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_logit_normalizer": 4,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.006,
16
+ "intermediate_size": null,
17
+ "model_type": "gsa",
18
+ "num_heads": 4,
19
+ "num_hidden_layers": 24,
20
+ "num_slots": 64,
21
+ "norm_eps": 1e-06,
22
+ "share_conv_kernel": true,
23
+ "tie_word_embeddings": false,
24
+ "use_cache": true,
25
+ "use_norm": true,
26
+ "use_output_gate": true,
27
+ "use_rope": false,
28
+ "use_short_conv": false
29
+ }
configs/hgrn2_340M.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "expand_ratio": 128,
6
+ "fuse_cross_entropy": true,
7
+ "fuse_norm": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 1024,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "model_type": "hgrn2",
14
+ "num_heads": 8,
15
+ "num_hidden_layers": 24,
16
+ "norm_eps": 1e-06,
17
+ "tie_word_embeddings": false,
18
+ "use_cache": true,
19
+ "vocab_size": 32000
20
+ }
configs/mtp_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "mtp_transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "n_future_tokens": 4
19
+ }
configs/myopic_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "use_myopic_loss": true
19
+ }
configs/myopic_transformer_1B.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false,
22
+ "use_myopic_loss": true
23
+ }
configs/myopic_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "use_myopic_loss": true
19
+ }
configs/transformer_120M.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000
18
+ }
configs/transformer_1B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 24,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false
22
+ }
configs/transformer_340M.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000
18
+ }
configs/transformer_7B.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_ratio": 4,
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.006,
11
+ "intermediate_size": 14336,
12
+ "model_type": "transformer",
13
+ "norm_eps": 1e-06,
14
+ "num_heads": 32,
15
+ "num_hidden_layers": 32,
16
+ "num_kv_heads": 8,
17
+ "rope_theta": 10000.0,
18
+ "tie_word_embeddings": false,
19
+ "use_cache": true,
20
+ "window_size": null
21
+ }
fla/layers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.5 kB). View file
 
fla/layers/__pycache__/abc.cpython-311.pyc ADDED
Binary file (9.78 kB). View file
 
fla/layers/__pycache__/attn.cpython-311.pyc ADDED
Binary file (9.99 kB). View file
 
fla/layers/__pycache__/based.cpython-311.pyc ADDED
Binary file (6.91 kB). View file
 
fla/layers/__pycache__/forgetting_attn.cpython-311.pyc ADDED
Binary file (5.47 kB). View file
 
fla/layers/__pycache__/gated_deltaproduct.cpython-311.pyc ADDED
Binary file (16.2 kB). View file
 
fla/layers/__pycache__/gla.cpython-311.pyc ADDED
Binary file (13.7 kB). View file
 
fla/layers/__pycache__/gsa.cpython-311.pyc ADDED
Binary file (10.3 kB). View file
 
fla/layers/__pycache__/hgrn.cpython-311.pyc ADDED
Binary file (7.23 kB). View file
 
fla/layers/__pycache__/multiscale_retention.cpython-311.pyc ADDED
Binary file (13 kB). View file
 
fla/layers/__pycache__/nsa.cpython-311.pyc ADDED
Binary file (6.73 kB). View file
 
fla/layers/__pycache__/rwkv6.cpython-311.pyc ADDED
Binary file (15.6 kB). View file
 
fla/layers/attn.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from einops import rearrange
14
+ from transformers.utils import logging
15
+
16
+ from fla.modules import RMSNorm, RotaryEmbedding
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+ try:
22
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
23
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
24
+ except ImportError:
25
+ warnings.warn(
26
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
27
+ category=ImportWarning
28
+ )
29
+ flash_attn_func = None
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ class Attention(nn.Module):
35
+
36
+ def __init__(
37
+ self,
38
+ hidden_size: int = 2048,
39
+ num_heads: int = 32,
40
+ num_kv_heads: Optional[int] = None,
41
+ qkv_bias: bool = False,
42
+ qk_norm: bool = False,
43
+ window_size: Optional[int] = None,
44
+ rope_theta: Optional[float] = 10000.,
45
+ max_position_embeddings: Optional[int] = None,
46
+ layer_idx: int = None
47
+ ):
48
+ super().__init__()
49
+
50
+ self.hidden_size = hidden_size
51
+ self.num_heads = num_heads
52
+ if num_kv_heads is None:
53
+ self.num_kv_heads = self.num_heads
54
+ else:
55
+ self.num_kv_heads = num_kv_heads
56
+ self.num_kv_groups = num_heads // self.num_kv_heads
57
+ self.head_dim = self.hidden_size // self.num_heads
58
+ self.kv_dim = self.num_kv_heads * self.head_dim
59
+ self.qkv_bias = qkv_bias
60
+ self.qk_norm = qk_norm
61
+
62
+ self.window_size = window_size
63
+ self.rope_theta = rope_theta
64
+ self.max_position_embeddings = max_position_embeddings
65
+ self.layer_idx = layer_idx
66
+
67
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
68
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
69
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
70
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
71
+
72
+ if qk_norm:
73
+ self.q_norm = RMSNorm(self.head_dim)
74
+ self.k_norm = RMSNorm(self.head_dim)
75
+
76
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
77
+
78
+ def forward(
79
+ self,
80
+ hidden_states: torch.Tensor,
81
+ attention_mask: Optional[torch.LongTensor] = None,
82
+ past_key_values: Optional[Cache] = None,
83
+ output_attentions: bool = False,
84
+ use_cache: bool = False,
85
+ **kwargs,
86
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
87
+ if attention_mask is not None:
88
+ assert len(attention_mask.shape) == 2, (
89
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
90
+ "for padding purposes (0 indicating padding). "
91
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
92
+ )
93
+
94
+ batch_size, q_len, _ = hidden_states.size()
95
+
96
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
97
+
98
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
99
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
100
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
101
+
102
+ if self.qk_norm:
103
+ q, k = self.q_norm(q), self.k_norm(k)
104
+
105
+ # equivalent to cu_seqlens in `flash_attn`
106
+ cu_seqlens = kwargs.get('cu_seqlens', None)
107
+
108
+ seqlen_offset, max_seqlen = 0, q_len
109
+ if past_key_values is not None:
110
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
111
+ max_seqlen = q.shape[1] + seqlen_offset
112
+
113
+ if attention_mask is not None:
114
+ # to deliminate the offsets of padding tokens
115
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
116
+ max_seqlen = q.shape[1] + max(seqlen_offset)
117
+
118
+ if self.max_position_embeddings is not None:
119
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
120
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
121
+
122
+ if past_key_values is not None:
123
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
124
+ k_cached, v_cached = past_key_values.update(
125
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
126
+ layer_idx=self.layer_idx,
127
+ offset=q_len,
128
+ cache_kwargs=dict(window_size=self.window_size)
129
+ )['attn_state']
130
+ if cache_has_content:
131
+ k, v = k_cached, v_cached
132
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
133
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
134
+
135
+ if flash_attn_func is None:
136
+ raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
137
+
138
+ # Contains at least one padding token in the sequence
139
+ if attention_mask is not None:
140
+ q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
141
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
142
+ max_seqlen_q, max_seqlen_k = max_seq_lens
143
+ o = flash_attn_varlen_func(
144
+ q, k, v,
145
+ cu_seqlens_q=cu_seqlens_q,
146
+ cu_seqlens_k=cu_seqlens_k,
147
+ max_seqlen_q=max_seqlen_q,
148
+ max_seqlen_k=max_seqlen_k,
149
+ causal=True,
150
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
151
+ )
152
+ o = pad_input(o, indices_q, batch_size, q_len)
153
+ elif cu_seqlens is not None:
154
+ o = flash_attn_varlen_func(
155
+ q.squeeze(0), k.squeeze(0), v.squeeze(0),
156
+ cu_seqlens_q=cu_seqlens,
157
+ cu_seqlens_k=cu_seqlens,
158
+ max_seqlen_q=max_seqlen,
159
+ max_seqlen_k=max_seqlen,
160
+ causal=True,
161
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
162
+ ).unsqueeze(0)
163
+ else:
164
+ o = flash_attn_func(
165
+ q, k, v,
166
+ causal=True,
167
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
168
+ )
169
+ o = o.reshape(batch_size, q_len, -1)
170
+ o = self.o_proj(o)
171
+
172
+ if not output_attentions:
173
+ attentions = None
174
+
175
+ return o, attentions, past_key_values
176
+
177
+ def _upad_input(self, q, k, v, attention_mask, q_len):
178
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
179
+ cache_mask = attention_mask[:, -seq_len:]
180
+ seqlens = cache_mask.sum(-1, dtype=torch.int32)
181
+ indices_k = torch.nonzero(cache_mask.flatten(), as_tuple=False).flatten()
182
+ max_seqlen_k = seqlens.max().item()
183
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
184
+
185
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
186
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
187
+ if q_len == seq_len:
188
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
189
+ cu_seqlens_q = cu_seqlens_k
190
+ max_seqlen_q = max_seqlen_k
191
+ indices_q = indices_k
192
+ elif q_len == 1:
193
+ max_seqlen_q = 1
194
+ # There is a memcpy here, that is very bad.
195
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
196
+ indices_q = cu_seqlens_q[:-1]
197
+ q = q.squeeze(1)
198
+ else:
199
+ # The -q_len: slice assumes left padding.
200
+ attention_mask = attention_mask[:, -q_len:]
201
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
202
+
203
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
fla/layers/based.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ Linear attention in Based.
6
+ https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+
13
+ from fla.modules.feature_map import TaylorFeatureMap
14
+ from fla.ops.based import parallel_based
15
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
16
+
17
+
18
+ class BasedLinearAttention(nn.Module):
19
+
20
+ def __init__(
21
+ self,
22
+ hidden_size: int,
23
+ feature_dim: int = 16,
24
+ num_key_value_heads: int = 12,
25
+ num_heads: int = 12,
26
+ feature_name: str = "taylor_exp",
27
+ eps: float = 1e-12,
28
+ causal: bool = True,
29
+ mode: str = "parallel",
30
+ ):
31
+ super().__init__()
32
+
33
+ self.hidden_size = hidden_size
34
+ self.mode = mode
35
+ self.feature_name = feature_name
36
+ self.feature_dim = feature_dim
37
+ self.num_key_value_heads = num_key_value_heads
38
+ self.num_heads = num_heads
39
+ self.head_dim = self.hidden_size // self.num_key_value_heads
40
+ assert self.hidden_size % self.head_dim == 0
41
+ self.causal = causal
42
+
43
+ self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
44
+ self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
45
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
46
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
47
+ self.dropout = nn.Identity()
48
+ self.feature_map = TaylorFeatureMap(feature_dim)
49
+ self.eps = eps
50
+
51
+ def forward(self, hidden_states: torch.Tensor, **kwargs):
52
+ mode = self.mode
53
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
54
+ q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), [q, k, v])
55
+ if mode == "fused_chunk":
56
+ q, k = self.feature_map(q), self.feature_map(k)
57
+ o, _ = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False)
58
+ elif mode == 'chunk':
59
+ q, k = self.feature_map(q), self.feature_map(k)
60
+ o, _ = chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False)
61
+ elif mode == 'parallel':
62
+ assert q.shape[-1] <= 128
63
+ o = parallel_based(q, k, v, scale=1, use_norm=True, head_first=False)
64
+ o = rearrange(o, 'b t h d -> b t (h d)')
65
+ o = self.o_proj(o)
66
+ o = self.dropout(o)
67
+ return o
68
+
69
+ # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
70
+
71
+ def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
72
+ """
73
+ x (torch.Tensor): tensor of shape (b, d, t)
74
+ y (torch.Tensor): tensor of shape (b, d, t)
75
+ """
76
+ # hidden_states = hidden_states.transpose(1, 2)
77
+ b, t, _ = hidden_states.size()
78
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
79
+
80
+ q = q.view(b, t, self.num_heads, self.feature_dim).transpose(1, 2)
81
+ k = k.view(b, t, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
82
+ v = v.view(b, t, self.num_key_value_heads, self.head_dim).transpose(1, 2)
83
+
84
+ # Linear attention
85
+ q, k = self.feature_map(q), self.feature_map(k)
86
+ q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
87
+
88
+ # Compute attention
89
+ if self.causal:
90
+ y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
91
+ else:
92
+ y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
93
+ y = rearrange(y, 'b h t d -> b t (h d)')
94
+ y = self.o_proj(y.to(hidden_states.dtype))
95
+ y = self.dropout(y)
96
+ return y.to(hidden_states.dtype)
fla/layers/forgetting_attn.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from einops import rearrange
13
+ from transformers.utils import logging
14
+
15
+ from fla.modules import GroupNorm
16
+ from fla.ops.forgetting_attn.parallel import parallel_forgetting_attn
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class ForgettingAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ hidden_size: int = 2048,
30
+ num_heads: int = 32,
31
+ num_kv_heads: Optional[int] = None,
32
+ qkv_bias: bool = False,
33
+ qk_norm: bool = False,
34
+ window_size: Optional[int] = None,
35
+ use_output_gate: bool = False,
36
+ layer_idx: int = None
37
+ ):
38
+ super().__init__()
39
+
40
+ self.hidden_size = hidden_size
41
+ self.num_heads = num_heads
42
+ if num_kv_heads is None:
43
+ self.num_kv_heads = self.num_heads
44
+ else:
45
+ self.num_kv_heads = num_kv_heads
46
+ self.num_kv_groups = num_heads // self.num_kv_heads
47
+ self.head_dim = self.hidden_size // self.num_heads
48
+ self.kv_dim = self.num_kv_heads * self.head_dim
49
+ self.qkv_bias = qkv_bias
50
+ self.qk_norm = qk_norm
51
+
52
+ self.window_size = window_size
53
+ self.use_output_gate = use_output_gate
54
+ self.layer_idx = layer_idx
55
+
56
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
57
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
58
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
59
+ self.f_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
60
+
61
+ if use_output_gate:
62
+ self.g_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
63
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
64
+
65
+ if qk_norm:
66
+ self.q_norm = GroupNorm(
67
+ num_groups=self.num_heads,
68
+ hidden_size=self.hidden_size,
69
+ is_rms_norm=True,
70
+ )
71
+ self.k_norm = GroupNorm(
72
+ num_groups=self.num_kv_heads,
73
+ hidden_size=self.kv_dim,
74
+ is_rms_norm=True,
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ hidden_states: torch.Tensor,
80
+ attention_mask: Optional[torch.LongTensor] = None,
81
+ past_key_values: Optional[Cache] = None,
82
+ output_attentions: bool = False,
83
+ use_cache: bool = False,
84
+ **kwargs,
85
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
86
+ if attention_mask is not None:
87
+ assert len(attention_mask.shape) == 2, (
88
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
89
+ "for padding purposes (0 indicating padding). "
90
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
91
+ )
92
+
93
+ cu_seqlens = kwargs.get('cu_seqlens', None)
94
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
95
+ f = F.logsigmoid(self.f_proj(hidden_states).float())
96
+ if self.qk_norm:
97
+ q, k = self.q_norm(q), self.k_norm(k)
98
+
99
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
100
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
101
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
102
+
103
+ o = parallel_forgetting_attn(q, k, v, f, cu_seqlens=cu_seqlens)
104
+ o = rearrange(o, '... h d -> ... (h d)')
105
+ if self.use_output_gate:
106
+ o = self.g_proj(hidden_states).sigmoid() * o
107
+ o = self.o_proj(o)
108
+
109
+ return o, None, past_key_values
fla/layers/hgrn2.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+
15
+ from fla.modules import RMSNorm, ShortConvolution
16
+ from fla.modules.activations import swish
17
+ from fla.modules.layernorm import rms_norm_linear
18
+ from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
19
+
20
+ if TYPE_CHECKING:
21
+ from transformers.processing_utils import Unpack
22
+
23
+ from fla.models.utils import Cache
24
+
25
+
26
+ class HGRN2Attention(nn.Module):
27
+
28
+ def __init__(
29
+ self,
30
+ mode: str = 'chunk',
31
+ hidden_size: int = 1024,
32
+ num_heads: Optional[int] = None,
33
+ expand_ratio: Optional[int] = 128,
34
+ use_short_conv: bool = False,
35
+ conv_size: int = 4,
36
+ conv_bias: bool = False,
37
+ elementwise_affine: Optional[bool] = True,
38
+ norm_eps: float = 1e-5,
39
+ layer_idx: int = None
40
+ ) -> HGRN2Attention:
41
+ super().__init__()
42
+
43
+ self.mode = mode
44
+ self.hidden_size = hidden_size
45
+
46
+ if expand_ratio is None and num_heads is not None:
47
+ expand_ratio = hidden_size // num_heads
48
+ elif expand_ratio is not None and num_heads is None:
49
+ num_heads = hidden_size // expand_ratio
50
+ elif expand_ratio is None and num_heads is None:
51
+ raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
52
+ self.num_heads = num_heads
53
+ self.expand_ratio = expand_ratio
54
+
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.conv_bias = conv_bias
58
+
59
+ self.forget_dim = int(self.num_heads * self.expand_ratio)
60
+ self.input_dim = hidden_size
61
+ self.layer_idx = layer_idx
62
+
63
+ assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
64
+ assert self.forget_dim % num_heads == 0, f"forget dim must be divisible by num_heads of {num_heads}"
65
+ assert self.input_dim % num_heads == 0, f"input dim must be divisible by num_heads of {num_heads}"
66
+
67
+ self.head_f_dim = self.expand_ratio
68
+ self.head_i_dim = self.hidden_size // num_heads
69
+
70
+ self.q_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
71
+ self.f_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
72
+ self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
73
+
74
+ if use_short_conv:
75
+ self.conv_size = conv_size
76
+ self.q_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None)
77
+ self.f_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None)
78
+ self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
79
+
80
+ self.g_norm = RMSNorm(hidden_size=self.hidden_size, elementwise_affine=elementwise_affine, eps=norm_eps)
81
+ self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
82
+
83
+ def forward(
84
+ self,
85
+ hidden_states: torch.Tensor,
86
+ attention_mask: Optional[torch.Tensor] = None,
87
+ past_key_values: Optional[Cache] = None,
88
+ use_cache: Optional[bool] = False,
89
+ output_attentions: Optional[bool] = False,
90
+ lower_bound: Optional[torch.Tensor] = None,
91
+ **kwargs: Unpack[Dict]
92
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
93
+ if attention_mask is not None:
94
+ assert len(attention_mask.shape) == 2, (
95
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
96
+ "for padding purposes (0 indicating padding). "
97
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
98
+ )
99
+
100
+ # launching the triton kernel for just one token will actually be slower
101
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
102
+
103
+ last_state = None
104
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
105
+ last_state = past_key_values[self.layer_idx]
106
+
107
+ cu_seqlens = kwargs.get('cu_seqlens', None)
108
+ if self.use_short_conv:
109
+ conv_state_q, conv_state_f, conv_state_i = None, None, None
110
+ if last_state is not None:
111
+ conv_state_q, conv_state_f, conv_state_i = last_state['conv_state']
112
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
113
+ q, conv_state_q = self.q_conv1d(
114
+ x=self.q_proj(hidden_states),
115
+ mask=conv_mask,
116
+ cache=conv_state_q,
117
+ output_final_state=use_cache,
118
+ cu_seqlens=cu_seqlens
119
+ )
120
+ f, conv_state_f = self.f_conv1d(
121
+ x=self.f_proj(hidden_states),
122
+ mask=conv_mask,
123
+ cache=conv_state_f,
124
+ output_final_state=use_cache,
125
+ cu_seqlens=cu_seqlens
126
+ )
127
+ i, conv_state_i = self.i_conv1d(
128
+ x=self.i_proj(hidden_states),
129
+ mask=conv_mask,
130
+ cache=conv_state_i,
131
+ output_final_state=use_cache,
132
+ cu_seqlens=cu_seqlens
133
+ )
134
+ else:
135
+ q = self.q_proj(hidden_states)
136
+ f = self.f_proj(hidden_states)
137
+ i = self.i_proj(hidden_states)
138
+
139
+ # dealing with left-padding
140
+ if attention_mask is not None:
141
+ i = i.mul_(attention_mask[:, -i.shape[-2]:, None])
142
+
143
+ q = swish(q)
144
+
145
+ # improve precision
146
+ f = f.float()
147
+
148
+ # the lower bound for the first layer is zero
149
+ if lower_bound is None or self.layer_idx == 0:
150
+ k, g = 1 - f.sigmoid(), F.logsigmoid(f)
151
+ else:
152
+ g = lower_bound + (1 - lower_bound) * f.sigmoid()
153
+ k, g = 1 - g, g.log()
154
+
155
+ q, k, g = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_f_dim), (q, k.to(i), g))
156
+ i = rearrange(i, '... (h d) -> ... h d', d=self.head_i_dim)
157
+
158
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
159
+ if mode == 'fused_recurrent':
160
+ o, recurrent_state = fused_recurrent_gla(
161
+ q=q,
162
+ k=k,
163
+ v=i,
164
+ gk=g,
165
+ initial_state=recurrent_state,
166
+ output_final_state=use_cache,
167
+ cu_seqlens=cu_seqlens,
168
+ head_first=False
169
+ )
170
+ elif mode == 'fused_chunk':
171
+ o, recurrent_state = fused_chunk_gla(
172
+ q=q,
173
+ k=k,
174
+ v=i,
175
+ g=g,
176
+ initial_state=recurrent_state,
177
+ output_final_state=use_cache,
178
+ head_first=False
179
+ )
180
+ elif mode == 'chunk':
181
+ o, recurrent_state = chunk_gla(
182
+ q=q,
183
+ k=k,
184
+ v=i,
185
+ g=g,
186
+ initial_state=recurrent_state,
187
+ output_final_state=use_cache,
188
+ cu_seqlens=cu_seqlens,
189
+ head_first=False
190
+ )
191
+ else:
192
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
193
+
194
+ if past_key_values is not None:
195
+ past_key_values.update(
196
+ recurrent_state=recurrent_state,
197
+ conv_state=(conv_state_q, conv_state_f, conv_state_i) if self.use_short_conv else None,
198
+ layer_idx=self.layer_idx,
199
+ offset=q.shape[1]
200
+ )
201
+
202
+ o = rearrange(o, '... h d -> ... (h d)')
203
+ o = rms_norm_linear(o, self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias)
204
+ return o, None, past_key_values
205
+
206
+ def state_size(self, **kwargs) -> int:
207
+ state_size = self.forget_dim * self.head_i_dim
208
+ for module in self.children():
209
+ if isinstance(module, ShortConvolution):
210
+ state_size += module.state_size
211
+ return state_size
fla/layers/linear_attn.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from fla.modules import RMSNorm
12
+ from fla.modules.feature_map import DPFPFeatureMap, HadamardFeatureMap, HedgehogFeatureMap, T2RFeatureMap
13
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn
14
+
15
+
16
+ class LinearAttention(nn.Module):
17
+
18
+ def __init__(
19
+ self,
20
+ mode: str = 'chunk',
21
+ hidden_size: str = 1024,
22
+ expand_k: int = 1.0,
23
+ expand_v: int = 1.0,
24
+ num_heads: int = 8,
25
+ num_kv_heads: Optional[int] = None,
26
+ feature_map: str = 'elementwise_product',
27
+ tie_feature_map_qk: bool = False,
28
+ output_norm: str = 'rmsnorm',
29
+ norm_q: bool = False,
30
+ norm_k: bool = False,
31
+ do_feature_map_norm: bool = False,
32
+ elementwise_affine: bool = True,
33
+ norm_eps: float = 1e-5,
34
+ **kwargs
35
+ ):
36
+ super().__init__()
37
+
38
+ self.hidden_size = hidden_size
39
+ self.mode = mode
40
+ self.num_heads = num_heads
41
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
42
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
43
+ self.key_dim = int(hidden_size * expand_k)
44
+ self.value_dim = int(hidden_size * expand_v)
45
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
46
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
47
+
48
+ assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
49
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
50
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
51
+
52
+ self.head_k_dim = self.key_dim // num_heads
53
+ self.head_v_dim = self.value_dim // num_heads
54
+ self.do_feature_map_norm = do_feature_map_norm
55
+
56
+ if feature_map == 'hedgehog':
57
+ if tie_feature_map_qk:
58
+ self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim)
59
+ else:
60
+ self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_k_dim)
61
+ self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim)
62
+
63
+ elif feature_map == 't2r':
64
+ if tie_feature_map_qk:
65
+ self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim)
66
+ else:
67
+ self.feature_map_q = T2RFeatureMap(head_dim=self.head_k_dim)
68
+ self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim)
69
+
70
+ elif feature_map == 'elementwise_product':
71
+ if tie_feature_map_qk:
72
+ self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim)
73
+ else:
74
+ self.feature_map_q = HadamardFeatureMap(head_dim=self.head_k_dim)
75
+ self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim)
76
+
77
+ elif feature_map == 'dpfp':
78
+ self.feature_map_q = DPFPFeatureMap(head_dim=self.head_k_dim)
79
+ self.feature_map_k = DPFPFeatureMap(head_dim=self.head_k_dim)
80
+
81
+ elif feature_map == 'elu':
82
+ def elu(x):
83
+ return F.elu(x) + 1
84
+ self.feature_map_q = elu
85
+ self.feature_map_k = elu
86
+
87
+ elif feature_map == 'relu':
88
+ self.feature_map_q = nn.ReLU()
89
+ self.feature_map_k = nn.ReLU()
90
+
91
+ elif feature_map == 'identity':
92
+ self.feature_map_q = nn.Identity()
93
+ self.feature_map_k = nn.Identity()
94
+ else:
95
+ raise NotImplementedError(f"Not supported feature map `{feature_map}`.")
96
+
97
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
98
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
99
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
100
+
101
+ if output_norm == 'rmsnorm':
102
+ self.norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps)
103
+ elif output_norm == 'identity':
104
+ self.norm = nn.Identity()
105
+ else:
106
+ raise NotImplementedError(f"Not supported output norm `{output_norm}`.")
107
+
108
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
109
+
110
+ self.norm_q = norm_q
111
+ self.norm_k = norm_k
112
+
113
+ def forward(
114
+ self,
115
+ hidden_states: torch.Tensor,
116
+ **kwargs
117
+ ) -> torch.Tensor:
118
+ mode = self.mode
119
+ q = self.q_proj(hidden_states)
120
+ k = self.k_proj(hidden_states)
121
+ v = self.v_proj(hidden_states)
122
+
123
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim)
124
+ if self.num_kv_groups > 1:
125
+ k = repeat(k, '... (h d) -> ... (h g) d', d=self.head_k_dim, g=self.num_kv_groups)
126
+ v = repeat(v, '... (h d) -> ... (h g) d', d=self.head_v_dim, g=self.num_kv_groups)
127
+ else:
128
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim)
129
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
130
+
131
+ q = self.feature_map_q(q)
132
+ k = self.feature_map_k(k)
133
+
134
+ if self.norm_q:
135
+ q = q / (q.sum(-1, True) + 1e-4)
136
+ if self.norm_k:
137
+ k = k / (k.sum(-1, True) + 1e-4)
138
+
139
+ if mode == 'chunk':
140
+ o, final_state = chunk_linear_attn(
141
+ q=q,
142
+ k=k,
143
+ v=v,
144
+ normalize=self.do_feature_map_norm,
145
+ head_first=False
146
+ )
147
+ elif mode == 'fused_chunk':
148
+ o, final_state = fused_chunk_linear_attn(
149
+ q=q,
150
+ k=k,
151
+ v=v,
152
+ normalize=self.do_feature_map_norm,
153
+ )
154
+ elif mode == 'fused_recurrent':
155
+ o, final_state = fused_recurrent_linear_attn(
156
+ q=q,
157
+ k=k,
158
+ v=v,
159
+ normalize=self.do_feature_map_norm,
160
+ )
161
+ else:
162
+ raise NotImplementedError
163
+ o = self.norm(o)
164
+ o = rearrange(o, '... h d -> ... (h d)')
165
+ o = self.o_proj(o)
166
+ return o
fla/layers/multiscale_retention.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange, repeat
11
+ from transformers.activations import ACT2FN
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
14
+ from fla.modules.rotary import RotaryEmbedding
15
+ from fla.ops.retention import chunk_retention, fused_chunk_retention, fused_recurrent_retention, parallel_retention
16
+
17
+ if TYPE_CHECKING:
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ class MultiScaleRetention(nn.Module):
22
+ r"""
23
+ The layer implementaion for [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf). # noqa
24
+
25
+ Args:
26
+ mode (str, Optional):
27
+ Which Retention kernel to use.
28
+ Currently available: `chunk`, `fused_recurrent`, `parallel`, and `fused_chunk`.
29
+ Default: `chunk`.
30
+ hidden_size (int, Optional):
31
+ The hidden size of the input. Default: 1024.
32
+ expand_k (float, Optional):
33
+ The expansion ratio for the key dim. Default: 1.0.
34
+ expand_v (float, Optional):
35
+ The expansion ratio for the value dim. Default: 2.0.
36
+ num_heads (int, Optional):
37
+ The number of heads. Default: 8.
38
+ num_kv_heads (int, Optional):
39
+ The number of key/value heads, used for MQA. Default: None.
40
+ feature_map (str, Optional):
41
+ Feature map function applied to queries/keys. Default: None.
42
+ use_short_conv (bool, Optional):
43
+ Whether to use short convolutions. Default: `False`.
44
+ conv_size (int, Optional):
45
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
46
+ conv_bias (bool, Optional):
47
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
48
+ use_output_gate (bool, Optional):
49
+ Whether to use output gate. Default: `True`.
50
+ gate_fn (str, Optional):
51
+ The activation function for the output gate. Default: `swish`.
52
+ elementwise_affine (bool, Optional):
53
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
54
+ norm_eps (float, Optional):
55
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
56
+ fuse_norm (bool, Optional):
57
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
58
+ layer_idx (int, Optional):
59
+ The index of the layer. Default: None.
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ mode: str = 'chunk',
65
+ hidden_size: int = 1024,
66
+ expand_k: float = 1.0,
67
+ expand_v: float = 2.0,
68
+ num_heads: int = 8,
69
+ num_kv_heads: Optional[int] = None,
70
+ feature_map: Optional[str] = None,
71
+ use_short_conv: bool = False,
72
+ conv_size: int = 4,
73
+ conv_bias: bool = False,
74
+ use_output_gate: bool = True,
75
+ gate_fn: str = 'swish',
76
+ elementwise_affine: Optional[bool] = True,
77
+ norm_eps: float = 1e-5,
78
+ fuse_norm: bool = True,
79
+ layer_idx: int = None,
80
+ **kwargs
81
+ ) -> MultiScaleRetention:
82
+ super().__init__()
83
+
84
+ self.mode = mode
85
+ self.hidden_size = hidden_size
86
+ self.expand_k = expand_k
87
+ self.expand_v = expand_v
88
+ self.num_heads = num_heads
89
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
90
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
91
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
92
+
93
+ self.use_short_conv = use_short_conv
94
+ self.conv_size = conv_size
95
+ self.conv_bias = conv_bias
96
+ self.use_output_gate = use_output_gate
97
+
98
+ self.key_dim = int(hidden_size * expand_k)
99
+ self.value_dim = int(hidden_size * expand_v)
100
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
101
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
102
+ self.layer_idx = layer_idx
103
+
104
+ assert mode in ['chunk', 'fused_chunk', 'parallel', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
105
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
106
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
107
+
108
+ self.head_k_dim = self.key_dim // num_heads
109
+ self.head_v_dim = self.value_dim // num_heads
110
+
111
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
112
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
113
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
114
+ if self.use_output_gate:
115
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
116
+
117
+ if use_short_conv:
118
+ self.conv_size = conv_size
119
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
120
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
121
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
122
+
123
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
124
+
125
+ if gate_fn == 'swish' and fuse_norm and use_output_gate:
126
+ self.g_norm_swish_gate = FusedRMSNormGated(
127
+ hidden_size=self.head_v_dim,
128
+ elementwise_affine=elementwise_affine,
129
+ eps=norm_eps
130
+ )
131
+ self.fuse_norm_and_gate = True
132
+ else:
133
+ self.fuse_norm_and_gate = False
134
+ self.g_norm = RMSNorm(
135
+ hidden_size=self.head_v_dim,
136
+ elementwise_affine=elementwise_affine,
137
+ eps=norm_eps
138
+ )
139
+ self.gate_fn = ACT2FN[gate_fn]
140
+
141
+ # TODO: fix this issue
142
+ # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py#L180
143
+ # Ideally, we would want to support arbitrary d_head_qk
144
+ assert self.head_k_dim <= 256, "head_k_dim must be less than or equal to 256"
145
+ self.rotary = RotaryEmbedding(dim=self.head_k_dim)
146
+
147
+ def forward(
148
+ self,
149
+ hidden_states: torch.Tensor,
150
+ attention_mask: Optional[torch.Tensor] = None,
151
+ past_key_values: Optional[Cache] = None,
152
+ use_cache: Optional[bool] = False,
153
+ output_attentions: Optional[bool] = False,
154
+ **kwargs
155
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
156
+ if attention_mask is not None:
157
+ assert len(attention_mask.shape) == 2, (
158
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
159
+ "for padding purposes (0 indicating padding). "
160
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
161
+ )
162
+
163
+ # launching the triton kernel for just one token will actually be slower
164
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
165
+
166
+ last_state = None
167
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
168
+ last_state = past_key_values[self.layer_idx]
169
+
170
+ cu_seqlens = kwargs.get('cu_seqlens', None)
171
+ if self.use_short_conv:
172
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
173
+ if last_state is not None:
174
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
175
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
176
+ q, conv_state_q = self.q_conv1d(
177
+ x=self.q_proj(hidden_states),
178
+ mask=conv_mask,
179
+ cache=conv_state_q,
180
+ output_final_state=use_cache,
181
+ cu_seqlens=cu_seqlens
182
+ )
183
+ k, conv_state_k = self.k_conv1d(
184
+ x=self.k_proj(hidden_states),
185
+ mask=conv_mask,
186
+ cache=conv_state_k,
187
+ output_final_state=use_cache,
188
+ cu_seqlens=cu_seqlens
189
+ )
190
+ v, conv_state_v = self.v_conv1d(
191
+ x=self.v_proj(hidden_states),
192
+ mask=conv_mask,
193
+ cache=conv_state_v,
194
+ output_final_state=use_cache,
195
+ cu_seqlens=cu_seqlens
196
+ )
197
+ else:
198
+ q = self.q_proj(hidden_states)
199
+ k = self.k_proj(hidden_states)
200
+ v = self.v_proj(hidden_states)
201
+
202
+ # dealing with left-padding
203
+ if attention_mask is not None:
204
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
205
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim)
206
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim)
207
+ if self.feature_map_fn is not None:
208
+ q, k = map(self.feature_map_fn, (q, k))
209
+
210
+ seqlen_offset, max_seqlen = 0, q.shape[1]
211
+ if past_key_values is not None:
212
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
213
+ max_seqlen = q.shape[1] + seqlen_offset
214
+
215
+ if attention_mask is not None:
216
+ # to deliminate the offsets of padding tokens
217
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
218
+ max_seqlen = q.shape[1] + max(seqlen_offset)
219
+
220
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
221
+
222
+ if self.num_kv_groups > 1:
223
+ k = repeat(k, 'b t h d -> b t (h g) d', g=self.num_kv_groups)
224
+ v = repeat(v, 'b t (h d) -> b t (h g) d', d=self.head_v_dim, g=self.num_kv_groups)
225
+ else:
226
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
227
+
228
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
229
+ if mode == 'chunk':
230
+ o, recurrent_state = chunk_retention(
231
+ q=q,
232
+ k=k,
233
+ v=v,
234
+ initial_state=recurrent_state,
235
+ output_final_state=use_cache,
236
+ cu_seqlens=cu_seqlens,
237
+ head_first=False
238
+ )
239
+ elif mode == 'fused_chunk':
240
+ o, recurrent_state = fused_chunk_retention(
241
+ q=q,
242
+ k=k,
243
+ v=v,
244
+ initial_state=recurrent_state,
245
+ output_final_state=use_cache,
246
+ cu_seqlens=cu_seqlens,
247
+ head_first=False
248
+ )
249
+ elif mode == 'parallel':
250
+ o, recurrent_state = parallel_retention(
251
+ q=q,
252
+ k=k,
253
+ v=v,
254
+ cu_seqlens=cu_seqlens,
255
+ head_first=False
256
+ )
257
+ elif mode == 'fused_recurrent':
258
+ o, recurrent_state = fused_recurrent_retention(
259
+ q=q,
260
+ k=k,
261
+ v=v,
262
+ initial_state=recurrent_state,
263
+ output_final_state=use_cache,
264
+ cu_seqlens=cu_seqlens,
265
+ head_first=False
266
+ )
267
+ else:
268
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
269
+
270
+ if past_key_values is not None:
271
+ past_key_values.update(
272
+ recurrent_state=recurrent_state,
273
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
274
+ layer_idx=self.layer_idx,
275
+ offset=q.shape[1]
276
+ )
277
+
278
+ if self.use_output_gate:
279
+ g = self.g_proj(hidden_states)
280
+ if self.fuse_norm_and_gate:
281
+ g = rearrange(g, 'b t (h d) -> b t h d', d=self.head_v_dim)
282
+ o = self.g_norm_swish_gate(o, g)
283
+ o = rearrange(o, 'b t h d -> b t (h d)')
284
+ else:
285
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
286
+ o = o * self.gate_fn(g)
287
+ else:
288
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
289
+ o = self.o_proj(o)
290
+
291
+ return o, None, past_key_values
292
+
293
+ def state_size(self, **kwargs) -> int:
294
+ state_size = self.key_dim * self.head_v_dim
295
+ for module in self.children():
296
+ if isinstance(module, ShortConvolution):
297
+ state_size += module.state_size
298
+ return state_size
fla/models/abc/__pycache__/configuration_abc.cpython-311.pyc ADDED
Binary file (4.02 kB). View file
 
fla/models/abc/__pycache__/modeling_abc.cpython-311.pyc ADDED
Binary file (19.2 kB). View file
 
fla/models/abc/configuration_abc.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class ABCConfig(PretrainedConfig):
9
+
10
+ model_type = 'abc'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ gate_low_rank_dim: int = 16,
17
+ clamp_min: float = -32,
18
+ clamp_max: float = 32,
19
+ hidden_ratio: Optional[int] = 4,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ num_slots: Optional[int] = 64,
24
+ use_short_conv: bool = False,
25
+ conv_size: int = 4,
26
+ exapnd_k: float = 0.5,
27
+ exapnd_v: float = 1,
28
+ hidden_act: str = "swish",
29
+ max_position_embeddings: int = 2048,
30
+ elementwise_affine: Optional[bool] = True,
31
+ norm_eps: float = 1e-6,
32
+ use_rope: bool = True,
33
+ attn: Optional[Dict] = None,
34
+ use_cache: bool = True,
35
+ pad_token_id: int = None,
36
+ bos_token_id: int = 1,
37
+ eos_token_id: int = 2,
38
+ tie_word_embeddings: bool = False,
39
+ initializer_range: float = 0.006,
40
+ fuse_norm: bool = True,
41
+ fuse_swiglu: bool = True,
42
+ fuse_cross_entropy: bool = True,
43
+ vocab_size: int = 32000,
44
+ **kwargs
45
+ ):
46
+ self.hidden_size = hidden_size
47
+ self.gate_low_rank_dim = gate_low_rank_dim
48
+ self.clamp_min = clamp_min
49
+ self.clamp_max = clamp_max
50
+ self.hidden_ratio = hidden_ratio
51
+ self.intermediate_size = intermediate_size
52
+ self.num_hidden_layers = num_hidden_layers
53
+ self.num_heads = num_heads
54
+ self.num_slots = num_slots
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.expand_k = exapnd_k
58
+ self.expand_v = exapnd_v
59
+ self.hidden_act = hidden_act
60
+ self.max_position_embeddings = max_position_embeddings
61
+ self.elementwise_affine = elementwise_affine
62
+ self.norm_eps = norm_eps
63
+ self.use_rope = use_rope
64
+ self.attn = attn
65
+ self.use_cache = use_cache
66
+ self.initializer_range = initializer_range
67
+
68
+ self.fuse_norm = fuse_norm
69
+ self.fuse_swiglu = fuse_swiglu
70
+ self.fuse_cross_entropy = fuse_cross_entropy
71
+ self.vocab_size = vocab_size
72
+
73
+ if attn is not None:
74
+ if not isinstance(attn, Dict):
75
+ raise ValueError("attn must be a dictionary")
76
+ if 'layers' not in attn:
77
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
78
+ if 'num_heads' not in attn:
79
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
80
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
81
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
82
+ attn['window_size'] = attn.get('window_size', None)
83
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
84
+
85
+ super().__init__(
86
+ pad_token_id=pad_token_id,
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ tie_word_embeddings=tie_word_embeddings,
90
+ **kwargs,
91
+ )
fla/models/bitnet/modeling_bitnet.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.bitattn import BitAttention
19
+ from fla.models.bitnet.configuration_bitnet import BitNetConfig
20
+ from fla.models.utils import Cache
21
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
22
+ from fla.modules.activations import swiglu
23
+ from fla.modules.fused_bitlinear import FusedBitLinear
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class BitNetMLP(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ hidden_ratio: Optional[int] = None,
37
+ intermediate_size: Optional[int] = None,
38
+ hidden_act: str = 'swish',
39
+ fuse_swiglu: bool = True
40
+ ) -> BitNetMLP:
41
+ super().__init__()
42
+
43
+ self.hidden_size = hidden_size
44
+ # the final number of params is `hidden_ratio * hidden_size^2`
45
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
46
+ if hidden_ratio is None:
47
+ hidden_ratio = 4
48
+ if intermediate_size is None:
49
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
50
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+ self.fuse_swiglu = fuse_swiglu
55
+
56
+ if hidden_act != 'swish':
57
+ raise ValueError(f'Unsupported hidden_act: {hidden_act}')
58
+
59
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
60
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
61
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
62
+
63
+ def forward(
64
+ self,
65
+ x: torch.Tensor,
66
+ **kwargs: Unpack[Any]
67
+ ) -> torch.Tensor:
68
+ gate, y = self.gate_proj(x), self.up_proj(x)
69
+ return self.down_proj(swiglu(gate, y))
70
+
71
+
72
+ class BitNetBlock(nn.Module):
73
+
74
+ def __init__(self, config: BitNetConfig, layer_idx: int):
75
+ super().__init__()
76
+
77
+ self.config = config
78
+ self.layer_idx = layer_idx
79
+
80
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
81
+ self.attn = BitAttention(
82
+ hidden_size=config.hidden_size,
83
+ num_heads=config.num_heads,
84
+ num_kv_heads=config.num_kv_heads,
85
+ window_size=config.window_size,
86
+ rope_theta=config.rope_theta,
87
+ max_position_embeddings=config.max_position_embeddings,
88
+ layer_idx=layer_idx
89
+ )
90
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
91
+ self.mlp = BitNetMLP(
92
+ hidden_size=config.hidden_size,
93
+ hidden_ratio=config.hidden_ratio,
94
+ intermediate_size=config.intermediate_size,
95
+ hidden_act=config.hidden_act,
96
+ fuse_swiglu=config.fuse_swiglu
97
+ )
98
+
99
+ def forward(
100
+ self,
101
+ hidden_states: torch.Tensor,
102
+ attention_mask: Optional[torch.Tensor] = None,
103
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
104
+ output_attentions: Optional[bool] = False,
105
+ use_cache: Optional[bool] = False,
106
+ **kwargs: Unpack[Any]
107
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
108
+
109
+ residual = hidden_states
110
+ hidden_states = self.attn_norm(hidden_states)
111
+ hidden_states, attentions, past_key_values = self.attn(
112
+ hidden_states=hidden_states,
113
+ attention_mask=attention_mask,
114
+ past_key_values=past_key_values,
115
+ use_cache=use_cache,
116
+ output_attentions=output_attentions,
117
+ **kwargs
118
+ )
119
+ if self.config.fuse_norm:
120
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
121
+ else:
122
+ hidden_states = residual + hidden_states
123
+ residual = hidden_states
124
+ hidden_states = self.mlp_norm(hidden_states)
125
+ hidden_states = self.mlp(hidden_states, **kwargs)
126
+ hidden_states = residual + hidden_states
127
+
128
+ outputs = (hidden_states,)
129
+
130
+ if output_attentions:
131
+ outputs += (attentions,)
132
+
133
+ if use_cache:
134
+ outputs += (past_key_values,)
135
+
136
+ return outputs
137
+
138
+
139
+ class BitNetPreTrainedModel(PreTrainedModel):
140
+
141
+ config_class = BitNetConfig
142
+ base_model_prefix = 'model'
143
+ supports_gradient_checkpointing = True
144
+ _no_split_modules = ['BitNetBlock']
145
+ _supports_cache_class = True
146
+
147
+ def __init__(self, *inputs, **kwargs):
148
+ super().__init__(*inputs, **kwargs)
149
+
150
+ def _init_weights(
151
+ self,
152
+ module: nn.Module,
153
+ rescale_prenorm_residual: bool = False,
154
+ num_residuals_per_layer: int = 2,
155
+ ):
156
+ if isinstance(module, (nn.Linear, nn.Conv1d, FusedBitLinear)):
157
+ # Slightly different from the TF version which uses truncated_normal for initialization
158
+ # cf https://github.com/pytorch/pytorch/pull/5617
159
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
160
+ if module.bias is not None:
161
+ nn.init.zeros_(module.bias)
162
+ elif isinstance(module, nn.Embedding):
163
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
164
+ elif hasattr(module, 'reset_parameters'):
165
+ module.reset_parameters()
166
+
167
+ if rescale_prenorm_residual:
168
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
169
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
170
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
171
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
172
+ #
173
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
174
+ p = None
175
+ if hasattr(module, 'o_proj'):
176
+ p = module.o_proj.weight
177
+ elif hasattr(module, 'down_proj'):
178
+ p = module.down_proj.weight
179
+ if p is not None:
180
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
181
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
182
+ # We need to reinit p since this code could be called multiple times
183
+ # Having just p *= scale would repeatedly scale it down
184
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
185
+ with torch.no_grad():
186
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
187
+
188
+
189
+ class BitNetModel(BitNetPreTrainedModel):
190
+
191
+ def __init__(
192
+ self,
193
+ config: BitNetConfig
194
+ ) -> BitNetModel:
195
+ super().__init__(config)
196
+ self.padding_idx = config.pad_token_id
197
+ self.vocab_size = config.vocab_size
198
+
199
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
200
+ self.layers = nn.ModuleList([BitNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
201
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
202
+
203
+ self.gradient_checkpointing = False
204
+
205
+ self.post_init()
206
+
207
+ def get_input_embeddings(self):
208
+ return self.embeddings
209
+
210
+ def set_input_embeddings(self, value):
211
+ self.embeddings = value
212
+
213
+ def forward(
214
+ self,
215
+ input_ids: Optional[torch.LongTensor] = None,
216
+ attention_mask: Optional[torch.Tensor] = None,
217
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
218
+ inputs_embeds: Optional[torch.FloatTensor] = None,
219
+ use_cache: Optional[bool] = None,
220
+ output_attentions: Optional[bool] = None,
221
+ output_hidden_states: Optional[bool] = None,
222
+ return_dict: Optional[bool] = None,
223
+ **kwargs: Unpack[Any]
224
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
225
+ if output_attentions:
226
+ warnings.warn(
227
+ "`BitNetModel` does not support output attention weights now, so `output_attentions` is set to `False`."
228
+ )
229
+ output_attentions = False
230
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
231
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
232
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
233
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
234
+
235
+ # retrieve input_ids and inputs_embeds
236
+ if input_ids is not None and inputs_embeds is not None:
237
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
238
+ elif input_ids is None and inputs_embeds is None:
239
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
240
+
241
+ if use_cache and not isinstance(past_key_values, Cache):
242
+ past_key_values = Cache.from_legacy_cache(past_key_values)
243
+
244
+ if inputs_embeds is None:
245
+ inputs_embeds = self.embeddings(input_ids)
246
+
247
+ # embed positions
248
+ hidden_states = inputs_embeds
249
+
250
+ if self.gradient_checkpointing and self.training:
251
+ if use_cache:
252
+ logger.warning_once(
253
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
254
+ )
255
+ use_cache = False
256
+
257
+ all_hidden_states = () if output_hidden_states else None
258
+ all_attns = () if output_attentions else None
259
+ next_cache = None
260
+
261
+ for layer in self.layers:
262
+ if output_hidden_states:
263
+ all_hidden_states += (hidden_states,)
264
+
265
+ if self.gradient_checkpointing and self.training:
266
+ layer_outputs = self._gradient_checkpointing_func(
267
+ layer.__call__,
268
+ hidden_states,
269
+ attention_mask,
270
+ past_key_values,
271
+ output_attentions,
272
+ use_cache,
273
+ **kwargs
274
+ )
275
+ else:
276
+ layer_outputs = layer(
277
+ hidden_states,
278
+ attention_mask=attention_mask,
279
+ past_key_values=past_key_values,
280
+ output_attentions=output_attentions,
281
+ use_cache=use_cache,
282
+ **kwargs
283
+ )
284
+
285
+ hidden_states = layer_outputs[0]
286
+
287
+ if use_cache:
288
+ next_cache = layer_outputs[2 if output_attentions else 1]
289
+
290
+ if output_attentions:
291
+ all_attns += (layer_outputs[1],)
292
+
293
+ hidden_states = self.norm(hidden_states)
294
+
295
+ # add hidden states from the last decoder layer
296
+ if output_hidden_states:
297
+ all_hidden_states += (hidden_states,)
298
+
299
+ if not return_dict:
300
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
301
+
302
+ return BaseModelOutputWithPast(
303
+ last_hidden_state=hidden_states,
304
+ past_key_values=next_cache,
305
+ hidden_states=all_hidden_states,
306
+ attentions=all_attns
307
+ )
308
+
309
+
310
+ class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin):
311
+
312
+ _tied_weights_keys = ["lm_head.weight"]
313
+
314
+ def __init__(self, config):
315
+ super().__init__(config)
316
+ self.model = BitNetModel(config)
317
+ self.vocab_size = config.vocab_size
318
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
319
+ self.criterion = None
320
+
321
+ # Initialize weights and apply final processing
322
+ self.post_init()
323
+
324
+ def get_input_embeddings(self):
325
+ return self.model.embeddings
326
+
327
+ def set_input_embeddings(self, value):
328
+ self.model.embeddings = value
329
+
330
+ def get_output_embeddings(self):
331
+ return self.lm_head
332
+
333
+ def set_output_embeddings(self, new_embeddings):
334
+ self.lm_head = new_embeddings
335
+
336
+ def set_decoder(self, decoder):
337
+ self.model = decoder
338
+
339
+ def get_decoder(self):
340
+ return self.model
341
+
342
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
343
+ def prepare_inputs_for_generation(
344
+ self,
345
+ input_ids: torch.LongTensor = None,
346
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
347
+ attention_mask: Optional[torch.Tensor] = None,
348
+ inputs_embeds: Optional[torch.Tensor] = None,
349
+ use_cache: bool = True,
350
+ logits_to_keep: Optional[int] = None,
351
+ **kwargs
352
+ ):
353
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
354
+ if past_key_values is not None and len(past_key_values) > 0:
355
+ input_ids = input_ids[:, -1:]
356
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
357
+ if inputs_embeds is not None and len(past_key_values) == 0:
358
+ model_inputs = {'inputs_embeds': inputs_embeds}
359
+ else:
360
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
361
+ # recompiles graphs as the stride of the inputs is a guard.
362
+ # Ref: https://github.com/huggingface/transformers/pull/29114
363
+ # TODO: use `next_tokens` directly instead.
364
+ model_inputs = {'input_ids': input_ids.contiguous()}
365
+
366
+ if logits_to_keep is not None:
367
+ model_inputs['logits_to_keep'] = logits_to_keep
368
+
369
+ model_inputs.update({
370
+ 'past_key_values': past_key_values,
371
+ 'use_cache': use_cache,
372
+ 'attention_mask': attention_mask,
373
+ })
374
+ return model_inputs
375
+
376
+ def forward(
377
+ self,
378
+ input_ids: torch.LongTensor = None,
379
+ attention_mask: Optional[torch.Tensor] = None,
380
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
381
+ inputs_embeds: Optional[torch.FloatTensor] = None,
382
+ labels: Optional[torch.LongTensor] = None,
383
+ use_cache: Optional[bool] = None,
384
+ output_attentions: Optional[bool] = None,
385
+ output_hidden_states: Optional[bool] = None,
386
+ return_dict: Optional[bool] = None,
387
+ logits_to_keep: Optional[int] = 0,
388
+ **kwargs: Unpack[Any]
389
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
390
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
391
+ output_hidden_states = (
392
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
393
+ )
394
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
395
+
396
+ outputs = self.model(
397
+ input_ids=input_ids,
398
+ attention_mask=attention_mask,
399
+ past_key_values=past_key_values,
400
+ inputs_embeds=inputs_embeds,
401
+ use_cache=use_cache,
402
+ output_attentions=output_attentions,
403
+ output_hidden_states=output_hidden_states,
404
+ return_dict=return_dict,
405
+ **kwargs
406
+ )
407
+
408
+ hidden_states = outputs[0]
409
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
410
+
411
+ loss, logits = None, None
412
+ if not fuse_linear_and_cross_entropy or labels is None:
413
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
414
+ if labels is not None:
415
+ if getattr(self, 'criterion', None) is None:
416
+ if fuse_linear_and_cross_entropy:
417
+ criterion = FusedLinearCrossEntropyLoss()
418
+ elif self.config.fuse_cross_entropy:
419
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
420
+ else:
421
+ criterion = nn.CrossEntropyLoss()
422
+ else:
423
+ criterion = self.criterion
424
+ labels = labels.to(hidden_states.device)
425
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
426
+ if fuse_linear_and_cross_entropy:
427
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
428
+ else:
429
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
430
+
431
+ if not return_dict:
432
+ output = (logits,) + outputs[1:]
433
+ return (loss,) + output if loss is not None else output
434
+
435
+ return CausalLMOutputWithPast(
436
+ loss=loss,
437
+ logits=logits,
438
+ past_key_values=outputs.past_key_values,
439
+ hidden_states=outputs.hidden_states,
440
+ attentions=outputs.attentions,
441
+ )
fla/models/delta_net/__pycache__/configuration_delta_net.cpython-311.pyc ADDED
Binary file (3.99 kB). View file
 
fla/models/delta_net/modeling_delta_net.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.delta_net import DeltaNet
20
+ from fla.models.delta_net.configuration_delta_net import DeltaNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as DeltaNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ if TYPE_CHECKING:
29
+ from transformers.processing_utils import Unpack
30
+
31
+
32
+ class DeltaNetBlock(nn.Module):
33
+ def __init__(self, config: DeltaNetConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = DeltaNet(
53
+ mode=config.attn_mode,
54
+ hidden_size=config.hidden_size,
55
+ expand_k=config.expand_k,
56
+ expand_v=config.expand_v,
57
+ num_heads=config.num_heads,
58
+ use_gate=config.use_gate,
59
+ use_beta=config.use_beta,
60
+ use_short_conv=config.use_short_conv,
61
+ use_output_norm=config.use_output_norm,
62
+ conv_size=config.conv_size,
63
+ qk_norm=config.qk_norm,
64
+ qk_activation=config.qk_activation,
65
+ norm_eps=config.norm_eps,
66
+ layer_idx=layer_idx
67
+ )
68
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
69
+ self.mlp = DeltaNetMLP(
70
+ hidden_size=config.hidden_size,
71
+ hidden_ratio=config.hidden_ratio,
72
+ intermediate_size=config.intermediate_size,
73
+ hidden_act=config.hidden_act,
74
+ fuse_swiglu=config.fuse_swiglu
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ hidden_states: torch.Tensor,
80
+ attention_mask: Optional[torch.Tensor] = None,
81
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
82
+ use_cache: Optional[bool] = False,
83
+ output_attentions: Optional[bool] = False,
84
+ **kwargs: Unpack[Dict]
85
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
86
+ residual = hidden_states
87
+ hidden_states = self.attn_norm(hidden_states)
88
+ hidden_states, attentions, past_key_values = self.attn(
89
+ hidden_states=hidden_states,
90
+ attention_mask=attention_mask,
91
+ past_key_values=past_key_values,
92
+ use_cache=use_cache,
93
+ output_attentions=output_attentions,
94
+ **kwargs
95
+ )
96
+ if self.config.fuse_norm:
97
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
98
+ else:
99
+ hidden_states = residual + hidden_states
100
+ residual = hidden_states
101
+ hidden_states = self.mlp_norm(hidden_states)
102
+ hidden_states = self.mlp(hidden_states, **kwargs)
103
+ hidden_states = residual + hidden_states
104
+
105
+ outputs = (hidden_states, attentions, past_key_values)
106
+
107
+ return outputs
108
+
109
+
110
+ class DeltaNetPreTrainedModel(PreTrainedModel):
111
+
112
+ config_class = DeltaNetConfig
113
+ base_model_prefix = 'model'
114
+ supports_gradient_checkpointing = True
115
+ _no_split_modules = ['DeltaNetBlock']
116
+ _supports_cache_class = True
117
+
118
+ def __init__(self, *inputs, **kwargs):
119
+ super().__init__(*inputs, **kwargs)
120
+
121
+ def _init_weights(
122
+ self,
123
+ module: nn.Module,
124
+ prenorm_residual_strategy: Optional[str] = 'rescale',
125
+ num_residuals_per_layer: int = 2,
126
+ ):
127
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
128
+ # Slightly different from the TF version which uses truncated_normal for initialization
129
+ # cf https://github.com/pytorch/pytorch/pull/5617
130
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
131
+ if module.bias is not None:
132
+ nn.init.zeros_(module.bias)
133
+ elif isinstance(module, nn.Embedding):
134
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
135
+ elif hasattr(module, 'reset_parameters'):
136
+ module.reset_parameters()
137
+
138
+ if prenorm_residual_strategy is not None:
139
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
140
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
141
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
142
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
143
+ #
144
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
145
+ p = None
146
+ if hasattr(module, 'o_proj'):
147
+ p = module.o_proj.weight
148
+ elif hasattr(module, 'down_proj'):
149
+ p = module.down_proj.weight
150
+ if p is not None:
151
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
152
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
153
+ # We need to reinit p since this code could be called multiple times
154
+ # Having just p *= scale would repeatedly scale it down
155
+ if prenorm_residual_strategy == 'rescale':
156
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
157
+ with torch.no_grad():
158
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
159
+ elif prenorm_residual_strategy == 'zero':
160
+ nn.init.zeros_(p)
161
+ else:
162
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
163
+
164
+
165
+ class DeltaNetModel(DeltaNetPreTrainedModel):
166
+
167
+ def __init__(self, config: DeltaNetConfig):
168
+ super().__init__(config)
169
+ self.padding_idx = config.pad_token_id
170
+ self.vocab_size = config.vocab_size
171
+
172
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
173
+ self.layers = nn.ModuleList([DeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
174
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
175
+
176
+ self.gradient_checkpointing = False
177
+
178
+ self.post_init()
179
+
180
+ def get_input_embeddings(self):
181
+ return self.embeddings
182
+
183
+ def set_input_embeddings(self, value):
184
+ self.embeddings = value
185
+
186
+ def forward(
187
+ self,
188
+ input_ids: Optional[torch.LongTensor] = None,
189
+ attention_mask: Optional[torch.Tensor] = None, # noqa
190
+ inputs_embeds: Optional[torch.FloatTensor] = None,
191
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
192
+ use_cache: Optional[bool] = None,
193
+ output_attentions: Optional[bool] = None,
194
+ output_hidden_states: Optional[bool] = None,
195
+ return_dict: Optional[bool] = None,
196
+ **kwargs: Unpack[Dict]
197
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
198
+ if output_attentions:
199
+ warnings.warn("`DeltaNetModel` does not `output_attentions` now, setting it to `False`.")
200
+ output_attentions = False
201
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
202
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
203
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
204
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
205
+
206
+ # retrieve input_ids and inputs_embeds
207
+ if input_ids is not None and inputs_embeds is not None:
208
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
209
+ if input_ids is None and inputs_embeds is None:
210
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
211
+
212
+ if inputs_embeds is None:
213
+ inputs_embeds = self.embeddings(input_ids)
214
+ hidden_states = inputs_embeds
215
+
216
+ if use_cache and not isinstance(past_key_values, Cache):
217
+ past_key_values = Cache.from_legacy_cache(past_key_values)
218
+
219
+ if self.gradient_checkpointing and self.training and use_cache:
220
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
221
+ use_cache = False
222
+
223
+ all_hidden_states = () if output_hidden_states else None
224
+ all_attns = () if output_attentions else None
225
+ for layer in self.layers:
226
+ if output_hidden_states:
227
+ all_hidden_states += (hidden_states,)
228
+
229
+ if self.gradient_checkpointing and self.training:
230
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
231
+ layer.__call__,
232
+ hidden_states,
233
+ attention_mask,
234
+ past_key_values,
235
+ use_cache,
236
+ output_attentions,
237
+ **kwargs
238
+ )
239
+ else:
240
+ hidden_states, attentions, past_key_values = layer(
241
+ hidden_states,
242
+ attention_mask=attention_mask,
243
+ past_key_values=past_key_values,
244
+ use_cache=use_cache,
245
+ output_attentions=output_attentions,
246
+ **kwargs
247
+ )
248
+
249
+ if output_attentions:
250
+ all_attns += (attentions,)
251
+
252
+ hidden_states = self.norm(hidden_states)
253
+
254
+ # add hidden states from the last decoder layer
255
+ if output_hidden_states:
256
+ all_hidden_states += (hidden_states,)
257
+
258
+ if not return_dict:
259
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
260
+ return BaseModelOutputWithPast(
261
+ last_hidden_state=hidden_states,
262
+ past_key_values=past_key_values,
263
+ hidden_states=all_hidden_states,
264
+ attentions=all_attns
265
+ )
266
+
267
+
268
+ class DeltaNetForCausalLM(DeltaNetPreTrainedModel, GenerationMixin):
269
+
270
+ _tied_weights_keys = ["lm_head.weight"]
271
+
272
+ def __init__(self, config):
273
+ super().__init__(config)
274
+ self.model = DeltaNetModel(config)
275
+ self.vocab_size = config.vocab_size
276
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
277
+ self.criterion = None
278
+
279
+ # Initialize weights and apply final processing
280
+ self.post_init()
281
+
282
+ def get_input_embeddings(self):
283
+ return self.model.embeddings
284
+
285
+ def set_input_embeddings(self, value):
286
+ self.model.embeddings = value
287
+
288
+ def get_output_embeddings(self):
289
+ return self.lm_head
290
+
291
+ def set_output_embeddings(self, new_embeddings):
292
+ self.lm_head = new_embeddings
293
+
294
+ def set_decoder(self, decoder):
295
+ self.model = decoder
296
+
297
+ def get_decoder(self):
298
+ return self.model
299
+
300
+ def generate(self, *args, **kwargs):
301
+ try:
302
+ return super().generate(*args, **kwargs)
303
+ except AttributeError as exception:
304
+ if 'past_key_values' in str(exception):
305
+ raise AttributeError(
306
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
307
+ f"which is not supported for {self.__class__.__name__}. "
308
+ f"Try another generation strategy instead. "
309
+ f"For the available generation strategies, check this doc: "
310
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
311
+ )
312
+ else:
313
+ raise exception
314
+
315
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
316
+ def prepare_inputs_for_generation(
317
+ self,
318
+ input_ids: torch.LongTensor = None,
319
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
320
+ attention_mask: Optional[torch.Tensor] = None,
321
+ inputs_embeds: Optional[torch.Tensor] = None,
322
+ use_cache: bool = True,
323
+ logits_to_keep: Optional[int] = None,
324
+ **kwargs
325
+ ):
326
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
327
+ if past_key_values is not None and len(past_key_values) > 0:
328
+ input_ids = input_ids[:, -1:]
329
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
330
+ if inputs_embeds is not None and len(past_key_values) == 0:
331
+ model_inputs = {'inputs_embeds': inputs_embeds}
332
+ else:
333
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
334
+ # recompiles graphs as the stride of the inputs is a guard.
335
+ # Ref: https://github.com/huggingface/transformers/pull/29114
336
+ # TODO: use `next_tokens` directly instead.
337
+ model_inputs = {'input_ids': input_ids.contiguous()}
338
+
339
+ if logits_to_keep is not None:
340
+ model_inputs['logits_to_keep'] = logits_to_keep
341
+
342
+ model_inputs.update({
343
+ 'past_key_values': past_key_values,
344
+ 'use_cache': use_cache,
345
+ 'attention_mask': attention_mask,
346
+ })
347
+ return model_inputs
348
+
349
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
350
+ def forward(
351
+ self,
352
+ input_ids: torch.LongTensor = None,
353
+ attention_mask: Optional[torch.Tensor] = None,
354
+ inputs_embeds: Optional[torch.Tensor] = None,
355
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
356
+ labels: Optional[torch.LongTensor] = None,
357
+ use_cache: Optional[bool] = None,
358
+ output_attentions: Optional[bool] = None,
359
+ output_hidden_states: Optional[bool] = None,
360
+ return_dict: Optional[bool] = None,
361
+ logits_to_keep: Optional[int] = 0,
362
+ **kwargs: Unpack[Dict]
363
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
364
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
365
+ output_hidden_states = (
366
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
367
+ )
368
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
369
+
370
+ outputs = self.model(
371
+ input_ids=input_ids,
372
+ attention_mask=attention_mask,
373
+ inputs_embeds=inputs_embeds,
374
+ past_key_values=past_key_values,
375
+ use_cache=use_cache,
376
+ output_attentions=output_attentions,
377
+ output_hidden_states=output_hidden_states,
378
+ return_dict=return_dict,
379
+ **kwargs
380
+ )
381
+
382
+ hidden_states = outputs[0]
383
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
384
+
385
+ loss, logits = None, None
386
+ if not fuse_linear_and_cross_entropy or labels is None:
387
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
388
+ if labels is not None:
389
+ if getattr(self, 'criterion', None) is None:
390
+ if fuse_linear_and_cross_entropy:
391
+ criterion = FusedLinearCrossEntropyLoss()
392
+ elif self.config.fuse_cross_entropy:
393
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
394
+ else:
395
+ criterion = nn.CrossEntropyLoss()
396
+ else:
397
+ criterion = self.criterion
398
+ labels = labels.to(hidden_states.device)
399
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
400
+ if fuse_linear_and_cross_entropy:
401
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
402
+ else:
403
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
404
+
405
+ if not return_dict:
406
+ output = (logits,) + outputs[1:]
407
+ return (loss,) + output if loss is not None else output
408
+
409
+ return CausalLMOutputWithPast(
410
+ loss=loss,
411
+ logits=logits,
412
+ past_key_values=outputs.past_key_values,
413
+ hidden_states=outputs.hidden_states,
414
+ attentions=outputs.attentions,
415
+ )
fla/models/forgetting_transformer/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.forgetting_transformer.configuration_forgetting_transformer import ForgettingTransformerConfig
6
+ from fla.models.forgetting_transformer.modeling_forgetting_transformer import (
7
+ ForgettingTransformerForCausalLM,
8
+ ForgettingTransformerModel
9
+ )
10
+
11
+ AutoConfig.register(ForgettingTransformerConfig.model_type, ForgettingTransformerConfig)
12
+ AutoModel.register(ForgettingTransformerConfig, ForgettingTransformerModel)
13
+ AutoModelForCausalLM.register(ForgettingTransformerConfig, ForgettingTransformerForCausalLM)
14
+
15
+
16
+ __all__ = ['ForgettingTransformerConfig', 'ForgettingTransformerForCausalLM', 'ForgettingTransformerModel']
fla/models/forgetting_transformer/modeling_forgetting_transformer.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.forgetting_attn import ForgettingAttention
19
+ from fla.models.forgetting_transformer.configuration_forgetting_transformer import ForgettingTransformerConfig
20
+ from fla.models.utils import Cache
21
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
22
+ from fla.modules import GatedMLP as ForgettingTransformerMLP
23
+ from fla.modules import RMSNorm
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class ForgettingTransformerBlock(nn.Module):
33
+
34
+ def __init__(self, config: ForgettingTransformerConfig, layer_idx: int):
35
+ super().__init__()
36
+
37
+ self.config = config
38
+ self.layer_idx = layer_idx
39
+
40
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
41
+ self.attn = ForgettingAttention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.num_heads,
44
+ num_kv_heads=config.num_kv_heads,
45
+ qkv_bias=config.qkv_bias,
46
+ qk_norm=config.qk_norm,
47
+ window_size=config.window_size,
48
+ use_output_gate=config.use_output_gate,
49
+ layer_idx=layer_idx
50
+ )
51
+
52
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
53
+ self.mlp = ForgettingTransformerMLP(
54
+ hidden_size=config.hidden_size,
55
+ hidden_ratio=config.hidden_ratio,
56
+ intermediate_size=config.intermediate_size,
57
+ hidden_act=config.hidden_act,
58
+ fuse_swiglu=config.fuse_swiglu
59
+ )
60
+
61
+ def forward(
62
+ self,
63
+ hidden_states: torch.Tensor,
64
+ attention_mask: Optional[torch.Tensor] = None,
65
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
66
+ output_attentions: Optional[bool] = False,
67
+ use_cache: Optional[bool] = False,
68
+ **kwargs: Unpack[Any]
69
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
70
+
71
+ residual = hidden_states
72
+ hidden_states = self.attn_norm(hidden_states)
73
+ hidden_states, attentions, past_key_values = self.attn(
74
+ hidden_states=hidden_states,
75
+ attention_mask=attention_mask,
76
+ past_key_values=past_key_values,
77
+ use_cache=use_cache,
78
+ output_attentions=output_attentions,
79
+ **kwargs
80
+ )
81
+ if self.config.fuse_norm:
82
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
83
+ else:
84
+ hidden_states = residual + hidden_states
85
+ residual = hidden_states
86
+ hidden_states = self.mlp_norm(hidden_states)
87
+ hidden_states = self.mlp(hidden_states, **kwargs)
88
+ hidden_states = residual + hidden_states
89
+
90
+ outputs = (hidden_states,)
91
+
92
+ if output_attentions:
93
+ outputs += (attentions,)
94
+
95
+ if use_cache:
96
+ outputs += (past_key_values,)
97
+
98
+ return outputs
99
+
100
+
101
+ class ForgettingTransformerPreTrainedModel(PreTrainedModel):
102
+
103
+ config_class = ForgettingTransformerConfig
104
+ base_model_prefix = 'model'
105
+ supports_gradient_checkpointing = True
106
+ _no_split_modules = ['ForgettingTransformerBlock']
107
+ _supports_cache_class = True
108
+
109
+ def __init__(self, *inputs, **kwargs):
110
+ super().__init__(*inputs, **kwargs)
111
+
112
+ def _init_weights(
113
+ self,
114
+ module: nn.Module,
115
+ rescale_prenorm_residual: bool = False,
116
+ num_residuals_per_layer: int = 2,
117
+ ):
118
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
119
+ # Slightly different from the TF version which uses truncated_normal for initialization
120
+ # cf https://github.com/pytorch/pytorch/pull/5617
121
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
122
+ if module.bias is not None:
123
+ nn.init.zeros_(module.bias)
124
+ elif isinstance(module, nn.Embedding):
125
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
126
+ elif hasattr(module, 'reset_parameters'):
127
+ module.reset_parameters()
128
+
129
+ if rescale_prenorm_residual:
130
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
131
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
132
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
133
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
134
+ #
135
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
136
+ p = None
137
+ if hasattr(module, 'o_proj'):
138
+ p = module.o_proj.weight
139
+ elif hasattr(module, 'down_proj'):
140
+ p = module.down_proj.weight
141
+ if p is not None:
142
+ # Special Scaled Initialization --> There are 2 Layer Norms per ForgettingTransformer Block
143
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
144
+ # We need to reinit p since this code could be called multiple times
145
+ # Having just p *= scale would repeatedly scale it down
146
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
147
+ with torch.no_grad():
148
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
149
+
150
+
151
+ class ForgettingTransformerModel(ForgettingTransformerPreTrainedModel):
152
+
153
+ def __init__(
154
+ self,
155
+ config: ForgettingTransformerConfig
156
+ ) -> ForgettingTransformerModel:
157
+ super().__init__(config)
158
+ self.padding_idx = config.pad_token_id
159
+ self.vocab_size = config.vocab_size
160
+
161
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
162
+ self.layers = nn.ModuleList([
163
+ ForgettingTransformerBlock(config, layer_idx)
164
+ for layer_idx in range(config.num_hidden_layers)
165
+ ])
166
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
167
+
168
+ self.gradient_checkpointing = False
169
+
170
+ self.post_init()
171
+
172
+ def get_input_embeddings(self):
173
+ return self.embeddings
174
+
175
+ def set_input_embeddings(self, value):
176
+ self.embeddings = value
177
+
178
+ def forward(
179
+ self,
180
+ input_ids: Optional[torch.LongTensor] = None,
181
+ attention_mask: Optional[torch.Tensor] = None,
182
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
183
+ inputs_embeds: Optional[torch.FloatTensor] = None,
184
+ use_cache: Optional[bool] = None,
185
+ output_attentions: Optional[bool] = None,
186
+ output_hidden_states: Optional[bool] = None,
187
+ return_dict: Optional[bool] = None,
188
+ **kwargs: Unpack[Any]
189
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
190
+ if output_attentions:
191
+ warnings.warn(
192
+ "`ForgettingTransformerModel` does not support output attention weights now, "
193
+ "so `output_attentions` is set to `False`."
194
+ )
195
+ output_attentions = False
196
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
197
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
198
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
199
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
200
+
201
+ # retrieve input_ids and inputs_embeds
202
+ if input_ids is not None and inputs_embeds is not None:
203
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
204
+ elif input_ids is None and inputs_embeds is None:
205
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
206
+
207
+ if use_cache and not isinstance(past_key_values, Cache):
208
+ past_key_values = Cache.from_legacy_cache(past_key_values)
209
+
210
+ if inputs_embeds is None:
211
+ inputs_embeds = self.embeddings(input_ids)
212
+
213
+ # embed positions
214
+ hidden_states = inputs_embeds
215
+
216
+ if self.gradient_checkpointing and self.training:
217
+ if use_cache:
218
+ logger.warning_once(
219
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
220
+ )
221
+ use_cache = False
222
+
223
+ all_hidden_states = () if output_hidden_states else None
224
+ all_attns = () if output_attentions else None
225
+ next_cache = None
226
+
227
+ for layer in self.layers:
228
+ if output_hidden_states:
229
+ all_hidden_states += (hidden_states,)
230
+
231
+ if self.gradient_checkpointing and self.training:
232
+ layer_outputs = self._gradient_checkpointing_func(
233
+ layer.__call__,
234
+ hidden_states,
235
+ attention_mask,
236
+ past_key_values,
237
+ output_attentions,
238
+ use_cache,
239
+ **kwargs
240
+ )
241
+ else:
242
+ layer_outputs = layer(
243
+ hidden_states,
244
+ attention_mask=attention_mask,
245
+ past_key_values=past_key_values,
246
+ output_attentions=output_attentions,
247
+ use_cache=use_cache,
248
+ **kwargs
249
+ )
250
+
251
+ hidden_states = layer_outputs[0]
252
+
253
+ if use_cache:
254
+ next_cache = layer_outputs[2 if output_attentions else 1]
255
+
256
+ if output_attentions:
257
+ all_attns += (layer_outputs[1],)
258
+
259
+ hidden_states = self.norm(hidden_states)
260
+
261
+ # add hidden states from the last decoder layer
262
+ if output_hidden_states:
263
+ all_hidden_states += (hidden_states,)
264
+
265
+ if not return_dict:
266
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
267
+
268
+ return BaseModelOutputWithPast(
269
+ last_hidden_state=hidden_states,
270
+ past_key_values=next_cache,
271
+ hidden_states=all_hidden_states,
272
+ attentions=all_attns
273
+ )
274
+
275
+
276
+ class ForgettingTransformerForCausalLM(ForgettingTransformerPreTrainedModel, GenerationMixin):
277
+
278
+ _tied_weights_keys = ["lm_head.weight"]
279
+
280
+ def __init__(self, config):
281
+ super().__init__(config)
282
+ self.model = ForgettingTransformerModel(config)
283
+ self.vocab_size = config.vocab_size
284
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
285
+ self.criterion = None
286
+
287
+ # Initialize weights and apply final processing
288
+ self.post_init()
289
+
290
+ def get_input_embeddings(self):
291
+ return self.model.embeddings
292
+
293
+ def set_input_embeddings(self, value):
294
+ self.model.embeddings = value
295
+
296
+ def get_output_embeddings(self):
297
+ return self.lm_head
298
+
299
+ def set_output_embeddings(self, new_embeddings):
300
+ self.lm_head = new_embeddings
301
+
302
+ def set_decoder(self, decoder):
303
+ self.model = decoder
304
+
305
+ def get_decoder(self):
306
+ return self.model
307
+
308
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
309
+ def prepare_inputs_for_generation(
310
+ self,
311
+ input_ids: torch.LongTensor = None,
312
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
313
+ attention_mask: Optional[torch.Tensor] = None,
314
+ inputs_embeds: Optional[torch.Tensor] = None,
315
+ use_cache: bool = True,
316
+ logits_to_keep: Optional[int] = None,
317
+ **kwargs
318
+ ):
319
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
320
+ if past_key_values is not None and len(past_key_values) > 0:
321
+ input_ids = input_ids[:, -1:]
322
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
323
+ if inputs_embeds is not None and len(past_key_values) == 0:
324
+ model_inputs = {'inputs_embeds': inputs_embeds}
325
+ else:
326
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
327
+ # recompiles graphs as the stride of the inputs is a guard.
328
+ # Ref: https://github.com/huggingface/transformers/pull/29114
329
+ # TODO: use `next_tokens` directly instead.
330
+ model_inputs = {'input_ids': input_ids.contiguous()}
331
+
332
+ if logits_to_keep is not None:
333
+ model_inputs['logits_to_keep'] = logits_to_keep
334
+
335
+ model_inputs.update({
336
+ 'past_key_values': past_key_values,
337
+ 'use_cache': use_cache,
338
+ 'attention_mask': attention_mask,
339
+ })
340
+ return model_inputs
341
+
342
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
343
+ def forward(
344
+ self,
345
+ input_ids: torch.LongTensor = None,
346
+ attention_mask: Optional[torch.Tensor] = None,
347
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
348
+ inputs_embeds: Optional[torch.FloatTensor] = None,
349
+ labels: Optional[torch.LongTensor] = None,
350
+ use_cache: Optional[bool] = None,
351
+ output_attentions: Optional[bool] = None,
352
+ output_hidden_states: Optional[bool] = None,
353
+ return_dict: Optional[bool] = None,
354
+ logits_to_keep: Optional[int] = 0,
355
+ **kwargs: Unpack[Any]
356
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
357
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
358
+ output_hidden_states = (
359
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
360
+ )
361
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
362
+
363
+ outputs = self.model(
364
+ input_ids=input_ids,
365
+ attention_mask=attention_mask,
366
+ past_key_values=past_key_values,
367
+ inputs_embeds=inputs_embeds,
368
+ use_cache=use_cache,
369
+ output_attentions=output_attentions,
370
+ output_hidden_states=output_hidden_states,
371
+ return_dict=return_dict,
372
+ **kwargs
373
+ )
374
+
375
+ hidden_states = outputs[0]
376
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
377
+ logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])
378
+
379
+ loss = None
380
+ if labels is not None:
381
+ if getattr(self, 'criterion', None) is None:
382
+ if fuse_linear_and_cross_entropy:
383
+ criterion = FusedLinearCrossEntropyLoss()
384
+ elif self.config.fuse_cross_entropy:
385
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
386
+ else:
387
+ criterion = nn.CrossEntropyLoss()
388
+ else:
389
+ criterion = self.criterion
390
+ # Enable model parallelism
391
+ labels = labels.to(hidden_states.device)
392
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
393
+ if fuse_linear_and_cross_entropy:
394
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
395
+ else:
396
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
397
+
398
+ if not return_dict:
399
+ output = (logits,) + outputs[1:]
400
+ return (loss,) + output if loss is not None else output
401
+
402
+ return CausalLMOutputWithPast(
403
+ loss=loss,
404
+ logits=logits,
405
+ past_key_values=outputs.past_key_values,
406
+ hidden_states=outputs.hidden_states,
407
+ attentions=outputs.attentions,
408
+ )
fla/models/gated_deltanet/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig
6
+ from fla.models.gated_deltanet.modeling_gated_deltanet import GatedDeltaNetForCausalLM, GatedDeltaNetModel
7
+
8
+ AutoConfig.register(GatedDeltaNetConfig.model_type, GatedDeltaNetConfig)
9
+ AutoModel.register(GatedDeltaNetConfig, GatedDeltaNetModel)
10
+ AutoModelForCausalLM.register(GatedDeltaNetConfig, GatedDeltaNetForCausalLM)
11
+
12
+ __all__ = ['GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel']
fla/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-311.pyc ADDED
Binary file (19.4 kB). View file
 
fla/models/gated_deltaproduct/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
2
+
3
+ from fla.models.gated_deltaproduct.configuration_gated_deltaproduct import GatedDeltaProductConfig
4
+ from fla.models.gated_deltaproduct.modeling_gated_deltaproduct import GatedDeltaProductForCausalLM, GatedDeltaProductModel
5
+
6
+ AutoConfig.register(GatedDeltaProductConfig.model_type, GatedDeltaProductConfig)
7
+ AutoModel.register(GatedDeltaProductConfig, GatedDeltaProductModel)
8
+ AutoModelForCausalLM.register(GatedDeltaProductConfig, GatedDeltaProductForCausalLM)
9
+
10
+ __all__ = [
11
+ "GatedDeltaProductConfig",
12
+ "GatedDeltaProductForCausalLM",
13
+ "GatedDeltaProductModel",
14
+ ]
fla/models/gated_deltaproduct/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (838 Bytes). View file
 
fla/models/gla/__pycache__/configuration_gla.cpython-311.pyc ADDED
Binary file (4.14 kB). View file
 
fla/models/gla/__pycache__/modeling_gla.cpython-311.pyc ADDED
Binary file (19.4 kB). View file
 
fla/models/hgrn/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (722 Bytes). View file
 
fla/models/hgrn/configuration_hgrn.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class HGRNConfig(PretrainedConfig):
9
+
10
+ model_type = 'hgrn'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "fused_recurrent",
16
+ hidden_size: int = 2048,
17
+ num_hidden_layers: int = 24,
18
+ expand_ratio: Optional[int] = 1,
19
+ use_short_conv: bool = False,
20
+ conv_size: int = 4,
21
+ use_lower_bound: bool = True,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ elementwise_affine: Optional[bool] = True,
27
+ norm_eps: float = 1e-6,
28
+ attn: Optional[Dict] = None,
29
+ use_cache: bool = True,
30
+ pad_token_id: int = None,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ tie_word_embeddings: bool = False,
34
+ initializer_range: float = 0.006,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ **kwargs
40
+ ):
41
+ self.attn_mode = attn_mode
42
+ self.hidden_size = hidden_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.expand_ratio = expand_ratio
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.use_lower_bound = use_lower_bound
48
+ self.max_position_embeddings = max_position_embeddings
49
+ self.hidden_ratio = hidden_ratio
50
+ self.intermediate_size = intermediate_size
51
+ self.elementwise_affine = elementwise_affine
52
+ self.attn = attn
53
+ self.norm_eps = norm_eps
54
+ self.hidden_act = hidden_act
55
+ self.use_cache = use_cache
56
+ self.initializer_range = initializer_range
57
+
58
+ self.fuse_norm = fuse_norm
59
+ self.fuse_swiglu = fuse_swiglu
60
+ self.fuse_cross_entropy = fuse_cross_entropy
61
+ self.vocab_size = vocab_size
62
+
63
+ if attn is not None:
64
+ if not isinstance(attn, Dict):
65
+ raise ValueError("attn must be a dictionary")
66
+ if 'layers' not in attn:
67
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
68
+ if 'num_heads' not in attn:
69
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
70
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
71
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
72
+ attn['window_size'] = attn.get('window_size', None)
73
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
74
+
75
+ super().__init__(
76
+ pad_token_id=pad_token_id,
77
+ bos_token_id=bos_token_id,
78
+ eos_token_id=eos_token_id,
79
+ tie_word_embeddings=tie_word_embeddings,
80
+ **kwargs,
81
+ )
fla/models/hgrn2/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config
6
+ from fla.models.hgrn2.modeling_hgrn2 import HGRN2ForCausalLM, HGRN2Model
7
+
8
+ AutoConfig.register(HGRN2Config.model_type, HGRN2Config)
9
+ AutoModel.register(HGRN2Config, HGRN2Model)
10
+ AutoModelForCausalLM.register(HGRN2Config, HGRN2ForCausalLM)
11
+
12
+
13
+ __all__ = ['HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model']
fla/models/hgrn2/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (731 Bytes). View file
 
fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-311.pyc ADDED
Binary file (19.7 kB). View file