Buckets:

hf-doc-build/doc / diffusers /main /en /api /schedulers /block_refinement.md
HuggingFaceDocBuilder's picture
|
download
raw
8.26 kB

BlockRefinementScheduler

The BlockRefinementScheduler manages block-wise iterative refinement for discrete token diffusion. At each step it commits the most confident tokens and optionally edits already-committed tokens when the model predicts a different token with high confidence.

This scheduler is used by LLaDA2Pipeline.

BlockRefinementScheduler[[diffusers.BlockRefinementScheduler]]

diffusers.BlockRefinementScheduler[[diffusers.BlockRefinementScheduler]]

Source

Scheduler for block-wise iterative refinement (commit-by-confidence).

At each step, the scheduler samples candidate tokens from model logits and commits those with the highest confidence. The number of tokens to commit per step is determined by evenly distributing the block length across the number of refinement steps.

Optionally supports editing: after all mask tokens are resolved, tokens can be replaced if the model predicts a different token with confidence above a positive editing_threshold (None, 0.0, or negative disables editing).

add_noisediffusers.BlockRefinementScheduler.add_noisehttps://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_block_refinement.py#L396[{"name": "original_samples", "val": ": torch.LongTensor"}, {"name": "attention_mask", "val": ": torch.LongTensor"}, {"name": "prompt_length", "val": ": int"}, {"name": "block_length", "val": ": int"}, {"name": "mask_token_id", "val": ": int"}, {"name": "generator", "val": ": torch.Generator | None = None"}]- original_samples (torch.LongTensor of shape (batch_size, seq_len)) -- Clean token IDs.

  • attention_mask (torch.LongTensor of shape (batch_size, seq_len)) -- Padding mask (1 for valid, 0 for padding).
  • prompt_length (int) -- Number of leading prompt tokens to keep unmasked.
  • block_length (int) -- Block size for masking.
  • mask_token_id (int) -- Token ID to use for masked positions.
  • generator (torch.Generator, optional) -- RNG for reproducibility.0tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]``(noisy, noisy_rev, masked, masked_rev) — the two complementary noisy sequences and their corresponding boolean masks.

Apply the forward (noising) process for semi-autoregressive block masking.

For each block after the prompt, a random fraction of valid (non-padding) tokens are replaced with mask_token_id. Two complementary views are returned: noisy and noisy_rev, where the masked positions in one are the unmasked positions in the other.

Parameters:

original_samples (torch.LongTensor of shape (batch_size, seq_len)) : Clean token IDs.

attention_mask (torch.LongTensor of shape (batch_size, seq_len)) : Padding mask (1 for valid, 0 for padding).

prompt_length (int) : Number of leading prompt tokens to keep unmasked.

block_length (int) : Block size for masking.

mask_token_id (int) : Token ID to use for masked positions.

generator (torch.Generator, optional) : RNG for reproducibility.

Returns:

tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]

(noisy, noisy_rev, masked, masked_rev) — the two complementary noisy sequences and their corresponding boolean masks.

check_block_should_continue[[diffusers.BlockRefinementScheduler.check_block_should_continue]]

Source

Determine whether the inner refinement loop should continue for the current block.

Parameters:

step_idx (int) : Current refinement step index within this block.

masks_remaining (bool) : Whether any mask tokens remain in the block.

editing_enabled (bool) : Whether editing mode is active.

editing_transfer_index (torch.BoolTensor) : Which tokens were edited in this step.

post_steps (int) : Number of post-mask editing steps taken so far.

max_post_steps (int) : Maximum allowed post-mask editing steps.

finished (torch.BoolTensor) : Per-batch finished flags (from EOS detection).

Returns:

bool

True if refinement should continue, False to break.

check_eos_finished[[diffusers.BlockRefinementScheduler.check_eos_finished]]

Source

Update per-batch finished flags when EOS tokens are committed.

Parameters:

cur_x (torch.LongTensor of shape (batch_size, seq_len)) : Current full sequence including all blocks up to the current window.

sampled_tokens (torch.LongTensor of shape (batch_size, block_length)) : Tokens sampled by the scheduler in this step.

final_transfer (torch.BoolTensor of shape (batch_size, block_length)) : Combined mask of committed and edited positions.

finished (torch.BoolTensor of shape (batch_size,)) : Current per-batch finished flags.

eos_token_id (int) : EOS token ID.

mask_token_id (int) : Mask token ID.

prompt_length (int) : Number of prompt tokens at the start of the sequence.

Returns:

torch.BoolTensor

Updated finished flags.

get_num_transfer_tokens[[diffusers.BlockRefinementScheduler.get_num_transfer_tokens]]

Source

Evenly distribute block_length token commits across num_inference_steps steps.

step[[diffusers.BlockRefinementScheduler.step]]

Source

Perform a single refinement step: sample from logits, commit confident tokens, and optionally edit existing ones.

Parameters:

model_output (torch.Tensor of shape (batch_size, block_length, vocab_size)) : Raw logits from the model for the current block.

timestep (int or torch.Tensor) : Current step index within the block's refinement schedule.

sample (torch.LongTensor of shape (batch_size, block_length)) : Current block token IDs (contains mask tokens for uncommitted positions).

mask_token_id (int) : Token ID used for masked positions.

temperature (float) : Sampling temperature.

top_p (float, optional) : Nucleus sampling cutoff.

top_k (int, optional) : Top-k sampling cutoff.

sampling_method (str) : Sampling method (auto, greedy, multinomial).

threshold (float, optional) : Confidence threshold for committing tokens. Defaults to config value.

editing_threshold (float, optional) : Confidence threshold for editing non-mask tokens; must be positive to enable editing. Defaults to config value.

minimal_topk (int, optional) : Minimum tokens to commit per step. Defaults to config value.

prompt_mask (torch.BoolTensor, optional) : Boolean mask of shape (block_length,) where True marks prompt (non-editable) positions.

generator (torch.Generator, optional) : RNG for sampling.

return_dict (bool) : Whether to return a BlockRefinementSchedulerOutput or a tuple.

BlockRefinementSchedulerOutput[[diffusers.BlockRefinementSchedulerOutput]]

diffusers.BlockRefinementSchedulerOutput[[diffusers.BlockRefinementSchedulerOutput]]

Source

Output class for block refinement scheduling.

Parameters:

prev_sample (torch.LongTensor of shape (batch_size, block_length)) : Updated block tokens after the current refinement step.

transfer_index (torch.BoolTensor of shape (batch_size, block_length)) : Boolean mask indicating which tokens were committed (mask-filling).

editing_transfer_index (torch.BoolTensor of shape (batch_size, block_length)) : Boolean mask indicating which tokens were edited (non-mask replacement).

sampled_tokens (torch.LongTensor of shape (batch_size, block_length)) : Sampled token IDs from the model logits.

sampled_probs (torch.Tensor of shape (batch_size, block_length)) : Probabilities of the sampled tokens.

Xet Storage Details

Size:
8.26 kB
·
Xet hash:
71fa0a7cc61c8ce056be9ceca05ede884aca77ef6f56590ce173b5ec03ff7c33

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.