muverqqw commited on
Commit
5f2b985
·
1 Parent(s): bc3697f

Update modeling_alinlight.py

Browse files
Files changed (1) hide show
  1. modeling_alinlight.py +74 -53
modeling_alinlight.py CHANGED
@@ -17,13 +17,43 @@ import math
17
  import torch
18
  import torch.nn as nn
19
  import torch.nn.functional as F
 
20
  from typing import Optional, Tuple, List, Union
21
  from torch.utils.checkpoint import checkpoint
22
 
23
  from transformers import PreTrainedModel, GenerationMixin
24
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
 
25
  from configuration_alinlight import AlinlightConfig
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # ==========================================
28
  # 1. BASE COMPONENTS
29
  # ==========================================
@@ -137,7 +167,6 @@ class AlinlightAttention(nn.Module):
137
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
138
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
139
 
140
- # Tag for specialized initialization
141
  self.o_proj._is_residual_projection = True
142
 
143
  self.use_qk_norm = getattr(config, "use_qk_norm", True)
@@ -155,7 +184,7 @@ class AlinlightAttention(nn.Module):
155
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
156
  output_attentions: bool = False,
157
  use_cache: bool = False,
158
- cos_sin: Optional[Tuple[torch.Tensor]] = None
159
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
160
 
161
  bsz, q_len, _ = hidden_states.size()
@@ -172,25 +201,25 @@ class AlinlightAttention(nn.Module):
172
  query_states = self.q_norm(query_states)
173
  key_states = self.k_norm(key_states)
174
 
175
- # 1. RoPE (Applied before caching)
176
- if cos_sin is not None:
177
- cos, sin = cos_sin
178
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
179
 
180
- # 2. KV Cache
181
  if past_key_value is not None:
182
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
183
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
184
 
185
- kv_seq_len = key_states.shape[2]
186
-
187
  # 3. Sliding Window (Slicing)
 
 
188
  if self.sliding_window is not None and kv_seq_len > self.sliding_window:
189
  slicing_tokens = kv_seq_len - self.sliding_window
190
  key_states = key_states[:, :, slicing_tokens:, :]
191
  value_states = value_states[:, :, slicing_tokens:, :]
192
 
193
- if attention_mask is not None:
194
  attention_mask = attention_mask[:, :, :, slicing_tokens:]
195
 
196
  past_key_value = (key_states, value_states) if use_cache else None
@@ -203,9 +232,6 @@ class AlinlightAttention(nn.Module):
203
  # 5. Attention Mechanism
204
  attn_weights = None
205
 
206
- # We must use manual implementation if:
207
- # a) Output weights are requested
208
- # b) Soft-capping is enabled (SDPA doesn't support intermediate logit transforms)
209
  if output_attentions or self.attn_logit_softcapping is not None:
210
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
211
 
@@ -217,16 +243,11 @@ class AlinlightAttention(nn.Module):
217
 
218
  attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
219
 
220
- if not output_attentions:
221
- # If we only calculated weights for soft-capping but user didn't ask for them, drop reference
222
- attn_weights_for_output = None
223
- else:
224
- attn_weights_for_output = attn_weights
225
 
226
  attn_weights_dropped = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
227
  attn_output = torch.matmul(attn_weights_dropped, value_states)
228
  else:
229
- # Fast Path (SDPA)
230
  attn_output = F.scaled_dot_product_attention(
231
  query_states,
232
  key_states,
@@ -264,7 +285,7 @@ class AlinlightDecoderLayer(nn.Module):
264
  past_key_value=None,
265
  output_attentions=False,
266
  use_cache=False,
267
- cos_sin=None
268
  ):
269
  residual = hidden_states
270
  hidden_states = self.input_layernorm(hidden_states)
@@ -276,7 +297,7 @@ class AlinlightDecoderLayer(nn.Module):
276
  past_key_value=past_key_value,
277
  output_attentions=output_attentions,
278
  use_cache=use_cache,
279
- cos_sin=cos_sin
280
  )
281
  hidden_states = residual + self.resid_dropout(hidden_states)
282
 
@@ -288,9 +309,7 @@ class AlinlightDecoderLayer(nn.Module):
288
  return hidden_states, attn_weights, present_key_value
289
 
290
 
291
- class AlinlightModel(PreTrainedModel):
292
- config_class = AlinlightConfig
293
-
294
  def __init__(self, config: AlinlightConfig):
295
  super().__init__(config)
296
  self.padding_idx = config.pad_token_id
@@ -299,7 +318,9 @@ class AlinlightModel(PreTrainedModel):
299
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
300
 
301
  self.embed_scale = math.sqrt(config.hidden_size) if getattr(config, 'embed_scale', False) else 1.0
302
- self.embed_dropout = nn.Dropout(config.embed_pdrop) if config.embed_pdrop > 0 else nn.Identity()
 
 
303
 
304
  self.layers = nn.ModuleList([AlinlightDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
305
  self.norm = AlinlightRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -369,6 +390,14 @@ class AlinlightModel(PreTrainedModel):
369
  use_cache = use_cache if use_cache is not None else self.config.use_cache
370
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
371
 
 
 
 
 
 
 
 
 
372
  if inputs_embeds is None:
373
  inputs_embeds = self.embed_tokens(input_ids)
374
 
@@ -406,10 +435,17 @@ class AlinlightModel(PreTrainedModel):
406
  if self.gradient_checkpointing and self.training:
407
  def create_custom_forward(module):
408
  def custom_forward(*inputs):
409
- return module(*inputs, output_attentions=output_attentions, use_cache=False, cos_sin=(cos, sin))
 
410
  return custom_forward
 
411
  layer_outputs = checkpoint(
412
- create_custom_forward(layer), hidden_states, attention_mask, position_ids, past_key_value, use_reentrant=True
 
 
 
 
 
413
  )
414
  else:
415
  layer_outputs = layer(
@@ -419,7 +455,7 @@ class AlinlightModel(PreTrainedModel):
419
  past_key_value=past_key_value,
420
  output_attentions=output_attentions,
421
  use_cache=use_cache,
422
- cos_sin=(cos, sin)
423
  )
424
 
425
  hidden_states = layer_outputs[0]
@@ -448,11 +484,7 @@ class AlinlightModel(PreTrainedModel):
448
  # 5. CAUSAL LM HEAD
449
  # ==========================================
450
 
451
- class AlinlightForCausalLM(PreTrainedModel, GenerationMixin):
452
- config_class = AlinlightConfig
453
- _keys_to_ignore_on_load_missing = ["model.rotary_emb.inv_freq"]
454
- _supports_gradient_checkpointing = True
455
-
456
  def __init__(self, config):
457
  super().__init__(config)
458
  self.model = AlinlightModel(config)
@@ -464,6 +496,8 @@ class AlinlightForCausalLM(PreTrainedModel, GenerationMixin):
464
  if config.tie_word_embeddings:
465
  self.lm_head.weight = self.model.embed_tokens.weight
466
 
 
 
467
  self.post_init()
468
 
469
  def get_input_embeddings(self): return self.model.embed_tokens
@@ -471,25 +505,11 @@ class AlinlightForCausalLM(PreTrainedModel, GenerationMixin):
471
  def get_output_embeddings(self): return self.lm_head
472
  def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings
473
 
474
- def _init_weights(self, module):
475
- std = self.config.initializer_range
476
- if isinstance(module, nn.Linear):
477
- # Scale down residual projections to improve training stability at depth
478
- if getattr(module, '_is_residual_projection', False):
479
- module.weight.data.normal_(mean=0.0, std=std / math.sqrt(2 * self.config.num_hidden_layers))
480
- else:
481
- module.weight.data.normal_(mean=0.0, std=std)
482
-
483
- if module.bias is not None:
484
- module.bias.data.zero_()
485
- elif isinstance(module, nn.Embedding):
486
- module.weight.data.normal_(mean=0.0, std=std)
487
- if module.padding_idx is not None:
488
- module.weight.data[module.padding_idx].zero_()
489
-
490
  def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
491
  self.model.gradient_checkpointing = True
492
- self.config.use_cache = False
 
 
493
 
494
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
495
  if past_key_values is not None:
@@ -498,7 +518,11 @@ class AlinlightForCausalLM(PreTrainedModel, GenerationMixin):
498
  position_ids = kwargs.get("position_ids", None)
499
  if position_ids is None:
500
  if past_key_values:
501
- position_ids = (attention_mask.long().sum(dim=-1) - 1).unsqueeze(-1)
 
 
 
 
502
  else:
503
  position_ids = torch.arange(input_ids.shape[1], dtype=torch.long, device=input_ids.device).unsqueeze(0)
504
 
@@ -540,7 +564,6 @@ class AlinlightForCausalLM(PreTrainedModel, GenerationMixin):
540
  hidden_states = outputs[0]
541
  logits = self.lm_head(hidden_states)
542
 
543
- # Final Logit Soft-Capping
544
  if self.final_logit_softcapping is not None:
545
  logits = self.final_logit_softcapping * torch.tanh(logits / self.final_logit_softcapping)
546
 
@@ -552,9 +575,7 @@ class AlinlightForCausalLM(PreTrainedModel, GenerationMixin):
552
  loss_fct = nn.CrossEntropyLoss()
553
  ce_loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
554
 
555
- # Z-Loss Regularization
556
  if self.z_loss_weight > 0 and self.training:
557
- # log(sum(exp(x)))^2
558
  z_loss = torch.logsumexp(shift_logits, dim=-1).pow(2).mean()
559
  loss = ce_loss + self.z_loss_weight * z_loss
560
  else:
 
17
  import torch
18
  import torch.nn as nn
19
  import torch.nn.functional as F
20
+ import warnings
21
  from typing import Optional, Tuple, List, Union
22
  from torch.utils.checkpoint import checkpoint
23
 
24
  from transformers import PreTrainedModel, GenerationMixin
25
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
26
+ from transformers.utils import logging
27
  from configuration_alinlight import AlinlightConfig
28
 
29
+ logger = logging.get_logger(__name__)
30
+
31
+ # ==========================================
32
+ # 0. BASE PRETRAINED MODEL
33
+ # ==========================================
34
+
35
+ class AlinlightPreTrainedModel(PreTrainedModel):
36
+ config_class = AlinlightConfig
37
+ base_model_prefix = "model"
38
+ _no_split_modules = ["AlinlightDecoderLayer"]
39
+ _supports_gradient_checkpointing = True
40
+
41
+ def _init_weights(self, module):
42
+ std = self.config.initializer_range
43
+ if isinstance(module, nn.Linear):
44
+ # Scale down residual projections to improve training stability at depth
45
+ if getattr(module, '_is_residual_projection', False):
46
+ module.weight.data.normal_(mean=0.0, std=std / math.sqrt(2 * self.config.num_hidden_layers))
47
+ else:
48
+ module.weight.data.normal_(mean=0.0, std=std)
49
+
50
+ if module.bias is not None:
51
+ module.bias.data.zero_()
52
+ elif isinstance(module, nn.Embedding):
53
+ module.weight.data.normal_(mean=0.0, std=std)
54
+ if module.padding_idx is not None:
55
+ module.weight.data[module.padding_idx].zero_()
56
+
57
  # ==========================================
58
  # 1. BASE COMPONENTS
59
  # ==========================================
 
167
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
168
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
169
 
 
170
  self.o_proj._is_residual_projection = True
171
 
172
  self.use_qk_norm = getattr(config, "use_qk_norm", True)
 
184
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
185
  output_attentions: bool = False,
186
  use_cache: bool = False,
187
+ rotary_pos_emb: Optional[Tuple[torch.Tensor]] = None
188
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
189
 
190
  bsz, q_len, _ = hidden_states.size()
 
201
  query_states = self.q_norm(query_states)
202
  key_states = self.k_norm(key_states)
203
 
204
+ # 1. RoPE
205
+ if rotary_pos_emb is not None:
206
+ cos, sin = rotary_pos_emb
207
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
208
 
209
+ # 2. KV Cache Update
210
  if past_key_value is not None:
211
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
212
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
213
 
 
 
214
  # 3. Sliding Window (Slicing)
215
+ kv_seq_len = key_states.shape[2] # NOTE: This is the length BEFORE slicing
216
+
217
  if self.sliding_window is not None and kv_seq_len > self.sliding_window:
218
  slicing_tokens = kv_seq_len - self.sliding_window
219
  key_states = key_states[:, :, slicing_tokens:, :]
220
  value_states = value_states[:, :, slicing_tokens:, :]
221
 
222
+ if attention_mask is not None and attention_mask.shape[-1] == kv_seq_len:
223
  attention_mask = attention_mask[:, :, :, slicing_tokens:]
224
 
225
  past_key_value = (key_states, value_states) if use_cache else None
 
232
  # 5. Attention Mechanism
233
  attn_weights = None
234
 
 
 
 
235
  if output_attentions or self.attn_logit_softcapping is not None:
236
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
237
 
 
243
 
244
  attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
245
 
246
+ attn_weights_for_output = attn_weights if output_attentions else None
 
 
 
 
247
 
248
  attn_weights_dropped = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
249
  attn_output = torch.matmul(attn_weights_dropped, value_states)
250
  else:
 
251
  attn_output = F.scaled_dot_product_attention(
252
  query_states,
253
  key_states,
 
285
  past_key_value=None,
286
  output_attentions=False,
287
  use_cache=False,
288
+ rotary_pos_emb=None
289
  ):
290
  residual = hidden_states
291
  hidden_states = self.input_layernorm(hidden_states)
 
297
  past_key_value=past_key_value,
298
  output_attentions=output_attentions,
299
  use_cache=use_cache,
300
+ rotary_pos_emb=rotary_pos_emb
301
  )
302
  hidden_states = residual + self.resid_dropout(hidden_states)
303
 
 
309
  return hidden_states, attn_weights, present_key_value
310
 
311
 
312
+ class AlinlightModel(AlinlightPreTrainedModel):
 
 
313
  def __init__(self, config: AlinlightConfig):
314
  super().__init__(config)
315
  self.padding_idx = config.pad_token_id
 
318
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
319
 
320
  self.embed_scale = math.sqrt(config.hidden_size) if getattr(config, 'embed_scale', False) else 1.0
321
+
322
+ embed_pdrop = getattr(config, 'embed_pdrop', 0.0)
323
+ self.embed_dropout = nn.Dropout(embed_pdrop) if embed_pdrop > 0 else nn.Identity()
324
 
325
  self.layers = nn.ModuleList([AlinlightDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
326
  self.norm = AlinlightRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
390
  use_cache = use_cache if use_cache is not None else self.config.use_cache
391
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
392
 
393
+ # --- SAFETY CHECK FOR GRADIENT CHECKPOINTING ---
394
+ if self.gradient_checkpointing and self.training:
395
+ if use_cache:
396
+ logger.warning_once(
397
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
398
+ )
399
+ use_cache = False
400
+
401
  if inputs_embeds is None:
402
  inputs_embeds = self.embed_tokens(input_ids)
403
 
 
435
  if self.gradient_checkpointing and self.training:
436
  def create_custom_forward(module):
437
  def custom_forward(*inputs):
438
+ # Force use_cache=False inside checkpoint to be safe
439
+ return module(*inputs, output_attentions=output_attentions, use_cache=False, rotary_pos_emb=(cos, sin))
440
  return custom_forward
441
+
442
  layer_outputs = checkpoint(
443
+ create_custom_forward(layer),
444
+ hidden_states,
445
+ attention_mask,
446
+ position_ids,
447
+ past_key_value,
448
+ use_reentrant=False
449
  )
450
  else:
451
  layer_outputs = layer(
 
455
  past_key_value=past_key_value,
456
  output_attentions=output_attentions,
457
  use_cache=use_cache,
458
+ rotary_pos_emb=(cos, sin)
459
  )
460
 
461
  hidden_states = layer_outputs[0]
 
484
  # 5. CAUSAL LM HEAD
485
  # ==========================================
486
 
487
+ class AlinlightForCausalLM(AlinlightPreTrainedModel, GenerationMixin):
 
 
 
 
488
  def __init__(self, config):
489
  super().__init__(config)
490
  self.model = AlinlightModel(config)
 
496
  if config.tie_word_embeddings:
497
  self.lm_head.weight = self.model.embed_tokens.weight
498
 
499
+ # Note: self.post_init() is called here, and inside AlinlightModel.
500
+ # This re-initialization is consistent with standard HF models (e.g. Llama).
501
  self.post_init()
502
 
503
  def get_input_embeddings(self): return self.model.embed_tokens
 
505
  def get_output_embeddings(self): return self.lm_head
506
  def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings
507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
509
  self.model.gradient_checkpointing = True
510
+
511
+ def gradient_checkpointing_disable(self):
512
+ self.model.gradient_checkpointing = False
513
 
514
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
515
  if past_key_values is not None:
 
518
  position_ids = kwargs.get("position_ids", None)
519
  if position_ids is None:
520
  if past_key_values:
521
+ if attention_mask is not None:
522
+ position_ids = (attention_mask.long().sum(dim=-1) - 1).unsqueeze(-1)
523
+ else:
524
+ past_length = past_key_values[0][0].shape[2]
525
+ position_ids = torch.tensor([[past_length]], device=input_ids.device)
526
  else:
527
  position_ids = torch.arange(input_ids.shape[1], dtype=torch.long, device=input_ids.device).unsqueeze(0)
528
 
 
564
  hidden_states = outputs[0]
565
  logits = self.lm_head(hidden_states)
566
 
 
567
  if self.final_logit_softcapping is not None:
568
  logits = self.final_logit_softcapping * torch.tanh(logits / self.final_logit_softcapping)
569
 
 
575
  loss_fct = nn.CrossEntropyLoss()
576
  ce_loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
577
 
 
578
  if self.z_loss_weight > 0 and self.training:
 
579
  z_loss = torch.logsumexp(shift_logits, dim=-1).pow(2).mean()
580
  loss = ce_loss + self.z_loss_weight * z_loss
581
  else: