lijiang commited on
Commit
f3cab6a
·
verified ·
1 Parent(s): 0f53954

Update generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +210 -293
generation_utils.py CHANGED
@@ -30,23 +30,8 @@ from transformers.utils import (
30
  is_torchdynamo_compiling,
31
  logging,
32
  )
33
- from .generate_from_llada import get_num_transfer_tokens_sch
34
  logger = logging.get_logger(__name__)
35
-
36
- import sys
37
- import pdb
38
- class ForkedPdb(pdb.Pdb):
39
- """
40
- PDB Subclass for debugging multi-processed code
41
- Suggested in: https://stackoverflow.com/questions/4716533/how-to-attach-debugger-to-a-python-subproccess
42
- """
43
- def interaction(self, *args, **kwargs):
44
- _stdin = sys.stdin
45
- try:
46
- sys.stdin = open('/dev/stdin')
47
- pdb.Pdb.interaction(self, *args, **kwargs)
48
- finally:
49
- sys.stdin = _stdin
50
 
51
 
52
  def top_p_logits(logits, top_p=None):
@@ -70,123 +55,59 @@ def top_k_logits(logits, top_k=None):
70
  return logits
71
 
72
 
73
- # def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
74
-
75
- # if temperature > 0:
76
- # logits = logits / temperature
77
- # if top_p is not None and top_p < 1:
78
- # logits = top_p_logits(logits, top_p)
79
- # if top_k is not None:
80
- # logits = top_k_logits(logits, top_k)
81
- # probs = torch.softmax(logits, dim=-1)
82
-
83
- # if temperature > 0:
84
- # try:
85
- # x0 = dists.Categorical(probs=probs).sample()
86
- # confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
87
- # except:
88
- # confidence, x0 = probs.max(dim=-1)
89
- # else:
90
- # confidence, x0 = probs.max(dim=-1)
91
-
92
- # if margin_confidence:
93
- # sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
94
- # # Extract top1 and top2 probabilities
95
- # top1_probs = sorted_probs[:, 0]
96
- # top2_probs = sorted_probs[:, 1]
97
- # # Calculate confidence as top1 - top2
98
- # confidence = top1_probs - top2_probs
99
-
100
- # if neg_entropy:
101
- # epsilon = 1e-10
102
- # log_probs = torch.log(probs + epsilon)
103
- # confidence = torch.sum(probs * log_probs, dim=-1)
104
-
105
- # return confidence, x0
106
-
107
-
108
- def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
109
- """
110
- 从给定的 logits 中采样或贪心选取 token,并返回置信度和 token ID。
111
-
112
- 参数:
113
- logits (Tensor):形状 [batch_size, vocab_size],模型对各候选 token 的打分(未经 softmax)。
114
- temperature (float):温度系数,默认 0.0。>0 时按概率采样,=0 时贪心选取。
115
- top_p (float 或 None):核采样参数(nucleus sampling),若指定且 <1,只保留累计概率前 top_p 的 token。
116
- top_k (int 或 None):前 k 采样参数(top-k sampling),若指定,只从概率最高的 k 个 token 中选取。
117
- margin_confidence (bool):是否使用 top1−top2 之差作为置信度,默认 False。
118
- neg_entropy (bool):是否使用负熵(−∑p·logp)作为置信度,默认 False。
119
-
120
- 返回:
121
- confidence (Tensor):形状 [batch_size] 的置信度值(可用概率、margin 差值或负熵)。
122
- x0 (Tensor):形状 [batch_size] 的 int64 张量,表示采样或贪心得到的 token ID。
123
- """
124
-
125
- # ======================================================
126
- # 1. 温度缩放 (Temperature Scaling)
127
- # ======================================================
128
  if temperature > 0:
129
- # 当 temperature>0 时,将 logits 除以 temperature,使得 softmax 分布更平滑或更尖锐
130
  logits = logits / temperature
131
 
132
- # ======================================================
133
- # 2. Top-p (Nucleus) 与 Top-k 过滤
134
- # ======================================================
135
  if top_p is not None and top_p < 1:
136
- # 调用 top_p_logits,保留累计概率达到 top_p 的 token,其它 logits 置为很小的负值
137
  logits = top_p_logits(logits, top_p)
138
  if top_k is not None:
139
- # 调用 top_k_logits,仅保留概率最高的 top_k 个 token,其它 logits 置为很小的负值
140
  logits = top_k_logits(logits, top_k)
141
 
142
- # ======================================================
143
- # 3. 计算概率分布 (Softmax)
144
- # ======================================================
145
- probs = torch.softmax(logits, dim=-1)
146
- # 此时 probs 形状为 [batch_size, vocab_size],每行和为 1
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- # ======================================================
149
- # 4. 根据 temperature 决定采样或贪心选取
150
- # ======================================================
151
  if temperature > 0:
152
- # 随机采样分支:从 Categorical 分布中采样 token
153
  try:
154
- # 从多项分布中采样得到 token ID,形状 [batch_size]
155
  x0 = dists.Categorical(probs=probs).sample()
156
- # 用 gather 取出对应位置的概率值作为置信度,形状 [batch_size]
157
  confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
158
  except:
159
- # 若采样出错(如概率分布不合法),退化为贪心选取
160
  confidence, x0 = probs.max(dim=-1)
161
  else:
162
- # 当 temperature=0 时,直接贪心选取概率最大的 token
163
  confidence, x0 = probs.max(dim=-1)
164
 
165
- # ======================================================
166
- # 5. margin_confidence: 使用 top1−top2 差值作为置信度
167
- # ======================================================
168
  if margin_confidence:
169
- # 将每行概率按降序排序,sorted_probs[:,0] 为 top1,sorted_probs[:,1] 为 top2
170
  sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
171
  top1_probs = sorted_probs[:, 0]
172
  top2_probs = sorted_probs[:, 1]
173
- # 置信度设为 top1_probs − top2_probs
174
  confidence = top1_probs - top2_probs
175
 
176
- # ======================================================
177
- # 6. neg_entropy: 使用负熵(−∑ p·log p)作为置信度
178
- # ======================================================
179
  if neg_entropy:
180
  epsilon = 1e-10
181
- # 为避免 log(0) 产生 −inf,加上一个小常数 epsilon
182
  log_probs = torch.log(probs + epsilon)
183
- # 计算 ∑ p_i * log p_i,结果是负熵值(值越接近 0,表示分布更“尖锐”)
184
  confidence = torch.sum(probs * log_probs, dim=-1)
185
 
186
  return confidence, x0
187
 
188
 
189
-
190
  @dataclass
191
  class DreamModelOutput(ModelOutput):
192
  sequences: torch.LongTensor = None
@@ -398,6 +319,10 @@ class DreamGenerationMixin:
398
  generation_config: Optional[DreamGenerationConfig] = None,
399
  inputs_embeds=None,
400
  prefix_lm=False,
 
 
 
 
401
  **kwargs,
402
  ) -> Union[DreamModelOutput, torch.LongTensor]:
403
  # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
@@ -406,7 +331,6 @@ class DreamGenerationMixin:
406
  generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
407
  # breakpoint()
408
  # 2. Define model inputs
409
- # import pdb;pdb.set_trace()
410
  if inputs is not None:
411
  input_ids = inputs
412
  device = input_ids.device
@@ -440,7 +364,6 @@ class DreamGenerationMixin:
440
  f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
441
  " running `.generate()`.",
442
  UserWarning,
443
-
444
  )
445
  # breakpoint()
446
  if (
@@ -454,14 +377,13 @@ class DreamGenerationMixin:
454
  "generation results, please set `attention_mask` when batch-padding inputs.",
455
  UserWarning,
456
  )
457
- assert generation_config.num_return_sequences == 1, "Currently, we only support num_return_sequences = 1 for diffusion generation."
458
- # import pdb;pdb.set_trace()
459
  input_ids, attention_mask = self._expand_inputs_for_generation(
460
  expand_size=generation_config.num_return_sequences,
461
  input_ids=input_ids,
462
  attention_mask=attention_mask
463
  )
464
-
465
  result = self._sample(
466
  input_ids,
467
  attention_mask=attention_mask,
@@ -471,9 +393,14 @@ class DreamGenerationMixin:
471
  inputs_embeds=inputs_embeds,
472
  device=device,
473
  prefix_lm=prefix_lm,
 
 
 
 
474
  **kwargs,
475
  )
476
  return result
 
477
  def _sample(
478
  self,
479
  input_ids: torch.LongTensor,
@@ -484,223 +411,213 @@ class DreamGenerationMixin:
484
  inputs_embeds=None,
485
  prefix_lm=False,
486
  device=None,
487
- schedule_kwargs=None,
488
- schedule=None,
489
  step_ratio=None,
 
 
 
 
 
 
 
490
  **kwargs,
491
  ) -> Union[DreamModelOutput, torch.LongTensor]:
492
- # 1. 从 generation_config 中提取常用参数
493
- output_history = generation_config.output_history # 是否保存每一步的中间结果
494
- # output_history = True
495
- return_dict_in_generate = generation_config.return_dict_in_generate # 生成时是否返回字典形式
496
- max_length = generation_config.max_length # 生成后序列的最大长度(包括前缀)
497
- mask_token_id = generation_config.mask_token_id # [MASK] 的 token ID
498
- max_new_tokens = generation_config.max_new_tokens # 最多新增的 token 数量
499
- steps = min(generation_config.steps, max_new_tokens) # 实际去噪步数,不能超过最大新增 token
500
- eps = generation_config.eps # 噪声下限,用于时刻表
501
- alg = generation_config.alg # 选择的去噪算法('origin'/ 'maskgit_plus'/ 'topk_margin'/ 'entropy')
502
- alg_temp = generation_config.alg_temp # 针对某些算法(margin/entropy)调整置信度的温度参数
503
- temperature = generation_config.temperature # 采样时的温度
504
- top_p = generation_config.top_p # top-p 截断采样参数
505
- top_k = generation_config.top_k # top-k 截断采样参数
506
-
507
- # histories 用于保存每一步的 x,如果需要返回历史则初始化为列表,否则为 None
508
  histories = [] if (return_dict_in_generate and output_history) else None
 
 
 
509
 
510
- # 2. 如果没有传入 input_ids,而是直接传了 inputs_embeds,就根据 inputs_embeds 构造一个 placeholder 的 input_ids
511
  if input_ids is None:
512
  assert device is not None
513
  assert inputs_embeds is not None
514
- bsz, seq_len = inputs_embeds.shape[:2] # batch size 和前缀长度
515
- max_length = seq_len + max_new_tokens # 重新计算 max_length
516
- # 创建一个全 0 的张量作为占位,后续会把 embedding 覆盖回去
517
  input_ids = torch.full((bsz, seq_len), 0, dtype=torch.long).to(device)
518
 
519
- # tok_idx 和 past_key_values 暂时留空,后面 prefix_lm 分支会用到
520
  tok_idx = None
521
  past_key_values = None
522
 
523
- # 3. input_ids pad max_length,后面补 [MASK]
524
- # F.pad 的 (0, L) 表示在右侧 pad 长度为 (max_length - seq_len),值为 mask_token_id
525
- # import pdb;pdb.set_trace()
526
- x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id) # 生成初始的 […, MASK, MASK, …]
527
-
528
- # 4. 如果启用 prefix_lm 模式,先用 inputs_embeds 做一次常规模型前缀推理,得到 past_key_values 和首个 token
529
- if prefix_lm:
530
- dtype = inputs_embeds.dtype
531
- # 先做一次前缀推理,use_cache=True 以获取 past_key_values
532
- prefill = self.forward_dream(
533
- None, attention_mask, tok_idx,
534
- inputs_embeds=inputs_embeds.to(dtype),
535
- use_cache=True
536
- )
537
- past_key_values = prefill.past_key_values
538
- # 把前缀阶段模型最后一步的预测 token 取出,作为去噪的第一个位置
539
- first_token = prefill.logits[:, -1:].argmax(dim=-1) # 形状为 [B, 1]
540
- # 只保留 mask 区域(原 x 的 right half)
541
- x = x[:, input_ids.shape[1]:] # 形状 [B, max_new_tokens]
542
- # 把 mask 区域第一位填为 first_token
543
- x[:, :1] = first_token
544
-
545
-
546
- #. prefill['logits'].shape. torch.Size([1, 1063, 151667]) 即输入是这个
547
-
548
- # 5. 当前不支持带 attention_mask 的情形,断言确保 attention_mask 一定为 None
549
- assert attention_mask is None
550
 
551
- # 6. 构造去噪时刻表 timesteps,线性从 1 -> eps,共 (steps + 1) 个值
552
- # timesteps[i] 对应上一步噪声权重,timesteps[i+1] 对应本步噪声权重
553
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
554
- # import pdb;pdb.set_trace()
555
- # 7. 给用户一个机会在第 0 步“初始 x”阶段插入自定义逻辑
556
  x = generation_tokens_hook_func(None, x, None)
557
 
558
- # 8. 如果用户指定 step_ratio,就根据比例重计算步数
559
  if step_ratio is not None:
560
  steps = int(max_new_tokens * step_ratio)
561
 
562
- # 9. 计算每一步要去噪多少个 mask(如果传了 schedule,就用自定义调度)
563
- if schedule is None:
564
- sch = None
565
- else:
566
- # get_num_transfer_tokens_sch 返回形状 [B, steps] 的矩阵
567
- sch = get_num_transfer_tokens_sch((x == mask_token_id), steps, schedule, schedule_kwargs)
568
-
569
- # 10. 进入去噪主循环
570
- for i in range(steps):
571
- # 10.1 找出当前仍是 [MASK] 的位置,mask_index 为布尔矩阵 [B, current_length]
572
- mask_index = (x == mask_token_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
 
574
- # 10.2 先把 x 转成 embedding,得到形状 [B, current_length, D]
575
- inputs_embeds_curr = self.model.embed_tokens(x)
576
-
577
- # 10.3 如果非 prefix_lm,且外部传入了 inputs_embeds,则把前缀部分覆盖回去
578
- if not prefix_lm:
579
  if inputs_embeds is not None:
580
  inputs_embeds_curr[:, :inputs_embeds.shape[1]] = inputs_embeds
581
 
582
- # 用当前 embedding 做一次前向,得到 logits,形状 [B, current_length, V]
583
- logits = self.forward_dream(None, attention_mask, tok_idx, inputs_embeds=inputs_embeds_curr).logits
584
- # logits 拼接成对齐当前预测:logits[:,1:] 对齐到 x[:, :-1]
585
- logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
586
- else:
587
- # prefix_lm 模式,用 past_key_values 加速推理
588
- logits = self.forward_dream(
589
- None, attention_mask, tok_idx,
590
- inputs_embeds=inputs_embeds_curr,
591
- past_key_values=past_key_values
592
- ).logits
593
- logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
594
-
595
- # 10.4 用户自定义 logits 钩子,可以修改 logits 分布
596
- # import pdb;pdb.set_trace()
597
- logits = generation_logits_hook_func(i, x, logits)
598
-
599
- # 10.5 取出当前所有 [MASK] 位置对应的 logits,形状 [num_mask, V]
600
- mask_logits = logits[mask_index]
601
-
602
- # 10.6 从 timesteps 中取出噪声权重 t, s
603
- t = timesteps[i]
604
- s = timesteps[i + 1]
605
-
606
- # 10.7 根据不同算法决定本轮去噪逻辑
607
- if alg == 'origin':
608
- # 基础扩散算法:按概率 p_transfer 随机把一部分 mask 位置替换成 token
609
- p_transfer = 1 - s / t if i < steps - 1 else 1 # 最后一轮保证把所有剩余 mask 都去掉
610
- # x0 临时占位,全填 mask
611
- x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
612
- # 随机采样哪些位置在本轮去噪:如果 torch.rand < p_transfer 就先去噪
613
- transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
614
- # 对这些选中的位置,从 mask_logits 中采样真实 token
615
- _, x0[transfer_index_t_s] = sample_tokens(
616
- mask_logits[transfer_index_t_s],
617
- temperature=temperature,
618
- top_p=top_p,
619
- top_k=top_k
620
- )
621
- # 更新 x:只替换 mask_index 位置
622
- x[mask_index] = x0.clone()
623
- else:
624
- # MaskGIT+ / Top-K Margin / Entropy 算法
625
- if alg == 'maskgit_plus':
626
- # 返回 confidence(置信度)和 x0(最可能的 token ID)
627
- confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
628
- elif alg == 'topk_margin':
629
- confidence, x0 = sample_tokens(
630
- mask_logits,
631
  temperature=temperature,
632
  top_p=top_p,
633
  top_k=top_k,
634
- margin_confidence=True
635
- )
636
- elif alg == 'entropy':
637
- confidence, x0 = sample_tokens(
638
- mask_logits,
639
- temperature,
640
- top_p=top_p,
641
- top_k=top_k,
642
- neg_entropy=True
643
  )
 
 
644
  else:
645
- raise RuntimeError(f"Unknown alg: {alg}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646
 
647
- # 当前还有多少 mask 位置
648
- num_mask_token = mask_index.sum()
649
- # 根据 schedule(或默认比例)决定本轮要去噪多少个
650
 
651
- if sch is not None:
652
- number_transfer_tokens = sch[0, i]
653
- else:
654
  number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token
 
655
 
656
- if number_transfer_tokens > 0:
657
- if alg_temp is None or alg_temp == 0:
658
- # 直接选置信度最高的 number_transfer_tokens 个位置
659
- _, transfer_index = torch.topk(confidence, number_transfer_tokens)
660
- else:
661
- # 用温度调节 confidence,再按多项式采样 number_transfer_tokens 个
662
- confidence = confidence / alg_temp
663
- confidence = F.softmax(confidence, dim=-1)
664
- transfer_index = torch.multinomial(confidence, num_samples=number_transfer_tokens)
665
-
666
- # x0_ 临时占位,全填 mask
667
- x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id
668
- # 在选中的位置填入从 x0 (argmax token) 中取得的 token
669
- x0_[transfer_index] = x0[transfer_index].clone()
670
-
671
-
672
-
673
- # 更新 x:只替换 mask_index 位置
674
- x[mask_index] = x0_
675
-
676
- #如果出现的token有 151643(eos) ,那么他后面的所有都换成 151643,不需要再次mask
677
- SPECIAL_TOKEN_ID = 151643
678
- if (x == SPECIAL_TOKEN_ID).any():
679
- # 对每个 batch 处理
680
- for b in range(x.shape[0]):
681
- row = x[b]
682
- # 找到第一个出现 SPECIAL_TOKEN_ID 的位置
683
- idx = (row == SPECIAL_TOKEN_ID).nonzero(as_tuple=True)[0]
684
- if len(idx) > 0:
685
- first_idx = idx[0].item()
686
- # 该位置及其后面全部赋值为 SPECIAL_TOKEN_ID
687
- row[first_idx:] = SPECIAL_TOKEN_ID
688
- x[b] = row
689
-
690
- # 10.8 用户自定义 token 钩子:对本轮更新后的 x 做额外处理
691
- x = generation_tokens_hook_func(i, x, logits)
692
-
693
- # 10.9 如果需要保存历史,就把当前 x clone 一份放进去
694
- if histories is not None:
695
- histories.append(x.clone())
696
-
697
-
698
- # ForkedPdb().set_trace()
699
- # 11. 循环结束后,根据 return_dict_in_generate 决定返回形式
700
- if return_dict_in_generate:
701
- return DreamModelOutput(
702
- sequences=x, # 最终生成的完整 token 序列 [B, max_length]
703
- history=histories, # 如果启用,会包含每一步的 x
704
- )
705
- else:
706
- return x # 只返回最终序列 [B, max_length]
 
30
  is_torchdynamo_compiling,
31
  logging,
32
  )
 
33
  logger = logging.get_logger(__name__)
34
+ from tqdm import tqdm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  def top_p_logits(logits, top_p=None):
 
55
  return logits
56
 
57
 
58
+ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False,
59
+ repeat_penalty=1.0, max_position_penalty=1.0, past_x=None, mask_id=None,):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  if temperature > 0:
 
61
  logits = logits / temperature
62
 
 
 
 
63
  if top_p is not None and top_p < 1:
 
64
  logits = top_p_logits(logits, top_p)
65
  if top_k is not None:
 
66
  logits = top_k_logits(logits, top_k)
67
 
68
+ if repeat_penalty != 1.0:
69
+ select_mask = torch.logical_and((past_x != 0), (past_x != mask_id))
70
+ generated_tokens = set(past_x[select_mask].tolist())
71
+ for token in set(generated_tokens):
72
+ logits[:, token][logits[:, token] < 0] *= repeat_penalty
73
+ logits[:, token][logits[:, token] >= 0] /= repeat_penalty
74
+
75
+ if max_position_penalty != 1.0:
76
+ token_length = logits.shape[-2]
77
+ if token_length > 100:
78
+ penalty_map = [i / (token_length - 100) * (max_position_penalty - 1.0) + 1.0
79
+ for i in range(token_length - 100)]
80
+ penalty_map = torch.tensor(penalty_map).unsqueeze(-1).to(logits.device).to(logits.dtype)
81
+ penalty_map = torch.cat([torch.ones_like(logits[:100, :1]), penalty_map], dim=0)
82
+ penalty_map = penalty_map.repeat(1, logits.shape[-1])
83
+
84
+ logits[logits < 0] *= penalty_map[logits < 0]
85
+ logits[logits >= 0] /= penalty_map[logits >= 0]
86
 
87
+ probs = torch.softmax(logits, dim=-1)
 
 
88
  if temperature > 0:
 
89
  try:
 
90
  x0 = dists.Categorical(probs=probs).sample()
 
91
  confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
92
  except:
 
93
  confidence, x0 = probs.max(dim=-1)
94
  else:
 
95
  confidence, x0 = probs.max(dim=-1)
96
 
 
 
 
97
  if margin_confidence:
 
98
  sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
99
  top1_probs = sorted_probs[:, 0]
100
  top2_probs = sorted_probs[:, 1]
 
101
  confidence = top1_probs - top2_probs
102
 
 
 
 
103
  if neg_entropy:
104
  epsilon = 1e-10
 
105
  log_probs = torch.log(probs + epsilon)
 
106
  confidence = torch.sum(probs * log_probs, dim=-1)
107
 
108
  return confidence, x0
109
 
110
 
 
111
  @dataclass
112
  class DreamModelOutput(ModelOutput):
113
  sequences: torch.LongTensor = None
 
319
  generation_config: Optional[DreamGenerationConfig] = None,
320
  inputs_embeds=None,
321
  prefix_lm=False,
322
+ alg=None,
323
+ block_size=-1,
324
+ cfg=0.0,
325
+ add_boa_token=False,
326
  **kwargs,
327
  ) -> Union[DreamModelOutput, torch.LongTensor]:
328
  # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
 
331
  generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
332
  # breakpoint()
333
  # 2. Define model inputs
 
334
  if inputs is not None:
335
  input_ids = inputs
336
  device = input_ids.device
 
364
  f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
365
  " running `.generate()`.",
366
  UserWarning,
 
367
  )
368
  # breakpoint()
369
  if (
 
377
  "generation results, please set `attention_mask` when batch-padding inputs.",
378
  UserWarning,
379
  )
380
+ assert generation_config.num_return_sequences == 1, \
381
+ "Currently, we only support num_return_sequences = 1 for diffusion generation."
382
  input_ids, attention_mask = self._expand_inputs_for_generation(
383
  expand_size=generation_config.num_return_sequences,
384
  input_ids=input_ids,
385
  attention_mask=attention_mask
386
  )
 
387
  result = self._sample(
388
  input_ids,
389
  attention_mask=attention_mask,
 
393
  inputs_embeds=inputs_embeds,
394
  device=device,
395
  prefix_lm=prefix_lm,
396
+ alg=alg,
397
+ block_size=block_size,
398
+ cfg=cfg,
399
+ add_boa_token=add_boa_token,
400
  **kwargs,
401
  )
402
  return result
403
+
404
  def _sample(
405
  self,
406
  input_ids: torch.LongTensor,
 
411
  inputs_embeds=None,
412
  prefix_lm=False,
413
  device=None,
 
 
414
  step_ratio=None,
415
+ penalty=1.2,
416
+ alg=None,
417
+ block_size=None,
418
+ add_boa_token=False,
419
+ max_position_penalty=1.0,
420
+ repeat_penalty=1.0,
421
+ cfg=0.0,
422
  **kwargs,
423
  ) -> Union[DreamModelOutput, torch.LongTensor]:
424
+ output_history = True
425
+ return_dict_in_generate = generation_config.return_dict_in_generate
426
+ max_length = generation_config.max_length
427
+ mask_token_id = generation_config.mask_token_id
428
+ max_new_tokens = generation_config.max_new_tokens
429
+ steps = min(generation_config.steps, max_new_tokens)
430
+ eps = generation_config.eps
431
+ alg = generation_config.alg if alg is None else alg
432
+ print("denoise algorithm: " + alg)
433
+ alg_temp = generation_config.alg_temp
434
+ temperature = generation_config.temperature
435
+ top_p = generation_config.top_p
436
+ top_k = generation_config.top_k
437
+
 
 
438
  histories = [] if (return_dict_in_generate and output_history) else None
439
+ all_logit = []
440
+ generated_tokens = []
441
+ block_size = max_new_tokens if block_size < 0 else block_size
442
 
 
443
  if input_ids is None:
444
  assert device is not None
445
  assert inputs_embeds is not None
446
+ bsz, seq_len = inputs_embeds.shape[:2]
447
+ max_length = seq_len + max_new_tokens
 
448
  input_ids = torch.full((bsz, seq_len), 0, dtype=torch.long).to(device)
449
 
 
450
  tok_idx = None
451
  past_key_values = None
452
 
453
+ x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
 
 
455
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
 
 
456
  x = generation_tokens_hook_func(None, x, None)
457
 
 
458
  if step_ratio is not None:
459
  steps = int(max_new_tokens * step_ratio)
460
 
461
+ if add_boa_token:
462
+ bos_index = int((x.shape[1] - (x == mask_token_id).sum()) + (x == mask_token_id).sum() * 0.2)
463
+ x[:, bos_index] = 151684 # <|begin_of_audio|>
464
+
465
+ input_x = x.clone()
466
+ total_steps = steps
467
+ block_num = (x == mask_token_id).sum() // block_size
468
+ if block_num * block_size < (x == mask_token_id).sum(): block_num += 1
469
+ input_length = input_ids.shape[-1]
470
+
471
+ task = None
472
+ if "task" in kwargs: task = kwargs['task']
473
+ if cfg > 0:
474
+ import random
475
+ empty_prompt = ""
476
+ if task == "S2I":
477
+ empty_prompt = "<|im_start|>system\nPlease generate an image based on the input audio.<|im_end|>\n"
478
+ empty_prompt += "<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\n"
479
+ un_x = kwargs['tokenizer'].encode(empty_prompt)
480
+ elif task == "T2I":
481
+ empty_prompt = "<|im_start|>user\nGenerate an image based on the provided text description.\n"
482
+ empty_prompt += "<|im_end|>\n<|im_start|>assistant\n"
483
+ first_audio_token = kwargs['tokenizer'].encode("<|begin_of_audio|>")[0]
484
+ un_x_text = random.sample([_ for _ in range(first_audio_token)],
485
+ input_ids.shape[1] - len(kwargs['tokenizer'].encode(empty_prompt)))
486
+ un_x = kwargs['tokenizer'].encode("<|im_start|>user\nGenerate an image based on the provided \
487
+ text description.\n")
488
+ un_x = un_x + un_x_text + kwargs['tokenizer'].encode("<|im_end|>\n<|im_start|>assistant\n")
489
+
490
+ for block_idx in range(block_num):
491
+ block_mask = torch.zeros([x.shape[-1]]).to(torch.bool).to(x.device)
492
+ block_mask[input_length + block_idx * block_size: input_length + (block_idx + 1) * block_size] = True
493
+ steps = int(block_mask.sum() / (x.shape[-1] - input_length) * total_steps)
494
+ timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
495
+ for i in tqdm(range(steps)):
496
+ mask_index = (x == mask_token_id)
497
+ if mask_index.sum() == 0: break
498
+ inputs_embeds_curr = self.model.embed_tokens(x)
499
 
 
 
 
 
 
500
  if inputs_embeds is not None:
501
  inputs_embeds_curr[:, :inputs_embeds.shape[1]] = inputs_embeds
502
 
503
+ if cfg > 0:
504
+ input_un_x = torch.tensor(un_x).unsqueeze(0).to(x.dtype).to(x.device)
505
+ input_un_x = torch.cat([input_un_x, x[:, input_ids.shape[1]:]], dim=1)
506
+ un_inpus_embeds = self.model.embed_tokens(input_un_x)
507
+
508
+ attention_mask_cond = torch.ones([1, inputs_embeds_curr.shape[1], inputs_embeds_curr.shape[1]])
509
+ attention_mask_cond = attention_mask_cond.to(torch.bool).to(inputs_embeds_curr.device)
510
+ attention_mask_uncond = torch.zeros([1, inputs_embeds_curr.shape[1], inputs_embeds_curr.shape[1]])
511
+ attention_mask_uncond[:, :un_inpus_embeds.shape[1], :un_inpus_embeds.shape[1]] = 1
512
+ attention_mask_uncond = attention_mask_uncond.to(torch.bool).to(inputs_embeds.device)
513
+ attention_mask = torch.cat([attention_mask_cond, attention_mask_uncond])
514
+ attention_mask = attention_mask.unsqueeze(1)
515
+
516
+ if inputs_embeds_curr.shape[1] != un_inpus_embeds.shape[1]:
517
+ un_inpus_embeds = torch.cat([un_inpus_embeds,
518
+ torch.zeros_like(inputs_embeds_curr[:, :inputs_embeds_curr.shape[1] -
519
+ un_inpus_embeds.shape[1], :])], dim=1)
520
+ input_inputs_embeds_curr = torch.cat([inputs_embeds_curr, un_inpus_embeds])
521
+
522
+ model_logits = self.forward_dream(None, attention_mask, tok_idx,
523
+ inputs_embeds=input_inputs_embeds_curr).logits
524
+ logits = model_logits[:1]; un_logits = model_logits[1:]
525
+ logits = un_logits + (cfg + 1) * (logits - un_logits)
526
+ logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
527
+
528
+ else:
529
+ logits = self.forward_dream(None, attention_mask, tok_idx,
530
+ inputs_embeds=inputs_embeds_curr).logits
531
+ logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
532
+
533
+ logits = generation_logits_hook_func(i, x, logits)
534
+
535
+ mask_logits = logits[mask_index]
536
+ if i == 0:
537
+ input_index = torch.where(mask_index[0]==True)[0][0]
538
+
539
+ t = timesteps[i]
540
+ s = timesteps[i + 1]
541
+
542
+ if alg == 'origin':
543
+ p_transfer = 1 - s / t if i < steps - 1 else 1
544
+ x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
545
+ transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
546
+ _, x0[transfer_index_t_s] = sample_tokens(
547
+ mask_logits[transfer_index_t_s],
 
 
 
 
548
  temperature=temperature,
549
  top_p=top_p,
550
  top_k=top_k,
551
+ max_position_penalty=max_position_penalty,
 
 
 
 
 
 
 
 
552
  )
553
+ x[mask_index] = x0.clone()
554
+
555
  else:
556
+ if alg == 'maskgit_plus':
557
+ confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k,
558
+ max_position_penalty=max_position_penalty)
559
+ elif alg == 'topk_margin':
560
+ confidence, x0 = sample_tokens(
561
+ mask_logits,
562
+ temperature=temperature,
563
+ top_p=top_p,
564
+ top_k=top_k,
565
+ margin_confidence=True,
566
+ max_position_penalty=max_position_penalty,
567
+ )
568
+ elif alg == 'entropy':
569
+ confidence, x0 = sample_tokens(
570
+ mask_logits,
571
+ temperature,
572
+ top_p=top_p,
573
+ top_k=top_k,
574
+ neg_entropy=True,
575
+ max_position_penalty=max_position_penalty,
576
+ )
577
+ elif alg == "entropy-penalty":
578
+ confidence, x0 = sample_tokens(
579
+ mask_logits,
580
+ temperature,
581
+ top_p=top_p,
582
+ top_k=top_k,
583
+ neg_entropy=True,
584
+ repeat_penalty=repeat_penalty if len(histories) != 0 else 1.0,
585
+ past_x=histories[-1] if len(histories) != 0 else [],
586
+ mask_id=mask_token_id,
587
+ max_position_penalty=max_position_penalty,
588
+ )
589
+ else:
590
+ raise RuntimeError(f"Unknown alg: {alg}")
591
 
592
+ block_mask_1 = block_mask[mask_index[0]]
593
+ confidence = confidence + torch.where(block_mask_1, 0, -torch.inf).to(confidence.device)
 
594
 
595
+ num_mask_token = mask_index.sum()
596
+ num_mask_token = (x[:, block_mask] == mask_token_id).sum()
 
597
  number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token
598
+ number_transfer_tokens = max(number_transfer_tokens, 1)
599
 
600
+ if number_transfer_tokens > 0:
601
+ if alg_temp is None or alg_temp == 0:
602
+ _, transfer_index = torch.topk(confidence, number_transfer_tokens)
603
+ else:
604
+ confidence = confidence / alg_temp
605
+ confidence = F.softmax(confidence, dim=-1)
606
+ transfer_index = torch.multinomial(confidence, num_samples=number_transfer_tokens)
607
+
608
+ x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id
609
+ x0_[transfer_index] = x0[transfer_index].clone()
610
+ x[mask_index] = x0_
611
+
612
+ logit,indic = torch.max(torch.softmax(logits.clone(),dim=-1),-1)
613
+ logit = logit[0][x[0]!=0]
614
+ indic = indic[0][x[0]!=0]
615
+ temp_X = x[0][x[0]!=0]
616
+
617
+ x = generation_tokens_hook_func(i, x, logits)
618
+
619
+ if histories is not None:
620
+ histories.append(x.clone())
621
+ all_logit.append(torch.max(logits.clone(),-1)[-1])
622
+
623
+ return (x, histories)