Update generation_utils.py
Browse files- generation_utils.py +210 -293
generation_utils.py
CHANGED
|
@@ -30,23 +30,8 @@ from transformers.utils import (
|
|
| 30 |
is_torchdynamo_compiling,
|
| 31 |
logging,
|
| 32 |
)
|
| 33 |
-
from .generate_from_llada import get_num_transfer_tokens_sch
|
| 34 |
logger = logging.get_logger(__name__)
|
| 35 |
-
|
| 36 |
-
import sys
|
| 37 |
-
import pdb
|
| 38 |
-
class ForkedPdb(pdb.Pdb):
|
| 39 |
-
"""
|
| 40 |
-
PDB Subclass for debugging multi-processed code
|
| 41 |
-
Suggested in: https://stackoverflow.com/questions/4716533/how-to-attach-debugger-to-a-python-subproccess
|
| 42 |
-
"""
|
| 43 |
-
def interaction(self, *args, **kwargs):
|
| 44 |
-
_stdin = sys.stdin
|
| 45 |
-
try:
|
| 46 |
-
sys.stdin = open('/dev/stdin')
|
| 47 |
-
pdb.Pdb.interaction(self, *args, **kwargs)
|
| 48 |
-
finally:
|
| 49 |
-
sys.stdin = _stdin
|
| 50 |
|
| 51 |
|
| 52 |
def top_p_logits(logits, top_p=None):
|
|
@@ -70,123 +55,59 @@ def top_k_logits(logits, top_k=None):
|
|
| 70 |
return logits
|
| 71 |
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
# if temperature > 0:
|
| 76 |
-
# logits = logits / temperature
|
| 77 |
-
# if top_p is not None and top_p < 1:
|
| 78 |
-
# logits = top_p_logits(logits, top_p)
|
| 79 |
-
# if top_k is not None:
|
| 80 |
-
# logits = top_k_logits(logits, top_k)
|
| 81 |
-
# probs = torch.softmax(logits, dim=-1)
|
| 82 |
-
|
| 83 |
-
# if temperature > 0:
|
| 84 |
-
# try:
|
| 85 |
-
# x0 = dists.Categorical(probs=probs).sample()
|
| 86 |
-
# confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
| 87 |
-
# except:
|
| 88 |
-
# confidence, x0 = probs.max(dim=-1)
|
| 89 |
-
# else:
|
| 90 |
-
# confidence, x0 = probs.max(dim=-1)
|
| 91 |
-
|
| 92 |
-
# if margin_confidence:
|
| 93 |
-
# sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
| 94 |
-
# # Extract top1 and top2 probabilities
|
| 95 |
-
# top1_probs = sorted_probs[:, 0]
|
| 96 |
-
# top2_probs = sorted_probs[:, 1]
|
| 97 |
-
# # Calculate confidence as top1 - top2
|
| 98 |
-
# confidence = top1_probs - top2_probs
|
| 99 |
-
|
| 100 |
-
# if neg_entropy:
|
| 101 |
-
# epsilon = 1e-10
|
| 102 |
-
# log_probs = torch.log(probs + epsilon)
|
| 103 |
-
# confidence = torch.sum(probs * log_probs, dim=-1)
|
| 104 |
-
|
| 105 |
-
# return confidence, x0
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
|
| 109 |
-
"""
|
| 110 |
-
从给定的 logits 中采样或贪心选取 token,并返回置信度和 token ID。
|
| 111 |
-
|
| 112 |
-
参数:
|
| 113 |
-
logits (Tensor):形状 [batch_size, vocab_size],模型对各候选 token 的打分(未经 softmax)。
|
| 114 |
-
temperature (float):温度系数,默认 0.0。>0 时按概率采样,=0 时贪心选取。
|
| 115 |
-
top_p (float 或 None):核采样参数(nucleus sampling),若指定且 <1,只保留累计概率前 top_p 的 token。
|
| 116 |
-
top_k (int 或 None):前 k 采样参数(top-k sampling),若指定,只从概率最高的 k 个 token 中选取。
|
| 117 |
-
margin_confidence (bool):是否使用 top1−top2 之差作为置信度,默认 False。
|
| 118 |
-
neg_entropy (bool):是否使用负熵(−∑p·logp)作为置信度,默认 False。
|
| 119 |
-
|
| 120 |
-
返回:
|
| 121 |
-
confidence (Tensor):形状 [batch_size] 的置信度值(可用概率、margin 差值或负熵)。
|
| 122 |
-
x0 (Tensor):形状 [batch_size] 的 int64 张量,表示采样或贪心得到的 token ID。
|
| 123 |
-
"""
|
| 124 |
-
|
| 125 |
-
# ======================================================
|
| 126 |
-
# 1. 温度缩放 (Temperature Scaling)
|
| 127 |
-
# ======================================================
|
| 128 |
if temperature > 0:
|
| 129 |
-
# 当 temperature>0 时,将 logits 除以 temperature,使得 softmax 分布更平滑或更尖锐
|
| 130 |
logits = logits / temperature
|
| 131 |
|
| 132 |
-
# ======================================================
|
| 133 |
-
# 2. Top-p (Nucleus) 与 Top-k 过滤
|
| 134 |
-
# ======================================================
|
| 135 |
if top_p is not None and top_p < 1:
|
| 136 |
-
# 调用 top_p_logits,保留累计概率达到 top_p 的 token,其它 logits 置为很小的负值
|
| 137 |
logits = top_p_logits(logits, top_p)
|
| 138 |
if top_k is not None:
|
| 139 |
-
# 调用 top_k_logits,仅保留概率最高的 top_k 个 token,其它 logits 置为很小的负值
|
| 140 |
logits = top_k_logits(logits, top_k)
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
-
|
| 149 |
-
# 4. 根据 temperature 决定采样或贪心选取
|
| 150 |
-
# ======================================================
|
| 151 |
if temperature > 0:
|
| 152 |
-
# 随机采样分支:从 Categorical 分布中采样 token
|
| 153 |
try:
|
| 154 |
-
# 从多项分布中采样得到 token ID,形状 [batch_size]
|
| 155 |
x0 = dists.Categorical(probs=probs).sample()
|
| 156 |
-
# 用 gather 取出对应位置的概率值作为置信度,形状 [batch_size]
|
| 157 |
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
| 158 |
except:
|
| 159 |
-
# 若采样出错(如概率分布不合法),退化为贪心选取
|
| 160 |
confidence, x0 = probs.max(dim=-1)
|
| 161 |
else:
|
| 162 |
-
# 当 temperature=0 时,直接贪心选取概率最大的 token
|
| 163 |
confidence, x0 = probs.max(dim=-1)
|
| 164 |
|
| 165 |
-
# ======================================================
|
| 166 |
-
# 5. margin_confidence: 使用 top1−top2 差值作为置信度
|
| 167 |
-
# ======================================================
|
| 168 |
if margin_confidence:
|
| 169 |
-
# 将每行概率按降序排序,sorted_probs[:,0] 为 top1,sorted_probs[:,1] 为 top2
|
| 170 |
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
| 171 |
top1_probs = sorted_probs[:, 0]
|
| 172 |
top2_probs = sorted_probs[:, 1]
|
| 173 |
-
# 置信度设为 top1_probs − top2_probs
|
| 174 |
confidence = top1_probs - top2_probs
|
| 175 |
|
| 176 |
-
# ======================================================
|
| 177 |
-
# 6. neg_entropy: 使用负熵(−∑ p·log p)作为置信度
|
| 178 |
-
# ======================================================
|
| 179 |
if neg_entropy:
|
| 180 |
epsilon = 1e-10
|
| 181 |
-
# 为避免 log(0) 产生 −inf,加上一个小常数 epsilon
|
| 182 |
log_probs = torch.log(probs + epsilon)
|
| 183 |
-
# 计算 ∑ p_i * log p_i,结果是负熵值(值越接近 0,表示分布更“尖锐”)
|
| 184 |
confidence = torch.sum(probs * log_probs, dim=-1)
|
| 185 |
|
| 186 |
return confidence, x0
|
| 187 |
|
| 188 |
|
| 189 |
-
|
| 190 |
@dataclass
|
| 191 |
class DreamModelOutput(ModelOutput):
|
| 192 |
sequences: torch.LongTensor = None
|
|
@@ -398,6 +319,10 @@ class DreamGenerationMixin:
|
|
| 398 |
generation_config: Optional[DreamGenerationConfig] = None,
|
| 399 |
inputs_embeds=None,
|
| 400 |
prefix_lm=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
**kwargs,
|
| 402 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
| 403 |
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
|
@@ -406,7 +331,6 @@ class DreamGenerationMixin:
|
|
| 406 |
generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
|
| 407 |
# breakpoint()
|
| 408 |
# 2. Define model inputs
|
| 409 |
-
# import pdb;pdb.set_trace()
|
| 410 |
if inputs is not None:
|
| 411 |
input_ids = inputs
|
| 412 |
device = input_ids.device
|
|
@@ -440,7 +364,6 @@ class DreamGenerationMixin:
|
|
| 440 |
f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
|
| 441 |
" running `.generate()`.",
|
| 442 |
UserWarning,
|
| 443 |
-
|
| 444 |
)
|
| 445 |
# breakpoint()
|
| 446 |
if (
|
|
@@ -454,14 +377,13 @@ class DreamGenerationMixin:
|
|
| 454 |
"generation results, please set `attention_mask` when batch-padding inputs.",
|
| 455 |
UserWarning,
|
| 456 |
)
|
| 457 |
-
assert generation_config.num_return_sequences == 1,
|
| 458 |
-
|
| 459 |
input_ids, attention_mask = self._expand_inputs_for_generation(
|
| 460 |
expand_size=generation_config.num_return_sequences,
|
| 461 |
input_ids=input_ids,
|
| 462 |
attention_mask=attention_mask
|
| 463 |
)
|
| 464 |
-
|
| 465 |
result = self._sample(
|
| 466 |
input_ids,
|
| 467 |
attention_mask=attention_mask,
|
|
@@ -471,9 +393,14 @@ class DreamGenerationMixin:
|
|
| 471 |
inputs_embeds=inputs_embeds,
|
| 472 |
device=device,
|
| 473 |
prefix_lm=prefix_lm,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
**kwargs,
|
| 475 |
)
|
| 476 |
return result
|
|
|
|
| 477 |
def _sample(
|
| 478 |
self,
|
| 479 |
input_ids: torch.LongTensor,
|
|
@@ -484,223 +411,213 @@ class DreamGenerationMixin:
|
|
| 484 |
inputs_embeds=None,
|
| 485 |
prefix_lm=False,
|
| 486 |
device=None,
|
| 487 |
-
schedule_kwargs=None,
|
| 488 |
-
schedule=None,
|
| 489 |
step_ratio=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
**kwargs,
|
| 491 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
# histories 用于保存每一步的 x,如果需要返回历史则初始化为列表,否则为 None
|
| 508 |
histories = [] if (return_dict_in_generate and output_history) else None
|
|
|
|
|
|
|
|
|
|
| 509 |
|
| 510 |
-
# 2. 如果没有传入 input_ids,而是直接传了 inputs_embeds,就根据 inputs_embeds 构造一个 placeholder 的 input_ids
|
| 511 |
if input_ids is None:
|
| 512 |
assert device is not None
|
| 513 |
assert inputs_embeds is not None
|
| 514 |
-
bsz, seq_len = inputs_embeds.shape[:2]
|
| 515 |
-
max_length = seq_len + max_new_tokens
|
| 516 |
-
# 创建一个全 0 的张量作为占位,后续会把 embedding 覆盖回去
|
| 517 |
input_ids = torch.full((bsz, seq_len), 0, dtype=torch.long).to(device)
|
| 518 |
|
| 519 |
-
# tok_idx 和 past_key_values 暂时留空,后面 prefix_lm 分支会用到
|
| 520 |
tok_idx = None
|
| 521 |
past_key_values = None
|
| 522 |
|
| 523 |
-
|
| 524 |
-
# F.pad 的 (0, L) 表示在右侧 pad 长度为 (max_length - seq_len),值为 mask_token_id
|
| 525 |
-
# import pdb;pdb.set_trace()
|
| 526 |
-
x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id) # 生成初始的 […, MASK, MASK, …]
|
| 527 |
-
|
| 528 |
-
# 4. 如果启用 prefix_lm 模式,先用 inputs_embeds 做一次常规模型前缀推理,得到 past_key_values 和首个 token
|
| 529 |
-
if prefix_lm:
|
| 530 |
-
dtype = inputs_embeds.dtype
|
| 531 |
-
# 先做一次前缀推理,use_cache=True 以获取 past_key_values
|
| 532 |
-
prefill = self.forward_dream(
|
| 533 |
-
None, attention_mask, tok_idx,
|
| 534 |
-
inputs_embeds=inputs_embeds.to(dtype),
|
| 535 |
-
use_cache=True
|
| 536 |
-
)
|
| 537 |
-
past_key_values = prefill.past_key_values
|
| 538 |
-
# 把前缀阶段模型最后一步的预测 token 取出,作为去噪的第一个位置
|
| 539 |
-
first_token = prefill.logits[:, -1:].argmax(dim=-1) # 形状为 [B, 1]
|
| 540 |
-
# 只保留 mask 区域(原 x 的 right half)
|
| 541 |
-
x = x[:, input_ids.shape[1]:] # 形状 [B, max_new_tokens]
|
| 542 |
-
# 把 mask 区域第一位填为 first_token
|
| 543 |
-
x[:, :1] = first_token
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
#. prefill['logits'].shape. torch.Size([1, 1063, 151667]) 即输入是这个
|
| 547 |
-
|
| 548 |
-
# 5. 当前不支持带 attention_mask 的情形,断言确保 attention_mask 一定为 None
|
| 549 |
-
assert attention_mask is None
|
| 550 |
|
| 551 |
-
# 6. 构造去噪时刻表 timesteps,线性从 1 -> eps,共 (steps + 1) 个值
|
| 552 |
-
# timesteps[i] 对应上一步噪声权重,timesteps[i+1] 对应本步噪声权重
|
| 553 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
| 554 |
-
# import pdb;pdb.set_trace()
|
| 555 |
-
# 7. 给用户一个机会在第 0 步“初始 x”阶段插入自定义逻辑
|
| 556 |
x = generation_tokens_hook_func(None, x, None)
|
| 557 |
|
| 558 |
-
# 8. 如果用户指定 step_ratio,就根据比例重计算步数
|
| 559 |
if step_ratio is not None:
|
| 560 |
steps = int(max_new_tokens * step_ratio)
|
| 561 |
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
|
| 574 |
-
# 10.2 先把 x 转成 embedding,得到形状 [B, current_length, D]
|
| 575 |
-
inputs_embeds_curr = self.model.embed_tokens(x)
|
| 576 |
-
|
| 577 |
-
# 10.3 如果非 prefix_lm,且外部传入了 inputs_embeds,则把前缀部分覆盖回去
|
| 578 |
-
if not prefix_lm:
|
| 579 |
if inputs_embeds is not None:
|
| 580 |
inputs_embeds_curr[:, :inputs_embeds.shape[1]] = inputs_embeds
|
| 581 |
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
| 628 |
-
elif alg == 'topk_margin':
|
| 629 |
-
confidence, x0 = sample_tokens(
|
| 630 |
-
mask_logits,
|
| 631 |
temperature=temperature,
|
| 632 |
top_p=top_p,
|
| 633 |
top_k=top_k,
|
| 634 |
-
|
| 635 |
-
)
|
| 636 |
-
elif alg == 'entropy':
|
| 637 |
-
confidence, x0 = sample_tokens(
|
| 638 |
-
mask_logits,
|
| 639 |
-
temperature,
|
| 640 |
-
top_p=top_p,
|
| 641 |
-
top_k=top_k,
|
| 642 |
-
neg_entropy=True
|
| 643 |
)
|
|
|
|
|
|
|
| 644 |
else:
|
| 645 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
# 根据 schedule(或默认比例)决定本轮要去噪多少个
|
| 650 |
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
else:
|
| 654 |
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token
|
|
|
|
| 655 |
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
for b in range(x.shape[0]):
|
| 681 |
-
row = x[b]
|
| 682 |
-
# 找到第一个出现 SPECIAL_TOKEN_ID 的位置
|
| 683 |
-
idx = (row == SPECIAL_TOKEN_ID).nonzero(as_tuple=True)[0]
|
| 684 |
-
if len(idx) > 0:
|
| 685 |
-
first_idx = idx[0].item()
|
| 686 |
-
# 该位置及其后面全部赋值为 SPECIAL_TOKEN_ID
|
| 687 |
-
row[first_idx:] = SPECIAL_TOKEN_ID
|
| 688 |
-
x[b] = row
|
| 689 |
-
|
| 690 |
-
# 10.8 用户自定义 token 钩子:对本轮更新后的 x 做额外处理
|
| 691 |
-
x = generation_tokens_hook_func(i, x, logits)
|
| 692 |
-
|
| 693 |
-
# 10.9 如果需要保存历史,就把当前 x clone 一份放进去
|
| 694 |
-
if histories is not None:
|
| 695 |
-
histories.append(x.clone())
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
# ForkedPdb().set_trace()
|
| 699 |
-
# 11. 循环结束后,根据 return_dict_in_generate 决定返回形式
|
| 700 |
-
if return_dict_in_generate:
|
| 701 |
-
return DreamModelOutput(
|
| 702 |
-
sequences=x, # 最终生成的完整 token 序列 [B, max_length]
|
| 703 |
-
history=histories, # 如果启用,会包含每一步的 x
|
| 704 |
-
)
|
| 705 |
-
else:
|
| 706 |
-
return x # 只返回最终序列 [B, max_length]
|
|
|
|
| 30 |
is_torchdynamo_compiling,
|
| 31 |
logging,
|
| 32 |
)
|
|
|
|
| 33 |
logger = logging.get_logger(__name__)
|
| 34 |
+
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
def top_p_logits(logits, top_p=None):
|
|
|
|
| 55 |
return logits
|
| 56 |
|
| 57 |
|
| 58 |
+
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False,
|
| 59 |
+
repeat_penalty=1.0, max_position_penalty=1.0, past_x=None, mask_id=None,):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
if temperature > 0:
|
|
|
|
| 61 |
logits = logits / temperature
|
| 62 |
|
|
|
|
|
|
|
|
|
|
| 63 |
if top_p is not None and top_p < 1:
|
|
|
|
| 64 |
logits = top_p_logits(logits, top_p)
|
| 65 |
if top_k is not None:
|
|
|
|
| 66 |
logits = top_k_logits(logits, top_k)
|
| 67 |
|
| 68 |
+
if repeat_penalty != 1.0:
|
| 69 |
+
select_mask = torch.logical_and((past_x != 0), (past_x != mask_id))
|
| 70 |
+
generated_tokens = set(past_x[select_mask].tolist())
|
| 71 |
+
for token in set(generated_tokens):
|
| 72 |
+
logits[:, token][logits[:, token] < 0] *= repeat_penalty
|
| 73 |
+
logits[:, token][logits[:, token] >= 0] /= repeat_penalty
|
| 74 |
+
|
| 75 |
+
if max_position_penalty != 1.0:
|
| 76 |
+
token_length = logits.shape[-2]
|
| 77 |
+
if token_length > 100:
|
| 78 |
+
penalty_map = [i / (token_length - 100) * (max_position_penalty - 1.0) + 1.0
|
| 79 |
+
for i in range(token_length - 100)]
|
| 80 |
+
penalty_map = torch.tensor(penalty_map).unsqueeze(-1).to(logits.device).to(logits.dtype)
|
| 81 |
+
penalty_map = torch.cat([torch.ones_like(logits[:100, :1]), penalty_map], dim=0)
|
| 82 |
+
penalty_map = penalty_map.repeat(1, logits.shape[-1])
|
| 83 |
+
|
| 84 |
+
logits[logits < 0] *= penalty_map[logits < 0]
|
| 85 |
+
logits[logits >= 0] /= penalty_map[logits >= 0]
|
| 86 |
|
| 87 |
+
probs = torch.softmax(logits, dim=-1)
|
|
|
|
|
|
|
| 88 |
if temperature > 0:
|
|
|
|
| 89 |
try:
|
|
|
|
| 90 |
x0 = dists.Categorical(probs=probs).sample()
|
|
|
|
| 91 |
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
| 92 |
except:
|
|
|
|
| 93 |
confidence, x0 = probs.max(dim=-1)
|
| 94 |
else:
|
|
|
|
| 95 |
confidence, x0 = probs.max(dim=-1)
|
| 96 |
|
|
|
|
|
|
|
|
|
|
| 97 |
if margin_confidence:
|
|
|
|
| 98 |
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
| 99 |
top1_probs = sorted_probs[:, 0]
|
| 100 |
top2_probs = sorted_probs[:, 1]
|
|
|
|
| 101 |
confidence = top1_probs - top2_probs
|
| 102 |
|
|
|
|
|
|
|
|
|
|
| 103 |
if neg_entropy:
|
| 104 |
epsilon = 1e-10
|
|
|
|
| 105 |
log_probs = torch.log(probs + epsilon)
|
|
|
|
| 106 |
confidence = torch.sum(probs * log_probs, dim=-1)
|
| 107 |
|
| 108 |
return confidence, x0
|
| 109 |
|
| 110 |
|
|
|
|
| 111 |
@dataclass
|
| 112 |
class DreamModelOutput(ModelOutput):
|
| 113 |
sequences: torch.LongTensor = None
|
|
|
|
| 319 |
generation_config: Optional[DreamGenerationConfig] = None,
|
| 320 |
inputs_embeds=None,
|
| 321 |
prefix_lm=False,
|
| 322 |
+
alg=None,
|
| 323 |
+
block_size=-1,
|
| 324 |
+
cfg=0.0,
|
| 325 |
+
add_boa_token=False,
|
| 326 |
**kwargs,
|
| 327 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
| 328 |
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
|
|
|
| 331 |
generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
|
| 332 |
# breakpoint()
|
| 333 |
# 2. Define model inputs
|
|
|
|
| 334 |
if inputs is not None:
|
| 335 |
input_ids = inputs
|
| 336 |
device = input_ids.device
|
|
|
|
| 364 |
f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
|
| 365 |
" running `.generate()`.",
|
| 366 |
UserWarning,
|
|
|
|
| 367 |
)
|
| 368 |
# breakpoint()
|
| 369 |
if (
|
|
|
|
| 377 |
"generation results, please set `attention_mask` when batch-padding inputs.",
|
| 378 |
UserWarning,
|
| 379 |
)
|
| 380 |
+
assert generation_config.num_return_sequences == 1, \
|
| 381 |
+
"Currently, we only support num_return_sequences = 1 for diffusion generation."
|
| 382 |
input_ids, attention_mask = self._expand_inputs_for_generation(
|
| 383 |
expand_size=generation_config.num_return_sequences,
|
| 384 |
input_ids=input_ids,
|
| 385 |
attention_mask=attention_mask
|
| 386 |
)
|
|
|
|
| 387 |
result = self._sample(
|
| 388 |
input_ids,
|
| 389 |
attention_mask=attention_mask,
|
|
|
|
| 393 |
inputs_embeds=inputs_embeds,
|
| 394 |
device=device,
|
| 395 |
prefix_lm=prefix_lm,
|
| 396 |
+
alg=alg,
|
| 397 |
+
block_size=block_size,
|
| 398 |
+
cfg=cfg,
|
| 399 |
+
add_boa_token=add_boa_token,
|
| 400 |
**kwargs,
|
| 401 |
)
|
| 402 |
return result
|
| 403 |
+
|
| 404 |
def _sample(
|
| 405 |
self,
|
| 406 |
input_ids: torch.LongTensor,
|
|
|
|
| 411 |
inputs_embeds=None,
|
| 412 |
prefix_lm=False,
|
| 413 |
device=None,
|
|
|
|
|
|
|
| 414 |
step_ratio=None,
|
| 415 |
+
penalty=1.2,
|
| 416 |
+
alg=None,
|
| 417 |
+
block_size=None,
|
| 418 |
+
add_boa_token=False,
|
| 419 |
+
max_position_penalty=1.0,
|
| 420 |
+
repeat_penalty=1.0,
|
| 421 |
+
cfg=0.0,
|
| 422 |
**kwargs,
|
| 423 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
| 424 |
+
output_history = True
|
| 425 |
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
| 426 |
+
max_length = generation_config.max_length
|
| 427 |
+
mask_token_id = generation_config.mask_token_id
|
| 428 |
+
max_new_tokens = generation_config.max_new_tokens
|
| 429 |
+
steps = min(generation_config.steps, max_new_tokens)
|
| 430 |
+
eps = generation_config.eps
|
| 431 |
+
alg = generation_config.alg if alg is None else alg
|
| 432 |
+
print("denoise algorithm: " + alg)
|
| 433 |
+
alg_temp = generation_config.alg_temp
|
| 434 |
+
temperature = generation_config.temperature
|
| 435 |
+
top_p = generation_config.top_p
|
| 436 |
+
top_k = generation_config.top_k
|
| 437 |
+
|
|
|
|
|
|
|
| 438 |
histories = [] if (return_dict_in_generate and output_history) else None
|
| 439 |
+
all_logit = []
|
| 440 |
+
generated_tokens = []
|
| 441 |
+
block_size = max_new_tokens if block_size < 0 else block_size
|
| 442 |
|
|
|
|
| 443 |
if input_ids is None:
|
| 444 |
assert device is not None
|
| 445 |
assert inputs_embeds is not None
|
| 446 |
+
bsz, seq_len = inputs_embeds.shape[:2]
|
| 447 |
+
max_length = seq_len + max_new_tokens
|
|
|
|
| 448 |
input_ids = torch.full((bsz, seq_len), 0, dtype=torch.long).to(device)
|
| 449 |
|
|
|
|
| 450 |
tok_idx = None
|
| 451 |
past_key_values = None
|
| 452 |
|
| 453 |
+
x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
|
|
|
|
|
|
|
| 455 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
|
|
|
|
|
|
| 456 |
x = generation_tokens_hook_func(None, x, None)
|
| 457 |
|
|
|
|
| 458 |
if step_ratio is not None:
|
| 459 |
steps = int(max_new_tokens * step_ratio)
|
| 460 |
|
| 461 |
+
if add_boa_token:
|
| 462 |
+
bos_index = int((x.shape[1] - (x == mask_token_id).sum()) + (x == mask_token_id).sum() * 0.2)
|
| 463 |
+
x[:, bos_index] = 151684 # <|begin_of_audio|>
|
| 464 |
+
|
| 465 |
+
input_x = x.clone()
|
| 466 |
+
total_steps = steps
|
| 467 |
+
block_num = (x == mask_token_id).sum() // block_size
|
| 468 |
+
if block_num * block_size < (x == mask_token_id).sum(): block_num += 1
|
| 469 |
+
input_length = input_ids.shape[-1]
|
| 470 |
+
|
| 471 |
+
task = None
|
| 472 |
+
if "task" in kwargs: task = kwargs['task']
|
| 473 |
+
if cfg > 0:
|
| 474 |
+
import random
|
| 475 |
+
empty_prompt = ""
|
| 476 |
+
if task == "S2I":
|
| 477 |
+
empty_prompt = "<|im_start|>system\nPlease generate an image based on the input audio.<|im_end|>\n"
|
| 478 |
+
empty_prompt += "<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\n"
|
| 479 |
+
un_x = kwargs['tokenizer'].encode(empty_prompt)
|
| 480 |
+
elif task == "T2I":
|
| 481 |
+
empty_prompt = "<|im_start|>user\nGenerate an image based on the provided text description.\n"
|
| 482 |
+
empty_prompt += "<|im_end|>\n<|im_start|>assistant\n"
|
| 483 |
+
first_audio_token = kwargs['tokenizer'].encode("<|begin_of_audio|>")[0]
|
| 484 |
+
un_x_text = random.sample([_ for _ in range(first_audio_token)],
|
| 485 |
+
input_ids.shape[1] - len(kwargs['tokenizer'].encode(empty_prompt)))
|
| 486 |
+
un_x = kwargs['tokenizer'].encode("<|im_start|>user\nGenerate an image based on the provided \
|
| 487 |
+
text description.\n")
|
| 488 |
+
un_x = un_x + un_x_text + kwargs['tokenizer'].encode("<|im_end|>\n<|im_start|>assistant\n")
|
| 489 |
+
|
| 490 |
+
for block_idx in range(block_num):
|
| 491 |
+
block_mask = torch.zeros([x.shape[-1]]).to(torch.bool).to(x.device)
|
| 492 |
+
block_mask[input_length + block_idx * block_size: input_length + (block_idx + 1) * block_size] = True
|
| 493 |
+
steps = int(block_mask.sum() / (x.shape[-1] - input_length) * total_steps)
|
| 494 |
+
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
| 495 |
+
for i in tqdm(range(steps)):
|
| 496 |
+
mask_index = (x == mask_token_id)
|
| 497 |
+
if mask_index.sum() == 0: break
|
| 498 |
+
inputs_embeds_curr = self.model.embed_tokens(x)
|
| 499 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
if inputs_embeds is not None:
|
| 501 |
inputs_embeds_curr[:, :inputs_embeds.shape[1]] = inputs_embeds
|
| 502 |
|
| 503 |
+
if cfg > 0:
|
| 504 |
+
input_un_x = torch.tensor(un_x).unsqueeze(0).to(x.dtype).to(x.device)
|
| 505 |
+
input_un_x = torch.cat([input_un_x, x[:, input_ids.shape[1]:]], dim=1)
|
| 506 |
+
un_inpus_embeds = self.model.embed_tokens(input_un_x)
|
| 507 |
+
|
| 508 |
+
attention_mask_cond = torch.ones([1, inputs_embeds_curr.shape[1], inputs_embeds_curr.shape[1]])
|
| 509 |
+
attention_mask_cond = attention_mask_cond.to(torch.bool).to(inputs_embeds_curr.device)
|
| 510 |
+
attention_mask_uncond = torch.zeros([1, inputs_embeds_curr.shape[1], inputs_embeds_curr.shape[1]])
|
| 511 |
+
attention_mask_uncond[:, :un_inpus_embeds.shape[1], :un_inpus_embeds.shape[1]] = 1
|
| 512 |
+
attention_mask_uncond = attention_mask_uncond.to(torch.bool).to(inputs_embeds.device)
|
| 513 |
+
attention_mask = torch.cat([attention_mask_cond, attention_mask_uncond])
|
| 514 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 515 |
+
|
| 516 |
+
if inputs_embeds_curr.shape[1] != un_inpus_embeds.shape[1]:
|
| 517 |
+
un_inpus_embeds = torch.cat([un_inpus_embeds,
|
| 518 |
+
torch.zeros_like(inputs_embeds_curr[:, :inputs_embeds_curr.shape[1] -
|
| 519 |
+
un_inpus_embeds.shape[1], :])], dim=1)
|
| 520 |
+
input_inputs_embeds_curr = torch.cat([inputs_embeds_curr, un_inpus_embeds])
|
| 521 |
+
|
| 522 |
+
model_logits = self.forward_dream(None, attention_mask, tok_idx,
|
| 523 |
+
inputs_embeds=input_inputs_embeds_curr).logits
|
| 524 |
+
logits = model_logits[:1]; un_logits = model_logits[1:]
|
| 525 |
+
logits = un_logits + (cfg + 1) * (logits - un_logits)
|
| 526 |
+
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
| 527 |
+
|
| 528 |
+
else:
|
| 529 |
+
logits = self.forward_dream(None, attention_mask, tok_idx,
|
| 530 |
+
inputs_embeds=inputs_embeds_curr).logits
|
| 531 |
+
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
| 532 |
+
|
| 533 |
+
logits = generation_logits_hook_func(i, x, logits)
|
| 534 |
+
|
| 535 |
+
mask_logits = logits[mask_index]
|
| 536 |
+
if i == 0:
|
| 537 |
+
input_index = torch.where(mask_index[0]==True)[0][0]
|
| 538 |
+
|
| 539 |
+
t = timesteps[i]
|
| 540 |
+
s = timesteps[i + 1]
|
| 541 |
+
|
| 542 |
+
if alg == 'origin':
|
| 543 |
+
p_transfer = 1 - s / t if i < steps - 1 else 1
|
| 544 |
+
x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
|
| 545 |
+
transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
|
| 546 |
+
_, x0[transfer_index_t_s] = sample_tokens(
|
| 547 |
+
mask_logits[transfer_index_t_s],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 548 |
temperature=temperature,
|
| 549 |
top_p=top_p,
|
| 550 |
top_k=top_k,
|
| 551 |
+
max_position_penalty=max_position_penalty,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
)
|
| 553 |
+
x[mask_index] = x0.clone()
|
| 554 |
+
|
| 555 |
else:
|
| 556 |
+
if alg == 'maskgit_plus':
|
| 557 |
+
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k,
|
| 558 |
+
max_position_penalty=max_position_penalty)
|
| 559 |
+
elif alg == 'topk_margin':
|
| 560 |
+
confidence, x0 = sample_tokens(
|
| 561 |
+
mask_logits,
|
| 562 |
+
temperature=temperature,
|
| 563 |
+
top_p=top_p,
|
| 564 |
+
top_k=top_k,
|
| 565 |
+
margin_confidence=True,
|
| 566 |
+
max_position_penalty=max_position_penalty,
|
| 567 |
+
)
|
| 568 |
+
elif alg == 'entropy':
|
| 569 |
+
confidence, x0 = sample_tokens(
|
| 570 |
+
mask_logits,
|
| 571 |
+
temperature,
|
| 572 |
+
top_p=top_p,
|
| 573 |
+
top_k=top_k,
|
| 574 |
+
neg_entropy=True,
|
| 575 |
+
max_position_penalty=max_position_penalty,
|
| 576 |
+
)
|
| 577 |
+
elif alg == "entropy-penalty":
|
| 578 |
+
confidence, x0 = sample_tokens(
|
| 579 |
+
mask_logits,
|
| 580 |
+
temperature,
|
| 581 |
+
top_p=top_p,
|
| 582 |
+
top_k=top_k,
|
| 583 |
+
neg_entropy=True,
|
| 584 |
+
repeat_penalty=repeat_penalty if len(histories) != 0 else 1.0,
|
| 585 |
+
past_x=histories[-1] if len(histories) != 0 else [],
|
| 586 |
+
mask_id=mask_token_id,
|
| 587 |
+
max_position_penalty=max_position_penalty,
|
| 588 |
+
)
|
| 589 |
+
else:
|
| 590 |
+
raise RuntimeError(f"Unknown alg: {alg}")
|
| 591 |
|
| 592 |
+
block_mask_1 = block_mask[mask_index[0]]
|
| 593 |
+
confidence = confidence + torch.where(block_mask_1, 0, -torch.inf).to(confidence.device)
|
|
|
|
| 594 |
|
| 595 |
+
num_mask_token = mask_index.sum()
|
| 596 |
+
num_mask_token = (x[:, block_mask] == mask_token_id).sum()
|
|
|
|
| 597 |
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token
|
| 598 |
+
number_transfer_tokens = max(number_transfer_tokens, 1)
|
| 599 |
|
| 600 |
+
if number_transfer_tokens > 0:
|
| 601 |
+
if alg_temp is None or alg_temp == 0:
|
| 602 |
+
_, transfer_index = torch.topk(confidence, number_transfer_tokens)
|
| 603 |
+
else:
|
| 604 |
+
confidence = confidence / alg_temp
|
| 605 |
+
confidence = F.softmax(confidence, dim=-1)
|
| 606 |
+
transfer_index = torch.multinomial(confidence, num_samples=number_transfer_tokens)
|
| 607 |
+
|
| 608 |
+
x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id
|
| 609 |
+
x0_[transfer_index] = x0[transfer_index].clone()
|
| 610 |
+
x[mask_index] = x0_
|
| 611 |
+
|
| 612 |
+
logit,indic = torch.max(torch.softmax(logits.clone(),dim=-1),-1)
|
| 613 |
+
logit = logit[0][x[0]!=0]
|
| 614 |
+
indic = indic[0][x[0]!=0]
|
| 615 |
+
temp_X = x[0][x[0]!=0]
|
| 616 |
+
|
| 617 |
+
x = generation_tokens_hook_func(i, x, logits)
|
| 618 |
+
|
| 619 |
+
if histories is not None:
|
| 620 |
+
histories.append(x.clone())
|
| 621 |
+
all_logit.append(torch.max(logits.clone(),-1)[-1])
|
| 622 |
+
|
| 623 |
+
return (x, histories)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|