Upload folder using huggingface_hub
Browse files- config.json +8 -3
- configuration_recursive.py +179 -0
- modeling_recursive.py +1732 -0
config.json
CHANGED
|
@@ -17,7 +17,8 @@
|
|
| 17 |
"auto_map": {
|
| 18 |
"AutoConfig": "configuration_llada.LLaDAConfig",
|
| 19 |
"AutoModel": "modeling_llada.LLaDAModelLM",
|
| 20 |
-
"AutoModelForCausalLM": "modeling_llada.LLaDAModelLM"
|
|
|
|
| 21 |
},
|
| 22 |
"bad_words_ids": null,
|
| 23 |
"begin_suppress_tokens": null,
|
|
@@ -149,5 +150,9 @@
|
|
| 149 |
"soft_embedding_method": "softmax",
|
| 150 |
"temperature_max": 0.0,
|
| 151 |
"transformers_version": "4.57.0",
|
| 152 |
-
"use_recursion_checkpointing": true
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
"auto_map": {
|
| 18 |
"AutoConfig": "configuration_llada.LLaDAConfig",
|
| 19 |
"AutoModel": "modeling_llada.LLaDAModelLM",
|
| 20 |
+
"AutoModelForCausalLM": "modeling_llada.LLaDAModelLM",
|
| 21 |
+
"AutoModelForMaskedLM": "modeling_llada.LLaDAModelLM"
|
| 22 |
},
|
| 23 |
"bad_words_ids": null,
|
| 24 |
"begin_suppress_tokens": null,
|
|
|
|
| 150 |
"soft_embedding_method": "softmax",
|
| 151 |
"temperature_max": 0.0,
|
| 152 |
"transformers_version": "4.57.0",
|
| 153 |
+
"use_recursion_checkpointing": true,
|
| 154 |
+
"auto_map": {
|
| 155 |
+
"AutoConfig": "configuration_recursive.RecursiveMLMConfig",
|
| 156 |
+
"AutoModel": "modeling_recursive.RecursiveMaskedLM"
|
| 157 |
+
}
|
| 158 |
+
}
|
configuration_recursive.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
from transformers import PretrainedConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RecursiveMLMConfig(PretrainedConfig):
|
| 8 |
+
"""
|
| 9 |
+
Configuration for RecursiveMaskedLM.
|
| 10 |
+
|
| 11 |
+
Stores the base MLM config plus recursive refinement parameters.
|
| 12 |
+
|
| 13 |
+
Convergence Schedule System
|
| 14 |
+
---------------------------
|
| 15 |
+
The convergence schedule controls WHEN each position is allowed to converge
|
| 16 |
+
to a confident prediction during iterative refinement.
|
| 17 |
+
|
| 18 |
+
Schedule types:
|
| 19 |
+
- "linear": All positions converge at the same rate (iteration-based only)
|
| 20 |
+
- "causal": Early positions converge first, late positions last
|
| 21 |
+
|
| 22 |
+
Effects (mechanisms to enforce the schedule):
|
| 23 |
+
- temperature_max: Raise temperature for positions not yet allowed to converge
|
| 24 |
+
- entropy_target_max: Force exact entropy via bisection search (two-sided, recommended)
|
| 25 |
+
- entropy_floor_max: Force minimum entropy (one-sided, only raises)
|
| 26 |
+
- smear_sigma_max: Spread probability across neighboring positions
|
| 27 |
+
- noise_std_max: Add Gaussian noise to logits
|
| 28 |
+
- iteration_rope_dim_fraction: Apply rotary embedding based on iteration progress
|
| 29 |
+
|
| 30 |
+
Soft Embedding Methods
|
| 31 |
+
----------------------
|
| 32 |
+
Controls how logits are converted to soft embeddings for the next iteration:
|
| 33 |
+
- "softmax": Standard softmax normalization (default). Creates sparse, probabilistic
|
| 34 |
+
mixing but can cause gradient bottlenecks through the softmax Jacobian.
|
| 35 |
+
- "l2_normalize": L2 normalize logits before mixing with embeddings. Removes the
|
| 36 |
+
softmax bottleneck for smoother gradients through long recursion chains.
|
| 37 |
+
- "none": No normalization - use raw logits directly. Warning: this can cause
|
| 38 |
+
scale explosion without additional mechanisms like EMA accumulation.
|
| 39 |
+
|
| 40 |
+
- soft_embedding_ema_step: Controls EMA blending with previous soft embeddings.
|
| 41 |
+
1.0 (default) = full update (no EMA), 0.1 = slow update (90% previous + 10% new).
|
| 42 |
+
Formula: new = (1 - ema_step) * prev + ema_step * current
|
| 43 |
+
|
| 44 |
+
Recursion Checkpointing
|
| 45 |
+
-----------------------
|
| 46 |
+
Controls gradient flow through the entire recursion chain for memory-efficient training.
|
| 47 |
+
|
| 48 |
+
Parameters:
|
| 49 |
+
- use_recursion_checkpointing: Enable gradient checkpointing for iterations
|
| 50 |
+
- loss_weight: Use "last_1" for final-iteration-only loss (learns convergence behavior)
|
| 51 |
+
|
| 52 |
+
Flow Matching (CFM-inspired)
|
| 53 |
+
----------------------------
|
| 54 |
+
Replaces the old temperature-based self-distillation with a Continuous Flow Matching
|
| 55 |
+
framework. Training inputs are interpolated on the probability simplex between random
|
| 56 |
+
noise and the target one-hot, distillation gives the student a noisier (earlier-time)
|
| 57 |
+
version of the same interpolation path, and inference uses a flow map update rule.
|
| 58 |
+
|
| 59 |
+
Parameters:
|
| 60 |
+
- flow_matching_enabled: Enable the flow matching framework
|
| 61 |
+
- flow_matching_lambda: Weight of distillation KL loss relative to CE loss
|
| 62 |
+
- flow_matching_t_distribution: How to sample time t ("logit_normal" or "uniform")
|
| 63 |
+
- flow_matching_t_logit_mean: Mean of logit-normal distribution (-0.4 biases toward noisy)
|
| 64 |
+
- flow_matching_t_logit_std: Std of logit-normal distribution
|
| 65 |
+
- flow_matching_t_min: Minimum time value (clamp)
|
| 66 |
+
- flow_matching_t_max: Maximum time value (clamp)
|
| 67 |
+
- flow_matching_mask_scale: If True, scale mask_emb by (1-t); if False, binary mask signal
|
| 68 |
+
|
| 69 |
+
Time levels are sampled independently per masked token. At t=0 the input is pure noise,
|
| 70 |
+
at t=1 it is the clean target embedding.
|
| 71 |
+
|
| 72 |
+
Self-Distillation (legacy, temperature-based)
|
| 73 |
+
----------------------------------------------
|
| 74 |
+
Kept for backward compatibility. Ignored when flow_matching_enabled=True.
|
| 75 |
+
|
| 76 |
+
Parameters:
|
| 77 |
+
- self_distillation_enabled: Enable the self-distillation KL loss
|
| 78 |
+
- self_distillation_lambda: Weight of distillation loss relative to CE loss
|
| 79 |
+
- self_distillation_temperature_min: Minimum degradation temperature
|
| 80 |
+
- self_distillation_temperature_max: Maximum degradation temperature
|
| 81 |
+
- self_distillation_temperature_distribution: How to sample temperature
|
| 82 |
+
- self_distillation_teacher: Which logits to use as teacher ("first" or "last")
|
| 83 |
+
"""
|
| 84 |
+
model_type = "recursive-mlm"
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
base_model_config: Optional[dict] = None,
|
| 89 |
+
num_recursions: int = 8,
|
| 90 |
+
normalization: str = "softmax",
|
| 91 |
+
loss_weight: str = "linear",
|
| 92 |
+
mask_token_id: Optional[int] = None,
|
| 93 |
+
temperature: float = 1.0,
|
| 94 |
+
gradient_steps: Optional[int] = None,
|
| 95 |
+
# === Convergence schedule parameters ===
|
| 96 |
+
schedule: str = "linear",
|
| 97 |
+
causal_strength: float = 1.0,
|
| 98 |
+
# === Effect parameters ===
|
| 99 |
+
temperature_max: float = 0.0,
|
| 100 |
+
entropy_target_max: float = 0.0,
|
| 101 |
+
entropy_floor_max: float = 0.0,
|
| 102 |
+
smear_sigma_max: float = 0.0,
|
| 103 |
+
noise_std_max: float = 0.0,
|
| 104 |
+
iteration_rope_dim_fraction: float = 0.0,
|
| 105 |
+
use_recursion_checkpointing: bool = True,
|
| 106 |
+
# === Soft embedding method ===
|
| 107 |
+
soft_embedding_method: str = "softmax",
|
| 108 |
+
soft_embedding_ema_step: float = 1.0,
|
| 109 |
+
# === Flow matching parameters (CFM-inspired) ===
|
| 110 |
+
flow_matching_enabled: bool = False,
|
| 111 |
+
flow_matching_lambda: float = 0.5,
|
| 112 |
+
flow_matching_t_distribution: str = "logit_normal",
|
| 113 |
+
flow_matching_t_logit_mean: float = -0.4,
|
| 114 |
+
flow_matching_t_logit_std: float = 1.0,
|
| 115 |
+
flow_matching_t_min: float = 0.01,
|
| 116 |
+
flow_matching_t_max: float = 0.99,
|
| 117 |
+
flow_matching_noise_scale: float = 2.0,
|
| 118 |
+
flow_matching_mask_scale: bool = False,
|
| 119 |
+
# === Self-distillation parameters (legacy, ignored when flow_matching_enabled) ===
|
| 120 |
+
self_distillation_enabled: bool = False,
|
| 121 |
+
self_distillation_lambda: float = 0.5,
|
| 122 |
+
self_distillation_temperature_min: float = 1.5,
|
| 123 |
+
self_distillation_temperature_max: float = 10.0,
|
| 124 |
+
self_distillation_temperature_distribution: str = "log_uniform",
|
| 125 |
+
self_distillation_teacher: str = "first",
|
| 126 |
+
**kwargs,
|
| 127 |
+
):
|
| 128 |
+
super().__init__(**kwargs)
|
| 129 |
+
self.base_model_config = base_model_config
|
| 130 |
+
self.num_recursions = num_recursions
|
| 131 |
+
self.normalization = normalization
|
| 132 |
+
self.loss_weight = loss_weight
|
| 133 |
+
self.mask_token_id = mask_token_id
|
| 134 |
+
self.temperature = temperature
|
| 135 |
+
self.gradient_steps = gradient_steps
|
| 136 |
+
# Convergence schedule
|
| 137 |
+
self.schedule = schedule
|
| 138 |
+
self.causal_strength = causal_strength
|
| 139 |
+
# Effects
|
| 140 |
+
self.temperature_max = temperature_max
|
| 141 |
+
self.entropy_target_max = entropy_target_max
|
| 142 |
+
self.entropy_floor_max = entropy_floor_max
|
| 143 |
+
self.smear_sigma_max = smear_sigma_max
|
| 144 |
+
self.noise_std_max = noise_std_max
|
| 145 |
+
self.iteration_rope_dim_fraction = iteration_rope_dim_fraction
|
| 146 |
+
# Recursion checkpointing
|
| 147 |
+
self.use_recursion_checkpointing = use_recursion_checkpointing
|
| 148 |
+
# Soft embedding method
|
| 149 |
+
self.soft_embedding_method = soft_embedding_method
|
| 150 |
+
self.soft_embedding_ema_step = soft_embedding_ema_step
|
| 151 |
+
# Flow matching
|
| 152 |
+
self.flow_matching_enabled = flow_matching_enabled
|
| 153 |
+
self.flow_matching_lambda = flow_matching_lambda
|
| 154 |
+
self.flow_matching_t_distribution = flow_matching_t_distribution
|
| 155 |
+
self.flow_matching_t_logit_mean = flow_matching_t_logit_mean
|
| 156 |
+
self.flow_matching_t_logit_std = flow_matching_t_logit_std
|
| 157 |
+
self.flow_matching_t_min = flow_matching_t_min
|
| 158 |
+
self.flow_matching_t_max = flow_matching_t_max
|
| 159 |
+
self.flow_matching_noise_scale = flow_matching_noise_scale
|
| 160 |
+
self.flow_matching_mask_scale = flow_matching_mask_scale
|
| 161 |
+
# Self-distillation (legacy)
|
| 162 |
+
self.self_distillation_enabled = self_distillation_enabled
|
| 163 |
+
self.self_distillation_lambda = self_distillation_lambda
|
| 164 |
+
self.self_distillation_temperature_min = self_distillation_temperature_min
|
| 165 |
+
self.self_distillation_temperature_max = self_distillation_temperature_max
|
| 166 |
+
self.self_distillation_temperature_distribution = self_distillation_temperature_distribution
|
| 167 |
+
self.self_distillation_teacher = self_distillation_teacher
|
| 168 |
+
|
| 169 |
+
@classmethod
|
| 170 |
+
def from_base_model_config(
|
| 171 |
+
cls,
|
| 172 |
+
base_config: PretrainedConfig,
|
| 173 |
+
**kwargs,
|
| 174 |
+
) -> "RecursiveMLMConfig":
|
| 175 |
+
"""Create config from a base MLM's config."""
|
| 176 |
+
return cls(
|
| 177 |
+
base_model_config=base_config.to_dict(),
|
| 178 |
+
**kwargs,
|
| 179 |
+
)
|
modeling_recursive.py
ADDED
|
@@ -0,0 +1,1732 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import warnings
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import NamedTuple, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.nn import CrossEntropyLoss
|
| 9 |
+
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
| 10 |
+
from transformers import AutoConfig, AutoModelForMaskedLM, PreTrainedModel
|
| 11 |
+
from transformers.modeling_outputs import MaskedLMOutput
|
| 12 |
+
from transformers.utils import ModelOutput
|
| 13 |
+
|
| 14 |
+
from .configuration_recursive import RecursiveMLMConfig
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class IterationMetrics(ModelOutput):
|
| 19 |
+
"""Metrics for a single iteration of recursive refinement."""
|
| 20 |
+
accuracy: Optional[float] = None
|
| 21 |
+
entropy: Optional[float] = None
|
| 22 |
+
softmax_ce: Optional[float] = None
|
| 23 |
+
full_sequence_accuracy: Optional[float] = None
|
| 24 |
+
min_sequence_confidence: Optional[float] = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class RecursiveMaskedLMOutput(MaskedLMOutput):
|
| 29 |
+
iteration_metrics: Optional[dict[int, IterationMetrics]] = None # Maps iteration index to metrics
|
| 30 |
+
next_soft_embeds: Optional[torch.Tensor] = None # For caching between training steps
|
| 31 |
+
all_logits: Optional[list[torch.Tensor]] = None # All T iterations' logits for trainer loss computation
|
| 32 |
+
# Flow matching state (for distillation — compact H-dim, not V-dim)
|
| 33 |
+
flow_noise_embed: Optional[torch.Tensor] = None # (num_masked, H) noise embedding
|
| 34 |
+
flow_t: Optional[torch.Tensor] = None # (num_masked,) per-token time levels
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class SelfDistillationOutput(NamedTuple):
|
| 38 |
+
"""Output from self-distillation forward pass."""
|
| 39 |
+
loss: torch.Tensor # KL divergence loss (scalar, has grad)
|
| 40 |
+
teacher_logits: torch.Tensor # For metrics/debugging (detached)
|
| 41 |
+
student_logits: torch.Tensor # For metrics/debugging (has grad)
|
| 42 |
+
degradation_temperature: float # Mean per-token temperature sampled
|
| 43 |
+
teacher_entropy: float # Entropy of teacher distribution (for monitoring)
|
| 44 |
+
student_entropy: float # Entropy of student distribution (for monitoring)
|
| 45 |
+
agreement_rate: float # Fraction where teacher and student argmax agree
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class RecursiveMaskedLM(PreTrainedModel):
|
| 49 |
+
"""
|
| 50 |
+
Wraps any HF MLM with recursive soft-token refinement.
|
| 51 |
+
|
| 52 |
+
At each step:
|
| 53 |
+
1. Normalize logits -> probs
|
| 54 |
+
2. Compute soft embeddings: probs @ embedding_weight + mask_embedding
|
| 55 |
+
3. Forward through MLM
|
| 56 |
+
4. Accumulate weighted loss
|
| 57 |
+
"""
|
| 58 |
+
config_class = RecursiveMLMConfig
|
| 59 |
+
base_model_prefix = "mlm"
|
| 60 |
+
supports_gradient_checkpointing = True
|
| 61 |
+
|
| 62 |
+
def __init__(self, config: RecursiveMLMConfig, base_model: Optional[PreTrainedModel] = None):
|
| 63 |
+
super().__init__(config)
|
| 64 |
+
|
| 65 |
+
if base_model is not None:
|
| 66 |
+
# Pre-trained model provided - assign directly WITHOUT calling post_init()
|
| 67 |
+
# to avoid reinitializing the pre-trained weights via _init_weights()
|
| 68 |
+
self.mlm = base_model
|
| 69 |
+
elif config.base_model_config is not None:
|
| 70 |
+
base_config = AutoConfig.for_model(**config.base_model_config)
|
| 71 |
+
self.mlm = AutoModelForMaskedLM.from_config(base_config)
|
| 72 |
+
# Only call post_init() for freshly created models (needs weight init)
|
| 73 |
+
self.post_init()
|
| 74 |
+
else:
|
| 75 |
+
raise ValueError("Need either base_model or config.base_model_config")
|
| 76 |
+
|
| 77 |
+
@classmethod
|
| 78 |
+
def from_mlm_pretrained(
|
| 79 |
+
cls,
|
| 80 |
+
mlm_name_or_path: str,
|
| 81 |
+
num_recursions: int = 8,
|
| 82 |
+
normalization: str = "softmax",
|
| 83 |
+
loss_weight: str = "linear",
|
| 84 |
+
mask_token_id: Optional[int] = None,
|
| 85 |
+
temperature: float = 1.0,
|
| 86 |
+
gradient_steps: Optional[int] = None,
|
| 87 |
+
# === Convergence schedule parameters ===
|
| 88 |
+
schedule: str = "linear",
|
| 89 |
+
causal_strength: float = 1.0,
|
| 90 |
+
# === Effect parameters ===
|
| 91 |
+
temperature_max: float = 0.0,
|
| 92 |
+
entropy_target_max: float = 0.0,
|
| 93 |
+
entropy_floor_max: float = 0.0,
|
| 94 |
+
smear_sigma_max: float = 0.0,
|
| 95 |
+
noise_std_max: float = 0.0,
|
| 96 |
+
iteration_rope_dim_fraction: float = 0.0,
|
| 97 |
+
use_recursion_checkpointing: bool = True,
|
| 98 |
+
# === Soft embedding method ===
|
| 99 |
+
soft_embedding_method: str = "softmax",
|
| 100 |
+
soft_embedding_ema_step: float = 1.0,
|
| 101 |
+
# === Flow matching parameters ===
|
| 102 |
+
flow_matching_enabled: bool = False,
|
| 103 |
+
flow_matching_lambda: float = 0.5,
|
| 104 |
+
flow_matching_t_distribution: str = "logit_normal",
|
| 105 |
+
flow_matching_t_logit_mean: float = -0.4,
|
| 106 |
+
flow_matching_t_logit_std: float = 1.0,
|
| 107 |
+
flow_matching_t_min: float = 0.01,
|
| 108 |
+
flow_matching_t_max: float = 0.99,
|
| 109 |
+
flow_matching_mask_scale: bool = False,
|
| 110 |
+
**model_kwargs,
|
| 111 |
+
) -> "RecursiveMaskedLM":
|
| 112 |
+
"""Load a pretrained MLM and wrap it for recursive refinement."""
|
| 113 |
+
base_model = AutoModelForMaskedLM.from_pretrained(mlm_name_or_path, **model_kwargs)
|
| 114 |
+
return cls.from_base_model(
|
| 115 |
+
base_model,
|
| 116 |
+
num_recursions=num_recursions,
|
| 117 |
+
normalization=normalization,
|
| 118 |
+
loss_weight=loss_weight,
|
| 119 |
+
mask_token_id=mask_token_id,
|
| 120 |
+
temperature=temperature,
|
| 121 |
+
gradient_steps=gradient_steps,
|
| 122 |
+
schedule=schedule,
|
| 123 |
+
causal_strength=causal_strength,
|
| 124 |
+
temperature_max=temperature_max,
|
| 125 |
+
entropy_target_max=entropy_target_max,
|
| 126 |
+
entropy_floor_max=entropy_floor_max,
|
| 127 |
+
smear_sigma_max=smear_sigma_max,
|
| 128 |
+
noise_std_max=noise_std_max,
|
| 129 |
+
iteration_rope_dim_fraction=iteration_rope_dim_fraction,
|
| 130 |
+
use_recursion_checkpointing=use_recursion_checkpointing,
|
| 131 |
+
soft_embedding_method=soft_embedding_method,
|
| 132 |
+
soft_embedding_ema_step=soft_embedding_ema_step,
|
| 133 |
+
flow_matching_enabled=flow_matching_enabled,
|
| 134 |
+
flow_matching_lambda=flow_matching_lambda,
|
| 135 |
+
flow_matching_t_distribution=flow_matching_t_distribution,
|
| 136 |
+
flow_matching_t_logit_mean=flow_matching_t_logit_mean,
|
| 137 |
+
flow_matching_t_logit_std=flow_matching_t_logit_std,
|
| 138 |
+
flow_matching_t_min=flow_matching_t_min,
|
| 139 |
+
flow_matching_t_max=flow_matching_t_max,
|
| 140 |
+
flow_matching_mask_scale=flow_matching_mask_scale,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
@classmethod
|
| 144 |
+
def from_base_model(
|
| 145 |
+
cls,
|
| 146 |
+
base_model: PreTrainedModel,
|
| 147 |
+
num_recursions: int = 8,
|
| 148 |
+
normalization: str = "softmax",
|
| 149 |
+
loss_weight: str = "linear",
|
| 150 |
+
mask_token_id: Optional[int] = None,
|
| 151 |
+
temperature: float = 1.0,
|
| 152 |
+
gradient_steps: Optional[int] = None,
|
| 153 |
+
# === Convergence schedule parameters ===
|
| 154 |
+
schedule: str = "linear",
|
| 155 |
+
causal_strength: float = 1.0,
|
| 156 |
+
# === Effect parameters ===
|
| 157 |
+
temperature_max: float = 0.0,
|
| 158 |
+
entropy_target_max: float = 0.0,
|
| 159 |
+
entropy_floor_max: float = 0.0,
|
| 160 |
+
smear_sigma_max: float = 0.0,
|
| 161 |
+
noise_std_max: float = 0.0,
|
| 162 |
+
iteration_rope_dim_fraction: float = 0.0,
|
| 163 |
+
use_recursion_checkpointing: bool = True,
|
| 164 |
+
# === Soft embedding method ===
|
| 165 |
+
soft_embedding_method: str = "softmax",
|
| 166 |
+
soft_embedding_ema_step: float = 1.0,
|
| 167 |
+
# === Flow matching parameters ===
|
| 168 |
+
flow_matching_enabled: bool = False,
|
| 169 |
+
flow_matching_lambda: float = 0.5,
|
| 170 |
+
flow_matching_t_distribution: str = "logit_normal",
|
| 171 |
+
flow_matching_t_logit_mean: float = -0.4,
|
| 172 |
+
flow_matching_t_logit_std: float = 1.0,
|
| 173 |
+
flow_matching_t_min: float = 0.01,
|
| 174 |
+
flow_matching_t_max: float = 0.99,
|
| 175 |
+
flow_matching_mask_scale: bool = False,
|
| 176 |
+
) -> "RecursiveMaskedLM":
|
| 177 |
+
"""Wrap an existing model for recursive refinement.
|
| 178 |
+
|
| 179 |
+
Use this for models not loadable via AutoModelForMaskedLM (e.g., LLaDA).
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
base_model: The base MLM model to wrap
|
| 183 |
+
num_recursions: Number of recursive refinement steps
|
| 184 |
+
normalization: Normalization method for logits (softmax, stable_softmax)
|
| 185 |
+
loss_weight: Loss weighting scheme (last_1, last_2, linear, uniform)
|
| 186 |
+
mask_token_id: Token ID for [MASK]
|
| 187 |
+
temperature: Temperature for softmax normalization
|
| 188 |
+
gradient_steps: Number of final steps to backprop through
|
| 189 |
+
schedule: Convergence schedule type ("linear" or "causal")
|
| 190 |
+
causal_strength: How much faster early positions converge (causal only)
|
| 191 |
+
temperature_max: Max temperature boost for uncertain positions
|
| 192 |
+
entropy_target_max: Target entropy at progress=0 (two-sided, recommended)
|
| 193 |
+
entropy_floor_max: Min entropy floor (one-sided)
|
| 194 |
+
smear_sigma_max: Max Gaussian sigma for position smearing
|
| 195 |
+
noise_std_max: Max std of Gaussian noise on logits
|
| 196 |
+
iteration_rope_dim_fraction: Fraction of dims for iteration RoPE
|
| 197 |
+
use_recursion_checkpointing: Enable gradient checkpointing for iterations
|
| 198 |
+
soft_embedding_method: How to convert logits to soft embeddings
|
| 199 |
+
soft_embedding_ema_step: EMA step size (1.0 = no EMA, <1.0 = blend with previous)
|
| 200 |
+
flow_matching_enabled: Enable CFM-inspired flow matching framework
|
| 201 |
+
flow_matching_lambda: Weight of distillation KL loss relative to CE
|
| 202 |
+
flow_matching_t_distribution: Time sampling distribution ("logit_normal" or "uniform")
|
| 203 |
+
flow_matching_t_logit_mean: Mean of logit-normal distribution
|
| 204 |
+
flow_matching_t_logit_std: Std of logit-normal distribution
|
| 205 |
+
flow_matching_t_min: Minimum time value (clamp)
|
| 206 |
+
flow_matching_t_max: Maximum time value (clamp)
|
| 207 |
+
flow_matching_mask_scale: Scale mask_emb by (1-t) if True, binary if False
|
| 208 |
+
"""
|
| 209 |
+
config = RecursiveMLMConfig.from_base_model_config(
|
| 210 |
+
base_model.config,
|
| 211 |
+
num_recursions=num_recursions,
|
| 212 |
+
normalization=normalization,
|
| 213 |
+
loss_weight=loss_weight,
|
| 214 |
+
mask_token_id=mask_token_id,
|
| 215 |
+
temperature=temperature,
|
| 216 |
+
gradient_steps=gradient_steps,
|
| 217 |
+
schedule=schedule,
|
| 218 |
+
causal_strength=causal_strength,
|
| 219 |
+
temperature_max=temperature_max,
|
| 220 |
+
entropy_target_max=entropy_target_max,
|
| 221 |
+
entropy_floor_max=entropy_floor_max,
|
| 222 |
+
smear_sigma_max=smear_sigma_max,
|
| 223 |
+
noise_std_max=noise_std_max,
|
| 224 |
+
iteration_rope_dim_fraction=iteration_rope_dim_fraction,
|
| 225 |
+
use_recursion_checkpointing=use_recursion_checkpointing,
|
| 226 |
+
soft_embedding_method=soft_embedding_method,
|
| 227 |
+
soft_embedding_ema_step=soft_embedding_ema_step,
|
| 228 |
+
flow_matching_enabled=flow_matching_enabled,
|
| 229 |
+
flow_matching_lambda=flow_matching_lambda,
|
| 230 |
+
flow_matching_t_distribution=flow_matching_t_distribution,
|
| 231 |
+
flow_matching_t_logit_mean=flow_matching_t_logit_mean,
|
| 232 |
+
flow_matching_t_logit_std=flow_matching_t_logit_std,
|
| 233 |
+
flow_matching_t_min=flow_matching_t_min,
|
| 234 |
+
flow_matching_t_max=flow_matching_t_max,
|
| 235 |
+
flow_matching_mask_scale=flow_matching_mask_scale,
|
| 236 |
+
)
|
| 237 |
+
return cls(config, base_model=base_model)
|
| 238 |
+
|
| 239 |
+
@property
|
| 240 |
+
def embed_weight(self) -> torch.Tensor:
|
| 241 |
+
return self.mlm.get_input_embeddings().weight
|
| 242 |
+
|
| 243 |
+
def get_input_embeddings(self):
|
| 244 |
+
return self.mlm.get_input_embeddings()
|
| 245 |
+
|
| 246 |
+
def set_input_embeddings(self, value):
|
| 247 |
+
self.mlm.set_input_embeddings(value)
|
| 248 |
+
|
| 249 |
+
def get_output_embeddings(self):
|
| 250 |
+
return self.mlm.get_output_embeddings()
|
| 251 |
+
|
| 252 |
+
def set_output_embeddings(self, new_embeddings):
|
| 253 |
+
self.mlm.set_output_embeddings(new_embeddings)
|
| 254 |
+
|
| 255 |
+
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
| 256 |
+
"""Enable gradient checkpointing with correct settings for recursion.
|
| 257 |
+
|
| 258 |
+
Forces use_reentrant=False which is required for:
|
| 259 |
+
- Nested checkpoint calls (base model + recursion checkpointing)
|
| 260 |
+
- Models with frozen parameters
|
| 261 |
+
- Complex gradient flows through soft embeddings
|
| 262 |
+
"""
|
| 263 |
+
if gradient_checkpointing_kwargs is None:
|
| 264 |
+
gradient_checkpointing_kwargs = {}
|
| 265 |
+
# Force use_reentrant=False for nested checkpointing compatibility
|
| 266 |
+
gradient_checkpointing_kwargs.setdefault("use_reentrant", False)
|
| 267 |
+
self.mlm.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
|
| 268 |
+
|
| 269 |
+
def gradient_checkpointing_disable(self):
|
| 270 |
+
"""Disable gradient checkpointing in the underlying MLM."""
|
| 271 |
+
self.mlm.gradient_checkpointing_disable()
|
| 272 |
+
|
| 273 |
+
def _single_iteration_checkpointable(
|
| 274 |
+
self,
|
| 275 |
+
soft_embeds: torch.Tensor,
|
| 276 |
+
base_embeds: torch.Tensor,
|
| 277 |
+
mask_pos: torch.Tensor,
|
| 278 |
+
attention_mask: torch.Tensor,
|
| 279 |
+
embed_weight: torch.Tensor,
|
| 280 |
+
mask_emb: torch.Tensor,
|
| 281 |
+
temperature: torch.Tensor,
|
| 282 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 283 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 284 |
+
"""
|
| 285 |
+
Single differentiable iteration for checkpointing.
|
| 286 |
+
|
| 287 |
+
This method performs one iteration of recursive refinement in a way that
|
| 288 |
+
maintains gradient flow and is compatible with torch.utils.checkpoint.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
soft_embeds: (B, L, H) - current soft embeddings
|
| 292 |
+
base_embeds: (B, L, H) - original token embeddings
|
| 293 |
+
mask_pos: (B, L) bool - which positions are masked
|
| 294 |
+
attention_mask: (B, L) - attention mask for MLM
|
| 295 |
+
embed_weight: (V, H) - embedding weight matrix
|
| 296 |
+
mask_emb: (H,) - mask token embedding
|
| 297 |
+
temperature: scalar tensor - softmax temperature
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
logits: (B, L, V) - output logits from this iteration
|
| 301 |
+
next_soft_embeds: (B, L, H) - soft embeddings for next iteration
|
| 302 |
+
"""
|
| 303 |
+
# Blend: use soft_embeds at masked positions, base_embeds elsewhere
|
| 304 |
+
inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds)
|
| 305 |
+
|
| 306 |
+
# Forward through base MLM
|
| 307 |
+
outputs = self.mlm(
|
| 308 |
+
inputs_embeds=inputs_embeds,
|
| 309 |
+
attention_mask=attention_mask,
|
| 310 |
+
position_ids=position_ids,
|
| 311 |
+
return_dict=True,
|
| 312 |
+
)
|
| 313 |
+
logits = outputs.logits
|
| 314 |
+
|
| 315 |
+
# Compute soft embeddings for next iteration (DIFFERENTIABLE - no detach!)
|
| 316 |
+
next_soft_embeds = base_embeds.clone()
|
| 317 |
+
if mask_pos.any():
|
| 318 |
+
masked_logits = logits[mask_pos] # (num_masked, V)
|
| 319 |
+
|
| 320 |
+
# Convert logits to mixing weights based on soft_embedding_method
|
| 321 |
+
if self.config.soft_embedding_method == "none":
|
| 322 |
+
# No normalization - use raw logits directly
|
| 323 |
+
weights = masked_logits # Differentiable!
|
| 324 |
+
elif self.config.soft_embedding_method == "l2_normalize":
|
| 325 |
+
# L2 normalize logits - removes softmax bottleneck for smoother gradients
|
| 326 |
+
weights = F.normalize(masked_logits, p=2, dim=-1) # Differentiable!
|
| 327 |
+
else:
|
| 328 |
+
# Default: softmax normalization
|
| 329 |
+
weights = F.softmax(masked_logits / temperature, dim=-1) # Differentiable!
|
| 330 |
+
|
| 331 |
+
soft_emb = weights @ embed_weight + mask_emb # Differentiable!
|
| 332 |
+
|
| 333 |
+
# Apply EMA blending with previous soft embeddings if enabled
|
| 334 |
+
ema_step = self.config.soft_embedding_ema_step
|
| 335 |
+
if ema_step < 1.0:
|
| 336 |
+
prev_soft_emb = soft_embeds[mask_pos] # Previous iteration's soft embeddings
|
| 337 |
+
soft_emb = (1.0 - ema_step) * prev_soft_emb + ema_step * soft_emb
|
| 338 |
+
|
| 339 |
+
next_soft_embeds[mask_pos] = soft_emb
|
| 340 |
+
|
| 341 |
+
return logits, next_soft_embeds
|
| 342 |
+
|
| 343 |
+
def _stable_softmax(self, logits: torch.Tensor, T: float = 1.0, dim: int = -1, eps: float = 1e-12) -> torch.Tensor:
|
| 344 |
+
"""Numerically stable softmax with temperature T > 0."""
|
| 345 |
+
z = logits / max(T, eps)
|
| 346 |
+
z = z - z.max(dim=dim, keepdim=True).values # subtract max
|
| 347 |
+
z = torch.exp(z) # safe since z <= 0
|
| 348 |
+
z_sum = z.sum(dim=dim, keepdim=True)
|
| 349 |
+
return z / z_sum.clamp(min=eps)
|
| 350 |
+
|
| 351 |
+
def normalize(self, logits: torch.Tensor) -> torch.Tensor:
|
| 352 |
+
"""Normalize logits -> mixing weights. Shape: (B, L, V) -> (B, L, V)"""
|
| 353 |
+
norm = self.config.normalization.lower()
|
| 354 |
+
T = self.config.temperature
|
| 355 |
+
V = logits.shape[-1]
|
| 356 |
+
|
| 357 |
+
if norm == "none":
|
| 358 |
+
return logits
|
| 359 |
+
|
| 360 |
+
if norm == "softmax":
|
| 361 |
+
return torch.softmax(logits / T, dim=-1)
|
| 362 |
+
|
| 363 |
+
if norm == "stable_softmax":
|
| 364 |
+
return self._stable_softmax(logits, T=T, dim=-1)
|
| 365 |
+
|
| 366 |
+
raise ValueError(f"Unknown normalization: {norm}")
|
| 367 |
+
|
| 368 |
+
def step_weight(self, t: int, T: int) -> float:
|
| 369 |
+
"""Loss weight for step t of T."""
|
| 370 |
+
lw = self.config.loss_weight
|
| 371 |
+
if lw == "linear":
|
| 372 |
+
return (t + 1) / T
|
| 373 |
+
if lw == "uniform":
|
| 374 |
+
return 1.0
|
| 375 |
+
if lw == "last_1":
|
| 376 |
+
return 1.0 if t == T - 1 else 0.0
|
| 377 |
+
if lw == "last_2":
|
| 378 |
+
return 1.0 if T - t <= 2 else 0.0
|
| 379 |
+
raise ValueError(f"Unknown loss_weight: {lw}")
|
| 380 |
+
|
| 381 |
+
# ==================== CONVERGENCE SCHEDULE SYSTEM ====================
|
| 382 |
+
#
|
| 383 |
+
# The core idea: control WHEN each position is allowed to converge.
|
| 384 |
+
#
|
| 385 |
+
# Schedule types:
|
| 386 |
+
# - "linear": All positions converge at the same rate
|
| 387 |
+
# - "causal": Early positions converge first, late positions last
|
| 388 |
+
#
|
| 389 |
+
# Effects (mechanisms to enforce the schedule):
|
| 390 |
+
# - temperature: Raise temperature for positions not yet allowed to converge
|
| 391 |
+
# - entropy_floor: Force minimum entropy
|
| 392 |
+
# - entropy_target: Force exact entropy via bisection (ARChitects-style)
|
| 393 |
+
# - smear: Spread probability across neighboring positions
|
| 394 |
+
# - noise: Add Gaussian noise to logits
|
| 395 |
+
#
|
| 396 |
+
# Each effect uses per-position "convergence progress" (0=uncertain, 1=can converge)
|
| 397 |
+
|
| 398 |
+
def _compute_convergence_progress(
|
| 399 |
+
self,
|
| 400 |
+
iteration: int,
|
| 401 |
+
total_iterations: int,
|
| 402 |
+
seq_length: int,
|
| 403 |
+
mask_positions: torch.Tensor,
|
| 404 |
+
schedule: str = "linear",
|
| 405 |
+
causal_strength: float = 1.0,
|
| 406 |
+
device: torch.device = None,
|
| 407 |
+
dtype: torch.dtype = None,
|
| 408 |
+
) -> torch.Tensor:
|
| 409 |
+
"""
|
| 410 |
+
Compute per-position convergence progress based on schedule.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
iteration: Current iteration (0-indexed)
|
| 414 |
+
total_iterations: Total number of iterations
|
| 415 |
+
seq_length: Full sequence length L
|
| 416 |
+
mask_positions: Position indices of masked tokens (num_masked,)
|
| 417 |
+
schedule: "linear" or "causal"
|
| 418 |
+
causal_strength: How much faster early positions converge (for causal schedule)
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
progress: (num_masked,) tensor with values in [0, 1]
|
| 422 |
+
0 = position should be maximally uncertain
|
| 423 |
+
1 = position is allowed to fully converge
|
| 424 |
+
"""
|
| 425 |
+
base_progress = iteration / max(total_iterations - 1, 1)
|
| 426 |
+
|
| 427 |
+
if schedule == "linear":
|
| 428 |
+
return torch.full(
|
| 429 |
+
(mask_positions.shape[0],),
|
| 430 |
+
base_progress,
|
| 431 |
+
device=device,
|
| 432 |
+
dtype=dtype
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
elif schedule == "causal":
|
| 436 |
+
position_factor = mask_positions.float() / max(seq_length - 1, 1)
|
| 437 |
+
effective_progress = base_progress * (1.0 + causal_strength * (1.0 - position_factor))
|
| 438 |
+
return effective_progress.clamp(0.0, 1.0)
|
| 439 |
+
|
| 440 |
+
else:
|
| 441 |
+
raise ValueError(f"Unknown schedule: {schedule}")
|
| 442 |
+
|
| 443 |
+
def _apply_temperature_effect(
|
| 444 |
+
self,
|
| 445 |
+
logits: torch.Tensor,
|
| 446 |
+
progress: torch.Tensor,
|
| 447 |
+
temperature_max: float,
|
| 448 |
+
) -> torch.Tensor:
|
| 449 |
+
"""
|
| 450 |
+
Apply per-position temperature scaling based on convergence progress.
|
| 451 |
+
Low progress = high temperature (uncertain), high progress = temperature 1.0.
|
| 452 |
+
"""
|
| 453 |
+
if temperature_max <= 0:
|
| 454 |
+
return logits
|
| 455 |
+
|
| 456 |
+
temperature = 1.0 + temperature_max * (1.0 - progress)
|
| 457 |
+
temperature = temperature.unsqueeze(-1)
|
| 458 |
+
|
| 459 |
+
return logits / temperature
|
| 460 |
+
|
| 461 |
+
def _apply_entropy_floor_effect(
|
| 462 |
+
self,
|
| 463 |
+
probs: torch.Tensor,
|
| 464 |
+
progress: torch.Tensor,
|
| 465 |
+
entropy_floor_max: float,
|
| 466 |
+
) -> torch.Tensor:
|
| 467 |
+
"""
|
| 468 |
+
Ensure minimum entropy based on convergence progress.
|
| 469 |
+
Low progress = high entropy floor, high progress = no floor.
|
| 470 |
+
|
| 471 |
+
NOTE: This is a ONE-SIDED constraint (floor only).
|
| 472 |
+
"""
|
| 473 |
+
if entropy_floor_max <= 0:
|
| 474 |
+
return probs
|
| 475 |
+
|
| 476 |
+
entropy_floor = entropy_floor_max * (1.0 - progress)
|
| 477 |
+
|
| 478 |
+
log_probs = torch.log(probs + 1e-10)
|
| 479 |
+
current_entropy = -(probs * log_probs).sum(dim=-1)
|
| 480 |
+
|
| 481 |
+
below_floor = current_entropy < entropy_floor
|
| 482 |
+
|
| 483 |
+
if not below_floor.any():
|
| 484 |
+
return probs
|
| 485 |
+
|
| 486 |
+
logits = torch.log(probs + 1e-10)
|
| 487 |
+
|
| 488 |
+
target_ratio = entropy_floor / (current_entropy + 1e-10)
|
| 489 |
+
temperature = torch.ones_like(current_entropy)
|
| 490 |
+
temperature[below_floor] = target_ratio[below_floor].clamp(1.0, 10.0)
|
| 491 |
+
|
| 492 |
+
scaled_probs = torch.softmax(logits / temperature.unsqueeze(-1), dim=-1)
|
| 493 |
+
|
| 494 |
+
result = probs.clone()
|
| 495 |
+
result[below_floor] = scaled_probs[below_floor]
|
| 496 |
+
return result
|
| 497 |
+
|
| 498 |
+
def _find_temperature_for_target_entropy(
|
| 499 |
+
self,
|
| 500 |
+
logits: torch.Tensor,
|
| 501 |
+
target_entropy: torch.Tensor,
|
| 502 |
+
tol: float = 1e-3,
|
| 503 |
+
max_iter: int = 32,
|
| 504 |
+
T_low: float = 1e-6,
|
| 505 |
+
T_high_init: float = 1.0,
|
| 506 |
+
max_T: float = 100.0,
|
| 507 |
+
) -> torch.Tensor:
|
| 508 |
+
"""
|
| 509 |
+
Find per-position temperatures that achieve exactly the target entropy.
|
| 510 |
+
Uses bisection search, adapted from ARChitects' implementation.
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
logits: Raw logits (num_positions, V)
|
| 514 |
+
target_entropy: Target entropy per position (num_positions,) or scalar
|
| 515 |
+
tol: Entropy tolerance for convergence
|
| 516 |
+
max_iter: Maximum bisection iterations
|
| 517 |
+
T_low: Minimum temperature (near-greedy)
|
| 518 |
+
T_high_init: Initial upper bound for search
|
| 519 |
+
max_T: Maximum allowed temperature
|
| 520 |
+
|
| 521 |
+
Returns:
|
| 522 |
+
temperatures: (num_positions,) temperatures that achieve target entropy
|
| 523 |
+
"""
|
| 524 |
+
N, V = logits.shape
|
| 525 |
+
device, dtype = logits.device, logits.dtype
|
| 526 |
+
H_max = torch.log(torch.tensor(V, device=device, dtype=dtype))
|
| 527 |
+
|
| 528 |
+
if target_entropy.dim() == 0:
|
| 529 |
+
target = target_entropy.expand(N).clone()
|
| 530 |
+
else:
|
| 531 |
+
target = target_entropy.clone()
|
| 532 |
+
target = target.clamp(0.0, H_max)
|
| 533 |
+
|
| 534 |
+
def compute_entropy(logits_: torch.Tensor, temps: torch.Tensor) -> torch.Tensor:
|
| 535 |
+
temps = temps.unsqueeze(-1).clamp(min=T_low)
|
| 536 |
+
scaled = logits_ / temps
|
| 537 |
+
scaled = scaled - scaled.max(dim=-1, keepdim=True).values
|
| 538 |
+
probs = torch.softmax(scaled, dim=-1)
|
| 539 |
+
log_probs = torch.log(probs + 1e-12)
|
| 540 |
+
return -(probs * log_probs).sum(dim=-1)
|
| 541 |
+
|
| 542 |
+
lo = torch.full((N,), T_low, device=device, dtype=dtype)
|
| 543 |
+
hi = torch.full((N,), T_high_init, device=device, dtype=dtype)
|
| 544 |
+
|
| 545 |
+
H_lo = compute_entropy(logits, lo)
|
| 546 |
+
|
| 547 |
+
done_low = target <= (H_lo + tol)
|
| 548 |
+
|
| 549 |
+
H_hi = compute_entropy(logits, hi)
|
| 550 |
+
needs_expansion = (H_hi < target - tol) & ~done_low
|
| 551 |
+
|
| 552 |
+
for _ in range(100):
|
| 553 |
+
if not needs_expansion.any():
|
| 554 |
+
break
|
| 555 |
+
hi[needs_expansion] = (hi[needs_expansion] * 2.0).clamp(max=max_T)
|
| 556 |
+
H_hi[needs_expansion] = compute_entropy(
|
| 557 |
+
logits[needs_expansion], hi[needs_expansion]
|
| 558 |
+
)
|
| 559 |
+
needs_expansion = (H_hi < target - tol) & ~done_low & (hi < max_T - 1e-6)
|
| 560 |
+
|
| 561 |
+
can_bisect = ~done_low & (H_hi >= target - tol)
|
| 562 |
+
|
| 563 |
+
for _ in range(max_iter):
|
| 564 |
+
if not can_bisect.any():
|
| 565 |
+
break
|
| 566 |
+
|
| 567 |
+
mid = (lo + hi) / 2.0
|
| 568 |
+
H_mid = compute_entropy(logits, mid)
|
| 569 |
+
|
| 570 |
+
too_low = (H_mid < target) & can_bisect
|
| 571 |
+
lo[too_low] = mid[too_low]
|
| 572 |
+
hi[~too_low & can_bisect] = mid[~too_low & can_bisect]
|
| 573 |
+
|
| 574 |
+
converged = (hi - lo) <= tol * mid.clamp(min=1.0)
|
| 575 |
+
can_bisect = can_bisect & ~converged
|
| 576 |
+
|
| 577 |
+
temps = torch.zeros(N, device=device, dtype=dtype)
|
| 578 |
+
temps[done_low] = T_low
|
| 579 |
+
temps[~done_low] = (lo[~done_low] + hi[~done_low]) / 2.0
|
| 580 |
+
|
| 581 |
+
return temps
|
| 582 |
+
|
| 583 |
+
def _apply_target_entropy_effect(
|
| 584 |
+
self,
|
| 585 |
+
logits: torch.Tensor,
|
| 586 |
+
progress: torch.Tensor,
|
| 587 |
+
entropy_target_max: float,
|
| 588 |
+
entropy_target_min: float = 0.0,
|
| 589 |
+
) -> torch.Tensor:
|
| 590 |
+
"""
|
| 591 |
+
Adjust temperature to achieve EXACTLY the target entropy per position.
|
| 592 |
+
This is a TWO-SIDED constraint: both raises and lowers entropy as needed.
|
| 593 |
+
|
| 594 |
+
Args:
|
| 595 |
+
logits: Raw logits (num_masked, V)
|
| 596 |
+
progress: Per-position convergence progress (num_masked,)
|
| 597 |
+
entropy_target_max: Target entropy at progress=0
|
| 598 |
+
entropy_target_min: Target entropy at progress=1 (usually ~0)
|
| 599 |
+
|
| 600 |
+
Returns:
|
| 601 |
+
probs: Probabilities with entropy matching targets
|
| 602 |
+
"""
|
| 603 |
+
if entropy_target_max <= 0:
|
| 604 |
+
return torch.softmax(logits, dim=-1)
|
| 605 |
+
|
| 606 |
+
target_entropy = entropy_target_max * (1.0 - progress) + entropy_target_min * progress
|
| 607 |
+
|
| 608 |
+
temps = self._find_temperature_for_target_entropy(logits, target_entropy)
|
| 609 |
+
|
| 610 |
+
temps = temps.unsqueeze(-1).clamp(min=1e-6)
|
| 611 |
+
return torch.softmax(logits / temps, dim=-1)
|
| 612 |
+
|
| 613 |
+
def _apply_smear_effect(
|
| 614 |
+
self,
|
| 615 |
+
probs: torch.Tensor,
|
| 616 |
+
mask_pos: torch.Tensor,
|
| 617 |
+
progress_full: torch.Tensor,
|
| 618 |
+
smear_sigma_max: float,
|
| 619 |
+
) -> torch.Tensor:
|
| 620 |
+
"""
|
| 621 |
+
Apply positional smearing with per-position sigma based on progress.
|
| 622 |
+
Low progress = high smearing, high progress = no smearing.
|
| 623 |
+
|
| 624 |
+
Note: This operates on full (B, L, V) tensor because smearing mixes across positions.
|
| 625 |
+
"""
|
| 626 |
+
if smear_sigma_max <= 0:
|
| 627 |
+
return probs
|
| 628 |
+
|
| 629 |
+
B, L, V = probs.shape
|
| 630 |
+
|
| 631 |
+
sigma_per_pos = smear_sigma_max * (1.0 - progress_full)
|
| 632 |
+
|
| 633 |
+
avg_sigma = sigma_per_pos[mask_pos].mean().item()
|
| 634 |
+
|
| 635 |
+
if avg_sigma < 0.1:
|
| 636 |
+
return probs
|
| 637 |
+
|
| 638 |
+
positions = torch.arange(L, device=probs.device, dtype=probs.dtype)
|
| 639 |
+
diff = positions.unsqueeze(0) - positions.unsqueeze(1)
|
| 640 |
+
kernel = torch.exp(-0.5 * (diff / avg_sigma) ** 2)
|
| 641 |
+
kernel = kernel / kernel.sum(dim=1, keepdim=True)
|
| 642 |
+
|
| 643 |
+
smeared = torch.einsum('ij,bjv->biv', kernel, probs)
|
| 644 |
+
smeared = smeared / smeared.sum(dim=-1, keepdim=True).clamp(min=1e-10)
|
| 645 |
+
|
| 646 |
+
blend = progress_full.unsqueeze(-1)
|
| 647 |
+
result = blend * probs + (1 - blend) * smeared
|
| 648 |
+
|
| 649 |
+
output = probs.clone()
|
| 650 |
+
output[mask_pos] = result[mask_pos]
|
| 651 |
+
return output
|
| 652 |
+
|
| 653 |
+
def _apply_noise_effect(
|
| 654 |
+
self,
|
| 655 |
+
logits: torch.Tensor,
|
| 656 |
+
progress: torch.Tensor,
|
| 657 |
+
noise_std_max: float,
|
| 658 |
+
) -> torch.Tensor:
|
| 659 |
+
"""
|
| 660 |
+
Add Gaussian noise to logits based on convergence progress.
|
| 661 |
+
Low progress = high noise, high progress = no noise.
|
| 662 |
+
"""
|
| 663 |
+
if noise_std_max <= 0:
|
| 664 |
+
return logits
|
| 665 |
+
|
| 666 |
+
noise_std = noise_std_max * (1.0 - progress)
|
| 667 |
+
noise_std = noise_std.unsqueeze(-1)
|
| 668 |
+
|
| 669 |
+
noise = torch.randn_like(logits) * noise_std
|
| 670 |
+
return logits + noise
|
| 671 |
+
|
| 672 |
+
def _apply_iteration_rope(
|
| 673 |
+
self,
|
| 674 |
+
embeds: torch.Tensor,
|
| 675 |
+
iteration: int,
|
| 676 |
+
total_iterations: int,
|
| 677 |
+
dim_fraction: float = 0.25,
|
| 678 |
+
base: float = 10000.0,
|
| 679 |
+
) -> torch.Tensor:
|
| 680 |
+
"""
|
| 681 |
+
Apply rotary embedding based on iteration progress.
|
| 682 |
+
Uses a subset of dimensions to avoid interfering with position RoPE.
|
| 683 |
+
"""
|
| 684 |
+
if dim_fraction <= 0:
|
| 685 |
+
return embeds
|
| 686 |
+
|
| 687 |
+
H = embeds.shape[-1]
|
| 688 |
+
rot_dim = int(H * dim_fraction)
|
| 689 |
+
rot_dim = rot_dim - (rot_dim % 2)
|
| 690 |
+
|
| 691 |
+
if rot_dim < 2:
|
| 692 |
+
return embeds
|
| 693 |
+
|
| 694 |
+
progress = iteration / max(total_iterations - 1, 1)
|
| 695 |
+
|
| 696 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2, device=embeds.device, dtype=embeds.dtype) / rot_dim))
|
| 697 |
+
angles = progress * inv_freq * 3.14159
|
| 698 |
+
cos, sin = torch.cos(angles), torch.sin(angles)
|
| 699 |
+
|
| 700 |
+
if embeds.dim() == 2:
|
| 701 |
+
cos, sin = cos.unsqueeze(0), sin.unsqueeze(0)
|
| 702 |
+
elif embeds.dim() == 3:
|
| 703 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 704 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 705 |
+
|
| 706 |
+
embeds_out = embeds.clone()
|
| 707 |
+
x1, x2 = embeds[..., -rot_dim::2], embeds[..., -rot_dim+1::2]
|
| 708 |
+
embeds_out[..., -rot_dim::2] = x1 * cos - x2 * sin
|
| 709 |
+
embeds_out[..., -rot_dim+1::2] = x1 * sin + x2 * cos
|
| 710 |
+
|
| 711 |
+
return embeds_out
|
| 712 |
+
|
| 713 |
+
# ==================== FLOW MATCHING ====================
|
| 714 |
+
|
| 715 |
+
def _sample_flow_matching_t(self, num_tokens: int, device: torch.device) -> torch.Tensor:
|
| 716 |
+
"""Sample per-token time levels for flow matching.
|
| 717 |
+
|
| 718 |
+
Returns:
|
| 719 |
+
t: (num_tokens,) tensor of time levels in [t_min, t_max]
|
| 720 |
+
"""
|
| 721 |
+
dist = self.config.flow_matching_t_distribution
|
| 722 |
+
if dist == "logit_normal":
|
| 723 |
+
z = torch.randn(num_tokens, device=device)
|
| 724 |
+
z = z * self.config.flow_matching_t_logit_std + self.config.flow_matching_t_logit_mean
|
| 725 |
+
t = torch.sigmoid(z)
|
| 726 |
+
elif dist == "uniform":
|
| 727 |
+
t = torch.empty(num_tokens, device=device).uniform_(0, 1)
|
| 728 |
+
else:
|
| 729 |
+
raise ValueError(f"Unknown flow_matching_t_distribution: {dist}")
|
| 730 |
+
return t.clamp(self.config.flow_matching_t_min, self.config.flow_matching_t_max)
|
| 731 |
+
|
| 732 |
+
def compute_flow_matching_distillation_loss(
|
| 733 |
+
self,
|
| 734 |
+
input_ids: torch.Tensor,
|
| 735 |
+
teacher_logits: torch.Tensor,
|
| 736 |
+
labels: torch.Tensor,
|
| 737 |
+
flow_noise_embed: torch.Tensor,
|
| 738 |
+
flow_t: torch.Tensor,
|
| 739 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 740 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 741 |
+
) -> SelfDistillationOutput:
|
| 742 |
+
"""
|
| 743 |
+
CFM flow matching distillation: teacher sees state at time t, student sees
|
| 744 |
+
noisier state at time s < t on the same interpolation path.
|
| 745 |
+
|
| 746 |
+
Both should predict the same endpoint (target token). The student must
|
| 747 |
+
learn to refine from noisier inputs by matching the teacher's predictions.
|
| 748 |
+
|
| 749 |
+
Args:
|
| 750 |
+
input_ids: Input with [MASK] tokens at positions to predict
|
| 751 |
+
teacher_logits: Logits from the forward pass (will be detached)
|
| 752 |
+
labels: Target tokens at masked positions (-100 elsewhere)
|
| 753 |
+
flow_noise_embed: (num_masked, H) noise embeddings from forward
|
| 754 |
+
flow_t: (num_masked,) per-token time levels from forward
|
| 755 |
+
attention_mask: Standard attention mask
|
| 756 |
+
position_ids: Position IDs (if needed by base model)
|
| 757 |
+
|
| 758 |
+
Returns:
|
| 759 |
+
SelfDistillationOutput with loss, logits, time gap, and diagnostics
|
| 760 |
+
"""
|
| 761 |
+
mask_id = self.config.mask_token_id
|
| 762 |
+
mask_pos = (input_ids == mask_id) # (B, L)
|
| 763 |
+
device = input_ids.device
|
| 764 |
+
num_masked = mask_pos.sum().item()
|
| 765 |
+
|
| 766 |
+
if num_masked == 0:
|
| 767 |
+
zero = torch.tensor(0.0, device=device, requires_grad=True)
|
| 768 |
+
dummy = torch.zeros(1, device=device)
|
| 769 |
+
return SelfDistillationOutput(zero, dummy, dummy, 0.0, 0.0, 0.0, 1.0)
|
| 770 |
+
|
| 771 |
+
teacher_logits = teacher_logits.detach()
|
| 772 |
+
|
| 773 |
+
embed_weight = self.embed_weight
|
| 774 |
+
mask_emb = embed_weight[mask_id] # (H,)
|
| 775 |
+
base_embeds = self.get_input_embeddings()(input_ids) # (B, L, H)
|
| 776 |
+
|
| 777 |
+
# Target embeddings from labels
|
| 778 |
+
target_ids = labels[mask_pos] # (num_masked,)
|
| 779 |
+
target_embed = embed_weight[target_ids] # (num_masked, H)
|
| 780 |
+
|
| 781 |
+
# Sample student time s ~ U(0, t) per token
|
| 782 |
+
s_per_token = flow_t * torch.rand(num_masked, device=device) # (num_masked,)
|
| 783 |
+
|
| 784 |
+
# Student state: same noise, earlier time (noisier)
|
| 785 |
+
s_col = s_per_token.unsqueeze(-1).to(base_embeds.dtype) # (num_masked, 1)
|
| 786 |
+
student_interp = (1 - s_col) * flow_noise_embed + s_col * target_embed
|
| 787 |
+
|
| 788 |
+
if self.config.flow_matching_mask_scale:
|
| 789 |
+
student_masked_embeds = student_interp + (1 - s_col) * mask_emb
|
| 790 |
+
else:
|
| 791 |
+
student_masked_embeds = student_interp + mask_emb
|
| 792 |
+
|
| 793 |
+
# Build full student input (detached — gradient only flows through student's forward)
|
| 794 |
+
student_embeds = base_embeds.detach().clone()
|
| 795 |
+
student_embeds[mask_pos] = student_masked_embeds.detach()
|
| 796 |
+
|
| 797 |
+
student_inputs = torch.where(
|
| 798 |
+
mask_pos.unsqueeze(-1), student_embeds, base_embeds.detach()
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
if attention_mask is None:
|
| 802 |
+
attention_mask = torch.ones_like(input_ids, dtype=base_embeds.dtype)
|
| 803 |
+
|
| 804 |
+
student_out = self.mlm(
|
| 805 |
+
inputs_embeds=student_inputs,
|
| 806 |
+
attention_mask=attention_mask,
|
| 807 |
+
position_ids=position_ids,
|
| 808 |
+
return_dict=True,
|
| 809 |
+
)
|
| 810 |
+
student_logits = student_out.logits # (B, L, V) — has gradient
|
| 811 |
+
|
| 812 |
+
# KL divergence loss on masked positions
|
| 813 |
+
t_logits = teacher_logits[mask_pos] # (num_masked, V)
|
| 814 |
+
s_logits = student_logits[mask_pos] # (num_masked, V)
|
| 815 |
+
|
| 816 |
+
teacher_probs = F.softmax(t_logits, dim=-1)
|
| 817 |
+
student_log_probs = F.log_softmax(s_logits, dim=-1)
|
| 818 |
+
|
| 819 |
+
kl_loss = F.kl_div(
|
| 820 |
+
student_log_probs,
|
| 821 |
+
teacher_probs,
|
| 822 |
+
reduction="batchmean",
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
# Diagnostic metrics
|
| 826 |
+
with torch.no_grad():
|
| 827 |
+
teacher_log_probs = torch.log(teacher_probs + 1e-10)
|
| 828 |
+
teacher_entropy = -(teacher_probs * teacher_log_probs).sum(dim=-1).mean().item()
|
| 829 |
+
|
| 830 |
+
student_probs = F.softmax(s_logits.detach(), dim=-1)
|
| 831 |
+
student_log_probs_det = torch.log(student_probs + 1e-10)
|
| 832 |
+
student_entropy = -(student_probs * student_log_probs_det).sum(dim=-1).mean().item()
|
| 833 |
+
|
| 834 |
+
agreement = (t_logits.argmax(dim=-1) == s_logits.detach().argmax(dim=-1)).float().mean().item()
|
| 835 |
+
|
| 836 |
+
mean_time_gap = (flow_t - s_per_token).mean().item()
|
| 837 |
+
|
| 838 |
+
return SelfDistillationOutput(
|
| 839 |
+
loss=kl_loss,
|
| 840 |
+
teacher_logits=teacher_logits,
|
| 841 |
+
student_logits=student_logits,
|
| 842 |
+
degradation_temperature=mean_time_gap,
|
| 843 |
+
teacher_entropy=teacher_entropy,
|
| 844 |
+
student_entropy=student_entropy,
|
| 845 |
+
agreement_rate=agreement,
|
| 846 |
+
)
|
| 847 |
+
|
| 848 |
+
# ==================== SELF-DISTILLATION (legacy) ====================
|
| 849 |
+
|
| 850 |
+
def compute_self_distillation_loss(
|
| 851 |
+
self,
|
| 852 |
+
input_ids: torch.Tensor,
|
| 853 |
+
teacher_logits: torch.Tensor,
|
| 854 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 855 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 856 |
+
temperature_min: Optional[float] = None,
|
| 857 |
+
temperature_max: Optional[float] = None,
|
| 858 |
+
temperature_distribution: Optional[str] = None,
|
| 859 |
+
) -> SelfDistillationOutput:
|
| 860 |
+
"""
|
| 861 |
+
CFM-style self-distillation: model's predictions should be consistent
|
| 862 |
+
across different levels of input degradation.
|
| 863 |
+
|
| 864 |
+
Process:
|
| 865 |
+
1. Take teacher logits (from standard forward pass, DETACHED)
|
| 866 |
+
2. Degrade: per-token random temperature → softer soft embeddings
|
| 867 |
+
3. Student: forward pass from degraded embeddings → logits (has grad)
|
| 868 |
+
4. Loss: KL(teacher || student) on masked positions
|
| 869 |
+
|
| 870 |
+
Each masked token gets its own independently sampled degradation
|
| 871 |
+
temperature, creating varied difficulty across the sequence.
|
| 872 |
+
|
| 873 |
+
Args:
|
| 874 |
+
input_ids: Input with [MASK] tokens at positions to predict
|
| 875 |
+
teacher_logits: Pre-computed teacher logits (will be detached).
|
| 876 |
+
Typically outputs.all_logits[0] or outputs.logits from standard forward.
|
| 877 |
+
attention_mask: Standard attention mask
|
| 878 |
+
position_ids: Position IDs (if needed by base model)
|
| 879 |
+
temperature_min: Min degradation temperature (default: config value)
|
| 880 |
+
temperature_max: Max degradation temperature (default: config value)
|
| 881 |
+
temperature_distribution: How to sample T (default: config value)
|
| 882 |
+
|
| 883 |
+
Returns:
|
| 884 |
+
SelfDistillationOutput with loss, logits, temperature, and diagnostics
|
| 885 |
+
"""
|
| 886 |
+
# Resolve defaults from config
|
| 887 |
+
temperature_min = temperature_min if temperature_min is not None else self.config.self_distillation_temperature_min
|
| 888 |
+
temperature_max = temperature_max if temperature_max is not None else self.config.self_distillation_temperature_max
|
| 889 |
+
temperature_distribution = temperature_distribution if temperature_distribution is not None else self.config.self_distillation_temperature_distribution
|
| 890 |
+
|
| 891 |
+
mask_id = self.config.mask_token_id
|
| 892 |
+
mask_pos = (input_ids == mask_id) # (B, L)
|
| 893 |
+
device = input_ids.device
|
| 894 |
+
num_masked = mask_pos.sum().item()
|
| 895 |
+
|
| 896 |
+
# Handle degenerate case: no masked positions
|
| 897 |
+
if num_masked == 0:
|
| 898 |
+
zero = torch.tensor(0.0, device=device, requires_grad=True)
|
| 899 |
+
dummy = torch.zeros(1, device=device)
|
| 900 |
+
return SelfDistillationOutput(zero, dummy, dummy, 1.0, 0.0, 0.0, 1.0)
|
| 901 |
+
|
| 902 |
+
# Ensure teacher logits are detached
|
| 903 |
+
teacher_logits = teacher_logits.detach()
|
| 904 |
+
|
| 905 |
+
embed_weight = self.embed_weight
|
| 906 |
+
mask_emb = embed_weight[mask_id] # (H,)
|
| 907 |
+
base_embeds = self.get_input_embeddings()(input_ids) # (B, L, H)
|
| 908 |
+
|
| 909 |
+
# ===== STEP 1: Sample per-token degradation temperatures =====
|
| 910 |
+
# Each masked position gets its own temperature independently
|
| 911 |
+
if temperature_distribution == "log_uniform":
|
| 912 |
+
log_min = torch.tensor(temperature_min, device=device).log()
|
| 913 |
+
log_max = torch.tensor(temperature_max, device=device).log()
|
| 914 |
+
log_T = torch.empty(num_masked, device=device).uniform_(log_min.item(), log_max.item())
|
| 915 |
+
T_per_token = log_T.exp() # (num_masked,)
|
| 916 |
+
elif temperature_distribution == "uniform":
|
| 917 |
+
T_per_token = torch.empty(num_masked, device=device).uniform_(
|
| 918 |
+
temperature_min, temperature_max
|
| 919 |
+
) # (num_masked,)
|
| 920 |
+
else:
|
| 921 |
+
raise ValueError(f"Unknown temperature distribution: {temperature_distribution}")
|
| 922 |
+
|
| 923 |
+
T_mean = T_per_token.mean().item()
|
| 924 |
+
|
| 925 |
+
# ===== STEP 2: Create degraded soft embeddings =====
|
| 926 |
+
# Per-token temperature scaling: each position gets its own T
|
| 927 |
+
masked_teacher_logits = teacher_logits[mask_pos] # (num_masked, V)
|
| 928 |
+
degraded_probs = F.softmax(masked_teacher_logits / T_per_token.unsqueeze(-1), dim=-1).to(embed_weight.dtype)
|
| 929 |
+
degraded_soft = degraded_probs @ embed_weight + mask_emb
|
| 930 |
+
|
| 931 |
+
degraded_soft_embeds = base_embeds.clone()
|
| 932 |
+
degraded_soft_embeds[mask_pos] = degraded_soft
|
| 933 |
+
degraded_soft_embeds = degraded_soft_embeds.detach()
|
| 934 |
+
|
| 935 |
+
# ===== STEP 3: Student forward from degraded input =====
|
| 936 |
+
student_inputs = torch.where(
|
| 937 |
+
mask_pos.unsqueeze(-1), degraded_soft_embeds, base_embeds.detach()
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
if attention_mask is None:
|
| 941 |
+
attention_mask = torch.ones_like(input_ids, dtype=base_embeds.dtype)
|
| 942 |
+
|
| 943 |
+
student_out = self.mlm(
|
| 944 |
+
inputs_embeds=student_inputs,
|
| 945 |
+
attention_mask=attention_mask,
|
| 946 |
+
position_ids=position_ids,
|
| 947 |
+
return_dict=True,
|
| 948 |
+
)
|
| 949 |
+
student_logits = student_out.logits # (B, L, V) — has gradient!
|
| 950 |
+
|
| 951 |
+
# ===== STEP 4: KL divergence loss on masked positions =====
|
| 952 |
+
t_logits = teacher_logits[mask_pos] # (num_masked, V)
|
| 953 |
+
s_logits = student_logits[mask_pos] # (num_masked, V)
|
| 954 |
+
|
| 955 |
+
teacher_probs = F.softmax(t_logits, dim=-1)
|
| 956 |
+
student_log_probs = F.log_softmax(s_logits, dim=-1)
|
| 957 |
+
|
| 958 |
+
# KL(teacher || student) = sum teacher * (log_teacher - log_student)
|
| 959 |
+
kl_loss = F.kl_div(
|
| 960 |
+
student_log_probs,
|
| 961 |
+
teacher_probs,
|
| 962 |
+
reduction="batchmean",
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
# ===== STEP 5: Compute diagnostic metrics =====
|
| 966 |
+
with torch.no_grad():
|
| 967 |
+
teacher_log_probs = torch.log(teacher_probs + 1e-10)
|
| 968 |
+
teacher_entropy = -(teacher_probs * teacher_log_probs).sum(dim=-1).mean().item()
|
| 969 |
+
|
| 970 |
+
student_probs = F.softmax(s_logits.detach(), dim=-1)
|
| 971 |
+
student_log_probs_det = torch.log(student_probs + 1e-10)
|
| 972 |
+
student_entropy = -(student_probs * student_log_probs_det).sum(dim=-1).mean().item()
|
| 973 |
+
|
| 974 |
+
agreement = (t_logits.argmax(dim=-1) == s_logits.detach().argmax(dim=-1)).float().mean().item()
|
| 975 |
+
|
| 976 |
+
return SelfDistillationOutput(
|
| 977 |
+
loss=kl_loss,
|
| 978 |
+
teacher_logits=teacher_logits,
|
| 979 |
+
student_logits=student_logits,
|
| 980 |
+
degradation_temperature=T_mean,
|
| 981 |
+
teacher_entropy=teacher_entropy,
|
| 982 |
+
student_entropy=student_entropy,
|
| 983 |
+
agreement_rate=agreement,
|
| 984 |
+
)
|
| 985 |
+
|
| 986 |
+
# ==================== MAIN SOFT EMBEDDING COMPUTATION ====================
|
| 987 |
+
|
| 988 |
+
@torch.no_grad()
|
| 989 |
+
def _compute_next_soft_embeds(
|
| 990 |
+
self,
|
| 991 |
+
logits: torch.Tensor,
|
| 992 |
+
mask_pos: torch.Tensor,
|
| 993 |
+
base_embeds: torch.Tensor,
|
| 994 |
+
prev_soft_embeds: Optional[torch.Tensor] = None,
|
| 995 |
+
iteration: int = 0,
|
| 996 |
+
total_iterations: int = 1,
|
| 997 |
+
# === Schedule parameters (default to config values) ===
|
| 998 |
+
schedule: Optional[str] = None,
|
| 999 |
+
causal_strength: Optional[float] = None,
|
| 1000 |
+
# === Effect parameters (default to config values) ===
|
| 1001 |
+
temperature_max: Optional[float] = None,
|
| 1002 |
+
entropy_target_max: Optional[float] = None,
|
| 1003 |
+
entropy_floor_max: Optional[float] = None,
|
| 1004 |
+
smear_sigma_max: Optional[float] = None,
|
| 1005 |
+
noise_std_max: Optional[float] = None,
|
| 1006 |
+
iteration_rope_dim_fraction: Optional[float] = None,
|
| 1007 |
+
) -> torch.Tensor:
|
| 1008 |
+
"""
|
| 1009 |
+
Compute soft embeddings from logits for the next iteration.
|
| 1010 |
+
|
| 1011 |
+
This function implements a unified "convergence schedule" system that controls
|
| 1012 |
+
when each position is allowed to converge to a confident prediction.
|
| 1013 |
+
|
| 1014 |
+
Schedule Types:
|
| 1015 |
+
"linear": All positions converge at the same rate (iteration-based only)
|
| 1016 |
+
"causal": Early positions converge first, late positions last
|
| 1017 |
+
|
| 1018 |
+
Effects (mechanisms to enforce the schedule):
|
| 1019 |
+
temperature_max: High temperature = more uniform distribution (one-sided)
|
| 1020 |
+
entropy_target_max: Force EXACT entropy via bisection search (two-sided, recommended)
|
| 1021 |
+
entropy_floor_max: Force MINIMUM entropy (one-sided, only prevents too confident)
|
| 1022 |
+
smear_sigma_max: Spread probability across neighboring positions
|
| 1023 |
+
noise_std_max: Add Gaussian noise to logits
|
| 1024 |
+
|
| 1025 |
+
All parameters default to their config values if not specified.
|
| 1026 |
+
|
| 1027 |
+
Args:
|
| 1028 |
+
logits: Output logits from current iteration (B, L, V)
|
| 1029 |
+
mask_pos: Boolean mask indicating which positions are masked (B, L)
|
| 1030 |
+
base_embeds: Base token embeddings for non-masked positions (B, L, H)
|
| 1031 |
+
iteration: Current iteration index (0-indexed)
|
| 1032 |
+
total_iterations: Total number of iterations
|
| 1033 |
+
|
| 1034 |
+
Returns:
|
| 1035 |
+
Soft embeddings for next iteration (B, L, H)
|
| 1036 |
+
"""
|
| 1037 |
+
# Use config values as defaults
|
| 1038 |
+
schedule = schedule if schedule is not None else self.config.schedule
|
| 1039 |
+
causal_strength = causal_strength if causal_strength is not None else self.config.causal_strength
|
| 1040 |
+
temperature_max = temperature_max if temperature_max is not None else self.config.temperature_max
|
| 1041 |
+
entropy_target_max = entropy_target_max if entropy_target_max is not None else self.config.entropy_target_max
|
| 1042 |
+
entropy_floor_max = entropy_floor_max if entropy_floor_max is not None else self.config.entropy_floor_max
|
| 1043 |
+
smear_sigma_max = smear_sigma_max if smear_sigma_max is not None else self.config.smear_sigma_max
|
| 1044 |
+
noise_std_max = noise_std_max if noise_std_max is not None else self.config.noise_std_max
|
| 1045 |
+
iteration_rope_dim_fraction = iteration_rope_dim_fraction if iteration_rope_dim_fraction is not None else self.config.iteration_rope_dim_fraction
|
| 1046 |
+
|
| 1047 |
+
soft_embeds = base_embeds.clone()
|
| 1048 |
+
|
| 1049 |
+
if not mask_pos.any():
|
| 1050 |
+
return soft_embeds.detach()
|
| 1051 |
+
|
| 1052 |
+
B, L, V = logits.shape
|
| 1053 |
+
device, dtype = logits.device, logits.dtype
|
| 1054 |
+
|
| 1055 |
+
# Check if any effects are enabled
|
| 1056 |
+
has_effects = (
|
| 1057 |
+
temperature_max > 0 or
|
| 1058 |
+
entropy_target_max > 0 or
|
| 1059 |
+
entropy_floor_max > 0 or
|
| 1060 |
+
smear_sigma_max > 0 or
|
| 1061 |
+
noise_std_max > 0 or
|
| 1062 |
+
iteration_rope_dim_fraction > 0
|
| 1063 |
+
)
|
| 1064 |
+
|
| 1065 |
+
if not has_effects:
|
| 1066 |
+
# Simple path: no convergence schedule effects
|
| 1067 |
+
masked_logits = logits[mask_pos]
|
| 1068 |
+
embed_weight = self.embed_weight
|
| 1069 |
+
|
| 1070 |
+
# Convert logits to mixing weights based on soft_embedding_method
|
| 1071 |
+
if self.config.soft_embedding_method == "none":
|
| 1072 |
+
weights = masked_logits
|
| 1073 |
+
elif self.config.soft_embedding_method == "l2_normalize":
|
| 1074 |
+
weights = F.normalize(masked_logits, p=2, dim=-1)
|
| 1075 |
+
else:
|
| 1076 |
+
weights = self.normalize(masked_logits)
|
| 1077 |
+
|
| 1078 |
+
masked_soft = weights @ embed_weight
|
| 1079 |
+
mask_emb = embed_weight[self.config.mask_token_id]
|
| 1080 |
+
masked_soft = masked_soft + mask_emb
|
| 1081 |
+
|
| 1082 |
+
# Apply EMA blending with previous soft embeddings if enabled
|
| 1083 |
+
ema_step = self.config.soft_embedding_ema_step
|
| 1084 |
+
if ema_step < 1.0 and prev_soft_embeds is not None:
|
| 1085 |
+
prev_masked_soft = prev_soft_embeds[mask_pos]
|
| 1086 |
+
masked_soft = (1.0 - ema_step) * prev_masked_soft + ema_step * masked_soft
|
| 1087 |
+
|
| 1088 |
+
soft_embeds[mask_pos] = masked_soft
|
| 1089 |
+
return soft_embeds.detach()
|
| 1090 |
+
|
| 1091 |
+
# ========== STEP 1: Compute per-position convergence progress ==========
|
| 1092 |
+
batch_indices, position_indices = torch.where(mask_pos)
|
| 1093 |
+
|
| 1094 |
+
progress = self._compute_convergence_progress(
|
| 1095 |
+
iteration=iteration,
|
| 1096 |
+
total_iterations=total_iterations,
|
| 1097 |
+
seq_length=L,
|
| 1098 |
+
mask_positions=position_indices,
|
| 1099 |
+
schedule=schedule,
|
| 1100 |
+
causal_strength=causal_strength,
|
| 1101 |
+
device=device,
|
| 1102 |
+
dtype=dtype,
|
| 1103 |
+
)
|
| 1104 |
+
|
| 1105 |
+
# Compute full (B, L) progress for smearing if needed
|
| 1106 |
+
if smear_sigma_max > 0:
|
| 1107 |
+
all_positions = torch.arange(L, device=device, dtype=dtype)
|
| 1108 |
+
progress_full = self._compute_convergence_progress(
|
| 1109 |
+
iteration=iteration,
|
| 1110 |
+
total_iterations=total_iterations,
|
| 1111 |
+
seq_length=L,
|
| 1112 |
+
mask_positions=all_positions,
|
| 1113 |
+
schedule=schedule,
|
| 1114 |
+
causal_strength=causal_strength,
|
| 1115 |
+
device=device,
|
| 1116 |
+
dtype=dtype,
|
| 1117 |
+
)
|
| 1118 |
+
progress_full = progress_full.unsqueeze(0).expand(B, -1)
|
| 1119 |
+
|
| 1120 |
+
# ========== STEP 2: Apply smearing (needs full tensor) ==========
|
| 1121 |
+
full_probs = self.normalize(logits)
|
| 1122 |
+
|
| 1123 |
+
if smear_sigma_max > 0:
|
| 1124 |
+
full_probs = self._apply_smear_effect(
|
| 1125 |
+
full_probs, mask_pos, progress_full, smear_sigma_max
|
| 1126 |
+
)
|
| 1127 |
+
|
| 1128 |
+
# ========== STEP 3: Extract masked positions ==========
|
| 1129 |
+
masked_logits = logits[mask_pos]
|
| 1130 |
+
masked_probs = full_probs[mask_pos]
|
| 1131 |
+
|
| 1132 |
+
# ========== STEP 4: Apply temperature effect (on logits) ==========
|
| 1133 |
+
if temperature_max > 0 and entropy_target_max <= 0:
|
| 1134 |
+
masked_logits = self._apply_temperature_effect(
|
| 1135 |
+
masked_logits, progress, temperature_max
|
| 1136 |
+
)
|
| 1137 |
+
masked_probs = torch.softmax(masked_logits, dim=-1)
|
| 1138 |
+
|
| 1139 |
+
# ========== STEP 5: Apply noise effect (on logits) ==========
|
| 1140 |
+
if noise_std_max > 0:
|
| 1141 |
+
masked_logits_noisy = self._apply_noise_effect(
|
| 1142 |
+
torch.log(masked_probs + 1e-10), progress, noise_std_max
|
| 1143 |
+
)
|
| 1144 |
+
masked_probs = torch.softmax(masked_logits_noisy, dim=-1)
|
| 1145 |
+
|
| 1146 |
+
# ========== STEP 6: Apply entropy control ==========
|
| 1147 |
+
if entropy_target_max > 0:
|
| 1148 |
+
masked_probs = self._apply_target_entropy_effect(
|
| 1149 |
+
masked_logits, progress, entropy_target_max
|
| 1150 |
+
)
|
| 1151 |
+
elif entropy_floor_max > 0:
|
| 1152 |
+
masked_probs = self._apply_entropy_floor_effect(
|
| 1153 |
+
masked_probs, progress, entropy_floor_max
|
| 1154 |
+
)
|
| 1155 |
+
|
| 1156 |
+
# ========== STEP 7: Compute soft embeddings ==========
|
| 1157 |
+
embed_weight = self.embed_weight
|
| 1158 |
+
|
| 1159 |
+
# Convert to mixing weights based on soft_embedding_method
|
| 1160 |
+
if self.config.soft_embedding_method == "none":
|
| 1161 |
+
# No normalization - use raw logits directly
|
| 1162 |
+
weights = masked_logits
|
| 1163 |
+
elif self.config.soft_embedding_method == "l2_normalize":
|
| 1164 |
+
# L2 normalize bypasses all the softmax-based effects above
|
| 1165 |
+
weights = F.normalize(masked_logits, p=2, dim=-1)
|
| 1166 |
+
else:
|
| 1167 |
+
weights = masked_probs
|
| 1168 |
+
|
| 1169 |
+
masked_soft = weights @ embed_weight
|
| 1170 |
+
mask_emb = embed_weight[self.config.mask_token_id]
|
| 1171 |
+
masked_soft = masked_soft + mask_emb
|
| 1172 |
+
|
| 1173 |
+
# ========== STEP 8: Apply iteration RoPE ==========
|
| 1174 |
+
if iteration_rope_dim_fraction > 0:
|
| 1175 |
+
masked_soft = self._apply_iteration_rope(
|
| 1176 |
+
masked_soft, iteration, total_iterations, iteration_rope_dim_fraction
|
| 1177 |
+
)
|
| 1178 |
+
|
| 1179 |
+
# ========== STEP 8.5: Apply EMA blending ==========
|
| 1180 |
+
ema_step = self.config.soft_embedding_ema_step
|
| 1181 |
+
if ema_step < 1.0 and prev_soft_embeds is not None:
|
| 1182 |
+
prev_masked_soft = prev_soft_embeds[mask_pos]
|
| 1183 |
+
masked_soft = (1.0 - ema_step) * prev_masked_soft + ema_step * masked_soft
|
| 1184 |
+
|
| 1185 |
+
# ========== STEP 9: Place back and return ==========
|
| 1186 |
+
soft_embeds[mask_pos] = masked_soft
|
| 1187 |
+
|
| 1188 |
+
return soft_embeds.detach()
|
| 1189 |
+
|
| 1190 |
+
@torch.no_grad()
|
| 1191 |
+
def _compute_iteration_metrics(
|
| 1192 |
+
self, logits: torch.Tensor, labels: torch.Tensor
|
| 1193 |
+
) -> IterationMetrics:
|
| 1194 |
+
"""
|
| 1195 |
+
Compute token-level AND sequence-level metrics for a single iteration.
|
| 1196 |
+
Returns scalars only - no large tensor storage.
|
| 1197 |
+
|
| 1198 |
+
Token-level metrics:
|
| 1199 |
+
- accuracy: fraction of correct token predictions
|
| 1200 |
+
- entropy: average entropy per token
|
| 1201 |
+
- softmax_ce: cross-entropy loss per token
|
| 1202 |
+
|
| 1203 |
+
Sequence-level metrics:
|
| 1204 |
+
- full_sequence_accuracy: fraction of sequences where ALL tokens are correct
|
| 1205 |
+
- min_sequence_confidence: mean of minimum top-1 confidence per sequence
|
| 1206 |
+
"""
|
| 1207 |
+
B = logits.shape[0]
|
| 1208 |
+
|
| 1209 |
+
# Move to CPU to avoid GPU OOM - metrics are for monitoring only
|
| 1210 |
+
logits = logits.detach().cpu().float() # float32 is sufficient for metrics
|
| 1211 |
+
target_labels = labels.detach().cpu().contiguous()
|
| 1212 |
+
mask = target_labels != -100
|
| 1213 |
+
|
| 1214 |
+
if mask.sum() == 0:
|
| 1215 |
+
return IterationMetrics(
|
| 1216 |
+
accuracy=0.0,
|
| 1217 |
+
entropy=0.0,
|
| 1218 |
+
softmax_ce=0.0,
|
| 1219 |
+
full_sequence_accuracy=0.0,
|
| 1220 |
+
min_sequence_confidence=0.0,
|
| 1221 |
+
)
|
| 1222 |
+
|
| 1223 |
+
logits = logits.contiguous()
|
| 1224 |
+
predictions = logits.argmax(dim=-1)
|
| 1225 |
+
correct = (predictions == target_labels) & mask
|
| 1226 |
+
|
| 1227 |
+
# ===== TOKEN-LEVEL METRICS =====
|
| 1228 |
+
|
| 1229 |
+
# Token accuracy
|
| 1230 |
+
accuracy = (correct.sum() / mask.sum()).item()
|
| 1231 |
+
|
| 1232 |
+
# Extract valid tokens for entropy/CE
|
| 1233 |
+
valid_logits = logits[mask]
|
| 1234 |
+
valid_labels = target_labels[mask]
|
| 1235 |
+
|
| 1236 |
+
# Entropy (using log_softmax for numerical stability)
|
| 1237 |
+
log_probs = torch.nn.functional.log_softmax(valid_logits, dim=-1)
|
| 1238 |
+
probs = torch.exp(log_probs)
|
| 1239 |
+
entropy = -(probs * log_probs).sum(dim=-1).mean().item()
|
| 1240 |
+
|
| 1241 |
+
# Cross-entropy
|
| 1242 |
+
softmax_ce = torch.nn.functional.cross_entropy(
|
| 1243 |
+
valid_logits, valid_labels, reduction="mean"
|
| 1244 |
+
).item()
|
| 1245 |
+
|
| 1246 |
+
# ===== SEQUENCE-LEVEL METRICS =====
|
| 1247 |
+
|
| 1248 |
+
# Check which sequences have valid tokens
|
| 1249 |
+
sequences_with_tokens = mask.any(dim=1) # (B,)
|
| 1250 |
+
num_valid_sequences = sequences_with_tokens.sum().item()
|
| 1251 |
+
|
| 1252 |
+
if num_valid_sequences == 0:
|
| 1253 |
+
return IterationMetrics(
|
| 1254 |
+
accuracy=accuracy,
|
| 1255 |
+
entropy=entropy,
|
| 1256 |
+
softmax_ce=softmax_ce,
|
| 1257 |
+
full_sequence_accuracy=0.0,
|
| 1258 |
+
min_sequence_confidence=0.0,
|
| 1259 |
+
)
|
| 1260 |
+
|
| 1261 |
+
# Full sequence accuracy: all tokens in sequence must be correct
|
| 1262 |
+
num_correct_per_seq = correct.sum(dim=1) # (B,)
|
| 1263 |
+
num_tokens_per_seq = mask.sum(dim=1) # (B,)
|
| 1264 |
+
all_correct = (num_correct_per_seq == num_tokens_per_seq) & sequences_with_tokens
|
| 1265 |
+
full_seq_accuracy = (all_correct.sum() / num_valid_sequences).item()
|
| 1266 |
+
|
| 1267 |
+
# Min sequence confidence: minimum top-1 probability within each sequence
|
| 1268 |
+
probs_full = torch.softmax(logits, dim=-1) # (B, L, V) - already float32
|
| 1269 |
+
top1_confidence = probs_full.max(dim=-1).values # (B, L)
|
| 1270 |
+
|
| 1271 |
+
min_confidences = []
|
| 1272 |
+
for i in range(B):
|
| 1273 |
+
if sequences_with_tokens[i]:
|
| 1274 |
+
seq_confidences = top1_confidence[i][mask[i]] # (num_tokens_in_seq,)
|
| 1275 |
+
min_confidences.append(seq_confidences.min().item())
|
| 1276 |
+
|
| 1277 |
+
min_seq_conf = sum(min_confidences) / len(min_confidences) if min_confidences else 0.0
|
| 1278 |
+
|
| 1279 |
+
return IterationMetrics(
|
| 1280 |
+
accuracy=accuracy,
|
| 1281 |
+
entropy=entropy,
|
| 1282 |
+
softmax_ce=softmax_ce,
|
| 1283 |
+
full_sequence_accuracy=full_seq_accuracy,
|
| 1284 |
+
min_sequence_confidence=min_seq_conf,
|
| 1285 |
+
)
|
| 1286 |
+
|
| 1287 |
+
def _single_iteration(
|
| 1288 |
+
self,
|
| 1289 |
+
t: int,
|
| 1290 |
+
T: int,
|
| 1291 |
+
soft_embeds: torch.Tensor,
|
| 1292 |
+
base_embeds: torch.Tensor,
|
| 1293 |
+
mask_pos: torch.Tensor,
|
| 1294 |
+
attention_mask: Optional[torch.Tensor],
|
| 1295 |
+
labels: Optional[torch.Tensor],
|
| 1296 |
+
compute_metrics: bool,
|
| 1297 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1298 |
+
**kwargs,
|
| 1299 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[IterationMetrics]]:
|
| 1300 |
+
"""
|
| 1301 |
+
Execute a single iteration of recursive refinement.
|
| 1302 |
+
|
| 1303 |
+
Args:
|
| 1304 |
+
t: Current iteration index (0 to T-1)
|
| 1305 |
+
T: Total number of iterations
|
| 1306 |
+
soft_embeds: Soft embeddings for mask positions
|
| 1307 |
+
base_embeds: Base token embeddings from input_ids
|
| 1308 |
+
mask_pos: Boolean mask of [MASK] positions (B, L)
|
| 1309 |
+
attention_mask: Attention mask for MLM
|
| 1310 |
+
labels: Target labels for loss computation
|
| 1311 |
+
compute_metrics: Whether to compute iteration metrics
|
| 1312 |
+
|
| 1313 |
+
Returns:
|
| 1314 |
+
logits: Output logits from MLM (B, L, V)
|
| 1315 |
+
weighted_loss: Loss weighted by step_weight(t, T), or None if no labels
|
| 1316 |
+
metrics: IterationMetrics, or None if not requested
|
| 1317 |
+
"""
|
| 1318 |
+
# Blend soft embeddings (at mask positions) with base embeddings (at non-mask positions)
|
| 1319 |
+
inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds)
|
| 1320 |
+
|
| 1321 |
+
# Forward through base MLM
|
| 1322 |
+
outputs = self.mlm(
|
| 1323 |
+
inputs_embeds=inputs_embeds,
|
| 1324 |
+
attention_mask=attention_mask,
|
| 1325 |
+
position_ids=position_ids,
|
| 1326 |
+
labels=labels,
|
| 1327 |
+
return_dict=True,
|
| 1328 |
+
**kwargs,
|
| 1329 |
+
)
|
| 1330 |
+
|
| 1331 |
+
# Compute weighted loss for this iteration
|
| 1332 |
+
weighted_loss = outputs.loss
|
| 1333 |
+
if labels is not None:
|
| 1334 |
+
if weighted_loss is None:
|
| 1335 |
+
# Base model doesn't compute loss (e.g., LLaDA) - compute it ourselves
|
| 1336 |
+
# Only compute loss on MASKED positions (MDLM training)
|
| 1337 |
+
masked_logits = outputs.logits[mask_pos] # (num_masked, V)
|
| 1338 |
+
masked_labels = labels[mask_pos] # (num_masked,)
|
| 1339 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
| 1340 |
+
weighted_loss = loss_fct(masked_logits, masked_labels)
|
| 1341 |
+
weighted_loss *= self.step_weight(t, T)
|
| 1342 |
+
|
| 1343 |
+
# Compute iteration metrics if requested
|
| 1344 |
+
metrics = None
|
| 1345 |
+
if compute_metrics and labels is not None:
|
| 1346 |
+
metrics = self._compute_iteration_metrics(outputs.logits, labels)
|
| 1347 |
+
|
| 1348 |
+
return outputs.logits, weighted_loss, metrics
|
| 1349 |
+
|
| 1350 |
+
def forward(
|
| 1351 |
+
self,
|
| 1352 |
+
input_ids: torch.Tensor,
|
| 1353 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1354 |
+
labels: Optional[torch.Tensor] = None,
|
| 1355 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1356 |
+
num_recursions: Optional[int] = None,
|
| 1357 |
+
compute_iteration_metrics: bool = False,
|
| 1358 |
+
use_recursion_checkpointing: Optional[bool] = None,
|
| 1359 |
+
# Parameters for single-iteration training mode (DEPRECATED)
|
| 1360 |
+
prev_soft_embeds: Optional[torch.Tensor] = None,
|
| 1361 |
+
run_set_iteration: Optional[int] = None,
|
| 1362 |
+
# === Convergence schedule parameters (None = use config defaults) ===
|
| 1363 |
+
schedule: Optional[str] = None,
|
| 1364 |
+
causal_strength: Optional[float] = None,
|
| 1365 |
+
# === Effect parameters (None = use config defaults) ===
|
| 1366 |
+
temperature_max: Optional[float] = None,
|
| 1367 |
+
entropy_target_max: Optional[float] = None,
|
| 1368 |
+
entropy_floor_max: Optional[float] = None,
|
| 1369 |
+
smear_sigma_max: Optional[float] = None,
|
| 1370 |
+
noise_std_max: Optional[float] = None,
|
| 1371 |
+
iteration_rope_dim_fraction: Optional[float] = None,
|
| 1372 |
+
**kwargs,
|
| 1373 |
+
) -> RecursiveMaskedLMOutput:
|
| 1374 |
+
"""
|
| 1375 |
+
Forward with recursive refinement.
|
| 1376 |
+
|
| 1377 |
+
Supports three modes:
|
| 1378 |
+
1. Checkpointed mode (default): Run all T recursions with gradient checkpointing.
|
| 1379 |
+
Gradients flow through the entire chain; activations recomputed during backward.
|
| 1380 |
+
2. Non-checkpointed mode (use_recursion_checkpointing=False): Store all activations.
|
| 1381 |
+
Faster backward but higher memory.
|
| 1382 |
+
3. Single-iteration mode (DEPRECATED - run_set_iteration is not None): Run only one
|
| 1383 |
+
iteration. Use use_recursion_checkpointing=True instead.
|
| 1384 |
+
|
| 1385 |
+
Loss Weighting (config.loss_weight):
|
| 1386 |
+
"last_1": Only final iteration loss (enables learning convergence behavior)
|
| 1387 |
+
"last_2": Last 2 iterations
|
| 1388 |
+
"linear": All iterations, linearly weighted (default)
|
| 1389 |
+
"uniform": All iterations, uniformly weighted
|
| 1390 |
+
|
| 1391 |
+
Recursion Checkpointing:
|
| 1392 |
+
use_recursion_checkpointing: Enable gradient checkpointing for iterations.
|
| 1393 |
+
True = checkpoint each iteration, recompute during backward (default).
|
| 1394 |
+
False = store all activations (higher memory, faster backward).
|
| 1395 |
+
|
| 1396 |
+
Convergence Schedule Parameters:
|
| 1397 |
+
All schedule/effect parameters default to their config values if not specified.
|
| 1398 |
+
Pass explicit values to override config for this forward pass.
|
| 1399 |
+
|
| 1400 |
+
schedule: "linear" or "causal" - controls when positions can converge
|
| 1401 |
+
causal_strength: How much faster early positions converge (causal only)
|
| 1402 |
+
temperature_max: Max temperature boost for uncertain positions
|
| 1403 |
+
entropy_target_max: Target entropy at progress=0 (two-sided, recommended)
|
| 1404 |
+
entropy_floor_max: Min entropy floor (one-sided)
|
| 1405 |
+
smear_sigma_max: Max Gaussian sigma for position smearing
|
| 1406 |
+
noise_std_max: Max std of Gaussian noise on logits
|
| 1407 |
+
iteration_rope_dim_fraction: Fraction of dims for iteration RoPE
|
| 1408 |
+
"""
|
| 1409 |
+
B, L = input_ids.shape
|
| 1410 |
+
V = self.embed_weight.shape[0]
|
| 1411 |
+
mask_id = self.config.mask_token_id
|
| 1412 |
+
|
| 1413 |
+
if mask_id is None:
|
| 1414 |
+
raise ValueError("mask_token_id must be set")
|
| 1415 |
+
|
| 1416 |
+
# Resolve config default for recursion checkpointing
|
| 1417 |
+
use_recursion_checkpointing = (
|
| 1418 |
+
use_recursion_checkpointing
|
| 1419 |
+
if use_recursion_checkpointing is not None
|
| 1420 |
+
else self.config.use_recursion_checkpointing
|
| 1421 |
+
)
|
| 1422 |
+
|
| 1423 |
+
mask_pos = (input_ids == mask_id) # (B, L)
|
| 1424 |
+
base_embeds = self.get_input_embeddings()(input_ids) # (B, L, H)
|
| 1425 |
+
T = num_recursions or self.config.num_recursions
|
| 1426 |
+
weight_sum = sum(self.step_weight(i, T) for i in range(T))
|
| 1427 |
+
|
| 1428 |
+
# Bundle schedule kwargs to pass to _compute_next_soft_embeds
|
| 1429 |
+
schedule_kwargs = dict(
|
| 1430 |
+
schedule=schedule,
|
| 1431 |
+
causal_strength=causal_strength,
|
| 1432 |
+
temperature_max=temperature_max,
|
| 1433 |
+
entropy_target_max=entropy_target_max,
|
| 1434 |
+
entropy_floor_max=entropy_floor_max,
|
| 1435 |
+
smear_sigma_max=smear_sigma_max,
|
| 1436 |
+
noise_std_max=noise_std_max,
|
| 1437 |
+
iteration_rope_dim_fraction=iteration_rope_dim_fraction,
|
| 1438 |
+
)
|
| 1439 |
+
|
| 1440 |
+
# ===== SINGLE ITERATION MODE (DEPRECATED) =====
|
| 1441 |
+
if run_set_iteration is not None:
|
| 1442 |
+
warnings.warn(
|
| 1443 |
+
"run_set_iteration is deprecated. Use use_recursion_checkpointing=True instead, "
|
| 1444 |
+
"which provides proper gradient flow through all iterations.",
|
| 1445 |
+
DeprecationWarning,
|
| 1446 |
+
stacklevel=2,
|
| 1447 |
+
)
|
| 1448 |
+
t = run_set_iteration
|
| 1449 |
+
|
| 1450 |
+
# Get soft embeddings for this iteration
|
| 1451 |
+
if t == 0:
|
| 1452 |
+
# t=0: Uniform prior = average embedding (equivalent to softmax(zeros) @ embed_weight)
|
| 1453 |
+
# We compute this efficiently via embed_weight.mean() rather than creating large zero tensors
|
| 1454 |
+
soft_embeds = base_embeds.clone()
|
| 1455 |
+
if mask_pos.any():
|
| 1456 |
+
avg_embed = self.embed_weight.mean(dim=0) # (H,) - mean over all V tokens
|
| 1457 |
+
mask_emb = self.embed_weight[mask_id]
|
| 1458 |
+
soft_embeds[mask_pos] = avg_embed + mask_emb
|
| 1459 |
+
else:
|
| 1460 |
+
if prev_soft_embeds is None:
|
| 1461 |
+
raise ValueError(f"prev_soft_embeds must be provided for iteration {t}")
|
| 1462 |
+
soft_embeds = prev_soft_embeds
|
| 1463 |
+
|
| 1464 |
+
logits, weighted_loss, metrics = self._single_iteration(
|
| 1465 |
+
t, T, soft_embeds, base_embeds, mask_pos,
|
| 1466 |
+
attention_mask, labels, compute_iteration_metrics,
|
| 1467 |
+
position_ids=position_ids, **kwargs
|
| 1468 |
+
)
|
| 1469 |
+
|
| 1470 |
+
# Normalize loss by total weight sum
|
| 1471 |
+
loss = weighted_loss / weight_sum if weighted_loss is not None else None
|
| 1472 |
+
|
| 1473 |
+
# Compute soft embeddings for next iteration (if not last)
|
| 1474 |
+
next_soft_embeds = None
|
| 1475 |
+
if t < T - 1:
|
| 1476 |
+
next_soft_embeds = self._compute_next_soft_embeds(
|
| 1477 |
+
logits, mask_pos, base_embeds,
|
| 1478 |
+
iteration=t,
|
| 1479 |
+
total_iterations=T,
|
| 1480 |
+
**schedule_kwargs,
|
| 1481 |
+
)
|
| 1482 |
+
|
| 1483 |
+
return RecursiveMaskedLMOutput(
|
| 1484 |
+
loss=loss,
|
| 1485 |
+
logits=logits,
|
| 1486 |
+
next_soft_embeds=next_soft_embeds,
|
| 1487 |
+
iteration_metrics={t: metrics} if metrics is not None else None,
|
| 1488 |
+
)
|
| 1489 |
+
|
| 1490 |
+
# ===== CHECKPOINTED MODE (gradient flow through all iterations) =====
|
| 1491 |
+
embed_weight = self.embed_weight
|
| 1492 |
+
mask_emb = embed_weight[mask_id] # (H,)
|
| 1493 |
+
|
| 1494 |
+
# Temperature must be a tensor for checkpointing (checkpoint requires tensor inputs)
|
| 1495 |
+
temperature = torch.tensor(
|
| 1496 |
+
self.config.temperature,
|
| 1497 |
+
device=input_ids.device,
|
| 1498 |
+
dtype=base_embeds.dtype,
|
| 1499 |
+
)
|
| 1500 |
+
|
| 1501 |
+
# Ensure attention_mask is a tensor (required for checkpointing)
|
| 1502 |
+
if attention_mask is None:
|
| 1503 |
+
attention_mask = torch.ones(B, L, device=input_ids.device, dtype=base_embeds.dtype)
|
| 1504 |
+
|
| 1505 |
+
# Initialize soft embeddings for masked positions
|
| 1506 |
+
soft_embeds = base_embeds.clone()
|
| 1507 |
+
flow_noise_embed = None
|
| 1508 |
+
flow_t_per_token = None
|
| 1509 |
+
|
| 1510 |
+
if self.config.flow_matching_enabled and self.training and labels is not None and mask_pos.any():
|
| 1511 |
+
# Flow matching: interpolate between random noise and target on the simplex
|
| 1512 |
+
num_masked = mask_pos.sum().item()
|
| 1513 |
+
V = embed_weight.shape[0]
|
| 1514 |
+
device = input_ids.device
|
| 1515 |
+
|
| 1516 |
+
# Sample per-token time levels (logit-normal by default)
|
| 1517 |
+
flow_t_per_token = self._sample_flow_matching_t(num_masked, device)
|
| 1518 |
+
|
| 1519 |
+
# Random noise embedding: sample on simplex, project to H-dim
|
| 1520 |
+
z = torch.randn(num_masked, V, device=device, dtype=base_embeds.dtype)
|
| 1521 |
+
p_noise = F.softmax(z * self.config.flow_matching_noise_scale, dim=-1).to(base_embeds.dtype)
|
| 1522 |
+
flow_noise_embed = p_noise @ embed_weight # (num_masked, H) — compact
|
| 1523 |
+
|
| 1524 |
+
# Target embedding from labels
|
| 1525 |
+
target_ids = labels[mask_pos] # original token IDs at masked positions
|
| 1526 |
+
target_embed = embed_weight[target_ids] # (num_masked, H)
|
| 1527 |
+
|
| 1528 |
+
# Interpolate in embedding space
|
| 1529 |
+
t_col = flow_t_per_token.unsqueeze(-1).to(base_embeds.dtype) # (num_masked, 1)
|
| 1530 |
+
interp_embed = (1 - t_col) * flow_noise_embed + t_col * target_embed
|
| 1531 |
+
|
| 1532 |
+
# Add mask signal (binary or scaled)
|
| 1533 |
+
if self.config.flow_matching_mask_scale:
|
| 1534 |
+
soft_embeds[mask_pos] = interp_embed + (1 - t_col) * mask_emb
|
| 1535 |
+
else:
|
| 1536 |
+
soft_embeds[mask_pos] = interp_embed + mask_emb
|
| 1537 |
+
elif mask_pos.any():
|
| 1538 |
+
# Standard uniform prior (average embedding + mask signal)
|
| 1539 |
+
avg_embed = embed_weight.mean(dim=0) # (H,)
|
| 1540 |
+
soft_embeds[mask_pos] = avg_embed + mask_emb
|
| 1541 |
+
|
| 1542 |
+
iteration_metrics = {} if compute_iteration_metrics and labels is not None else None
|
| 1543 |
+
|
| 1544 |
+
# Main recursion loop with optional checkpointing
|
| 1545 |
+
all_logits = []
|
| 1546 |
+
for t in range(T):
|
| 1547 |
+
if self.training and use_recursion_checkpointing:
|
| 1548 |
+
# Use checkpointing: activations recomputed during backward
|
| 1549 |
+
# This maintains gradient flow while saving memory
|
| 1550 |
+
logits, soft_embeds = torch_checkpoint(
|
| 1551 |
+
self._single_iteration_checkpointable,
|
| 1552 |
+
soft_embeds,
|
| 1553 |
+
base_embeds,
|
| 1554 |
+
mask_pos,
|
| 1555 |
+
attention_mask,
|
| 1556 |
+
embed_weight,
|
| 1557 |
+
mask_emb,
|
| 1558 |
+
temperature,
|
| 1559 |
+
position_ids,
|
| 1560 |
+
use_reentrant=False, # Critical for nested checkpointing!
|
| 1561 |
+
)
|
| 1562 |
+
else:
|
| 1563 |
+
# No checkpointing: store all activations (inference or explicit disable)
|
| 1564 |
+
logits, soft_embeds = self._single_iteration_checkpointable(
|
| 1565 |
+
soft_embeds,
|
| 1566 |
+
base_embeds,
|
| 1567 |
+
mask_pos,
|
| 1568 |
+
attention_mask,
|
| 1569 |
+
embed_weight,
|
| 1570 |
+
mask_emb,
|
| 1571 |
+
temperature,
|
| 1572 |
+
position_ids,
|
| 1573 |
+
)
|
| 1574 |
+
all_logits.append(logits)
|
| 1575 |
+
|
| 1576 |
+
# Compute iteration metrics if requested (no grad needed)
|
| 1577 |
+
if iteration_metrics is not None and labels is not None:
|
| 1578 |
+
with torch.no_grad():
|
| 1579 |
+
iteration_metrics[t] = self._compute_iteration_metrics(logits, labels)
|
| 1580 |
+
|
| 1581 |
+
# Return all logits for trainer to compute loss with proper normalization
|
| 1582 |
+
# Trainer handles: timestep-based weighting, iteration weighting, batch/sequence/token normalization
|
| 1583 |
+
return RecursiveMaskedLMOutput(
|
| 1584 |
+
loss=None, # Let trainer compute loss
|
| 1585 |
+
logits=logits, # Final logits for inference/metrics
|
| 1586 |
+
all_logits=all_logits if self.training else None, # Only needed during training
|
| 1587 |
+
iteration_metrics=iteration_metrics or None,
|
| 1588 |
+
flow_noise_embed=flow_noise_embed, # For flow matching distillation
|
| 1589 |
+
flow_t=flow_t_per_token, # For flow matching distillation
|
| 1590 |
+
)
|
| 1591 |
+
|
| 1592 |
+
@torch.no_grad()
|
| 1593 |
+
def _generate_flow_map(
|
| 1594 |
+
self,
|
| 1595 |
+
input_ids: torch.Tensor,
|
| 1596 |
+
attention_mask: Optional[torch.Tensor],
|
| 1597 |
+
position_ids: Optional[torch.Tensor],
|
| 1598 |
+
num_steps: int,
|
| 1599 |
+
) -> torch.Tensor:
|
| 1600 |
+
"""Fill in mask positions using the CFM flow map update rule.
|
| 1601 |
+
|
| 1602 |
+
Starts from a random point on the probability simplex and iteratively
|
| 1603 |
+
moves toward the model's predictions using the flow map step rule.
|
| 1604 |
+
|
| 1605 |
+
Args:
|
| 1606 |
+
input_ids: Input with [MASK] tokens at positions to fill
|
| 1607 |
+
attention_mask: Attention mask
|
| 1608 |
+
position_ids: Position IDs
|
| 1609 |
+
num_steps: Number of flow map steps (finer = better, 1 step = greedy)
|
| 1610 |
+
|
| 1611 |
+
Returns:
|
| 1612 |
+
Tensor with [MASK] positions filled with predicted tokens
|
| 1613 |
+
"""
|
| 1614 |
+
mask_pos = (input_ids == self.config.mask_token_id)
|
| 1615 |
+
num_masked = mask_pos.sum().item()
|
| 1616 |
+
|
| 1617 |
+
if num_masked == 0:
|
| 1618 |
+
return input_ids.clone()
|
| 1619 |
+
|
| 1620 |
+
device = input_ids.device
|
| 1621 |
+
V = self.embed_weight.shape[0]
|
| 1622 |
+
embed_weight = self.embed_weight
|
| 1623 |
+
mask_emb = embed_weight[self.config.mask_token_id]
|
| 1624 |
+
base_embeds = self.get_input_embeddings()(input_ids)
|
| 1625 |
+
|
| 1626 |
+
# Start from random simplex point
|
| 1627 |
+
noise_scale = self.config.flow_matching_noise_scale
|
| 1628 |
+
p = F.softmax(torch.randn(num_masked, V, device=device, dtype=base_embeds.dtype) * noise_scale, dim=-1).to(base_embeds.dtype)
|
| 1629 |
+
|
| 1630 |
+
times = torch.linspace(0, 1, num_steps + 1, device=device)
|
| 1631 |
+
|
| 1632 |
+
for i in range(num_steps):
|
| 1633 |
+
t_now = times[i]
|
| 1634 |
+
t_next = times[i + 1]
|
| 1635 |
+
step_size = (t_next - t_now) / (1 - t_now)
|
| 1636 |
+
|
| 1637 |
+
# Mask signal (binary or scaled)
|
| 1638 |
+
if self.config.flow_matching_mask_scale:
|
| 1639 |
+
mask_signal = (1 - t_now) * mask_emb
|
| 1640 |
+
else:
|
| 1641 |
+
mask_signal = mask_emb
|
| 1642 |
+
|
| 1643 |
+
# Project current state to embedding space
|
| 1644 |
+
embed = p @ embed_weight + mask_signal
|
| 1645 |
+
|
| 1646 |
+
soft_embeds = base_embeds.clone()
|
| 1647 |
+
soft_embeds[mask_pos] = embed
|
| 1648 |
+
inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds)
|
| 1649 |
+
|
| 1650 |
+
outputs = self.mlm(
|
| 1651 |
+
inputs_embeds=inputs_embeds,
|
| 1652 |
+
attention_mask=attention_mask,
|
| 1653 |
+
position_ids=position_ids,
|
| 1654 |
+
return_dict=True,
|
| 1655 |
+
)
|
| 1656 |
+
pi = F.softmax(outputs.logits[mask_pos], dim=-1).to(p.dtype)
|
| 1657 |
+
|
| 1658 |
+
# Flow map update: move toward model's prediction
|
| 1659 |
+
p = p + step_size * (pi - p)
|
| 1660 |
+
|
| 1661 |
+
# Fix floating point drift off the simplex
|
| 1662 |
+
p = p.clamp(min=0)
|
| 1663 |
+
p = p / p.sum(dim=-1, keepdim=True)
|
| 1664 |
+
|
| 1665 |
+
result = input_ids.clone()
|
| 1666 |
+
result[mask_pos] = p.argmax(dim=-1)
|
| 1667 |
+
return result
|
| 1668 |
+
|
| 1669 |
+
@torch.no_grad()
|
| 1670 |
+
def generate(
|
| 1671 |
+
self,
|
| 1672 |
+
input_ids: torch.Tensor,
|
| 1673 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1674 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1675 |
+
num_recursions: Optional[int] = None,
|
| 1676 |
+
# === Convergence schedule parameters (None = use config defaults) ===
|
| 1677 |
+
schedule: Optional[str] = None,
|
| 1678 |
+
causal_strength: Optional[float] = None,
|
| 1679 |
+
# === Effect parameters (None = use config defaults) ===
|
| 1680 |
+
temperature_max: Optional[float] = None,
|
| 1681 |
+
entropy_target_max: Optional[float] = None,
|
| 1682 |
+
entropy_floor_max: Optional[float] = None,
|
| 1683 |
+
smear_sigma_max: Optional[float] = None,
|
| 1684 |
+
noise_std_max: Optional[float] = None,
|
| 1685 |
+
iteration_rope_dim_fraction: Optional[float] = None,
|
| 1686 |
+
) -> torch.Tensor:
|
| 1687 |
+
"""Fill in mask positions via iterative refinement.
|
| 1688 |
+
|
| 1689 |
+
When flow_matching_enabled, uses the CFM flow map update rule.
|
| 1690 |
+
Otherwise, uses standard recursive soft-token refinement.
|
| 1691 |
+
|
| 1692 |
+
Args:
|
| 1693 |
+
input_ids: Input token IDs with [MASK] tokens at positions to fill
|
| 1694 |
+
attention_mask: Attention mask
|
| 1695 |
+
num_recursions: Override number of recursions/steps (default: config value)
|
| 1696 |
+
schedule: "linear" or "causal" convergence schedule
|
| 1697 |
+
causal_strength: How much faster early positions converge (causal only)
|
| 1698 |
+
temperature_max: Max temperature boost for uncertain positions
|
| 1699 |
+
entropy_target_max: Target entropy at progress=0 (two-sided)
|
| 1700 |
+
entropy_floor_max: Min entropy floor (one-sided)
|
| 1701 |
+
smear_sigma_max: Max Gaussian sigma for position smearing
|
| 1702 |
+
noise_std_max: Max std of Gaussian noise on logits
|
| 1703 |
+
iteration_rope_dim_fraction: Fraction of dims for iteration RoPE
|
| 1704 |
+
|
| 1705 |
+
Returns:
|
| 1706 |
+
Tensor with [MASK] positions filled with predicted tokens
|
| 1707 |
+
"""
|
| 1708 |
+
num_steps = num_recursions or self.config.num_recursions
|
| 1709 |
+
|
| 1710 |
+
if self.config.flow_matching_enabled:
|
| 1711 |
+
return self._generate_flow_map(
|
| 1712 |
+
input_ids, attention_mask, position_ids, num_steps
|
| 1713 |
+
)
|
| 1714 |
+
|
| 1715 |
+
out = self.forward(
|
| 1716 |
+
input_ids,
|
| 1717 |
+
attention_mask,
|
| 1718 |
+
position_ids=position_ids,
|
| 1719 |
+
num_recursions=num_steps,
|
| 1720 |
+
schedule=schedule,
|
| 1721 |
+
causal_strength=causal_strength,
|
| 1722 |
+
temperature_max=temperature_max,
|
| 1723 |
+
entropy_target_max=entropy_target_max,
|
| 1724 |
+
entropy_floor_max=entropy_floor_max,
|
| 1725 |
+
smear_sigma_max=smear_sigma_max,
|
| 1726 |
+
noise_std_max=noise_std_max,
|
| 1727 |
+
iteration_rope_dim_fraction=iteration_rope_dim_fraction,
|
| 1728 |
+
)
|
| 1729 |
+
result = input_ids.clone()
|
| 1730 |
+
mask_pos = (input_ids == self.config.mask_token_id)
|
| 1731 |
+
result[mask_pos] = out.logits.argmax(dim=-1)[mask_pos]
|
| 1732 |
+
return result
|