File size: 8,628 Bytes
7493ebb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | from __future__ import annotations
from typing import Optional
from transformers import PretrainedConfig
class RecursiveMLMConfig(PretrainedConfig):
"""
Configuration for RecursiveMaskedLM.
Stores the base MLM config plus recursive refinement parameters.
Convergence Schedule System
---------------------------
The convergence schedule controls WHEN each position is allowed to converge
to a confident prediction during iterative refinement.
Schedule types:
- "linear": All positions converge at the same rate (iteration-based only)
- "causal": Early positions converge first, late positions last
Effects (mechanisms to enforce the schedule):
- temperature_max: Raise temperature for positions not yet allowed to converge
- entropy_target_max: Force exact entropy via bisection search (two-sided, recommended)
- entropy_floor_max: Force minimum entropy (one-sided, only raises)
- smear_sigma_max: Spread probability across neighboring positions
- noise_std_max: Add Gaussian noise to logits
- iteration_rope_dim_fraction: Apply rotary embedding based on iteration progress
Soft Embedding Methods
----------------------
Controls how logits are converted to soft embeddings for the next iteration:
- "softmax": Standard softmax normalization (default). Creates sparse, probabilistic
mixing but can cause gradient bottlenecks through the softmax Jacobian.
- "l2_normalize": L2 normalize logits before mixing with embeddings. Removes the
softmax bottleneck for smoother gradients through long recursion chains.
- "none": No normalization - use raw logits directly. Warning: this can cause
scale explosion without additional mechanisms like EMA accumulation.
- soft_embedding_ema_step: Controls EMA blending with previous soft embeddings.
1.0 (default) = full update (no EMA), 0.1 = slow update (90% previous + 10% new).
Formula: new = (1 - ema_step) * prev + ema_step * current
Recursion Checkpointing
-----------------------
Controls gradient flow through the entire recursion chain for memory-efficient training.
Parameters:
- use_recursion_checkpointing: Enable gradient checkpointing for iterations
- loss_weight: Use "last_1" for final-iteration-only loss (learns convergence behavior)
Flow Matching (CFM-inspired)
----------------------------
Replaces the old temperature-based self-distillation with a Continuous Flow Matching
framework. Training inputs are interpolated on the probability simplex between random
noise and the target one-hot, distillation gives the student a noisier (earlier-time)
version of the same interpolation path, and inference uses a flow map update rule.
Parameters:
- flow_matching_enabled: Enable the flow matching framework
- flow_matching_lambda: Weight of distillation KL loss relative to CE loss
- flow_matching_t_distribution: How to sample time t ("logit_normal" or "uniform")
- flow_matching_t_logit_mean: Mean of logit-normal distribution (-0.4 biases toward noisy)
- flow_matching_t_logit_std: Std of logit-normal distribution
- flow_matching_t_min: Minimum time value (clamp)
- flow_matching_t_max: Maximum time value (clamp)
- flow_matching_mask_scale: If True, scale mask_emb by (1-t); if False, binary mask signal
Time levels are sampled independently per masked token. At t=0 the input is pure noise,
at t=1 it is the clean target embedding.
Self-Distillation (legacy, temperature-based)
----------------------------------------------
Kept for backward compatibility. Ignored when flow_matching_enabled=True.
Parameters:
- self_distillation_enabled: Enable the self-distillation KL loss
- self_distillation_lambda: Weight of distillation loss relative to CE loss
- self_distillation_temperature_min: Minimum degradation temperature
- self_distillation_temperature_max: Maximum degradation temperature
- self_distillation_temperature_distribution: How to sample temperature
- self_distillation_teacher: Which logits to use as teacher ("first" or "last")
"""
model_type = "recursive-mlm"
def __init__(
self,
base_model_config: Optional[dict] = None,
num_recursions: int = 8,
normalization: str = "softmax",
loss_weight: str = "linear",
mask_token_id: Optional[int] = None,
temperature: float = 1.0,
gradient_steps: Optional[int] = None,
# === Convergence schedule parameters ===
schedule: str = "linear",
causal_strength: float = 1.0,
# === Effect parameters ===
temperature_max: float = 0.0,
entropy_target_max: float = 0.0,
entropy_floor_max: float = 0.0,
smear_sigma_max: float = 0.0,
noise_std_max: float = 0.0,
iteration_rope_dim_fraction: float = 0.0,
use_recursion_checkpointing: bool = True,
# === Soft embedding method ===
soft_embedding_method: str = "softmax",
soft_embedding_ema_step: float = 1.0,
# === Flow matching parameters (CFM-inspired) ===
flow_matching_enabled: bool = False,
flow_matching_lambda: float = 0.5,
flow_matching_t_distribution: str = "logit_normal",
flow_matching_t_logit_mean: float = -0.4,
flow_matching_t_logit_std: float = 1.0,
flow_matching_t_min: float = 0.01,
flow_matching_t_max: float = 0.99,
flow_matching_noise_scale: float = 2.0,
flow_matching_mask_scale: bool = False,
# === Self-distillation parameters (legacy, ignored when flow_matching_enabled) ===
self_distillation_enabled: bool = False,
self_distillation_lambda: float = 0.5,
self_distillation_temperature_min: float = 1.5,
self_distillation_temperature_max: float = 10.0,
self_distillation_temperature_distribution: str = "log_uniform",
self_distillation_teacher: str = "first",
**kwargs,
):
super().__init__(**kwargs)
self.base_model_config = base_model_config
self.num_recursions = num_recursions
self.normalization = normalization
self.loss_weight = loss_weight
self.mask_token_id = mask_token_id
self.temperature = temperature
self.gradient_steps = gradient_steps
# Convergence schedule
self.schedule = schedule
self.causal_strength = causal_strength
# Effects
self.temperature_max = temperature_max
self.entropy_target_max = entropy_target_max
self.entropy_floor_max = entropy_floor_max
self.smear_sigma_max = smear_sigma_max
self.noise_std_max = noise_std_max
self.iteration_rope_dim_fraction = iteration_rope_dim_fraction
# Recursion checkpointing
self.use_recursion_checkpointing = use_recursion_checkpointing
# Soft embedding method
self.soft_embedding_method = soft_embedding_method
self.soft_embedding_ema_step = soft_embedding_ema_step
# Flow matching
self.flow_matching_enabled = flow_matching_enabled
self.flow_matching_lambda = flow_matching_lambda
self.flow_matching_t_distribution = flow_matching_t_distribution
self.flow_matching_t_logit_mean = flow_matching_t_logit_mean
self.flow_matching_t_logit_std = flow_matching_t_logit_std
self.flow_matching_t_min = flow_matching_t_min
self.flow_matching_t_max = flow_matching_t_max
self.flow_matching_noise_scale = flow_matching_noise_scale
self.flow_matching_mask_scale = flow_matching_mask_scale
# Self-distillation (legacy)
self.self_distillation_enabled = self_distillation_enabled
self.self_distillation_lambda = self_distillation_lambda
self.self_distillation_temperature_min = self_distillation_temperature_min
self.self_distillation_temperature_max = self_distillation_temperature_max
self.self_distillation_temperature_distribution = self_distillation_temperature_distribution
self.self_distillation_teacher = self_distillation_teacher
@classmethod
def from_base_model_config(
cls,
base_config: PretrainedConfig,
**kwargs,
) -> "RecursiveMLMConfig":
"""Create config from a base MLM's config."""
return cls(
base_model_config=base_config.to_dict(),
**kwargs,
)
|