Azrail commited on
Commit
ac9f33c
·
verified ·
1 Parent(s): 2667be4

Upload SmalLmForCausalLM

Browse files
Files changed (3) hide show
  1. config.json +4 -0
  2. config.py +129 -0
  3. model.py +801 -0
config.json CHANGED
@@ -4,6 +4,10 @@
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.1,
 
 
 
 
7
  "balancing_coef": 0.0001,
8
  "bos_token_id": 1,
9
  "embedding_dropout": 0.0,
 
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "config.SmalLmConfig",
9
+ "AutoModelForCausalLM": "model.SmalLmForCausalLM"
10
+ },
11
  "balancing_coef": 0.0001,
12
  "bos_token_id": 1,
13
  "embedding_dropout": 0.0,
config.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from transformers import PretrainedConfig
3
+ from typing import Optional
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ class SmalLmConfig(PretrainedConfig):
9
+ model_type = "smallm"
10
+
11
+ def __init__(
12
+ self,
13
+ # global model params
14
+ hidden_size: int = 512,
15
+ intermediate_size: int = 2048,
16
+ mlp_bias: bool = False,
17
+ num_hidden_layers: int = 27,
18
+ rms_norm_eps: float = 1e-6,
19
+ rms_affine: bool = False,
20
+ initializer_range: float = 0.02,
21
+ output_hidden_states: bool = False,
22
+ output_attentions: bool = False,
23
+ use_cache: bool = True,
24
+ sliding_window_attention: bool = True,
25
+ sliding_window_context: int = 1024,
26
+ sliding_window_period: int = 4,
27
+ embedding_dropout: float = 0.0,
28
+ layer_dropout: float = 0.1,
29
+ max_seq_len: int = 2048,
30
+ original_seq_len: int | None = None,
31
+ tie_word_embeddings: bool = True,
32
+ # attention params
33
+ num_attention_heads: int = 9,
34
+ num_kv_heads: int = 3,
35
+ head_size: Optional[int] = None,
36
+ attention_dropout: float = 0.1,
37
+ positional_bias_type: str = "rope",
38
+ high_rotations: int = 32,
39
+ low_rotations: int = 1,
40
+ attention_bias: bool = False,
41
+ rope_base: int = 100000,
42
+ # MoE params
43
+ use_moe: bool = True,
44
+ moe_period: int = 3,
45
+ expert_size: int = 256,
46
+ shared_experts: int = 2,
47
+ routed_experts: int = 16,
48
+ token_experts: int = 4,
49
+ noisy_experts: bool = False,
50
+ moe_bias: bool = False,
51
+ balancing_coef: float = 1e-4,
52
+ no_moe_layers: int = 5,
53
+ # extra params
54
+ vocab_size: int = 60000,
55
+ bos_token_id: int = 1,
56
+ eos_token_id: int = 0,
57
+ pad_token_id: int = 0,
58
+ static_residual: bool = False,
59
+ moe_type: str = "default",
60
+ **kwargs,
61
+ ):
62
+ if positional_bias_type not in ["alibi", "rope"]:
63
+ raise ValueError(
64
+ f"positional_bias_type must be 'alibi' or 'rope', got {positional_bias_type}"
65
+ )
66
+ self.moe_type = moe_type
67
+ self.static_residual = not static_residual
68
+ self.no_moe_layers = no_moe_layers
69
+ self.moe_bias = moe_bias
70
+ self.balancing_coef = balancing_coef
71
+ self.noisy_experts = noisy_experts
72
+ self.high_rotations = high_rotations
73
+ self.low_rotations = low_rotations
74
+ self.positional_bias_type = positional_bias_type
75
+ self.vocab_size = vocab_size
76
+ self.hidden_size = hidden_size
77
+ self.mlp_bias = mlp_bias
78
+ self.num_hidden_layers = num_hidden_layers
79
+ self.num_attention_heads = num_attention_heads
80
+ self.num_kv_heads = num_kv_heads
81
+ self.attention_dropout = attention_dropout
82
+ self.rms_norm_eps = rms_norm_eps
83
+ self.max_seq_len = max_seq_len
84
+ self.use_cache = use_cache
85
+ self.initializer_range = initializer_range
86
+ self.embedding_dropout = embedding_dropout
87
+ self.rms_affine = rms_affine
88
+ self.output_hidden_states = output_hidden_states
89
+ self.output_attentions = output_attentions
90
+ self.layer_dropout = layer_dropout
91
+ self.use_moe = use_moe
92
+ self.moe_period = moe_period
93
+ self.expert_size = expert_size
94
+ self.shared_experts = shared_experts
95
+ self.routed_experts = routed_experts
96
+ self.token_experts = token_experts
97
+ self.intermediate_size = intermediate_size
98
+ self.attention_bias = attention_bias
99
+ self.rope_base = rope_base
100
+ self.head_size = head_size if head_size else hidden_size // num_attention_heads
101
+ self.original_seq_len = (
102
+ original_seq_len if original_seq_len is not None else max_seq_len
103
+ )
104
+
105
+ self.sliding_window_attention = sliding_window_attention
106
+ self.sliding_window_context = sliding_window_context
107
+ self.sliding_window_period = sliding_window_period
108
+ if sliding_window_attention and sliding_window_context > max_seq_len:
109
+ logger.warning(
110
+ f"sliding_window_context more than max_seq_len, \
111
+ set sliding_window_context to {max_seq_len}"
112
+ )
113
+ self.sliding_window_context = max_seq_len
114
+ if not sliding_window_attention:
115
+ self.sliding_window_context = max_seq_len
116
+
117
+ if self.head_size % 2 != 0 and self.positional_bias_type == "rope":
118
+ raise ValueError("Head size should divided by 2")
119
+
120
+ super().__init__(
121
+ bos_token_id=bos_token_id,
122
+ eos_token_id=eos_token_id,
123
+ pad_token_id=pad_token_id,
124
+ tie_word_embeddings=tie_word_embeddings,
125
+ **kwargs,
126
+ )
127
+
128
+
129
+ __all__ = ["SmalLmConfig"]
model.py ADDED
@@ -0,0 +1,801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel, GenerationMixin
5
+ from transformers.cache_utils import Cache, DynamicCache
6
+ from transformers.modeling_outputs import (
7
+ BaseModelOutputWithPast,
8
+ CausalLMOutputWithPast,
9
+ )
10
+ from .config import SmalLmConfig
11
+ from typing import Optional
12
+ import logging
13
+ from einops import rearrange, repeat
14
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
15
+ from einops._torch_specific import allow_ops_in_compiled_graph
16
+
17
+ allow_ops_in_compiled_graph()
18
+ from transformers.utils import is_flash_attn_2_available
19
+
20
+ USE_FLASH = False
21
+ if is_flash_attn_2_available():
22
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
23
+ from flash_attn.bert_padding import unpad_input, pad_input
24
+
25
+ USE_FLASH = True
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+ logger.info(f"USE FLASH: {USE_FLASH=}")
30
+
31
+
32
+ class SwiGLU(nn.Module):
33
+ def __init__(
34
+ self, input_size: int, hidden_size: int, bias: bool = False, *args, **kwargs
35
+ ):
36
+ super().__init__(*args, **kwargs)
37
+ self.input_size = input_size
38
+ self.hidden_size = hidden_size
39
+ self.up_proj = nn.Linear(input_size, hidden_size * 2, bias=bias)
40
+ self.down_proj = nn.Linear(hidden_size, input_size, bias=bias)
41
+
42
+ def forward(self, x):
43
+ up_gate = self.up_proj(x)
44
+ up, gate = rearrange(up_gate, "... (d span) -> span ... d", d=self.hidden_size)
45
+ down = F.silu(gate) * up
46
+ return self.down_proj(down)
47
+
48
+
49
+ class Router(nn.Module):
50
+ def __init__(self, config: SmalLmConfig, *args, **kwargs):
51
+
52
+ super().__init__(*args, **kwargs)
53
+ self.config = config
54
+ self.experts_to_select = self.config.token_experts - self.config.shared_experts
55
+ self.gate = nn.Linear(config.hidden_size, config.routed_experts, bias=False)
56
+ self.gate_noise = (
57
+ nn.Linear(config.hidden_size, config.routed_experts, bias=False)
58
+ if config.noisy_experts is True
59
+ else None
60
+ )
61
+ self.bias_coef = config.balancing_coef
62
+ self.register_buffer(
63
+ "bias", torch.zeros(config.routed_experts), persistent=True
64
+ )
65
+ self.register_buffer(
66
+ "expert_counts", torch.zeros(config.routed_experts), persistent=False
67
+ )
68
+
69
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
70
+ # calculating with fp32 for stability
71
+ # num_tokens n_shared_experts
72
+ gate_logits = self.gate(x)
73
+ if self.gate_noise is not None:
74
+ gate_logits_noise = F.softplus(self.gate_noise(x))
75
+ gate_logits_noise = torch.randn_like(gate_logits_noise) * gate_logits_noise
76
+ gate_logits = gate_logits + gate_logits_noise
77
+
78
+ gate_weights = gate_logits.sigmoid()
79
+ original_weights = gate_weights
80
+
81
+ gate_weights = gate_weights + self.bias
82
+
83
+ _, top_experts_idx = torch.topk(gate_weights, self.experts_to_select, dim=-1)
84
+ counts = torch.bincount(
85
+ top_experts_idx.flatten(), minlength=self.config.routed_experts
86
+ ).detach()
87
+ if self.training:
88
+ self.expert_counts += counts
89
+ top_experts_weights = original_weights.gather(1, top_experts_idx)
90
+ top_experts_weights = top_experts_weights / top_experts_weights.sum(
91
+ dim=-1, keepdim=True
92
+ )
93
+ return top_experts_idx, top_experts_weights.type_as(x), counts.tolist()
94
+
95
+ def update_bias(self):
96
+ mean = self.expert_counts.float().mean()
97
+ delta = self.bias_coef * torch.sign(mean - self.expert_counts)
98
+ self.bias += delta
99
+ self.expert_counts.zero_()
100
+
101
+
102
+ class MoE(nn.Module):
103
+ def __init__(self, config: SmalLmConfig, *args, **kwargs):
104
+ super().__init__(*args, **kwargs)
105
+ self.config = config
106
+ self.shared_experts = SwiGLU(
107
+ config.hidden_size,
108
+ config.shared_experts * config.expert_size,
109
+ config.moe_bias,
110
+ )
111
+ self.routed_experts = nn.ModuleList(
112
+ [
113
+ SwiGLU(config.hidden_size, config.expert_size, config.moe_bias)
114
+ for _ in range(config.routed_experts)
115
+ ]
116
+ )
117
+ self.router = Router(config)
118
+
119
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
120
+ shape = x.size()
121
+ x = x.view(-1, self.config.hidden_size)
122
+ experts_idx, experts_weights, counts = self.router(x)
123
+ out = torch.zeros_like(x)
124
+ for i, expert in enumerate(self.routed_experts):
125
+ if counts[i] == 0:
126
+ continue
127
+ idx, pos = torch.where(experts_idx == i)
128
+ out[idx] += expert(x[idx]) * experts_weights[idx, pos, None]
129
+ shared_out = self.shared_experts(x)
130
+ return (out + shared_out).view(shape)
131
+
132
+
133
+ class ComboMoe(nn.Module):
134
+ def __init__(self, config: SmalLmConfig, *args, **kwargs):
135
+ super().__init__(*args, **kwargs)
136
+ self.config = config
137
+ self.shared_experts = SwiGLU(
138
+ config.hidden_size,
139
+ config.shared_experts * config.expert_size,
140
+ config.moe_bias,
141
+ )
142
+ self.input_router = Router(config)
143
+ self.middle_router = Router(config)
144
+ self.out_router = Router(config)
145
+ self.routed_experts = nn.ModuleList(
146
+ [
147
+ nn.Linear(config.hidden_size, config.expert_size, bias=config.moe_bias)
148
+ for _ in range(config.routed_experts)
149
+ ]
150
+ )
151
+ self.middle_routed_experts = nn.ModuleList(
152
+ [
153
+ nn.Linear(config.expert_size, config.hidden_size, bias=config.moe_bias)
154
+ for _ in range(config.routed_experts)
155
+ ]
156
+ )
157
+ self.out_routed_experts = nn.ModuleList(
158
+ [
159
+ nn.Linear(config.expert_size, config.hidden_size, bias=config.moe_bias)
160
+ for _ in range(config.routed_experts)
161
+ ]
162
+ )
163
+ self.offset = config.routed_experts
164
+
165
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
166
+ shape = x.size()
167
+ x = x.view(-1, self.config.hidden_size)
168
+ iexpert_idx, iexpert_weights, icounts = self.input_router(x)
169
+ iout = torch.zeros((*x.shape[:-1], self.config.expert_size), device=x.device)
170
+ for i, expert in enumerate(self.routed_experts[: self.offset]):
171
+ if icounts[i] == 0:
172
+ continue
173
+ idx, pos = torch.where(iexpert_idx == i)
174
+ iout[idx] += expert(x[idx]) * iexpert_weights[idx, pos, None]
175
+
176
+ mexpert_idx, mexpert_weights, mcounts = self.middle_router(x)
177
+ for i, expert in enumerate(self.middle_routed_experts):
178
+ if mcounts[i] == 0:
179
+ continue
180
+ idx, pos = torch.where(mexpert_idx == i)
181
+ iout[idx] *= F.silu(expert(x[idx]) * mexpert_weights[idx, pos, None])
182
+
183
+ out = torch.zeros_like(x)
184
+ oexpert_idx, oexpert_weights, ocounts = self.out_router(iout)
185
+ for i, expert in enumerate(self.out_routed_experts):
186
+ if ocounts[i] == 0:
187
+ continue
188
+ idx, pos = torch.where(oexpert_idx == i)
189
+ out[idx] += expert(iout[idx]) * oexpert_weights[idx, pos, None]
190
+
191
+ shared_out = self.shared_experts(x)
192
+ return (out + shared_out).view(shape)
193
+
194
+
195
+ def build_alibi_bias(config: SmalLmConfig) -> torch.Tensor:
196
+ """Build ALiBi for specified number of heads:
197
+
198
+ Returns:
199
+ Tensor with ALiBi biases, shape: [num heads]
200
+ """
201
+ bias = (
202
+ 2**-8
203
+ / config.num_attention_heads
204
+ * torch.arange(1, config.num_attention_heads + 1).float()
205
+ )
206
+ return bias
207
+
208
+
209
+ def calc_rotation(num_rotaitions, dim, base, seq_len):
210
+ return (
211
+ dim
212
+ * torch.log(torch.tensor(seq_len).float() / (num_rotaitions * 2 * torch.pi))
213
+ / torch.log(torch.tensor(base))
214
+ )
215
+
216
+
217
+ def get_ramp_interpolation(min_idx, max_idx, thetas_dim, eps=1e-6):
218
+ if min_idx == max_idx:
219
+ max_idx += eps
220
+ mult = (torch.arange(thetas_dim) - min_idx) / (max_idx - min_idx)
221
+ mult = torch.clamp(mult, 0, 1)
222
+ return 1 - mult
223
+
224
+
225
+ def build_rope_bias(config: SmalLmConfig) -> torch.Tensor:
226
+ dim = config.head_size
227
+
228
+ theta = 1.0 / (config.rope_base ** (torch.arange(0, dim, 2).float() / dim))
229
+
230
+ # neural tangent kernel by part korrection
231
+ if config.max_seq_len > config.original_seq_len:
232
+ scale = config.max_seq_len / config.original_seq_len
233
+ # from idea that lambda = 2pi / theta_i and lmbad = seq_len / num_rotations, lambda - wavelen
234
+ low_interpolation_idx = max(
235
+ 0,
236
+ torch.ceil(
237
+ calc_rotation(
238
+ config.high_rotations,
239
+ dim,
240
+ config.rope_base,
241
+ config.original_seq_len,
242
+ )
243
+ ).item(),
244
+ )
245
+ high_interpolation_idx = min(
246
+ dim - 1,
247
+ torch.floor(
248
+ calc_rotation(
249
+ config.low_rotations, dim, config.rope_base, config.original_seq_len
250
+ )
251
+ ).item(),
252
+ )
253
+ interpolation_mult = get_ramp_interpolation(
254
+ low_interpolation_idx, high_interpolation_idx, dim // 2
255
+ )
256
+ theta = (1 - interpolation_mult) * theta / scale + interpolation_mult * theta
257
+
258
+ seq_idx = torch.arange(config.max_seq_len)
259
+ seq_theta = torch.outer(seq_idx, theta)
260
+ bias = torch.polar(torch.ones_like(seq_theta), seq_theta)
261
+ return bias
262
+
263
+
264
+ def apply_rope_bias(x: torch.Tensor, precompute_bias: torch.Tensor) -> torch.Tensor:
265
+ ini_dtype = x.dtype
266
+ # for stbility to fp32, also need for torch
267
+ x = rearrange(x.float(), "b n s (d i) -> b n s d i", i=2).contiguous()
268
+ x = torch.view_as_complex(x)
269
+ x = x * precompute_bias
270
+ x = torch.view_as_real(x)
271
+ x = rearrange(x, "b n s d i -> b n s (d i)")
272
+ return x.to(ini_dtype)
273
+
274
+
275
+ class CausalSelfAttention(nn.Module):
276
+ def __init__(self, config: SmalLmConfig, layer_idx: int, *args, **kwargs):
277
+ super().__init__(*args, **kwargs)
278
+ if config.num_attention_heads % config.num_kv_heads != 0:
279
+ raise ValueError("Num attention heads should divided by num kv heads")
280
+
281
+ self.config = config
282
+ self.layer_idx = layer_idx
283
+ self.head_per_group = config.num_attention_heads // config.num_kv_heads
284
+ self.q_proj = nn.Linear(
285
+ config.hidden_size,
286
+ config.head_size * config.num_attention_heads,
287
+ bias=config.attention_bias,
288
+ )
289
+ self.kv_proj = nn.Linear(
290
+ config.hidden_size,
291
+ config.head_size * config.num_kv_heads * 2,
292
+ bias=config.attention_bias,
293
+ )
294
+ self.out_proj = nn.Linear(
295
+ config.head_size * config.num_attention_heads,
296
+ config.hidden_size,
297
+ bias=config.attention_bias,
298
+ )
299
+
300
+ def forward(
301
+ self,
302
+ x: torch.Tensor,
303
+ attention_mask: torch.Tensor,
304
+ past_key_values: Optional[Cache | torch.FloatTensor],
305
+ cache_position: Optional[torch.LongTensor],
306
+ bias: torch.Tensor,
307
+ ):
308
+ q = self.q_proj(x)
309
+ kv = self.kv_proj(x)
310
+ q = rearrange(q, "b s (n d) -> b n s d", n=self.config.num_attention_heads)
311
+ k, v = rearrange(kv, "b s (n d q) -> q b n s d", q=2, d=self.config.head_size)
312
+
313
+ if self.config.positional_bias_type == "rope":
314
+ k = apply_rope_bias(k, bias)
315
+ q = apply_rope_bias(q, bias)
316
+
317
+ if past_key_values is not None:
318
+ # for static cache
319
+ cach_kwargs = {"cache_position": cache_position}
320
+ k, v = past_key_values.update(
321
+ key_states=k,
322
+ value_states=v,
323
+ layer_idx=self.layer_idx,
324
+ cache_kwargs=cach_kwargs,
325
+ )
326
+
327
+ is_causal = attention_mask is None and q.size(-2) > 1
328
+
329
+ # for enabling flash attention kernel
330
+ if USE_FLASH and x.is_cuda:
331
+ q = rearrange(q, "b n s d -> b s n d")
332
+ k = rearrange(k, "b n s d -> b s n d")
333
+ v = rearrange(v, "b n s d -> b s n d")
334
+ q, idx_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, attention_mask)
335
+ k, _, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, attention_mask)
336
+ v, _, _, _, _ = unpad_input(v, attention_mask)
337
+
338
+ k = k.contiguous()
339
+ v = v.contiguous()
340
+ q = q.contiguous()
341
+ if USE_FLASH and x.is_cuda:
342
+ attention_probs = flash_attn_varlen_func(
343
+ q,
344
+ k,
345
+ v,
346
+ cu_seqlens_q=cu_seqlens_q,
347
+ cu_seqlens_k=cu_seqlens_k,
348
+ max_seqlen_q=max_seqlen_q,
349
+ max_seqlen_k=max_seqlen_k,
350
+ dropout_p=self.config.attention_dropout if self.training else 0.0,
351
+ causal=True,
352
+ alibi_slopes=bias if self.config.attention_bias == "alibi" else None,
353
+ )
354
+ attention_probs = pad_input(attention_probs, idx_q, x.size(0), x.size(1))
355
+ out = rearrange(attention_probs, "b s n d -> b s (n d)")
356
+ else:
357
+ attention_probs = F.scaled_dot_product_attention(
358
+ q,
359
+ k,
360
+ v,
361
+ attn_mask=attention_mask,
362
+ enable_gqa=True,
363
+ is_causal=is_causal,
364
+ dropout_p=self.config.attention_dropout if self.training else 0.0,
365
+ )
366
+ out = rearrange(attention_probs, "b n s d -> b s (n d)")
367
+
368
+ out = self.out_proj(out)
369
+
370
+ return out, attention_probs
371
+
372
+
373
+ class WeightedResidual(nn.Module):
374
+ def __init__(self, config: SmalLmConfig, *args, **kwargs):
375
+ super().__init__(*args, **kwargs)
376
+ self.weight = nn.Parameter(
377
+ torch.ones(config.hidden_size), requires_grad=config.static_residual
378
+ )
379
+
380
+ def forward(self, short, long):
381
+ return self.weight * short + long
382
+
383
+
384
+ class Block(nn.Module):
385
+ def __init__(self, config: SmalLmConfig, layer_idx: int, *args, **kwargs):
386
+ super().__init__(*args, **kwargs)
387
+ self.attn_norm = nn.RMSNorm(
388
+ config.hidden_size,
389
+ eps=config.rms_norm_eps,
390
+ elementwise_affine=config.rms_affine,
391
+ )
392
+ self.ffn_norm = nn.RMSNorm(
393
+ config.hidden_size,
394
+ eps=config.rms_norm_eps,
395
+ elementwise_affine=config.rms_affine,
396
+ )
397
+ self.dropout1 = nn.Dropout(config.layer_dropout)
398
+ self.dropout2 = nn.Dropout(config.layer_dropout)
399
+ self.attention = CausalSelfAttention(config, layer_idx)
400
+ moe_class = MoE if config.moe_type == "default" else ComboMoe
401
+ self.mlp = (
402
+ moe_class(config)
403
+ if (
404
+ config.use_moe
405
+ and layer_idx % config.moe_period == 0
406
+ and layer_idx > config.no_moe_layers
407
+ )
408
+ else SwiGLU(config.hidden_size, config.intermediate_size, config.mlp_bias)
409
+ )
410
+ self.attention_residual = WeightedResidual(config)
411
+ self.ffn_residual = WeightedResidual(config)
412
+
413
+ def forward(
414
+ self,
415
+ inputs_embeds: torch.Tensor,
416
+ attention_mask: torch.Tensor,
417
+ past_key_values: Optional[Cache | torch.FloatTensor],
418
+ output_attentions: bool,
419
+ cache_position: Optional[torch.LongTensor],
420
+ bias: torch.Tensor,
421
+ ) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
422
+ identity = inputs_embeds
423
+
424
+ # attention block
425
+ out = self.attn_norm(inputs_embeds)
426
+ out, attention_probs = self.attention(
427
+ out, attention_mask, past_key_values, cache_position, bias
428
+ )
429
+ out = self.dropout1(out)
430
+ identity = self.attention_residual(identity, out)
431
+
432
+ # swiglu / MoE block
433
+ out = self.dropout2(self.mlp(self.ffn_norm(identity)))
434
+ out = self.ffn_residual(identity, out)
435
+ if output_attentions:
436
+ return out, attention_probs
437
+ return (out,)
438
+
439
+
440
+ class SmalLmPreTrainedModel(PreTrainedModel):
441
+ config_class = SmalLmConfig
442
+ base_model_prefix = "model"
443
+ supports_gradient_checkpointing = True
444
+ _no_split_modules = ["Block"]
445
+ _skip_keys_device_placement = "past_key_values"
446
+
447
+ def __init__(self, *inputs, **kwargs):
448
+ super().__init__(*inputs, **kwargs)
449
+
450
+ def _init_weights(self, module):
451
+ std = self.config.initializer_range
452
+ if isinstance(module, nn.Linear):
453
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
454
+ if module.bias is not None:
455
+ torch.nn.init.zeros_(module.bias)
456
+ elif isinstance(module, nn.Embedding):
457
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
458
+ module.weight.data[self.pad_idx].zero_()
459
+
460
+
461
+ class SmalLmModel(SmalLmPreTrainedModel):
462
+ def __init__(self, config: SmalLmConfig, *args, **kwargs):
463
+ super().__init__(config, *args, **kwargs)
464
+ self.config = config
465
+ self.pad_idx = config.pad_token_id
466
+ self.pad_token_id = config.pad_token_id
467
+ self.vocab_size = config.vocab_size
468
+ self.config = config
469
+ precompute_bias = (
470
+ build_alibi_bias(config)
471
+ if config.positional_bias_type == "alibi"
472
+ else build_rope_bias(config)
473
+ )
474
+ self.register_buffer("precompute_bias", precompute_bias, persistent=False)
475
+ # не забыть про sharing weights на output голове self.embedding.weight = self.output.weight
476
+ self.embedding = nn.Embedding(
477
+ self.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
478
+ )
479
+ self.embedding_dropout = nn.Dropout(config.embedding_dropout)
480
+ self.layers = nn.ModuleList(
481
+ [Block(config, idx) for idx in range(1, config.num_hidden_layers + 1)]
482
+ )
483
+ self.out_norm = nn.RMSNorm(
484
+ config.hidden_size,
485
+ eps=config.rms_norm_eps,
486
+ elementwise_affine=config.rms_affine,
487
+ )
488
+
489
+ self.gradient_checkpointing = False
490
+ self.post_init()
491
+
492
+ def get_input_embeddings(self):
493
+ return self.embedding
494
+
495
+ def set_input_embeddings(self, value):
496
+ self.embedding = value
497
+
498
+ def forward(
499
+ self,
500
+ # input options
501
+ input_ids: torch.LongTensor = None,
502
+ attention_mask: Optional[torch.Tensor] = None,
503
+ inputs_embeds: Optional[torch.FloatTensor] = None,
504
+ # output options
505
+ output_attentions: Optional[bool] = None,
506
+ output_hidden_states: Optional[bool] = None,
507
+ return_dict: Optional[bool] = None,
508
+ # cache options
509
+ use_cache: Optional[bool] = None,
510
+ past_key_values: Optional[Cache | torch.FloatTensor] = None,
511
+ cache_position: Optional[torch.LongTensor] = None,
512
+ position_ids: Optional[torch.LongTensor] = None,
513
+ **kwargs,
514
+ ) -> tuple | BaseModelOutputWithPast:
515
+ # check additional parameters
516
+ output_attentions = False
517
+ output_hidden_states = (
518
+ output_hidden_states
519
+ if output_hidden_states is not None
520
+ else self.config.output_hidden_states
521
+ )
522
+ use_cache = (
523
+ use_cache
524
+ if use_cache is not None
525
+ else (False if self.training else self.config.use_cache)
526
+ )
527
+ return_dict = (
528
+ return_dict if return_dict is not None else self.config.return_dict
529
+ )
530
+
531
+ if input_ids is not None and inputs_embeds is not None:
532
+ raise ValueError(
533
+ "You must specify only input_ids or inputs_embeds, not both"
534
+ )
535
+
536
+ if self.training and use_cache:
537
+ use_cache = False
538
+
539
+ if inputs_embeds is None:
540
+ inputs_embeds = self.embedding(input_ids)
541
+
542
+ if use_cache and past_key_values is None:
543
+ past_key_values = DynamicCache()
544
+
545
+ # calculating position for StaticCache
546
+ if cache_position is None:
547
+ last_position = (
548
+ past_key_values.get_seq_length() if past_key_values is not None else 0
549
+ )
550
+ cache_position = torch.arange(
551
+ last_position,
552
+ last_position + inputs_embeds.size(1),
553
+ device=inputs_embeds.device,
554
+ )
555
+
556
+ causal_mask = self._get_causal_masks(
557
+ attention_mask, inputs_embeds, past_key_values, cache_position
558
+ )
559
+ if self.config.positional_bias_type == "rope":
560
+ end_pos = (
561
+ inputs_embeds.size(1)
562
+ if past_key_values is None
563
+ else cache_position[-1] + 1
564
+ )
565
+ start_pos = 0 if past_key_values is None else cache_position[0]
566
+ bias = self.precompute_bias[start_pos:end_pos]
567
+
568
+ elif self.config.positional_bias_type == "alibi":
569
+ if USE_FLASH and input_ids.is_cuda:
570
+ bias = self.precompute_bias
571
+ else:
572
+ i = torch.arange(
573
+ (
574
+ inputs_embeds.size(1)
575
+ if past_key_values is None
576
+ else cache_position[-1] + 1
577
+ ),
578
+ device=inputs_embeds.device,
579
+ )
580
+ bias = i[:, None] - i[None, :]
581
+ bias = torch.tril(bias).expand(
582
+ inputs_embeds.size(0), self.config.num_attention_heads, -1, -1
583
+ ) * rearrange(self.precompute_bias, "n -> 1 n 1 1")
584
+ if causal_mask is not None:
585
+ causal_mask = causal_mask + bias
586
+ else:
587
+ causal_mask = bias
588
+
589
+ hidden_state = inputs_embeds
590
+ hidden_states = [hidden_state] if output_hidden_states else None
591
+ attentions = [] if output_attentions else None
592
+ for idx, layer in enumerate(self.layers, 1):
593
+ if self.gradient_checkpointing:
594
+ # for details see:
595
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3107
596
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3149
597
+ layer_out = self._gradient_checkpointing_func(
598
+ layer.__call__,
599
+ hidden_state,
600
+ causal_mask,
601
+ past_key_values,
602
+ output_attentions,
603
+ cache_position,
604
+ bias,
605
+ )
606
+ else:
607
+ layer_out = layer(
608
+ hidden_state,
609
+ causal_mask,
610
+ past_key_values,
611
+ output_attentions,
612
+ cache_position,
613
+ bias,
614
+ )
615
+ hidden_state = layer_out[0]
616
+ if output_hidden_states:
617
+ hidden_states.append(hidden_state)
618
+ if output_attentions:
619
+ attentions.append(layer_out[1])
620
+
621
+ hidden_state = self.out_norm(hidden_state)
622
+ out = BaseModelOutputWithPast(
623
+ last_hidden_state=hidden_state,
624
+ past_key_values=past_key_values if use_cache else None,
625
+ hidden_states=tuple(hidden_states) if hidden_states is not None else None,
626
+ attentions=tuple(attentions) if attentions is not None else None,
627
+ )
628
+ return out if return_dict else out.to_tuple()
629
+
630
+ def _get_causal_masks(
631
+ self,
632
+ attention_mask: Optional[torch.Tensor],
633
+ inputs_embeds: torch.Tensor,
634
+ past_key_values: Optional[torch.Tensor],
635
+ cache_position: Optional[torch.Tensor],
636
+ ):
637
+ if USE_FLASH and inputs_embeds.is_cuda:
638
+ return attention_mask
639
+ dtype, device = inputs_embeds.dtype, inputs_embeds.device
640
+ past_token = (
641
+ past_key_values.get_seq_length() if past_key_values is not None else 0
642
+ )
643
+ if attention_mask is not None and torch.all(attention_mask == 0.0):
644
+ return None
645
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
646
+ attention_mask=attention_mask,
647
+ inputs_embeds=inputs_embeds,
648
+ past_key_values_length=past_token,
649
+ is_training=self.training,
650
+ ):
651
+ return None
652
+
653
+ sequence_length = inputs_embeds.size(1)
654
+ target_length = (
655
+ attention_mask.size(-1)
656
+ if isinstance(attention_mask, torch.Tensor)
657
+ else past_token + sequence_length + 1
658
+ )
659
+
660
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
661
+ attention_mask=attention_mask,
662
+ sequence_length=sequence_length,
663
+ target_length=target_length,
664
+ dtype=dtype,
665
+ device=device,
666
+ cache_position=cache_position,
667
+ batch_size=inputs_embeds.size(0),
668
+ )
669
+
670
+ min_dtype = torch.finfo(dtype).min
671
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
672
+ return causal_mask
673
+
674
+ @staticmethod
675
+ def _prepare_4d_causal_attention_mask_with_cache_position(
676
+ attention_mask: Optional[torch.Tensor],
677
+ sequence_length: int,
678
+ target_length: int,
679
+ dtype: torch.dtype,
680
+ device: torch.device,
681
+ cache_position: Optional[torch.Tensor],
682
+ batch_size: int,
683
+ ):
684
+ if attention_mask is not None and attention_mask.dim() == 4:
685
+ causal_mask = attention_mask
686
+ else:
687
+ min_dtype = torch.finfo(dtype).min
688
+ causal_mask = torch.full(
689
+ (sequence_length, target_length),
690
+ fill_value=min_dtype,
691
+ dtype=dtype,
692
+ device=device,
693
+ )
694
+ if sequence_length != 1:
695
+ causal_mask = torch.triu(causal_mask, diagonal=1)
696
+ causal_mask *= torch.arange(
697
+ target_length, device=device
698
+ ) > cache_position.reshape(-1, 1)
699
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
700
+ if attention_mask is not None:
701
+ causal_mask = causal_mask.clone()
702
+ mask_length = attention_mask.shape[-1]
703
+ padding_mask = (
704
+ causal_mask[:, :, :, :mask_length]
705
+ + attention_mask[:, None, None, :]
706
+ )
707
+ padding_mask = padding_mask == 0
708
+ causal_mask[:, :, :, :mask_length] = causal_mask[
709
+ :, :, :, :mask_length
710
+ ].masked_fill(padding_mask, min_dtype)
711
+ return causal_mask
712
+
713
+
714
+ class SmalLmForCausalLM(SmalLmPreTrainedModel, GenerationMixin):
715
+ _tied_weights_keys = ["lm_head.weight"]
716
+
717
+ def __init__(self, config: SmalLmConfig, *args, **kwargs):
718
+ super().__init__(config, *args, **kwargs)
719
+ self.config = config
720
+ self.model = SmalLmModel(config)
721
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
722
+ self.post_init()
723
+
724
+ def get_output_embeddings(self):
725
+ return self.lm_head
726
+
727
+ def set_output_embeddings(self, new_embeddings):
728
+ self.lm_head = new_embeddings
729
+
730
+ def forward(
731
+ self,
732
+ # input options
733
+ input_ids: torch.LongTensor = None,
734
+ attention_mask: Optional[torch.Tensor] = None,
735
+ inputs_embeds: Optional[torch.FloatTensor] = None,
736
+ # output options
737
+ output_attentions: Optional[bool] = None,
738
+ output_hidden_states: Optional[bool] = None,
739
+ return_dict: Optional[bool] = None,
740
+ # cache options
741
+ use_cache: Optional[bool] = None,
742
+ past_key_values: Optional[Cache | torch.FloatTensor] = None,
743
+ cache_position: Optional[torch.LongTensor] = None,
744
+ # generation options
745
+ labels: Optional[torch.Tensor] = None,
746
+ logits_to_keep: int | torch.Tensor = 0,
747
+ **kwargs,
748
+ ) -> tuple | CausalLMOutputWithPast:
749
+ output_attentions = (
750
+ output_attentions
751
+ if output_attentions is not None
752
+ else self.config.output_attentions
753
+ )
754
+ output_hidden_states = (
755
+ output_hidden_states
756
+ if output_hidden_states is not None
757
+ else self.config.output_hidden_states
758
+ )
759
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
760
+ return_dict = (
761
+ return_dict if return_dict is not None else self.config.return_dict
762
+ )
763
+
764
+ model_outputs = self.model(
765
+ input_ids=input_ids,
766
+ attention_mask=attention_mask,
767
+ past_key_values=past_key_values,
768
+ inputs_embeds=inputs_embeds,
769
+ use_cache=use_cache,
770
+ output_attentions=output_attentions,
771
+ output_hidden_states=output_hidden_states,
772
+ return_dict=return_dict,
773
+ cache_position=cache_position,
774
+ **kwargs,
775
+ )
776
+
777
+ hidden_states = model_outputs[0]
778
+ logits = self.lm_head(hidden_states)
779
+
780
+ loss = None
781
+ if labels is not None:
782
+ shifted_logits = logits[:, :-1].contiguous()
783
+ shifted_labels = labels[:, 1:].contiguous()
784
+ loss = F.cross_entropy(
785
+ shifted_logits.view(-1, self.config.vocab_size), shifted_labels.view(-1)
786
+ )
787
+
788
+ if not return_dict:
789
+ output = (logits, model_outputs[1:])
790
+ return (loss, output) if loss is not None else output
791
+
792
+ return CausalLMOutputWithPast(
793
+ loss=loss,
794
+ logits=logits,
795
+ past_key_values=model_outputs.past_key_values,
796
+ hidden_states=model_outputs.hidden_states,
797
+ attentions=model_outputs.attentions,
798
+ )
799
+
800
+
801
+ __all__ = ["SmalLmForCausalLM", "SmalLmModel", "SmalLmPreTrainedModel"]