Fraser commited on
Commit
7493ebb
·
verified ·
1 Parent(s): 56be64d

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.json +8 -3
  2. configuration_recursive.py +179 -0
  3. modeling_recursive.py +1732 -0
config.json CHANGED
@@ -17,7 +17,8 @@
17
  "auto_map": {
18
  "AutoConfig": "configuration_llada.LLaDAConfig",
19
  "AutoModel": "modeling_llada.LLaDAModelLM",
20
- "AutoModelForCausalLM": "modeling_llada.LLaDAModelLM"
 
21
  },
22
  "bad_words_ids": null,
23
  "begin_suppress_tokens": null,
@@ -149,5 +150,9 @@
149
  "soft_embedding_method": "softmax",
150
  "temperature_max": 0.0,
151
  "transformers_version": "4.57.0",
152
- "use_recursion_checkpointing": true
153
- }
 
 
 
 
 
17
  "auto_map": {
18
  "AutoConfig": "configuration_llada.LLaDAConfig",
19
  "AutoModel": "modeling_llada.LLaDAModelLM",
20
+ "AutoModelForCausalLM": "modeling_llada.LLaDAModelLM",
21
+ "AutoModelForMaskedLM": "modeling_llada.LLaDAModelLM"
22
  },
23
  "bad_words_ids": null,
24
  "begin_suppress_tokens": null,
 
150
  "soft_embedding_method": "softmax",
151
  "temperature_max": 0.0,
152
  "transformers_version": "4.57.0",
153
+ "use_recursion_checkpointing": true,
154
+ "auto_map": {
155
+ "AutoConfig": "configuration_recursive.RecursiveMLMConfig",
156
+ "AutoModel": "modeling_recursive.RecursiveMaskedLM"
157
+ }
158
+ }
configuration_recursive.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Optional
3
+
4
+ from transformers import PretrainedConfig
5
+
6
+
7
+ class RecursiveMLMConfig(PretrainedConfig):
8
+ """
9
+ Configuration for RecursiveMaskedLM.
10
+
11
+ Stores the base MLM config plus recursive refinement parameters.
12
+
13
+ Convergence Schedule System
14
+ ---------------------------
15
+ The convergence schedule controls WHEN each position is allowed to converge
16
+ to a confident prediction during iterative refinement.
17
+
18
+ Schedule types:
19
+ - "linear": All positions converge at the same rate (iteration-based only)
20
+ - "causal": Early positions converge first, late positions last
21
+
22
+ Effects (mechanisms to enforce the schedule):
23
+ - temperature_max: Raise temperature for positions not yet allowed to converge
24
+ - entropy_target_max: Force exact entropy via bisection search (two-sided, recommended)
25
+ - entropy_floor_max: Force minimum entropy (one-sided, only raises)
26
+ - smear_sigma_max: Spread probability across neighboring positions
27
+ - noise_std_max: Add Gaussian noise to logits
28
+ - iteration_rope_dim_fraction: Apply rotary embedding based on iteration progress
29
+
30
+ Soft Embedding Methods
31
+ ----------------------
32
+ Controls how logits are converted to soft embeddings for the next iteration:
33
+ - "softmax": Standard softmax normalization (default). Creates sparse, probabilistic
34
+ mixing but can cause gradient bottlenecks through the softmax Jacobian.
35
+ - "l2_normalize": L2 normalize logits before mixing with embeddings. Removes the
36
+ softmax bottleneck for smoother gradients through long recursion chains.
37
+ - "none": No normalization - use raw logits directly. Warning: this can cause
38
+ scale explosion without additional mechanisms like EMA accumulation.
39
+
40
+ - soft_embedding_ema_step: Controls EMA blending with previous soft embeddings.
41
+ 1.0 (default) = full update (no EMA), 0.1 = slow update (90% previous + 10% new).
42
+ Formula: new = (1 - ema_step) * prev + ema_step * current
43
+
44
+ Recursion Checkpointing
45
+ -----------------------
46
+ Controls gradient flow through the entire recursion chain for memory-efficient training.
47
+
48
+ Parameters:
49
+ - use_recursion_checkpointing: Enable gradient checkpointing for iterations
50
+ - loss_weight: Use "last_1" for final-iteration-only loss (learns convergence behavior)
51
+
52
+ Flow Matching (CFM-inspired)
53
+ ----------------------------
54
+ Replaces the old temperature-based self-distillation with a Continuous Flow Matching
55
+ framework. Training inputs are interpolated on the probability simplex between random
56
+ noise and the target one-hot, distillation gives the student a noisier (earlier-time)
57
+ version of the same interpolation path, and inference uses a flow map update rule.
58
+
59
+ Parameters:
60
+ - flow_matching_enabled: Enable the flow matching framework
61
+ - flow_matching_lambda: Weight of distillation KL loss relative to CE loss
62
+ - flow_matching_t_distribution: How to sample time t ("logit_normal" or "uniform")
63
+ - flow_matching_t_logit_mean: Mean of logit-normal distribution (-0.4 biases toward noisy)
64
+ - flow_matching_t_logit_std: Std of logit-normal distribution
65
+ - flow_matching_t_min: Minimum time value (clamp)
66
+ - flow_matching_t_max: Maximum time value (clamp)
67
+ - flow_matching_mask_scale: If True, scale mask_emb by (1-t); if False, binary mask signal
68
+
69
+ Time levels are sampled independently per masked token. At t=0 the input is pure noise,
70
+ at t=1 it is the clean target embedding.
71
+
72
+ Self-Distillation (legacy, temperature-based)
73
+ ----------------------------------------------
74
+ Kept for backward compatibility. Ignored when flow_matching_enabled=True.
75
+
76
+ Parameters:
77
+ - self_distillation_enabled: Enable the self-distillation KL loss
78
+ - self_distillation_lambda: Weight of distillation loss relative to CE loss
79
+ - self_distillation_temperature_min: Minimum degradation temperature
80
+ - self_distillation_temperature_max: Maximum degradation temperature
81
+ - self_distillation_temperature_distribution: How to sample temperature
82
+ - self_distillation_teacher: Which logits to use as teacher ("first" or "last")
83
+ """
84
+ model_type = "recursive-mlm"
85
+
86
+ def __init__(
87
+ self,
88
+ base_model_config: Optional[dict] = None,
89
+ num_recursions: int = 8,
90
+ normalization: str = "softmax",
91
+ loss_weight: str = "linear",
92
+ mask_token_id: Optional[int] = None,
93
+ temperature: float = 1.0,
94
+ gradient_steps: Optional[int] = None,
95
+ # === Convergence schedule parameters ===
96
+ schedule: str = "linear",
97
+ causal_strength: float = 1.0,
98
+ # === Effect parameters ===
99
+ temperature_max: float = 0.0,
100
+ entropy_target_max: float = 0.0,
101
+ entropy_floor_max: float = 0.0,
102
+ smear_sigma_max: float = 0.0,
103
+ noise_std_max: float = 0.0,
104
+ iteration_rope_dim_fraction: float = 0.0,
105
+ use_recursion_checkpointing: bool = True,
106
+ # === Soft embedding method ===
107
+ soft_embedding_method: str = "softmax",
108
+ soft_embedding_ema_step: float = 1.0,
109
+ # === Flow matching parameters (CFM-inspired) ===
110
+ flow_matching_enabled: bool = False,
111
+ flow_matching_lambda: float = 0.5,
112
+ flow_matching_t_distribution: str = "logit_normal",
113
+ flow_matching_t_logit_mean: float = -0.4,
114
+ flow_matching_t_logit_std: float = 1.0,
115
+ flow_matching_t_min: float = 0.01,
116
+ flow_matching_t_max: float = 0.99,
117
+ flow_matching_noise_scale: float = 2.0,
118
+ flow_matching_mask_scale: bool = False,
119
+ # === Self-distillation parameters (legacy, ignored when flow_matching_enabled) ===
120
+ self_distillation_enabled: bool = False,
121
+ self_distillation_lambda: float = 0.5,
122
+ self_distillation_temperature_min: float = 1.5,
123
+ self_distillation_temperature_max: float = 10.0,
124
+ self_distillation_temperature_distribution: str = "log_uniform",
125
+ self_distillation_teacher: str = "first",
126
+ **kwargs,
127
+ ):
128
+ super().__init__(**kwargs)
129
+ self.base_model_config = base_model_config
130
+ self.num_recursions = num_recursions
131
+ self.normalization = normalization
132
+ self.loss_weight = loss_weight
133
+ self.mask_token_id = mask_token_id
134
+ self.temperature = temperature
135
+ self.gradient_steps = gradient_steps
136
+ # Convergence schedule
137
+ self.schedule = schedule
138
+ self.causal_strength = causal_strength
139
+ # Effects
140
+ self.temperature_max = temperature_max
141
+ self.entropy_target_max = entropy_target_max
142
+ self.entropy_floor_max = entropy_floor_max
143
+ self.smear_sigma_max = smear_sigma_max
144
+ self.noise_std_max = noise_std_max
145
+ self.iteration_rope_dim_fraction = iteration_rope_dim_fraction
146
+ # Recursion checkpointing
147
+ self.use_recursion_checkpointing = use_recursion_checkpointing
148
+ # Soft embedding method
149
+ self.soft_embedding_method = soft_embedding_method
150
+ self.soft_embedding_ema_step = soft_embedding_ema_step
151
+ # Flow matching
152
+ self.flow_matching_enabled = flow_matching_enabled
153
+ self.flow_matching_lambda = flow_matching_lambda
154
+ self.flow_matching_t_distribution = flow_matching_t_distribution
155
+ self.flow_matching_t_logit_mean = flow_matching_t_logit_mean
156
+ self.flow_matching_t_logit_std = flow_matching_t_logit_std
157
+ self.flow_matching_t_min = flow_matching_t_min
158
+ self.flow_matching_t_max = flow_matching_t_max
159
+ self.flow_matching_noise_scale = flow_matching_noise_scale
160
+ self.flow_matching_mask_scale = flow_matching_mask_scale
161
+ # Self-distillation (legacy)
162
+ self.self_distillation_enabled = self_distillation_enabled
163
+ self.self_distillation_lambda = self_distillation_lambda
164
+ self.self_distillation_temperature_min = self_distillation_temperature_min
165
+ self.self_distillation_temperature_max = self_distillation_temperature_max
166
+ self.self_distillation_temperature_distribution = self_distillation_temperature_distribution
167
+ self.self_distillation_teacher = self_distillation_teacher
168
+
169
+ @classmethod
170
+ def from_base_model_config(
171
+ cls,
172
+ base_config: PretrainedConfig,
173
+ **kwargs,
174
+ ) -> "RecursiveMLMConfig":
175
+ """Create config from a base MLM's config."""
176
+ return cls(
177
+ base_model_config=base_config.to_dict(),
178
+ **kwargs,
179
+ )
modeling_recursive.py ADDED
@@ -0,0 +1,1732 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import warnings
3
+ from dataclasses import dataclass
4
+ from typing import NamedTuple, Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch.nn import CrossEntropyLoss
9
+ from torch.utils.checkpoint import checkpoint as torch_checkpoint
10
+ from transformers import AutoConfig, AutoModelForMaskedLM, PreTrainedModel
11
+ from transformers.modeling_outputs import MaskedLMOutput
12
+ from transformers.utils import ModelOutput
13
+
14
+ from .configuration_recursive import RecursiveMLMConfig
15
+
16
+
17
+ @dataclass
18
+ class IterationMetrics(ModelOutput):
19
+ """Metrics for a single iteration of recursive refinement."""
20
+ accuracy: Optional[float] = None
21
+ entropy: Optional[float] = None
22
+ softmax_ce: Optional[float] = None
23
+ full_sequence_accuracy: Optional[float] = None
24
+ min_sequence_confidence: Optional[float] = None
25
+
26
+
27
+ @dataclass
28
+ class RecursiveMaskedLMOutput(MaskedLMOutput):
29
+ iteration_metrics: Optional[dict[int, IterationMetrics]] = None # Maps iteration index to metrics
30
+ next_soft_embeds: Optional[torch.Tensor] = None # For caching between training steps
31
+ all_logits: Optional[list[torch.Tensor]] = None # All T iterations' logits for trainer loss computation
32
+ # Flow matching state (for distillation — compact H-dim, not V-dim)
33
+ flow_noise_embed: Optional[torch.Tensor] = None # (num_masked, H) noise embedding
34
+ flow_t: Optional[torch.Tensor] = None # (num_masked,) per-token time levels
35
+
36
+
37
+ class SelfDistillationOutput(NamedTuple):
38
+ """Output from self-distillation forward pass."""
39
+ loss: torch.Tensor # KL divergence loss (scalar, has grad)
40
+ teacher_logits: torch.Tensor # For metrics/debugging (detached)
41
+ student_logits: torch.Tensor # For metrics/debugging (has grad)
42
+ degradation_temperature: float # Mean per-token temperature sampled
43
+ teacher_entropy: float # Entropy of teacher distribution (for monitoring)
44
+ student_entropy: float # Entropy of student distribution (for monitoring)
45
+ agreement_rate: float # Fraction where teacher and student argmax agree
46
+
47
+
48
+ class RecursiveMaskedLM(PreTrainedModel):
49
+ """
50
+ Wraps any HF MLM with recursive soft-token refinement.
51
+
52
+ At each step:
53
+ 1. Normalize logits -> probs
54
+ 2. Compute soft embeddings: probs @ embedding_weight + mask_embedding
55
+ 3. Forward through MLM
56
+ 4. Accumulate weighted loss
57
+ """
58
+ config_class = RecursiveMLMConfig
59
+ base_model_prefix = "mlm"
60
+ supports_gradient_checkpointing = True
61
+
62
+ def __init__(self, config: RecursiveMLMConfig, base_model: Optional[PreTrainedModel] = None):
63
+ super().__init__(config)
64
+
65
+ if base_model is not None:
66
+ # Pre-trained model provided - assign directly WITHOUT calling post_init()
67
+ # to avoid reinitializing the pre-trained weights via _init_weights()
68
+ self.mlm = base_model
69
+ elif config.base_model_config is not None:
70
+ base_config = AutoConfig.for_model(**config.base_model_config)
71
+ self.mlm = AutoModelForMaskedLM.from_config(base_config)
72
+ # Only call post_init() for freshly created models (needs weight init)
73
+ self.post_init()
74
+ else:
75
+ raise ValueError("Need either base_model or config.base_model_config")
76
+
77
+ @classmethod
78
+ def from_mlm_pretrained(
79
+ cls,
80
+ mlm_name_or_path: str,
81
+ num_recursions: int = 8,
82
+ normalization: str = "softmax",
83
+ loss_weight: str = "linear",
84
+ mask_token_id: Optional[int] = None,
85
+ temperature: float = 1.0,
86
+ gradient_steps: Optional[int] = None,
87
+ # === Convergence schedule parameters ===
88
+ schedule: str = "linear",
89
+ causal_strength: float = 1.0,
90
+ # === Effect parameters ===
91
+ temperature_max: float = 0.0,
92
+ entropy_target_max: float = 0.0,
93
+ entropy_floor_max: float = 0.0,
94
+ smear_sigma_max: float = 0.0,
95
+ noise_std_max: float = 0.0,
96
+ iteration_rope_dim_fraction: float = 0.0,
97
+ use_recursion_checkpointing: bool = True,
98
+ # === Soft embedding method ===
99
+ soft_embedding_method: str = "softmax",
100
+ soft_embedding_ema_step: float = 1.0,
101
+ # === Flow matching parameters ===
102
+ flow_matching_enabled: bool = False,
103
+ flow_matching_lambda: float = 0.5,
104
+ flow_matching_t_distribution: str = "logit_normal",
105
+ flow_matching_t_logit_mean: float = -0.4,
106
+ flow_matching_t_logit_std: float = 1.0,
107
+ flow_matching_t_min: float = 0.01,
108
+ flow_matching_t_max: float = 0.99,
109
+ flow_matching_mask_scale: bool = False,
110
+ **model_kwargs,
111
+ ) -> "RecursiveMaskedLM":
112
+ """Load a pretrained MLM and wrap it for recursive refinement."""
113
+ base_model = AutoModelForMaskedLM.from_pretrained(mlm_name_or_path, **model_kwargs)
114
+ return cls.from_base_model(
115
+ base_model,
116
+ num_recursions=num_recursions,
117
+ normalization=normalization,
118
+ loss_weight=loss_weight,
119
+ mask_token_id=mask_token_id,
120
+ temperature=temperature,
121
+ gradient_steps=gradient_steps,
122
+ schedule=schedule,
123
+ causal_strength=causal_strength,
124
+ temperature_max=temperature_max,
125
+ entropy_target_max=entropy_target_max,
126
+ entropy_floor_max=entropy_floor_max,
127
+ smear_sigma_max=smear_sigma_max,
128
+ noise_std_max=noise_std_max,
129
+ iteration_rope_dim_fraction=iteration_rope_dim_fraction,
130
+ use_recursion_checkpointing=use_recursion_checkpointing,
131
+ soft_embedding_method=soft_embedding_method,
132
+ soft_embedding_ema_step=soft_embedding_ema_step,
133
+ flow_matching_enabled=flow_matching_enabled,
134
+ flow_matching_lambda=flow_matching_lambda,
135
+ flow_matching_t_distribution=flow_matching_t_distribution,
136
+ flow_matching_t_logit_mean=flow_matching_t_logit_mean,
137
+ flow_matching_t_logit_std=flow_matching_t_logit_std,
138
+ flow_matching_t_min=flow_matching_t_min,
139
+ flow_matching_t_max=flow_matching_t_max,
140
+ flow_matching_mask_scale=flow_matching_mask_scale,
141
+ )
142
+
143
+ @classmethod
144
+ def from_base_model(
145
+ cls,
146
+ base_model: PreTrainedModel,
147
+ num_recursions: int = 8,
148
+ normalization: str = "softmax",
149
+ loss_weight: str = "linear",
150
+ mask_token_id: Optional[int] = None,
151
+ temperature: float = 1.0,
152
+ gradient_steps: Optional[int] = None,
153
+ # === Convergence schedule parameters ===
154
+ schedule: str = "linear",
155
+ causal_strength: float = 1.0,
156
+ # === Effect parameters ===
157
+ temperature_max: float = 0.0,
158
+ entropy_target_max: float = 0.0,
159
+ entropy_floor_max: float = 0.0,
160
+ smear_sigma_max: float = 0.0,
161
+ noise_std_max: float = 0.0,
162
+ iteration_rope_dim_fraction: float = 0.0,
163
+ use_recursion_checkpointing: bool = True,
164
+ # === Soft embedding method ===
165
+ soft_embedding_method: str = "softmax",
166
+ soft_embedding_ema_step: float = 1.0,
167
+ # === Flow matching parameters ===
168
+ flow_matching_enabled: bool = False,
169
+ flow_matching_lambda: float = 0.5,
170
+ flow_matching_t_distribution: str = "logit_normal",
171
+ flow_matching_t_logit_mean: float = -0.4,
172
+ flow_matching_t_logit_std: float = 1.0,
173
+ flow_matching_t_min: float = 0.01,
174
+ flow_matching_t_max: float = 0.99,
175
+ flow_matching_mask_scale: bool = False,
176
+ ) -> "RecursiveMaskedLM":
177
+ """Wrap an existing model for recursive refinement.
178
+
179
+ Use this for models not loadable via AutoModelForMaskedLM (e.g., LLaDA).
180
+
181
+ Args:
182
+ base_model: The base MLM model to wrap
183
+ num_recursions: Number of recursive refinement steps
184
+ normalization: Normalization method for logits (softmax, stable_softmax)
185
+ loss_weight: Loss weighting scheme (last_1, last_2, linear, uniform)
186
+ mask_token_id: Token ID for [MASK]
187
+ temperature: Temperature for softmax normalization
188
+ gradient_steps: Number of final steps to backprop through
189
+ schedule: Convergence schedule type ("linear" or "causal")
190
+ causal_strength: How much faster early positions converge (causal only)
191
+ temperature_max: Max temperature boost for uncertain positions
192
+ entropy_target_max: Target entropy at progress=0 (two-sided, recommended)
193
+ entropy_floor_max: Min entropy floor (one-sided)
194
+ smear_sigma_max: Max Gaussian sigma for position smearing
195
+ noise_std_max: Max std of Gaussian noise on logits
196
+ iteration_rope_dim_fraction: Fraction of dims for iteration RoPE
197
+ use_recursion_checkpointing: Enable gradient checkpointing for iterations
198
+ soft_embedding_method: How to convert logits to soft embeddings
199
+ soft_embedding_ema_step: EMA step size (1.0 = no EMA, <1.0 = blend with previous)
200
+ flow_matching_enabled: Enable CFM-inspired flow matching framework
201
+ flow_matching_lambda: Weight of distillation KL loss relative to CE
202
+ flow_matching_t_distribution: Time sampling distribution ("logit_normal" or "uniform")
203
+ flow_matching_t_logit_mean: Mean of logit-normal distribution
204
+ flow_matching_t_logit_std: Std of logit-normal distribution
205
+ flow_matching_t_min: Minimum time value (clamp)
206
+ flow_matching_t_max: Maximum time value (clamp)
207
+ flow_matching_mask_scale: Scale mask_emb by (1-t) if True, binary if False
208
+ """
209
+ config = RecursiveMLMConfig.from_base_model_config(
210
+ base_model.config,
211
+ num_recursions=num_recursions,
212
+ normalization=normalization,
213
+ loss_weight=loss_weight,
214
+ mask_token_id=mask_token_id,
215
+ temperature=temperature,
216
+ gradient_steps=gradient_steps,
217
+ schedule=schedule,
218
+ causal_strength=causal_strength,
219
+ temperature_max=temperature_max,
220
+ entropy_target_max=entropy_target_max,
221
+ entropy_floor_max=entropy_floor_max,
222
+ smear_sigma_max=smear_sigma_max,
223
+ noise_std_max=noise_std_max,
224
+ iteration_rope_dim_fraction=iteration_rope_dim_fraction,
225
+ use_recursion_checkpointing=use_recursion_checkpointing,
226
+ soft_embedding_method=soft_embedding_method,
227
+ soft_embedding_ema_step=soft_embedding_ema_step,
228
+ flow_matching_enabled=flow_matching_enabled,
229
+ flow_matching_lambda=flow_matching_lambda,
230
+ flow_matching_t_distribution=flow_matching_t_distribution,
231
+ flow_matching_t_logit_mean=flow_matching_t_logit_mean,
232
+ flow_matching_t_logit_std=flow_matching_t_logit_std,
233
+ flow_matching_t_min=flow_matching_t_min,
234
+ flow_matching_t_max=flow_matching_t_max,
235
+ flow_matching_mask_scale=flow_matching_mask_scale,
236
+ )
237
+ return cls(config, base_model=base_model)
238
+
239
+ @property
240
+ def embed_weight(self) -> torch.Tensor:
241
+ return self.mlm.get_input_embeddings().weight
242
+
243
+ def get_input_embeddings(self):
244
+ return self.mlm.get_input_embeddings()
245
+
246
+ def set_input_embeddings(self, value):
247
+ self.mlm.set_input_embeddings(value)
248
+
249
+ def get_output_embeddings(self):
250
+ return self.mlm.get_output_embeddings()
251
+
252
+ def set_output_embeddings(self, new_embeddings):
253
+ self.mlm.set_output_embeddings(new_embeddings)
254
+
255
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
256
+ """Enable gradient checkpointing with correct settings for recursion.
257
+
258
+ Forces use_reentrant=False which is required for:
259
+ - Nested checkpoint calls (base model + recursion checkpointing)
260
+ - Models with frozen parameters
261
+ - Complex gradient flows through soft embeddings
262
+ """
263
+ if gradient_checkpointing_kwargs is None:
264
+ gradient_checkpointing_kwargs = {}
265
+ # Force use_reentrant=False for nested checkpointing compatibility
266
+ gradient_checkpointing_kwargs.setdefault("use_reentrant", False)
267
+ self.mlm.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
268
+
269
+ def gradient_checkpointing_disable(self):
270
+ """Disable gradient checkpointing in the underlying MLM."""
271
+ self.mlm.gradient_checkpointing_disable()
272
+
273
+ def _single_iteration_checkpointable(
274
+ self,
275
+ soft_embeds: torch.Tensor,
276
+ base_embeds: torch.Tensor,
277
+ mask_pos: torch.Tensor,
278
+ attention_mask: torch.Tensor,
279
+ embed_weight: torch.Tensor,
280
+ mask_emb: torch.Tensor,
281
+ temperature: torch.Tensor,
282
+ position_ids: Optional[torch.Tensor] = None,
283
+ ) -> tuple[torch.Tensor, torch.Tensor]:
284
+ """
285
+ Single differentiable iteration for checkpointing.
286
+
287
+ This method performs one iteration of recursive refinement in a way that
288
+ maintains gradient flow and is compatible with torch.utils.checkpoint.
289
+
290
+ Args:
291
+ soft_embeds: (B, L, H) - current soft embeddings
292
+ base_embeds: (B, L, H) - original token embeddings
293
+ mask_pos: (B, L) bool - which positions are masked
294
+ attention_mask: (B, L) - attention mask for MLM
295
+ embed_weight: (V, H) - embedding weight matrix
296
+ mask_emb: (H,) - mask token embedding
297
+ temperature: scalar tensor - softmax temperature
298
+
299
+ Returns:
300
+ logits: (B, L, V) - output logits from this iteration
301
+ next_soft_embeds: (B, L, H) - soft embeddings for next iteration
302
+ """
303
+ # Blend: use soft_embeds at masked positions, base_embeds elsewhere
304
+ inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds)
305
+
306
+ # Forward through base MLM
307
+ outputs = self.mlm(
308
+ inputs_embeds=inputs_embeds,
309
+ attention_mask=attention_mask,
310
+ position_ids=position_ids,
311
+ return_dict=True,
312
+ )
313
+ logits = outputs.logits
314
+
315
+ # Compute soft embeddings for next iteration (DIFFERENTIABLE - no detach!)
316
+ next_soft_embeds = base_embeds.clone()
317
+ if mask_pos.any():
318
+ masked_logits = logits[mask_pos] # (num_masked, V)
319
+
320
+ # Convert logits to mixing weights based on soft_embedding_method
321
+ if self.config.soft_embedding_method == "none":
322
+ # No normalization - use raw logits directly
323
+ weights = masked_logits # Differentiable!
324
+ elif self.config.soft_embedding_method == "l2_normalize":
325
+ # L2 normalize logits - removes softmax bottleneck for smoother gradients
326
+ weights = F.normalize(masked_logits, p=2, dim=-1) # Differentiable!
327
+ else:
328
+ # Default: softmax normalization
329
+ weights = F.softmax(masked_logits / temperature, dim=-1) # Differentiable!
330
+
331
+ soft_emb = weights @ embed_weight + mask_emb # Differentiable!
332
+
333
+ # Apply EMA blending with previous soft embeddings if enabled
334
+ ema_step = self.config.soft_embedding_ema_step
335
+ if ema_step < 1.0:
336
+ prev_soft_emb = soft_embeds[mask_pos] # Previous iteration's soft embeddings
337
+ soft_emb = (1.0 - ema_step) * prev_soft_emb + ema_step * soft_emb
338
+
339
+ next_soft_embeds[mask_pos] = soft_emb
340
+
341
+ return logits, next_soft_embeds
342
+
343
+ def _stable_softmax(self, logits: torch.Tensor, T: float = 1.0, dim: int = -1, eps: float = 1e-12) -> torch.Tensor:
344
+ """Numerically stable softmax with temperature T > 0."""
345
+ z = logits / max(T, eps)
346
+ z = z - z.max(dim=dim, keepdim=True).values # subtract max
347
+ z = torch.exp(z) # safe since z <= 0
348
+ z_sum = z.sum(dim=dim, keepdim=True)
349
+ return z / z_sum.clamp(min=eps)
350
+
351
+ def normalize(self, logits: torch.Tensor) -> torch.Tensor:
352
+ """Normalize logits -> mixing weights. Shape: (B, L, V) -> (B, L, V)"""
353
+ norm = self.config.normalization.lower()
354
+ T = self.config.temperature
355
+ V = logits.shape[-1]
356
+
357
+ if norm == "none":
358
+ return logits
359
+
360
+ if norm == "softmax":
361
+ return torch.softmax(logits / T, dim=-1)
362
+
363
+ if norm == "stable_softmax":
364
+ return self._stable_softmax(logits, T=T, dim=-1)
365
+
366
+ raise ValueError(f"Unknown normalization: {norm}")
367
+
368
+ def step_weight(self, t: int, T: int) -> float:
369
+ """Loss weight for step t of T."""
370
+ lw = self.config.loss_weight
371
+ if lw == "linear":
372
+ return (t + 1) / T
373
+ if lw == "uniform":
374
+ return 1.0
375
+ if lw == "last_1":
376
+ return 1.0 if t == T - 1 else 0.0
377
+ if lw == "last_2":
378
+ return 1.0 if T - t <= 2 else 0.0
379
+ raise ValueError(f"Unknown loss_weight: {lw}")
380
+
381
+ # ==================== CONVERGENCE SCHEDULE SYSTEM ====================
382
+ #
383
+ # The core idea: control WHEN each position is allowed to converge.
384
+ #
385
+ # Schedule types:
386
+ # - "linear": All positions converge at the same rate
387
+ # - "causal": Early positions converge first, late positions last
388
+ #
389
+ # Effects (mechanisms to enforce the schedule):
390
+ # - temperature: Raise temperature for positions not yet allowed to converge
391
+ # - entropy_floor: Force minimum entropy
392
+ # - entropy_target: Force exact entropy via bisection (ARChitects-style)
393
+ # - smear: Spread probability across neighboring positions
394
+ # - noise: Add Gaussian noise to logits
395
+ #
396
+ # Each effect uses per-position "convergence progress" (0=uncertain, 1=can converge)
397
+
398
+ def _compute_convergence_progress(
399
+ self,
400
+ iteration: int,
401
+ total_iterations: int,
402
+ seq_length: int,
403
+ mask_positions: torch.Tensor,
404
+ schedule: str = "linear",
405
+ causal_strength: float = 1.0,
406
+ device: torch.device = None,
407
+ dtype: torch.dtype = None,
408
+ ) -> torch.Tensor:
409
+ """
410
+ Compute per-position convergence progress based on schedule.
411
+
412
+ Args:
413
+ iteration: Current iteration (0-indexed)
414
+ total_iterations: Total number of iterations
415
+ seq_length: Full sequence length L
416
+ mask_positions: Position indices of masked tokens (num_masked,)
417
+ schedule: "linear" or "causal"
418
+ causal_strength: How much faster early positions converge (for causal schedule)
419
+
420
+ Returns:
421
+ progress: (num_masked,) tensor with values in [0, 1]
422
+ 0 = position should be maximally uncertain
423
+ 1 = position is allowed to fully converge
424
+ """
425
+ base_progress = iteration / max(total_iterations - 1, 1)
426
+
427
+ if schedule == "linear":
428
+ return torch.full(
429
+ (mask_positions.shape[0],),
430
+ base_progress,
431
+ device=device,
432
+ dtype=dtype
433
+ )
434
+
435
+ elif schedule == "causal":
436
+ position_factor = mask_positions.float() / max(seq_length - 1, 1)
437
+ effective_progress = base_progress * (1.0 + causal_strength * (1.0 - position_factor))
438
+ return effective_progress.clamp(0.0, 1.0)
439
+
440
+ else:
441
+ raise ValueError(f"Unknown schedule: {schedule}")
442
+
443
+ def _apply_temperature_effect(
444
+ self,
445
+ logits: torch.Tensor,
446
+ progress: torch.Tensor,
447
+ temperature_max: float,
448
+ ) -> torch.Tensor:
449
+ """
450
+ Apply per-position temperature scaling based on convergence progress.
451
+ Low progress = high temperature (uncertain), high progress = temperature 1.0.
452
+ """
453
+ if temperature_max <= 0:
454
+ return logits
455
+
456
+ temperature = 1.0 + temperature_max * (1.0 - progress)
457
+ temperature = temperature.unsqueeze(-1)
458
+
459
+ return logits / temperature
460
+
461
+ def _apply_entropy_floor_effect(
462
+ self,
463
+ probs: torch.Tensor,
464
+ progress: torch.Tensor,
465
+ entropy_floor_max: float,
466
+ ) -> torch.Tensor:
467
+ """
468
+ Ensure minimum entropy based on convergence progress.
469
+ Low progress = high entropy floor, high progress = no floor.
470
+
471
+ NOTE: This is a ONE-SIDED constraint (floor only).
472
+ """
473
+ if entropy_floor_max <= 0:
474
+ return probs
475
+
476
+ entropy_floor = entropy_floor_max * (1.0 - progress)
477
+
478
+ log_probs = torch.log(probs + 1e-10)
479
+ current_entropy = -(probs * log_probs).sum(dim=-1)
480
+
481
+ below_floor = current_entropy < entropy_floor
482
+
483
+ if not below_floor.any():
484
+ return probs
485
+
486
+ logits = torch.log(probs + 1e-10)
487
+
488
+ target_ratio = entropy_floor / (current_entropy + 1e-10)
489
+ temperature = torch.ones_like(current_entropy)
490
+ temperature[below_floor] = target_ratio[below_floor].clamp(1.0, 10.0)
491
+
492
+ scaled_probs = torch.softmax(logits / temperature.unsqueeze(-1), dim=-1)
493
+
494
+ result = probs.clone()
495
+ result[below_floor] = scaled_probs[below_floor]
496
+ return result
497
+
498
+ def _find_temperature_for_target_entropy(
499
+ self,
500
+ logits: torch.Tensor,
501
+ target_entropy: torch.Tensor,
502
+ tol: float = 1e-3,
503
+ max_iter: int = 32,
504
+ T_low: float = 1e-6,
505
+ T_high_init: float = 1.0,
506
+ max_T: float = 100.0,
507
+ ) -> torch.Tensor:
508
+ """
509
+ Find per-position temperatures that achieve exactly the target entropy.
510
+ Uses bisection search, adapted from ARChitects' implementation.
511
+
512
+ Args:
513
+ logits: Raw logits (num_positions, V)
514
+ target_entropy: Target entropy per position (num_positions,) or scalar
515
+ tol: Entropy tolerance for convergence
516
+ max_iter: Maximum bisection iterations
517
+ T_low: Minimum temperature (near-greedy)
518
+ T_high_init: Initial upper bound for search
519
+ max_T: Maximum allowed temperature
520
+
521
+ Returns:
522
+ temperatures: (num_positions,) temperatures that achieve target entropy
523
+ """
524
+ N, V = logits.shape
525
+ device, dtype = logits.device, logits.dtype
526
+ H_max = torch.log(torch.tensor(V, device=device, dtype=dtype))
527
+
528
+ if target_entropy.dim() == 0:
529
+ target = target_entropy.expand(N).clone()
530
+ else:
531
+ target = target_entropy.clone()
532
+ target = target.clamp(0.0, H_max)
533
+
534
+ def compute_entropy(logits_: torch.Tensor, temps: torch.Tensor) -> torch.Tensor:
535
+ temps = temps.unsqueeze(-1).clamp(min=T_low)
536
+ scaled = logits_ / temps
537
+ scaled = scaled - scaled.max(dim=-1, keepdim=True).values
538
+ probs = torch.softmax(scaled, dim=-1)
539
+ log_probs = torch.log(probs + 1e-12)
540
+ return -(probs * log_probs).sum(dim=-1)
541
+
542
+ lo = torch.full((N,), T_low, device=device, dtype=dtype)
543
+ hi = torch.full((N,), T_high_init, device=device, dtype=dtype)
544
+
545
+ H_lo = compute_entropy(logits, lo)
546
+
547
+ done_low = target <= (H_lo + tol)
548
+
549
+ H_hi = compute_entropy(logits, hi)
550
+ needs_expansion = (H_hi < target - tol) & ~done_low
551
+
552
+ for _ in range(100):
553
+ if not needs_expansion.any():
554
+ break
555
+ hi[needs_expansion] = (hi[needs_expansion] * 2.0).clamp(max=max_T)
556
+ H_hi[needs_expansion] = compute_entropy(
557
+ logits[needs_expansion], hi[needs_expansion]
558
+ )
559
+ needs_expansion = (H_hi < target - tol) & ~done_low & (hi < max_T - 1e-6)
560
+
561
+ can_bisect = ~done_low & (H_hi >= target - tol)
562
+
563
+ for _ in range(max_iter):
564
+ if not can_bisect.any():
565
+ break
566
+
567
+ mid = (lo + hi) / 2.0
568
+ H_mid = compute_entropy(logits, mid)
569
+
570
+ too_low = (H_mid < target) & can_bisect
571
+ lo[too_low] = mid[too_low]
572
+ hi[~too_low & can_bisect] = mid[~too_low & can_bisect]
573
+
574
+ converged = (hi - lo) <= tol * mid.clamp(min=1.0)
575
+ can_bisect = can_bisect & ~converged
576
+
577
+ temps = torch.zeros(N, device=device, dtype=dtype)
578
+ temps[done_low] = T_low
579
+ temps[~done_low] = (lo[~done_low] + hi[~done_low]) / 2.0
580
+
581
+ return temps
582
+
583
+ def _apply_target_entropy_effect(
584
+ self,
585
+ logits: torch.Tensor,
586
+ progress: torch.Tensor,
587
+ entropy_target_max: float,
588
+ entropy_target_min: float = 0.0,
589
+ ) -> torch.Tensor:
590
+ """
591
+ Adjust temperature to achieve EXACTLY the target entropy per position.
592
+ This is a TWO-SIDED constraint: both raises and lowers entropy as needed.
593
+
594
+ Args:
595
+ logits: Raw logits (num_masked, V)
596
+ progress: Per-position convergence progress (num_masked,)
597
+ entropy_target_max: Target entropy at progress=0
598
+ entropy_target_min: Target entropy at progress=1 (usually ~0)
599
+
600
+ Returns:
601
+ probs: Probabilities with entropy matching targets
602
+ """
603
+ if entropy_target_max <= 0:
604
+ return torch.softmax(logits, dim=-1)
605
+
606
+ target_entropy = entropy_target_max * (1.0 - progress) + entropy_target_min * progress
607
+
608
+ temps = self._find_temperature_for_target_entropy(logits, target_entropy)
609
+
610
+ temps = temps.unsqueeze(-1).clamp(min=1e-6)
611
+ return torch.softmax(logits / temps, dim=-1)
612
+
613
+ def _apply_smear_effect(
614
+ self,
615
+ probs: torch.Tensor,
616
+ mask_pos: torch.Tensor,
617
+ progress_full: torch.Tensor,
618
+ smear_sigma_max: float,
619
+ ) -> torch.Tensor:
620
+ """
621
+ Apply positional smearing with per-position sigma based on progress.
622
+ Low progress = high smearing, high progress = no smearing.
623
+
624
+ Note: This operates on full (B, L, V) tensor because smearing mixes across positions.
625
+ """
626
+ if smear_sigma_max <= 0:
627
+ return probs
628
+
629
+ B, L, V = probs.shape
630
+
631
+ sigma_per_pos = smear_sigma_max * (1.0 - progress_full)
632
+
633
+ avg_sigma = sigma_per_pos[mask_pos].mean().item()
634
+
635
+ if avg_sigma < 0.1:
636
+ return probs
637
+
638
+ positions = torch.arange(L, device=probs.device, dtype=probs.dtype)
639
+ diff = positions.unsqueeze(0) - positions.unsqueeze(1)
640
+ kernel = torch.exp(-0.5 * (diff / avg_sigma) ** 2)
641
+ kernel = kernel / kernel.sum(dim=1, keepdim=True)
642
+
643
+ smeared = torch.einsum('ij,bjv->biv', kernel, probs)
644
+ smeared = smeared / smeared.sum(dim=-1, keepdim=True).clamp(min=1e-10)
645
+
646
+ blend = progress_full.unsqueeze(-1)
647
+ result = blend * probs + (1 - blend) * smeared
648
+
649
+ output = probs.clone()
650
+ output[mask_pos] = result[mask_pos]
651
+ return output
652
+
653
+ def _apply_noise_effect(
654
+ self,
655
+ logits: torch.Tensor,
656
+ progress: torch.Tensor,
657
+ noise_std_max: float,
658
+ ) -> torch.Tensor:
659
+ """
660
+ Add Gaussian noise to logits based on convergence progress.
661
+ Low progress = high noise, high progress = no noise.
662
+ """
663
+ if noise_std_max <= 0:
664
+ return logits
665
+
666
+ noise_std = noise_std_max * (1.0 - progress)
667
+ noise_std = noise_std.unsqueeze(-1)
668
+
669
+ noise = torch.randn_like(logits) * noise_std
670
+ return logits + noise
671
+
672
+ def _apply_iteration_rope(
673
+ self,
674
+ embeds: torch.Tensor,
675
+ iteration: int,
676
+ total_iterations: int,
677
+ dim_fraction: float = 0.25,
678
+ base: float = 10000.0,
679
+ ) -> torch.Tensor:
680
+ """
681
+ Apply rotary embedding based on iteration progress.
682
+ Uses a subset of dimensions to avoid interfering with position RoPE.
683
+ """
684
+ if dim_fraction <= 0:
685
+ return embeds
686
+
687
+ H = embeds.shape[-1]
688
+ rot_dim = int(H * dim_fraction)
689
+ rot_dim = rot_dim - (rot_dim % 2)
690
+
691
+ if rot_dim < 2:
692
+ return embeds
693
+
694
+ progress = iteration / max(total_iterations - 1, 1)
695
+
696
+ inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2, device=embeds.device, dtype=embeds.dtype) / rot_dim))
697
+ angles = progress * inv_freq * 3.14159
698
+ cos, sin = torch.cos(angles), torch.sin(angles)
699
+
700
+ if embeds.dim() == 2:
701
+ cos, sin = cos.unsqueeze(0), sin.unsqueeze(0)
702
+ elif embeds.dim() == 3:
703
+ cos = cos.unsqueeze(0).unsqueeze(0)
704
+ sin = sin.unsqueeze(0).unsqueeze(0)
705
+
706
+ embeds_out = embeds.clone()
707
+ x1, x2 = embeds[..., -rot_dim::2], embeds[..., -rot_dim+1::2]
708
+ embeds_out[..., -rot_dim::2] = x1 * cos - x2 * sin
709
+ embeds_out[..., -rot_dim+1::2] = x1 * sin + x2 * cos
710
+
711
+ return embeds_out
712
+
713
+ # ==================== FLOW MATCHING ====================
714
+
715
+ def _sample_flow_matching_t(self, num_tokens: int, device: torch.device) -> torch.Tensor:
716
+ """Sample per-token time levels for flow matching.
717
+
718
+ Returns:
719
+ t: (num_tokens,) tensor of time levels in [t_min, t_max]
720
+ """
721
+ dist = self.config.flow_matching_t_distribution
722
+ if dist == "logit_normal":
723
+ z = torch.randn(num_tokens, device=device)
724
+ z = z * self.config.flow_matching_t_logit_std + self.config.flow_matching_t_logit_mean
725
+ t = torch.sigmoid(z)
726
+ elif dist == "uniform":
727
+ t = torch.empty(num_tokens, device=device).uniform_(0, 1)
728
+ else:
729
+ raise ValueError(f"Unknown flow_matching_t_distribution: {dist}")
730
+ return t.clamp(self.config.flow_matching_t_min, self.config.flow_matching_t_max)
731
+
732
+ def compute_flow_matching_distillation_loss(
733
+ self,
734
+ input_ids: torch.Tensor,
735
+ teacher_logits: torch.Tensor,
736
+ labels: torch.Tensor,
737
+ flow_noise_embed: torch.Tensor,
738
+ flow_t: torch.Tensor,
739
+ attention_mask: Optional[torch.Tensor] = None,
740
+ position_ids: Optional[torch.Tensor] = None,
741
+ ) -> SelfDistillationOutput:
742
+ """
743
+ CFM flow matching distillation: teacher sees state at time t, student sees
744
+ noisier state at time s < t on the same interpolation path.
745
+
746
+ Both should predict the same endpoint (target token). The student must
747
+ learn to refine from noisier inputs by matching the teacher's predictions.
748
+
749
+ Args:
750
+ input_ids: Input with [MASK] tokens at positions to predict
751
+ teacher_logits: Logits from the forward pass (will be detached)
752
+ labels: Target tokens at masked positions (-100 elsewhere)
753
+ flow_noise_embed: (num_masked, H) noise embeddings from forward
754
+ flow_t: (num_masked,) per-token time levels from forward
755
+ attention_mask: Standard attention mask
756
+ position_ids: Position IDs (if needed by base model)
757
+
758
+ Returns:
759
+ SelfDistillationOutput with loss, logits, time gap, and diagnostics
760
+ """
761
+ mask_id = self.config.mask_token_id
762
+ mask_pos = (input_ids == mask_id) # (B, L)
763
+ device = input_ids.device
764
+ num_masked = mask_pos.sum().item()
765
+
766
+ if num_masked == 0:
767
+ zero = torch.tensor(0.0, device=device, requires_grad=True)
768
+ dummy = torch.zeros(1, device=device)
769
+ return SelfDistillationOutput(zero, dummy, dummy, 0.0, 0.0, 0.0, 1.0)
770
+
771
+ teacher_logits = teacher_logits.detach()
772
+
773
+ embed_weight = self.embed_weight
774
+ mask_emb = embed_weight[mask_id] # (H,)
775
+ base_embeds = self.get_input_embeddings()(input_ids) # (B, L, H)
776
+
777
+ # Target embeddings from labels
778
+ target_ids = labels[mask_pos] # (num_masked,)
779
+ target_embed = embed_weight[target_ids] # (num_masked, H)
780
+
781
+ # Sample student time s ~ U(0, t) per token
782
+ s_per_token = flow_t * torch.rand(num_masked, device=device) # (num_masked,)
783
+
784
+ # Student state: same noise, earlier time (noisier)
785
+ s_col = s_per_token.unsqueeze(-1).to(base_embeds.dtype) # (num_masked, 1)
786
+ student_interp = (1 - s_col) * flow_noise_embed + s_col * target_embed
787
+
788
+ if self.config.flow_matching_mask_scale:
789
+ student_masked_embeds = student_interp + (1 - s_col) * mask_emb
790
+ else:
791
+ student_masked_embeds = student_interp + mask_emb
792
+
793
+ # Build full student input (detached — gradient only flows through student's forward)
794
+ student_embeds = base_embeds.detach().clone()
795
+ student_embeds[mask_pos] = student_masked_embeds.detach()
796
+
797
+ student_inputs = torch.where(
798
+ mask_pos.unsqueeze(-1), student_embeds, base_embeds.detach()
799
+ )
800
+
801
+ if attention_mask is None:
802
+ attention_mask = torch.ones_like(input_ids, dtype=base_embeds.dtype)
803
+
804
+ student_out = self.mlm(
805
+ inputs_embeds=student_inputs,
806
+ attention_mask=attention_mask,
807
+ position_ids=position_ids,
808
+ return_dict=True,
809
+ )
810
+ student_logits = student_out.logits # (B, L, V) — has gradient
811
+
812
+ # KL divergence loss on masked positions
813
+ t_logits = teacher_logits[mask_pos] # (num_masked, V)
814
+ s_logits = student_logits[mask_pos] # (num_masked, V)
815
+
816
+ teacher_probs = F.softmax(t_logits, dim=-1)
817
+ student_log_probs = F.log_softmax(s_logits, dim=-1)
818
+
819
+ kl_loss = F.kl_div(
820
+ student_log_probs,
821
+ teacher_probs,
822
+ reduction="batchmean",
823
+ )
824
+
825
+ # Diagnostic metrics
826
+ with torch.no_grad():
827
+ teacher_log_probs = torch.log(teacher_probs + 1e-10)
828
+ teacher_entropy = -(teacher_probs * teacher_log_probs).sum(dim=-1).mean().item()
829
+
830
+ student_probs = F.softmax(s_logits.detach(), dim=-1)
831
+ student_log_probs_det = torch.log(student_probs + 1e-10)
832
+ student_entropy = -(student_probs * student_log_probs_det).sum(dim=-1).mean().item()
833
+
834
+ agreement = (t_logits.argmax(dim=-1) == s_logits.detach().argmax(dim=-1)).float().mean().item()
835
+
836
+ mean_time_gap = (flow_t - s_per_token).mean().item()
837
+
838
+ return SelfDistillationOutput(
839
+ loss=kl_loss,
840
+ teacher_logits=teacher_logits,
841
+ student_logits=student_logits,
842
+ degradation_temperature=mean_time_gap,
843
+ teacher_entropy=teacher_entropy,
844
+ student_entropy=student_entropy,
845
+ agreement_rate=agreement,
846
+ )
847
+
848
+ # ==================== SELF-DISTILLATION (legacy) ====================
849
+
850
+ def compute_self_distillation_loss(
851
+ self,
852
+ input_ids: torch.Tensor,
853
+ teacher_logits: torch.Tensor,
854
+ attention_mask: Optional[torch.Tensor] = None,
855
+ position_ids: Optional[torch.Tensor] = None,
856
+ temperature_min: Optional[float] = None,
857
+ temperature_max: Optional[float] = None,
858
+ temperature_distribution: Optional[str] = None,
859
+ ) -> SelfDistillationOutput:
860
+ """
861
+ CFM-style self-distillation: model's predictions should be consistent
862
+ across different levels of input degradation.
863
+
864
+ Process:
865
+ 1. Take teacher logits (from standard forward pass, DETACHED)
866
+ 2. Degrade: per-token random temperature → softer soft embeddings
867
+ 3. Student: forward pass from degraded embeddings → logits (has grad)
868
+ 4. Loss: KL(teacher || student) on masked positions
869
+
870
+ Each masked token gets its own independently sampled degradation
871
+ temperature, creating varied difficulty across the sequence.
872
+
873
+ Args:
874
+ input_ids: Input with [MASK] tokens at positions to predict
875
+ teacher_logits: Pre-computed teacher logits (will be detached).
876
+ Typically outputs.all_logits[0] or outputs.logits from standard forward.
877
+ attention_mask: Standard attention mask
878
+ position_ids: Position IDs (if needed by base model)
879
+ temperature_min: Min degradation temperature (default: config value)
880
+ temperature_max: Max degradation temperature (default: config value)
881
+ temperature_distribution: How to sample T (default: config value)
882
+
883
+ Returns:
884
+ SelfDistillationOutput with loss, logits, temperature, and diagnostics
885
+ """
886
+ # Resolve defaults from config
887
+ temperature_min = temperature_min if temperature_min is not None else self.config.self_distillation_temperature_min
888
+ temperature_max = temperature_max if temperature_max is not None else self.config.self_distillation_temperature_max
889
+ temperature_distribution = temperature_distribution if temperature_distribution is not None else self.config.self_distillation_temperature_distribution
890
+
891
+ mask_id = self.config.mask_token_id
892
+ mask_pos = (input_ids == mask_id) # (B, L)
893
+ device = input_ids.device
894
+ num_masked = mask_pos.sum().item()
895
+
896
+ # Handle degenerate case: no masked positions
897
+ if num_masked == 0:
898
+ zero = torch.tensor(0.0, device=device, requires_grad=True)
899
+ dummy = torch.zeros(1, device=device)
900
+ return SelfDistillationOutput(zero, dummy, dummy, 1.0, 0.0, 0.0, 1.0)
901
+
902
+ # Ensure teacher logits are detached
903
+ teacher_logits = teacher_logits.detach()
904
+
905
+ embed_weight = self.embed_weight
906
+ mask_emb = embed_weight[mask_id] # (H,)
907
+ base_embeds = self.get_input_embeddings()(input_ids) # (B, L, H)
908
+
909
+ # ===== STEP 1: Sample per-token degradation temperatures =====
910
+ # Each masked position gets its own temperature independently
911
+ if temperature_distribution == "log_uniform":
912
+ log_min = torch.tensor(temperature_min, device=device).log()
913
+ log_max = torch.tensor(temperature_max, device=device).log()
914
+ log_T = torch.empty(num_masked, device=device).uniform_(log_min.item(), log_max.item())
915
+ T_per_token = log_T.exp() # (num_masked,)
916
+ elif temperature_distribution == "uniform":
917
+ T_per_token = torch.empty(num_masked, device=device).uniform_(
918
+ temperature_min, temperature_max
919
+ ) # (num_masked,)
920
+ else:
921
+ raise ValueError(f"Unknown temperature distribution: {temperature_distribution}")
922
+
923
+ T_mean = T_per_token.mean().item()
924
+
925
+ # ===== STEP 2: Create degraded soft embeddings =====
926
+ # Per-token temperature scaling: each position gets its own T
927
+ masked_teacher_logits = teacher_logits[mask_pos] # (num_masked, V)
928
+ degraded_probs = F.softmax(masked_teacher_logits / T_per_token.unsqueeze(-1), dim=-1).to(embed_weight.dtype)
929
+ degraded_soft = degraded_probs @ embed_weight + mask_emb
930
+
931
+ degraded_soft_embeds = base_embeds.clone()
932
+ degraded_soft_embeds[mask_pos] = degraded_soft
933
+ degraded_soft_embeds = degraded_soft_embeds.detach()
934
+
935
+ # ===== STEP 3: Student forward from degraded input =====
936
+ student_inputs = torch.where(
937
+ mask_pos.unsqueeze(-1), degraded_soft_embeds, base_embeds.detach()
938
+ )
939
+
940
+ if attention_mask is None:
941
+ attention_mask = torch.ones_like(input_ids, dtype=base_embeds.dtype)
942
+
943
+ student_out = self.mlm(
944
+ inputs_embeds=student_inputs,
945
+ attention_mask=attention_mask,
946
+ position_ids=position_ids,
947
+ return_dict=True,
948
+ )
949
+ student_logits = student_out.logits # (B, L, V) — has gradient!
950
+
951
+ # ===== STEP 4: KL divergence loss on masked positions =====
952
+ t_logits = teacher_logits[mask_pos] # (num_masked, V)
953
+ s_logits = student_logits[mask_pos] # (num_masked, V)
954
+
955
+ teacher_probs = F.softmax(t_logits, dim=-1)
956
+ student_log_probs = F.log_softmax(s_logits, dim=-1)
957
+
958
+ # KL(teacher || student) = sum teacher * (log_teacher - log_student)
959
+ kl_loss = F.kl_div(
960
+ student_log_probs,
961
+ teacher_probs,
962
+ reduction="batchmean",
963
+ )
964
+
965
+ # ===== STEP 5: Compute diagnostic metrics =====
966
+ with torch.no_grad():
967
+ teacher_log_probs = torch.log(teacher_probs + 1e-10)
968
+ teacher_entropy = -(teacher_probs * teacher_log_probs).sum(dim=-1).mean().item()
969
+
970
+ student_probs = F.softmax(s_logits.detach(), dim=-1)
971
+ student_log_probs_det = torch.log(student_probs + 1e-10)
972
+ student_entropy = -(student_probs * student_log_probs_det).sum(dim=-1).mean().item()
973
+
974
+ agreement = (t_logits.argmax(dim=-1) == s_logits.detach().argmax(dim=-1)).float().mean().item()
975
+
976
+ return SelfDistillationOutput(
977
+ loss=kl_loss,
978
+ teacher_logits=teacher_logits,
979
+ student_logits=student_logits,
980
+ degradation_temperature=T_mean,
981
+ teacher_entropy=teacher_entropy,
982
+ student_entropy=student_entropy,
983
+ agreement_rate=agreement,
984
+ )
985
+
986
+ # ==================== MAIN SOFT EMBEDDING COMPUTATION ====================
987
+
988
+ @torch.no_grad()
989
+ def _compute_next_soft_embeds(
990
+ self,
991
+ logits: torch.Tensor,
992
+ mask_pos: torch.Tensor,
993
+ base_embeds: torch.Tensor,
994
+ prev_soft_embeds: Optional[torch.Tensor] = None,
995
+ iteration: int = 0,
996
+ total_iterations: int = 1,
997
+ # === Schedule parameters (default to config values) ===
998
+ schedule: Optional[str] = None,
999
+ causal_strength: Optional[float] = None,
1000
+ # === Effect parameters (default to config values) ===
1001
+ temperature_max: Optional[float] = None,
1002
+ entropy_target_max: Optional[float] = None,
1003
+ entropy_floor_max: Optional[float] = None,
1004
+ smear_sigma_max: Optional[float] = None,
1005
+ noise_std_max: Optional[float] = None,
1006
+ iteration_rope_dim_fraction: Optional[float] = None,
1007
+ ) -> torch.Tensor:
1008
+ """
1009
+ Compute soft embeddings from logits for the next iteration.
1010
+
1011
+ This function implements a unified "convergence schedule" system that controls
1012
+ when each position is allowed to converge to a confident prediction.
1013
+
1014
+ Schedule Types:
1015
+ "linear": All positions converge at the same rate (iteration-based only)
1016
+ "causal": Early positions converge first, late positions last
1017
+
1018
+ Effects (mechanisms to enforce the schedule):
1019
+ temperature_max: High temperature = more uniform distribution (one-sided)
1020
+ entropy_target_max: Force EXACT entropy via bisection search (two-sided, recommended)
1021
+ entropy_floor_max: Force MINIMUM entropy (one-sided, only prevents too confident)
1022
+ smear_sigma_max: Spread probability across neighboring positions
1023
+ noise_std_max: Add Gaussian noise to logits
1024
+
1025
+ All parameters default to their config values if not specified.
1026
+
1027
+ Args:
1028
+ logits: Output logits from current iteration (B, L, V)
1029
+ mask_pos: Boolean mask indicating which positions are masked (B, L)
1030
+ base_embeds: Base token embeddings for non-masked positions (B, L, H)
1031
+ iteration: Current iteration index (0-indexed)
1032
+ total_iterations: Total number of iterations
1033
+
1034
+ Returns:
1035
+ Soft embeddings for next iteration (B, L, H)
1036
+ """
1037
+ # Use config values as defaults
1038
+ schedule = schedule if schedule is not None else self.config.schedule
1039
+ causal_strength = causal_strength if causal_strength is not None else self.config.causal_strength
1040
+ temperature_max = temperature_max if temperature_max is not None else self.config.temperature_max
1041
+ entropy_target_max = entropy_target_max if entropy_target_max is not None else self.config.entropy_target_max
1042
+ entropy_floor_max = entropy_floor_max if entropy_floor_max is not None else self.config.entropy_floor_max
1043
+ smear_sigma_max = smear_sigma_max if smear_sigma_max is not None else self.config.smear_sigma_max
1044
+ noise_std_max = noise_std_max if noise_std_max is not None else self.config.noise_std_max
1045
+ iteration_rope_dim_fraction = iteration_rope_dim_fraction if iteration_rope_dim_fraction is not None else self.config.iteration_rope_dim_fraction
1046
+
1047
+ soft_embeds = base_embeds.clone()
1048
+
1049
+ if not mask_pos.any():
1050
+ return soft_embeds.detach()
1051
+
1052
+ B, L, V = logits.shape
1053
+ device, dtype = logits.device, logits.dtype
1054
+
1055
+ # Check if any effects are enabled
1056
+ has_effects = (
1057
+ temperature_max > 0 or
1058
+ entropy_target_max > 0 or
1059
+ entropy_floor_max > 0 or
1060
+ smear_sigma_max > 0 or
1061
+ noise_std_max > 0 or
1062
+ iteration_rope_dim_fraction > 0
1063
+ )
1064
+
1065
+ if not has_effects:
1066
+ # Simple path: no convergence schedule effects
1067
+ masked_logits = logits[mask_pos]
1068
+ embed_weight = self.embed_weight
1069
+
1070
+ # Convert logits to mixing weights based on soft_embedding_method
1071
+ if self.config.soft_embedding_method == "none":
1072
+ weights = masked_logits
1073
+ elif self.config.soft_embedding_method == "l2_normalize":
1074
+ weights = F.normalize(masked_logits, p=2, dim=-1)
1075
+ else:
1076
+ weights = self.normalize(masked_logits)
1077
+
1078
+ masked_soft = weights @ embed_weight
1079
+ mask_emb = embed_weight[self.config.mask_token_id]
1080
+ masked_soft = masked_soft + mask_emb
1081
+
1082
+ # Apply EMA blending with previous soft embeddings if enabled
1083
+ ema_step = self.config.soft_embedding_ema_step
1084
+ if ema_step < 1.0 and prev_soft_embeds is not None:
1085
+ prev_masked_soft = prev_soft_embeds[mask_pos]
1086
+ masked_soft = (1.0 - ema_step) * prev_masked_soft + ema_step * masked_soft
1087
+
1088
+ soft_embeds[mask_pos] = masked_soft
1089
+ return soft_embeds.detach()
1090
+
1091
+ # ========== STEP 1: Compute per-position convergence progress ==========
1092
+ batch_indices, position_indices = torch.where(mask_pos)
1093
+
1094
+ progress = self._compute_convergence_progress(
1095
+ iteration=iteration,
1096
+ total_iterations=total_iterations,
1097
+ seq_length=L,
1098
+ mask_positions=position_indices,
1099
+ schedule=schedule,
1100
+ causal_strength=causal_strength,
1101
+ device=device,
1102
+ dtype=dtype,
1103
+ )
1104
+
1105
+ # Compute full (B, L) progress for smearing if needed
1106
+ if smear_sigma_max > 0:
1107
+ all_positions = torch.arange(L, device=device, dtype=dtype)
1108
+ progress_full = self._compute_convergence_progress(
1109
+ iteration=iteration,
1110
+ total_iterations=total_iterations,
1111
+ seq_length=L,
1112
+ mask_positions=all_positions,
1113
+ schedule=schedule,
1114
+ causal_strength=causal_strength,
1115
+ device=device,
1116
+ dtype=dtype,
1117
+ )
1118
+ progress_full = progress_full.unsqueeze(0).expand(B, -1)
1119
+
1120
+ # ========== STEP 2: Apply smearing (needs full tensor) ==========
1121
+ full_probs = self.normalize(logits)
1122
+
1123
+ if smear_sigma_max > 0:
1124
+ full_probs = self._apply_smear_effect(
1125
+ full_probs, mask_pos, progress_full, smear_sigma_max
1126
+ )
1127
+
1128
+ # ========== STEP 3: Extract masked positions ==========
1129
+ masked_logits = logits[mask_pos]
1130
+ masked_probs = full_probs[mask_pos]
1131
+
1132
+ # ========== STEP 4: Apply temperature effect (on logits) ==========
1133
+ if temperature_max > 0 and entropy_target_max <= 0:
1134
+ masked_logits = self._apply_temperature_effect(
1135
+ masked_logits, progress, temperature_max
1136
+ )
1137
+ masked_probs = torch.softmax(masked_logits, dim=-1)
1138
+
1139
+ # ========== STEP 5: Apply noise effect (on logits) ==========
1140
+ if noise_std_max > 0:
1141
+ masked_logits_noisy = self._apply_noise_effect(
1142
+ torch.log(masked_probs + 1e-10), progress, noise_std_max
1143
+ )
1144
+ masked_probs = torch.softmax(masked_logits_noisy, dim=-1)
1145
+
1146
+ # ========== STEP 6: Apply entropy control ==========
1147
+ if entropy_target_max > 0:
1148
+ masked_probs = self._apply_target_entropy_effect(
1149
+ masked_logits, progress, entropy_target_max
1150
+ )
1151
+ elif entropy_floor_max > 0:
1152
+ masked_probs = self._apply_entropy_floor_effect(
1153
+ masked_probs, progress, entropy_floor_max
1154
+ )
1155
+
1156
+ # ========== STEP 7: Compute soft embeddings ==========
1157
+ embed_weight = self.embed_weight
1158
+
1159
+ # Convert to mixing weights based on soft_embedding_method
1160
+ if self.config.soft_embedding_method == "none":
1161
+ # No normalization - use raw logits directly
1162
+ weights = masked_logits
1163
+ elif self.config.soft_embedding_method == "l2_normalize":
1164
+ # L2 normalize bypasses all the softmax-based effects above
1165
+ weights = F.normalize(masked_logits, p=2, dim=-1)
1166
+ else:
1167
+ weights = masked_probs
1168
+
1169
+ masked_soft = weights @ embed_weight
1170
+ mask_emb = embed_weight[self.config.mask_token_id]
1171
+ masked_soft = masked_soft + mask_emb
1172
+
1173
+ # ========== STEP 8: Apply iteration RoPE ==========
1174
+ if iteration_rope_dim_fraction > 0:
1175
+ masked_soft = self._apply_iteration_rope(
1176
+ masked_soft, iteration, total_iterations, iteration_rope_dim_fraction
1177
+ )
1178
+
1179
+ # ========== STEP 8.5: Apply EMA blending ==========
1180
+ ema_step = self.config.soft_embedding_ema_step
1181
+ if ema_step < 1.0 and prev_soft_embeds is not None:
1182
+ prev_masked_soft = prev_soft_embeds[mask_pos]
1183
+ masked_soft = (1.0 - ema_step) * prev_masked_soft + ema_step * masked_soft
1184
+
1185
+ # ========== STEP 9: Place back and return ==========
1186
+ soft_embeds[mask_pos] = masked_soft
1187
+
1188
+ return soft_embeds.detach()
1189
+
1190
+ @torch.no_grad()
1191
+ def _compute_iteration_metrics(
1192
+ self, logits: torch.Tensor, labels: torch.Tensor
1193
+ ) -> IterationMetrics:
1194
+ """
1195
+ Compute token-level AND sequence-level metrics for a single iteration.
1196
+ Returns scalars only - no large tensor storage.
1197
+
1198
+ Token-level metrics:
1199
+ - accuracy: fraction of correct token predictions
1200
+ - entropy: average entropy per token
1201
+ - softmax_ce: cross-entropy loss per token
1202
+
1203
+ Sequence-level metrics:
1204
+ - full_sequence_accuracy: fraction of sequences where ALL tokens are correct
1205
+ - min_sequence_confidence: mean of minimum top-1 confidence per sequence
1206
+ """
1207
+ B = logits.shape[0]
1208
+
1209
+ # Move to CPU to avoid GPU OOM - metrics are for monitoring only
1210
+ logits = logits.detach().cpu().float() # float32 is sufficient for metrics
1211
+ target_labels = labels.detach().cpu().contiguous()
1212
+ mask = target_labels != -100
1213
+
1214
+ if mask.sum() == 0:
1215
+ return IterationMetrics(
1216
+ accuracy=0.0,
1217
+ entropy=0.0,
1218
+ softmax_ce=0.0,
1219
+ full_sequence_accuracy=0.0,
1220
+ min_sequence_confidence=0.0,
1221
+ )
1222
+
1223
+ logits = logits.contiguous()
1224
+ predictions = logits.argmax(dim=-1)
1225
+ correct = (predictions == target_labels) & mask
1226
+
1227
+ # ===== TOKEN-LEVEL METRICS =====
1228
+
1229
+ # Token accuracy
1230
+ accuracy = (correct.sum() / mask.sum()).item()
1231
+
1232
+ # Extract valid tokens for entropy/CE
1233
+ valid_logits = logits[mask]
1234
+ valid_labels = target_labels[mask]
1235
+
1236
+ # Entropy (using log_softmax for numerical stability)
1237
+ log_probs = torch.nn.functional.log_softmax(valid_logits, dim=-1)
1238
+ probs = torch.exp(log_probs)
1239
+ entropy = -(probs * log_probs).sum(dim=-1).mean().item()
1240
+
1241
+ # Cross-entropy
1242
+ softmax_ce = torch.nn.functional.cross_entropy(
1243
+ valid_logits, valid_labels, reduction="mean"
1244
+ ).item()
1245
+
1246
+ # ===== SEQUENCE-LEVEL METRICS =====
1247
+
1248
+ # Check which sequences have valid tokens
1249
+ sequences_with_tokens = mask.any(dim=1) # (B,)
1250
+ num_valid_sequences = sequences_with_tokens.sum().item()
1251
+
1252
+ if num_valid_sequences == 0:
1253
+ return IterationMetrics(
1254
+ accuracy=accuracy,
1255
+ entropy=entropy,
1256
+ softmax_ce=softmax_ce,
1257
+ full_sequence_accuracy=0.0,
1258
+ min_sequence_confidence=0.0,
1259
+ )
1260
+
1261
+ # Full sequence accuracy: all tokens in sequence must be correct
1262
+ num_correct_per_seq = correct.sum(dim=1) # (B,)
1263
+ num_tokens_per_seq = mask.sum(dim=1) # (B,)
1264
+ all_correct = (num_correct_per_seq == num_tokens_per_seq) & sequences_with_tokens
1265
+ full_seq_accuracy = (all_correct.sum() / num_valid_sequences).item()
1266
+
1267
+ # Min sequence confidence: minimum top-1 probability within each sequence
1268
+ probs_full = torch.softmax(logits, dim=-1) # (B, L, V) - already float32
1269
+ top1_confidence = probs_full.max(dim=-1).values # (B, L)
1270
+
1271
+ min_confidences = []
1272
+ for i in range(B):
1273
+ if sequences_with_tokens[i]:
1274
+ seq_confidences = top1_confidence[i][mask[i]] # (num_tokens_in_seq,)
1275
+ min_confidences.append(seq_confidences.min().item())
1276
+
1277
+ min_seq_conf = sum(min_confidences) / len(min_confidences) if min_confidences else 0.0
1278
+
1279
+ return IterationMetrics(
1280
+ accuracy=accuracy,
1281
+ entropy=entropy,
1282
+ softmax_ce=softmax_ce,
1283
+ full_sequence_accuracy=full_seq_accuracy,
1284
+ min_sequence_confidence=min_seq_conf,
1285
+ )
1286
+
1287
+ def _single_iteration(
1288
+ self,
1289
+ t: int,
1290
+ T: int,
1291
+ soft_embeds: torch.Tensor,
1292
+ base_embeds: torch.Tensor,
1293
+ mask_pos: torch.Tensor,
1294
+ attention_mask: Optional[torch.Tensor],
1295
+ labels: Optional[torch.Tensor],
1296
+ compute_metrics: bool,
1297
+ position_ids: Optional[torch.Tensor] = None,
1298
+ **kwargs,
1299
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[IterationMetrics]]:
1300
+ """
1301
+ Execute a single iteration of recursive refinement.
1302
+
1303
+ Args:
1304
+ t: Current iteration index (0 to T-1)
1305
+ T: Total number of iterations
1306
+ soft_embeds: Soft embeddings for mask positions
1307
+ base_embeds: Base token embeddings from input_ids
1308
+ mask_pos: Boolean mask of [MASK] positions (B, L)
1309
+ attention_mask: Attention mask for MLM
1310
+ labels: Target labels for loss computation
1311
+ compute_metrics: Whether to compute iteration metrics
1312
+
1313
+ Returns:
1314
+ logits: Output logits from MLM (B, L, V)
1315
+ weighted_loss: Loss weighted by step_weight(t, T), or None if no labels
1316
+ metrics: IterationMetrics, or None if not requested
1317
+ """
1318
+ # Blend soft embeddings (at mask positions) with base embeddings (at non-mask positions)
1319
+ inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds)
1320
+
1321
+ # Forward through base MLM
1322
+ outputs = self.mlm(
1323
+ inputs_embeds=inputs_embeds,
1324
+ attention_mask=attention_mask,
1325
+ position_ids=position_ids,
1326
+ labels=labels,
1327
+ return_dict=True,
1328
+ **kwargs,
1329
+ )
1330
+
1331
+ # Compute weighted loss for this iteration
1332
+ weighted_loss = outputs.loss
1333
+ if labels is not None:
1334
+ if weighted_loss is None:
1335
+ # Base model doesn't compute loss (e.g., LLaDA) - compute it ourselves
1336
+ # Only compute loss on MASKED positions (MDLM training)
1337
+ masked_logits = outputs.logits[mask_pos] # (num_masked, V)
1338
+ masked_labels = labels[mask_pos] # (num_masked,)
1339
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1340
+ weighted_loss = loss_fct(masked_logits, masked_labels)
1341
+ weighted_loss *= self.step_weight(t, T)
1342
+
1343
+ # Compute iteration metrics if requested
1344
+ metrics = None
1345
+ if compute_metrics and labels is not None:
1346
+ metrics = self._compute_iteration_metrics(outputs.logits, labels)
1347
+
1348
+ return outputs.logits, weighted_loss, metrics
1349
+
1350
+ def forward(
1351
+ self,
1352
+ input_ids: torch.Tensor,
1353
+ attention_mask: Optional[torch.Tensor] = None,
1354
+ labels: Optional[torch.Tensor] = None,
1355
+ position_ids: Optional[torch.Tensor] = None,
1356
+ num_recursions: Optional[int] = None,
1357
+ compute_iteration_metrics: bool = False,
1358
+ use_recursion_checkpointing: Optional[bool] = None,
1359
+ # Parameters for single-iteration training mode (DEPRECATED)
1360
+ prev_soft_embeds: Optional[torch.Tensor] = None,
1361
+ run_set_iteration: Optional[int] = None,
1362
+ # === Convergence schedule parameters (None = use config defaults) ===
1363
+ schedule: Optional[str] = None,
1364
+ causal_strength: Optional[float] = None,
1365
+ # === Effect parameters (None = use config defaults) ===
1366
+ temperature_max: Optional[float] = None,
1367
+ entropy_target_max: Optional[float] = None,
1368
+ entropy_floor_max: Optional[float] = None,
1369
+ smear_sigma_max: Optional[float] = None,
1370
+ noise_std_max: Optional[float] = None,
1371
+ iteration_rope_dim_fraction: Optional[float] = None,
1372
+ **kwargs,
1373
+ ) -> RecursiveMaskedLMOutput:
1374
+ """
1375
+ Forward with recursive refinement.
1376
+
1377
+ Supports three modes:
1378
+ 1. Checkpointed mode (default): Run all T recursions with gradient checkpointing.
1379
+ Gradients flow through the entire chain; activations recomputed during backward.
1380
+ 2. Non-checkpointed mode (use_recursion_checkpointing=False): Store all activations.
1381
+ Faster backward but higher memory.
1382
+ 3. Single-iteration mode (DEPRECATED - run_set_iteration is not None): Run only one
1383
+ iteration. Use use_recursion_checkpointing=True instead.
1384
+
1385
+ Loss Weighting (config.loss_weight):
1386
+ "last_1": Only final iteration loss (enables learning convergence behavior)
1387
+ "last_2": Last 2 iterations
1388
+ "linear": All iterations, linearly weighted (default)
1389
+ "uniform": All iterations, uniformly weighted
1390
+
1391
+ Recursion Checkpointing:
1392
+ use_recursion_checkpointing: Enable gradient checkpointing for iterations.
1393
+ True = checkpoint each iteration, recompute during backward (default).
1394
+ False = store all activations (higher memory, faster backward).
1395
+
1396
+ Convergence Schedule Parameters:
1397
+ All schedule/effect parameters default to their config values if not specified.
1398
+ Pass explicit values to override config for this forward pass.
1399
+
1400
+ schedule: "linear" or "causal" - controls when positions can converge
1401
+ causal_strength: How much faster early positions converge (causal only)
1402
+ temperature_max: Max temperature boost for uncertain positions
1403
+ entropy_target_max: Target entropy at progress=0 (two-sided, recommended)
1404
+ entropy_floor_max: Min entropy floor (one-sided)
1405
+ smear_sigma_max: Max Gaussian sigma for position smearing
1406
+ noise_std_max: Max std of Gaussian noise on logits
1407
+ iteration_rope_dim_fraction: Fraction of dims for iteration RoPE
1408
+ """
1409
+ B, L = input_ids.shape
1410
+ V = self.embed_weight.shape[0]
1411
+ mask_id = self.config.mask_token_id
1412
+
1413
+ if mask_id is None:
1414
+ raise ValueError("mask_token_id must be set")
1415
+
1416
+ # Resolve config default for recursion checkpointing
1417
+ use_recursion_checkpointing = (
1418
+ use_recursion_checkpointing
1419
+ if use_recursion_checkpointing is not None
1420
+ else self.config.use_recursion_checkpointing
1421
+ )
1422
+
1423
+ mask_pos = (input_ids == mask_id) # (B, L)
1424
+ base_embeds = self.get_input_embeddings()(input_ids) # (B, L, H)
1425
+ T = num_recursions or self.config.num_recursions
1426
+ weight_sum = sum(self.step_weight(i, T) for i in range(T))
1427
+
1428
+ # Bundle schedule kwargs to pass to _compute_next_soft_embeds
1429
+ schedule_kwargs = dict(
1430
+ schedule=schedule,
1431
+ causal_strength=causal_strength,
1432
+ temperature_max=temperature_max,
1433
+ entropy_target_max=entropy_target_max,
1434
+ entropy_floor_max=entropy_floor_max,
1435
+ smear_sigma_max=smear_sigma_max,
1436
+ noise_std_max=noise_std_max,
1437
+ iteration_rope_dim_fraction=iteration_rope_dim_fraction,
1438
+ )
1439
+
1440
+ # ===== SINGLE ITERATION MODE (DEPRECATED) =====
1441
+ if run_set_iteration is not None:
1442
+ warnings.warn(
1443
+ "run_set_iteration is deprecated. Use use_recursion_checkpointing=True instead, "
1444
+ "which provides proper gradient flow through all iterations.",
1445
+ DeprecationWarning,
1446
+ stacklevel=2,
1447
+ )
1448
+ t = run_set_iteration
1449
+
1450
+ # Get soft embeddings for this iteration
1451
+ if t == 0:
1452
+ # t=0: Uniform prior = average embedding (equivalent to softmax(zeros) @ embed_weight)
1453
+ # We compute this efficiently via embed_weight.mean() rather than creating large zero tensors
1454
+ soft_embeds = base_embeds.clone()
1455
+ if mask_pos.any():
1456
+ avg_embed = self.embed_weight.mean(dim=0) # (H,) - mean over all V tokens
1457
+ mask_emb = self.embed_weight[mask_id]
1458
+ soft_embeds[mask_pos] = avg_embed + mask_emb
1459
+ else:
1460
+ if prev_soft_embeds is None:
1461
+ raise ValueError(f"prev_soft_embeds must be provided for iteration {t}")
1462
+ soft_embeds = prev_soft_embeds
1463
+
1464
+ logits, weighted_loss, metrics = self._single_iteration(
1465
+ t, T, soft_embeds, base_embeds, mask_pos,
1466
+ attention_mask, labels, compute_iteration_metrics,
1467
+ position_ids=position_ids, **kwargs
1468
+ )
1469
+
1470
+ # Normalize loss by total weight sum
1471
+ loss = weighted_loss / weight_sum if weighted_loss is not None else None
1472
+
1473
+ # Compute soft embeddings for next iteration (if not last)
1474
+ next_soft_embeds = None
1475
+ if t < T - 1:
1476
+ next_soft_embeds = self._compute_next_soft_embeds(
1477
+ logits, mask_pos, base_embeds,
1478
+ iteration=t,
1479
+ total_iterations=T,
1480
+ **schedule_kwargs,
1481
+ )
1482
+
1483
+ return RecursiveMaskedLMOutput(
1484
+ loss=loss,
1485
+ logits=logits,
1486
+ next_soft_embeds=next_soft_embeds,
1487
+ iteration_metrics={t: metrics} if metrics is not None else None,
1488
+ )
1489
+
1490
+ # ===== CHECKPOINTED MODE (gradient flow through all iterations) =====
1491
+ embed_weight = self.embed_weight
1492
+ mask_emb = embed_weight[mask_id] # (H,)
1493
+
1494
+ # Temperature must be a tensor for checkpointing (checkpoint requires tensor inputs)
1495
+ temperature = torch.tensor(
1496
+ self.config.temperature,
1497
+ device=input_ids.device,
1498
+ dtype=base_embeds.dtype,
1499
+ )
1500
+
1501
+ # Ensure attention_mask is a tensor (required for checkpointing)
1502
+ if attention_mask is None:
1503
+ attention_mask = torch.ones(B, L, device=input_ids.device, dtype=base_embeds.dtype)
1504
+
1505
+ # Initialize soft embeddings for masked positions
1506
+ soft_embeds = base_embeds.clone()
1507
+ flow_noise_embed = None
1508
+ flow_t_per_token = None
1509
+
1510
+ if self.config.flow_matching_enabled and self.training and labels is not None and mask_pos.any():
1511
+ # Flow matching: interpolate between random noise and target on the simplex
1512
+ num_masked = mask_pos.sum().item()
1513
+ V = embed_weight.shape[0]
1514
+ device = input_ids.device
1515
+
1516
+ # Sample per-token time levels (logit-normal by default)
1517
+ flow_t_per_token = self._sample_flow_matching_t(num_masked, device)
1518
+
1519
+ # Random noise embedding: sample on simplex, project to H-dim
1520
+ z = torch.randn(num_masked, V, device=device, dtype=base_embeds.dtype)
1521
+ p_noise = F.softmax(z * self.config.flow_matching_noise_scale, dim=-1).to(base_embeds.dtype)
1522
+ flow_noise_embed = p_noise @ embed_weight # (num_masked, H) — compact
1523
+
1524
+ # Target embedding from labels
1525
+ target_ids = labels[mask_pos] # original token IDs at masked positions
1526
+ target_embed = embed_weight[target_ids] # (num_masked, H)
1527
+
1528
+ # Interpolate in embedding space
1529
+ t_col = flow_t_per_token.unsqueeze(-1).to(base_embeds.dtype) # (num_masked, 1)
1530
+ interp_embed = (1 - t_col) * flow_noise_embed + t_col * target_embed
1531
+
1532
+ # Add mask signal (binary or scaled)
1533
+ if self.config.flow_matching_mask_scale:
1534
+ soft_embeds[mask_pos] = interp_embed + (1 - t_col) * mask_emb
1535
+ else:
1536
+ soft_embeds[mask_pos] = interp_embed + mask_emb
1537
+ elif mask_pos.any():
1538
+ # Standard uniform prior (average embedding + mask signal)
1539
+ avg_embed = embed_weight.mean(dim=0) # (H,)
1540
+ soft_embeds[mask_pos] = avg_embed + mask_emb
1541
+
1542
+ iteration_metrics = {} if compute_iteration_metrics and labels is not None else None
1543
+
1544
+ # Main recursion loop with optional checkpointing
1545
+ all_logits = []
1546
+ for t in range(T):
1547
+ if self.training and use_recursion_checkpointing:
1548
+ # Use checkpointing: activations recomputed during backward
1549
+ # This maintains gradient flow while saving memory
1550
+ logits, soft_embeds = torch_checkpoint(
1551
+ self._single_iteration_checkpointable,
1552
+ soft_embeds,
1553
+ base_embeds,
1554
+ mask_pos,
1555
+ attention_mask,
1556
+ embed_weight,
1557
+ mask_emb,
1558
+ temperature,
1559
+ position_ids,
1560
+ use_reentrant=False, # Critical for nested checkpointing!
1561
+ )
1562
+ else:
1563
+ # No checkpointing: store all activations (inference or explicit disable)
1564
+ logits, soft_embeds = self._single_iteration_checkpointable(
1565
+ soft_embeds,
1566
+ base_embeds,
1567
+ mask_pos,
1568
+ attention_mask,
1569
+ embed_weight,
1570
+ mask_emb,
1571
+ temperature,
1572
+ position_ids,
1573
+ )
1574
+ all_logits.append(logits)
1575
+
1576
+ # Compute iteration metrics if requested (no grad needed)
1577
+ if iteration_metrics is not None and labels is not None:
1578
+ with torch.no_grad():
1579
+ iteration_metrics[t] = self._compute_iteration_metrics(logits, labels)
1580
+
1581
+ # Return all logits for trainer to compute loss with proper normalization
1582
+ # Trainer handles: timestep-based weighting, iteration weighting, batch/sequence/token normalization
1583
+ return RecursiveMaskedLMOutput(
1584
+ loss=None, # Let trainer compute loss
1585
+ logits=logits, # Final logits for inference/metrics
1586
+ all_logits=all_logits if self.training else None, # Only needed during training
1587
+ iteration_metrics=iteration_metrics or None,
1588
+ flow_noise_embed=flow_noise_embed, # For flow matching distillation
1589
+ flow_t=flow_t_per_token, # For flow matching distillation
1590
+ )
1591
+
1592
+ @torch.no_grad()
1593
+ def _generate_flow_map(
1594
+ self,
1595
+ input_ids: torch.Tensor,
1596
+ attention_mask: Optional[torch.Tensor],
1597
+ position_ids: Optional[torch.Tensor],
1598
+ num_steps: int,
1599
+ ) -> torch.Tensor:
1600
+ """Fill in mask positions using the CFM flow map update rule.
1601
+
1602
+ Starts from a random point on the probability simplex and iteratively
1603
+ moves toward the model's predictions using the flow map step rule.
1604
+
1605
+ Args:
1606
+ input_ids: Input with [MASK] tokens at positions to fill
1607
+ attention_mask: Attention mask
1608
+ position_ids: Position IDs
1609
+ num_steps: Number of flow map steps (finer = better, 1 step = greedy)
1610
+
1611
+ Returns:
1612
+ Tensor with [MASK] positions filled with predicted tokens
1613
+ """
1614
+ mask_pos = (input_ids == self.config.mask_token_id)
1615
+ num_masked = mask_pos.sum().item()
1616
+
1617
+ if num_masked == 0:
1618
+ return input_ids.clone()
1619
+
1620
+ device = input_ids.device
1621
+ V = self.embed_weight.shape[0]
1622
+ embed_weight = self.embed_weight
1623
+ mask_emb = embed_weight[self.config.mask_token_id]
1624
+ base_embeds = self.get_input_embeddings()(input_ids)
1625
+
1626
+ # Start from random simplex point
1627
+ noise_scale = self.config.flow_matching_noise_scale
1628
+ p = F.softmax(torch.randn(num_masked, V, device=device, dtype=base_embeds.dtype) * noise_scale, dim=-1).to(base_embeds.dtype)
1629
+
1630
+ times = torch.linspace(0, 1, num_steps + 1, device=device)
1631
+
1632
+ for i in range(num_steps):
1633
+ t_now = times[i]
1634
+ t_next = times[i + 1]
1635
+ step_size = (t_next - t_now) / (1 - t_now)
1636
+
1637
+ # Mask signal (binary or scaled)
1638
+ if self.config.flow_matching_mask_scale:
1639
+ mask_signal = (1 - t_now) * mask_emb
1640
+ else:
1641
+ mask_signal = mask_emb
1642
+
1643
+ # Project current state to embedding space
1644
+ embed = p @ embed_weight + mask_signal
1645
+
1646
+ soft_embeds = base_embeds.clone()
1647
+ soft_embeds[mask_pos] = embed
1648
+ inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds)
1649
+
1650
+ outputs = self.mlm(
1651
+ inputs_embeds=inputs_embeds,
1652
+ attention_mask=attention_mask,
1653
+ position_ids=position_ids,
1654
+ return_dict=True,
1655
+ )
1656
+ pi = F.softmax(outputs.logits[mask_pos], dim=-1).to(p.dtype)
1657
+
1658
+ # Flow map update: move toward model's prediction
1659
+ p = p + step_size * (pi - p)
1660
+
1661
+ # Fix floating point drift off the simplex
1662
+ p = p.clamp(min=0)
1663
+ p = p / p.sum(dim=-1, keepdim=True)
1664
+
1665
+ result = input_ids.clone()
1666
+ result[mask_pos] = p.argmax(dim=-1)
1667
+ return result
1668
+
1669
+ @torch.no_grad()
1670
+ def generate(
1671
+ self,
1672
+ input_ids: torch.Tensor,
1673
+ attention_mask: Optional[torch.Tensor] = None,
1674
+ position_ids: Optional[torch.Tensor] = None,
1675
+ num_recursions: Optional[int] = None,
1676
+ # === Convergence schedule parameters (None = use config defaults) ===
1677
+ schedule: Optional[str] = None,
1678
+ causal_strength: Optional[float] = None,
1679
+ # === Effect parameters (None = use config defaults) ===
1680
+ temperature_max: Optional[float] = None,
1681
+ entropy_target_max: Optional[float] = None,
1682
+ entropy_floor_max: Optional[float] = None,
1683
+ smear_sigma_max: Optional[float] = None,
1684
+ noise_std_max: Optional[float] = None,
1685
+ iteration_rope_dim_fraction: Optional[float] = None,
1686
+ ) -> torch.Tensor:
1687
+ """Fill in mask positions via iterative refinement.
1688
+
1689
+ When flow_matching_enabled, uses the CFM flow map update rule.
1690
+ Otherwise, uses standard recursive soft-token refinement.
1691
+
1692
+ Args:
1693
+ input_ids: Input token IDs with [MASK] tokens at positions to fill
1694
+ attention_mask: Attention mask
1695
+ num_recursions: Override number of recursions/steps (default: config value)
1696
+ schedule: "linear" or "causal" convergence schedule
1697
+ causal_strength: How much faster early positions converge (causal only)
1698
+ temperature_max: Max temperature boost for uncertain positions
1699
+ entropy_target_max: Target entropy at progress=0 (two-sided)
1700
+ entropy_floor_max: Min entropy floor (one-sided)
1701
+ smear_sigma_max: Max Gaussian sigma for position smearing
1702
+ noise_std_max: Max std of Gaussian noise on logits
1703
+ iteration_rope_dim_fraction: Fraction of dims for iteration RoPE
1704
+
1705
+ Returns:
1706
+ Tensor with [MASK] positions filled with predicted tokens
1707
+ """
1708
+ num_steps = num_recursions or self.config.num_recursions
1709
+
1710
+ if self.config.flow_matching_enabled:
1711
+ return self._generate_flow_map(
1712
+ input_ids, attention_mask, position_ids, num_steps
1713
+ )
1714
+
1715
+ out = self.forward(
1716
+ input_ids,
1717
+ attention_mask,
1718
+ position_ids=position_ids,
1719
+ num_recursions=num_steps,
1720
+ schedule=schedule,
1721
+ causal_strength=causal_strength,
1722
+ temperature_max=temperature_max,
1723
+ entropy_target_max=entropy_target_max,
1724
+ entropy_floor_max=entropy_floor_max,
1725
+ smear_sigma_max=smear_sigma_max,
1726
+ noise_std_max=noise_std_max,
1727
+ iteration_rope_dim_fraction=iteration_rope_dim_fraction,
1728
+ )
1729
+ result = input_ids.clone()
1730
+ mask_pos = (input_ids == self.config.mask_token_id)
1731
+ result[mask_pos] = out.logits.argmax(dim=-1)[mask_pos]
1732
+ return result