llada-8b-table-sft-baseline / llada_generate.py
6uvsoomJ's picture
Upload llada_generate.py with huggingface_hub
8e21640 verified
import torch
import numpy as np
import torch.nn.functional as F
import time
import re
from collections import Counter
from transformers import AutoTokenizer, AutoModel
def add_gumbel_noise(logits, temperature):
if temperature == 0:
return logits
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
gumbel_noise = (- torch.log(noise)) ** temperature
return logits.exp() / gumbel_noise
def get_num_transfer_tokens(block_mask_index: torch.Tensor, steps: int) -> torch.Tensor:
device = block_mask_index.device
dtype = torch.long
total = block_mask_index.sum(dim=1)
base = torch.div(total, steps, rounding_mode='floor')
rem = total - base * steps
num_transfer_tokens = base.unsqueeze(1).expand(-1, steps).to(dtype)
cols = torch.arange(steps, device=device).unsqueeze(0)
add_mask = cols < rem.unsqueeze(1)
num_transfer_tokens = num_transfer_tokens + add_mask.to(dtype)
return num_transfer_tokens
# =================================================================
# [์ˆ˜์ •๋จ] top_prob_margin ์ง€์› ์ถ”๊ฐ€
# =================================================================
def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None):
# 1) Sample proposal x0
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0 = torch.argmax(logits_with_noise, dim=-1)
# 2) Confidence for chosen tokens
if remasking == "low_confidence":
p = F.softmax(logits.to(torch.float64), dim=-1)
x0_p = torch.gather(p, dim=-1, index=x0.unsqueeze(-1)).squeeze(-1)
# [์—ฌ๊ธฐ ์ถ”๊ฐ€๋จ!] top_prob_margin ๋กœ์ง ๋ณต์›
elif remasking == "top_prob_margin":
p = F.softmax(logits.to(torch.float64), dim=-1)
top2_probs, _ = torch.topk(p, k=2, dim=-1)
x0_p = top2_probs[..., 0] - top2_probs[..., 1]
elif remasking == "random":
x0_p = torch.rand(x0.shape, device=x0.device, dtype=torch.float64)
else:
raise NotImplementedError(remasking)
# Only modify masked spots
x0 = torch.where(mask_index, x0, x)
neg_inf = torch.tensor(torch.finfo(x0_p.dtype).min, device=x0_p.device, dtype=x0_p.dtype)
confidence = torch.where(mask_index, x0_p, neg_inf)
# 3) Pick positions to transfer
if threshold is not None:
transfer_index = mask_index & (confidence >= threshold)
max_conf_indices = torch.argmax(confidence, dim=1, keepdim=True)
force_mask = torch.zeros_like(transfer_index).scatter_(1, max_conf_indices, True)
transfer_index = transfer_index | force_mask
transfer_index = transfer_index & mask_index
return x0, transfer_index
if num_transfer_tokens is None:
raise ValueError("num_transfer_tokens must be a tensor when threshold is None.")
if num_transfer_tokens.dim() == 2 and num_transfer_tokens.size(1) == 1:
num_transfer_tokens = num_transfer_tokens.squeeze(1)
num_transfer_tokens = num_transfer_tokens.to(dtype=torch.long, device=confidence.device)
num_transfer_tokens = torch.clamp(num_transfer_tokens, min=0)
values, idx = torch.sort(confidence, dim=1, descending=True)
B, L = confidence.shape
cols = torch.arange(L, device=confidence.device).unsqueeze(0).expand(B, L)
k_expanded = num_transfer_tokens.unsqueeze(1).expand(B, L)
select_sorted = cols < k_expanded
transfer_int = torch.zeros(B, L, device=confidence.device, dtype=torch.int8)
transfer_int = transfer_int.scatter(1, idx, select_sorted.to(torch.int8))
transfer_index = transfer_int.bool() & mask_index
return x0, transfer_index
# =================================================================
# [์ˆ˜์ •๋จ] top_prob_margin ์ง€์› ์ถ”๊ฐ€ (Dynamic ๋ฒ„์ „)
# =================================================================
def get_transfer_index_dynamic(logits, temperature, remasking, mask_index, x, num_transfer_tokens, factor=1):
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0 = torch.argmax(logits_with_noise, dim=-1)
if remasking == 'low_confidence':
p = F.softmax(logits.to(torch.float64), dim=-1)
x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1)
# [์—ฌ๊ธฐ ์ถ”๊ฐ€๋จ!] top_prob_margin ๋กœ์ง ๋ณต์›
elif remasking == 'top_prob_margin':
p = F.softmax(logits.to(torch.float64), dim=-1)
top2_probs, _ = torch.topk(p, k=2, dim=-1)
x0_p = top2_probs[..., 0] - top2_probs[..., 1]
elif remasking == 'random':
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
else:
raise NotImplementedError(remasking)
x0 = torch.where(mask_index, x0, x)
confidence = torch.where(mask_index, x0_p, -np.inf)
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
for j in range(confidence.shape[0]):
num_tokens = int(num_transfer_tokens[j].item())
if num_tokens == 0: continue
ns = list(range(1, num_transfer_tokens[j] + 1))
es = [factor / (n + 1) for n in ns]
threshs = [1 - e for e in es]
threshs[0] = -1
sorted_confidence = torch.sort(confidence[j][mask_index[j]], dim=-1, descending=True)[0]
top_i = len(threshs)
for i in range(len(threshs)):
if sorted_confidence[i] < threshs[i]:
top_i = i
break
if top_i == 0: top_i = 1
_, select_index = torch.topk(confidence[j], k=top_i)
transfer_index[j, select_index] = True
return x0, transfer_index
# =================================================================
# generate_standard (๊ธฐ์กด ํ•จ์ˆ˜)
# =================================================================
@ torch.no_grad()
def generate_standard(model, prompt, attention_mask=None, steps=128, gen_length=128, block_length=128, temperature=0.,
cfg_scale=0., remasking='low_confidence', mask_id=126336, logits_eos_inf=False, confidence_eos_eot_inf=False):
x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
x[:, :prompt.shape[1]] = prompt.clone()
if attention_mask is not None:
attention_mask = torch.cat([attention_mask, torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, device=model.device)], dim=-1)
prompt_index = (x != mask_id)
assert gen_length % block_length == 0
num_blocks = gen_length // block_length
assert steps % num_blocks == 0
steps = steps // num_blocks
for num_block in range(num_blocks):
block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length] == mask_id)
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
for i in range(steps):
mask_index = (x == mask_id)
if cfg_scale > 0.:
un_x = x.clone()
un_x[prompt_index] = mask_id
x_ = torch.cat([x, un_x], dim=0)
if attention_mask is not None:
attention_mask_ = torch.cat([attention_mask, attention_mask], dim=0)
logits = model(x_, attention_mask=attention_mask_).logits
logits, un_logits = torch.chunk(logits, 2, dim=0)
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
else:
logits = model(x, attention_mask=attention_mask).logits
if logits_eos_inf:
logits[:, :, 126081] = -torch.inf
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0 = torch.argmax(logits_with_noise, dim=-1)
if confidence_eos_eot_inf:
logits_with_noise[:, :, 126081] = logits[:, :, 126348] = -torch.inf
if remasking == 'low_confidence':
p = F.softmax(logits, dim=-1)
x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1)
elif remasking == 'top_prob_margin':
p = F.softmax(logits, dim=-1)
top2_probs, _ = torch.topk(p, k=2, dim=-1)
x0_p = top2_probs[:, :, 0] - top2_probs[:, :, 1]
elif remasking == 'random':
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
else:
raise NotImplementedError(remasking)
x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf
x0 = torch.where(mask_index, x0, x)
confidence = torch.where(mask_index, x0_p, -np.inf)
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
for j in range(confidence.shape[0]):
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
transfer_index[j, select_index] = True
x[transfer_index] = x0[transfer_index]
return x
# =================================================================
# generate_with_dual_cache (์ตœ์ ํ™” ํ•จ์ˆ˜)
# =================================================================
@torch.no_grad()
def generate_with_dual_cache(
model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
remasking="low_confidence", mask_id=126336, threshold=None, factor=None,
cfg_scale=0., logits_eos_inf=False, confidence_eos_eot_inf=False, attention_mask=None
):
if cfg_scale > 0:
print("โš ๏ธ Warning: cfg_scale > 0 is not supported in Dual Cache mode. Falling back to standard generate.")
return generate_standard(model, prompt, attention_mask, steps, gen_length, block_length, temperature, cfg_scale, remasking, mask_id, logits_eos_inf, confidence_eos_eot_inf)
B = prompt.shape[0]
Lp = int(prompt.shape[1])
assert gen_length % block_length == 0
num_blocks = gen_length // block_length
assert steps % num_blocks == 0
steps_per_block = steps // num_blocks
x = torch.full((B, Lp + gen_length), mask_id, dtype=torch.long, device=model.device)
x[:, :Lp] = prompt
nfe = 0
for nb in range(num_blocks):
s = Lp + nb * block_length
e = s + block_length
block_mask_index = (x[:, s:e] == mask_id)
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)
# 1) Warm KV-cache
out_full = model(x, use_cache=True)
past_key_values = out_full.past_key_values
nfe += 1
replace_position = torch.zeros_like(x, dtype=torch.bool)
replace_position[:, s:e] = True
global_mask_index = (x == mask_id)
global_mask_index[:, e:] = False
if factor is None:
quota0 = None if threshold is not None else num_transfer_tokens[:, 0]
# ์—ฌ๊ธฐ remasking ์ธ์ž๊ฐ€ 'top_prob_margin'์ด์–ด๋„ ์ด์ œ ์ž‘๋™ํ•จ
x0, transfer_index = get_transfer_index(
out_full.logits, temperature, remasking, global_mask_index, x, quota0, threshold
)
else:
x0, transfer_index = get_transfer_index_dynamic(
out_full.logits, temperature, remasking, global_mask_index, x, None, factor
)
x = torch.where(transfer_index, x0, x)
for i in range(1, steps_per_block):
if (x[:, s:e] == mask_id).sum() == 0:
break
try:
logits_blk = model(
x[:, s:e], past_key_values=past_key_values, use_cache=True, replace_position=replace_position
).logits
except TypeError:
logits_blk = model(
x[:, s:e], past_key_values=past_key_values, use_cache=True
).logits
mask_blk = (x[:, s:e] == mask_id)
if factor is None:
quota_i = None if threshold is not None else num_transfer_tokens[:, i]
x0_blk, transfer_idx_blk = get_transfer_index(
logits_blk, temperature, remasking, mask_blk, x[:, s:e], quota_i, threshold
)
else:
x0_blk, transfer_idx_blk = get_transfer_index_dynamic(
logits_blk, temperature, remasking, mask_blk, x[:, s:e], None, factor
)
blk_old = x[:, s:e]
blk_new = torch.where(transfer_idx_blk, x0_blk, blk_old)
x = torch.cat([x[:, :s], blk_new, x[:, e:]], dim=1)
nfe += 1
return x
# Alias
generate = generate_standard