LLaDA-8B-Recursive-ARC / configuration_recursive.py
Fraser's picture
Upload folder using huggingface_hub
7493ebb verified
from __future__ import annotations
from typing import Optional
from transformers import PretrainedConfig
class RecursiveMLMConfig(PretrainedConfig):
"""
Configuration for RecursiveMaskedLM.
Stores the base MLM config plus recursive refinement parameters.
Convergence Schedule System
---------------------------
The convergence schedule controls WHEN each position is allowed to converge
to a confident prediction during iterative refinement.
Schedule types:
- "linear": All positions converge at the same rate (iteration-based only)
- "causal": Early positions converge first, late positions last
Effects (mechanisms to enforce the schedule):
- temperature_max: Raise temperature for positions not yet allowed to converge
- entropy_target_max: Force exact entropy via bisection search (two-sided, recommended)
- entropy_floor_max: Force minimum entropy (one-sided, only raises)
- smear_sigma_max: Spread probability across neighboring positions
- noise_std_max: Add Gaussian noise to logits
- iteration_rope_dim_fraction: Apply rotary embedding based on iteration progress
Soft Embedding Methods
----------------------
Controls how logits are converted to soft embeddings for the next iteration:
- "softmax": Standard softmax normalization (default). Creates sparse, probabilistic
mixing but can cause gradient bottlenecks through the softmax Jacobian.
- "l2_normalize": L2 normalize logits before mixing with embeddings. Removes the
softmax bottleneck for smoother gradients through long recursion chains.
- "none": No normalization - use raw logits directly. Warning: this can cause
scale explosion without additional mechanisms like EMA accumulation.
- soft_embedding_ema_step: Controls EMA blending with previous soft embeddings.
1.0 (default) = full update (no EMA), 0.1 = slow update (90% previous + 10% new).
Formula: new = (1 - ema_step) * prev + ema_step * current
Recursion Checkpointing
-----------------------
Controls gradient flow through the entire recursion chain for memory-efficient training.
Parameters:
- use_recursion_checkpointing: Enable gradient checkpointing for iterations
- loss_weight: Use "last_1" for final-iteration-only loss (learns convergence behavior)
Flow Matching (CFM-inspired)
----------------------------
Replaces the old temperature-based self-distillation with a Continuous Flow Matching
framework. Training inputs are interpolated on the probability simplex between random
noise and the target one-hot, distillation gives the student a noisier (earlier-time)
version of the same interpolation path, and inference uses a flow map update rule.
Parameters:
- flow_matching_enabled: Enable the flow matching framework
- flow_matching_lambda: Weight of distillation KL loss relative to CE loss
- flow_matching_t_distribution: How to sample time t ("logit_normal" or "uniform")
- flow_matching_t_logit_mean: Mean of logit-normal distribution (-0.4 biases toward noisy)
- flow_matching_t_logit_std: Std of logit-normal distribution
- flow_matching_t_min: Minimum time value (clamp)
- flow_matching_t_max: Maximum time value (clamp)
- flow_matching_mask_scale: If True, scale mask_emb by (1-t); if False, binary mask signal
Time levels are sampled independently per masked token. At t=0 the input is pure noise,
at t=1 it is the clean target embedding.
Self-Distillation (legacy, temperature-based)
----------------------------------------------
Kept for backward compatibility. Ignored when flow_matching_enabled=True.
Parameters:
- self_distillation_enabled: Enable the self-distillation KL loss
- self_distillation_lambda: Weight of distillation loss relative to CE loss
- self_distillation_temperature_min: Minimum degradation temperature
- self_distillation_temperature_max: Maximum degradation temperature
- self_distillation_temperature_distribution: How to sample temperature
- self_distillation_teacher: Which logits to use as teacher ("first" or "last")
"""
model_type = "recursive-mlm"
def __init__(
self,
base_model_config: Optional[dict] = None,
num_recursions: int = 8,
normalization: str = "softmax",
loss_weight: str = "linear",
mask_token_id: Optional[int] = None,
temperature: float = 1.0,
gradient_steps: Optional[int] = None,
# === Convergence schedule parameters ===
schedule: str = "linear",
causal_strength: float = 1.0,
# === Effect parameters ===
temperature_max: float = 0.0,
entropy_target_max: float = 0.0,
entropy_floor_max: float = 0.0,
smear_sigma_max: float = 0.0,
noise_std_max: float = 0.0,
iteration_rope_dim_fraction: float = 0.0,
use_recursion_checkpointing: bool = True,
# === Soft embedding method ===
soft_embedding_method: str = "softmax",
soft_embedding_ema_step: float = 1.0,
# === Flow matching parameters (CFM-inspired) ===
flow_matching_enabled: bool = False,
flow_matching_lambda: float = 0.5,
flow_matching_t_distribution: str = "logit_normal",
flow_matching_t_logit_mean: float = -0.4,
flow_matching_t_logit_std: float = 1.0,
flow_matching_t_min: float = 0.01,
flow_matching_t_max: float = 0.99,
flow_matching_noise_scale: float = 2.0,
flow_matching_mask_scale: bool = False,
# === Self-distillation parameters (legacy, ignored when flow_matching_enabled) ===
self_distillation_enabled: bool = False,
self_distillation_lambda: float = 0.5,
self_distillation_temperature_min: float = 1.5,
self_distillation_temperature_max: float = 10.0,
self_distillation_temperature_distribution: str = "log_uniform",
self_distillation_teacher: str = "first",
**kwargs,
):
super().__init__(**kwargs)
self.base_model_config = base_model_config
self.num_recursions = num_recursions
self.normalization = normalization
self.loss_weight = loss_weight
self.mask_token_id = mask_token_id
self.temperature = temperature
self.gradient_steps = gradient_steps
# Convergence schedule
self.schedule = schedule
self.causal_strength = causal_strength
# Effects
self.temperature_max = temperature_max
self.entropy_target_max = entropy_target_max
self.entropy_floor_max = entropy_floor_max
self.smear_sigma_max = smear_sigma_max
self.noise_std_max = noise_std_max
self.iteration_rope_dim_fraction = iteration_rope_dim_fraction
# Recursion checkpointing
self.use_recursion_checkpointing = use_recursion_checkpointing
# Soft embedding method
self.soft_embedding_method = soft_embedding_method
self.soft_embedding_ema_step = soft_embedding_ema_step
# Flow matching
self.flow_matching_enabled = flow_matching_enabled
self.flow_matching_lambda = flow_matching_lambda
self.flow_matching_t_distribution = flow_matching_t_distribution
self.flow_matching_t_logit_mean = flow_matching_t_logit_mean
self.flow_matching_t_logit_std = flow_matching_t_logit_std
self.flow_matching_t_min = flow_matching_t_min
self.flow_matching_t_max = flow_matching_t_max
self.flow_matching_noise_scale = flow_matching_noise_scale
self.flow_matching_mask_scale = flow_matching_mask_scale
# Self-distillation (legacy)
self.self_distillation_enabled = self_distillation_enabled
self.self_distillation_lambda = self_distillation_lambda
self.self_distillation_temperature_min = self_distillation_temperature_min
self.self_distillation_temperature_max = self_distillation_temperature_max
self.self_distillation_temperature_distribution = self_distillation_temperature_distribution
self.self_distillation_teacher = self_distillation_teacher
@classmethod
def from_base_model_config(
cls,
base_config: PretrainedConfig,
**kwargs,
) -> "RecursiveMLMConfig":
"""Create config from a base MLM's config."""
return cls(
base_model_config=base_config.to_dict(),
**kwargs,
)