YongganFu commited on
Commit
378400a
·
verified ·
1 Parent(s): 1c298a2

Upload model

Browse files
Files changed (1) hide show
  1. modeling_ministral_dlm.py +20 -10
modeling_ministral_dlm.py CHANGED
@@ -1,4 +1,5 @@
1
  import copy
 
2
  from typing import Callable, Optional, Tuple, Union
3
  import random
4
  import os
@@ -10,6 +11,7 @@ import torch
10
  import torch.nn.functional as F
11
  from torch import nn
12
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
 
13
 
14
  from torch.nn.attention.flex_attention import flex_attention, create_block_mask
15
 
@@ -29,6 +31,17 @@ from .chat_utils import generate_with_prefix_cache_block_diff
29
  from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
30
  from .configuration_ministral_dlm import MinistralDLMConfig
31
 
 
 
 
 
 
 
 
 
 
 
 
32
  # @torch.compile(dynamic=True, mode="reduce-overhead")
33
  # @torch.compile(mode="default")
34
  # @torch.compile(fullgraph=True, mode="reduce-overhead", dynamic=False)
@@ -81,13 +94,11 @@ class MinistralFlexAttention(Ministral3Attention):
81
  - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
82
  - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
83
  - **Block Causal Mask (M_BC)**: Attention to update x0
84
-
85
  Args:
86
  b, h: Batch and head indices (ignored for mask logic).
87
  q_idx, kv_idx: Query and Key indices.
88
  seq_len: Total sequence length.
89
  block_size: Defines the block structure.
90
-
91
  Returns:
92
  A boolean attention mask.
93
  """
@@ -378,16 +389,13 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
378
  ):
379
  """
380
  Two-stage corruption with optional per-block sampling.
381
-
382
  • Stage 1: m ~ U(eps, 1) → k = round(m · len) (exact budget).
383
  • Stage 2: sample exactly k positions with weights
384
  w_i(m) = exp[ λ · (1−m) · i ] (late-heavy when m→0,
385
  uniform when m→1).
386
-
387
  If `block_size` is given, the procedure is run *independently*
388
  inside each contiguous block of that length (last block may be shorter).
389
  When block_size is provided, m is sampled per-block and p_mask is per-block.
390
-
391
  Args
392
  ----
393
  input_ids : (B, L) LongTensor
@@ -479,6 +487,7 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
479
  loss_mask: Optional[torch.Tensor] = None,
480
  ce_loss_weight: float = 1.0,
481
  output_last_hidden_states_only: bool = False,
 
482
  **kwargs,
483
  ) -> CausalLMOutputWithPast:
484
 
@@ -555,7 +564,8 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
555
  return BaseModelOutput(last_hidden_state=enc_out.last_hidden_state)
556
 
557
  logits = self.diffusion_head(enc_out.last_hidden_state) # (batch, len_B, vocab)
558
-
 
559
  if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
560
  if self.config.dlm_paradigm == 'sbd_block_diff':
561
  causal_logits = logits[:, input_ids_len:]
@@ -565,7 +575,7 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
565
  logits = logits[:, :input_ids_len]
566
 
567
  loss = None
568
- if labels is not None:
569
  if self.config.dlm_paradigm == 'autoregressive':
570
  shift_logits = logits[..., :-1, :].contiguous()
571
  shift_labels = labels[..., 1:].contiguous()
@@ -702,9 +712,10 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
702
  else:
703
  loss = (loss, num_mask_tokens)
704
 
705
- return CausalLMOutputWithPast(
706
  loss=loss if not is_teacher else logits,
707
  logits=logits,
 
708
  past_key_values=enc_out.past_key_values,
709
  hidden_states=None,
710
  attentions=None,
@@ -729,5 +740,4 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
729
 
730
  return out_ids, nfe
731
 
732
- __all__ = ["MinistralDiffEncoderModel", "MinistralFlexAttention"]
733
-
 
1
  import copy
2
+ from dataclasses import dataclass
3
  from typing import Callable, Optional, Tuple, Union
4
  import random
5
  import os
 
11
  import torch.nn.functional as F
12
  from torch import nn
13
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput
14
+ from transformers.utils import ModelOutput
15
 
16
  from torch.nn.attention.flex_attention import flex_attention, create_block_mask
17
 
 
31
  from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
32
  from .configuration_ministral_dlm import MinistralDLMConfig
33
 
34
+
35
+ @dataclass
36
+ class MinistralDiffOutputWithPast(ModelOutput):
37
+ loss: torch.FloatTensor | None = None
38
+ logits: torch.FloatTensor | None = None
39
+ causal_logits: torch.FloatTensor | None = None
40
+ past_key_values: Cache | None = None
41
+ hidden_states: tuple[torch.FloatTensor, ...] | None = None
42
+ attentions: tuple[torch.FloatTensor, ...] | None = None
43
+
44
+
45
  # @torch.compile(dynamic=True, mode="reduce-overhead")
46
  # @torch.compile(mode="default")
47
  # @torch.compile(fullgraph=True, mode="reduce-overhead", dynamic=False)
 
94
  - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
95
  - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
96
  - **Block Causal Mask (M_BC)**: Attention to update x0
 
97
  Args:
98
  b, h: Batch and head indices (ignored for mask logic).
99
  q_idx, kv_idx: Query and Key indices.
100
  seq_len: Total sequence length.
101
  block_size: Defines the block structure.
 
102
  Returns:
103
  A boolean attention mask.
104
  """
 
389
  ):
390
  """
391
  Two-stage corruption with optional per-block sampling.
 
392
  • Stage 1: m ~ U(eps, 1) → k = round(m · len) (exact budget).
393
  • Stage 2: sample exactly k positions with weights
394
  w_i(m) = exp[ λ · (1−m) · i ] (late-heavy when m→0,
395
  uniform when m→1).
 
396
  If `block_size` is given, the procedure is run *independently*
397
  inside each contiguous block of that length (last block may be shorter).
398
  When block_size is provided, m is sampled per-block and p_mask is per-block.
 
399
  Args
400
  ----
401
  input_ids : (B, L) LongTensor
 
487
  loss_mask: Optional[torch.Tensor] = None,
488
  ce_loss_weight: float = 1.0,
489
  output_last_hidden_states_only: bool = False,
490
+ skip_loss: bool = False,
491
  **kwargs,
492
  ) -> CausalLMOutputWithPast:
493
 
 
564
  return BaseModelOutput(last_hidden_state=enc_out.last_hidden_state)
565
 
566
  logits = self.diffusion_head(enc_out.last_hidden_state) # (batch, len_B, vocab)
567
+ causal_logits = None
568
+
569
  if labels is not None and self.config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
570
  if self.config.dlm_paradigm == 'sbd_block_diff':
571
  causal_logits = logits[:, input_ids_len:]
 
575
  logits = logits[:, :input_ids_len]
576
 
577
  loss = None
578
+ if labels is not None and not skip_loss:
579
  if self.config.dlm_paradigm == 'autoregressive':
580
  shift_logits = logits[..., :-1, :].contiguous()
581
  shift_labels = labels[..., 1:].contiguous()
 
712
  else:
713
  loss = (loss, num_mask_tokens)
714
 
715
+ return MinistralDiffOutputWithPast(
716
  loss=loss if not is_teacher else logits,
717
  logits=logits,
718
+ causal_logits=causal_logits,
719
  past_key_values=enc_out.past_key_values,
720
  hidden_states=None,
721
  attentions=None,
 
740
 
741
  return out_ids, nfe
742
 
743
+ __all__ = ["MinistralDiffEncoderModel", "MinistralFlexAttention"]