| from __future__ import annotations |
| import warnings |
| from dataclasses import dataclass |
| from typing import NamedTuple, Optional |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.nn import CrossEntropyLoss |
| from torch.utils.checkpoint import checkpoint as torch_checkpoint |
| from transformers import AutoConfig, AutoModelForMaskedLM, PreTrainedModel |
| from transformers.modeling_outputs import MaskedLMOutput |
| from transformers.utils import ModelOutput |
|
|
| from .configuration_recursive import RecursiveMLMConfig |
|
|
|
|
| @dataclass |
| class IterationMetrics(ModelOutput): |
| """Metrics for a single iteration of recursive refinement.""" |
| accuracy: Optional[float] = None |
| entropy: Optional[float] = None |
| softmax_ce: Optional[float] = None |
| full_sequence_accuracy: Optional[float] = None |
| min_sequence_confidence: Optional[float] = None |
|
|
|
|
| @dataclass |
| class RecursiveMaskedLMOutput(MaskedLMOutput): |
| iteration_metrics: Optional[dict[int, IterationMetrics]] = None |
| next_soft_embeds: Optional[torch.Tensor] = None |
| all_logits: Optional[list[torch.Tensor]] = None |
| |
| flow_noise_embed: Optional[torch.Tensor] = None |
| flow_t: Optional[torch.Tensor] = None |
|
|
|
|
| class SelfDistillationOutput(NamedTuple): |
| """Output from self-distillation forward pass.""" |
| loss: torch.Tensor |
| teacher_logits: torch.Tensor |
| student_logits: torch.Tensor |
| degradation_temperature: float |
| teacher_entropy: float |
| student_entropy: float |
| agreement_rate: float |
|
|
|
|
| class RecursiveMaskedLM(PreTrainedModel): |
| """ |
| Wraps any HF MLM with recursive soft-token refinement. |
| |
| At each step: |
| 1. Normalize logits -> probs |
| 2. Compute soft embeddings: probs @ embedding_weight + mask_embedding |
| 3. Forward through MLM |
| 4. Accumulate weighted loss |
| """ |
| config_class = RecursiveMLMConfig |
| base_model_prefix = "mlm" |
| supports_gradient_checkpointing = True |
|
|
| def __init__(self, config: RecursiveMLMConfig, base_model: Optional[PreTrainedModel] = None): |
| super().__init__(config) |
|
|
| if base_model is not None: |
| |
| |
| self.mlm = base_model |
| elif config.base_model_config is not None: |
| model_type = config.base_model_config.get("model_type", "") |
| if model_type == "llada": |
| from .configuration_llada import LLaDAConfig |
| from .modeling_llada import LLaDAModelLM |
| base_config = LLaDAConfig.from_dict(config.base_model_config) |
| self.mlm = LLaDAModelLM(base_config) |
| else: |
| base_config = AutoConfig.for_model(**config.base_model_config) |
| self.mlm = AutoModelForMaskedLM.from_config(base_config) |
| |
| self.post_init() |
| else: |
| raise ValueError("Need either base_model or config.base_model_config") |
|
|
| @classmethod |
| def from_mlm_pretrained( |
| cls, |
| mlm_name_or_path: str, |
| num_recursions: int = 8, |
| normalization: str = "softmax", |
| loss_weight: str = "linear", |
| mask_token_id: Optional[int] = None, |
| temperature: float = 1.0, |
| gradient_steps: Optional[int] = None, |
| |
| schedule: str = "linear", |
| causal_strength: float = 1.0, |
| |
| temperature_max: float = 0.0, |
| entropy_target_max: float = 0.0, |
| entropy_floor_max: float = 0.0, |
| smear_sigma_max: float = 0.0, |
| noise_std_max: float = 0.0, |
| iteration_rope_dim_fraction: float = 0.0, |
| use_recursion_checkpointing: bool = True, |
| |
| soft_embedding_method: str = "softmax", |
| soft_embedding_ema_step: float = 1.0, |
| |
| flow_matching_enabled: bool = False, |
| flow_matching_lambda: float = 0.5, |
| flow_matching_t_distribution: str = "logit_normal", |
| flow_matching_t_logit_mean: float = -0.4, |
| flow_matching_t_logit_std: float = 1.0, |
| flow_matching_t_min: float = 0.01, |
| flow_matching_t_max: float = 0.99, |
| flow_matching_mask_scale: bool = False, |
| **model_kwargs, |
| ) -> "RecursiveMaskedLM": |
| """Load a pretrained MLM and wrap it for recursive refinement.""" |
| base_model = AutoModelForMaskedLM.from_pretrained(mlm_name_or_path, **model_kwargs) |
| return cls.from_base_model( |
| base_model, |
| num_recursions=num_recursions, |
| normalization=normalization, |
| loss_weight=loss_weight, |
| mask_token_id=mask_token_id, |
| temperature=temperature, |
| gradient_steps=gradient_steps, |
| schedule=schedule, |
| causal_strength=causal_strength, |
| temperature_max=temperature_max, |
| entropy_target_max=entropy_target_max, |
| entropy_floor_max=entropy_floor_max, |
| smear_sigma_max=smear_sigma_max, |
| noise_std_max=noise_std_max, |
| iteration_rope_dim_fraction=iteration_rope_dim_fraction, |
| use_recursion_checkpointing=use_recursion_checkpointing, |
| soft_embedding_method=soft_embedding_method, |
| soft_embedding_ema_step=soft_embedding_ema_step, |
| flow_matching_enabled=flow_matching_enabled, |
| flow_matching_lambda=flow_matching_lambda, |
| flow_matching_t_distribution=flow_matching_t_distribution, |
| flow_matching_t_logit_mean=flow_matching_t_logit_mean, |
| flow_matching_t_logit_std=flow_matching_t_logit_std, |
| flow_matching_t_min=flow_matching_t_min, |
| flow_matching_t_max=flow_matching_t_max, |
| flow_matching_mask_scale=flow_matching_mask_scale, |
| ) |
|
|
| @classmethod |
| def from_base_model( |
| cls, |
| base_model: PreTrainedModel, |
| num_recursions: int = 8, |
| normalization: str = "softmax", |
| loss_weight: str = "linear", |
| mask_token_id: Optional[int] = None, |
| temperature: float = 1.0, |
| gradient_steps: Optional[int] = None, |
| |
| schedule: str = "linear", |
| causal_strength: float = 1.0, |
| |
| temperature_max: float = 0.0, |
| entropy_target_max: float = 0.0, |
| entropy_floor_max: float = 0.0, |
| smear_sigma_max: float = 0.0, |
| noise_std_max: float = 0.0, |
| iteration_rope_dim_fraction: float = 0.0, |
| use_recursion_checkpointing: bool = True, |
| |
| soft_embedding_method: str = "softmax", |
| soft_embedding_ema_step: float = 1.0, |
| |
| flow_matching_enabled: bool = False, |
| flow_matching_lambda: float = 0.5, |
| flow_matching_t_distribution: str = "logit_normal", |
| flow_matching_t_logit_mean: float = -0.4, |
| flow_matching_t_logit_std: float = 1.0, |
| flow_matching_t_min: float = 0.01, |
| flow_matching_t_max: float = 0.99, |
| flow_matching_mask_scale: bool = False, |
| ) -> "RecursiveMaskedLM": |
| """Wrap an existing model for recursive refinement. |
| |
| Use this for models not loadable via AutoModelForMaskedLM (e.g., LLaDA). |
| |
| Args: |
| base_model: The base MLM model to wrap |
| num_recursions: Number of recursive refinement steps |
| normalization: Normalization method for logits (softmax, stable_softmax) |
| loss_weight: Loss weighting scheme (last_1, last_2, linear, uniform) |
| mask_token_id: Token ID for [MASK] |
| temperature: Temperature for softmax normalization |
| gradient_steps: Number of final steps to backprop through |
| schedule: Convergence schedule type ("linear" or "causal") |
| causal_strength: How much faster early positions converge (causal only) |
| temperature_max: Max temperature boost for uncertain positions |
| entropy_target_max: Target entropy at progress=0 (two-sided, recommended) |
| entropy_floor_max: Min entropy floor (one-sided) |
| smear_sigma_max: Max Gaussian sigma for position smearing |
| noise_std_max: Max std of Gaussian noise on logits |
| iteration_rope_dim_fraction: Fraction of dims for iteration RoPE |
| use_recursion_checkpointing: Enable gradient checkpointing for iterations |
| soft_embedding_method: How to convert logits to soft embeddings |
| soft_embedding_ema_step: EMA step size (1.0 = no EMA, <1.0 = blend with previous) |
| flow_matching_enabled: Enable CFM-inspired flow matching framework |
| flow_matching_lambda: Weight of distillation KL loss relative to CE |
| flow_matching_t_distribution: Time sampling distribution ("logit_normal" or "uniform") |
| flow_matching_t_logit_mean: Mean of logit-normal distribution |
| flow_matching_t_logit_std: Std of logit-normal distribution |
| flow_matching_t_min: Minimum time value (clamp) |
| flow_matching_t_max: Maximum time value (clamp) |
| flow_matching_mask_scale: Scale mask_emb by (1-t) if True, binary if False |
| """ |
| config = RecursiveMLMConfig.from_base_model_config( |
| base_model.config, |
| num_recursions=num_recursions, |
| normalization=normalization, |
| loss_weight=loss_weight, |
| mask_token_id=mask_token_id, |
| temperature=temperature, |
| gradient_steps=gradient_steps, |
| schedule=schedule, |
| causal_strength=causal_strength, |
| temperature_max=temperature_max, |
| entropy_target_max=entropy_target_max, |
| entropy_floor_max=entropy_floor_max, |
| smear_sigma_max=smear_sigma_max, |
| noise_std_max=noise_std_max, |
| iteration_rope_dim_fraction=iteration_rope_dim_fraction, |
| use_recursion_checkpointing=use_recursion_checkpointing, |
| soft_embedding_method=soft_embedding_method, |
| soft_embedding_ema_step=soft_embedding_ema_step, |
| flow_matching_enabled=flow_matching_enabled, |
| flow_matching_lambda=flow_matching_lambda, |
| flow_matching_t_distribution=flow_matching_t_distribution, |
| flow_matching_t_logit_mean=flow_matching_t_logit_mean, |
| flow_matching_t_logit_std=flow_matching_t_logit_std, |
| flow_matching_t_min=flow_matching_t_min, |
| flow_matching_t_max=flow_matching_t_max, |
| flow_matching_mask_scale=flow_matching_mask_scale, |
| ) |
| return cls(config, base_model=base_model) |
|
|
| @property |
| def embed_weight(self) -> torch.Tensor: |
| return self.mlm.get_input_embeddings().weight |
|
|
| def get_input_embeddings(self): |
| return self.mlm.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.mlm.set_input_embeddings(value) |
|
|
| def get_output_embeddings(self): |
| return self.mlm.get_output_embeddings() |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.mlm.set_output_embeddings(new_embeddings) |
|
|
| def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): |
| """Enable gradient checkpointing with correct settings for recursion. |
| |
| Forces use_reentrant=False which is required for: |
| - Nested checkpoint calls (base model + recursion checkpointing) |
| - Models with frozen parameters |
| - Complex gradient flows through soft embeddings |
| """ |
| if gradient_checkpointing_kwargs is None: |
| gradient_checkpointing_kwargs = {} |
| |
| gradient_checkpointing_kwargs.setdefault("use_reentrant", False) |
| self.mlm.gradient_checkpointing_enable(gradient_checkpointing_kwargs) |
|
|
| def gradient_checkpointing_disable(self): |
| """Disable gradient checkpointing in the underlying MLM.""" |
| self.mlm.gradient_checkpointing_disable() |
|
|
| def _single_iteration_checkpointable( |
| self, |
| soft_embeds: torch.Tensor, |
| base_embeds: torch.Tensor, |
| mask_pos: torch.Tensor, |
| attention_mask: torch.Tensor, |
| embed_weight: torch.Tensor, |
| mask_emb: torch.Tensor, |
| temperature: torch.Tensor, |
| position_ids: Optional[torch.Tensor] = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Single differentiable iteration for checkpointing. |
| |
| This method performs one iteration of recursive refinement in a way that |
| maintains gradient flow and is compatible with torch.utils.checkpoint. |
| |
| Args: |
| soft_embeds: (B, L, H) - current soft embeddings |
| base_embeds: (B, L, H) - original token embeddings |
| mask_pos: (B, L) bool - which positions are masked |
| attention_mask: (B, L) - attention mask for MLM |
| embed_weight: (V, H) - embedding weight matrix |
| mask_emb: (H,) - mask token embedding |
| temperature: scalar tensor - softmax temperature |
| |
| Returns: |
| logits: (B, L, V) - output logits from this iteration |
| next_soft_embeds: (B, L, H) - soft embeddings for next iteration |
| """ |
| |
| inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds) |
|
|
| |
| outputs = self.mlm( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| return_dict=True, |
| ) |
| logits = outputs.logits |
|
|
| |
| next_soft_embeds = base_embeds.clone() |
| if mask_pos.any(): |
| masked_logits = logits[mask_pos] |
|
|
| |
| if self.config.soft_embedding_method == "none": |
| |
| weights = masked_logits |
| elif self.config.soft_embedding_method == "l2_normalize": |
| |
| weights = F.normalize(masked_logits, p=2, dim=-1) |
| else: |
| |
| weights = F.softmax(masked_logits / temperature, dim=-1) |
|
|
| soft_emb = weights @ embed_weight + mask_emb |
|
|
| |
| ema_step = self.config.soft_embedding_ema_step |
| if ema_step < 1.0: |
| prev_soft_emb = soft_embeds[mask_pos] |
| soft_emb = (1.0 - ema_step) * prev_soft_emb + ema_step * soft_emb |
|
|
| next_soft_embeds[mask_pos] = soft_emb |
|
|
| return logits, next_soft_embeds |
|
|
| def _stable_softmax(self, logits: torch.Tensor, T: float = 1.0, dim: int = -1, eps: float = 1e-12) -> torch.Tensor: |
| """Numerically stable softmax with temperature T > 0.""" |
| z = logits / max(T, eps) |
| z = z - z.max(dim=dim, keepdim=True).values |
| z = torch.exp(z) |
| z_sum = z.sum(dim=dim, keepdim=True) |
| return z / z_sum.clamp(min=eps) |
|
|
| def normalize(self, logits: torch.Tensor) -> torch.Tensor: |
| """Normalize logits -> mixing weights. Shape: (B, L, V) -> (B, L, V)""" |
| norm = self.config.normalization.lower() |
| T = self.config.temperature |
| V = logits.shape[-1] |
|
|
| if norm == "none": |
| return logits |
|
|
| if norm == "softmax": |
| return torch.softmax(logits / T, dim=-1) |
|
|
| if norm == "stable_softmax": |
| return self._stable_softmax(logits, T=T, dim=-1) |
|
|
| raise ValueError(f"Unknown normalization: {norm}") |
|
|
| def step_weight(self, t: int, T: int) -> float: |
| """Loss weight for step t of T.""" |
| lw = self.config.loss_weight |
| if lw == "linear": |
| return (t + 1) / T |
| if lw == "uniform": |
| return 1.0 |
| if lw == "last_1": |
| return 1.0 if t == T - 1 else 0.0 |
| if lw == "last_2": |
| return 1.0 if T - t <= 2 else 0.0 |
| raise ValueError(f"Unknown loss_weight: {lw}") |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def _compute_convergence_progress( |
| self, |
| iteration: int, |
| total_iterations: int, |
| seq_length: int, |
| mask_positions: torch.Tensor, |
| schedule: str = "linear", |
| causal_strength: float = 1.0, |
| device: torch.device = None, |
| dtype: torch.dtype = None, |
| ) -> torch.Tensor: |
| """ |
| Compute per-position convergence progress based on schedule. |
| |
| Args: |
| iteration: Current iteration (0-indexed) |
| total_iterations: Total number of iterations |
| seq_length: Full sequence length L |
| mask_positions: Position indices of masked tokens (num_masked,) |
| schedule: "linear" or "causal" |
| causal_strength: How much faster early positions converge (for causal schedule) |
| |
| Returns: |
| progress: (num_masked,) tensor with values in [0, 1] |
| 0 = position should be maximally uncertain |
| 1 = position is allowed to fully converge |
| """ |
| base_progress = iteration / max(total_iterations - 1, 1) |
|
|
| if schedule == "linear": |
| return torch.full( |
| (mask_positions.shape[0],), |
| base_progress, |
| device=device, |
| dtype=dtype |
| ) |
|
|
| elif schedule == "causal": |
| position_factor = mask_positions.float() / max(seq_length - 1, 1) |
| effective_progress = base_progress * (1.0 + causal_strength * (1.0 - position_factor)) |
| return effective_progress.clamp(0.0, 1.0) |
|
|
| else: |
| raise ValueError(f"Unknown schedule: {schedule}") |
|
|
| def _apply_temperature_effect( |
| self, |
| logits: torch.Tensor, |
| progress: torch.Tensor, |
| temperature_max: float, |
| ) -> torch.Tensor: |
| """ |
| Apply per-position temperature scaling based on convergence progress. |
| Low progress = high temperature (uncertain), high progress = temperature 1.0. |
| """ |
| if temperature_max <= 0: |
| return logits |
|
|
| temperature = 1.0 + temperature_max * (1.0 - progress) |
| temperature = temperature.unsqueeze(-1) |
|
|
| return logits / temperature |
|
|
| def _apply_entropy_floor_effect( |
| self, |
| probs: torch.Tensor, |
| progress: torch.Tensor, |
| entropy_floor_max: float, |
| ) -> torch.Tensor: |
| """ |
| Ensure minimum entropy based on convergence progress. |
| Low progress = high entropy floor, high progress = no floor. |
| |
| NOTE: This is a ONE-SIDED constraint (floor only). |
| """ |
| if entropy_floor_max <= 0: |
| return probs |
|
|
| entropy_floor = entropy_floor_max * (1.0 - progress) |
|
|
| log_probs = torch.log(probs + 1e-10) |
| current_entropy = -(probs * log_probs).sum(dim=-1) |
|
|
| below_floor = current_entropy < entropy_floor |
|
|
| if not below_floor.any(): |
| return probs |
|
|
| logits = torch.log(probs + 1e-10) |
|
|
| target_ratio = entropy_floor / (current_entropy + 1e-10) |
| temperature = torch.ones_like(current_entropy) |
| temperature[below_floor] = target_ratio[below_floor].clamp(1.0, 10.0) |
|
|
| scaled_probs = torch.softmax(logits / temperature.unsqueeze(-1), dim=-1) |
|
|
| result = probs.clone() |
| result[below_floor] = scaled_probs[below_floor] |
| return result |
|
|
| def _find_temperature_for_target_entropy( |
| self, |
| logits: torch.Tensor, |
| target_entropy: torch.Tensor, |
| tol: float = 1e-3, |
| max_iter: int = 32, |
| T_low: float = 1e-6, |
| T_high_init: float = 1.0, |
| max_T: float = 100.0, |
| ) -> torch.Tensor: |
| """ |
| Find per-position temperatures that achieve exactly the target entropy. |
| Uses bisection search, adapted from ARChitects' implementation. |
| |
| Args: |
| logits: Raw logits (num_positions, V) |
| target_entropy: Target entropy per position (num_positions,) or scalar |
| tol: Entropy tolerance for convergence |
| max_iter: Maximum bisection iterations |
| T_low: Minimum temperature (near-greedy) |
| T_high_init: Initial upper bound for search |
| max_T: Maximum allowed temperature |
| |
| Returns: |
| temperatures: (num_positions,) temperatures that achieve target entropy |
| """ |
| N, V = logits.shape |
| device, dtype = logits.device, logits.dtype |
| H_max = torch.log(torch.tensor(V, device=device, dtype=dtype)) |
|
|
| if target_entropy.dim() == 0: |
| target = target_entropy.expand(N).clone() |
| else: |
| target = target_entropy.clone() |
| target = target.clamp(0.0, H_max) |
|
|
| def compute_entropy(logits_: torch.Tensor, temps: torch.Tensor) -> torch.Tensor: |
| temps = temps.unsqueeze(-1).clamp(min=T_low) |
| scaled = logits_ / temps |
| scaled = scaled - scaled.max(dim=-1, keepdim=True).values |
| probs = torch.softmax(scaled, dim=-1) |
| log_probs = torch.log(probs + 1e-12) |
| return -(probs * log_probs).sum(dim=-1) |
|
|
| lo = torch.full((N,), T_low, device=device, dtype=dtype) |
| hi = torch.full((N,), T_high_init, device=device, dtype=dtype) |
|
|
| H_lo = compute_entropy(logits, lo) |
|
|
| done_low = target <= (H_lo + tol) |
|
|
| H_hi = compute_entropy(logits, hi) |
| needs_expansion = (H_hi < target - tol) & ~done_low |
|
|
| for _ in range(100): |
| if not needs_expansion.any(): |
| break |
| hi[needs_expansion] = (hi[needs_expansion] * 2.0).clamp(max=max_T) |
| H_hi[needs_expansion] = compute_entropy( |
| logits[needs_expansion], hi[needs_expansion] |
| ) |
| needs_expansion = (H_hi < target - tol) & ~done_low & (hi < max_T - 1e-6) |
|
|
| can_bisect = ~done_low & (H_hi >= target - tol) |
|
|
| for _ in range(max_iter): |
| if not can_bisect.any(): |
| break |
|
|
| mid = (lo + hi) / 2.0 |
| H_mid = compute_entropy(logits, mid) |
|
|
| too_low = (H_mid < target) & can_bisect |
| lo[too_low] = mid[too_low] |
| hi[~too_low & can_bisect] = mid[~too_low & can_bisect] |
|
|
| converged = (hi - lo) <= tol * mid.clamp(min=1.0) |
| can_bisect = can_bisect & ~converged |
|
|
| temps = torch.zeros(N, device=device, dtype=dtype) |
| temps[done_low] = T_low |
| temps[~done_low] = (lo[~done_low] + hi[~done_low]) / 2.0 |
|
|
| return temps |
|
|
| def _apply_target_entropy_effect( |
| self, |
| logits: torch.Tensor, |
| progress: torch.Tensor, |
| entropy_target_max: float, |
| entropy_target_min: float = 0.0, |
| ) -> torch.Tensor: |
| """ |
| Adjust temperature to achieve EXACTLY the target entropy per position. |
| This is a TWO-SIDED constraint: both raises and lowers entropy as needed. |
| |
| Args: |
| logits: Raw logits (num_masked, V) |
| progress: Per-position convergence progress (num_masked,) |
| entropy_target_max: Target entropy at progress=0 |
| entropy_target_min: Target entropy at progress=1 (usually ~0) |
| |
| Returns: |
| probs: Probabilities with entropy matching targets |
| """ |
| if entropy_target_max <= 0: |
| return torch.softmax(logits, dim=-1) |
|
|
| target_entropy = entropy_target_max * (1.0 - progress) + entropy_target_min * progress |
|
|
| temps = self._find_temperature_for_target_entropy(logits, target_entropy) |
|
|
| temps = temps.unsqueeze(-1).clamp(min=1e-6) |
| return torch.softmax(logits / temps, dim=-1) |
|
|
| def _apply_smear_effect( |
| self, |
| probs: torch.Tensor, |
| mask_pos: torch.Tensor, |
| progress_full: torch.Tensor, |
| smear_sigma_max: float, |
| ) -> torch.Tensor: |
| """ |
| Apply positional smearing with per-position sigma based on progress. |
| Low progress = high smearing, high progress = no smearing. |
| |
| Note: This operates on full (B, L, V) tensor because smearing mixes across positions. |
| """ |
| if smear_sigma_max <= 0: |
| return probs |
|
|
| B, L, V = probs.shape |
|
|
| sigma_per_pos = smear_sigma_max * (1.0 - progress_full) |
|
|
| avg_sigma = sigma_per_pos[mask_pos].mean().item() |
|
|
| if avg_sigma < 0.1: |
| return probs |
|
|
| positions = torch.arange(L, device=probs.device, dtype=probs.dtype) |
| diff = positions.unsqueeze(0) - positions.unsqueeze(1) |
| kernel = torch.exp(-0.5 * (diff / avg_sigma) ** 2) |
| kernel = kernel / kernel.sum(dim=1, keepdim=True) |
|
|
| smeared = torch.einsum('ij,bjv->biv', kernel, probs) |
| smeared = smeared / smeared.sum(dim=-1, keepdim=True).clamp(min=1e-10) |
|
|
| blend = progress_full.unsqueeze(-1) |
| result = blend * probs + (1 - blend) * smeared |
|
|
| output = probs.clone() |
| output[mask_pos] = result[mask_pos] |
| return output |
|
|
| def _apply_noise_effect( |
| self, |
| logits: torch.Tensor, |
| progress: torch.Tensor, |
| noise_std_max: float, |
| ) -> torch.Tensor: |
| """ |
| Add Gaussian noise to logits based on convergence progress. |
| Low progress = high noise, high progress = no noise. |
| """ |
| if noise_std_max <= 0: |
| return logits |
|
|
| noise_std = noise_std_max * (1.0 - progress) |
| noise_std = noise_std.unsqueeze(-1) |
|
|
| noise = torch.randn_like(logits) * noise_std |
| return logits + noise |
|
|
| def _apply_iteration_rope( |
| self, |
| embeds: torch.Tensor, |
| iteration: int, |
| total_iterations: int, |
| dim_fraction: float = 0.25, |
| base: float = 10000.0, |
| ) -> torch.Tensor: |
| """ |
| Apply rotary embedding based on iteration progress. |
| Uses a subset of dimensions to avoid interfering with position RoPE. |
| """ |
| if dim_fraction <= 0: |
| return embeds |
|
|
| H = embeds.shape[-1] |
| rot_dim = int(H * dim_fraction) |
| rot_dim = rot_dim - (rot_dim % 2) |
|
|
| if rot_dim < 2: |
| return embeds |
|
|
| progress = iteration / max(total_iterations - 1, 1) |
|
|
| inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2, device=embeds.device, dtype=embeds.dtype) / rot_dim)) |
| angles = progress * inv_freq * 3.14159 |
| cos, sin = torch.cos(angles), torch.sin(angles) |
|
|
| if embeds.dim() == 2: |
| cos, sin = cos.unsqueeze(0), sin.unsqueeze(0) |
| elif embeds.dim() == 3: |
| cos = cos.unsqueeze(0).unsqueeze(0) |
| sin = sin.unsqueeze(0).unsqueeze(0) |
|
|
| embeds_out = embeds.clone() |
| x1, x2 = embeds[..., -rot_dim::2], embeds[..., -rot_dim+1::2] |
| embeds_out[..., -rot_dim::2] = x1 * cos - x2 * sin |
| embeds_out[..., -rot_dim+1::2] = x1 * sin + x2 * cos |
|
|
| return embeds_out |
|
|
| |
|
|
| def _sample_flow_matching_t(self, num_tokens: int, device: torch.device) -> torch.Tensor: |
| """Sample per-token time levels for flow matching. |
| |
| Returns: |
| t: (num_tokens,) tensor of time levels in [t_min, t_max] |
| """ |
| dist = self.config.flow_matching_t_distribution |
| if dist == "logit_normal": |
| z = torch.randn(num_tokens, device=device) |
| z = z * self.config.flow_matching_t_logit_std + self.config.flow_matching_t_logit_mean |
| t = torch.sigmoid(z) |
| elif dist == "uniform": |
| t = torch.empty(num_tokens, device=device).uniform_(0, 1) |
| else: |
| raise ValueError(f"Unknown flow_matching_t_distribution: {dist}") |
| return t.clamp(self.config.flow_matching_t_min, self.config.flow_matching_t_max) |
|
|
| def compute_flow_matching_distillation_loss( |
| self, |
| input_ids: torch.Tensor, |
| teacher_logits: torch.Tensor, |
| labels: torch.Tensor, |
| flow_noise_embed: torch.Tensor, |
| flow_t: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| ) -> SelfDistillationOutput: |
| """ |
| CFM flow matching distillation: teacher sees state at time t, student sees |
| noisier state at time s < t on the same interpolation path. |
| |
| Both should predict the same endpoint (target token). The student must |
| learn to refine from noisier inputs by matching the teacher's predictions. |
| |
| Args: |
| input_ids: Input with [MASK] tokens at positions to predict |
| teacher_logits: Logits from the forward pass (will be detached) |
| labels: Target tokens at masked positions (-100 elsewhere) |
| flow_noise_embed: (num_masked, H) noise embeddings from forward |
| flow_t: (num_masked,) per-token time levels from forward |
| attention_mask: Standard attention mask |
| position_ids: Position IDs (if needed by base model) |
| |
| Returns: |
| SelfDistillationOutput with loss, logits, time gap, and diagnostics |
| """ |
| mask_id = self.config.mask_token_id |
| mask_pos = (input_ids == mask_id) |
| device = input_ids.device |
| num_masked = mask_pos.sum().item() |
|
|
| if num_masked == 0: |
| zero = torch.tensor(0.0, device=device, requires_grad=True) |
| dummy = torch.zeros(1, device=device) |
| return SelfDistillationOutput(zero, dummy, dummy, 0.0, 0.0, 0.0, 1.0) |
|
|
| teacher_logits = teacher_logits.detach() |
|
|
| embed_weight = self.embed_weight |
| mask_emb = embed_weight[mask_id] |
| base_embeds = self.get_input_embeddings()(input_ids) |
|
|
| |
| target_ids = labels[mask_pos] |
| target_embed = embed_weight[target_ids] |
|
|
| |
| s_per_token = flow_t * torch.rand(num_masked, device=device) |
|
|
| |
| s_col = s_per_token.unsqueeze(-1).to(base_embeds.dtype) |
| student_interp = (1 - s_col) * flow_noise_embed + s_col * target_embed |
|
|
| if self.config.flow_matching_mask_scale: |
| student_masked_embeds = student_interp + (1 - s_col) * mask_emb |
| else: |
| student_masked_embeds = student_interp + mask_emb |
|
|
| |
| student_embeds = base_embeds.detach().clone() |
| student_embeds[mask_pos] = student_masked_embeds.detach() |
|
|
| student_inputs = torch.where( |
| mask_pos.unsqueeze(-1), student_embeds, base_embeds.detach() |
| ) |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids, dtype=base_embeds.dtype) |
|
|
| student_out = self.mlm( |
| inputs_embeds=student_inputs, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| return_dict=True, |
| ) |
| student_logits = student_out.logits |
|
|
| |
| t_logits = teacher_logits[mask_pos] |
| s_logits = student_logits[mask_pos] |
|
|
| teacher_probs = F.softmax(t_logits, dim=-1) |
| student_log_probs = F.log_softmax(s_logits, dim=-1) |
|
|
| kl_loss = F.kl_div( |
| student_log_probs, |
| teacher_probs, |
| reduction="batchmean", |
| ) |
|
|
| |
| with torch.no_grad(): |
| teacher_log_probs = torch.log(teacher_probs + 1e-10) |
| teacher_entropy = -(teacher_probs * teacher_log_probs).sum(dim=-1).mean().item() |
|
|
| student_probs = F.softmax(s_logits.detach(), dim=-1) |
| student_log_probs_det = torch.log(student_probs + 1e-10) |
| student_entropy = -(student_probs * student_log_probs_det).sum(dim=-1).mean().item() |
|
|
| agreement = (t_logits.argmax(dim=-1) == s_logits.detach().argmax(dim=-1)).float().mean().item() |
|
|
| mean_time_gap = (flow_t - s_per_token).mean().item() |
|
|
| return SelfDistillationOutput( |
| loss=kl_loss, |
| teacher_logits=teacher_logits, |
| student_logits=student_logits, |
| degradation_temperature=mean_time_gap, |
| teacher_entropy=teacher_entropy, |
| student_entropy=student_entropy, |
| agreement_rate=agreement, |
| ) |
|
|
| |
|
|
| def compute_self_distillation_loss( |
| self, |
| input_ids: torch.Tensor, |
| teacher_logits: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| temperature_min: Optional[float] = None, |
| temperature_max: Optional[float] = None, |
| temperature_distribution: Optional[str] = None, |
| ) -> SelfDistillationOutput: |
| """ |
| CFM-style self-distillation: model's predictions should be consistent |
| across different levels of input degradation. |
| |
| Process: |
| 1. Take teacher logits (from standard forward pass, DETACHED) |
| 2. Degrade: per-token random temperature → softer soft embeddings |
| 3. Student: forward pass from degraded embeddings → logits (has grad) |
| 4. Loss: KL(teacher || student) on masked positions |
| |
| Each masked token gets its own independently sampled degradation |
| temperature, creating varied difficulty across the sequence. |
| |
| Args: |
| input_ids: Input with [MASK] tokens at positions to predict |
| teacher_logits: Pre-computed teacher logits (will be detached). |
| Typically outputs.all_logits[0] or outputs.logits from standard forward. |
| attention_mask: Standard attention mask |
| position_ids: Position IDs (if needed by base model) |
| temperature_min: Min degradation temperature (default: config value) |
| temperature_max: Max degradation temperature (default: config value) |
| temperature_distribution: How to sample T (default: config value) |
| |
| Returns: |
| SelfDistillationOutput with loss, logits, temperature, and diagnostics |
| """ |
| |
| temperature_min = temperature_min if temperature_min is not None else self.config.self_distillation_temperature_min |
| temperature_max = temperature_max if temperature_max is not None else self.config.self_distillation_temperature_max |
| temperature_distribution = temperature_distribution if temperature_distribution is not None else self.config.self_distillation_temperature_distribution |
|
|
| mask_id = self.config.mask_token_id |
| mask_pos = (input_ids == mask_id) |
| device = input_ids.device |
| num_masked = mask_pos.sum().item() |
|
|
| |
| if num_masked == 0: |
| zero = torch.tensor(0.0, device=device, requires_grad=True) |
| dummy = torch.zeros(1, device=device) |
| return SelfDistillationOutput(zero, dummy, dummy, 1.0, 0.0, 0.0, 1.0) |
|
|
| |
| teacher_logits = teacher_logits.detach() |
|
|
| embed_weight = self.embed_weight |
| mask_emb = embed_weight[mask_id] |
| base_embeds = self.get_input_embeddings()(input_ids) |
|
|
| |
| |
| if temperature_distribution == "log_uniform": |
| log_min = torch.tensor(temperature_min, device=device).log() |
| log_max = torch.tensor(temperature_max, device=device).log() |
| log_T = torch.empty(num_masked, device=device).uniform_(log_min.item(), log_max.item()) |
| T_per_token = log_T.exp() |
| elif temperature_distribution == "uniform": |
| T_per_token = torch.empty(num_masked, device=device).uniform_( |
| temperature_min, temperature_max |
| ) |
| else: |
| raise ValueError(f"Unknown temperature distribution: {temperature_distribution}") |
|
|
| T_mean = T_per_token.mean().item() |
|
|
| |
| |
| masked_teacher_logits = teacher_logits[mask_pos] |
| degraded_probs = F.softmax(masked_teacher_logits / T_per_token.unsqueeze(-1), dim=-1).to(embed_weight.dtype) |
| degraded_soft = degraded_probs @ embed_weight + mask_emb |
|
|
| degraded_soft_embeds = base_embeds.clone() |
| degraded_soft_embeds[mask_pos] = degraded_soft |
| degraded_soft_embeds = degraded_soft_embeds.detach() |
|
|
| |
| student_inputs = torch.where( |
| mask_pos.unsqueeze(-1), degraded_soft_embeds, base_embeds.detach() |
| ) |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids, dtype=base_embeds.dtype) |
|
|
| student_out = self.mlm( |
| inputs_embeds=student_inputs, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| return_dict=True, |
| ) |
| student_logits = student_out.logits |
|
|
| |
| t_logits = teacher_logits[mask_pos] |
| s_logits = student_logits[mask_pos] |
|
|
| teacher_probs = F.softmax(t_logits, dim=-1) |
| student_log_probs = F.log_softmax(s_logits, dim=-1) |
|
|
| |
| kl_loss = F.kl_div( |
| student_log_probs, |
| teacher_probs, |
| reduction="batchmean", |
| ) |
|
|
| |
| with torch.no_grad(): |
| teacher_log_probs = torch.log(teacher_probs + 1e-10) |
| teacher_entropy = -(teacher_probs * teacher_log_probs).sum(dim=-1).mean().item() |
|
|
| student_probs = F.softmax(s_logits.detach(), dim=-1) |
| student_log_probs_det = torch.log(student_probs + 1e-10) |
| student_entropy = -(student_probs * student_log_probs_det).sum(dim=-1).mean().item() |
|
|
| agreement = (t_logits.argmax(dim=-1) == s_logits.detach().argmax(dim=-1)).float().mean().item() |
|
|
| return SelfDistillationOutput( |
| loss=kl_loss, |
| teacher_logits=teacher_logits, |
| student_logits=student_logits, |
| degradation_temperature=T_mean, |
| teacher_entropy=teacher_entropy, |
| student_entropy=student_entropy, |
| agreement_rate=agreement, |
| ) |
|
|
| |
|
|
| @torch.no_grad() |
| def _compute_next_soft_embeds( |
| self, |
| logits: torch.Tensor, |
| mask_pos: torch.Tensor, |
| base_embeds: torch.Tensor, |
| prev_soft_embeds: Optional[torch.Tensor] = None, |
| iteration: int = 0, |
| total_iterations: int = 1, |
| |
| schedule: Optional[str] = None, |
| causal_strength: Optional[float] = None, |
| |
| temperature_max: Optional[float] = None, |
| entropy_target_max: Optional[float] = None, |
| entropy_floor_max: Optional[float] = None, |
| smear_sigma_max: Optional[float] = None, |
| noise_std_max: Optional[float] = None, |
| iteration_rope_dim_fraction: Optional[float] = None, |
| ) -> torch.Tensor: |
| """ |
| Compute soft embeddings from logits for the next iteration. |
| |
| This function implements a unified "convergence schedule" system that controls |
| when each position is allowed to converge to a confident prediction. |
| |
| Schedule Types: |
| "linear": All positions converge at the same rate (iteration-based only) |
| "causal": Early positions converge first, late positions last |
| |
| Effects (mechanisms to enforce the schedule): |
| temperature_max: High temperature = more uniform distribution (one-sided) |
| entropy_target_max: Force EXACT entropy via bisection search (two-sided, recommended) |
| entropy_floor_max: Force MINIMUM entropy (one-sided, only prevents too confident) |
| smear_sigma_max: Spread probability across neighboring positions |
| noise_std_max: Add Gaussian noise to logits |
| |
| All parameters default to their config values if not specified. |
| |
| Args: |
| logits: Output logits from current iteration (B, L, V) |
| mask_pos: Boolean mask indicating which positions are masked (B, L) |
| base_embeds: Base token embeddings for non-masked positions (B, L, H) |
| iteration: Current iteration index (0-indexed) |
| total_iterations: Total number of iterations |
| |
| Returns: |
| Soft embeddings for next iteration (B, L, H) |
| """ |
| |
| schedule = schedule if schedule is not None else self.config.schedule |
| causal_strength = causal_strength if causal_strength is not None else self.config.causal_strength |
| temperature_max = temperature_max if temperature_max is not None else self.config.temperature_max |
| entropy_target_max = entropy_target_max if entropy_target_max is not None else self.config.entropy_target_max |
| entropy_floor_max = entropy_floor_max if entropy_floor_max is not None else self.config.entropy_floor_max |
| smear_sigma_max = smear_sigma_max if smear_sigma_max is not None else self.config.smear_sigma_max |
| noise_std_max = noise_std_max if noise_std_max is not None else self.config.noise_std_max |
| iteration_rope_dim_fraction = iteration_rope_dim_fraction if iteration_rope_dim_fraction is not None else self.config.iteration_rope_dim_fraction |
|
|
| soft_embeds = base_embeds.clone() |
|
|
| if not mask_pos.any(): |
| return soft_embeds.detach() |
|
|
| B, L, V = logits.shape |
| device, dtype = logits.device, logits.dtype |
|
|
| |
| has_effects = ( |
| temperature_max > 0 or |
| entropy_target_max > 0 or |
| entropy_floor_max > 0 or |
| smear_sigma_max > 0 or |
| noise_std_max > 0 or |
| iteration_rope_dim_fraction > 0 |
| ) |
|
|
| if not has_effects: |
| |
| masked_logits = logits[mask_pos] |
| embed_weight = self.embed_weight |
|
|
| |
| if self.config.soft_embedding_method == "none": |
| weights = masked_logits |
| elif self.config.soft_embedding_method == "l2_normalize": |
| weights = F.normalize(masked_logits, p=2, dim=-1) |
| else: |
| weights = self.normalize(masked_logits) |
|
|
| masked_soft = weights @ embed_weight |
| mask_emb = embed_weight[self.config.mask_token_id] |
| masked_soft = masked_soft + mask_emb |
|
|
| |
| ema_step = self.config.soft_embedding_ema_step |
| if ema_step < 1.0 and prev_soft_embeds is not None: |
| prev_masked_soft = prev_soft_embeds[mask_pos] |
| masked_soft = (1.0 - ema_step) * prev_masked_soft + ema_step * masked_soft |
|
|
| soft_embeds[mask_pos] = masked_soft |
| return soft_embeds.detach() |
|
|
| |
| batch_indices, position_indices = torch.where(mask_pos) |
|
|
| progress = self._compute_convergence_progress( |
| iteration=iteration, |
| total_iterations=total_iterations, |
| seq_length=L, |
| mask_positions=position_indices, |
| schedule=schedule, |
| causal_strength=causal_strength, |
| device=device, |
| dtype=dtype, |
| ) |
|
|
| |
| if smear_sigma_max > 0: |
| all_positions = torch.arange(L, device=device, dtype=dtype) |
| progress_full = self._compute_convergence_progress( |
| iteration=iteration, |
| total_iterations=total_iterations, |
| seq_length=L, |
| mask_positions=all_positions, |
| schedule=schedule, |
| causal_strength=causal_strength, |
| device=device, |
| dtype=dtype, |
| ) |
| progress_full = progress_full.unsqueeze(0).expand(B, -1) |
|
|
| |
| full_probs = self.normalize(logits) |
|
|
| if smear_sigma_max > 0: |
| full_probs = self._apply_smear_effect( |
| full_probs, mask_pos, progress_full, smear_sigma_max |
| ) |
|
|
| |
| masked_logits = logits[mask_pos] |
| masked_probs = full_probs[mask_pos] |
|
|
| |
| if temperature_max > 0 and entropy_target_max <= 0: |
| masked_logits = self._apply_temperature_effect( |
| masked_logits, progress, temperature_max |
| ) |
| masked_probs = torch.softmax(masked_logits, dim=-1) |
|
|
| |
| if noise_std_max > 0: |
| masked_logits_noisy = self._apply_noise_effect( |
| torch.log(masked_probs + 1e-10), progress, noise_std_max |
| ) |
| masked_probs = torch.softmax(masked_logits_noisy, dim=-1) |
|
|
| |
| if entropy_target_max > 0: |
| masked_probs = self._apply_target_entropy_effect( |
| masked_logits, progress, entropy_target_max |
| ) |
| elif entropy_floor_max > 0: |
| masked_probs = self._apply_entropy_floor_effect( |
| masked_probs, progress, entropy_floor_max |
| ) |
|
|
| |
| embed_weight = self.embed_weight |
|
|
| |
| if self.config.soft_embedding_method == "none": |
| |
| weights = masked_logits |
| elif self.config.soft_embedding_method == "l2_normalize": |
| |
| weights = F.normalize(masked_logits, p=2, dim=-1) |
| else: |
| weights = masked_probs |
|
|
| masked_soft = weights @ embed_weight |
| mask_emb = embed_weight[self.config.mask_token_id] |
| masked_soft = masked_soft + mask_emb |
|
|
| |
| if iteration_rope_dim_fraction > 0: |
| masked_soft = self._apply_iteration_rope( |
| masked_soft, iteration, total_iterations, iteration_rope_dim_fraction |
| ) |
|
|
| |
| ema_step = self.config.soft_embedding_ema_step |
| if ema_step < 1.0 and prev_soft_embeds is not None: |
| prev_masked_soft = prev_soft_embeds[mask_pos] |
| masked_soft = (1.0 - ema_step) * prev_masked_soft + ema_step * masked_soft |
|
|
| |
| soft_embeds[mask_pos] = masked_soft |
|
|
| return soft_embeds.detach() |
|
|
| @torch.no_grad() |
| def _compute_iteration_metrics( |
| self, logits: torch.Tensor, labels: torch.Tensor |
| ) -> IterationMetrics: |
| """ |
| Compute token-level AND sequence-level metrics for a single iteration. |
| Returns scalars only - no large tensor storage. |
| |
| Token-level metrics: |
| - accuracy: fraction of correct token predictions |
| - entropy: average entropy per token |
| - softmax_ce: cross-entropy loss per token |
| |
| Sequence-level metrics: |
| - full_sequence_accuracy: fraction of sequences where ALL tokens are correct |
| - min_sequence_confidence: mean of minimum top-1 confidence per sequence |
| """ |
| B = logits.shape[0] |
|
|
| |
| logits = logits.detach().cpu().float() |
| target_labels = labels.detach().cpu().contiguous() |
| mask = target_labels != -100 |
|
|
| if mask.sum() == 0: |
| return IterationMetrics( |
| accuracy=0.0, |
| entropy=0.0, |
| softmax_ce=0.0, |
| full_sequence_accuracy=0.0, |
| min_sequence_confidence=0.0, |
| ) |
|
|
| logits = logits.contiguous() |
| predictions = logits.argmax(dim=-1) |
| correct = (predictions == target_labels) & mask |
|
|
| |
|
|
| |
| accuracy = (correct.sum() / mask.sum()).item() |
|
|
| |
| valid_logits = logits[mask] |
| valid_labels = target_labels[mask] |
|
|
| |
| log_probs = torch.nn.functional.log_softmax(valid_logits, dim=-1) |
| probs = torch.exp(log_probs) |
| entropy = -(probs * log_probs).sum(dim=-1).mean().item() |
|
|
| |
| softmax_ce = torch.nn.functional.cross_entropy( |
| valid_logits, valid_labels, reduction="mean" |
| ).item() |
|
|
| |
|
|
| |
| sequences_with_tokens = mask.any(dim=1) |
| num_valid_sequences = sequences_with_tokens.sum().item() |
|
|
| if num_valid_sequences == 0: |
| return IterationMetrics( |
| accuracy=accuracy, |
| entropy=entropy, |
| softmax_ce=softmax_ce, |
| full_sequence_accuracy=0.0, |
| min_sequence_confidence=0.0, |
| ) |
|
|
| |
| num_correct_per_seq = correct.sum(dim=1) |
| num_tokens_per_seq = mask.sum(dim=1) |
| all_correct = (num_correct_per_seq == num_tokens_per_seq) & sequences_with_tokens |
| full_seq_accuracy = (all_correct.sum() / num_valid_sequences).item() |
|
|
| |
| probs_full = torch.softmax(logits, dim=-1) |
| top1_confidence = probs_full.max(dim=-1).values |
|
|
| min_confidences = [] |
| for i in range(B): |
| if sequences_with_tokens[i]: |
| seq_confidences = top1_confidence[i][mask[i]] |
| min_confidences.append(seq_confidences.min().item()) |
|
|
| min_seq_conf = sum(min_confidences) / len(min_confidences) if min_confidences else 0.0 |
|
|
| return IterationMetrics( |
| accuracy=accuracy, |
| entropy=entropy, |
| softmax_ce=softmax_ce, |
| full_sequence_accuracy=full_seq_accuracy, |
| min_sequence_confidence=min_seq_conf, |
| ) |
|
|
| def _single_iteration( |
| self, |
| t: int, |
| T: int, |
| soft_embeds: torch.Tensor, |
| base_embeds: torch.Tensor, |
| mask_pos: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| labels: Optional[torch.Tensor], |
| compute_metrics: bool, |
| position_ids: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[IterationMetrics]]: |
| """ |
| Execute a single iteration of recursive refinement. |
| |
| Args: |
| t: Current iteration index (0 to T-1) |
| T: Total number of iterations |
| soft_embeds: Soft embeddings for mask positions |
| base_embeds: Base token embeddings from input_ids |
| mask_pos: Boolean mask of [MASK] positions (B, L) |
| attention_mask: Attention mask for MLM |
| labels: Target labels for loss computation |
| compute_metrics: Whether to compute iteration metrics |
| |
| Returns: |
| logits: Output logits from MLM (B, L, V) |
| weighted_loss: Loss weighted by step_weight(t, T), or None if no labels |
| metrics: IterationMetrics, or None if not requested |
| """ |
| |
| inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds) |
|
|
| |
| outputs = self.mlm( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| labels=labels, |
| return_dict=True, |
| **kwargs, |
| ) |
|
|
| |
| weighted_loss = outputs.loss |
| if labels is not None: |
| if weighted_loss is None: |
| |
| |
| masked_logits = outputs.logits[mask_pos] |
| masked_labels = labels[mask_pos] |
| loss_fct = CrossEntropyLoss() |
| weighted_loss = loss_fct(masked_logits, masked_labels) |
| weighted_loss *= self.step_weight(t, T) |
|
|
| |
| metrics = None |
| if compute_metrics and labels is not None: |
| metrics = self._compute_iteration_metrics(outputs.logits, labels) |
|
|
| return outputs.logits, weighted_loss, metrics |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| num_recursions: Optional[int] = None, |
| compute_iteration_metrics: bool = False, |
| use_recursion_checkpointing: Optional[bool] = None, |
| |
| prev_soft_embeds: Optional[torch.Tensor] = None, |
| run_set_iteration: Optional[int] = None, |
| |
| schedule: Optional[str] = None, |
| causal_strength: Optional[float] = None, |
| |
| temperature_max: Optional[float] = None, |
| entropy_target_max: Optional[float] = None, |
| entropy_floor_max: Optional[float] = None, |
| smear_sigma_max: Optional[float] = None, |
| noise_std_max: Optional[float] = None, |
| iteration_rope_dim_fraction: Optional[float] = None, |
| **kwargs, |
| ) -> RecursiveMaskedLMOutput: |
| """ |
| Forward with recursive refinement. |
| |
| Supports three modes: |
| 1. Checkpointed mode (default): Run all T recursions with gradient checkpointing. |
| Gradients flow through the entire chain; activations recomputed during backward. |
| 2. Non-checkpointed mode (use_recursion_checkpointing=False): Store all activations. |
| Faster backward but higher memory. |
| 3. Single-iteration mode (DEPRECATED - run_set_iteration is not None): Run only one |
| iteration. Use use_recursion_checkpointing=True instead. |
| |
| Loss Weighting (config.loss_weight): |
| "last_1": Only final iteration loss (enables learning convergence behavior) |
| "last_2": Last 2 iterations |
| "linear": All iterations, linearly weighted (default) |
| "uniform": All iterations, uniformly weighted |
| |
| Recursion Checkpointing: |
| use_recursion_checkpointing: Enable gradient checkpointing for iterations. |
| True = checkpoint each iteration, recompute during backward (default). |
| False = store all activations (higher memory, faster backward). |
| |
| Convergence Schedule Parameters: |
| All schedule/effect parameters default to their config values if not specified. |
| Pass explicit values to override config for this forward pass. |
| |
| schedule: "linear" or "causal" - controls when positions can converge |
| causal_strength: How much faster early positions converge (causal only) |
| temperature_max: Max temperature boost for uncertain positions |
| entropy_target_max: Target entropy at progress=0 (two-sided, recommended) |
| entropy_floor_max: Min entropy floor (one-sided) |
| smear_sigma_max: Max Gaussian sigma for position smearing |
| noise_std_max: Max std of Gaussian noise on logits |
| iteration_rope_dim_fraction: Fraction of dims for iteration RoPE |
| """ |
| B, L = input_ids.shape |
| V = self.embed_weight.shape[0] |
| mask_id = self.config.mask_token_id |
|
|
| if mask_id is None: |
| raise ValueError("mask_token_id must be set") |
|
|
| |
| use_recursion_checkpointing = ( |
| use_recursion_checkpointing |
| if use_recursion_checkpointing is not None |
| else self.config.use_recursion_checkpointing |
| ) |
|
|
| mask_pos = (input_ids == mask_id) |
| base_embeds = self.get_input_embeddings()(input_ids) |
| T = num_recursions or self.config.num_recursions |
| weight_sum = sum(self.step_weight(i, T) for i in range(T)) |
|
|
| |
| schedule_kwargs = dict( |
| schedule=schedule, |
| causal_strength=causal_strength, |
| temperature_max=temperature_max, |
| entropy_target_max=entropy_target_max, |
| entropy_floor_max=entropy_floor_max, |
| smear_sigma_max=smear_sigma_max, |
| noise_std_max=noise_std_max, |
| iteration_rope_dim_fraction=iteration_rope_dim_fraction, |
| ) |
|
|
| |
| if run_set_iteration is not None: |
| warnings.warn( |
| "run_set_iteration is deprecated. Use use_recursion_checkpointing=True instead, " |
| "which provides proper gradient flow through all iterations.", |
| DeprecationWarning, |
| stacklevel=2, |
| ) |
| t = run_set_iteration |
|
|
| |
| if t == 0: |
| |
| |
| soft_embeds = base_embeds.clone() |
| if mask_pos.any(): |
| avg_embed = self.embed_weight.mean(dim=0) |
| mask_emb = self.embed_weight[mask_id] |
| soft_embeds[mask_pos] = avg_embed + mask_emb |
| else: |
| if prev_soft_embeds is None: |
| raise ValueError(f"prev_soft_embeds must be provided for iteration {t}") |
| soft_embeds = prev_soft_embeds |
|
|
| logits, weighted_loss, metrics = self._single_iteration( |
| t, T, soft_embeds, base_embeds, mask_pos, |
| attention_mask, labels, compute_iteration_metrics, |
| position_ids=position_ids, **kwargs |
| ) |
|
|
| |
| loss = weighted_loss / weight_sum if weighted_loss is not None else None |
|
|
| |
| next_soft_embeds = None |
| if t < T - 1: |
| next_soft_embeds = self._compute_next_soft_embeds( |
| logits, mask_pos, base_embeds, |
| iteration=t, |
| total_iterations=T, |
| **schedule_kwargs, |
| ) |
|
|
| return RecursiveMaskedLMOutput( |
| loss=loss, |
| logits=logits, |
| next_soft_embeds=next_soft_embeds, |
| iteration_metrics={t: metrics} if metrics is not None else None, |
| ) |
|
|
| |
| embed_weight = self.embed_weight |
| mask_emb = embed_weight[mask_id] |
|
|
| |
| temperature = torch.tensor( |
| self.config.temperature, |
| device=input_ids.device, |
| dtype=base_embeds.dtype, |
| ) |
|
|
| |
| if attention_mask is None: |
| attention_mask = torch.ones(B, L, device=input_ids.device, dtype=base_embeds.dtype) |
|
|
| |
| soft_embeds = base_embeds.clone() |
| flow_noise_embed = None |
| flow_t_per_token = None |
|
|
| if self.config.flow_matching_enabled and self.training and labels is not None and mask_pos.any(): |
| |
| num_masked = mask_pos.sum().item() |
| V = embed_weight.shape[0] |
| device = input_ids.device |
|
|
| |
| flow_t_per_token = self._sample_flow_matching_t(num_masked, device) |
|
|
| |
| z = torch.randn(num_masked, V, device=device, dtype=base_embeds.dtype) |
| p_noise = F.softmax(z * self.config.flow_matching_noise_scale, dim=-1).to(base_embeds.dtype) |
| flow_noise_embed = p_noise @ embed_weight |
|
|
| |
| target_ids = labels[mask_pos] |
| target_embed = embed_weight[target_ids] |
|
|
| |
| t_col = flow_t_per_token.unsqueeze(-1).to(base_embeds.dtype) |
| interp_embed = (1 - t_col) * flow_noise_embed + t_col * target_embed |
|
|
| |
| if self.config.flow_matching_mask_scale: |
| soft_embeds[mask_pos] = interp_embed + (1 - t_col) * mask_emb |
| else: |
| soft_embeds[mask_pos] = interp_embed + mask_emb |
| elif mask_pos.any(): |
| |
| avg_embed = embed_weight.mean(dim=0) |
| soft_embeds[mask_pos] = avg_embed + mask_emb |
|
|
| iteration_metrics = {} if compute_iteration_metrics and labels is not None else None |
|
|
| |
| all_logits = [] |
| for t in range(T): |
| if self.training and use_recursion_checkpointing: |
| |
| |
| logits, soft_embeds = torch_checkpoint( |
| self._single_iteration_checkpointable, |
| soft_embeds, |
| base_embeds, |
| mask_pos, |
| attention_mask, |
| embed_weight, |
| mask_emb, |
| temperature, |
| position_ids, |
| use_reentrant=False, |
| ) |
| else: |
| |
| logits, soft_embeds = self._single_iteration_checkpointable( |
| soft_embeds, |
| base_embeds, |
| mask_pos, |
| attention_mask, |
| embed_weight, |
| mask_emb, |
| temperature, |
| position_ids, |
| ) |
| all_logits.append(logits) |
|
|
| |
| if iteration_metrics is not None and labels is not None: |
| with torch.no_grad(): |
| iteration_metrics[t] = self._compute_iteration_metrics(logits, labels) |
|
|
| |
| |
| return RecursiveMaskedLMOutput( |
| loss=None, |
| logits=logits, |
| all_logits=all_logits if self.training else None, |
| iteration_metrics=iteration_metrics or None, |
| flow_noise_embed=flow_noise_embed, |
| flow_t=flow_t_per_token, |
| ) |
|
|
| @torch.no_grad() |
| def _generate_flow_map( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| position_ids: Optional[torch.Tensor], |
| num_steps: int, |
| ) -> torch.Tensor: |
| """Fill in mask positions using the CFM flow map update rule. |
| |
| Starts from a random point on the probability simplex and iteratively |
| moves toward the model's predictions using the flow map step rule. |
| |
| Args: |
| input_ids: Input with [MASK] tokens at positions to fill |
| attention_mask: Attention mask |
| position_ids: Position IDs |
| num_steps: Number of flow map steps (finer = better, 1 step = greedy) |
| |
| Returns: |
| Tensor with [MASK] positions filled with predicted tokens |
| """ |
| mask_pos = (input_ids == self.config.mask_token_id) |
| num_masked = mask_pos.sum().item() |
|
|
| if num_masked == 0: |
| return input_ids.clone() |
|
|
| device = input_ids.device |
| V = self.embed_weight.shape[0] |
| embed_weight = self.embed_weight |
| mask_emb = embed_weight[self.config.mask_token_id] |
| base_embeds = self.get_input_embeddings()(input_ids) |
|
|
| |
| noise_scale = self.config.flow_matching_noise_scale |
| p = F.softmax(torch.randn(num_masked, V, device=device, dtype=base_embeds.dtype) * noise_scale, dim=-1).to(base_embeds.dtype) |
|
|
| times = torch.linspace(0, 1, num_steps + 1, device=device) |
|
|
| for i in range(num_steps): |
| t_now = times[i] |
| t_next = times[i + 1] |
| step_size = (t_next - t_now) / (1 - t_now) |
|
|
| |
| if self.config.flow_matching_mask_scale: |
| mask_signal = (1 - t_now) * mask_emb |
| else: |
| mask_signal = mask_emb |
|
|
| |
| embed = p @ embed_weight + mask_signal |
|
|
| soft_embeds = base_embeds.clone() |
| soft_embeds[mask_pos] = embed |
| inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds) |
|
|
| outputs = self.mlm( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| return_dict=True, |
| ) |
| pi = F.softmax(outputs.logits[mask_pos], dim=-1).to(p.dtype) |
|
|
| |
| p = p + step_size * (pi - p) |
|
|
| |
| p = p.clamp(min=0) |
| p = p / p.sum(dim=-1, keepdim=True) |
|
|
| result = input_ids.clone() |
| result[mask_pos] = p.argmax(dim=-1) |
| return result |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| num_recursions: Optional[int] = None, |
| |
| schedule: Optional[str] = None, |
| causal_strength: Optional[float] = None, |
| |
| temperature_max: Optional[float] = None, |
| entropy_target_max: Optional[float] = None, |
| entropy_floor_max: Optional[float] = None, |
| smear_sigma_max: Optional[float] = None, |
| noise_std_max: Optional[float] = None, |
| iteration_rope_dim_fraction: Optional[float] = None, |
| ) -> torch.Tensor: |
| """Fill in mask positions via iterative refinement. |
| |
| When flow_matching_enabled, uses the CFM flow map update rule. |
| Otherwise, uses standard recursive soft-token refinement. |
| |
| Args: |
| input_ids: Input token IDs with [MASK] tokens at positions to fill |
| attention_mask: Attention mask |
| num_recursions: Override number of recursions/steps (default: config value) |
| schedule: "linear" or "causal" convergence schedule |
| causal_strength: How much faster early positions converge (causal only) |
| temperature_max: Max temperature boost for uncertain positions |
| entropy_target_max: Target entropy at progress=0 (two-sided) |
| entropy_floor_max: Min entropy floor (one-sided) |
| smear_sigma_max: Max Gaussian sigma for position smearing |
| noise_std_max: Max std of Gaussian noise on logits |
| iteration_rope_dim_fraction: Fraction of dims for iteration RoPE |
| |
| Returns: |
| Tensor with [MASK] positions filled with predicted tokens |
| """ |
| num_steps = num_recursions or self.config.num_recursions |
|
|
| if self.config.flow_matching_enabled: |
| return self._generate_flow_map( |
| input_ids, attention_mask, position_ids, num_steps |
| ) |
|
|
| out = self.forward( |
| input_ids, |
| attention_mask, |
| position_ids=position_ids, |
| num_recursions=num_steps, |
| schedule=schedule, |
| causal_strength=causal_strength, |
| temperature_max=temperature_max, |
| entropy_target_max=entropy_target_max, |
| entropy_floor_max=entropy_floor_max, |
| smear_sigma_max=smear_sigma_max, |
| noise_std_max=noise_std_max, |
| iteration_rope_dim_fraction=iteration_rope_dim_fraction, |
| ) |
| result = input_ids.clone() |
| mask_pos = (input_ids == self.config.mask_token_id) |
| result[mask_pos] = out.logits.argmax(dim=-1)[mask_pos] |
| return result |
|
|