File size: 7,926 Bytes
6daf6b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
# SPDX-License-Identifier: Apache-2.0
# adapted fromhttps://github.com/Gen-Verse/dLLM-RL
# adapted from SADR https://github.com/JetAstra/SDAR/blob/main/generate.py
import torch
from torch.nn import functional as F
from transformers.cache_utils import DynamicCache
def top_k_logits(logits, k):
if k <= 0:
return logits
else:
values, _ = torch.topk(logits, k)
min_values = values[..., -1, None]
return torch.where(logits < min_values, torch.full_like(logits, float("-inf")), logits)
def top_p_logits(logits, p):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_mask = cumulative_probs > p
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
sorted_mask[..., 0] = False
mask_indices = torch.scatter(torch.full_like(logits, False, dtype=torch.bool), -1, sorted_indices, sorted_mask)
logits = logits.masked_fill(mask_indices, float("-inf"))
return logits
def sample_with_temperature_topk_topp(logits, temperature=1.0, top_k=0, top_p=1.0):
orig_shape = logits.shape[:-1] # [batch, block]
vocab_size = logits.shape[-1]
logits = logits.reshape(-1, vocab_size) # [batch*block, vocab]
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
logits = top_k_logits(logits, top_k)
if top_p < 1.0:
logits = top_p_logits(logits, top_p)
probs = F.softmax(logits, dim=-1) # shape: [batch*block, vocab]
assert probs.dim() == 2
token = torch.multinomial(probs, num_samples=1) # [batch*block, 1]
token_prob = torch.gather(probs, -1, token) # [batch*block, 1]
return token.view(*orig_shape), token_prob.view(*orig_shape)
def get_num_transfer_tokens(block_length, steps):
base = block_length // steps
remainder = block_length % steps
num_transfer_tokens = torch.zeros(steps, dtype=torch.int64) + base
num_transfer_tokens[:remainder] += 1
return num_transfer_tokens
@torch.no_grad()
def block_diffusion_generate(
model,
prompt,
mask_id,
gen_length=128,
block_length=8,
denoising_steps=8,
temperature=1.0,
top_k=0,
top_p=1.0,
remasking_strategy="low_confidence_dynamic",
confidence_threshold=0.85,
stopping_criteria_idx=None,
):
model.eval()
input_ids = prompt["input_ids"]
prompt_length = input_ids.shape[1]
past_key_values = DynamicCache()
num_blocks = (prompt_length + gen_length + block_length - 1) // block_length
total_length = num_blocks * block_length
block_mask = torch.tril(torch.ones(num_blocks, num_blocks, device=model.device))
block_diffusion_attention_mask = block_mask.repeat_interleave(block_length, dim=0).repeat_interleave(block_length, dim=1).unsqueeze(0)
position_ids = torch.arange(total_length, device=model.device).unsqueeze(0)
x = torch.full((1, total_length), mask_id, dtype=torch.long, device=model.device)
x[:, :prompt_length] = input_ids
prefill_blocks = prompt_length // block_length
prefill_length = prefill_blocks * block_length
# Prefill stage
if prefill_length > 0:
cur_x = x[:, :prefill_length]
cur_attn_mask = block_diffusion_attention_mask[:, :prefill_length, :prefill_length]
if cur_attn_mask.dim() == 3:
cur_attn_mask = cur_attn_mask[:, None, :, :]
cur_position_ids = position_ids[:, :prefill_length]
model(cur_x, attention_mask=cur_attn_mask, position_ids=cur_position_ids, past_key_values=past_key_values, use_cache=True, store_kv=True)
num_transfer_tokens = get_num_transfer_tokens(block_length, denoising_steps)
# Decode stage
for num_block in range(prefill_blocks, num_blocks):
cur_x = x[:, num_block * block_length : (num_block + 1) * block_length].clone()
cur_attn_mask = block_diffusion_attention_mask[:, num_block * block_length : (num_block + 1) * block_length, : (num_block + 1) * block_length]
if cur_attn_mask.dim() == 3:
cur_attn_mask = cur_attn_mask[:, None, :, :]
cur_position_ids = position_ids[:, num_block * block_length : (num_block + 1) * block_length]
for step in range(denoising_steps + 1):
mask_index = cur_x == mask_id
if mask_index.sum() == 0:
# Store kv cache
model(cur_x, attention_mask=cur_attn_mask, position_ids=cur_position_ids, past_key_values=past_key_values, use_cache=True, store_kv=True)
break
# Denosing
output = model(cur_x, attention_mask=cur_attn_mask, position_ids=cur_position_ids, past_key_values=past_key_values, use_cache=True, store_kv=False)
# Extract logits from the output - handle both CausalLMOutputWithPast and BaseModelOutputWithPast
if hasattr(output, "logits") and output.logits is not None:
logits = output.logits
elif hasattr(output, "last_hidden_state"):
# If logits don't exist but we have hidden states, compute logits from the model's lm_head
# This can happen if the model returns BaseModelOutputWithPast instead of CausalLMOutputWithPast
if hasattr(model, "lm_head"):
hidden_states = output.last_hidden_state
logits = model.lm_head(hidden_states)
else:
raise ValueError("Model output does not contain logits and model does not have lm_head to compute them.")
else:
raise ValueError(f"Unexpected model output type: {type(output)}. Expected CausalLMOutputWithPast or BaseModelOutputWithPast with logits or last_hidden_state.")
# Sampling
x0, x0_p = sample_with_temperature_topk_topp(logits, temperature=temperature, top_k=top_k, top_p=top_p)
# Sampling strategy
if remasking_strategy == "sequential":
transfer_index = torch.zeros_like(x0, dtype=torch.bool)
for j in range(cur_x.shape[0]):
if mask_index[j].any():
first_mask_index = mask_index[j].nonzero(as_tuple=True)[0].min().item()
transfer_index[j, first_mask_index : first_mask_index + num_transfer_tokens[step]] = True
else:
raise ValueError("No mask tokens found in the current block.")
elif remasking_strategy == "low_confidence_static":
confidence = torch.where(mask_index, x0_p, -torch.inf)
transfer_index = torch.zeros_like(x0, dtype=torch.bool)
for j in range(confidence.shape[0]):
_, idx = torch.topk(confidence[j], num_transfer_tokens[step])
transfer_index[j, idx] = True
elif remasking_strategy == "low_confidence_dynamic":
confidence = torch.where(mask_index, x0_p, -torch.inf)
transfer_index = torch.zeros_like(x0, dtype=torch.bool)
for j in range(confidence.shape[0]):
high_conf_mask = confidence[j] > confidence_threshold
num_high_confidence = high_conf_mask.sum()
if num_high_confidence >= num_transfer_tokens[step]:
transfer_index[j] = high_conf_mask
else:
_, idx = torch.topk(confidence[j], num_transfer_tokens[step])
transfer_index[j, idx] = True
else:
raise ValueError(f"Unknown remasking strategy: {remasking_strategy}")
cur_x[transfer_index] = x0[transfer_index]
x[:, num_block * block_length : (num_block + 1) * block_length] = cur_x
if stopping_criteria_idx is not None and any(stop_idx in x[:, prompt_length:] for stop_idx in stopping_criteria_idx):
break
return x
|