File size: 8,628 Bytes
7493ebb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from __future__ import annotations
from typing import Optional

from transformers import PretrainedConfig


class RecursiveMLMConfig(PretrainedConfig):
    """
    Configuration for RecursiveMaskedLM.

    Stores the base MLM config plus recursive refinement parameters.

    Convergence Schedule System
    ---------------------------
    The convergence schedule controls WHEN each position is allowed to converge
    to a confident prediction during iterative refinement.

    Schedule types:
        - "linear": All positions converge at the same rate (iteration-based only)
        - "causal": Early positions converge first, late positions last

    Effects (mechanisms to enforce the schedule):
        - temperature_max: Raise temperature for positions not yet allowed to converge
        - entropy_target_max: Force exact entropy via bisection search (two-sided, recommended)
        - entropy_floor_max: Force minimum entropy (one-sided, only raises)
        - smear_sigma_max: Spread probability across neighboring positions
        - noise_std_max: Add Gaussian noise to logits
        - iteration_rope_dim_fraction: Apply rotary embedding based on iteration progress

    Soft Embedding Methods
    ----------------------
    Controls how logits are converted to soft embeddings for the next iteration:
        - "softmax": Standard softmax normalization (default). Creates sparse, probabilistic
          mixing but can cause gradient bottlenecks through the softmax Jacobian.
        - "l2_normalize": L2 normalize logits before mixing with embeddings. Removes the
          softmax bottleneck for smoother gradients through long recursion chains.
        - "none": No normalization - use raw logits directly. Warning: this can cause
          scale explosion without additional mechanisms like EMA accumulation.

        - soft_embedding_ema_step: Controls EMA blending with previous soft embeddings.
          1.0 (default) = full update (no EMA), 0.1 = slow update (90% previous + 10% new).
          Formula: new = (1 - ema_step) * prev + ema_step * current

    Recursion Checkpointing
    -----------------------
    Controls gradient flow through the entire recursion chain for memory-efficient training.

    Parameters:
        - use_recursion_checkpointing: Enable gradient checkpointing for iterations
        - loss_weight: Use "last_1" for final-iteration-only loss (learns convergence behavior)

    Flow Matching (CFM-inspired)
    ----------------------------
    Replaces the old temperature-based self-distillation with a Continuous Flow Matching
    framework. Training inputs are interpolated on the probability simplex between random
    noise and the target one-hot, distillation gives the student a noisier (earlier-time)
    version of the same interpolation path, and inference uses a flow map update rule.

    Parameters:
        - flow_matching_enabled: Enable the flow matching framework
        - flow_matching_lambda: Weight of distillation KL loss relative to CE loss
        - flow_matching_t_distribution: How to sample time t ("logit_normal" or "uniform")
        - flow_matching_t_logit_mean: Mean of logit-normal distribution (-0.4 biases toward noisy)
        - flow_matching_t_logit_std: Std of logit-normal distribution
        - flow_matching_t_min: Minimum time value (clamp)
        - flow_matching_t_max: Maximum time value (clamp)
        - flow_matching_mask_scale: If True, scale mask_emb by (1-t); if False, binary mask signal

    Time levels are sampled independently per masked token. At t=0 the input is pure noise,
    at t=1 it is the clean target embedding.

    Self-Distillation (legacy, temperature-based)
    ----------------------------------------------
    Kept for backward compatibility. Ignored when flow_matching_enabled=True.

    Parameters:
        - self_distillation_enabled: Enable the self-distillation KL loss
        - self_distillation_lambda: Weight of distillation loss relative to CE loss
        - self_distillation_temperature_min: Minimum degradation temperature
        - self_distillation_temperature_max: Maximum degradation temperature
        - self_distillation_temperature_distribution: How to sample temperature
        - self_distillation_teacher: Which logits to use as teacher ("first" or "last")
    """
    model_type = "recursive-mlm"

    def __init__(
        self,
        base_model_config: Optional[dict] = None,
        num_recursions: int = 8,
        normalization: str = "softmax",
        loss_weight: str = "linear",
        mask_token_id: Optional[int] = None,
        temperature: float = 1.0,
        gradient_steps: Optional[int] = None,
        # === Convergence schedule parameters ===
        schedule: str = "linear",
        causal_strength: float = 1.0,
        # === Effect parameters ===
        temperature_max: float = 0.0,
        entropy_target_max: float = 0.0,
        entropy_floor_max: float = 0.0,
        smear_sigma_max: float = 0.0,
        noise_std_max: float = 0.0,
        iteration_rope_dim_fraction: float = 0.0,
        use_recursion_checkpointing: bool = True,
        # === Soft embedding method ===
        soft_embedding_method: str = "softmax",
        soft_embedding_ema_step: float = 1.0,
        # === Flow matching parameters (CFM-inspired) ===
        flow_matching_enabled: bool = False,
        flow_matching_lambda: float = 0.5,
        flow_matching_t_distribution: str = "logit_normal",
        flow_matching_t_logit_mean: float = -0.4,
        flow_matching_t_logit_std: float = 1.0,
        flow_matching_t_min: float = 0.01,
        flow_matching_t_max: float = 0.99,
        flow_matching_noise_scale: float = 2.0,
        flow_matching_mask_scale: bool = False,
        # === Self-distillation parameters (legacy, ignored when flow_matching_enabled) ===
        self_distillation_enabled: bool = False,
        self_distillation_lambda: float = 0.5,
        self_distillation_temperature_min: float = 1.5,
        self_distillation_temperature_max: float = 10.0,
        self_distillation_temperature_distribution: str = "log_uniform",
        self_distillation_teacher: str = "first",
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.base_model_config = base_model_config
        self.num_recursions = num_recursions
        self.normalization = normalization
        self.loss_weight = loss_weight
        self.mask_token_id = mask_token_id
        self.temperature = temperature
        self.gradient_steps = gradient_steps
        # Convergence schedule
        self.schedule = schedule
        self.causal_strength = causal_strength
        # Effects
        self.temperature_max = temperature_max
        self.entropy_target_max = entropy_target_max
        self.entropy_floor_max = entropy_floor_max
        self.smear_sigma_max = smear_sigma_max
        self.noise_std_max = noise_std_max
        self.iteration_rope_dim_fraction = iteration_rope_dim_fraction
        # Recursion checkpointing
        self.use_recursion_checkpointing = use_recursion_checkpointing
        # Soft embedding method
        self.soft_embedding_method = soft_embedding_method
        self.soft_embedding_ema_step = soft_embedding_ema_step
        # Flow matching
        self.flow_matching_enabled = flow_matching_enabled
        self.flow_matching_lambda = flow_matching_lambda
        self.flow_matching_t_distribution = flow_matching_t_distribution
        self.flow_matching_t_logit_mean = flow_matching_t_logit_mean
        self.flow_matching_t_logit_std = flow_matching_t_logit_std
        self.flow_matching_t_min = flow_matching_t_min
        self.flow_matching_t_max = flow_matching_t_max
        self.flow_matching_noise_scale = flow_matching_noise_scale
        self.flow_matching_mask_scale = flow_matching_mask_scale
        # Self-distillation (legacy)
        self.self_distillation_enabled = self_distillation_enabled
        self.self_distillation_lambda = self_distillation_lambda
        self.self_distillation_temperature_min = self_distillation_temperature_min
        self.self_distillation_temperature_max = self_distillation_temperature_max
        self.self_distillation_temperature_distribution = self_distillation_temperature_distribution
        self.self_distillation_teacher = self_distillation_teacher

    @classmethod
    def from_base_model_config(
        cls,
        base_config: PretrainedConfig,
        **kwargs,
    ) -> "RecursiveMLMConfig":
        """Create config from a base MLM's config."""
        return cls(
            base_model_config=base_config.to_dict(),
            **kwargs,
        )