--- license: apache-2.0 pipeline_tag: text-generation tags: - dllm - diffusion - llm - text_generation library_name: transformers --- # ReFusion [![arXiv](https://img.shields.io/badge/Paper-arXiv-red.svg)](http://arxiv.org/abs/2512.13586) [![GitHub](https://img.shields.io/badge/GitHub-ReFusion-black?logo=github)](https://github.com/ML-GSAI/ReFusion) **ReFusion** is a masked diffusion model that achieves superior performance and efficiency, featuring full KV cache reuse while simultaneously supporting any-order generation. # Quickstart The following code snippet shows how to load the tokenizer and model and how to generate content. ```python import torch import numpy as np from torch import nn import torch.nn.functional as F from tqdm import tqdm import pandas as pd import os import random import copy import math from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig from typing import Optional, Dict, Any, Tuple, List def add_gumbel_noise(logits, temperature): ''' The Gumbel max is a method for sampling categorical distributions. According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. Thus, we use float64. ''' if temperature == 0: return logits logits = logits.to(torch.float64) noise = torch.rand_like(logits, dtype=torch.float64) gumbel_noise = (- torch.log(noise)) ** temperature return logits.exp() / gumbel_noise import torch import torch.nn.functional as F @ torch.no_grad() def generate_refusion(model, tokenizer, prompt, gen_length=128, temperature=0., mask_id=151670, slot_size=8, model_path='', serial_num_blocks=2, slot_threshold=0.9, token_threshold=0.9): slot_threshold = slot_threshold token_threshold = token_threshold sum_TPF = 0.0 forward_count = 0 eos_token_id = tokenizer.eos_token_id batch_size = 1 prompt_len = prompt.shape[1] device = model.device gen_pad_len = (slot_size - (gen_length % slot_size)) % slot_size gen_length = gen_length + gen_pad_len gen_x = torch.full((batch_size, gen_length), mask_id, dtype=torch.long, device=device) prompt_pos_ids = torch.arange(prompt_len, dtype=torch.long, device=device).unsqueeze(0) gen_pos_ids = torch.arange(prompt_len, prompt_len + gen_length, dtype=torch.long, device=device).unsqueeze(0) cur_x = prompt.clone() cur_pos = prompt_pos_ids.clone() cur_slot_size = slot_size eos_flag = False block_length = gen_length // serial_num_blocks past_key_values = None for serial_num_block in range(serial_num_blocks): # block level cur_gen_x = gen_x[:, serial_num_block*block_length:(serial_num_block+1)*block_length] # (batch_size, block_length) cur_gen_pos_ids = gen_pos_ids[:, serial_num_block*block_length:(serial_num_block+1)*block_length] # (batch_size, block_length) cur_gen_blocks_x = cur_gen_x.reshape(batch_size, -1, cur_slot_size) cur_gen_blocks_pos_ids = cur_gen_pos_ids.reshape(batch_size, -1, cur_slot_size) # slot level generation while cur_gen_blocks_x.numel() > 0: cur_gen_blocks_x = cur_gen_blocks_x.reshape(batch_size, -1, cur_slot_size) cur_gen_blocks_pos_ids = cur_gen_blocks_pos_ids.reshape(batch_size, -1, cur_slot_size) flat_gen_blocks_x = cur_gen_blocks_x.view(batch_size, -1) flat_gen_blocks_pos_ids = cur_gen_blocks_pos_ids.view(batch_size, -1) prefix_block_tag = False # MDM if past_key_values is None: input_x = torch.cat((cur_x, flat_gen_blocks_x), dim=1) input_pos_ids = torch.cat((cur_pos, flat_gen_blocks_pos_ids), dim=1) outputs = model( input_ids=input_x, position_ids=input_pos_ids, past_key_values=past_key_values, use_cache=True ) else: outputs = model( input_ids=flat_gen_blocks_x, position_ids=flat_gen_blocks_pos_ids, past_key_values=past_key_values, use_cache=True ) logits = outputs.logits gen_logits = logits[:, -flat_gen_blocks_x.shape[1]:, :] past_key_values = outputs.past_key_values past_key_values.crop(cur_x.shape[1]) assert cur_x.shape[-1] == past_key_values[0][0].shape[-2] logits_with_noise = add_gumbel_noise(gen_logits, temperature=temperature) x0_gen = torch.argmax(logits_with_noise, dim=-1) x0_gen_blocks = x0_gen.view(batch_size, -1, cur_slot_size) p_softmax = F.softmax(gen_logits, dim=-1) x0_p_softmax = torch.gather(p_softmax, dim=-1, index=torch.unsqueeze(x0_gen, -1)).squeeze(-1) x0_p_softmax_blocks = x0_p_softmax.view(batch_size, -1, cur_slot_size) block_confidence_softmax = x0_p_softmax_blocks[:,:,0] # (bsz, num_slots) is_confident_block = block_confidence_softmax > slot_threshold counts_block = torch.sum(is_confident_block, dim=1).item() topk_indices_relative = is_confident_block[0].nonzero(as_tuple=True)[0] if counts_block <= 0: counts_block = 1 _, topk_indices_relative = torch.topk(block_confidence_softmax.squeeze(0), k=1) # choose slot topk_indices_relative, _ = torch.sort(topk_indices_relative) chosen_gen_blocks = x0_gen_blocks[0, topk_indices_relative, :] chosen_position_ids = cur_gen_blocks_pos_ids[0, topk_indices_relative, :] chosen_p_softmax_blocks = x0_p_softmax_blocks[0, topk_indices_relative, :] # Global Verification outputs = model( input_ids=chosen_gen_blocks.reshape(1, -1), position_ids=chosen_position_ids.reshape(1, -1), past_key_values=past_key_values, use_cache=True, ) AR_logits = outputs.logits #[1, len, vocab_len] AR_logits = torch.cat([AR_logits[:,:1], AR_logits[:, :-1]], dim=1) AR_p_softmax = F.softmax(AR_logits, dim=-1) #[1, len, 1] AR_x0_p_softmax = torch.gather(AR_p_softmax, dim=-1, index=torch.unsqueeze(chosen_gen_blocks.reshape(1, -1), -1)).squeeze(-1) #[1, len] AR_x0_p_softmax_blocks = AR_x0_p_softmax.reshape(-1, cur_slot_size) chosen_p_softmax_blocks[:,1:] = AR_x0_p_softmax_blocks[:,1:] prob_mask = chosen_p_softmax_blocks > token_threshold prob_mask[:, 0] = 1 tag_blocks = torch.cumprod(prob_mask.int(), dim=-1) tag_tokens = torch.cumprod(prob_mask.int().reshape(1, -1), dim=-1) prefix_len = torch.sum(tag_tokens, dim=-1) flat_chosen_gen_blocks = chosen_gen_blocks.reshape(1, -1) confident_prefix_tokens = flat_chosen_gen_blocks[:, :prefix_len] if prefix_len > 0: is_eos_in_prefix = (confident_prefix_tokens.squeeze(0) == eos_token_id) eos_found_flag = torch.any(is_eos_in_prefix) remain_indices = [] indices_to_remove = set() if eos_found_flag: first_eos_pos_tensor = torch.argmax(is_eos_in_prefix.int()) eos_block_pos = first_eos_pos_tensor // cur_slot_size + 1 eos_token_pos = first_eos_pos_tensor - (first_eos_pos_tensor // cur_slot_size) * cur_slot_size eos_block = topk_indices_relative[eos_block_pos-1].item() remain_indices.extend(topk_indices_relative[:eos_block_pos].tolist()) topk_indices_relative = torch.tensor([], device=device) eos_flag = True indices_after_eos = list(range(eos_block, cur_gen_blocks_x.shape[1])) indices_to_remove.update(indices_after_eos) elif (prefix_len // cur_slot_size) > 0: num_prefix_blocks = prefix_len // cur_slot_size remain_indices.extend(topk_indices_relative[:num_prefix_blocks].tolist()) topk_indices_relative = topk_indices_relative[num_prefix_blocks:] tag_blocks = tag_blocks[num_prefix_blocks:] if len(remain_indices) > 0: indices_to_remove.update(remain_indices) token_indices = [] for i_idx, b_idx in enumerate(remain_indices): start_index = b_idx * cur_slot_size current_block_len = cur_slot_size # If EOS exists and this is the last slot, then adjust the length. if eos_found_flag and i_idx == len(remain_indices) - 1: current_block_len = eos_token_pos + 1 end_index = start_index + current_block_len block_range = torch.arange(start_index, end_index, dtype=torch.long, device=device) token_indices.append(block_range) full_token_indices = torch.cat(token_indices) cur_x = torch.cat((cur_x, x0_gen[:, full_token_indices]), dim=1) cur_pos = torch.cat((cur_pos, flat_gen_blocks_pos_ids[:, full_token_indices]), dim=1) past_key_values = outputs.past_key_values past_key_values.crop(cur_x.shape[1]) assert cur_x.shape[-1] == past_key_values[0][0].shape[-2] prefix_block_tag = True sum_TPF += cur_slot_size * len(remain_indices) / 2 forward_count += 1 if prefix_block_tag == True: keep_mask = torch.ones(cur_gen_blocks_x.shape[1], dtype=torch.bool, device=device) keep_mask[list(indices_to_remove)] = False cur_gen_blocks_x = cur_gen_blocks_x[:, keep_mask, :] cur_gen_blocks_pos_ids = cur_gen_blocks_pos_ids[:, keep_mask, :] continue elif prefix_block_tag == False: past_key_values = outputs.past_key_values past_key_values.crop(cur_x.shape[1]) assert cur_x.shape[-1] == past_key_values[0][0].shape[-2] indices_to_remove = set(topk_indices_relative.tolist()) current_speculative_blocks = chosen_gen_blocks.clone() accepted_prefix_len = 0 eos_found_in_loop = False if past_key_values is not None and counts_block > 1: past_key_values.batch_repeat_interleave(counts_block) for loop_iter in range(cur_slot_size): if not torch.any(tag_blocks == 0): break input_tokens = current_speculative_blocks[:, accepted_prefix_len:] input_pos = chosen_position_ids[:, accepted_prefix_len:] current_tags = tag_blocks[:, accepted_prefix_len:] masked_input_tokens = torch.where(current_tags.bool(), input_tokens, mask_id) # Prediction draft_len = past_key_values[0][0].shape[2] draft_outputs = model( input_ids=masked_input_tokens, position_ids=input_pos, past_key_values=past_key_values, use_cache=False, ) past_key_values.crop(draft_len) draft_logits = draft_outputs.logits proposed_tokens = torch.argmax(draft_logits, dim=-1) input_tokens = torch.where(current_tags.bool(), input_tokens, proposed_tokens) current_speculative_blocks[:, accepted_prefix_len:] = input_tokens # Verification verify_outputs = model( input_ids=input_tokens, position_ids=input_pos, past_key_values=past_key_values, use_cache=True, ) verify_logits = verify_outputs.logits verify_logits = torch.cat([verify_logits[:,:1], verify_logits[:, :-1]], dim=1) verify_probs = F.softmax(verify_logits, dim=-1) gathered_probs = torch.gather(verify_probs, -1, input_tokens.unsqueeze(-1)).squeeze(-1) prob_mask = gathered_probs > token_threshold # Keep at least one token update_tag_blocks = F.pad(tag_blocks[:, accepted_prefix_len:], (1, 0), value=1)[:, :-1] prob_mask[update_tag_blocks == 1] = True new_tags = torch.cumprod(prob_mask.int(), dim=-1) tag_blocks[:, accepted_prefix_len:] = new_tags newly_verified_mask = (tag_blocks[:, accepted_prefix_len:] == 1) is_eos_in_new = (current_speculative_blocks[:, accepted_prefix_len:] == eos_token_id) & newly_verified_mask if torch.any(is_eos_in_new): eos_found_in_loop = True first_eos_block_idx = torch.where(torch.any(is_eos_in_new, dim=1))[0][0].item() current_speculative_blocks = current_speculative_blocks[:first_eos_block_idx+1] tag_blocks = tag_blocks[:first_eos_block_idx+1] tag_blocks[first_eos_block_idx] = 1 chosen_position_ids = chosen_position_ids[:first_eos_block_idx+1] topk_indices_relative = topk_indices_relative[:first_eos_block_idx+1] if verify_outputs.past_key_values is not None: verify_outputs.past_key_values.batch_select_minibatch(first_eos_block_idx + 1) current_tags = tag_blocks[:, accepted_prefix_len:] len_per_block = torch.sum(current_tags, dim=1) newly_accepted_len = torch.min(len_per_block).item() if newly_accepted_len > 0: if torch.any(tag_blocks == 0): accepted_prefix_len = accepted_prefix_len + newly_accepted_len - 1 else: accepted_prefix_len = accepted_prefix_len + newly_accepted_len past_key_values = verify_outputs.past_key_values if past_key_values is not None: past_key_values.crop(cur_x.shape[1] + accepted_prefix_len) sum_TPF += (cur_slot_size * counts_block) / (loop_iter * 2 + 2) forward_count += 1 ar_kv_cache = tuple( ( layer_past[0][:, :, -cur_slot_size:, :], # key layer_past[1][:, :, -cur_slot_size:, :] # value ) for layer_past in past_key_values ) past_key_values.crop(cur_x.shape[1]) past_key_values.batch_select_indices(torch.tensor([0]).to(device)) eos_mask = (current_speculative_blocks == eos_token_id) # (k*cur_slot_size) keep_mask = (torch.cumsum(eos_mask.flatten().int(), dim=-1) - eos_mask.flatten().int()) == 0 kept_tokens = current_speculative_blocks.flatten()[keep_mask].reshape(batch_size, -1) kept_pos_ids = chosen_position_ids.flatten()[keep_mask].reshape(batch_size, -1) # update KV cache if kept_tokens.numel() > 0 and ar_kv_cache is not None: new_past = [] for i, (key, val) in enumerate(ar_kv_cache): num_heads, _, head_dim = key.shape[1], key.shape[2], key.shape[3] flat_key = key.permute(1, 0, 2, 3).reshape(1, num_heads, -1, head_dim) flat_val = val.permute(1, 0, 2, 3).reshape(1, num_heads, -1, head_dim) kept_key = flat_key[:, :, keep_mask, :] kept_val = flat_val[:, :, keep_mask, :] new_past.append((kept_key, kept_val)) kept_kv = tuple(new_past) past_key_values.full_update(kept_kv) cur_x = torch.cat((cur_x, kept_tokens), dim=1) cur_pos = torch.cat((cur_pos, kept_pos_ids), dim=1) assert cur_x.shape[-1] == past_key_values[0][0].shape[-2] if eos_found_in_loop: indices_after_eos = list(range(first_eos_block_idx, cur_gen_blocks_x.shape[1])) indices_to_remove.update(indices_after_eos) eos_flag = True keep_mask = torch.ones(cur_gen_blocks_x.shape[1], dtype=torch.bool, device=device) keep_mask[list(indices_to_remove)] = False cur_gen_blocks_x = cur_gen_blocks_x[:, keep_mask, :] cur_gen_blocks_pos_ids = cur_gen_blocks_pos_ids[:, keep_mask, :] if eos_flag: break _, re_mask_indices = torch.sort(cur_pos, dim=-1) x = torch.gather(cur_x, dim=-1, index=re_mask_indices) TPF = sum_TPF / forward_count return x, TPF def main(): device = 'cuda' model_path = "ReFusion" model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval() tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) prompt = "You are an expert Python programmer. Your task is to write a single Python function to solve the problem described below, and here is your task: Write a function to sum all amicable numbers from 1 to a specified number.\n\nDirectly after the '[BEGIN]' marker, you must write only the Python code for the function. Do not provide any explanations, comments, or introductory text. The function must include the 'def' line, its arguments, the function body, and a 'return' statement. Your code should pass these tests:\n\nassert amicable_numbers_sum(999)==504\nassert amicable_numbers_sum(9999)==31626\nassert amicable_numbers_sum(99)==0\n[BEGIN]\n" m = [{"role": "user", "content": prompt}, ] prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False, enable_thinking=True) print(prompt) input_ids = tokenizer(prompt)['input_ids'] input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) out, TPF = generate_refusion(model, tokenizer, input_ids, gen_length=512, temperature=0., mask_id=151670, slot_size=4, model_path=model_path, serial_num_blocks=32, slot_threshold=0.6, token_threshold=0.3) print(tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0]) print("---------TPF:", TPF) if __name__ == '__main__': main() ``` # Citation If you find our work helpful, please consider citing our paper. ```bibtex @misc{li2025refusiondiffusionlargelanguage, title={ReFusion: A Diffusion Large Language Model with Parallel Autoregressive Decoding}, author={Jia-Nan Li and Jian Guan and Wei Wu and Chongxuan Li}, year={2025}, eprint={2512.13586}, archivePrefix={arXiv}, primaryClass={cs.CL}, url={https://arxiv.org/abs/2512.13586}, } ```