|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from torch.nn import functional as F |
|
|
from transformers.cache_utils import DynamicCache |
|
|
|
|
|
|
|
|
def top_k_logits(logits, k): |
|
|
if k <= 0: |
|
|
return logits |
|
|
else: |
|
|
values, _ = torch.topk(logits, k) |
|
|
min_values = values[..., -1, None] |
|
|
return torch.where(logits < min_values, torch.full_like(logits, float("-inf")), logits) |
|
|
|
|
|
|
|
|
def top_p_logits(logits, p): |
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
sorted_mask = cumulative_probs > p |
|
|
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone() |
|
|
sorted_mask[..., 0] = False |
|
|
mask_indices = torch.scatter(torch.full_like(logits, False, dtype=torch.bool), -1, sorted_indices, sorted_mask) |
|
|
logits = logits.masked_fill(mask_indices, float("-inf")) |
|
|
return logits |
|
|
|
|
|
|
|
|
def sample_with_temperature_topk_topp(logits, temperature=1.0, top_k=0, top_p=1.0): |
|
|
orig_shape = logits.shape[:-1] |
|
|
vocab_size = logits.shape[-1] |
|
|
|
|
|
logits = logits.reshape(-1, vocab_size) |
|
|
|
|
|
if temperature != 1.0: |
|
|
logits = logits / temperature |
|
|
if top_k > 0: |
|
|
logits = top_k_logits(logits, top_k) |
|
|
if top_p < 1.0: |
|
|
logits = top_p_logits(logits, top_p) |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
assert probs.dim() == 2 |
|
|
token = torch.multinomial(probs, num_samples=1) |
|
|
token_prob = torch.gather(probs, -1, token) |
|
|
|
|
|
return token.view(*orig_shape), token_prob.view(*orig_shape) |
|
|
|
|
|
|
|
|
def get_num_transfer_tokens(block_length, steps): |
|
|
base = block_length // steps |
|
|
remainder = block_length % steps |
|
|
num_transfer_tokens = torch.zeros(steps, dtype=torch.int64) + base |
|
|
num_transfer_tokens[:remainder] += 1 |
|
|
return num_transfer_tokens |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def block_diffusion_generate( |
|
|
model, |
|
|
prompt, |
|
|
mask_id, |
|
|
gen_length=128, |
|
|
block_length=8, |
|
|
denoising_steps=8, |
|
|
temperature=1.0, |
|
|
top_k=0, |
|
|
top_p=1.0, |
|
|
remasking_strategy="low_confidence_dynamic", |
|
|
confidence_threshold=0.85, |
|
|
stopping_criteria_idx=None, |
|
|
): |
|
|
model.eval() |
|
|
input_ids = prompt["input_ids"] |
|
|
prompt_length = input_ids.shape[1] |
|
|
past_key_values = DynamicCache() |
|
|
|
|
|
num_blocks = (prompt_length + gen_length + block_length - 1) // block_length |
|
|
total_length = num_blocks * block_length |
|
|
|
|
|
block_mask = torch.tril(torch.ones(num_blocks, num_blocks, device=model.device)) |
|
|
block_diffusion_attention_mask = block_mask.repeat_interleave(block_length, dim=0).repeat_interleave(block_length, dim=1).unsqueeze(0) |
|
|
position_ids = torch.arange(total_length, device=model.device).unsqueeze(0) |
|
|
|
|
|
x = torch.full((1, total_length), mask_id, dtype=torch.long, device=model.device) |
|
|
x[:, :prompt_length] = input_ids |
|
|
prefill_blocks = prompt_length // block_length |
|
|
prefill_length = prefill_blocks * block_length |
|
|
|
|
|
|
|
|
if prefill_length > 0: |
|
|
cur_x = x[:, :prefill_length] |
|
|
cur_attn_mask = block_diffusion_attention_mask[:, :prefill_length, :prefill_length] |
|
|
if cur_attn_mask.dim() == 3: |
|
|
cur_attn_mask = cur_attn_mask[:, None, :, :] |
|
|
cur_position_ids = position_ids[:, :prefill_length] |
|
|
model(cur_x, attention_mask=cur_attn_mask, position_ids=cur_position_ids, past_key_values=past_key_values, use_cache=True, store_kv=True) |
|
|
|
|
|
num_transfer_tokens = get_num_transfer_tokens(block_length, denoising_steps) |
|
|
|
|
|
|
|
|
for num_block in range(prefill_blocks, num_blocks): |
|
|
cur_x = x[:, num_block * block_length : (num_block + 1) * block_length].clone() |
|
|
cur_attn_mask = block_diffusion_attention_mask[:, num_block * block_length : (num_block + 1) * block_length, : (num_block + 1) * block_length] |
|
|
if cur_attn_mask.dim() == 3: |
|
|
cur_attn_mask = cur_attn_mask[:, None, :, :] |
|
|
cur_position_ids = position_ids[:, num_block * block_length : (num_block + 1) * block_length] |
|
|
for step in range(denoising_steps + 1): |
|
|
mask_index = cur_x == mask_id |
|
|
if mask_index.sum() == 0: |
|
|
|
|
|
model(cur_x, attention_mask=cur_attn_mask, position_ids=cur_position_ids, past_key_values=past_key_values, use_cache=True, store_kv=True) |
|
|
break |
|
|
|
|
|
|
|
|
output = model(cur_x, attention_mask=cur_attn_mask, position_ids=cur_position_ids, past_key_values=past_key_values, use_cache=True, store_kv=False) |
|
|
|
|
|
if hasattr(output, "logits") and output.logits is not None: |
|
|
logits = output.logits |
|
|
elif hasattr(output, "last_hidden_state"): |
|
|
|
|
|
|
|
|
if hasattr(model, "lm_head"): |
|
|
hidden_states = output.last_hidden_state |
|
|
logits = model.lm_head(hidden_states) |
|
|
else: |
|
|
raise ValueError("Model output does not contain logits and model does not have lm_head to compute them.") |
|
|
else: |
|
|
raise ValueError(f"Unexpected model output type: {type(output)}. Expected CausalLMOutputWithPast or BaseModelOutputWithPast with logits or last_hidden_state.") |
|
|
|
|
|
|
|
|
x0, x0_p = sample_with_temperature_topk_topp(logits, temperature=temperature, top_k=top_k, top_p=top_p) |
|
|
|
|
|
|
|
|
if remasking_strategy == "sequential": |
|
|
transfer_index = torch.zeros_like(x0, dtype=torch.bool) |
|
|
for j in range(cur_x.shape[0]): |
|
|
if mask_index[j].any(): |
|
|
first_mask_index = mask_index[j].nonzero(as_tuple=True)[0].min().item() |
|
|
transfer_index[j, first_mask_index : first_mask_index + num_transfer_tokens[step]] = True |
|
|
else: |
|
|
raise ValueError("No mask tokens found in the current block.") |
|
|
|
|
|
elif remasking_strategy == "low_confidence_static": |
|
|
confidence = torch.where(mask_index, x0_p, -torch.inf) |
|
|
transfer_index = torch.zeros_like(x0, dtype=torch.bool) |
|
|
for j in range(confidence.shape[0]): |
|
|
_, idx = torch.topk(confidence[j], num_transfer_tokens[step]) |
|
|
transfer_index[j, idx] = True |
|
|
|
|
|
elif remasking_strategy == "low_confidence_dynamic": |
|
|
confidence = torch.where(mask_index, x0_p, -torch.inf) |
|
|
transfer_index = torch.zeros_like(x0, dtype=torch.bool) |
|
|
for j in range(confidence.shape[0]): |
|
|
high_conf_mask = confidence[j] > confidence_threshold |
|
|
num_high_confidence = high_conf_mask.sum() |
|
|
if num_high_confidence >= num_transfer_tokens[step]: |
|
|
transfer_index[j] = high_conf_mask |
|
|
else: |
|
|
_, idx = torch.topk(confidence[j], num_transfer_tokens[step]) |
|
|
transfer_index[j, idx] = True |
|
|
else: |
|
|
raise ValueError(f"Unknown remasking strategy: {remasking_strategy}") |
|
|
|
|
|
cur_x[transfer_index] = x0[transfer_index] |
|
|
|
|
|
x[:, num_block * block_length : (num_block + 1) * block_length] = cur_x |
|
|
if stopping_criteria_idx is not None and any(stop_idx in x[:, prompt_length:] for stop_idx in stopping_criteria_idx): |
|
|
break |
|
|
|
|
|
return x |
|
|
|