Azrail commited on
Commit
466c359
·
verified ·
1 Parent(s): 87e79ae

Upload SmalLmForCausalLM

Browse files
Files changed (4) hide show
  1. README.md +1 -0
  2. config.json +5 -1
  3. config.py +129 -0
  4. model.py +878 -0
README.md CHANGED
@@ -2,6 +2,7 @@
2
  library_name: transformers
3
  tags:
4
  - generated_from_trainer
 
5
  model-index:
6
  - name: smallm_70_rope
7
  results: []
 
2
  library_name: transformers
3
  tags:
4
  - generated_from_trainer
5
+ - smallm
6
  model-index:
7
  - name: smallm_70_rope
8
  results: []
config.json CHANGED
@@ -4,6 +4,10 @@
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.1,
 
 
 
 
7
  "balancing_coef": 0.0001,
8
  "bos_token_id": 1,
9
  "embedding_dropout": 0.0,
@@ -39,7 +43,7 @@
39
  "sliding_window_attention": true,
40
  "sliding_window_context": 1024,
41
  "sliding_window_period": 4,
42
- "static_residual": false,
43
  "token_experts": 3,
44
  "torch_dtype": "float32",
45
  "transformers_version": "4.50.3",
 
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "config.SmalLmConfig",
9
+ "AutoModelForCausalLM": "model.SmalLmForCausalLM"
10
+ },
11
  "balancing_coef": 0.0001,
12
  "bos_token_id": 1,
13
  "embedding_dropout": 0.0,
 
43
  "sliding_window_attention": true,
44
  "sliding_window_context": 1024,
45
  "sliding_window_period": 4,
46
+ "static_residual": true,
47
  "token_experts": 3,
48
  "torch_dtype": "float32",
49
  "transformers_version": "4.50.3",
config.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from transformers import PretrainedConfig
3
+ from typing import Optional
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ class SmalLmConfig(PretrainedConfig):
9
+ model_type = "smallm"
10
+
11
+ def __init__(
12
+ self,
13
+ # global model params
14
+ hidden_size: int = 512,
15
+ intermediate_size: int = 2048,
16
+ mlp_bias: bool = False,
17
+ num_hidden_layers: int = 27,
18
+ rms_norm_eps: float = 1e-6,
19
+ rms_affine: bool = False,
20
+ initializer_range: float = 0.02,
21
+ output_hidden_states: bool = False,
22
+ output_attentions: bool = False,
23
+ use_cache: bool = True,
24
+ sliding_window_attention: bool = True,
25
+ sliding_window_context: int = 1024,
26
+ sliding_window_period: int = 4,
27
+ embedding_dropout: float = 0.0,
28
+ layer_dropout: float = 0.1,
29
+ max_seq_len: int = 2048,
30
+ original_seq_len: int | None = None,
31
+ tie_word_embeddings: bool = True,
32
+ # attention params
33
+ num_attention_heads: int = 9,
34
+ num_kv_heads: int = 3,
35
+ head_size: Optional[int] = None,
36
+ attention_dropout: float = 0.1,
37
+ positional_bias_type: str = "rope",
38
+ high_rotations: int = 32,
39
+ low_rotations: int = 1,
40
+ attention_bias: bool = False,
41
+ rope_base: int = 100000,
42
+ # MoE params
43
+ use_moe: bool = True,
44
+ moe_period: int = 3,
45
+ expert_size: int = 256,
46
+ shared_experts: int = 2,
47
+ routed_experts: int = 16,
48
+ token_experts: int = 4,
49
+ noisy_experts: bool = False,
50
+ moe_bias: bool = False,
51
+ balancing_coef: float = 1e-4,
52
+ no_moe_layers: int = 5,
53
+ # extra params
54
+ vocab_size: int = 60000,
55
+ bos_token_id: int = 1,
56
+ eos_token_id: int = 0,
57
+ pad_token_id: int = 0,
58
+ static_residual: bool = False,
59
+ moe_type: str = "default",
60
+ **kwargs,
61
+ ):
62
+ if positional_bias_type not in ["alibi", "rope"]:
63
+ raise ValueError(
64
+ f"positional_bias_type must be 'alibi' or 'rope', got {positional_bias_type}"
65
+ )
66
+ self.moe_type = moe_type
67
+ self.static_residual = not static_residual
68
+ self.no_moe_layers = no_moe_layers
69
+ self.moe_bias = moe_bias
70
+ self.balancing_coef = balancing_coef
71
+ self.noisy_experts = noisy_experts
72
+ self.high_rotations = high_rotations
73
+ self.low_rotations = low_rotations
74
+ self.positional_bias_type = positional_bias_type
75
+ self.vocab_size = vocab_size
76
+ self.hidden_size = hidden_size
77
+ self.mlp_bias = mlp_bias
78
+ self.num_hidden_layers = num_hidden_layers
79
+ self.num_attention_heads = num_attention_heads
80
+ self.num_kv_heads = num_kv_heads
81
+ self.attention_dropout = attention_dropout
82
+ self.rms_norm_eps = rms_norm_eps
83
+ self.max_seq_len = max_seq_len
84
+ self.use_cache = use_cache
85
+ self.initializer_range = initializer_range
86
+ self.embedding_dropout = embedding_dropout
87
+ self.rms_affine = rms_affine
88
+ self.output_hidden_states = output_hidden_states
89
+ self.output_attentions = output_attentions
90
+ self.layer_dropout = layer_dropout
91
+ self.use_moe = use_moe
92
+ self.moe_period = moe_period
93
+ self.expert_size = expert_size
94
+ self.shared_experts = shared_experts
95
+ self.routed_experts = routed_experts
96
+ self.token_experts = token_experts
97
+ self.intermediate_size = intermediate_size
98
+ self.attention_bias = attention_bias
99
+ self.rope_base = rope_base
100
+ self.head_size = head_size if head_size else hidden_size // num_attention_heads
101
+ self.original_seq_len = (
102
+ original_seq_len if original_seq_len is not None else max_seq_len
103
+ )
104
+
105
+ self.sliding_window_attention = sliding_window_attention
106
+ self.sliding_window_context = sliding_window_context
107
+ self.sliding_window_period = sliding_window_period
108
+ if sliding_window_attention and sliding_window_context > max_seq_len:
109
+ logger.warning(
110
+ f"sliding_window_context more than max_seq_len, \
111
+ set sliding_window_context to {max_seq_len}"
112
+ )
113
+ self.sliding_window_context = max_seq_len
114
+ if not sliding_window_attention:
115
+ self.sliding_window_context = max_seq_len
116
+
117
+ if self.head_size % 2 != 0 and self.positional_bias_type == "rope":
118
+ raise ValueError("Head size should divided by 2")
119
+
120
+ super().__init__(
121
+ bos_token_id=bos_token_id,
122
+ eos_token_id=eos_token_id,
123
+ pad_token_id=pad_token_id,
124
+ tie_word_embeddings=tie_word_embeddings,
125
+ **kwargs,
126
+ )
127
+
128
+
129
+ __all__ = ["SmalLmConfig"]
model.py ADDED
@@ -0,0 +1,878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel, GenerationMixin
5
+ from transformers.cache_utils import Cache, DynamicCache
6
+ from transformers.modeling_outputs import (
7
+ BaseModelOutputWithPast,
8
+ CausalLMOutputWithPast,
9
+ )
10
+ from .config import SmalLmConfig
11
+ from typing import Optional
12
+ import logging
13
+ from einops import rearrange, repeat
14
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
15
+ from einops._torch_specific import allow_ops_in_compiled_graph
16
+
17
+ allow_ops_in_compiled_graph()
18
+ from transformers.utils import is_flash_attn_2_available
19
+
20
+ if is_flash_attn_2_available():
21
+ from flash_attn import flash_attn_varlen_func
22
+ from flash_attn.bert_padding import unpad_input, pad_input
23
+
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class SwiGLU(nn.Module):
30
+ def __init__(
31
+ self, input_size: int, hidden_size: int, bias: bool = False, *args, **kwargs
32
+ ):
33
+ super().__init__(*args, **kwargs)
34
+ self.input_size = input_size
35
+ self.hidden_size = hidden_size
36
+ self.up_proj = nn.Linear(input_size, hidden_size * 2, bias=bias)
37
+ self.down_proj = nn.Linear(hidden_size, input_size, bias=bias)
38
+
39
+ def forward(self, x):
40
+ up_gate = self.up_proj(x)
41
+ up, gate = rearrange(up_gate, "... (d span) -> span ... d", d=self.hidden_size)
42
+ down = F.silu(gate) * up
43
+ return self.down_proj(down)
44
+
45
+
46
+ class Router(nn.Module):
47
+ def __init__(self, config: SmalLmConfig, *args, **kwargs):
48
+
49
+ super().__init__(*args, **kwargs)
50
+ self.config = config
51
+ self.experts_to_select = self.config.token_experts - self.config.shared_experts
52
+ self.gate = nn.Linear(config.hidden_size, config.routed_experts, bias=False)
53
+ self.gate_noise = (
54
+ nn.Linear(config.hidden_size, config.routed_experts, bias=False)
55
+ if config.noisy_experts is True
56
+ else None
57
+ )
58
+ self.bias_coef = config.balancing_coef
59
+ self.register_buffer(
60
+ "bias", torch.zeros(config.routed_experts), persistent=True
61
+ )
62
+ self.register_buffer(
63
+ "expert_counts", torch.zeros(config.routed_experts), persistent=False
64
+ )
65
+
66
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
67
+ # calculating with fp32 for stability
68
+ # num_tokens n_shared_experts
69
+ gate_logits = self.gate(x)
70
+ if self.gate_noise is not None:
71
+ gate_logits_noise = F.softplus(self.gate_noise(x))
72
+ gate_logits_noise = torch.randn_like(gate_logits_noise) * gate_logits_noise
73
+ gate_logits = gate_logits + gate_logits_noise
74
+
75
+ gate_weights = gate_logits.sigmoid()
76
+ original_weights = gate_weights
77
+
78
+ gate_weights = gate_weights + self.bias
79
+
80
+ _, top_experts_idx = torch.topk(gate_weights, self.experts_to_select, dim=-1)
81
+ counts = torch.bincount(
82
+ top_experts_idx.flatten(), minlength=self.config.routed_experts
83
+ ).detach()
84
+ if self.training:
85
+ self.expert_counts += counts
86
+ top_experts_weights = original_weights.gather(1, top_experts_idx)
87
+ top_experts_weights = top_experts_weights / top_experts_weights.sum(
88
+ dim=-1, keepdim=True
89
+ )
90
+ return top_experts_idx, top_experts_weights.type_as(x), counts.tolist()
91
+
92
+ def update_bias(self):
93
+ mean = self.expert_counts.float().mean()
94
+ delta = self.bias_coef * torch.sign(mean - self.expert_counts)
95
+ self.bias += delta
96
+ self.expert_counts.zero_()
97
+
98
+
99
+ class MoE(nn.Module):
100
+ def __init__(self, config: SmalLmConfig, *args, **kwargs):
101
+ super().__init__(*args, **kwargs)
102
+ self.config = config
103
+ self.shared_experts = SwiGLU(
104
+ config.hidden_size,
105
+ config.shared_experts * config.expert_size,
106
+ config.moe_bias,
107
+ )
108
+ self.routed_experts = nn.ModuleList(
109
+ [
110
+ SwiGLU(config.hidden_size, config.expert_size, config.moe_bias)
111
+ for _ in range(config.routed_experts)
112
+ ]
113
+ )
114
+ self.router = Router(config)
115
+
116
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
117
+ shape = x.size()
118
+ x = x.view(-1, self.config.hidden_size)
119
+ experts_idx, experts_weights, counts = self.router(x)
120
+ out = torch.zeros_like(x)
121
+ for i, expert in enumerate(self.routed_experts):
122
+ if counts[i] == 0:
123
+ continue
124
+ idx, pos = torch.where(experts_idx == i)
125
+ out[idx] += expert(x[idx]) * experts_weights[idx, pos, None]
126
+ shared_out = self.shared_experts(x)
127
+ return (out + shared_out).view(shape)
128
+
129
+
130
+ class ComboMoe(nn.Module):
131
+ def __init__(self, config: SmalLmConfig, *args, **kwargs):
132
+ super().__init__(*args, **kwargs)
133
+ self.config = config
134
+ self.shared_experts = SwiGLU(
135
+ config.hidden_size,
136
+ config.shared_experts * config.expert_size,
137
+ config.moe_bias,
138
+ )
139
+ self.input_router = Router(config)
140
+ self.middle_router = Router(config)
141
+ self.out_router = Router(config)
142
+ self.routed_experts = nn.ModuleList(
143
+ [
144
+ nn.Linear(config.hidden_size, config.expert_size, bias=config.moe_bias)
145
+ for _ in range(config.routed_experts)
146
+ ]
147
+ )
148
+ self.middle_routed_experts = nn.ModuleList(
149
+ [
150
+ nn.Linear(config.expert_size, config.hidden_size, bias=config.moe_bias)
151
+ for _ in range(config.routed_experts)
152
+ ]
153
+ )
154
+ self.out_routed_experts = nn.ModuleList(
155
+ [
156
+ nn.Linear(config.expert_size, config.hidden_size, bias=config.moe_bias)
157
+ for _ in range(config.routed_experts)
158
+ ]
159
+ )
160
+ self.offset = config.routed_experts
161
+
162
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
163
+ shape = x.size()
164
+ x = x.view(-1, self.config.hidden_size)
165
+ iexpert_idx, iexpert_weights, icounts = self.input_router(x)
166
+ iout = torch.zeros((*x.shape[:-1], self.config.expert_size), device=x.device)
167
+ for i, expert in enumerate(self.routed_experts[: self.offset]):
168
+ if icounts[i] == 0:
169
+ continue
170
+ idx, pos = torch.where(iexpert_idx == i)
171
+ iout[idx] += expert(x[idx]) * iexpert_weights[idx, pos, None]
172
+
173
+ mexpert_idx, mexpert_weights, mcounts = self.middle_router(x)
174
+ for i, expert in enumerate(self.middle_routed_experts):
175
+ if mcounts[i] == 0:
176
+ continue
177
+ idx, pos = torch.where(mexpert_idx == i)
178
+ iout[idx] *= F.silu(expert(x[idx]) * mexpert_weights[idx, pos, None])
179
+
180
+ out = torch.zeros_like(x)
181
+ oexpert_idx, oexpert_weights, ocounts = self.out_router(iout)
182
+ for i, expert in enumerate(self.out_routed_experts):
183
+ if ocounts[i] == 0:
184
+ continue
185
+ idx, pos = torch.where(oexpert_idx == i)
186
+ out[idx] += expert(iout[idx]) * oexpert_weights[idx, pos, None]
187
+
188
+ shared_out = self.shared_experts(x)
189
+ return (out + shared_out).view(shape)
190
+
191
+
192
+ def build_alibi_bias(config: SmalLmConfig) -> torch.Tensor:
193
+ """Build ALiBi for specified number of heads:
194
+
195
+ Returns:
196
+ Tensor with ALiBi biases, shape: [num heads]
197
+ """
198
+ bias = (
199
+ 2**-8
200
+ / config.num_attention_heads
201
+ * torch.arange(1, config.num_attention_heads + 1).float()
202
+ )
203
+ return bias
204
+
205
+
206
+ def calc_rotation(num_rotaitions, dim, base, seq_len):
207
+ return (
208
+ dim
209
+ * torch.log(torch.tensor(seq_len).float() / (num_rotaitions * 2 * torch.pi))
210
+ / torch.log(torch.tensor(base))
211
+ )
212
+
213
+
214
+ def get_ramp_interpolation(min_idx, max_idx, thetas_dim, eps=1e-6):
215
+ if min_idx == max_idx:
216
+ max_idx += eps
217
+ mult = (torch.arange(thetas_dim) - min_idx) / (max_idx - min_idx)
218
+ mult = torch.clamp(mult, 0, 1)
219
+ return 1 - mult
220
+
221
+
222
+ def build_rope_bias(config: SmalLmConfig) -> torch.Tensor:
223
+ dim = config.head_size
224
+
225
+ theta = 1.0 / (config.rope_base ** (torch.arange(0, dim, 2).float() / dim))
226
+
227
+ # neural tangent kernel by part korrection
228
+ if config.max_seq_len > config.original_seq_len:
229
+ scale = config.max_seq_len / config.original_seq_len
230
+ # from idea that lambda = 2pi / theta_i and lmbad = seq_len / num_rotations, lambda - wavelen
231
+ low_interpolation_idx = max(
232
+ 0,
233
+ torch.ceil(
234
+ calc_rotation(
235
+ config.high_rotations,
236
+ dim,
237
+ config.rope_base,
238
+ config.original_seq_len,
239
+ )
240
+ ).item(),
241
+ )
242
+ high_interpolation_idx = min(
243
+ dim - 1,
244
+ torch.floor(
245
+ calc_rotation(
246
+ config.low_rotations, dim, config.rope_base, config.original_seq_len
247
+ )
248
+ ).item(),
249
+ )
250
+ interpolation_mult = get_ramp_interpolation(
251
+ low_interpolation_idx, high_interpolation_idx, dim // 2
252
+ )
253
+ theta = (1 - interpolation_mult) * theta / scale + interpolation_mult * theta
254
+
255
+ seq_idx = torch.arange(config.max_seq_len)
256
+ seq_theta = torch.outer(seq_idx, theta)
257
+ bias = torch.polar(torch.ones_like(seq_theta), seq_theta)
258
+ return bias
259
+
260
+
261
+ def apply_rope_bias(x: torch.Tensor, precompute_bias: torch.Tensor) -> torch.Tensor:
262
+ ini_dtype = x.dtype
263
+ # for stbility to fp32, also need for torch
264
+ x = rearrange(x.float(), "b n s (d i) -> b n s d i", i=2).contiguous()
265
+ x = torch.view_as_complex(x)
266
+ x = x * precompute_bias
267
+ x = torch.view_as_real(x)
268
+ x = rearrange(x, "b n s d i -> b n s (d i)")
269
+ return x.to(ini_dtype)
270
+
271
+
272
+ def flash_attention_forward(
273
+ module: nn.Module,
274
+ x: torch.Tensor,
275
+ query: torch.Tensor,
276
+ key: torch.Tensor,
277
+ value: torch.Tensor,
278
+ attention_mask: torch.Tensor,
279
+ alibi_slope: Optional[torch.Tensor]
280
+ ) -> torch.Tensor:
281
+ query = rearrange(query, "b n s d -> b s n d")
282
+ key = rearrange(key, "b n s d -> b s n d")
283
+ value = rearrange(value, "b n s d -> b s n d")
284
+ query, idx_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(query, attention_mask)
285
+ key, _, cu_seqlens_k, max_seqlen_k, _ = unpad_input(key, attention_mask)
286
+ value, _, _, _, _ = unpad_input(value, attention_mask)
287
+
288
+ key = key.contiguous()
289
+ value = value.contiguous()
290
+ query = query.contiguous()
291
+
292
+ attention_probs = flash_attn_varlen_func(
293
+ query,
294
+ key,
295
+ value,
296
+ cu_seqlens_q=cu_seqlens_q,
297
+ cu_seqlens_k=cu_seqlens_k,
298
+ max_seqlen_q=max_seqlen_q,
299
+ max_seqlen_k=max_seqlen_k,
300
+ dropout_p=module.config.attention_dropout if module.training else 0.0,
301
+ causal=True,
302
+ alibi_slopes=alibi_slope if module.config.attention_bias == "alibi" else None,
303
+ )
304
+ attention_probs = pad_input(attention_probs, idx_q, x.size(0), x.size(1))
305
+ out = rearrange(attention_probs, "b s n d -> b s (n d)")
306
+ return out, None
307
+
308
+
309
+ def sdpa_attention_forward(
310
+ module: nn.Module,
311
+ x: torch.Tensor,
312
+ query: torch.Tensor,
313
+ key: torch.Tensor,
314
+ value: torch.Tensor,
315
+ attention_mask: torch.Tensor,
316
+ alibi_slope: Optional[torch.Tensor]
317
+ ) -> torch.Tensor:
318
+ is_causal = attention_mask is None and query.size(-2) > 1
319
+
320
+ attention_probs = F.scaled_dot_product_attention(
321
+ query,
322
+ key,
323
+ value,
324
+ attn_mask=attention_mask,
325
+ enable_gqa=True,
326
+ is_causal=is_causal,
327
+ dropout_p=module.config.attention_dropout if module.training else 0.0,
328
+ )
329
+ out = rearrange(attention_probs, "b n s d -> b s (n d)")
330
+
331
+ return out, None
332
+
333
+ def eager_attention_forward(
334
+ module: nn.Module,
335
+ x: torch.Tensor,
336
+ query: torch.Tensor,
337
+ key: torch.Tensor,
338
+ value: torch.Tensor,
339
+ attention_mask: torch.Tensor,
340
+ alibi_slope: Optional[torch.Tensor]
341
+ ) -> torch.Tensor:
342
+ query = rearrange(query, 'b (kv group) s d -> b kv group s d', kv=module.config.num_kv_heads, group=module.head_per_group)
343
+ key = rearrange(key, 'b kv s d -> b kv 1 s d')
344
+ value = rearrange(
345
+ value, 'b kv s d -> b kv 1 s d'
346
+ )
347
+ attention_weights = query @ key.transpose(-1, -2)
348
+ attention_probs = F.dropout(attention_weights / torch.sqrt(
349
+ torch.tensor(value.size(-1), device=x.device)
350
+ ),
351
+ p=module.config.attention_dropout if module.training else 0.0
352
+ )
353
+ if alibi_slope is not None:
354
+ alibi_slope = rearrange(
355
+ alibi_slope, 'b n s s -> b kv group s s', kv=module.config.num_kv_heads, group=module.head_per_group
356
+ )
357
+ attention_probs = attention_probs + alibi_slope
358
+ elif alibi_slope is None and attention_mask is not None:
359
+ attention_mask = attention_mask.expand(-1, module.config.num_attention_heads, -1, -1)
360
+ attention_mask = rearrange(
361
+ attention_mask, 'b (kv group) s1 s2 -> b kv group s1 s2', kv=module.config.num_kv_heads, group=module.head_per_group
362
+ )
363
+ attention_probs = attention_probs + attention_mask
364
+ attention_probs = F.softmax(attention_probs, dim=-1)
365
+ attention_probs = attention_probs @ value
366
+ out = rearrange(attention_probs, "b kv group s d -> b s (kv group d)")
367
+ return out, attention_weights
368
+
369
+
370
+ ALL_ATTENTION_FUNCTIONS = {
371
+ "eager": eager_attention_forward,
372
+ "sdpa": sdpa_attention_forward,
373
+ "flash_attention_2": flash_attention_forward,
374
+ }
375
+
376
+
377
+ class CausalSelfAttention(nn.Module):
378
+ def __init__(self, config: SmalLmConfig, layer_idx: int, *args, **kwargs):
379
+ super().__init__(*args, **kwargs)
380
+ if config.num_attention_heads % config.num_kv_heads != 0:
381
+ raise ValueError("Num attention heads should divided by num kv heads")
382
+
383
+ self.config = config
384
+ self.layer_idx = layer_idx
385
+ self.head_per_group = config.num_attention_heads // config.num_kv_heads
386
+ self.q_proj = nn.Linear(
387
+ config.hidden_size,
388
+ config.head_size * config.num_attention_heads,
389
+ bias=config.attention_bias,
390
+ )
391
+ self.kv_proj = nn.Linear(
392
+ config.hidden_size,
393
+ config.head_size * config.num_kv_heads * 2,
394
+ bias=config.attention_bias,
395
+ )
396
+ self.out_proj = nn.Linear(
397
+ config.head_size * config.num_attention_heads,
398
+ config.hidden_size,
399
+ bias=config.attention_bias,
400
+ )
401
+
402
+ def forward(
403
+ self,
404
+ x: torch.Tensor,
405
+ attention_mask: torch.Tensor,
406
+ past_key_values: Optional[Cache | torch.FloatTensor],
407
+ cache_position: Optional[torch.LongTensor],
408
+ bias: torch.Tensor,
409
+ ):
410
+ q = self.q_proj(x)
411
+ kv = self.kv_proj(x)
412
+ q = rearrange(q, "b s (n d) -> b n s d", n=self.config.num_attention_heads)
413
+ k, v = rearrange(kv, "b s (n d q) -> q b n s d", q=2, d=self.config.head_size)
414
+
415
+ if self.config.positional_bias_type == "rope":
416
+ k = apply_rope_bias(k, bias)
417
+ q = apply_rope_bias(q, bias)
418
+
419
+ if past_key_values is not None:
420
+ # for static cache
421
+ cach_kwargs = {"cache_position": cache_position}
422
+ k, v = past_key_values.update(
423
+ key_states=k,
424
+ value_states=v,
425
+ layer_idx=self.layer_idx,
426
+ cache_kwargs=cach_kwargs,
427
+ )
428
+
429
+ attention_interface = eager_attention_forward
430
+ if self.config._attn_implementation != "eager":
431
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
432
+
433
+ out, attention_weights = attention_interface(
434
+ self,
435
+ x,
436
+ q,
437
+ k,
438
+ v,
439
+ attention_mask,
440
+ bias if self.config.positional_bias_type == "alibi" else None
441
+ )
442
+
443
+ out = self.out_proj(out)
444
+ return out, attention_weights
445
+
446
+
447
+ class WeightedResidual(nn.Module):
448
+ def __init__(self, config: SmalLmConfig, *args, **kwargs):
449
+ super().__init__(*args, **kwargs)
450
+ self.weight = nn.Parameter(
451
+ torch.ones(config.hidden_size), requires_grad=config.static_residual
452
+ )
453
+
454
+ def forward(self, short, long):
455
+ return self.weight * short + long
456
+
457
+
458
+ class Block(nn.Module):
459
+ def __init__(self, config: SmalLmConfig, layer_idx: int, *args, **kwargs):
460
+ super().__init__(*args, **kwargs)
461
+ self.attn_norm = nn.RMSNorm(
462
+ config.hidden_size,
463
+ eps=config.rms_norm_eps,
464
+ elementwise_affine=config.rms_affine,
465
+ )
466
+ self.ffn_norm = nn.RMSNorm(
467
+ config.hidden_size,
468
+ eps=config.rms_norm_eps,
469
+ elementwise_affine=config.rms_affine,
470
+ )
471
+ self.dropout1 = nn.Dropout(config.layer_dropout)
472
+ self.dropout2 = nn.Dropout(config.layer_dropout)
473
+ self.attention = CausalSelfAttention(config, layer_idx)
474
+ moe_class = MoE if config.moe_type == "default" else ComboMoe
475
+ self.mlp = (
476
+ moe_class(config)
477
+ if (
478
+ config.use_moe
479
+ and layer_idx % config.moe_period == 0
480
+ and layer_idx > config.no_moe_layers
481
+ )
482
+ else SwiGLU(config.hidden_size, config.intermediate_size, config.mlp_bias)
483
+ )
484
+ self.attention_residual = WeightedResidual(config)
485
+ self.ffn_residual = WeightedResidual(config)
486
+
487
+ def forward(
488
+ self,
489
+ inputs_embeds: torch.Tensor,
490
+ attention_mask: torch.Tensor,
491
+ past_key_values: Optional[Cache | torch.FloatTensor],
492
+ output_attentions: bool,
493
+ cache_position: Optional[torch.LongTensor],
494
+ bias: torch.Tensor,
495
+ ) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
496
+ identity = inputs_embeds
497
+
498
+ # attention block
499
+ out = self.attn_norm(inputs_embeds)
500
+ out, attention_probs = self.attention(
501
+ out, attention_mask, past_key_values, cache_position, bias
502
+ )
503
+ out = self.dropout1(out)
504
+ identity = self.attention_residual(identity, out)
505
+
506
+ # swiglu / MoE block
507
+ out = self.dropout2(self.mlp(self.ffn_norm(identity)))
508
+ out = self.ffn_residual(identity, out)
509
+ if output_attentions:
510
+ return out, attention_probs
511
+ return (out,)
512
+
513
+
514
+ class SmalLmPreTrainedModel(PreTrainedModel):
515
+ config_class = SmalLmConfig
516
+ base_model_prefix = "model"
517
+ supports_gradient_checkpointing = True
518
+ _no_split_modules = ["Block"]
519
+ _skip_keys_device_placement = "past_key_values"
520
+ _supports_sdpa = True
521
+ _supports_flash_attn_2 = True
522
+ def __init__(self, *inputs, **kwargs):
523
+ super().__init__(*inputs, **kwargs)
524
+
525
+ def _init_weights(self, module):
526
+ std = self.config.initializer_range
527
+ if isinstance(module, nn.Linear):
528
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
529
+ if module.bias is not None:
530
+ torch.nn.init.zeros_(module.bias)
531
+ elif isinstance(module, nn.Embedding):
532
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
533
+ module.weight.data[self.pad_idx].zero_()
534
+
535
+
536
+ class SmalLmModel(SmalLmPreTrainedModel):
537
+ def __init__(self, config: SmalLmConfig, *args, **kwargs):
538
+ super().__init__(config, *args, **kwargs)
539
+ self.config = config
540
+ self.pad_idx = config.pad_token_id
541
+ self.pad_token_id = config.pad_token_id
542
+ self.vocab_size = config.vocab_size
543
+ self.config = config
544
+ precompute_bias = (
545
+ build_alibi_bias(config)
546
+ if config.positional_bias_type == "alibi"
547
+ else build_rope_bias(config)
548
+ )
549
+ self.register_buffer("precompute_bias", precompute_bias, persistent=False)
550
+ # не заб��ть про sharing weights на output голове self.embedding.weight = self.output.weight
551
+ self.embedding = nn.Embedding(
552
+ self.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
553
+ )
554
+ self.embedding_dropout = nn.Dropout(config.embedding_dropout)
555
+ self.layers = nn.ModuleList(
556
+ [Block(config, idx) for idx in range(1, config.num_hidden_layers + 1)]
557
+ )
558
+ self.out_norm = nn.RMSNorm(
559
+ config.hidden_size,
560
+ eps=config.rms_norm_eps,
561
+ elementwise_affine=config.rms_affine,
562
+ )
563
+
564
+ self.gradient_checkpointing = False
565
+ self.post_init()
566
+
567
+ def get_input_embeddings(self):
568
+ return self.embedding
569
+
570
+ def set_input_embeddings(self, value):
571
+ self.embedding = value
572
+
573
+ def forward(
574
+ self,
575
+ # input options
576
+ input_ids: torch.LongTensor = None,
577
+ attention_mask: Optional[torch.Tensor] = None,
578
+ inputs_embeds: Optional[torch.FloatTensor] = None,
579
+ # output options
580
+ output_attentions: Optional[bool] = None,
581
+ output_hidden_states: Optional[bool] = None,
582
+ return_dict: Optional[bool] = None,
583
+ # cache options
584
+ use_cache: Optional[bool] = None,
585
+ past_key_values: Optional[Cache | torch.FloatTensor] = None,
586
+ cache_position: Optional[torch.LongTensor] = None,
587
+ position_ids: Optional[torch.LongTensor] = None,
588
+ **kwargs,
589
+ ) -> tuple | BaseModelOutputWithPast:
590
+ # check additional parameters
591
+ output_hidden_states = (
592
+ output_hidden_states
593
+ if output_hidden_states is not None
594
+ else self.config.output_hidden_states
595
+ )
596
+ use_cache = (
597
+ use_cache
598
+ if use_cache is not None
599
+ else (False if self.training else self.config.use_cache)
600
+ )
601
+ return_dict = (
602
+ return_dict if return_dict is not None else self.config.return_dict
603
+ )
604
+
605
+ if input_ids is not None and inputs_embeds is not None:
606
+ raise ValueError(
607
+ "You must specify only input_ids or inputs_embeds, not both"
608
+ )
609
+
610
+ if self.training and use_cache:
611
+ use_cache = False
612
+
613
+ if inputs_embeds is None:
614
+ inputs_embeds = self.embedding(input_ids)
615
+
616
+ if use_cache and past_key_values is None:
617
+ past_key_values = DynamicCache()
618
+
619
+ # calculating position for StaticCache
620
+ if cache_position is None:
621
+ last_position = (
622
+ past_key_values.get_seq_length() if past_key_values is not None else 0
623
+ )
624
+ cache_position = torch.arange(
625
+ last_position,
626
+ last_position + inputs_embeds.size(1),
627
+ device=inputs_embeds.device,
628
+ )
629
+
630
+ causal_mask = self._get_causal_masks(
631
+ attention_mask, inputs_embeds, past_key_values, cache_position
632
+ )
633
+ if self.config.positional_bias_type == "rope":
634
+ end_pos = (
635
+ inputs_embeds.size(1)
636
+ if past_key_values is None
637
+ else cache_position[-1] + 1
638
+ )
639
+ start_pos = 0 if past_key_values is None else cache_position[0]
640
+ bias = self.precompute_bias[start_pos:end_pos]
641
+
642
+ elif self.config.positional_bias_type == "alibi":
643
+ if self.config._attn_implementation == "flash_attention_2":
644
+ bias = self.precompute_bias
645
+ else:
646
+ i = torch.arange(
647
+ (
648
+ inputs_embeds.size(1)
649
+ if past_key_values is None
650
+ else cache_position[-1] + 1
651
+ ),
652
+ device=inputs_embeds.device,
653
+ )
654
+ bias = i[:, None] - i[None, :]
655
+ bias = torch.tril(bias).expand(
656
+ inputs_embeds.size(0), self.config.num_attention_heads, -1, -1
657
+ ) * rearrange(self.precompute_bias, "n -> 1 n 1 1")
658
+ if causal_mask is not None:
659
+ causal_mask = causal_mask + bias
660
+ else:
661
+ causal_mask = bias
662
+
663
+ hidden_state = inputs_embeds
664
+ hidden_states = [hidden_state] if output_hidden_states else None
665
+ attentions = [] if output_attentions else None
666
+ for idx, layer in enumerate(self.layers, 1):
667
+ if self.gradient_checkpointing:
668
+ # for details see:
669
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3107
670
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3149
671
+ layer_out = self._gradient_checkpointing_func(
672
+ layer.__call__,
673
+ hidden_state,
674
+ causal_mask,
675
+ past_key_values,
676
+ output_attentions,
677
+ cache_position,
678
+ bias,
679
+ )
680
+ else:
681
+ layer_out = layer(
682
+ hidden_state,
683
+ causal_mask,
684
+ past_key_values,
685
+ output_attentions,
686
+ cache_position,
687
+ bias,
688
+ )
689
+ hidden_state = layer_out[0]
690
+ if output_hidden_states:
691
+ hidden_states.append(hidden_state)
692
+ if output_attentions:
693
+ attentions.append(layer_out[1])
694
+
695
+ hidden_state = self.out_norm(hidden_state)
696
+ out = BaseModelOutputWithPast(
697
+ last_hidden_state=hidden_state,
698
+ past_key_values=past_key_values if use_cache else None,
699
+ hidden_states=tuple(hidden_states) if hidden_states is not None else None,
700
+ attentions=tuple(attentions) if attentions is not None else None,
701
+ )
702
+ return out if return_dict else out.to_tuple()
703
+
704
+ def _get_causal_masks(
705
+ self,
706
+ attention_mask: Optional[torch.Tensor],
707
+ inputs_embeds: torch.Tensor,
708
+ past_key_values: Optional[torch.Tensor],
709
+ cache_position: Optional[torch.Tensor],
710
+ ):
711
+ if self.config._attn_implementation == "flash_attention_2":
712
+ if attention_mask is None:
713
+ attention_mask = torch.ones(
714
+ (inputs_embeds.size(0), inputs_embeds.size(1)), device=inputs_embeds.device
715
+ ).long()
716
+ return attention_mask
717
+ dtype, device = inputs_embeds.dtype, inputs_embeds.device
718
+ past_token = (
719
+ past_key_values.get_seq_length() if past_key_values is not None else 0
720
+ )
721
+ if attention_mask is not None and torch.all(attention_mask == 0.0):
722
+ return None
723
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
724
+ attention_mask=attention_mask,
725
+ inputs_embeds=inputs_embeds,
726
+ past_key_values_length=past_token,
727
+ is_training=self.training,
728
+ ):
729
+ return None
730
+
731
+ sequence_length = inputs_embeds.size(1)
732
+ target_length = (
733
+ attention_mask.size(-1)
734
+ if isinstance(attention_mask, torch.Tensor)
735
+ else past_token + sequence_length + 1
736
+ )
737
+
738
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
739
+ attention_mask=attention_mask,
740
+ sequence_length=sequence_length,
741
+ target_length=target_length,
742
+ dtype=dtype,
743
+ device=device,
744
+ cache_position=cache_position,
745
+ batch_size=inputs_embeds.size(0),
746
+ )
747
+
748
+ min_dtype = torch.finfo(dtype).min
749
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
750
+ return causal_mask
751
+
752
+ @staticmethod
753
+ def _prepare_4d_causal_attention_mask_with_cache_position(
754
+ attention_mask: Optional[torch.Tensor],
755
+ sequence_length: int,
756
+ target_length: int,
757
+ dtype: torch.dtype,
758
+ device: torch.device,
759
+ cache_position: Optional[torch.Tensor],
760
+ batch_size: int,
761
+ ):
762
+ if attention_mask is not None and attention_mask.dim() == 4:
763
+ causal_mask = attention_mask
764
+ else:
765
+ min_dtype = torch.finfo(dtype).min
766
+ causal_mask = torch.full(
767
+ (sequence_length, target_length),
768
+ fill_value=min_dtype,
769
+ dtype=dtype,
770
+ device=device,
771
+ )
772
+ if sequence_length != 1:
773
+ causal_mask = torch.triu(causal_mask, diagonal=1)
774
+ causal_mask *= torch.arange(
775
+ target_length, device=device
776
+ ) > cache_position.reshape(-1, 1)
777
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
778
+ if attention_mask is not None:
779
+ causal_mask = causal_mask.clone()
780
+ mask_length = attention_mask.shape[-1]
781
+ padding_mask = (
782
+ causal_mask[:, :, :, :mask_length]
783
+ + attention_mask[:, None, None, :]
784
+ )
785
+ padding_mask = padding_mask == 0
786
+ causal_mask[:, :, :, :mask_length] = causal_mask[
787
+ :, :, :, :mask_length
788
+ ].masked_fill(padding_mask, min_dtype)
789
+ return causal_mask
790
+
791
+
792
+ class SmalLmForCausalLM(SmalLmPreTrainedModel, GenerationMixin):
793
+ _tied_weights_keys = ["lm_head.weight"]
794
+
795
+ def __init__(self, config: SmalLmConfig, *args, **kwargs):
796
+ super().__init__(config, *args, **kwargs)
797
+ self.config = config
798
+ self.model = SmalLmModel(config)
799
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
800
+ self.post_init()
801
+
802
+ def get_output_embeddings(self):
803
+ return self.lm_head
804
+
805
+ def set_output_embeddings(self, new_embeddings):
806
+ self.lm_head = new_embeddings
807
+
808
+ def forward(
809
+ self,
810
+ # input options
811
+ input_ids: torch.LongTensor = None,
812
+ attention_mask: Optional[torch.Tensor] = None,
813
+ inputs_embeds: Optional[torch.FloatTensor] = None,
814
+ # output options
815
+ output_attentions: Optional[bool] = None,
816
+ output_hidden_states: Optional[bool] = None,
817
+ return_dict: Optional[bool] = None,
818
+ # cache options
819
+ use_cache: Optional[bool] = None,
820
+ past_key_values: Optional[Cache | torch.FloatTensor] = None,
821
+ cache_position: Optional[torch.LongTensor] = None,
822
+ # generation options
823
+ labels: Optional[torch.Tensor] = None,
824
+ logits_to_keep: int | torch.Tensor = 0,
825
+ **kwargs,
826
+ ) -> tuple | CausalLMOutputWithPast:
827
+ output_attentions = (
828
+ output_attentions
829
+ if output_attentions is not None
830
+ else self.config.output_attentions
831
+ )
832
+ output_hidden_states = (
833
+ output_hidden_states
834
+ if output_hidden_states is not None
835
+ else self.config.output_hidden_states
836
+ )
837
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
838
+ return_dict = (
839
+ return_dict if return_dict is not None else self.config.return_dict
840
+ )
841
+
842
+ model_outputs = self.model(
843
+ input_ids=input_ids,
844
+ attention_mask=attention_mask,
845
+ past_key_values=past_key_values,
846
+ inputs_embeds=inputs_embeds,
847
+ use_cache=use_cache,
848
+ output_attentions=output_attentions,
849
+ output_hidden_states=output_hidden_states,
850
+ return_dict=return_dict,
851
+ cache_position=cache_position,
852
+ **kwargs,
853
+ )
854
+
855
+ hidden_states = model_outputs[0]
856
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
857
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
858
+
859
+ loss = None
860
+ if labels is not None:
861
+ loss = self.loss_function(
862
+ logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs
863
+ )
864
+
865
+ if not return_dict:
866
+ output = (logits, model_outputs[1:])
867
+ return (loss, output) if loss is not None else output
868
+
869
+ return CausalLMOutputWithPast(
870
+ loss=loss,
871
+ logits=logits,
872
+ past_key_values=model_outputs.past_key_values,
873
+ hidden_states=model_outputs.hidden_states,
874
+ attentions=model_outputs.attentions,
875
+ )
876
+
877
+
878
+ __all__ = ["SmalLmForCausalLM", "SmalLmModel", "SmalLmPreTrainedModel"]