|
|
import torch |
|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
import time |
|
|
import re |
|
|
from collections import Counter |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
def add_gumbel_noise(logits, temperature): |
|
|
if temperature == 0: |
|
|
return logits |
|
|
logits = logits.to(torch.float64) |
|
|
noise = torch.rand_like(logits, dtype=torch.float64) |
|
|
gumbel_noise = (- torch.log(noise)) ** temperature |
|
|
return logits.exp() / gumbel_noise |
|
|
|
|
|
def get_num_transfer_tokens(block_mask_index: torch.Tensor, steps: int) -> torch.Tensor: |
|
|
device = block_mask_index.device |
|
|
dtype = torch.long |
|
|
total = block_mask_index.sum(dim=1) |
|
|
base = torch.div(total, steps, rounding_mode='floor') |
|
|
rem = total - base * steps |
|
|
num_transfer_tokens = base.unsqueeze(1).expand(-1, steps).to(dtype) |
|
|
cols = torch.arange(steps, device=device).unsqueeze(0) |
|
|
add_mask = cols < rem.unsqueeze(1) |
|
|
num_transfer_tokens = num_transfer_tokens + add_mask.to(dtype) |
|
|
return num_transfer_tokens |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None): |
|
|
|
|
|
logits_with_noise = add_gumbel_noise(logits, temperature=temperature) |
|
|
x0 = torch.argmax(logits_with_noise, dim=-1) |
|
|
|
|
|
|
|
|
if remasking == "low_confidence": |
|
|
p = F.softmax(logits.to(torch.float64), dim=-1) |
|
|
x0_p = torch.gather(p, dim=-1, index=x0.unsqueeze(-1)).squeeze(-1) |
|
|
|
|
|
|
|
|
elif remasking == "top_prob_margin": |
|
|
p = F.softmax(logits.to(torch.float64), dim=-1) |
|
|
top2_probs, _ = torch.topk(p, k=2, dim=-1) |
|
|
x0_p = top2_probs[..., 0] - top2_probs[..., 1] |
|
|
|
|
|
elif remasking == "random": |
|
|
x0_p = torch.rand(x0.shape, device=x0.device, dtype=torch.float64) |
|
|
else: |
|
|
raise NotImplementedError(remasking) |
|
|
|
|
|
|
|
|
x0 = torch.where(mask_index, x0, x) |
|
|
neg_inf = torch.tensor(torch.finfo(x0_p.dtype).min, device=x0_p.device, dtype=x0_p.dtype) |
|
|
confidence = torch.where(mask_index, x0_p, neg_inf) |
|
|
|
|
|
|
|
|
if threshold is not None: |
|
|
transfer_index = mask_index & (confidence >= threshold) |
|
|
max_conf_indices = torch.argmax(confidence, dim=1, keepdim=True) |
|
|
force_mask = torch.zeros_like(transfer_index).scatter_(1, max_conf_indices, True) |
|
|
transfer_index = transfer_index | force_mask |
|
|
transfer_index = transfer_index & mask_index |
|
|
return x0, transfer_index |
|
|
|
|
|
if num_transfer_tokens is None: |
|
|
raise ValueError("num_transfer_tokens must be a tensor when threshold is None.") |
|
|
|
|
|
if num_transfer_tokens.dim() == 2 and num_transfer_tokens.size(1) == 1: |
|
|
num_transfer_tokens = num_transfer_tokens.squeeze(1) |
|
|
num_transfer_tokens = num_transfer_tokens.to(dtype=torch.long, device=confidence.device) |
|
|
num_transfer_tokens = torch.clamp(num_transfer_tokens, min=0) |
|
|
|
|
|
values, idx = torch.sort(confidence, dim=1, descending=True) |
|
|
B, L = confidence.shape |
|
|
cols = torch.arange(L, device=confidence.device).unsqueeze(0).expand(B, L) |
|
|
k_expanded = num_transfer_tokens.unsqueeze(1).expand(B, L) |
|
|
select_sorted = cols < k_expanded |
|
|
|
|
|
transfer_int = torch.zeros(B, L, device=confidence.device, dtype=torch.int8) |
|
|
transfer_int = transfer_int.scatter(1, idx, select_sorted.to(torch.int8)) |
|
|
transfer_index = transfer_int.bool() & mask_index |
|
|
|
|
|
return x0, transfer_index |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_transfer_index_dynamic(logits, temperature, remasking, mask_index, x, num_transfer_tokens, factor=1): |
|
|
logits_with_noise = add_gumbel_noise(logits, temperature=temperature) |
|
|
x0 = torch.argmax(logits_with_noise, dim=-1) |
|
|
|
|
|
if remasking == 'low_confidence': |
|
|
p = F.softmax(logits.to(torch.float64), dim=-1) |
|
|
x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) |
|
|
|
|
|
|
|
|
elif remasking == 'top_prob_margin': |
|
|
p = F.softmax(logits.to(torch.float64), dim=-1) |
|
|
top2_probs, _ = torch.topk(p, k=2, dim=-1) |
|
|
x0_p = top2_probs[..., 0] - top2_probs[..., 1] |
|
|
|
|
|
elif remasking == 'random': |
|
|
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) |
|
|
else: |
|
|
raise NotImplementedError(remasking) |
|
|
|
|
|
x0 = torch.where(mask_index, x0, x) |
|
|
confidence = torch.where(mask_index, x0_p, -np.inf) |
|
|
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) |
|
|
num_transfer_tokens = mask_index.sum(dim=1, keepdim=True) |
|
|
|
|
|
for j in range(confidence.shape[0]): |
|
|
num_tokens = int(num_transfer_tokens[j].item()) |
|
|
if num_tokens == 0: continue |
|
|
|
|
|
ns = list(range(1, num_transfer_tokens[j] + 1)) |
|
|
es = [factor / (n + 1) for n in ns] |
|
|
threshs = [1 - e for e in es] |
|
|
threshs[0] = -1 |
|
|
|
|
|
sorted_confidence = torch.sort(confidence[j][mask_index[j]], dim=-1, descending=True)[0] |
|
|
top_i = len(threshs) |
|
|
for i in range(len(threshs)): |
|
|
if sorted_confidence[i] < threshs[i]: |
|
|
top_i = i |
|
|
break |
|
|
if top_i == 0: top_i = 1 |
|
|
|
|
|
_, select_index = torch.topk(confidence[j], k=top_i) |
|
|
transfer_index[j, select_index] = True |
|
|
|
|
|
return x0, transfer_index |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ torch.no_grad() |
|
|
def generate_standard(model, prompt, attention_mask=None, steps=128, gen_length=128, block_length=128, temperature=0., |
|
|
cfg_scale=0., remasking='low_confidence', mask_id=126336, logits_eos_inf=False, confidence_eos_eot_inf=False): |
|
|
x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) |
|
|
x[:, :prompt.shape[1]] = prompt.clone() |
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_mask = torch.cat([attention_mask, torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, device=model.device)], dim=-1) |
|
|
|
|
|
prompt_index = (x != mask_id) |
|
|
assert gen_length % block_length == 0 |
|
|
num_blocks = gen_length // block_length |
|
|
assert steps % num_blocks == 0 |
|
|
steps = steps // num_blocks |
|
|
|
|
|
for num_block in range(num_blocks): |
|
|
block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length] == mask_id) |
|
|
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) |
|
|
|
|
|
for i in range(steps): |
|
|
mask_index = (x == mask_id) |
|
|
if cfg_scale > 0.: |
|
|
un_x = x.clone() |
|
|
un_x[prompt_index] = mask_id |
|
|
x_ = torch.cat([x, un_x], dim=0) |
|
|
if attention_mask is not None: |
|
|
attention_mask_ = torch.cat([attention_mask, attention_mask], dim=0) |
|
|
logits = model(x_, attention_mask=attention_mask_).logits |
|
|
logits, un_logits = torch.chunk(logits, 2, dim=0) |
|
|
logits = un_logits + (cfg_scale + 1) * (logits - un_logits) |
|
|
else: |
|
|
logits = model(x, attention_mask=attention_mask).logits |
|
|
|
|
|
if logits_eos_inf: |
|
|
logits[:, :, 126081] = -torch.inf |
|
|
|
|
|
logits_with_noise = add_gumbel_noise(logits, temperature=temperature) |
|
|
x0 = torch.argmax(logits_with_noise, dim=-1) |
|
|
|
|
|
if confidence_eos_eot_inf: |
|
|
logits_with_noise[:, :, 126081] = logits[:, :, 126348] = -torch.inf |
|
|
|
|
|
if remasking == 'low_confidence': |
|
|
p = F.softmax(logits, dim=-1) |
|
|
x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) |
|
|
elif remasking == 'top_prob_margin': |
|
|
p = F.softmax(logits, dim=-1) |
|
|
top2_probs, _ = torch.topk(p, k=2, dim=-1) |
|
|
x0_p = top2_probs[:, :, 0] - top2_probs[:, :, 1] |
|
|
elif remasking == 'random': |
|
|
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) |
|
|
else: |
|
|
raise NotImplementedError(remasking) |
|
|
|
|
|
x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf |
|
|
x0 = torch.where(mask_index, x0, x) |
|
|
confidence = torch.where(mask_index, x0_p, -np.inf) |
|
|
|
|
|
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) |
|
|
for j in range(confidence.shape[0]): |
|
|
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) |
|
|
transfer_index[j, select_index] = True |
|
|
x[transfer_index] = x0[transfer_index] |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_with_dual_cache( |
|
|
model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., |
|
|
remasking="low_confidence", mask_id=126336, threshold=None, factor=None, |
|
|
cfg_scale=0., logits_eos_inf=False, confidence_eos_eot_inf=False, attention_mask=None |
|
|
): |
|
|
if cfg_scale > 0: |
|
|
print("โ ๏ธ Warning: cfg_scale > 0 is not supported in Dual Cache mode. Falling back to standard generate.") |
|
|
return generate_standard(model, prompt, attention_mask, steps, gen_length, block_length, temperature, cfg_scale, remasking, mask_id, logits_eos_inf, confidence_eos_eot_inf) |
|
|
|
|
|
B = prompt.shape[0] |
|
|
Lp = int(prompt.shape[1]) |
|
|
|
|
|
assert gen_length % block_length == 0 |
|
|
num_blocks = gen_length // block_length |
|
|
assert steps % num_blocks == 0 |
|
|
steps_per_block = steps // num_blocks |
|
|
|
|
|
x = torch.full((B, Lp + gen_length), mask_id, dtype=torch.long, device=model.device) |
|
|
x[:, :Lp] = prompt |
|
|
|
|
|
nfe = 0 |
|
|
for nb in range(num_blocks): |
|
|
s = Lp + nb * block_length |
|
|
e = s + block_length |
|
|
|
|
|
block_mask_index = (x[:, s:e] == mask_id) |
|
|
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block) |
|
|
|
|
|
|
|
|
out_full = model(x, use_cache=True) |
|
|
past_key_values = out_full.past_key_values |
|
|
nfe += 1 |
|
|
|
|
|
replace_position = torch.zeros_like(x, dtype=torch.bool) |
|
|
replace_position[:, s:e] = True |
|
|
|
|
|
global_mask_index = (x == mask_id) |
|
|
global_mask_index[:, e:] = False |
|
|
|
|
|
if factor is None: |
|
|
quota0 = None if threshold is not None else num_transfer_tokens[:, 0] |
|
|
|
|
|
x0, transfer_index = get_transfer_index( |
|
|
out_full.logits, temperature, remasking, global_mask_index, x, quota0, threshold |
|
|
) |
|
|
else: |
|
|
x0, transfer_index = get_transfer_index_dynamic( |
|
|
out_full.logits, temperature, remasking, global_mask_index, x, None, factor |
|
|
) |
|
|
|
|
|
x = torch.where(transfer_index, x0, x) |
|
|
|
|
|
for i in range(1, steps_per_block): |
|
|
if (x[:, s:e] == mask_id).sum() == 0: |
|
|
break |
|
|
try: |
|
|
logits_blk = model( |
|
|
x[:, s:e], past_key_values=past_key_values, use_cache=True, replace_position=replace_position |
|
|
).logits |
|
|
except TypeError: |
|
|
logits_blk = model( |
|
|
x[:, s:e], past_key_values=past_key_values, use_cache=True |
|
|
).logits |
|
|
|
|
|
mask_blk = (x[:, s:e] == mask_id) |
|
|
|
|
|
if factor is None: |
|
|
quota_i = None if threshold is not None else num_transfer_tokens[:, i] |
|
|
x0_blk, transfer_idx_blk = get_transfer_index( |
|
|
logits_blk, temperature, remasking, mask_blk, x[:, s:e], quota_i, threshold |
|
|
) |
|
|
else: |
|
|
x0_blk, transfer_idx_blk = get_transfer_index_dynamic( |
|
|
logits_blk, temperature, remasking, mask_blk, x[:, s:e], None, factor |
|
|
) |
|
|
|
|
|
blk_old = x[:, s:e] |
|
|
blk_new = torch.where(transfer_idx_blk, x0_blk, blk_old) |
|
|
x = torch.cat([x[:, :s], blk_new, x[:, e:]], dim=1) |
|
|
nfe += 1 |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
generate = generate_standard |