Azrail commited on
Commit
a78a30e
·
verified ·
1 Parent(s): 06918f2

Upload SmalLmForCausalLM

Browse files
Files changed (4) hide show
  1. README.md +1 -0
  2. config.json +5 -1
  3. config.py +134 -0
  4. model.py +882 -0
README.md CHANGED
@@ -2,6 +2,7 @@
2
  library_name: transformers
3
  tags:
4
  - generated_from_trainer
 
5
  model-index:
6
  - name: smallm_350
7
  results: []
 
2
  library_name: transformers
3
  tags:
4
  - generated_from_trainer
5
+ - smallm
6
  model-index:
7
  - name: smallm_350
8
  results: []
config.json CHANGED
@@ -4,6 +4,10 @@
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.1,
 
 
 
 
7
  "balancing_coef": 0.0001,
8
  "bos_token_id": 1,
9
  "embedding_dropout": 0.0,
@@ -38,7 +42,7 @@
38
  "sliding_window_attention": true,
39
  "sliding_window_context": 1024,
40
  "sliding_window_period": 4,
41
- "static_residual": false,
42
  "token_experts": 3,
43
  "torch_dtype": "float32",
44
  "transformers_version": "4.50.3",
 
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "config.SmalLmConfig",
9
+ "AutoModelForCausalLM": "model.SmalLmForCausalLM"
10
+ },
11
  "balancing_coef": 0.0001,
12
  "bos_token_id": 1,
13
  "embedding_dropout": 0.0,
 
42
  "sliding_window_attention": true,
43
  "sliding_window_context": 1024,
44
  "sliding_window_period": 4,
45
+ "static_residual": true,
46
  "token_experts": 3,
47
  "torch_dtype": "float32",
48
  "transformers_version": "4.50.3",
config.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from transformers import PretrainedConfig
3
+ from typing import Optional
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ class SmalLmConfig(PretrainedConfig):
9
+ """
10
+ Base config for all SmalLm models
11
+
12
+ Raises:
13
+ ValueError: Positional_bias_type must be in suported types
14
+ ValueError: In case of rope positional_bias_type head_size can't be anything
15
+ """
16
+ model_type = "smallm"
17
+
18
+ def __init__(
19
+ self,
20
+ # global model params
21
+ hidden_size: int = 512,
22
+ intermediate_size: int = 2048,
23
+ mlp_bias: bool = False,
24
+ num_hidden_layers: int = 27,
25
+ rms_norm_eps: float = 1e-6,
26
+ rms_affine: bool = False,
27
+ initializer_range: float = 0.02,
28
+ output_hidden_states: bool = False,
29
+ output_attentions: bool = False,
30
+ use_cache: bool = True,
31
+ sliding_window_attention: bool = True,
32
+ sliding_window_context: int = 1024,
33
+ sliding_window_period: int = 4,
34
+ embedding_dropout: float = 0.0,
35
+ layer_dropout: float = 0.1,
36
+ max_seq_len: int = 2048,
37
+ original_seq_len: int | None = None,
38
+ tie_word_embeddings: bool = True,
39
+ # attention params
40
+ num_attention_heads: int = 9,
41
+ num_kv_heads: int = 3,
42
+ head_size: Optional[int] = None,
43
+ attention_dropout: float = 0.1,
44
+ positional_bias_type: str = "rope",
45
+ high_rotations: int = 32,
46
+ low_rotations: int = 1,
47
+ attention_bias: bool = False,
48
+ rope_base: int = 100000,
49
+ # MoE params
50
+ use_moe: bool = True,
51
+ moe_period: int = 3,
52
+ expert_size: int = 256,
53
+ shared_experts: int = 2,
54
+ routed_experts: int = 16,
55
+ token_experts: int = 4,
56
+ noisy_experts: bool = False,
57
+ moe_bias: bool = False,
58
+ balancing_coef: float = 1e-4,
59
+ no_moe_layers: int = 5,
60
+ # extra params
61
+ vocab_size: int = 60000,
62
+ bos_token_id: int = 1,
63
+ eos_token_id: int = 0,
64
+ pad_token_id: int = 0,
65
+ static_residual: bool = False,
66
+ **kwargs,
67
+ ):
68
+ if positional_bias_type not in ["alibi", "rope"]:
69
+ raise ValueError(
70
+ f"positional_bias_type must be 'alibi' or 'rope', got {positional_bias_type}"
71
+ )
72
+ self.static_residual = not static_residual
73
+ self.no_moe_layers = no_moe_layers
74
+ self.moe_bias = moe_bias
75
+ self.balancing_coef = balancing_coef
76
+ self.noisy_experts = noisy_experts
77
+ self.high_rotations = high_rotations
78
+ self.low_rotations = low_rotations
79
+ self.positional_bias_type = positional_bias_type
80
+ self.vocab_size = vocab_size
81
+ self.hidden_size = hidden_size
82
+ self.mlp_bias = mlp_bias
83
+ self.num_hidden_layers = num_hidden_layers
84
+ self.num_attention_heads = num_attention_heads
85
+ self.num_kv_heads = num_kv_heads
86
+ self.attention_dropout = attention_dropout
87
+ self.rms_norm_eps = rms_norm_eps
88
+ self.max_seq_len = max_seq_len
89
+ self.use_cache = use_cache
90
+ self.initializer_range = initializer_range
91
+ self.embedding_dropout = embedding_dropout
92
+ self.rms_affine = rms_affine
93
+ self.output_hidden_states = output_hidden_states
94
+ self.output_attentions = output_attentions
95
+ self.layer_dropout = layer_dropout
96
+ self.use_moe = use_moe
97
+ self.moe_period = moe_period
98
+ self.expert_size = expert_size
99
+ self.shared_experts = shared_experts
100
+ self.routed_experts = routed_experts
101
+ self.token_experts = token_experts
102
+ self.intermediate_size = intermediate_size
103
+ self.attention_bias = attention_bias
104
+ self.rope_base = rope_base
105
+ self.head_size = head_size if head_size else hidden_size // num_attention_heads
106
+ self.original_seq_len = (
107
+ original_seq_len if original_seq_len is not None else max_seq_len
108
+ )
109
+
110
+ self.sliding_window_attention = sliding_window_attention
111
+ self.sliding_window_context = sliding_window_context
112
+ self.sliding_window_period = sliding_window_period
113
+ if sliding_window_attention and sliding_window_context > max_seq_len:
114
+ logger.warning(
115
+ f"sliding_window_context more than max_seq_len, \
116
+ set sliding_window_context to {max_seq_len}"
117
+ )
118
+ self.sliding_window_context = max_seq_len
119
+ if not sliding_window_attention:
120
+ self.sliding_window_context = max_seq_len
121
+
122
+ if self.head_size % 2 != 0 and self.positional_bias_type == "rope":
123
+ raise ValueError("Head size should divided by 2")
124
+
125
+ super().__init__(
126
+ bos_token_id=bos_token_id,
127
+ eos_token_id=eos_token_id,
128
+ pad_token_id=pad_token_id,
129
+ tie_word_embeddings=tie_word_embeddings,
130
+ **kwargs,
131
+ )
132
+
133
+
134
+ __all__ = ["SmalLmConfig"]
model.py ADDED
@@ -0,0 +1,882 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel, GenerationMixin
5
+ from transformers.cache_utils import Cache, DynamicCache
6
+ from transformers.modeling_outputs import (
7
+ BaseModelOutputWithPast,
8
+ CausalLMOutputWithPast,
9
+ )
10
+ from .config import SmalLmConfig
11
+ from typing import Optional
12
+ import logging
13
+ from einops import rearrange
14
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
15
+ from einops._torch_specific import allow_ops_in_compiled_graph
16
+
17
+ allow_ops_in_compiled_graph()
18
+ from transformers.utils import is_flash_attn_2_available
19
+
20
+ if is_flash_attn_2_available():
21
+ from flash_attn import flash_attn_varlen_func
22
+ from flash_attn.bert_padding import unpad_input, pad_input
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class SwiGLU(nn.Module):
29
+ def __init__(
30
+ self, input_size: int, hidden_size: int, bias: bool = False, *args, **kwargs
31
+ ):
32
+ super().__init__(*args, **kwargs)
33
+ self.input_size = input_size
34
+ self.hidden_size = hidden_size
35
+ self.up_proj = nn.Linear(input_size, hidden_size * 2, bias=bias)
36
+ self.down_proj = nn.Linear(hidden_size, input_size, bias=bias)
37
+
38
+ def forward(self, x):
39
+ up_gate = self.up_proj(x)
40
+ up, gate = rearrange(up_gate, "... (d span) -> span ... d", d=self.hidden_size)
41
+ down = F.silu(gate) * up
42
+ return self.down_proj(down)
43
+
44
+
45
+ class Router(nn.Module):
46
+ """
47
+ Router for distribution of tokens by experts in MoE
48
+ """
49
+ def __init__(self, config: SmalLmConfig, *args, **kwargs):
50
+ super().__init__(*args, **kwargs)
51
+ self.config = config
52
+ self.experts_to_select = self.config.token_experts - self.config.shared_experts
53
+ self.gate = nn.Linear(config.hidden_size, config.routed_experts, bias=False)
54
+ self.gate_noise = (
55
+ nn.Linear(config.hidden_size, config.routed_experts, bias=False)
56
+ if config.noisy_experts is True
57
+ else None
58
+ )
59
+ self.bias_coef = config.balancing_coef
60
+ self.register_buffer(
61
+ "bias", torch.zeros(config.routed_experts), persistent=True
62
+ )
63
+ self.register_buffer(
64
+ "expert_counts", torch.zeros(config.routed_experts), persistent=False
65
+ )
66
+
67
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
68
+ # calculating with fp32 for stability
69
+ # num_tokens n_shared_experts
70
+ gate_logits = self.gate(x)
71
+ if self.gate_noise is not None:
72
+ gate_logits_noise = F.softplus(self.gate_noise(x))
73
+ gate_logits_noise = torch.randn_like(gate_logits_noise) * gate_logits_noise
74
+ gate_logits = gate_logits + gate_logits_noise
75
+
76
+ gate_weights = gate_logits.sigmoid()
77
+ original_weights = gate_weights
78
+
79
+ gate_weights = gate_weights + self.bias
80
+
81
+ _, top_experts_idx = torch.topk(gate_weights, self.experts_to_select, dim=-1)
82
+ counts = torch.bincount(
83
+ top_experts_idx.flatten(), minlength=self.config.routed_experts
84
+ ).detach()
85
+ if self.training:
86
+ self.expert_counts += counts
87
+ top_experts_weights = original_weights.gather(1, top_experts_idx)
88
+ top_experts_weights = top_experts_weights / top_experts_weights.sum(
89
+ dim=-1, keepdim=True
90
+ )
91
+ return top_experts_idx, top_experts_weights.type_as(x), counts.tolist()
92
+
93
+ def update_bias(self):
94
+ mean = self.expert_counts.float().mean()
95
+ delta = self.bias_coef * torch.sign(mean - self.expert_counts)
96
+ self.bias += delta
97
+ self.expert_counts.zero_()
98
+
99
+
100
+ class MoE(nn.Module):
101
+ """
102
+ MoE experts, contains shared and routed experts,
103
+ like DeepSeek MoE, also use Auxiliary-Loss-Free Load Balancing
104
+ ref: https://arxiv.org/abs/2408.15664
105
+ """
106
+ def __init__(self, config: SmalLmConfig, *args, **kwargs):
107
+ super().__init__(*args, **kwargs)
108
+ self.config = config
109
+ self.shared_experts = SwiGLU(
110
+ config.hidden_size,
111
+ config.shared_experts * config.expert_size,
112
+ config.moe_bias,
113
+ )
114
+ self.routed_experts = nn.ModuleList(
115
+ [
116
+ SwiGLU(config.hidden_size, config.expert_size, config.moe_bias)
117
+ for _ in range(config.routed_experts)
118
+ ]
119
+ )
120
+ self.router = Router(config)
121
+
122
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
123
+ shape = x.size()
124
+ x = x.view(-1, self.config.hidden_size)
125
+ experts_idx, experts_weights, counts = self.router(x)
126
+ out = torch.zeros_like(x)
127
+ for i, expert in enumerate(self.routed_experts):
128
+ if counts[i] == 0:
129
+ continue
130
+ idx, pos = torch.where(experts_idx == i)
131
+ out[idx] += expert(x[idx]) * experts_weights[idx, pos, None]
132
+ shared_out = self.shared_experts(x)
133
+ return (out + shared_out).view(shape)
134
+
135
+
136
+ def build_alibi_bias(config: SmalLmConfig) -> torch.Tensor:
137
+ """
138
+ Build ALiBi bias for specified number of heads
139
+ ref: https://arxiv.org/abs/2108.12409v2
140
+
141
+ Returns:
142
+ Tensor with ALiBi biases, shape: [num heads]
143
+ """
144
+ bias = (
145
+ 2**-8
146
+ / config.num_attention_heads
147
+ * torch.arange(1, config.num_attention_heads + 1).float()
148
+ )
149
+ return bias
150
+
151
+
152
+ def calc_rotation(num_rotaitions, dim, base, seq_len) -> torch.Tensor:
153
+ """
154
+ In terms of wavelength calculate the position for a specific rotation frequence
155
+ """
156
+ return (
157
+ dim
158
+ * torch.log(torch.tensor(seq_len).float() / (num_rotaitions * 2 * torch.pi))
159
+ / torch.log(torch.tensor(base))
160
+ )
161
+
162
+
163
+ def get_ramp_interpolation(min_idx, max_idx, thetas_dim, eps=1e-6) -> torch.Tensor:
164
+ """
165
+ Ramp interpolation function to maintain high frequencies and expand low frequencies
166
+ """
167
+ if min_idx == max_idx:
168
+ max_idx += eps
169
+ mult = (torch.arange(thetas_dim) - min_idx) / (max_idx - min_idx)
170
+ mult = torch.clamp(mult, 0, 1)
171
+ return 1 - mult
172
+
173
+
174
+ def build_rope_bias(config: SmalLmConfig) -> torch.Tensor:
175
+ """
176
+ Build RoPE bias for specified dimension and maximum sequence length
177
+ uses complex space for simplicity and convenience
178
+ ref: https://arxiv.org/abs/2104.09864v5
179
+ Also use NTK-by-parts interpolation method
180
+ ref: https://arxiv.org/abs/2309.00071
181
+ good explanation: https://blog.eleuther.ai/yarn/
182
+
183
+ Args:
184
+ config (SmalLmConfig): base model config
185
+
186
+ Returns:
187
+ torch.Tensor: Complex values for rotations, shape: [seq_len, head_size]
188
+ """
189
+ dim = config.head_size
190
+
191
+ theta = 1.0 / (config.rope_base ** (torch.arange(0, dim, 2).float() / dim))
192
+
193
+ # neural tangent kernel by part korrection
194
+ if config.max_seq_len > config.original_seq_len:
195
+ scale = config.max_seq_len / config.original_seq_len
196
+ # from idea that lambda = 2pi / theta_i and also lambda = seq_len / num_rotations, lambda - wavelen
197
+ low_interpolation_idx = max(
198
+ 0,
199
+ torch.ceil(
200
+ calc_rotation(
201
+ config.high_rotations,
202
+ dim,
203
+ config.rope_base,
204
+ config.original_seq_len,
205
+ )
206
+ ).item(),
207
+ )
208
+ high_interpolation_idx = min(
209
+ dim - 1,
210
+ torch.floor(
211
+ calc_rotation(
212
+ config.low_rotations, dim, config.rope_base, config.original_seq_len
213
+ )
214
+ ).item(),
215
+ )
216
+ interpolation_mult = get_ramp_interpolation(
217
+ low_interpolation_idx, high_interpolation_idx, dim // 2
218
+ )
219
+ theta = (1 - interpolation_mult) * theta / scale + interpolation_mult * theta
220
+
221
+ seq_idx = torch.arange(config.max_seq_len)
222
+ seq_theta = torch.outer(seq_idx, theta)
223
+ bias = torch.polar(torch.ones_like(seq_theta), seq_theta)
224
+ return bias
225
+
226
+
227
+ def apply_rope_bias(x: torch.Tensor, precompute_bias: torch.Tensor) -> torch.Tensor:
228
+ """
229
+ Apply rope bias in complex space
230
+
231
+ Args:
232
+ x (torch.Tensor): input embeddings for head
233
+ precompute_bias (torch.Tensor): precomputed rope bias
234
+
235
+ Returns:
236
+ torch.Tensor: rotated embeddings
237
+ """
238
+ ini_dtype = x.dtype
239
+ # for numerical stability convert to fp32
240
+ x = rearrange(x.float(), "b n s (d i) -> b n s d i", i=2).contiguous()
241
+ x = torch.view_as_complex(x)
242
+ x = x * precompute_bias
243
+ x = torch.view_as_real(x)
244
+ x = rearrange(x, "b n s d i -> b n s (d i)")
245
+ return x.to(ini_dtype)
246
+
247
+
248
+ def flash_attention_forward(
249
+ module: nn.Module,
250
+ x: torch.Tensor,
251
+ query: torch.Tensor,
252
+ key: torch.Tensor,
253
+ value: torch.Tensor,
254
+ attention_mask: torch.Tensor,
255
+ alibi_slope: Optional[torch.Tensor],
256
+ ) -> torch.Tensor:
257
+ query = rearrange(query, "b n s d -> b s n d")
258
+ key = rearrange(key, "b n s d -> b s n d")
259
+ value = rearrange(value, "b n s d -> b s n d")
260
+ query, idx_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(query, attention_mask)
261
+ key, _, cu_seqlens_k, max_seqlen_k, _ = unpad_input(key, attention_mask)
262
+ value, _, _, _, _ = unpad_input(value, attention_mask)
263
+
264
+ key = key.contiguous()
265
+ value = value.contiguous()
266
+ query = query.contiguous()
267
+
268
+ attention_probs = flash_attn_varlen_func(
269
+ query,
270
+ key,
271
+ value,
272
+ cu_seqlens_q=cu_seqlens_q,
273
+ cu_seqlens_k=cu_seqlens_k,
274
+ max_seqlen_q=max_seqlen_q,
275
+ max_seqlen_k=max_seqlen_k,
276
+ dropout_p=module.config.attention_dropout if module.training else 0.0,
277
+ causal=True,
278
+ alibi_slopes=alibi_slope if module.config.attention_bias == "alibi" else None,
279
+ )
280
+ attention_probs = pad_input(attention_probs, idx_q, x.size(0), x.size(1))
281
+ out = rearrange(attention_probs, "b s n d -> b s (n d)")
282
+ return out, None
283
+
284
+
285
+ def sdpa_attention_forward(
286
+ module: nn.Module,
287
+ x: torch.Tensor,
288
+ query: torch.Tensor,
289
+ key: torch.Tensor,
290
+ value: torch.Tensor,
291
+ attention_mask: torch.Tensor,
292
+ alibi_slope: Optional[torch.Tensor],
293
+ ) -> torch.Tensor:
294
+ is_causal = attention_mask is None and query.size(-2) > 1
295
+
296
+ attention_probs = F.scaled_dot_product_attention(
297
+ query,
298
+ key,
299
+ value,
300
+ attn_mask=attention_mask,
301
+ enable_gqa=True,
302
+ is_causal=is_causal,
303
+ dropout_p=module.config.attention_dropout if module.training else 0.0,
304
+ )
305
+ out = rearrange(attention_probs, "b n s d -> b s (n d)")
306
+
307
+ return out, None
308
+
309
+
310
+ def eager_attention_forward(
311
+ module: nn.Module,
312
+ x: torch.Tensor,
313
+ query: torch.Tensor,
314
+ key: torch.Tensor,
315
+ value: torch.Tensor,
316
+ attention_mask: torch.Tensor,
317
+ alibi_slope: Optional[torch.Tensor],
318
+ ) -> torch.Tensor:
319
+ query = rearrange(
320
+ query,
321
+ "b (kv group) s d -> b kv group s d",
322
+ kv=module.config.num_kv_heads,
323
+ group=module.head_per_group,
324
+ )
325
+ key = rearrange(key, "b kv s d -> b kv 1 s d")
326
+ value = rearrange(value, "b kv s d -> b kv 1 s d")
327
+ attention_weights = query @ key.transpose(-1, -2)
328
+ attention_probs = F.dropout(
329
+ attention_weights / torch.sqrt(torch.tensor(value.size(-1), device=x.device)),
330
+ p=module.config.attention_dropout if module.training else 0.0,
331
+ )
332
+ if alibi_slope is not None:
333
+ alibi_slope = rearrange(
334
+ alibi_slope,
335
+ "b n s s -> b kv group s s",
336
+ kv=module.config.num_kv_heads,
337
+ group=module.head_per_group,
338
+ )
339
+ attention_probs = attention_probs + alibi_slope
340
+ elif alibi_slope is None and attention_mask is not None:
341
+ attention_mask = attention_mask.expand(
342
+ -1, module.config.num_attention_heads, -1, -1
343
+ )
344
+ attention_mask = rearrange(
345
+ attention_mask,
346
+ "b (kv group) s1 s2 -> b kv group s1 s2",
347
+ kv=module.config.num_kv_heads,
348
+ group=module.head_per_group,
349
+ )
350
+ attention_probs = attention_probs + attention_mask
351
+ attention_probs = F.softmax(attention_probs, dim=-1)
352
+ attention_probs = attention_probs @ value
353
+ out = rearrange(attention_probs, "b kv group s d -> b s (kv group d)")
354
+ return out, attention_weights
355
+
356
+
357
+ ALL_ATTENTION_FUNCTIONS = {
358
+ "eager": eager_attention_forward,
359
+ "sdpa": sdpa_attention_forward,
360
+ "flash_attention_2": flash_attention_forward,
361
+ }
362
+
363
+
364
+ class CausalSelfAttention(nn.Module):
365
+ """
366
+ Scaled dot product attention with supports different implementations
367
+ currently available: sdpa, flash, native torch
368
+ """
369
+ def __init__(self, config: SmalLmConfig, layer_idx: int, *args, **kwargs):
370
+ super().__init__(*args, **kwargs)
371
+ if config.num_attention_heads % config.num_kv_heads != 0:
372
+ raise ValueError("Num attention heads should divided by num kv heads")
373
+
374
+ self.config = config
375
+ self.layer_idx = layer_idx
376
+ self.head_per_group = config.num_attention_heads // config.num_kv_heads
377
+ self.q_proj = nn.Linear(
378
+ config.hidden_size,
379
+ config.head_size * config.num_attention_heads,
380
+ bias=config.attention_bias,
381
+ )
382
+ self.kv_proj = nn.Linear(
383
+ config.hidden_size,
384
+ config.head_size * config.num_kv_heads * 2,
385
+ bias=config.attention_bias,
386
+ )
387
+ self.out_proj = nn.Linear(
388
+ config.head_size * config.num_attention_heads,
389
+ config.hidden_size,
390
+ bias=config.attention_bias,
391
+ )
392
+
393
+ def forward(
394
+ self,
395
+ x: torch.Tensor,
396
+ attention_mask: torch.Tensor,
397
+ past_key_values: Optional[Cache | torch.FloatTensor],
398
+ cache_position: Optional[torch.LongTensor],
399
+ bias: torch.Tensor,
400
+ ):
401
+ q = self.q_proj(x)
402
+ kv = self.kv_proj(x)
403
+ q = rearrange(q, "b s (n d) -> b n s d", n=self.config.num_attention_heads)
404
+ k, v = rearrange(kv, "b s (n d q) -> q b n s d", q=2, d=self.config.head_size)
405
+
406
+ if self.config.positional_bias_type == "rope":
407
+ k = apply_rope_bias(k, bias)
408
+ q = apply_rope_bias(q, bias)
409
+
410
+ if past_key_values is not None:
411
+ # for static cache
412
+ cach_kwargs = {"cache_position": cache_position}
413
+ k, v = past_key_values.update(
414
+ key_states=k,
415
+ value_states=v,
416
+ layer_idx=self.layer_idx,
417
+ cache_kwargs=cach_kwargs,
418
+ )
419
+
420
+ attention_interface = eager_attention_forward
421
+ if self.config._attn_implementation != "eager":
422
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
423
+ self.config._attn_implementation
424
+ ]
425
+
426
+ out, attention_weights = attention_interface(
427
+ self,
428
+ x,
429
+ q,
430
+ k,
431
+ v,
432
+ attention_mask,
433
+ bias if self.config.positional_bias_type == "alibi" else None,
434
+ )
435
+
436
+ out = self.out_proj(out)
437
+ return out, attention_weights
438
+
439
+
440
+ class WeightedResidual(nn.Module):
441
+ """
442
+ Weighted residual connection, possibly learn skip weight
443
+ """
444
+ def __init__(self, config: SmalLmConfig, *args, **kwargs):
445
+ super().__init__(*args, **kwargs)
446
+ self.weight = nn.Parameter(
447
+ torch.ones(config.hidden_size), requires_grad=config.static_residual
448
+ )
449
+
450
+ def forward(self, short, long):
451
+ return self.weight * short + long
452
+
453
+
454
+ class Block(nn.Module):
455
+ def __init__(self, config: SmalLmConfig, layer_idx: int, *args, **kwargs):
456
+ super().__init__(*args, **kwargs)
457
+ self.attn_norm = nn.RMSNorm(
458
+ config.hidden_size,
459
+ eps=config.rms_norm_eps,
460
+ elementwise_affine=config.rms_affine,
461
+ )
462
+ self.ffn_norm = nn.RMSNorm(
463
+ config.hidden_size,
464
+ eps=config.rms_norm_eps,
465
+ elementwise_affine=config.rms_affine,
466
+ )
467
+ self.dropout1 = nn.Dropout(config.layer_dropout)
468
+ self.dropout2 = nn.Dropout(config.layer_dropout)
469
+ self.attention = CausalSelfAttention(config, layer_idx)
470
+ self.mlp = (
471
+ MoE(config)
472
+ if (
473
+ config.use_moe
474
+ and layer_idx % config.moe_period == 0
475
+ and layer_idx > config.no_moe_layers
476
+ )
477
+ else SwiGLU(config.hidden_size, config.intermediate_size, config.mlp_bias)
478
+ )
479
+ self.attention_residual = WeightedResidual(config)
480
+ self.ffn_residual = WeightedResidual(config)
481
+
482
+ def forward(
483
+ self,
484
+ inputs_embeds: torch.Tensor,
485
+ attention_mask: torch.Tensor,
486
+ past_key_values: Optional[Cache | torch.FloatTensor],
487
+ output_attentions: bool,
488
+ cache_position: Optional[torch.LongTensor],
489
+ bias: torch.Tensor,
490
+ ) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
491
+ identity = inputs_embeds
492
+
493
+ # attention block
494
+ out = self.attn_norm(inputs_embeds)
495
+ out, attention_probs = self.attention(
496
+ out, attention_mask, past_key_values, cache_position, bias
497
+ )
498
+ out = self.dropout1(out)
499
+ identity = self.attention_residual(identity, out)
500
+
501
+ # swiglu / MoE block
502
+ out = self.dropout2(self.mlp(self.ffn_norm(identity)))
503
+ out = self.ffn_residual(identity, out)
504
+ if output_attentions:
505
+ return out, attention_probs
506
+ return (out,)
507
+
508
+
509
+ class SmalLmPreTrainedModel(PreTrainedModel):
510
+ config_class = SmalLmConfig
511
+ base_model_prefix = "model"
512
+ supports_gradient_checkpointing = True
513
+ _no_split_modules = ["Block"]
514
+ _skip_keys_device_placement = "past_key_values"
515
+ _supports_sdpa = True
516
+ _supports_flash_attn_2 = True
517
+
518
+ def __init__(self, *inputs, **kwargs):
519
+ super().__init__(*inputs, **kwargs)
520
+
521
+ def _init_weights(self, module):
522
+ std = self.config.initializer_range
523
+ if isinstance(module, nn.Linear):
524
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
525
+ if module.bias is not None:
526
+ torch.nn.init.zeros_(module.bias)
527
+ elif isinstance(module, nn.Embedding):
528
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
529
+ module.weight.data[self.pad_idx].zero_()
530
+
531
+
532
+ class SmalLmModel(SmalLmPreTrainedModel):
533
+ def __init__(self, config: SmalLmConfig, *args, **kwargs):
534
+ super().__init__(config, *args, **kwargs)
535
+ self.config = config
536
+ self.pad_idx = config.pad_token_id
537
+ self.pad_token_id = config.pad_token_id
538
+ self.vocab_size = config.vocab_size
539
+ self.config = config
540
+ precompute_bias = (
541
+ build_alibi_bias(config)
542
+ if config.positional_bias_type == "alibi"
543
+ else build_rope_bias(config)
544
+ )
545
+ self.register_buffer("precompute_bias", precompute_bias, persistent=False)
546
+ # не забыть про sharing weights на output голове self.embedding.weight = self.output.weight
547
+ self.embedding = nn.Embedding(
548
+ self.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
549
+ )
550
+ self.embedding_dropout = nn.Dropout(config.embedding_dropout)
551
+ self.layers = nn.ModuleList(
552
+ [Block(config, idx) for idx in range(1, config.num_hidden_layers + 1)]
553
+ )
554
+ self.out_norm = nn.RMSNorm(
555
+ config.hidden_size,
556
+ eps=config.rms_norm_eps,
557
+ elementwise_affine=config.rms_affine,
558
+ )
559
+
560
+ self.gradient_checkpointing = False
561
+ self.post_init()
562
+
563
+ def get_input_embeddings(self):
564
+ return self.embedding
565
+
566
+ def set_input_embeddings(self, value):
567
+ self.embedding = value
568
+
569
+ def forward(
570
+ self,
571
+ # input options
572
+ input_ids: torch.LongTensor = None,
573
+ attention_mask: Optional[torch.Tensor] = None,
574
+ inputs_embeds: Optional[torch.FloatTensor] = None,
575
+ # output options
576
+ output_attentions: Optional[bool] = None,
577
+ output_hidden_states: Optional[bool] = None,
578
+ return_dict: Optional[bool] = None,
579
+ # cache options
580
+ use_cache: Optional[bool] = None,
581
+ past_key_values: Optional[Cache | torch.FloatTensor] = None,
582
+ cache_position: Optional[torch.LongTensor] = None,
583
+ position_ids: Optional[torch.LongTensor] = None,
584
+ **kwargs,
585
+ ) -> tuple | BaseModelOutputWithPast:
586
+ # check additional parameters
587
+ output_hidden_states = (
588
+ output_hidden_states
589
+ if output_hidden_states is not None
590
+ else self.config.output_hidden_states
591
+ )
592
+ use_cache = (
593
+ use_cache
594
+ if use_cache is not None
595
+ else (False if self.training else self.config.use_cache)
596
+ )
597
+ return_dict = (
598
+ return_dict if return_dict is not None else self.config.return_dict
599
+ )
600
+
601
+ if input_ids is not None and inputs_embeds is not None:
602
+ raise ValueError(
603
+ "You must specify only input_ids or inputs_embeds, not both"
604
+ )
605
+
606
+ if self.training and use_cache:
607
+ use_cache = False
608
+
609
+ if inputs_embeds is None:
610
+ inputs_embeds = self.embedding(input_ids)
611
+
612
+ if use_cache and past_key_values is None:
613
+ past_key_values = DynamicCache()
614
+
615
+ # calculating position for StaticCache
616
+ if cache_position is None:
617
+ last_position = (
618
+ past_key_values.get_seq_length() if past_key_values is not None else 0
619
+ )
620
+ cache_position = torch.arange(
621
+ last_position,
622
+ last_position + inputs_embeds.size(1),
623
+ device=inputs_embeds.device,
624
+ )
625
+
626
+ causal_mask = self._get_causal_masks(
627
+ attention_mask, inputs_embeds, past_key_values, cache_position
628
+ )
629
+ if self.config.positional_bias_type == "rope":
630
+ end_pos = (
631
+ inputs_embeds.size(1)
632
+ if past_key_values is None
633
+ else cache_position[-1] + 1
634
+ )
635
+ start_pos = 0 if past_key_values is None else cache_position[0]
636
+ bias = self.precompute_bias[start_pos:end_pos]
637
+
638
+ elif self.config.positional_bias_type == "alibi":
639
+ if self.config._attn_implementation == "flash_attention_2":
640
+ bias = self.precompute_bias
641
+ else:
642
+ i = torch.arange(
643
+ (
644
+ inputs_embeds.size(1)
645
+ if past_key_values is None
646
+ else cache_position[-1] + 1
647
+ ),
648
+ device=inputs_embeds.device,
649
+ )
650
+ bias = i[:, None] - i[None, :]
651
+ bias = torch.tril(bias).expand(
652
+ inputs_embeds.size(0), self.config.num_attention_heads, -1, -1
653
+ ) * rearrange(self.precompute_bias, "n -> 1 n 1 1")
654
+ if causal_mask is not None:
655
+ causal_mask = causal_mask + bias
656
+ else:
657
+ causal_mask = bias
658
+
659
+ hidden_state = inputs_embeds
660
+ hidden_states = [hidden_state] if output_hidden_states else None
661
+ attentions = [] if output_attentions else None
662
+ for idx, layer in enumerate(self.layers, 1):
663
+ if self.gradient_checkpointing:
664
+ # for details see:
665
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3107
666
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3149
667
+ layer_out = self._gradient_checkpointing_func(
668
+ layer.__call__,
669
+ hidden_state,
670
+ causal_mask,
671
+ past_key_values,
672
+ output_attentions,
673
+ cache_position,
674
+ bias,
675
+ )
676
+ else:
677
+ layer_out = layer(
678
+ hidden_state,
679
+ causal_mask,
680
+ past_key_values,
681
+ output_attentions,
682
+ cache_position,
683
+ bias,
684
+ )
685
+ hidden_state = layer_out[0]
686
+ if output_hidden_states:
687
+ hidden_states.append(hidden_state)
688
+ if output_attentions:
689
+ attentions.append(layer_out[1])
690
+
691
+ hidden_state = self.out_norm(hidden_state)
692
+ out = BaseModelOutputWithPast(
693
+ last_hidden_state=hidden_state,
694
+ past_key_values=past_key_values if use_cache else None,
695
+ hidden_states=tuple(hidden_states) if hidden_states is not None else None,
696
+ attentions=tuple(attentions) if attentions is not None else None,
697
+ )
698
+ return out if return_dict else out.to_tuple()
699
+
700
+ def _get_causal_masks(
701
+ self,
702
+ attention_mask: Optional[torch.Tensor],
703
+ inputs_embeds: torch.Tensor,
704
+ past_key_values: Optional[torch.Tensor],
705
+ cache_position: Optional[torch.Tensor],
706
+ ):
707
+ if self.config._attn_implementation == "flash_attention_2":
708
+ if attention_mask is None:
709
+ attention_mask = torch.ones(
710
+ (inputs_embeds.size(0), inputs_embeds.size(1)),
711
+ device=inputs_embeds.device,
712
+ ).long()
713
+ return attention_mask
714
+ dtype, device = inputs_embeds.dtype, inputs_embeds.device
715
+ past_token = (
716
+ past_key_values.get_seq_length() if past_key_values is not None else 0
717
+ )
718
+ if attention_mask is not None and torch.all(attention_mask == 0.0):
719
+ return None
720
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
721
+ attention_mask=attention_mask,
722
+ inputs_embeds=inputs_embeds,
723
+ past_key_values_length=past_token,
724
+ is_training=self.training,
725
+ ):
726
+ return None
727
+
728
+ sequence_length = inputs_embeds.size(1)
729
+ target_length = (
730
+ attention_mask.size(-1)
731
+ if isinstance(attention_mask, torch.Tensor)
732
+ else past_token + sequence_length + 1
733
+ )
734
+
735
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
736
+ attention_mask=attention_mask,
737
+ sequence_length=sequence_length,
738
+ target_length=target_length,
739
+ dtype=dtype,
740
+ device=device,
741
+ cache_position=cache_position,
742
+ batch_size=inputs_embeds.size(0),
743
+ )
744
+
745
+ min_dtype = torch.finfo(dtype).min
746
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
747
+ return causal_mask
748
+
749
+ @staticmethod
750
+ def _prepare_4d_causal_attention_mask_with_cache_position(
751
+ attention_mask: Optional[torch.Tensor],
752
+ sequence_length: int,
753
+ target_length: int,
754
+ dtype: torch.dtype,
755
+ device: torch.device,
756
+ cache_position: Optional[torch.Tensor],
757
+ batch_size: int,
758
+ ):
759
+ if attention_mask is not None and attention_mask.dim() == 4:
760
+ causal_mask = attention_mask
761
+ else:
762
+ min_dtype = torch.finfo(dtype).min
763
+ causal_mask = torch.full(
764
+ (sequence_length, target_length),
765
+ fill_value=min_dtype,
766
+ dtype=dtype,
767
+ device=device,
768
+ )
769
+ if sequence_length != 1:
770
+ causal_mask = torch.triu(causal_mask, diagonal=1)
771
+ causal_mask *= torch.arange(
772
+ target_length, device=device
773
+ ) > cache_position.reshape(-1, 1)
774
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
775
+ if attention_mask is not None:
776
+ causal_mask = causal_mask.clone()
777
+ mask_length = attention_mask.shape[-1]
778
+ padding_mask = (
779
+ causal_mask[:, :, :, :mask_length]
780
+ + attention_mask[:, None, None, :]
781
+ )
782
+ padding_mask = padding_mask == 0
783
+ causal_mask[:, :, :, :mask_length] = causal_mask[
784
+ :, :, :, :mask_length
785
+ ].masked_fill(padding_mask, min_dtype)
786
+ return causal_mask
787
+
788
+
789
+ class SmalLmForCausalLM(SmalLmPreTrainedModel, GenerationMixin):
790
+ _tied_weights_keys = ["lm_head.weight"]
791
+
792
+ def __init__(self, config: SmalLmConfig, *args, **kwargs):
793
+ super().__init__(config, *args, **kwargs)
794
+ self.config = config
795
+ self.model = SmalLmModel(config)
796
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
797
+ self.post_init()
798
+
799
+ def get_output_embeddings(self):
800
+ return self.lm_head
801
+
802
+ def set_output_embeddings(self, new_embeddings):
803
+ self.lm_head = new_embeddings
804
+
805
+ def forward(
806
+ self,
807
+ # input options
808
+ input_ids: torch.LongTensor = None,
809
+ attention_mask: Optional[torch.Tensor] = None,
810
+ inputs_embeds: Optional[torch.FloatTensor] = None,
811
+ # output options
812
+ output_attentions: Optional[bool] = None,
813
+ output_hidden_states: Optional[bool] = None,
814
+ return_dict: Optional[bool] = None,
815
+ # cache options
816
+ use_cache: Optional[bool] = None,
817
+ past_key_values: Optional[Cache | torch.FloatTensor] = None,
818
+ cache_position: Optional[torch.LongTensor] = None,
819
+ # generation options
820
+ labels: Optional[torch.Tensor] = None,
821
+ logits_to_keep: int | torch.Tensor = 0,
822
+ **kwargs,
823
+ ) -> tuple | CausalLMOutputWithPast:
824
+ output_attentions = (
825
+ output_attentions
826
+ if output_attentions is not None
827
+ else self.config.output_attentions
828
+ )
829
+ output_hidden_states = (
830
+ output_hidden_states
831
+ if output_hidden_states is not None
832
+ else self.config.output_hidden_states
833
+ )
834
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
835
+ return_dict = (
836
+ return_dict if return_dict is not None else self.config.return_dict
837
+ )
838
+
839
+ model_outputs = self.model(
840
+ input_ids=input_ids,
841
+ attention_mask=attention_mask,
842
+ past_key_values=past_key_values,
843
+ inputs_embeds=inputs_embeds,
844
+ use_cache=use_cache,
845
+ output_attentions=output_attentions,
846
+ output_hidden_states=output_hidden_states,
847
+ return_dict=return_dict,
848
+ cache_position=cache_position,
849
+ **kwargs,
850
+ )
851
+
852
+ hidden_states = model_outputs[0]
853
+ slice_indices = (
854
+ slice(-logits_to_keep, None)
855
+ if isinstance(logits_to_keep, int)
856
+ else logits_to_keep
857
+ )
858
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
859
+
860
+ loss = None
861
+ if labels is not None:
862
+ loss = self.loss_function(
863
+ logits=logits,
864
+ labels=labels,
865
+ vocab_size=self.config.vocab_size,
866
+ **kwargs,
867
+ )
868
+
869
+ if not return_dict:
870
+ output = (logits, model_outputs[1:])
871
+ return (loss, output) if loss is not None else output
872
+
873
+ return CausalLMOutputWithPast(
874
+ loss=loss,
875
+ logits=logits,
876
+ past_key_values=model_outputs.past_key_values,
877
+ hidden_states=model_outputs.hidden_states,
878
+ attentions=model_outputs.attentions,
879
+ )
880
+
881
+
882
+ __all__ = ["SmalLmForCausalLM", "SmalLmModel", "SmalLmPreTrainedModel"]