File size: 6,584 Bytes
8511ba7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# Copyright (c) 2026 ByteDance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import torch
import numpy as np
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache
def add_gumbel_noise(logits, temperature):
if temperature == 0:
return logits
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
gumbel_noise = (- torch.log(noise)) ** temperature
return logits.exp() / gumbel_noise
def get_num_transfer_tokens(mask_index, steps):
mask_num = mask_index.sum(dim=1, keepdim=True)
base = mask_num // steps
remainder = mask_num % steps
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
for i in range(mask_num.size(0)):
num_transfer_tokens[i, :remainder[i]] += 1
return num_transfer_tokens
def make_block_causal_mask(seq_len, block_size=2, device=None, dtype=torch.bool):
num_blocks = (seq_len + block_size - 1) // block_size
block_mask = torch.tril(torch.ones((num_blocks, num_blocks), dtype=torch.bool, device=device))
local_block = torch.ones((block_size, block_size), dtype=torch.bool, device=device)
mask = torch.kron(block_mask, local_block)[:seq_len, :seq_len]
attention_mask = mask.float()
attention_mask.masked_fill_(~mask, float('-inf'))
attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).to(dtype)
return attention_mask
@ torch.no_grad()
def generate_block(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
remasking='low_confidence', tokenizer=None, mask_id=5, threshold=0.95, shift=False, eos_id=None):
x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
x[:, :prompt.shape[1]] = prompt.clone()
assert gen_length % block_length == 0
num_blocks = gen_length // block_length
assert steps % num_blocks == 0
steps = steps // num_blocks
prompt_len = prompt.shape[1]
res_block = block_length - prompt_len % block_length
every_block = [block_length for _ in range(num_blocks)]
if res_block > 0:
every_block = [res_block] + every_block
every_block[-1] = block_length - res_block
cum_block = [sum(every_block[:i+1]) for i in range(len(every_block))]
num_block = len(cum_block)
block_diffusion_attention_mask = make_block_causal_mask(prompt.shape[1] + gen_length, block_length, model.device, dtype=torch.bfloat16)
nfe = 0
final_flag = 0
prefill_length = prompt_len // block_length * block_length
if prefill_length > 0:
cur_attn_mask = block_diffusion_attention_mask[:, :, :prefill_length, :prefill_length]
past_key_values = model(x[:, :prefill_length], attention_mask=cur_attn_mask, use_cache=True).past_key_values
for num_block in range(num_blocks):
current_block_start = prompt_len + cum_block[num_block - 1] if num_block > 0 else prefill_length
current_block_end = prompt_len + cum_block[num_block]
block_mask_index = (x[:, current_block_start:current_block_end] == mask_id)
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
replace_position = torch.zeros_like(x, dtype=torch.bool)
replace_position[:, current_block_start:current_block_end] = 1
i = 0
while True:
nfe += 1
mask_index = (x[:, current_block_start:current_block_end] == mask_id)
cur_attn_mask = block_diffusion_attention_mask[:, :, current_block_start:current_block_end, :current_block_end]
output = model(x[:, current_block_start:current_block_end], attention_mask=block_diffusion_attention_mask[:, :, current_block_start:current_block_end, :current_block_end], past_key_values=past_key_values, use_cache=True, cache_position=replace_position.nonzero(as_tuple=True)[1])
logits = output.logits
x0, transfer_index = get_transfer_index(logits, temperature, remasking, mask_index,
x[:, current_block_start:current_block_end], num_transfer_tokens[:, i] if threshold is None else None, threshold, shift=False)
x[:, current_block_start:current_block_end][transfer_index] = x0[transfer_index]
if (x[:, current_block_start:current_block_end] == mask_id).sum() == 0:
if eos_id is not None and (x[:, current_block_start:current_block_end] == eos_id).sum() > 0:
final_flag = 1
x = x[:, :current_block_end]
break
past_key_values = model(x[:, current_block_start:current_block_end], attention_mask=block_diffusion_attention_mask[:, :, current_block_start:current_block_end, :current_block_end], past_key_values=past_key_values, use_cache=True, cache_position=replace_position.nonzero(as_tuple=True)[1]).past_key_values
break
if final_flag == 1:
break
i += 1
return x, nfe
def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, shift=False):
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
if shift == True:
x0 = torch.cat([x[:, :1], x0[:, :-1]], dim=-1)
pad = torch.zeros_like(logits[:, :1])
logits = torch.cat([pad, logits[:, :-1]], dim=1)
if remasking == 'low_confidence':
p = F.softmax(logits.to(torch.float64), dim=-1)
x0_p = torch.squeeze(
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
elif remasking == 'random':
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
else:
raise NotImplementedError(remasking)
x0 = torch.where(mask_index, x0, x)
confidence = torch.where(mask_index, x0_p, -np.inf)
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
if threshold is not None:
num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
for j in range(confidence.shape[0]):
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
transfer_index[j, select_index] = True
if threshold is not None:
for k in range(1, num_transfer_tokens[j]):
if confidence[j, select_index[k]] < threshold:
transfer_index[j, select_index[k]] = False
return x0, transfer_index |