~test conditional workarounds

#1
by exdysa - opened
Files changed (1) hide show
  1. modeling_sdar.py +179 -199
modeling_sdar.py CHANGED
@@ -23,15 +23,20 @@
23
 
24
  from typing import Callable, Optional, Tuple, Union
25
 
 
26
  import torch
27
  from torch import nn
28
-
29
  from transformers.activations import ACT2FN
30
- from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
 
 
 
 
 
31
  from transformers.generation import GenerationMixin
32
  from transformers.integrations import use_kernel_forward_from_hub
33
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
34
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
35
  from transformers.modeling_layers import GradientCheckpointingLayer
36
  from transformers.modeling_outputs import (
37
  BaseModelOutputWithPast,
@@ -43,31 +48,63 @@ from transformers.modeling_outputs import (
43
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
44
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
45
  from transformers.processing_utils import Unpack
46
- from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
47
- from .configuration_sdar import SDARConfig
 
 
 
 
 
48
 
49
- from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
50
 
51
- import torch.nn.functional as F
52
- try:
53
- from flash_attn import flash_attn_func, flash_attn_varlen_func
54
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
55
- except:
56
- pass
57
 
 
58
  try:
59
- from liger_kernel.ops.swiglu import LigerSiLUMulFunction # noqa: F401
60
- liger_kernel_is_available = True
61
  except ImportError:
62
- liger_kernel_is_available = False
63
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  if is_torch_flex_attn_available():
66
- from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention
67
- from transformers.integrations.flex_attention import make_flex_block_causal_mask
 
 
 
68
 
 
 
69
 
70
- logger = logging.get_logger(__name__)
 
 
71
 
72
 
73
  @use_kernel_forward_from_hub("RMSNorm")
@@ -81,16 +118,20 @@ class SDARRMSNorm(nn.Module):
81
  self.variance_epsilon = eps
82
 
83
  def forward(self, hidden_states):
84
- return flash_rms_norm(
85
- hidden_states, weight=self.weight, bias=None, eps=self.variance_epsilon)
86
- '''
 
 
 
 
 
 
87
  input_dtype = hidden_states.dtype
88
  hidden_states = hidden_states.to(torch.float32)
89
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
90
- hidden_states = hidden_states * \
91
- torch.rsqrt(variance + self.variance_epsilon)
92
  return self.weight * hidden_states.to(input_dtype)
93
- '''
94
 
95
  def extra_repr(self):
96
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
@@ -102,27 +143,23 @@ class SDARMLP(nn.Module):
102
  self.config = config
103
  self.hidden_size = config.hidden_size
104
  self.intermediate_size = config.intermediate_size
105
- self.gate_proj = nn.Linear(
106
- self.hidden_size, self.intermediate_size, bias=False)
107
- self.up_proj = nn.Linear(
108
- self.hidden_size, self.intermediate_size, bias=False)
109
- self.down_proj = nn.Linear(
110
- self.intermediate_size, self.hidden_size, bias=False)
111
  self.act_fn = ACT2FN[config.hidden_act]
112
 
113
  def forward(self, x):
114
  if liger_kernel_is_available:
115
  return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
116
  else:
117
- down_proj = self.down_proj(self.act_fn(
118
- self.gate_proj(x)) * self.up_proj(x))
119
  return down_proj
120
 
121
 
122
  def rotate_half(x):
123
  """Rotates half the hidden dims of the input."""
124
  x1 = x[..., : x.shape[-1] // 2]
125
- x2 = x[..., x.shape[-1] // 2:]
126
  return torch.cat((-x2, x1), dim=-1)
127
 
128
 
@@ -160,8 +197,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
160
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
161
  if n_rep == 1:
162
  return hidden_states
163
- hidden_states = hidden_states[:, :, None, :, :].expand(
164
- batch, num_key_value_heads, n_rep, slen, head_dim)
165
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
166
 
167
 
@@ -183,10 +219,8 @@ def eager_attention_forward(
183
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
184
  attn_weights = attn_weights + causal_mask
185
 
186
- attn_weights = nn.functional.softmax(
187
- attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
188
- attn_weights = nn.functional.dropout(
189
- attn_weights, p=dropout, training=module.training)
190
  attn_output = torch.matmul(attn_weights, value_states)
191
  attn_output = attn_output.transpose(1, 2).contiguous()
192
 
@@ -200,8 +234,7 @@ class SDARAttention(nn.Module):
200
  super().__init__()
201
  self.config = config
202
  self.layer_idx = layer_idx
203
- self.head_dim = getattr(
204
- config, "head_dim", config.hidden_size // config.num_attention_heads)
205
  self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
206
  self.scaling = self.head_dim**-0.5
207
  self.attention_dropout = config.attention_dropout
@@ -211,28 +244,16 @@ class SDARAttention(nn.Module):
211
  self.num_attention_heads = config.num_attention_heads
212
  self.num_key_value_heads = config.num_key_value_heads
213
 
214
- self.q_proj = nn.Linear(
215
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
216
- )
217
- self.k_proj = nn.Linear(
218
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
219
- )
220
- self.v_proj = nn.Linear(
221
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
222
- )
223
- self.o_proj = nn.Linear(
224
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
225
- )
226
  # unlike olmo, only on the head dim!
227
  self.q_norm = SDARRMSNorm(self.head_dim, eps=config.rms_norm_eps)
228
  # thus post q_norm does not need reshape
229
  self.k_norm = SDARRMSNorm(self.head_dim, eps=config.rms_norm_eps)
230
  self.sliding_window = config.sliding_window
231
- if not (
232
- self.config.use_sliding_window
233
- and getattr(self.config, "sliding_window", None) is not None
234
- and self.layer_idx >= self.config.max_window_layers
235
- ):
236
  self.sliding_window = None
237
 
238
  def forward(
@@ -248,32 +269,23 @@ class SDARAttention(nn.Module):
248
  bsz, q_len = input_shape
249
  hidden_shape = (*input_shape, -1, self.head_dim)
250
 
251
- query_states = self.q_norm(self.q_proj(
252
- hidden_states).view(hidden_shape)).transpose(1, 2)
253
- key_states = self.k_norm(self.k_proj(
254
- hidden_states).view(hidden_shape)).transpose(1, 2)
255
- value_states = self.v_proj(hidden_states).view(
256
- hidden_shape).transpose(1, 2)
257
-
258
-
259
 
260
  cos, sin = position_embeddings
261
- query_states, key_states = apply_rotary_pos_emb(
262
- query_states, key_states, cos, sin)
263
 
264
  if past_key_value is not None and kwargs.get("store_kv", False):
265
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
266
- key_states, value_states = past_key_value.update(
267
- key_states, value_states, self.layer_idx)
268
  elif past_key_value is not None and not kwargs.get("store_kv", False) and len(past_key_value) > self.layer_idx:
269
  # only retrive, do not store kv
270
  past_key_states, past_value_states = past_key_value[self.layer_idx]
271
- key_states = torch.cat(
272
- [past_key_states, key_states], dim=-2)
273
- value_states = torch.cat(
274
- [past_value_states, value_states], dim=-2)
275
 
276
- '''
277
  attention_mask = attention_mask.bool() if attention_mask is not None else None
278
  if torch.all(attention_mask): # decoding
279
  query_states = query_states.transpose(1, 2)
@@ -298,12 +310,12 @@ class SDARAttention(nn.Module):
298
  enable_gqa=True
299
  )
300
  attn_output = attn_output.transpose(1, 2).contiguous()
301
- '''
302
 
303
- #print(query_states.shape, key_states.shape, value_states.shape)
304
 
305
  # --- After RoPE and KV-cache handling, expand KV to all heads ---
306
- key_states = repeat_kv(key_states, self.num_key_value_groups) # [B, H, K, D]
307
  value_states = repeat_kv(value_states, self.num_key_value_groups) # [B, H, K, D]
308
 
309
  # --- Convert a 0/1 or bool 4D mask into an *additive* mask, and align to [B, H, Q, K] ---
@@ -313,14 +325,14 @@ class SDARAttention(nn.Module):
313
  am = attention_mask
314
  # Support either 2D [B, K] or 4D [B, 1/H, Q, K]
315
  if am.dim() == 2:
316
- am = am[:, None, None, :k_len] # -> [B,1,1,K]
317
  else:
318
- am = am[:, :, :, :k_len] # -> [B,1/H,Q,K]
319
 
320
  finfo_min = torch.finfo(query_states.dtype).min
321
  # 0/1 or bool -> float additive mask: 1->0, 0->-inf
322
  if am.dtype == torch.bool:
323
- zero = torch.zeros((), dtype=query_states.dtype, device=am.device)
324
  neginf = torch.full((), finfo_min, dtype=query_states.dtype, device=am.device)
325
  am = torch.where(am, zero, neginf)
326
  else:
@@ -329,33 +341,59 @@ class SDARAttention(nn.Module):
329
  am = torch.where(am > 0, torch.zeros_like(am), torch.full_like(am, finfo_min))
330
 
331
  # Expand to all heads
332
- #if am.shape[1] == 1 and self.num_attention_heads > 1:
333
  # am = am.expand(am.shape[0], self.num_attention_heads, am.shape[2], am.shape[3])
334
 
335
- #attn_mask = am.contiguous()
336
  attn_mask = am
337
-
338
 
339
  bsz, q_len = input_shape
340
 
341
  if q_len == 1 and past_key_value is not None:
342
- # --- Decoding: flash-attn ---
343
- q = query_states.transpose(1, 2) # [B,Q,H,D]
344
- k = key_states.transpose(1, 2)
345
- v = value_states.transpose(1, 2)
346
- attn_output = flash_attn_func(
347
- q, k, v,
348
- causal=True, # For decoding, explicitly set causal=True
349
- softmax_scale=self.scaling
350
- )
351
- attn_output = attn_output.transpose(1, 2).contiguous()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  else:
353
  attn_output = F.scaled_dot_product_attention(
354
- query=query_states, # [B,H,Q,D]
355
- key=key_states, # [B,H,K,D]
356
- value=value_states, # [B,H,K,D]
357
- attn_mask=attn_mask, # float additive mask
358
- is_causal=False, # All constraints are already encoded in the mask
359
  scale=self.scaling,
360
  )
361
  attn_output = attn_output.transpose(1, 2).contiguous() # -> [B,Q,H,D]
@@ -371,17 +409,10 @@ class SDARDecoderLayer(GradientCheckpointingLayer):
371
  self.hidden_size = config.hidden_size
372
  self.self_attn = SDARAttention(config=config, layer_idx=layer_idx)
373
  self.mlp = SDARMLP(config)
374
- self.input_layernorm = SDARRMSNorm(
375
- config.hidden_size, eps=config.rms_norm_eps)
376
- self.post_attention_layernorm = SDARRMSNorm(
377
- config.hidden_size, eps=config.rms_norm_eps)
378
- if (
379
- config.sliding_window and config._attn_implementation != "flash_attention_2"
380
- ): # diff with Llama is this warning
381
- logger.warning_once(
382
- f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
383
- "unexpected results may be encountered."
384
- )
385
 
386
  def forward(
387
  self,
@@ -394,8 +425,7 @@ class SDARDecoderLayer(GradientCheckpointingLayer):
394
  store_kv: Optional[bool] = False,
395
  cache_position: Optional[torch.LongTensor] = None,
396
  # necessary, but kept here for BC
397
- position_embeddings: Optional[Tuple[torch.Tensor,
398
- torch.Tensor]] = None,
399
  **kwargs: Unpack[FlashAttentionKwargs],
400
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
401
  residual = hidden_states
@@ -463,8 +493,7 @@ class SDARRotaryEmbedding(nn.Module):
463
  super().__init__()
464
  # BC: "rope_type" was originally "type"
465
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
466
- self.rope_type = config.rope_scaling.get(
467
- "rope_type", config.rope_scaling.get("type"))
468
  else:
469
  self.rope_type = "default"
470
  self.max_seq_len_cached = config.max_position_embeddings
@@ -473,8 +502,7 @@ class SDARRotaryEmbedding(nn.Module):
473
  self.config = config
474
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
475
 
476
- inv_freq, self.attention_scaling = self.rope_init_fn(
477
- self.config, device)
478
  self.register_buffer("inv_freq", inv_freq, persistent=False)
479
  self.original_inv_freq = self.inv_freq
480
 
@@ -482,15 +510,12 @@ class SDARRotaryEmbedding(nn.Module):
482
  # power user: used with advanced RoPE types (e.g. dynamic rope)
483
  @dynamic_rope_update
484
  def forward(self, x, position_ids):
485
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
486
- position_ids.shape[0], -1, 1).to(x.device)
487
  position_ids_expanded = position_ids[:, None, :].float()
488
 
489
- device_type = x.device.type if isinstance(
490
- x.device.type, str) and x.device.type != "mps" else "cpu"
491
  with torch.autocast(device_type=device_type, enabled=False): # Force float32
492
- freqs = (inv_freq_expanded.float() @
493
- position_ids_expanded.float()).transpose(1, 2)
494
  emb = torch.cat((freqs, freqs), dim=-1)
495
  cos = emb.cos() * self.attention_scaling
496
  sin = emb.sin() * self.attention_scaling
@@ -505,12 +530,8 @@ class SDARModel(SDARPreTrainedModel):
505
  self.padding_idx = config.pad_token_id
506
  self.vocab_size = config.vocab_size
507
 
508
- self.embed_tokens = nn.Embedding(
509
- config.vocab_size, config.hidden_size, self.padding_idx)
510
- self.layers = nn.ModuleList(
511
- [SDARDecoderLayer(config, layer_idx)
512
- for layer_idx in range(config.num_hidden_layers)]
513
- )
514
  self.norm = SDARRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
515
  self.rotary_emb = SDARRotaryEmbedding(config=config)
516
  self.gradient_checkpointing = False
@@ -541,25 +562,19 @@ class SDARModel(SDARPreTrainedModel):
541
  **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
542
  ) -> BaseModelOutputWithPast:
543
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
544
- output_hidden_states = (
545
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
546
- )
547
  use_cache = use_cache if use_cache is not None else self.config.use_cache
548
 
549
  if (input_ids is None) ^ (inputs_embeds is not None):
550
- raise ValueError(
551
- "You must specify exactly one of input_ids or inputs_embeds")
552
 
553
  if self.gradient_checkpointing and self.training and use_cache:
554
- logger.warning_once(
555
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
556
- )
557
  use_cache = False
558
 
559
  # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
560
  if not isinstance(past_key_values, (type(None), Cache)):
561
- raise ValueError(
562
- "The `past_key_values` should be either a `Cache` object or `None`.")
563
 
564
  if inputs_embeds is None:
565
  inputs_embeds = self.embed_tokens(input_ids)
@@ -568,11 +583,8 @@ class SDARModel(SDARPreTrainedModel):
568
  past_key_values = DynamicCache()
569
 
570
  if cache_position is None:
571
- past_seen_tokens = past_key_values.get_seq_length(
572
- ) if past_key_values is not None else 0
573
- cache_position = torch.arange(
574
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
575
- )
576
 
577
  if position_ids is None:
578
  position_ids = cache_position.unsqueeze(0)
@@ -635,8 +647,7 @@ class SDARModel(SDARPreTrainedModel):
635
  ):
636
  if self.config._attn_implementation == "flash_attention_2":
637
  if attention_mask is not None and past_key_values is not None:
638
- is_padding_right = attention_mask[:, -
639
- 1].sum().item() != input_tensor.size()[0]
640
  if is_padding_right:
641
  raise ValueError(
642
  "You are attempting to perform batched generation with padding_side='right'"
@@ -653,7 +664,10 @@ class SDARModel(SDARPreTrainedModel):
653
  attention_mask = create_block_mask(
654
  # 2d bool tensor, shape: [2*seqlen, 2*seqlen]
655
  lambda b, h, q_idx, kv_idx: attention_mask[q_idx, kv_idx],
656
- B=None, H=None, Q_LEN=seq_len_q, KV_LEN=seq_len_kv,
 
 
 
657
  )
658
  else:
659
  # Here we pass in flex mask computed externally
@@ -663,18 +677,12 @@ class SDARModel(SDARPreTrainedModel):
663
  # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
664
  # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
665
  # to infer the attention mask.
666
- past_seen_tokens = past_key_values.get_seq_length(
667
- ) if past_key_values is not None else 0
668
  using_static_cache = isinstance(past_key_values, StaticCache)
669
- using_sliding_window_cache = isinstance(
670
- past_key_values, SlidingWindowCache)
671
 
672
  # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
673
- if (
674
- self.config._attn_implementation == "sdpa"
675
- and not (using_static_cache or using_sliding_window_cache)
676
- and not output_attentions
677
- ):
678
  if AttentionMaskConverter._ignore_causal_mask_sdpa(
679
  attention_mask,
680
  inputs_embeds=input_tensor,
@@ -692,11 +700,7 @@ class SDARModel(SDARPreTrainedModel):
692
  target_length = past_key_values.get_max_cache_shape()
693
  # DynamicCache or no cache
694
  else:
695
- target_length = (
696
- attention_mask.shape[-1]
697
- if isinstance(attention_mask, torch.Tensor)
698
- else past_seen_tokens + sequence_length + 1
699
- )
700
 
701
  # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
702
  causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
@@ -710,17 +714,11 @@ class SDARModel(SDARPreTrainedModel):
710
  past_key_values=past_key_values,
711
  )
712
 
713
- if (
714
- self.config._attn_implementation == "sdpa"
715
- and attention_mask is not None
716
- and attention_mask.device.type in ["cuda", "xpu", "npu"]
717
- and not output_attentions
718
- ):
719
  # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
720
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
721
  # Details: https://github.com/pytorch/pytorch/issues/110213
722
- causal_mask = AttentionMaskConverter._unmask_unattended(
723
- causal_mask, min_dtype)
724
 
725
  return causal_mask
726
 
@@ -761,42 +759,29 @@ class SDARModel(SDARPreTrainedModel):
761
  causal_mask = attention_mask
762
  else:
763
  min_dtype = torch.finfo(dtype).min
764
- causal_mask = torch.full(
765
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
766
- )
767
- diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
768
- -1, 1
769
- )
770
  text_config = config.get_text_config()
771
  if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
772
  # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
773
  # the check is needed to verify is current checkpoint was trained with sliding window or not
774
  if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
775
- sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
776
- cache_position.reshape(-1, 1) -
777
- text_config.sliding_window
778
- )
779
  diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
780
  causal_mask *= diagonal_attend_mask
781
- causal_mask = causal_mask[None, None,
782
- :, :].expand(batch_size, 1, -1, -1)
783
  if attention_mask is not None:
784
  causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
785
  if attention_mask.shape[-1] > target_length:
786
  attention_mask = attention_mask[:, :target_length]
787
  mask_length = attention_mask.shape[-1]
788
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
789
- causal_mask.device
790
- )
791
  padding_mask = padding_mask == 0
792
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
793
- padding_mask, min_dtype
794
- )
795
  return causal_mask
796
 
797
 
798
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
799
- ...
800
 
801
 
802
  @auto_docstring
@@ -809,8 +794,7 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
809
  super().__init__(config)
810
  self.model = SDARModel(config)
811
  self.vocab_size = config.vocab_size
812
- self.lm_head = nn.Linear(
813
- config.hidden_size, config.vocab_size, bias=False)
814
 
815
  # Initialize weights and apply final processing
816
  self.post_init()
@@ -868,9 +852,7 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
868
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
869
  ```"""
870
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
871
- output_hidden_states = (
872
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
873
- )
874
 
875
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
876
  outputs: BaseModelOutputWithPast = self.model(
@@ -888,8 +870,7 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
888
 
889
  hidden_states = outputs.last_hidden_state
890
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
891
- slice_indices = slice(-logits_to_keep,
892
- None) if isinstance(logits_to_keep, int) else logits_to_keep
893
  hidden_states = hidden_states[:, slice_indices, :].contiguous()
894
  fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
895
  if fuse_linear_and_cross_entropy:
@@ -903,8 +884,7 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
903
  # FusedLinearCrossEntropyLoss will be implemented by monkey patch when training
904
  # We don't use it when inferencing
905
  loss_fct = nn.CrossEntropyLoss() # nn.CE
906
- loss = loss_fct(
907
- logits.view(-1, self.config.vocab_size), labels.view(-1))
908
 
909
  return CausalLMOutputWithPast(
910
  loss=loss,
@@ -919,4 +899,4 @@ __all__ = [
919
  "SDARForCausalLM",
920
  "SDARModel",
921
  "SDARPreTrainedModel",
922
- ]
 
23
 
24
  from typing import Callable, Optional, Tuple, Union
25
 
26
+ from nnll.init_gpu import device
27
  import torch
28
  from torch import nn
29
+ import torch.nn.functional as F
30
  from transformers.activations import ACT2FN
31
+ from transformers.cache_utils import (
32
+ Cache,
33
+ DynamicCache,
34
+ SlidingWindowCache,
35
+ StaticCache,
36
+ )
37
  from transformers.generation import GenerationMixin
38
  from transformers.integrations import use_kernel_forward_from_hub
39
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
 
40
  from transformers.modeling_layers import GradientCheckpointingLayer
41
  from transformers.modeling_outputs import (
42
  BaseModelOutputWithPast,
 
48
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
49
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
50
  from transformers.processing_utils import Unpack
51
+ from transformers.utils import (
52
+ LossKwargs,
53
+ auto_docstring,
54
+ can_return_tuple,
55
+ is_torch_flex_attn_available,
56
+ logging,
57
+ )
58
 
59
+ from divisor.trado.configuration_sdar import SDARConfig
60
 
61
+ logger = logging.get_logger(__name__)
 
 
 
 
 
62
 
63
+ # Make FlashAttentionKwargs available for all devices (used in type hints)
64
  try:
65
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
 
66
  except ImportError:
67
+ # Fallback if not available
68
+ from typing import TypedDict
69
+
70
+ FlashAttentionKwargs = TypedDict("FlashAttentionKwargs", {})
71
+
72
+ # Conditionally import flash attention components (CUDA only)
73
+ flash_rms_norm = None
74
+ flash_attn_func = None
75
+ flash_attn_varlen_func = None
76
+ index_first_axis = None
77
+ pad_input = None
78
+ unpad_input = None
79
+
80
+ if device.type == "cuda":
81
+ try:
82
+ from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
83
+ except (ImportError, ModuleNotFoundError):
84
+ logger.warning("Flash attention RMS norm not available. Falling back to standard implementation.")
85
+ flash_rms_norm = None
86
+
87
+ try:
88
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
89
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
90
+ except (ImportError, ModuleNotFoundError):
91
+ logger.warning("Flash attention not available. Falling back to standard attention.")
92
+ flash_attn_func = None
93
+ flash_attn_varlen_func = None
94
 
95
  if is_torch_flex_attn_available():
96
+ try:
97
+ from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention
98
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
99
+ except ImportError:
100
+ pass
101
 
102
+ try:
103
+ from liger_kernel.ops.swiglu import LigerSiLUMulFunction # noqa: F401
104
 
105
+ liger_kernel_is_available = True
106
+ except ImportError:
107
+ liger_kernel_is_available = False
108
 
109
 
110
  @use_kernel_forward_from_hub("RMSNorm")
 
118
  self.variance_epsilon = eps
119
 
120
  def forward(self, hidden_states):
121
+ # Use flash RMS norm if available (CUDA only), otherwise fall back to standard implementation
122
+ if flash_rms_norm is not None and hidden_states.device.type == "cuda":
123
+ try:
124
+ return flash_rms_norm(hidden_states, weight=self.weight, bias=None, eps=self.variance_epsilon)
125
+ except Exception as e:
126
+ logger.warning(f"Flash RMS norm failed ({e}). Falling back to standard implementation.")
127
+ # Fall through to standard implementation
128
+
129
+ # Standard RMS norm implementation (fallback for MPS, CPU, or when flash_rms_norm fails)
130
  input_dtype = hidden_states.dtype
131
  hidden_states = hidden_states.to(torch.float32)
132
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
133
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 
134
  return self.weight * hidden_states.to(input_dtype)
 
135
 
136
  def extra_repr(self):
137
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
 
143
  self.config = config
144
  self.hidden_size = config.hidden_size
145
  self.intermediate_size = config.intermediate_size
146
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
147
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
148
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
 
 
 
149
  self.act_fn = ACT2FN[config.hidden_act]
150
 
151
  def forward(self, x):
152
  if liger_kernel_is_available:
153
  return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
154
  else:
155
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 
156
  return down_proj
157
 
158
 
159
  def rotate_half(x):
160
  """Rotates half the hidden dims of the input."""
161
  x1 = x[..., : x.shape[-1] // 2]
162
+ x2 = x[..., x.shape[-1] // 2 :]
163
  return torch.cat((-x2, x1), dim=-1)
164
 
165
 
 
197
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
198
  if n_rep == 1:
199
  return hidden_states
200
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
 
201
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
202
 
203
 
 
219
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
220
  attn_weights = attn_weights + causal_mask
221
 
222
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
223
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
 
 
224
  attn_output = torch.matmul(attn_weights, value_states)
225
  attn_output = attn_output.transpose(1, 2).contiguous()
226
 
 
234
  super().__init__()
235
  self.config = config
236
  self.layer_idx = layer_idx
237
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
 
238
  self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
239
  self.scaling = self.head_dim**-0.5
240
  self.attention_dropout = config.attention_dropout
 
244
  self.num_attention_heads = config.num_attention_heads
245
  self.num_key_value_heads = config.num_key_value_heads
246
 
247
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias)
248
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
249
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
250
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
 
 
 
 
 
 
 
 
251
  # unlike olmo, only on the head dim!
252
  self.q_norm = SDARRMSNorm(self.head_dim, eps=config.rms_norm_eps)
253
  # thus post q_norm does not need reshape
254
  self.k_norm = SDARRMSNorm(self.head_dim, eps=config.rms_norm_eps)
255
  self.sliding_window = config.sliding_window
256
+ if not (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers):
 
 
 
 
257
  self.sliding_window = None
258
 
259
  def forward(
 
269
  bsz, q_len = input_shape
270
  hidden_shape = (*input_shape, -1, self.head_dim)
271
 
272
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
273
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
274
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
 
 
 
 
 
275
 
276
  cos, sin = position_embeddings
277
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 
278
 
279
  if past_key_value is not None and kwargs.get("store_kv", False):
280
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
281
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
 
282
  elif past_key_value is not None and not kwargs.get("store_kv", False) and len(past_key_value) > self.layer_idx:
283
  # only retrive, do not store kv
284
  past_key_states, past_value_states = past_key_value[self.layer_idx]
285
+ key_states = torch.cat([past_key_states, key_states], dim=-2)
286
+ value_states = torch.cat([past_value_states, value_states], dim=-2)
 
 
287
 
288
+ """
289
  attention_mask = attention_mask.bool() if attention_mask is not None else None
290
  if torch.all(attention_mask): # decoding
291
  query_states = query_states.transpose(1, 2)
 
310
  enable_gqa=True
311
  )
312
  attn_output = attn_output.transpose(1, 2).contiguous()
313
+ """
314
 
315
+ # print(query_states.shape, key_states.shape, value_states.shape)
316
 
317
  # --- After RoPE and KV-cache handling, expand KV to all heads ---
318
+ key_states = repeat_kv(key_states, self.num_key_value_groups) # [B, H, K, D]
319
  value_states = repeat_kv(value_states, self.num_key_value_groups) # [B, H, K, D]
320
 
321
  # --- Convert a 0/1 or bool 4D mask into an *additive* mask, and align to [B, H, Q, K] ---
 
325
  am = attention_mask
326
  # Support either 2D [B, K] or 4D [B, 1/H, Q, K]
327
  if am.dim() == 2:
328
+ am = am[:, None, None, :k_len] # -> [B,1,1,K]
329
  else:
330
+ am = am[:, :, :, :k_len] # -> [B,1/H,Q,K]
331
 
332
  finfo_min = torch.finfo(query_states.dtype).min
333
  # 0/1 or bool -> float additive mask: 1->0, 0->-inf
334
  if am.dtype == torch.bool:
335
+ zero = torch.zeros((), dtype=query_states.dtype, device=am.device)
336
  neginf = torch.full((), finfo_min, dtype=query_states.dtype, device=am.device)
337
  am = torch.where(am, zero, neginf)
338
  else:
 
341
  am = torch.where(am > 0, torch.zeros_like(am), torch.full_like(am, finfo_min))
342
 
343
  # Expand to all heads
344
+ # if am.shape[1] == 1 and self.num_attention_heads > 1:
345
  # am = am.expand(am.shape[0], self.num_attention_heads, am.shape[2], am.shape[3])
346
 
347
+ # attn_mask = am.contiguous()
348
  attn_mask = am
 
349
 
350
  bsz, q_len = input_shape
351
 
352
  if q_len == 1 and past_key_value is not None:
353
+ # --- Decoding: try flash-attn if available (CUDA only), otherwise fall back to SDPA ---
354
+ if flash_attn_func is not None and query_states.device.type == "cuda":
355
+ try:
356
+ q = query_states.transpose(1, 2) # [B,Q,H,D]
357
+ k = key_states.transpose(1, 2)
358
+ v = value_states.transpose(1, 2)
359
+ attn_output = flash_attn_func(
360
+ q,
361
+ k,
362
+ v,
363
+ causal=True, # For decoding, explicitly set causal=True
364
+ softmax_scale=self.scaling,
365
+ )
366
+ attn_output = attn_output.transpose(1, 2).contiguous()
367
+ except Exception as e:
368
+ logger.warning(f"Flash attention failed during decoding ({e}). Falling back to SDPA.")
369
+ # Fall through to SDPA implementation below
370
+ attn_output = F.scaled_dot_product_attention(
371
+ query=query_states, # [B,H,Q,D]
372
+ key=key_states, # [B,H,K,D]
373
+ value=value_states, # [B,H,K,D]
374
+ attn_mask=attn_mask, # float additive mask
375
+ is_causal=False, # All constraints are already encoded in the mask
376
+ scale=self.scaling,
377
+ )
378
+ attn_output = attn_output.transpose(1, 2).contiguous() # -> [B,Q,H,D]
379
+ else:
380
+ # Fallback to SDPA for MPS, CPU, or when flash_attn_func is not available
381
+ attn_output = F.scaled_dot_product_attention(
382
+ query=query_states, # [B,H,Q,D]
383
+ key=key_states, # [B,H,K,D]
384
+ value=value_states, # [B,H,K,D]
385
+ attn_mask=attn_mask, # float additive mask
386
+ is_causal=False, # All constraints are already encoded in the mask
387
+ scale=self.scaling,
388
+ )
389
+ attn_output = attn_output.transpose(1, 2).contiguous() # -> [B,Q,H,D]
390
  else:
391
  attn_output = F.scaled_dot_product_attention(
392
+ query=query_states, # [B,H,Q,D]
393
+ key=key_states, # [B,H,K,D]
394
+ value=value_states, # [B,H,K,D]
395
+ attn_mask=attn_mask, # float additive mask
396
+ is_causal=False, # All constraints are already encoded in the mask
397
  scale=self.scaling,
398
  )
399
  attn_output = attn_output.transpose(1, 2).contiguous() # -> [B,Q,H,D]
 
409
  self.hidden_size = config.hidden_size
410
  self.self_attn = SDARAttention(config=config, layer_idx=layer_idx)
411
  self.mlp = SDARMLP(config)
412
+ self.input_layernorm = SDARRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
413
+ self.post_attention_layernorm = SDARRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
414
+ if config.sliding_window and config._attn_implementation != "flash_attention_2": # diff with Llama is this warning
415
+ logger.warning_once(f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; unexpected results may be encountered.")
 
 
 
 
 
 
 
416
 
417
  def forward(
418
  self,
 
425
  store_kv: Optional[bool] = False,
426
  cache_position: Optional[torch.LongTensor] = None,
427
  # necessary, but kept here for BC
428
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
 
429
  **kwargs: Unpack[FlashAttentionKwargs],
430
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
431
  residual = hidden_states
 
493
  super().__init__()
494
  # BC: "rope_type" was originally "type"
495
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
496
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
 
497
  else:
498
  self.rope_type = "default"
499
  self.max_seq_len_cached = config.max_position_embeddings
 
502
  self.config = config
503
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
504
 
505
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
 
506
  self.register_buffer("inv_freq", inv_freq, persistent=False)
507
  self.original_inv_freq = self.inv_freq
508
 
 
510
  # power user: used with advanced RoPE types (e.g. dynamic rope)
511
  @dynamic_rope_update
512
  def forward(self, x, position_ids):
513
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 
514
  position_ids_expanded = position_ids[:, None, :].float()
515
 
516
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
 
517
  with torch.autocast(device_type=device_type, enabled=False): # Force float32
518
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 
519
  emb = torch.cat((freqs, freqs), dim=-1)
520
  cos = emb.cos() * self.attention_scaling
521
  sin = emb.sin() * self.attention_scaling
 
530
  self.padding_idx = config.pad_token_id
531
  self.vocab_size = config.vocab_size
532
 
533
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
534
+ self.layers = nn.ModuleList([SDARDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
 
 
 
 
535
  self.norm = SDARRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
536
  self.rotary_emb = SDARRotaryEmbedding(config=config)
537
  self.gradient_checkpointing = False
 
562
  **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
563
  ) -> BaseModelOutputWithPast:
564
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
565
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
566
  use_cache = use_cache if use_cache is not None else self.config.use_cache
567
 
568
  if (input_ids is None) ^ (inputs_embeds is not None):
569
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
 
570
 
571
  if self.gradient_checkpointing and self.training and use_cache:
572
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
 
 
573
  use_cache = False
574
 
575
  # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
576
  if not isinstance(past_key_values, (type(None), Cache)):
577
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
 
578
 
579
  if inputs_embeds is None:
580
  inputs_embeds = self.embed_tokens(input_ids)
 
583
  past_key_values = DynamicCache()
584
 
585
  if cache_position is None:
586
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
587
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)
 
 
 
588
 
589
  if position_ids is None:
590
  position_ids = cache_position.unsqueeze(0)
 
647
  ):
648
  if self.config._attn_implementation == "flash_attention_2":
649
  if attention_mask is not None and past_key_values is not None:
650
+ is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
 
651
  if is_padding_right:
652
  raise ValueError(
653
  "You are attempting to perform batched generation with padding_side='right'"
 
664
  attention_mask = create_block_mask(
665
  # 2d bool tensor, shape: [2*seqlen, 2*seqlen]
666
  lambda b, h, q_idx, kv_idx: attention_mask[q_idx, kv_idx],
667
+ B=None,
668
+ H=None,
669
+ Q_LEN=seq_len_q,
670
+ KV_LEN=seq_len_kv,
671
  )
672
  else:
673
  # Here we pass in flex mask computed externally
 
677
  # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
678
  # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
679
  # to infer the attention mask.
680
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
681
  using_static_cache = isinstance(past_key_values, StaticCache)
682
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
 
683
 
684
  # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
685
+ if self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache) and not output_attentions:
 
 
 
 
686
  if AttentionMaskConverter._ignore_causal_mask_sdpa(
687
  attention_mask,
688
  inputs_embeds=input_tensor,
 
700
  target_length = past_key_values.get_max_cache_shape()
701
  # DynamicCache or no cache
702
  else:
703
+ target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1
 
 
 
 
704
 
705
  # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
706
  causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
 
714
  past_key_values=past_key_values,
715
  )
716
 
717
+ if self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu", "mps"] and not output_attentions:
 
 
 
 
 
718
  # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
719
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
720
  # Details: https://github.com/pytorch/pytorch/issues/110213
721
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
 
722
 
723
  return causal_mask
724
 
 
759
  causal_mask = attention_mask
760
  else:
761
  min_dtype = torch.finfo(dtype).min
762
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device)
763
+ diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
 
 
 
 
764
  text_config = config.get_text_config()
765
  if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
766
  # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
767
  # the check is needed to verify is current checkpoint was trained with sliding window or not
768
  if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
769
+ sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (cache_position.reshape(-1, 1) - text_config.sliding_window)
 
 
 
770
  diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
771
  causal_mask *= diagonal_attend_mask
772
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
 
773
  if attention_mask is not None:
774
  causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
775
  if attention_mask.shape[-1] > target_length:
776
  attention_mask = attention_mask[:, :target_length]
777
  mask_length = attention_mask.shape[-1]
778
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
 
 
779
  padding_mask = padding_mask == 0
780
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
 
 
781
  return causal_mask
782
 
783
 
784
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
 
785
 
786
 
787
  @auto_docstring
 
794
  super().__init__(config)
795
  self.model = SDARModel(config)
796
  self.vocab_size = config.vocab_size
797
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
798
 
799
  # Initialize weights and apply final processing
800
  self.post_init()
 
852
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
853
  ```"""
854
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
855
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
856
 
857
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
858
  outputs: BaseModelOutputWithPast = self.model(
 
870
 
871
  hidden_states = outputs.last_hidden_state
872
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
873
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
 
874
  hidden_states = hidden_states[:, slice_indices, :].contiguous()
875
  fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
876
  if fuse_linear_and_cross_entropy:
 
884
  # FusedLinearCrossEntropyLoss will be implemented by monkey patch when training
885
  # We don't use it when inferencing
886
  loss_fct = nn.CrossEntropyLoss() # nn.CE
887
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
 
888
 
889
  return CausalLMOutputWithPast(
890
  loss=loss,
 
899
  "SDARForCausalLM",
900
  "SDARModel",
901
  "SDARPreTrainedModel",
902
+ ]