llaa33219 commited on
Commit
91ebbd3
·
verified ·
1 Parent(s): f3deb25

Upload modeling_solar_open.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_solar_open.py +605 -0
modeling_solar_open.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Upstage AI.
3
+ # Copyright 2025 The GLM4 & ZhipuAI team and HuggingFace Inc. team.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ # This file has been modified by Upstage AI including:
18
+ # - Hybrid MoE Architecture: Replaced the standard dense structure with a depth-dependent Hybrid MoE, adding `SolarOpenMoE` and `SolarOpenTopkRouter` classes.
19
+ # - RoPE Strategy: Changed the rotary position embedding strategy from GLM4's interleaved rotation to Llama-style block rotation (via modified `rotate_half`).
20
+ # - Normalization Logic: Simplified the layer normalization structure by removing GLM4's extra post-operation norms and adding optional Query-Key Normalization (`use_qk_norm`).
21
+ #
22
+ # Based on code from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4/modeling_glm4.py
23
+
24
+ from typing import Callable, Optional, Union
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from torch import nn
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.cache_utils import Cache, DynamicCache
32
+ from transformers.generation import GenerationMixin
33
+ from transformers.integrations import use_kernel_forward_from_hub
34
+ from transformers.masking_utils import create_causal_mask
35
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
36
+ from transformers.modeling_layers import GradientCheckpointingLayer
37
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
38
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
40
+ from transformers.processing_utils import Unpack
41
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
42
+ from transformers.utils.deprecation import deprecate_kwarg
43
+ from transformers.utils.generic import check_model_inputs
44
+ from .configuration_solar_open import SolarOpenConfig
45
+
46
+
47
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
48
+ """
49
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
50
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
51
+ """
52
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
53
+ if n_rep == 1:
54
+ return hidden_states
55
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
56
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
57
+
58
+
59
+ def eager_attention_forward(
60
+ module: nn.Module,
61
+ query: torch.Tensor,
62
+ key: torch.Tensor,
63
+ value: torch.Tensor,
64
+ attention_mask: Optional[torch.Tensor],
65
+ scaling: float,
66
+ dropout: float = 0.0,
67
+ **kwargs: Unpack[TransformersKwargs],
68
+ ):
69
+ key_states = repeat_kv(key, module.num_key_value_groups)
70
+ value_states = repeat_kv(value, module.num_key_value_groups)
71
+
72
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
73
+ if attention_mask is not None:
74
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
75
+ attn_weights = attn_weights + causal_mask
76
+
77
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
78
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
79
+ attn_output = torch.matmul(attn_weights, value_states)
80
+ attn_output = attn_output.transpose(1, 2).contiguous()
81
+
82
+ return attn_output, attn_weights
83
+
84
+
85
+ def rotate_half(x):
86
+ """Rotates half the hidden dims of the input."""
87
+ x1 = x[..., : x.shape[-1] // 2]
88
+ x2 = x[..., x.shape[-1] // 2 :]
89
+ return torch.cat((-x2, x1), dim=-1)
90
+
91
+
92
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
93
+ """Applies Rotary Position Embedding to the query and key tensors.
94
+
95
+ Args:
96
+ q (`torch.Tensor`): The query tensor.
97
+ k (`torch.Tensor`): The key tensor.
98
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
99
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
100
+ position_ids (`torch.Tensor`, *optional*):
101
+ Deprecated and unused.
102
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
103
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
104
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
105
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
106
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
107
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
108
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
109
+ Returns:
110
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
111
+ """
112
+ cos = cos.unsqueeze(unsqueeze_dim)
113
+ sin = sin.unsqueeze(unsqueeze_dim)
114
+
115
+ # Keep half or full tensor for later concatenation
116
+ rotary_dim = cos.shape[-1]
117
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
118
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
119
+
120
+ # Apply rotary embeddings on the first half or full tensor
121
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
122
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
123
+
124
+ # Concatenate back to full shape
125
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
126
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
127
+ return q_embed, k_embed
128
+
129
+
130
+ class SolarOpenAttention(nn.Module):
131
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
132
+
133
+ def __init__(self, config: SolarOpenConfig, layer_idx: Optional[int] = None):
134
+ super().__init__()
135
+ self.config = config
136
+ self.layer_idx = layer_idx
137
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
138
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
139
+ self.scaling = self.head_dim**-0.5
140
+ self.rope_scaling = config.rope_scaling
141
+ self.attention_dropout = config.attention_dropout
142
+ self.is_causal = True
143
+
144
+ self.q_proj = nn.Linear(
145
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
146
+ )
147
+ self.k_proj = nn.Linear(
148
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
149
+ )
150
+ self.v_proj = nn.Linear(
151
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
152
+ )
153
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
154
+ self.use_qk_norm = config.use_qk_norm
155
+ if self.use_qk_norm:
156
+ self.q_norm = SolarOpenRMSNorm(self.head_dim, eps=config.rms_norm_eps)
157
+ self.k_norm = SolarOpenRMSNorm(self.head_dim, eps=config.rms_norm_eps)
158
+
159
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
160
+ def forward(
161
+ self,
162
+ hidden_states: torch.Tensor,
163
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
164
+ attention_mask: Optional[torch.Tensor],
165
+ past_key_values: Optional[Cache] = None,
166
+ cache_position: Optional[torch.LongTensor] = None,
167
+ **kwargs: Unpack[FlashAttentionKwargs],
168
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
169
+ input_shape = hidden_states.shape[:-1]
170
+ hidden_shape = (*input_shape, -1, self.head_dim)
171
+
172
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
173
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
174
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
175
+
176
+ if self.use_qk_norm: # main diff from Llama
177
+ query_states = self.q_norm(query_states)
178
+ key_states = self.k_norm(key_states)
179
+
180
+ query_states = query_states.transpose(1, 2)
181
+ key_states = key_states.transpose(1, 2)
182
+ value_states = value_states.transpose(1, 2)
183
+
184
+ cos, sin = position_embeddings
185
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
186
+
187
+ if past_key_values is not None:
188
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
189
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
190
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
191
+
192
+ attention_interface: Callable = eager_attention_forward
193
+ if self.config._attn_implementation != "eager":
194
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
195
+
196
+ attn_output, attn_weights = attention_interface(
197
+ self,
198
+ query_states,
199
+ key_states,
200
+ value_states,
201
+ attention_mask,
202
+ dropout=0.0 if not self.training else self.attention_dropout,
203
+ scaling=self.scaling,
204
+ **kwargs,
205
+ )
206
+
207
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
208
+ attn_output = self.o_proj(attn_output)
209
+ return attn_output, attn_weights
210
+
211
+
212
+ class SolarOpenMLP(nn.Module):
213
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
214
+ super().__init__()
215
+ self.config = config
216
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
217
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
218
+
219
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
220
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
221
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
222
+ self.act_fn = ACT2FN[config.hidden_act]
223
+
224
+ def forward(self, x):
225
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
226
+ return down_proj
227
+
228
+
229
+ class SolarOpenTopkRouter(nn.Module):
230
+ def __init__(self, config: SolarOpenConfig):
231
+ super().__init__()
232
+ self.config = config
233
+ self.top_k = config.num_experts_per_tok
234
+ self.n_routed_experts = config.n_routed_experts
235
+ self.routed_scaling_factor = config.routed_scaling_factor
236
+ self.n_group = config.n_group
237
+ self.topk_group = config.topk_group
238
+ self.norm_topk_prob = config.norm_topk_prob
239
+
240
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
241
+ self.e_score_correction_bias = nn.Parameter(
242
+ torch.zeros((self.n_routed_experts), dtype=torch.float32))
243
+
244
+ @torch.no_grad()
245
+ def get_topk_indices(self, scores):
246
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
247
+ group_scores = (
248
+ scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
249
+ .topk(2, dim=-1)[0]
250
+ .sum(dim=-1)
251
+ )
252
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
253
+ group_mask = torch.zeros_like(group_scores)
254
+ group_mask.scatter_(1, group_idx, 1)
255
+ score_mask = (
256
+ group_mask.unsqueeze(-1)
257
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
258
+ .reshape(-1, self.n_routed_experts)
259
+ )
260
+ scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
261
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
262
+ return topk_indices
263
+
264
+ def forward(self, hidden_states):
265
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
266
+ router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
267
+ scores = router_logits.sigmoid()
268
+ topk_indices = self.get_topk_indices(scores)
269
+ topk_weights = scores.gather(1, topk_indices)
270
+ if self.norm_topk_prob:
271
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
272
+ topk_weights /= denominator
273
+ topk_weights = topk_weights * self.routed_scaling_factor
274
+ return topk_indices, topk_weights
275
+
276
+
277
+ @use_kernel_forward_from_hub("RMSNorm")
278
+ class SolarOpenRMSNorm(nn.Module):
279
+ def __init__(self, hidden_size, eps=1e-6):
280
+ """
281
+ SolarOpenRMSNorm is equivalent to T5LayerNorm
282
+ """
283
+ super().__init__()
284
+ self.weight = nn.Parameter(torch.ones(hidden_size))
285
+ self.variance_epsilon = eps
286
+
287
+ def forward(self, hidden_states):
288
+ input_dtype = hidden_states.dtype
289
+ hidden_states = hidden_states.to(torch.float32)
290
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
291
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
292
+ return self.weight * hidden_states.to(input_dtype)
293
+
294
+ def extra_repr(self):
295
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
296
+
297
+
298
+ class SolarOpenMoE(nn.Module):
299
+ """
300
+ A mixed expert module containing shared experts.
301
+ """
302
+
303
+ def __init__(self, config):
304
+ super().__init__()
305
+ self.config = config
306
+ self.experts = nn.ModuleList(
307
+ [
308
+ SolarOpenMLP(config, intermediate_size=config.moe_intermediate_size)
309
+ for _ in range(config.n_routed_experts)
310
+ ]
311
+ )
312
+ self.gate = SolarOpenTopkRouter(config)
313
+ self.shared_experts = SolarOpenMLP(
314
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
315
+ )
316
+
317
+ @torch.compiler.disable()
318
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
319
+ r"""
320
+ MoE forward pass that only executes selected experts.
321
+ Uses @torch.compiler.disable() to allow dynamic shape operations.
322
+ Requires --enforce-eager flag when serving with vLLM.
323
+ """
324
+ final_hidden_states = torch.zeros_like(hidden_states)
325
+
326
+ for expert_idx in range(len(self.experts)):
327
+ expert = self.experts[expert_idx]
328
+
329
+ # Find positions where this expert was selected
330
+ batch_idx, topk_pos = torch.where(topk_indices == expert_idx)
331
+
332
+ if batch_idx.numel() == 0:
333
+ continue
334
+
335
+ # Extract only the tokens routed to this expert
336
+ expert_input = hidden_states[batch_idx]
337
+ expert_output = expert(expert_input)
338
+
339
+ # Apply weights and accumulate results
340
+ weights = topk_weights[batch_idx, topk_pos].unsqueeze(-1)
341
+ final_hidden_states.index_add_(0, batch_idx, (expert_output * weights).to(hidden_states.dtype))
342
+
343
+ return final_hidden_states
344
+
345
+ def forward(self, hidden_states):
346
+ residuals = hidden_states
347
+ orig_shape = hidden_states.shape
348
+ topk_indices, topk_weights = self.gate(hidden_states)
349
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
350
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
351
+ hidden_states = hidden_states + self.shared_experts(residuals)
352
+ return hidden_states
353
+
354
+
355
+ class SolarOpenDecoderLayer(GradientCheckpointingLayer):
356
+ def __init__(self, config: SolarOpenConfig, layer_idx: int):
357
+ super().__init__()
358
+ self.hidden_size = config.hidden_size
359
+
360
+ self.self_attn = SolarOpenAttention(config=config, layer_idx=layer_idx)
361
+
362
+ if layer_idx >= config.first_k_dense_replace:
363
+ self.mlp = SolarOpenMoE(config)
364
+ else:
365
+ self.mlp = SolarOpenMLP(config)
366
+
367
+ self.input_layernorm = SolarOpenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
368
+ self.post_attention_layernorm = SolarOpenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
369
+
370
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
371
+ def forward(
372
+ self,
373
+ hidden_states: torch.Tensor,
374
+ attention_mask: Optional[torch.Tensor] = None,
375
+ position_ids: Optional[torch.LongTensor] = None,
376
+ past_key_values: Optional[Cache] = None,
377
+ use_cache: Optional[bool] = False,
378
+ cache_position: Optional[torch.LongTensor] = None,
379
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
380
+ **kwargs: Unpack[TransformersKwargs],
381
+ ) -> torch.Tensor:
382
+ residual = hidden_states
383
+ hidden_states = self.input_layernorm(hidden_states)
384
+ # Self Attention
385
+ hidden_states, _ = self.self_attn(
386
+ hidden_states=hidden_states,
387
+ attention_mask=attention_mask,
388
+ position_ids=position_ids,
389
+ past_key_values=past_key_values,
390
+ use_cache=use_cache,
391
+ cache_position=cache_position,
392
+ position_embeddings=position_embeddings,
393
+ **kwargs,
394
+ )
395
+ hidden_states = residual + hidden_states
396
+
397
+ # Fully Connected
398
+ residual = hidden_states
399
+ hidden_states = self.post_attention_layernorm(hidden_states)
400
+ hidden_states = self.mlp(hidden_states)
401
+ hidden_states = residual + hidden_states
402
+ return hidden_states
403
+
404
+
405
+ @auto_docstring
406
+ class SolarOpenPreTrainedModel(PreTrainedModel):
407
+ config: SolarOpenConfig
408
+ base_model_prefix = "model"
409
+ supports_gradient_checkpointing = True
410
+ _no_split_modules = ["SolarOpenDecoderLayer"]
411
+ _skip_keys_device_placement = ["past_key_values"]
412
+ _supports_flash_attn = True
413
+ _supports_sdpa = True
414
+ _supports_flex_attn = True
415
+ _can_compile_fullgraph = False
416
+ _supports_attention_backend = True
417
+ _can_record_outputs = {
418
+ "hidden_states": SolarOpenDecoderLayer,
419
+ "attentions": SolarOpenAttention,
420
+ }
421
+
422
+ def _init_weights(self, module):
423
+ super()._init_weights(module)
424
+ if isinstance(module, SolarOpenTopkRouter):
425
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
426
+
427
+
428
+ class SolarOpenRotaryEmbedding(nn.Module):
429
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
430
+
431
+ def __init__(self, config: SolarOpenConfig, device=None):
432
+ super().__init__()
433
+ # BC: "rope_type" was originally "type"
434
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
435
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
436
+ else:
437
+ self.rope_type = "default"
438
+ self.max_seq_len_cached = config.max_position_embeddings
439
+ self.original_max_seq_len = config.max_position_embeddings
440
+
441
+ self.config = config
442
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
443
+
444
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
445
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
446
+ self.original_inv_freq = self.inv_freq
447
+
448
+ @torch.no_grad()
449
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
450
+ def forward(self, x, position_ids):
451
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
452
+ position_ids_expanded = position_ids[:, None, :].float()
453
+
454
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
455
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
456
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
457
+ emb = torch.cat((freqs, freqs), dim=-1)
458
+ cos = emb.cos() * self.attention_scaling
459
+ sin = emb.sin() * self.attention_scaling
460
+
461
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
462
+
463
+
464
+ @auto_docstring
465
+ class SolarOpenModel(SolarOpenPreTrainedModel):
466
+ _keys_to_ignore_on_load_unexpected = [r"model\.layers\.92.*", r"model\.layers\.46.*"]
467
+
468
+ def __init__(self, config: SolarOpenConfig):
469
+ super().__init__(config)
470
+ self.padding_idx = config.pad_token_id
471
+ self.vocab_size = config.vocab_size
472
+
473
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
474
+ self.layers = nn.ModuleList(
475
+ [SolarOpenDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
476
+ )
477
+ self.norm = SolarOpenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
478
+ self.rotary_emb = SolarOpenRotaryEmbedding(config=config)
479
+ self.gradient_checkpointing = False
480
+
481
+ # Initialize weights and apply final processing
482
+ self.post_init()
483
+
484
+ @check_model_inputs()
485
+ @auto_docstring
486
+ def forward(
487
+ self,
488
+ input_ids: Optional[torch.LongTensor] = None,
489
+ attention_mask: Optional[torch.Tensor] = None,
490
+ position_ids: Optional[torch.LongTensor] = None,
491
+ past_key_values: Optional[Cache] = None,
492
+ inputs_embeds: Optional[torch.FloatTensor] = None,
493
+ cache_position: Optional[torch.LongTensor] = None,
494
+ use_cache: Optional[bool] = None,
495
+ **kwargs: Unpack[TransformersKwargs],
496
+ ) -> BaseModelOutputWithPast:
497
+ if (input_ids is None) ^ (inputs_embeds is not None):
498
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
499
+
500
+ if inputs_embeds is None:
501
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
502
+
503
+ if use_cache and past_key_values is None:
504
+ past_key_values = DynamicCache(config=self.config)
505
+
506
+ if cache_position is None:
507
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
508
+ cache_position: torch.Tensor = torch.arange(
509
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
510
+ )
511
+
512
+ if position_ids is None:
513
+ position_ids = cache_position.unsqueeze(0)
514
+
515
+ causal_mask = create_causal_mask(
516
+ config=self.config,
517
+ input_embeds=inputs_embeds,
518
+ attention_mask=attention_mask,
519
+ cache_position=cache_position,
520
+ past_key_values=past_key_values,
521
+ position_ids=position_ids,
522
+ )
523
+
524
+ hidden_states = inputs_embeds
525
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
526
+
527
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
528
+ hidden_states = decoder_layer(
529
+ hidden_states,
530
+ attention_mask=causal_mask,
531
+ position_ids=position_ids,
532
+ past_key_values=past_key_values,
533
+ cache_position=cache_position,
534
+ position_embeddings=position_embeddings,
535
+ **kwargs,
536
+ )
537
+
538
+ hidden_states = self.norm(hidden_states)
539
+ return BaseModelOutputWithPast(
540
+ last_hidden_state=hidden_states,
541
+ past_key_values=past_key_values,
542
+ )
543
+
544
+
545
+ @auto_docstring
546
+ class SolarOpenForCausalLM(SolarOpenPreTrainedModel, GenerationMixin):
547
+ _tied_weights_keys = ["lm_head.weight"]
548
+ _tp_plan = {"lm_head": "colwise_rep"}
549
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
550
+
551
+ def __init__(self, config):
552
+ super().__init__(config)
553
+ self.model = SolarOpenModel(config)
554
+ self.vocab_size = config.vocab_size
555
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
556
+
557
+ # Initialize weights and apply final processing
558
+ self.post_init()
559
+
560
+ @can_return_tuple
561
+ @auto_docstring
562
+ def forward(
563
+ self,
564
+ input_ids: Optional[torch.LongTensor] = None,
565
+ attention_mask: Optional[torch.Tensor] = None,
566
+ position_ids: Optional[torch.LongTensor] = None,
567
+ past_key_values: Optional[Cache] = None,
568
+ inputs_embeds: Optional[torch.FloatTensor] = None,
569
+ labels: Optional[torch.LongTensor] = None,
570
+ use_cache: Optional[bool] = None,
571
+ cache_position: Optional[torch.LongTensor] = None,
572
+ logits_to_keep: Union[int, torch.Tensor] = 0,
573
+ **kwargs: Unpack[TransformersKwargs],
574
+ ) -> CausalLMOutputWithPast:
575
+
576
+ outputs: BaseModelOutputWithPast = self.model(
577
+ input_ids=input_ids,
578
+ attention_mask=attention_mask,
579
+ position_ids=position_ids,
580
+ past_key_values=past_key_values,
581
+ inputs_embeds=inputs_embeds,
582
+ use_cache=use_cache,
583
+ cache_position=cache_position,
584
+ **kwargs,
585
+ )
586
+
587
+ hidden_states = outputs.last_hidden_state
588
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
589
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
590
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
591
+
592
+ loss = None
593
+ if labels is not None:
594
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
595
+
596
+ return CausalLMOutputWithPast(
597
+ loss=loss,
598
+ logits=logits,
599
+ past_key_values=outputs.past_key_values,
600
+ hidden_states=outputs.hidden_states,
601
+ attentions=outputs.attentions,
602
+ )
603
+
604
+
605
+ __all__ = ["SolarOpenPreTrainedModel", "SolarOpenModel", "SolarOpenForCausalLM"]