from __future__ import annotations from typing import Optional from transformers import PretrainedConfig class RecursiveMLMConfig(PretrainedConfig): """ Configuration for RecursiveMaskedLM. Stores the base MLM config plus recursive refinement parameters. Convergence Schedule System --------------------------- The convergence schedule controls WHEN each position is allowed to converge to a confident prediction during iterative refinement. Schedule types: - "linear": All positions converge at the same rate (iteration-based only) - "causal": Early positions converge first, late positions last Effects (mechanisms to enforce the schedule): - temperature_max: Raise temperature for positions not yet allowed to converge - entropy_target_max: Force exact entropy via bisection search (two-sided, recommended) - entropy_floor_max: Force minimum entropy (one-sided, only raises) - smear_sigma_max: Spread probability across neighboring positions - noise_std_max: Add Gaussian noise to logits - iteration_rope_dim_fraction: Apply rotary embedding based on iteration progress Soft Embedding Methods ---------------------- Controls how logits are converted to soft embeddings for the next iteration: - "softmax": Standard softmax normalization (default). Creates sparse, probabilistic mixing but can cause gradient bottlenecks through the softmax Jacobian. - "l2_normalize": L2 normalize logits before mixing with embeddings. Removes the softmax bottleneck for smoother gradients through long recursion chains. - "none": No normalization - use raw logits directly. Warning: this can cause scale explosion without additional mechanisms like EMA accumulation. - soft_embedding_ema_step: Controls EMA blending with previous soft embeddings. 1.0 (default) = full update (no EMA), 0.1 = slow update (90% previous + 10% new). Formula: new = (1 - ema_step) * prev + ema_step * current Recursion Checkpointing ----------------------- Controls gradient flow through the entire recursion chain for memory-efficient training. Parameters: - use_recursion_checkpointing: Enable gradient checkpointing for iterations - loss_weight: Use "last_1" for final-iteration-only loss (learns convergence behavior) Flow Matching (CFM-inspired) ---------------------------- Replaces the old temperature-based self-distillation with a Continuous Flow Matching framework. Training inputs are interpolated on the probability simplex between random noise and the target one-hot, distillation gives the student a noisier (earlier-time) version of the same interpolation path, and inference uses a flow map update rule. Parameters: - flow_matching_enabled: Enable the flow matching framework - flow_matching_lambda: Weight of distillation KL loss relative to CE loss - flow_matching_t_distribution: How to sample time t ("logit_normal" or "uniform") - flow_matching_t_logit_mean: Mean of logit-normal distribution (-0.4 biases toward noisy) - flow_matching_t_logit_std: Std of logit-normal distribution - flow_matching_t_min: Minimum time value (clamp) - flow_matching_t_max: Maximum time value (clamp) - flow_matching_mask_scale: If True, scale mask_emb by (1-t); if False, binary mask signal Time levels are sampled independently per masked token. At t=0 the input is pure noise, at t=1 it is the clean target embedding. Self-Distillation (legacy, temperature-based) ---------------------------------------------- Kept for backward compatibility. Ignored when flow_matching_enabled=True. Parameters: - self_distillation_enabled: Enable the self-distillation KL loss - self_distillation_lambda: Weight of distillation loss relative to CE loss - self_distillation_temperature_min: Minimum degradation temperature - self_distillation_temperature_max: Maximum degradation temperature - self_distillation_temperature_distribution: How to sample temperature - self_distillation_teacher: Which logits to use as teacher ("first" or "last") """ model_type = "recursive-mlm" def __init__( self, base_model_config: Optional[dict] = None, num_recursions: int = 8, normalization: str = "softmax", loss_weight: str = "linear", mask_token_id: Optional[int] = None, temperature: float = 1.0, gradient_steps: Optional[int] = None, # === Convergence schedule parameters === schedule: str = "linear", causal_strength: float = 1.0, # === Effect parameters === temperature_max: float = 0.0, entropy_target_max: float = 0.0, entropy_floor_max: float = 0.0, smear_sigma_max: float = 0.0, noise_std_max: float = 0.0, iteration_rope_dim_fraction: float = 0.0, use_recursion_checkpointing: bool = True, # === Soft embedding method === soft_embedding_method: str = "softmax", soft_embedding_ema_step: float = 1.0, # === Flow matching parameters (CFM-inspired) === flow_matching_enabled: bool = False, flow_matching_lambda: float = 0.5, flow_matching_t_distribution: str = "logit_normal", flow_matching_t_logit_mean: float = -0.4, flow_matching_t_logit_std: float = 1.0, flow_matching_t_min: float = 0.01, flow_matching_t_max: float = 0.99, flow_matching_noise_scale: float = 2.0, flow_matching_mask_scale: bool = False, # === Self-distillation parameters (legacy, ignored when flow_matching_enabled) === self_distillation_enabled: bool = False, self_distillation_lambda: float = 0.5, self_distillation_temperature_min: float = 1.5, self_distillation_temperature_max: float = 10.0, self_distillation_temperature_distribution: str = "log_uniform", self_distillation_teacher: str = "first", **kwargs, ): super().__init__(**kwargs) self.base_model_config = base_model_config self.num_recursions = num_recursions self.normalization = normalization self.loss_weight = loss_weight self.mask_token_id = mask_token_id self.temperature = temperature self.gradient_steps = gradient_steps # Convergence schedule self.schedule = schedule self.causal_strength = causal_strength # Effects self.temperature_max = temperature_max self.entropy_target_max = entropy_target_max self.entropy_floor_max = entropy_floor_max self.smear_sigma_max = smear_sigma_max self.noise_std_max = noise_std_max self.iteration_rope_dim_fraction = iteration_rope_dim_fraction # Recursion checkpointing self.use_recursion_checkpointing = use_recursion_checkpointing # Soft embedding method self.soft_embedding_method = soft_embedding_method self.soft_embedding_ema_step = soft_embedding_ema_step # Flow matching self.flow_matching_enabled = flow_matching_enabled self.flow_matching_lambda = flow_matching_lambda self.flow_matching_t_distribution = flow_matching_t_distribution self.flow_matching_t_logit_mean = flow_matching_t_logit_mean self.flow_matching_t_logit_std = flow_matching_t_logit_std self.flow_matching_t_min = flow_matching_t_min self.flow_matching_t_max = flow_matching_t_max self.flow_matching_noise_scale = flow_matching_noise_scale self.flow_matching_mask_scale = flow_matching_mask_scale # Self-distillation (legacy) self.self_distillation_enabled = self_distillation_enabled self.self_distillation_lambda = self_distillation_lambda self.self_distillation_temperature_min = self_distillation_temperature_min self.self_distillation_temperature_max = self_distillation_temperature_max self.self_distillation_temperature_distribution = self_distillation_temperature_distribution self.self_distillation_teacher = self_distillation_teacher @classmethod def from_base_model_config( cls, base_config: PretrainedConfig, **kwargs, ) -> "RecursiveMLMConfig": """Create config from a base MLM's config.""" return cls( base_model_config=base_config.to_dict(), **kwargs, )