remove LossKwargs

#1
by kashif HF Staff - opened
Files changed (1) hide show
  1. modeling_sdar.py +65 -23
modeling_sdar.py CHANGED
@@ -43,7 +43,7 @@ from transformers.modeling_outputs import (
43
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
44
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
45
  from transformers.processing_utils import Unpack
46
- from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
47
  from .configuration_sdar import SDARConfig
48
 
49
  from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
@@ -261,22 +261,41 @@ class SDARAttention(nn.Module):
261
  query_states, key_states = apply_rotary_pos_emb(
262
  query_states, key_states, cos, sin)
263
 
264
- if past_key_value is not None and kwargs.get("store_kv", False):
265
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
 
 
266
  key_states, value_states = past_key_value.update(
267
  key_states, value_states, self.layer_idx)
268
- elif past_key_value is not None and not kwargs.get("store_kv", False) and len(past_key_value) > self.layer_idx:
269
- # only retrive, do not store kv
270
- past_key_states, past_value_states = past_key_value[self.layer_idx]
271
- key_states = torch.cat(
272
- [past_key_states, key_states], dim=-2
273
- )
274
- value_states = torch.cat(
275
- [past_value_states, value_states], dim=-2
276
- )
277
 
278
  attention_mask = attention_mask.bool() if attention_mask is not None else None
279
- if torch.all(attention_mask): # decoding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  query_states = query_states.transpose(1, 2)
281
  key_states = key_states.transpose(1, 2)
282
  value_states = value_states.transpose(1, 2)
@@ -329,7 +348,6 @@ class SDARDecoderLayer(GradientCheckpointingLayer):
329
  past_key_value: Optional[Cache] = None,
330
  output_attentions: Optional[bool] = False,
331
  use_cache: Optional[bool] = False,
332
- store_kv: Optional[bool] = False,
333
  cache_position: Optional[torch.LongTensor] = None,
334
  # necessary, but kept here for BC
335
  position_embeddings: Optional[Tuple[torch.Tensor,
@@ -347,7 +365,6 @@ class SDARDecoderLayer(GradientCheckpointingLayer):
347
  past_key_value=past_key_value,
348
  output_attentions=output_attentions,
349
  use_cache=use_cache,
350
- store_kv=store_kv,
351
  cache_position=cache_position,
352
  position_embeddings=position_embeddings,
353
  **kwargs,
@@ -394,9 +411,27 @@ class SDARPreTrainedModel(PreTrainedModel):
394
  module.weight.data[module.padding_idx].zero_()
395
  elif isinstance(module, SDARRMSNorm):
396
  module.weight.data.fill_(1.0)
 
 
 
 
397
 
398
 
399
  class SDARRotaryEmbedding(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  def __init__(self, config: SDARConfig, device=None):
401
  super().__init__()
402
  # BC: "rope_type" was originally "type"
@@ -409,12 +444,18 @@ class SDARRotaryEmbedding(nn.Module):
409
  self.original_max_seq_len = config.max_position_embeddings
410
 
411
  self.config = config
412
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
413
 
414
- inv_freq, self.attention_scaling = self.rope_init_fn(
415
- self.config, device)
 
 
 
 
 
 
 
416
  self.register_buffer("inv_freq", inv_freq, persistent=False)
417
- self.original_inv_freq = self.inv_freq
418
 
419
  @torch.no_grad()
420
  # power user: used with advanced RoPE types (e.g. dynamic rope)
@@ -440,7 +481,10 @@ class SDARRotaryEmbedding(nn.Module):
440
  class SDARModel(SDARPreTrainedModel):
441
  def __init__(self, config: SDARConfig):
442
  super().__init__(config)
443
- self.padding_idx = config.pad_token_id
 
 
 
444
  self.vocab_size = config.vocab_size
445
 
446
  self.embed_tokens = nn.Embedding(
@@ -472,7 +516,6 @@ class SDARModel(SDARPreTrainedModel):
472
  past_key_values: Optional[Cache] = None,
473
  inputs_embeds: Optional[torch.FloatTensor] = None,
474
  use_cache: Optional[bool] = None,
475
- store_kv: Optional[bool] = None,
476
  output_attentions: Optional[bool] = None,
477
  output_hidden_states: Optional[bool] = None,
478
  cache_position: Optional[torch.LongTensor] = None,
@@ -539,7 +582,6 @@ class SDARModel(SDARPreTrainedModel):
539
  past_key_value=past_key_values,
540
  output_attentions=output_attentions,
541
  use_cache=use_cache,
542
- store_kv=store_kv,
543
  cache_position=cache_position,
544
  position_embeddings=position_embeddings,
545
  **flash_attn_kwargs,
@@ -734,7 +776,7 @@ class SDARModel(SDARPreTrainedModel):
734
  return causal_mask
735
 
736
 
737
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
738
  ...
739
 
740
 
 
43
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
44
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
45
  from transformers.processing_utils import Unpack
46
+ from transformers.utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
47
  from .configuration_sdar import SDARConfig
48
 
49
  from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
 
261
  query_states, key_states = apply_rotary_pos_emb(
262
  query_states, key_states, cos, sin)
263
 
264
+ # Standard transformers v5 cache convention: when a cache is provided, always `.update()` it.
265
+ # Callers that want a read-only forward should pass `past_key_values=None`, or use
266
+ # `DynamicCache.crop(prev_seq_len)` to roll back the append after reading the logits.
267
+ if past_key_value is not None:
268
  key_states, value_states = past_key_value.update(
269
  key_states, value_states, self.layer_idx)
 
 
 
 
 
 
 
 
 
270
 
271
  attention_mask = attention_mask.bool() if attention_mask is not None else None
272
+
273
+ # I-DLM / strict-causal mode: rely on PyTorch's built-in `is_causal=True` path so GQA
274
+ # broadcasting works cleanly with a KV cache (query q_len ≠ key k_len). We compute a
275
+ # per-query offset such that `is_causal=True` masks against key position `q + offset`,
276
+ # matching the Dream-shifted causal-LM convention.
277
+ use_regular_causal = bool(getattr(self.config, "use_regular_causal", False))
278
+ if use_regular_causal:
279
+ q_len = query_states.shape[-2]
280
+ k_len = key_states.shape[-2]
281
+ if q_len == k_len:
282
+ attn_output = F.scaled_dot_product_attention(
283
+ query=query_states, key=key_states, value=value_states,
284
+ is_causal=True, scale=self.scaling, enable_gqa=True,
285
+ )
286
+ else:
287
+ # Non-square causal: build a (q_len, k_len) mask where row `i` attends to key
288
+ # positions `0..k_len - q_len + i`. Works for any cache state.
289
+ offset = k_len - q_len
290
+ rows = torch.arange(q_len, device=query_states.device).unsqueeze(1)
291
+ cols = torch.arange(k_len, device=query_states.device).unsqueeze(0)
292
+ causal_mask = cols <= rows + offset # [q_len, k_len]
293
+ attn_output = F.scaled_dot_product_attention(
294
+ query=query_states, key=key_states, value=value_states,
295
+ attn_mask=causal_mask, is_causal=False, scale=self.scaling, enable_gqa=True,
296
+ )
297
+ attn_output = attn_output.transpose(1, 2).contiguous()
298
+ elif attention_mask is not None and torch.all(attention_mask): # decoding
299
  query_states = query_states.transpose(1, 2)
300
  key_states = key_states.transpose(1, 2)
301
  value_states = value_states.transpose(1, 2)
 
348
  past_key_value: Optional[Cache] = None,
349
  output_attentions: Optional[bool] = False,
350
  use_cache: Optional[bool] = False,
 
351
  cache_position: Optional[torch.LongTensor] = None,
352
  # necessary, but kept here for BC
353
  position_embeddings: Optional[Tuple[torch.Tensor,
 
365
  past_key_value=past_key_value,
366
  output_attentions=output_attentions,
367
  use_cache=use_cache,
 
368
  cache_position=cache_position,
369
  position_embeddings=position_embeddings,
370
  **kwargs,
 
411
  module.weight.data[module.padding_idx].zero_()
412
  elif isinstance(module, SDARRMSNorm):
413
  module.weight.data.fill_(1.0)
414
+ # Delegate rotary-embedding buffer re-init to the base PreTrainedModel, which handles
415
+ # transformers v5's meta-device load by recomputing inv_freq via compute_default_rope_parameters.
416
+ else:
417
+ super()._init_weights(module)
418
 
419
 
420
  class SDARRotaryEmbedding(nn.Module):
421
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
422
+
423
+ @staticmethod
424
+ def compute_default_rope_parameters(config, device=None, seq_len=None):
425
+ # transformers v5 removed "default" from ROPE_INIT_FUNCTIONS; match the Qwen3 implementation.
426
+ base = getattr(config, "rope_theta", None)
427
+ if base is None:
428
+ base = config.rope_parameters["rope_theta"]
429
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
430
+ inv_freq = 1.0 / (
431
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
432
+ )
433
+ return inv_freq, 1.0
434
+
435
  def __init__(self, config: SDARConfig, device=None):
436
  super().__init__()
437
  # BC: "rope_type" was originally "type"
 
444
  self.original_max_seq_len = config.max_position_embeddings
445
 
446
  self.config = config
 
447
 
448
+ if self.rope_type == "default":
449
+ inv_freq, self.attention_scaling = self.compute_default_rope_parameters(config, device)
450
+ else:
451
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
452
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
453
+
454
+ # Register both as buffers — transformers v5's `_move_missing_keys_from_meta_to_device`
455
+ # replaces non-persistent buffers with `torch.empty_like` (uninitialized / zeros); the base
456
+ # `_init_weights` then re-copies into them IF they're buffers with `original_inv_freq` present.
457
  self.register_buffer("inv_freq", inv_freq, persistent=False)
458
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
459
 
460
  @torch.no_grad()
461
  # power user: used with advanced RoPE types (e.g. dynamic rope)
 
481
  class SDARModel(SDARPreTrainedModel):
482
  def __init__(self, config: SDARConfig):
483
  super().__init__(config)
484
+ # transformers v5 configs may not have pad_token_id; fall back to eos_token_id.
485
+ self.padding_idx = getattr(config, "pad_token_id", None)
486
+ if self.padding_idx is None:
487
+ self.padding_idx = getattr(config, "eos_token_id", None)
488
  self.vocab_size = config.vocab_size
489
 
490
  self.embed_tokens = nn.Embedding(
 
516
  past_key_values: Optional[Cache] = None,
517
  inputs_embeds: Optional[torch.FloatTensor] = None,
518
  use_cache: Optional[bool] = None,
 
519
  output_attentions: Optional[bool] = None,
520
  output_hidden_states: Optional[bool] = None,
521
  cache_position: Optional[torch.LongTensor] = None,
 
582
  past_key_value=past_key_values,
583
  output_attentions=output_attentions,
584
  use_cache=use_cache,
 
585
  cache_position=cache_position,
586
  position_embeddings=position_embeddings,
587
  **flash_attn_kwargs,
 
776
  return causal_mask
777
 
778
 
779
+ class KwargsForCausalLM(FlashAttentionKwargs):
780
  ...
781
 
782