muverqqw commited on
Commit
598ebaa
·
1 Parent(s): 341ca1c

Update modeling_alinlight.py

Browse files
Files changed (1) hide show
  1. modeling_alinlight.py +359 -101
modeling_alinlight.py CHANGED
@@ -13,29 +13,35 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
16
-
17
  import math
18
  import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
  from typing import Optional, Tuple, List, Union
 
 
22
  from transformers import PreTrainedModel, GenerationMixin
23
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
24
  from configuration_alinlight import AlinlightConfig
25
 
 
 
 
 
26
  class AlinlightRMSNorm(nn.Module):
27
- def __init__(self, hidden_size, eps=1e-6):
28
  super().__init__()
29
  self.weight = nn.Parameter(torch.ones(hidden_size))
30
  self.eps = eps
31
 
32
- def forward(self, x):
33
  input_dtype = x.dtype
34
  x = x.to(torch.float32)
35
  variance = x.pow(2).mean(-1, keepdim=True)
36
  x = x * torch.rsqrt(variance + self.eps)
37
  return self.weight * x.to(input_dtype)
38
 
 
39
  class AlinlightRotaryEmbedding(nn.Module):
40
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
41
  super().__init__()
@@ -43,16 +49,18 @@ class AlinlightRotaryEmbedding(nn.Module):
43
  self.base = base
44
  self.max_position_embeddings = max_position_embeddings
45
  self.scaling_factor = scaling_factor
46
-
47
 
48
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim))
49
  self.register_buffer("inv_freq", inv_freq, persistent=False)
50
-
51
-
52
  self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
53
 
54
  def _set_cos_sin_cache(self, seq_len, device, dtype):
55
-
 
 
 
 
 
56
  t = torch.arange(seq_len, device=device, dtype=torch.int64).type_as(self.inv_freq)
57
  t = t / self.scaling_factor
58
  freqs = torch.outer(t, self.inv_freq)
@@ -61,20 +69,20 @@ class AlinlightRotaryEmbedding(nn.Module):
61
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
62
 
63
  def forward(self, x, seq_len=None):
64
-
65
  if seq_len > self.cos_cached.shape[0]:
66
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
67
-
68
  return (
69
  self.cos_cached[:seq_len].to(dtype=x.dtype, device=x.device),
70
  self.sin_cached[:seq_len].to(dtype=x.dtype, device=x.device)
71
  )
72
 
73
- def rotate_half(x):
 
74
  x1 = x[..., : x.shape[-1] // 2]
75
  x2 = x[..., x.shape[-1] // 2 :]
76
  return torch.cat((-x2, x1), dim=-1)
77
 
 
78
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
79
  cos = cos[position_ids].unsqueeze(unsqueeze_dim)
80
  sin = sin[position_ids].unsqueeze(unsqueeze_dim)
@@ -82,6 +90,11 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
82
  k_embed = (k * cos) + (rotate_half(k) * sin)
83
  return q_embed, k_embed
84
 
 
 
 
 
 
85
  class AlinlightMLP(nn.Module):
86
  def __init__(self, config):
87
  super().__init__()
@@ -91,9 +104,20 @@ class AlinlightMLP(nn.Module):
91
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
92
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
93
  self.act_fn = nn.SiLU()
 
 
 
 
94
 
95
  def forward(self, x):
96
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 
 
 
 
 
 
 
97
 
98
  class AlinlightAttention(nn.Module):
99
  def __init__(self, config, layer_idx: Optional[int] = None):
@@ -106,71 +130,120 @@ class AlinlightAttention(nn.Module):
106
  self.num_key_value_heads = config.num_key_value_heads
107
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
108
  self.sliding_window = config.sliding_window
109
-
 
110
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
111
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
112
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
113
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
 
 
 
 
 
 
 
 
 
 
114
 
115
  def forward(
116
- self,
117
- hidden_states,
118
- attention_mask=None,
119
- position_ids=None,
120
- past_key_value=None,
121
- output_attentions=False,
122
- use_cache=False,
123
- cos_sin=None
124
- ):
 
125
  bsz, q_len, _ = hidden_states.size()
126
-
127
  query_states = self.q_proj(hidden_states)
128
  key_states = self.k_proj(hidden_states)
129
  value_states = self.v_proj(hidden_states)
130
-
131
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
132
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
133
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
134
 
 
 
 
 
 
135
  if cos_sin is not None:
136
  cos, sin = cos_sin
137
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
138
 
 
139
  if past_key_value is not None:
140
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
141
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
142
 
 
143
 
144
- if self.sliding_window is not None and key_states.shape[2] > self.sliding_window:
145
- key_states = key_states[:, :, -self.sliding_window:, :]
146
- value_states = value_states[:, :, -self.sliding_window:, :]
 
 
 
 
 
147
 
148
  past_key_value = (key_states, value_states) if use_cache else None
149
-
150
-
151
  if self.num_key_value_groups > 1:
152
- key_states = key_states[:, :, None, :, :].expand(
153
- bsz, self.num_key_value_heads, self.num_key_value_groups, key_states.shape[-2], self.head_dim
154
- ).reshape(bsz, self.num_heads, key_states.shape[-2], self.head_dim)
155
-
156
- value_states = value_states[:, :, None, :, :].expand(
157
- bsz, self.num_key_value_heads, self.num_key_value_groups, value_states.shape[-2], self.head_dim
158
- ).reshape(bsz, self.num_heads, value_states.shape[-2], self.head_dim)
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- is_causal = q_len > 1
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
- attn_output = F.scaled_dot_product_attention(
164
- query_states,
165
- key_states,
166
- value_states,
167
- attn_mask=None,
168
- dropout_p=0.0,
169
- is_causal=is_causal
170
- )
171
-
172
  attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
173
- return self.o_proj(attn_output), None, past_key_value
 
 
 
 
 
174
 
175
  class AlinlightDecoderLayer(nn.Module):
176
  def __init__(self, config, layer_idx: int):
@@ -179,99 +252,218 @@ class AlinlightDecoderLayer(nn.Module):
179
  self.mlp = AlinlightMLP(config)
180
  self.input_layernorm = AlinlightRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
181
  self.post_attention_layernorm = AlinlightRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
 
182
 
183
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False, cos_sin=None):
 
 
 
 
 
 
 
 
 
184
  residual = hidden_states
185
  hidden_states = self.input_layernorm(hidden_states)
186
- hidden_states, _, present_key_value = self.self_attn(
187
- hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cos_sin
 
 
 
 
 
 
 
188
  )
189
- hidden_states = residual + hidden_states
190
-
191
  residual = hidden_states
192
  hidden_states = self.post_attention_layernorm(hidden_states)
193
  hidden_states = self.mlp(hidden_states)
194
- hidden_states = residual + hidden_states
195
- return hidden_states, None, present_key_value
 
 
196
 
197
  class AlinlightModel(PreTrainedModel):
198
  config_class = AlinlightConfig
 
199
  def __init__(self, config: AlinlightConfig):
200
  super().__init__(config)
201
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
 
 
 
 
 
 
 
202
  self.layers = nn.ModuleList([AlinlightDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
203
  self.norm = AlinlightRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
204
-
205
  scaling_factor = 1.0
206
  if config.rope_scaling and config.rope_scaling.get("type") == "linear":
207
  scaling_factor = config.rope_scaling.get("factor", 1.0)
208
-
209
  self.rotary_emb = AlinlightRotaryEmbedding(
210
- config.hidden_size // config.num_attention_heads,
211
- max_position_embeddings=config.max_position_embeddings,
212
- base=config.rope_theta,
213
  scaling_factor=scaling_factor
214
  )
 
215
  self.post_init()
216
 
217
  def get_input_embeddings(self): return self.embed_tokens
218
  def set_input_embeddings(self, value): self.embed_tokens = value
219
 
220
- def forward(self, input_ids=None, past_key_values=None, use_cache=None, **kwargs):
221
- if input_ids is not None:
222
- inputs_embeds = self.embed_tokens(input_ids)
 
 
 
 
 
 
 
 
 
 
223
  else:
224
- inputs_embeds = kwargs.get("inputs_embeds")
 
 
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
- seq_len = inputs_embeds.shape[1]
 
 
 
 
228
  if past_key_values is not None:
229
- seq_len += past_key_values[0][0].shape[2]
230
-
231
-
232
- cos, sin = self.rotary_emb(inputs_embeds, seq_len=seq_len)
233
-
234
- position_ids = kwargs.get("position_ids")
235
  if position_ids is None:
236
-
237
- position_ids = torch.arange(seq_len - inputs_embeds.shape[1], seq_len, dtype=torch.long, device=inputs_embeds.device)
238
- position_ids = position_ids.unsqueeze(0).expand(inputs_embeds.shape[0], -1)
 
 
 
 
239
 
240
  hidden_states = inputs_embeds
241
  next_decoder_cache = () if use_cache else None
 
 
242
 
243
  for idx, layer in enumerate(self.layers):
 
 
 
244
  past_key_value = past_key_values[idx] if past_key_values is not None else None
245
- layer_outputs = layer(
246
- hidden_states,
247
- position_ids=position_ids,
248
- past_key_value=past_key_value,
249
- use_cache=use_cache,
250
- cos_sin=(cos, sin)
251
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  hidden_states = layer_outputs[0]
 
 
253
  if use_cache:
254
  next_decoder_cache += (layer_outputs[2],)
255
 
256
  hidden_states = self.norm(hidden_states)
257
-
 
 
 
 
 
 
258
  return BaseModelOutputWithPast(
259
  last_hidden_state=hidden_states,
260
- past_key_values=next_decoder_cache
 
 
261
  )
262
 
 
 
 
 
 
263
  class AlinlightForCausalLM(PreTrainedModel, GenerationMixin):
264
  config_class = AlinlightConfig
265
  _keys_to_ignore_on_load_missing = ["model.rotary_emb.inv_freq"]
 
266
 
267
  def __init__(self, config):
268
  super().__init__(config)
269
  self.model = AlinlightModel(config)
270
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
271
 
 
 
 
272
  if config.tie_word_embeddings:
273
  self.lm_head.weight = self.model.embed_tokens.weight
274
-
275
  self.post_init()
276
 
277
  def get_input_embeddings(self): return self.model.embed_tokens
@@ -279,37 +471,103 @@ class AlinlightForCausalLM(PreTrainedModel, GenerationMixin):
279
  def get_output_embeddings(self): return self.lm_head
280
  def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings
281
 
282
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
283
-
284
- if past_key_values:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  input_ids = input_ids[:, -1:]
286
-
287
  position_ids = kwargs.get("position_ids", None)
288
  if position_ids is None:
289
  if past_key_values:
290
-
291
- past_length = past_key_values[0][0].shape[2]
292
- position_ids = torch.tensor([[past_length]], dtype=torch.long, device=input_ids.device)
293
  else:
294
-
295
- position_ids = torch.arange(input_ids.shape[1], dtype=torch.long, device=input_ids.device).unsqueeze(0)
296
 
297
  return {
298
  "input_ids": input_ids,
299
  "past_key_values": past_key_values,
300
  "use_cache": True,
301
- "position_ids": position_ids
 
302
  }
303
 
304
- def forward(self, input_ids=None, past_key_values=None, labels=None, **kwargs):
305
- outputs = self.model(input_ids=input_ids, past_key_values=past_key_values, **kwargs)
306
- hidden_states = outputs.last_hidden_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  logits = self.lm_head(hidden_states)
308
-
 
 
 
 
309
  loss = None
310
  if labels is not None:
311
  shift_logits = logits[..., :-1, :].contiguous()
312
  shift_labels = labels[..., 1:].contiguous()
313
- loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
314
-
315
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
 
16
  import math
17
  import torch
18
  import torch.nn as nn
19
  import torch.nn.functional as F
20
  from typing import Optional, Tuple, List, Union
21
+ from torch.utils.checkpoint import checkpoint
22
+
23
  from transformers import PreTrainedModel, GenerationMixin
24
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
25
  from configuration_alinlight import AlinlightConfig
26
 
27
+ # ==========================================
28
+ # 1. BASE COMPONENTS
29
+ # ==========================================
30
+
31
  class AlinlightRMSNorm(nn.Module):
32
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
33
  super().__init__()
34
  self.weight = nn.Parameter(torch.ones(hidden_size))
35
  self.eps = eps
36
 
37
+ def forward(self, x: torch.Tensor):
38
  input_dtype = x.dtype
39
  x = x.to(torch.float32)
40
  variance = x.pow(2).mean(-1, keepdim=True)
41
  x = x * torch.rsqrt(variance + self.eps)
42
  return self.weight * x.to(input_dtype)
43
 
44
+
45
  class AlinlightRotaryEmbedding(nn.Module):
46
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
47
  super().__init__()
 
49
  self.base = base
50
  self.max_position_embeddings = max_position_embeddings
51
  self.scaling_factor = scaling_factor
 
52
 
53
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim))
54
  self.register_buffer("inv_freq", inv_freq, persistent=False)
 
 
55
  self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
56
 
57
  def _set_cos_sin_cache(self, seq_len, device, dtype):
58
+ if (hasattr(self, 'cos_cached') and
59
+ self.cos_cached.device == device and
60
+ self.cos_cached.dtype == dtype and
61
+ self.cos_cached.shape[0] >= seq_len):
62
+ return
63
+
64
  t = torch.arange(seq_len, device=device, dtype=torch.int64).type_as(self.inv_freq)
65
  t = t / self.scaling_factor
66
  freqs = torch.outer(t, self.inv_freq)
 
69
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
70
 
71
  def forward(self, x, seq_len=None):
 
72
  if seq_len > self.cos_cached.shape[0]:
73
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
 
74
  return (
75
  self.cos_cached[:seq_len].to(dtype=x.dtype, device=x.device),
76
  self.sin_cached[:seq_len].to(dtype=x.dtype, device=x.device)
77
  )
78
 
79
+
80
+ def rotate_half(x: torch.Tensor):
81
  x1 = x[..., : x.shape[-1] // 2]
82
  x2 = x[..., x.shape[-1] // 2 :]
83
  return torch.cat((-x2, x1), dim=-1)
84
 
85
+
86
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
87
  cos = cos[position_ids].unsqueeze(unsqueeze_dim)
88
  sin = sin[position_ids].unsqueeze(unsqueeze_dim)
 
90
  k_embed = (k * cos) + (rotate_half(k) * sin)
91
  return q_embed, k_embed
92
 
93
+
94
+ # ==========================================
95
+ # 2. MLP
96
+ # ==========================================
97
+
98
  class AlinlightMLP(nn.Module):
99
  def __init__(self, config):
100
  super().__init__()
 
104
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
105
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
106
  self.act_fn = nn.SiLU()
107
+ self.pre_down_norm = AlinlightRMSNorm(self.intermediate_size, eps=config.rms_norm_eps)
108
+
109
+ # Tag for specialized initialization
110
+ self.down_proj._is_residual_projection = True
111
 
112
  def forward(self, x):
113
+ intermediate = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
114
+ intermediate = self.pre_down_norm(intermediate)
115
+ return self.down_proj(intermediate)
116
+
117
+
118
+ # ==========================================
119
+ # 3. ATTENTION
120
+ # ==========================================
121
 
122
  class AlinlightAttention(nn.Module):
123
  def __init__(self, config, layer_idx: Optional[int] = None):
 
130
  self.num_key_value_heads = config.num_key_value_heads
131
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
132
  self.sliding_window = config.sliding_window
133
+ self.attention_dropout = config.attention_dropout
134
+
135
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
136
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
137
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
138
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
139
+
140
+ # Tag for specialized initialization
141
+ self.o_proj._is_residual_projection = True
142
+
143
+ self.use_qk_norm = getattr(config, "use_qk_norm", True)
144
+ if self.use_qk_norm:
145
+ self.q_norm = AlinlightRMSNorm(self.head_dim, eps=config.rms_norm_eps)
146
+ self.k_norm = AlinlightRMSNorm(self.head_dim, eps=config.rms_norm_eps)
147
+
148
+ self.attn_logit_softcapping = getattr(config, 'attn_logit_softcapping', None)
149
 
150
  def forward(
151
+ self,
152
+ hidden_states: torch.Tensor,
153
+ attention_mask: Optional[torch.Tensor] = None,
154
+ position_ids: Optional[torch.LongTensor] = None,
155
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
156
+ output_attentions: bool = False,
157
+ use_cache: bool = False,
158
+ cos_sin: Optional[Tuple[torch.Tensor]] = None
159
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
160
+
161
  bsz, q_len, _ = hidden_states.size()
162
+
163
  query_states = self.q_proj(hidden_states)
164
  key_states = self.k_proj(hidden_states)
165
  value_states = self.v_proj(hidden_states)
166
+
167
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
168
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
169
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
170
 
171
+ if self.use_qk_norm:
172
+ query_states = self.q_norm(query_states)
173
+ key_states = self.k_norm(key_states)
174
+
175
+ # 1. RoPE (Applied before caching)
176
  if cos_sin is not None:
177
  cos, sin = cos_sin
178
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
179
 
180
+ # 2. KV Cache
181
  if past_key_value is not None:
182
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
183
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
184
 
185
+ kv_seq_len = key_states.shape[2]
186
 
187
+ # 3. Sliding Window (Slicing)
188
+ if self.sliding_window is not None and kv_seq_len > self.sliding_window:
189
+ slicing_tokens = kv_seq_len - self.sliding_window
190
+ key_states = key_states[:, :, slicing_tokens:, :]
191
+ value_states = value_states[:, :, slicing_tokens:, :]
192
+
193
+ if attention_mask is not None:
194
+ attention_mask = attention_mask[:, :, :, slicing_tokens:]
195
 
196
  past_key_value = (key_states, value_states) if use_cache else None
197
+
198
+ # 4. GQA Repeat
199
  if self.num_key_value_groups > 1:
200
+ key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
201
+ value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
 
 
 
 
 
202
 
203
+ # 5. Attention Mechanism
204
+ attn_weights = None
205
+
206
+ # We must use manual implementation if:
207
+ # a) Output weights are requested
208
+ # b) Soft-capping is enabled (SDPA doesn't support intermediate logit transforms)
209
+ if output_attentions or self.attn_logit_softcapping is not None:
210
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
211
+
212
+ if self.attn_logit_softcapping is not None:
213
+ attn_weights = self.attn_logit_softcapping * torch.tanh(attn_weights / self.attn_logit_softcapping)
214
+
215
+ if attention_mask is not None:
216
+ attn_weights = attn_weights + attention_mask
217
+
218
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
219
+
220
+ if not output_attentions:
221
+ # If we only calculated weights for soft-capping but user didn't ask for them, drop reference
222
+ attn_weights_for_output = None
223
+ else:
224
+ attn_weights_for_output = attn_weights
225
 
226
+ attn_weights_dropped = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
227
+ attn_output = torch.matmul(attn_weights_dropped, value_states)
228
+ else:
229
+ # Fast Path (SDPA)
230
+ attn_output = F.scaled_dot_product_attention(
231
+ query_states,
232
+ key_states,
233
+ value_states,
234
+ attn_mask=attention_mask,
235
+ dropout_p=self.attention_dropout if self.training else 0.0,
236
+ is_causal=False
237
+ )
238
+ attn_weights_for_output = None
239
 
 
 
 
 
 
 
 
 
 
240
  attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
241
+ return self.o_proj(attn_output), attn_weights_for_output, past_key_value
242
+
243
+
244
+ # ==========================================
245
+ # 4. DECODER LAYER & MODEL
246
+ # ==========================================
247
 
248
  class AlinlightDecoderLayer(nn.Module):
249
  def __init__(self, config, layer_idx: int):
 
252
  self.mlp = AlinlightMLP(config)
253
  self.input_layernorm = AlinlightRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
254
  self.post_attention_layernorm = AlinlightRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
255
+
256
+ self.resid_pdrop = getattr(config, 'resid_pdrop', 0.0)
257
+ self.resid_dropout = nn.Dropout(self.resid_pdrop) if self.resid_pdrop > 0 else nn.Identity()
258
 
259
+ def forward(
260
+ self,
261
+ hidden_states,
262
+ attention_mask=None,
263
+ position_ids=None,
264
+ past_key_value=None,
265
+ output_attentions=False,
266
+ use_cache=False,
267
+ cos_sin=None
268
+ ):
269
  residual = hidden_states
270
  hidden_states = self.input_layernorm(hidden_states)
271
+
272
+ hidden_states, attn_weights, present_key_value = self.self_attn(
273
+ hidden_states=hidden_states,
274
+ attention_mask=attention_mask,
275
+ position_ids=position_ids,
276
+ past_key_value=past_key_value,
277
+ output_attentions=output_attentions,
278
+ use_cache=use_cache,
279
+ cos_sin=cos_sin
280
  )
281
+ hidden_states = residual + self.resid_dropout(hidden_states)
282
+
283
  residual = hidden_states
284
  hidden_states = self.post_attention_layernorm(hidden_states)
285
  hidden_states = self.mlp(hidden_states)
286
+ hidden_states = residual + self.resid_dropout(hidden_states)
287
+
288
+ return hidden_states, attn_weights, present_key_value
289
+
290
 
291
  class AlinlightModel(PreTrainedModel):
292
  config_class = AlinlightConfig
293
+
294
  def __init__(self, config: AlinlightConfig):
295
  super().__init__(config)
296
+ self.padding_idx = config.pad_token_id
297
+ self.vocab_size = config.vocab_size
298
+
299
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
300
+
301
+ self.embed_scale = math.sqrt(config.hidden_size) if getattr(config, 'embed_scale', False) else 1.0
302
+ self.embed_dropout = nn.Dropout(config.embed_pdrop) if config.embed_pdrop > 0 else nn.Identity()
303
+
304
  self.layers = nn.ModuleList([AlinlightDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
305
  self.norm = AlinlightRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
306
+
307
  scaling_factor = 1.0
308
  if config.rope_scaling and config.rope_scaling.get("type") == "linear":
309
  scaling_factor = config.rope_scaling.get("factor", 1.0)
310
+
311
  self.rotary_emb = AlinlightRotaryEmbedding(
312
+ config.hidden_size // config.num_attention_heads,
313
+ max_position_embeddings=config.max_position_embeddings,
314
+ base=config.rope_theta,
315
  scaling_factor=scaling_factor
316
  )
317
+ self.gradient_checkpointing = False
318
  self.post_init()
319
 
320
  def get_input_embeddings(self): return self.embed_tokens
321
  def set_input_embeddings(self, value): self.embed_tokens = value
322
 
323
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
324
+ bsz, seq_len = input_shape
325
+ dtype = inputs_embeds.dtype
326
+ device = inputs_embeds.device
327
+
328
+ if attention_mask is not None:
329
+ current_mask = attention_mask[:, None, None, :].to(dtype=dtype)
330
+ else:
331
+ current_mask = torch.ones((bsz, 1, 1, seq_len), dtype=dtype, device=device)
332
+
333
+ if past_key_values_length > 0:
334
+ past_mask = torch.ones((bsz, 1, 1, past_key_values_length), dtype=dtype, device=device)
335
+ combined_mask = torch.cat([past_mask, current_mask], dim=-1)
336
  else:
337
+ combined_mask = current_mask
338
+
339
+ inverted_mask = (1.0 - combined_mask) * torch.finfo(dtype).min
340
 
341
+ if seq_len > 1:
342
+ causal_mask = torch.triu(
343
+ torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=dtype),
344
+ diagonal=1
345
+ )
346
+ if past_key_values_length > 0:
347
+ past_causal = torch.zeros((seq_len, past_key_values_length), dtype=dtype, device=device)
348
+ causal_mask = torch.cat([past_causal, causal_mask], dim=-1)
349
+
350
+ causal_mask = causal_mask[None, None, :, :]
351
+ inverted_mask = inverted_mask + causal_mask
352
+
353
+ return inverted_mask
354
+
355
+ def forward(
356
+ self,
357
+ input_ids: torch.LongTensor = None,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ position_ids: Optional[torch.LongTensor] = None,
360
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
361
+ inputs_embeds: Optional[torch.FloatTensor] = None,
362
+ use_cache: Optional[bool] = None,
363
+ output_attentions: Optional[bool] = None,
364
+ output_hidden_states: Optional[bool] = None,
365
+ return_dict: Optional[bool] = None,
366
+ ):
367
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
368
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
369
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
370
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
371
+
372
+ if inputs_embeds is None:
373
+ inputs_embeds = self.embed_tokens(input_ids)
374
 
375
+ inputs_embeds = inputs_embeds * self.embed_scale
376
+ inputs_embeds = self.embed_dropout(inputs_embeds)
377
+
378
+ batch_size, seq_length = inputs_embeds.shape[:2]
379
+ past_key_values_length = 0
380
  if past_key_values is not None:
381
+ past_key_values_length = past_key_values[0][0].shape[2]
382
+
383
+ total_seq_len = seq_length + past_key_values_length
384
+ cos, sin = self.rotary_emb(inputs_embeds, seq_len=total_seq_len)
385
+
 
386
  if position_ids is None:
387
+ position_ids = torch.arange(
388
+ past_key_values_length, total_seq_len, dtype=torch.long, device=inputs_embeds.device
389
+ ).unsqueeze(0).expand(batch_size, -1)
390
+
391
+ attention_mask = self._prepare_decoder_attention_mask(
392
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
393
+ )
394
 
395
  hidden_states = inputs_embeds
396
  next_decoder_cache = () if use_cache else None
397
+ all_hidden_states = () if output_hidden_states else None
398
+ all_self_attns = () if output_attentions else None
399
 
400
  for idx, layer in enumerate(self.layers):
401
+ if output_hidden_states:
402
+ all_hidden_states += (hidden_states,)
403
+
404
  past_key_value = past_key_values[idx] if past_key_values is not None else None
405
+
406
+ if self.gradient_checkpointing and self.training:
407
+ def create_custom_forward(module):
408
+ def custom_forward(*inputs):
409
+ return module(*inputs, output_attentions=output_attentions, use_cache=False, cos_sin=(cos, sin))
410
+ return custom_forward
411
+ layer_outputs = checkpoint(
412
+ create_custom_forward(layer), hidden_states, attention_mask, position_ids, past_key_value, use_reentrant=True
413
+ )
414
+ else:
415
+ layer_outputs = layer(
416
+ hidden_states,
417
+ attention_mask=attention_mask,
418
+ position_ids=position_ids,
419
+ past_key_value=past_key_value,
420
+ output_attentions=output_attentions,
421
+ use_cache=use_cache,
422
+ cos_sin=(cos, sin)
423
+ )
424
+
425
  hidden_states = layer_outputs[0]
426
+ if output_attentions:
427
+ all_self_attns += (layer_outputs[1],)
428
  if use_cache:
429
  next_decoder_cache += (layer_outputs[2],)
430
 
431
  hidden_states = self.norm(hidden_states)
432
+
433
+ if output_hidden_states:
434
+ all_hidden_states += (hidden_states,)
435
+
436
+ if not return_dict:
437
+ return tuple(v for v in [hidden_states, next_decoder_cache, all_hidden_states, all_self_attns] if v is not None)
438
+
439
  return BaseModelOutputWithPast(
440
  last_hidden_state=hidden_states,
441
+ past_key_values=next_decoder_cache,
442
+ hidden_states=all_hidden_states,
443
+ attentions=all_self_attns,
444
  )
445
 
446
+
447
+ # ==========================================
448
+ # 5. CAUSAL LM HEAD
449
+ # ==========================================
450
+
451
  class AlinlightForCausalLM(PreTrainedModel, GenerationMixin):
452
  config_class = AlinlightConfig
453
  _keys_to_ignore_on_load_missing = ["model.rotary_emb.inv_freq"]
454
+ _supports_gradient_checkpointing = True
455
 
456
  def __init__(self, config):
457
  super().__init__(config)
458
  self.model = AlinlightModel(config)
459
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
460
 
461
+ self.final_logit_softcapping = getattr(config, 'final_logit_softcapping', None)
462
+ self.z_loss_weight = getattr(config, 'z_loss_weight', 0.0)
463
+
464
  if config.tie_word_embeddings:
465
  self.lm_head.weight = self.model.embed_tokens.weight
466
+
467
  self.post_init()
468
 
469
  def get_input_embeddings(self): return self.model.embed_tokens
 
471
  def get_output_embeddings(self): return self.lm_head
472
  def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings
473
 
474
+ def _init_weights(self, module):
475
+ std = self.config.initializer_range
476
+ if isinstance(module, nn.Linear):
477
+ # Scale down residual projections to improve training stability at depth
478
+ if getattr(module, '_is_residual_projection', False):
479
+ module.weight.data.normal_(mean=0.0, std=std / math.sqrt(2 * self.config.num_hidden_layers))
480
+ else:
481
+ module.weight.data.normal_(mean=0.0, std=std)
482
+
483
+ if module.bias is not None:
484
+ module.bias.data.zero_()
485
+ elif isinstance(module, nn.Embedding):
486
+ module.weight.data.normal_(mean=0.0, std=std)
487
+ if module.padding_idx is not None:
488
+ module.weight.data[module.padding_idx].zero_()
489
+
490
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
491
+ self.model.gradient_checkpointing = True
492
+ self.config.use_cache = False
493
+
494
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
495
+ if past_key_values is not None:
496
  input_ids = input_ids[:, -1:]
497
+
498
  position_ids = kwargs.get("position_ids", None)
499
  if position_ids is None:
500
  if past_key_values:
501
+ position_ids = (attention_mask.long().sum(dim=-1) - 1).unsqueeze(-1)
 
 
502
  else:
503
+ position_ids = torch.arange(input_ids.shape[1], dtype=torch.long, device=input_ids.device).unsqueeze(0)
 
504
 
505
  return {
506
  "input_ids": input_ids,
507
  "past_key_values": past_key_values,
508
  "use_cache": True,
509
+ "position_ids": position_ids,
510
+ "attention_mask": attention_mask,
511
  }
512
 
513
+ def forward(
514
+ self,
515
+ input_ids=None,
516
+ attention_mask=None,
517
+ position_ids=None,
518
+ past_key_values=None,
519
+ labels=None,
520
+ use_cache=None,
521
+ output_attentions=None,
522
+ output_hidden_states=None,
523
+ return_dict=None,
524
+ **kwargs
525
+ ):
526
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
527
+
528
+ outputs = self.model(
529
+ input_ids=input_ids,
530
+ attention_mask=attention_mask,
531
+ position_ids=position_ids,
532
+ past_key_values=past_key_values,
533
+ use_cache=use_cache,
534
+ output_attentions=output_attentions,
535
+ output_hidden_states=output_hidden_states,
536
+ return_dict=return_dict,
537
+ **kwargs
538
+ )
539
+
540
+ hidden_states = outputs[0]
541
  logits = self.lm_head(hidden_states)
542
+
543
+ # Final Logit Soft-Capping
544
+ if self.final_logit_softcapping is not None:
545
+ logits = self.final_logit_softcapping * torch.tanh(logits / self.final_logit_softcapping)
546
+
547
  loss = None
548
  if labels is not None:
549
  shift_logits = logits[..., :-1, :].contiguous()
550
  shift_labels = labels[..., 1:].contiguous()
551
+
552
+ loss_fct = nn.CrossEntropyLoss()
553
+ ce_loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
554
+
555
+ # Z-Loss Regularization
556
+ if self.z_loss_weight > 0 and self.training:
557
+ # log(sum(exp(x)))^2
558
+ z_loss = torch.logsumexp(shift_logits, dim=-1).pow(2).mean()
559
+ loss = ce_loss + self.z_loss_weight * z_loss
560
+ else:
561
+ loss = ce_loss
562
+
563
+ if not return_dict:
564
+ output = (logits,) + outputs[1:]
565
+ return ((loss,) + output) if loss is not None else output
566
+
567
+ return CausalLMOutputWithPast(
568
+ loss=loss,
569
+ logits=logits,
570
+ past_key_values=outputs.past_key_values,
571
+ hidden_states=outputs.hidden_states,
572
+ attentions=outputs.attentions,
573
+ )