autoprogrammer commited on
Commit
631ce9b
·
verified ·
1 Parent(s): 32f22f3

Update generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +122 -137
generation_utils.py CHANGED
@@ -1,19 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2024 The Dream team, HKUNLP Group and the
3
- # HuggingFace Inc. team. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # You may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
  import warnings
18
  import copy
19
  from dataclasses import dataclass
@@ -33,10 +19,8 @@ def top_p_logits(logits, top_p=None):
33
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
34
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
35
  sorted_indices_to_remove = cumulative_probs > top_p
36
- # Shift the indices to the right to keep the first token above the threshold
37
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
38
  sorted_indices_to_remove[..., 0] = 0
39
-
40
  mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
41
  mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
42
  logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
@@ -44,20 +28,27 @@ def top_p_logits(logits, top_p=None):
44
 
45
 
46
  def top_k_logits(logits, top_k=None):
47
- top_k = min(top_k, logits.size(-1)) # Safety check
48
- # Remove all tokens with a probability less than the last token of the top-k
49
  indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
50
  logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
51
  return logits
52
 
53
 
54
- def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
 
 
 
 
 
 
 
55
  if temperature > 0:
56
  logits = logits / temperature
57
  if top_p is not None and top_p < 1:
58
  logits = top_p_logits(logits, top_p)
59
  if top_k is not None:
60
  logits = top_k_logits(logits, top_k)
 
61
  probs = torch.softmax(logits, dim=-1)
62
 
63
  if temperature > 0:
@@ -76,10 +67,10 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confid
76
  confidence = top1_probs - top2_probs
77
 
78
  if neg_entropy:
 
79
  epsilon = 1e-10
80
  log_probs = torch.log(probs + epsilon)
81
- # 改动 1:用“负熵”的正值(越大越自信),与其它置信度方向保持一致
82
- confidence = -(probs * log_probs).sum(dim=-1)
83
 
84
  return confidence, x0
85
 
@@ -97,31 +88,29 @@ class DreamGenerationConfig(GenerationConfig):
97
  self.top_k: Optional[int] = kwargs.pop("top_k", None)
98
  self.max_length = kwargs.pop("max_length", 20)
99
  self.max_new_tokens = kwargs.pop("max_new_tokens", None)
100
- # diffusion specific params
101
  self.eps: float = kwargs.pop("eps", 1e-3)
102
  self.steps: int = kwargs.pop("steps", 512)
103
- self.alg: str = kwargs.pop("alg", 'origin')
104
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
105
 
106
- # === RCR 相关参数(默认不影响原逻辑) ===
107
  self.rcr: bool = kwargs.pop("rcr", False)
108
- self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
109
 
110
- # Parameters that define the output variables of `generate`
111
  self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
112
  self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
113
  self.output_history: bool = kwargs.pop("output_history", False)
114
 
115
- # Special tokens that can be used at generation time
116
  self.mask_token_id = kwargs.pop("mask_token_id", None)
117
  self.pad_token_id = kwargs.pop("pad_token_id", None)
118
  self.bos_token_id = kwargs.pop("bos_token_id", None)
119
  self.eos_token_id = kwargs.pop("eos_token_id", None)
120
 
121
- # Wild card
122
  self.generation_kwargs = kwargs.pop("generation_kwargs", {})
123
 
124
- # hub interface
125
  self._from_model_config = kwargs.pop("_from_model_config", False)
126
  self._commit_hash = kwargs.pop("_commit_hash", None)
127
  self.transformers_version = kwargs.pop("transformers_version", __version__)
@@ -145,7 +134,7 @@ class DreamGenerationMixin:
145
  def _expand_inputs_for_generation(
146
  expand_size: int = 1,
147
  input_ids: Optional[torch.LongTensor] = None,
148
- attention_mask: Optional[torch.LongTensor] = None
149
  ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
150
  if expand_size == 1:
151
  return input_ids, attention_mask
@@ -155,56 +144,50 @@ class DreamGenerationMixin:
155
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
156
  return input_ids, attention_mask
157
 
158
- # === RCR:仅在 rcr=True 时调用;不改动 baseline 分支 ===
159
- def _apply_rcr_logic(self, x, x0, confidence, mask_index, overtime_confidence,
160
- mask_token_id, step, total_steps, s, t):
 
 
 
 
 
 
 
 
 
161
  """
162
- Running Confidence Remasking:
163
- - 采用 Dream 的调度:k_step = num_mask_token * (1 - s/t)
164
- - 本步先按置信度从 [MASK] 中挑 top-k_step 写入预测,并把置信度累计到 overtime_confidence
165
- - 再施加“累计目标”约束:target_cum = num_mask_token * (1 - s/t)
166
- 若当前累计 > 目标,则把最低置信度的 token 反遮回 [MASK]
167
  """
168
- device = x.device
169
- B = x.shape[0]
170
-
171
- # Dream 一致的“批均值”口径
172
- num_mask_token = mask_index.sum() / mask_index.shape[0]
173
- number_transfer_tokens = int(num_mask_token * (1 - s / t)) if step < total_steps - 1 else int(num_mask_token)
174
 
175
- # 构造全长置信度和候选(非 mask 置 -inf / mask_token)
176
- full_conf = torch.full_like(x, -torch.inf, device=device, dtype=confidence.dtype)
177
- x_temp = torch.zeros_like(x, device=device, dtype=torch.long) + mask_token_id
178
- full_conf[mask_index] = confidence
179
- x_temp[mask_index] = x0.clone()
180
 
181
  for j in range(B):
182
- masked_j = int(mask_index[j].sum().item())
183
- k_j = min(number_transfer_tokens, masked_j)
184
-
185
- # 先选本步 top-k_j
186
- if k_j > 0:
187
- _, select_idx = torch.topk(full_conf[j], k=k_j, largest=True)
188
- x[j, select_idx] = x_temp[j, select_idx]
189
- overtime_confidence[j, select_idx] = full_conf[j, select_idx].clone().float()
190
-
191
- # 累计目标(与 baseline 对齐)
192
- if step < total_steps - 1:
193
- target_cum = int(num_mask_token * (1 - s / t))
194
- # 改动 2:用有限性判断“已生成”,而不是 > 0
195
- gen_mask = torch.isfinite(overtime_confidence[j])
196
- current_gen = int(gen_mask.sum().item())
197
-
198
- to_remask = max(0, current_gen - target_cum)
199
- if to_remask > 0:
200
- gen_indices = torch.where(gen_mask)[0]
201
- if gen_indices.numel() > 0:
202
- gen_conf = overtime_confidence[j, gen_indices]
203
- to_remask = min(to_remask, int(gen_indices.numel()))
204
- _, local_low = torch.topk(gen_conf, k=to_remask, largest=False)
205
- low_global = gen_indices[local_low]
206
- x[j, low_global] = mask_token_id
207
- overtime_confidence[j, low_global] = float("-inf")
208
 
209
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
210
  if is_torchdynamo_compiling():
@@ -229,12 +212,9 @@ class DreamGenerationMixin:
229
  if not has_default_max_length and generation_config.max_length is not None:
230
  logger.warning(
231
  f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
232
- f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
233
- "Please refer to the documentation for more information. "
234
- "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
235
  )
236
  generation_config.max_length = generation_config.max_new_tokens + input_ids_length
237
-
238
  elif has_default_max_length:
239
  if generation_config.max_length == DreamGenerationConfig().max_length:
240
  generation_config.max_length = generation_config.max_length + input_ids_length
@@ -261,7 +241,6 @@ class DreamGenerationMixin:
261
  generation_config.pad_token_id = self.generation_config.pad_token_id
262
  if generation_config.mask_token_id is None:
263
  generation_config.mask_token_id = self.generation_config.mask_token_id
264
-
265
  return generation_config
266
 
267
  def _prepare_special_tokens(self, generation_config: DreamGenerationConfig, device: Optional[Union[torch.device, str]] = None):
@@ -314,18 +293,13 @@ class DreamGenerationMixin:
314
  has_default_max_length=has_default_max_length,
315
  input_ids_length=input_ids_length,
316
  )
317
-
318
  self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
319
 
320
  if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
321
  warnings.warn(
322
  "You are calling .generate() with the `input_ids` being on a device type different"
323
  f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
324
- f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
325
- " Please make sure that you have put `input_ids` to the"
326
- f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
327
- " running `.generate()`.",
328
- UserWarning,
329
  )
330
  if (
331
  hasattr(generation_config, "pad_token_id")
@@ -333,9 +307,7 @@ class DreamGenerationMixin:
333
  and attention_mask is None
334
  ):
335
  warnings.warn(
336
- "Padding was detected but no attention mask is passed here. For correct "
337
- "generation results, please set `attention_mask` when batch-padding inputs.",
338
- UserWarning,
339
  )
340
 
341
  input_ids, attention_mask = self._expand_inputs_for_generation(
@@ -359,9 +331,8 @@ class DreamGenerationMixin:
359
  attention_mask: Optional[torch.LongTensor],
360
  generation_config: DreamGenerationConfig,
361
  generation_tokens_hook_func,
362
- generation_logits_hook_func
363
  ) -> Union[DreamModelOutput, torch.LongTensor]:
364
- # ---- 原参数 ----
365
  output_history = generation_config.output_history
366
  return_dict_in_generate = generation_config.return_dict_in_generate
367
  max_length = generation_config.max_length
@@ -374,13 +345,12 @@ class DreamGenerationMixin:
374
  top_p = generation_config.top_p
375
  top_k = generation_config.top_k
376
 
377
- # ---- RCR 参数 ----
378
  rcr = generation_config.rcr
379
  conf_alg = generation_config.conf_alg
380
 
381
  histories = [] if (return_dict_in_generate and output_history) else None
382
 
383
- # pad input_ids to max_length
384
  x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
385
 
386
  if attention_mask is not None and torch.any(attention_mask == 0.0):
@@ -397,54 +367,57 @@ class DreamGenerationMixin:
397
 
398
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
399
 
400
- # 改动 2:仅在 rcr=True 时,用 -inf 初始化,后续用 isfinite 判断
401
- overtime_confidence = torch.full_like(x, float("-inf"), dtype=torch.float32) if rcr else None
 
402
 
403
- # this allows user-defined token control of the intermediate steps
404
  x = generation_tokens_hook_func(None, x, None)
405
  for i in range(steps):
406
  mask_index = (x == mask_token_id)
407
  logits = self(x, attention_mask, tok_idx).logits
408
  logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
409
 
410
- # this allows user-defined logits control of the intermediate steps
411
  logits = generation_logits_hook_func(i, x, logits)
412
 
413
  mask_logits = logits[mask_index]
414
  t = timesteps[i]
415
  s = timesteps[i + 1]
416
 
417
- if alg == 'origin':
418
- # 原版 origin 分支:保持不变
419
  p_transfer = 1 - s / t if i < steps - 1 else 1
420
  x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
421
  transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
422
  _, x0[transfer_index_t_s] = sample_tokens(
423
- mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k
 
 
 
424
  )
425
  x[mask_index] = x0.clone()
 
 
426
  else:
427
- # rcr=False:沿用 alg 指定的置信度算法
428
- # rcr=True :使用 conf_alg 指定的置信度算法(不改变 rcr=False 的行为)
429
- if (not rcr and alg == 'maskgit_plus') or (rcr and conf_alg == 'maskgit_plus'):
430
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
431
- elif (not rcr and alg == 'topk_margin') or (rcr and conf_alg == 'topk_margin'):
432
  confidence, x0 = sample_tokens(
433
  mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True
434
  )
435
- elif (not rcr and alg == 'entropy') or (rcr and conf_alg == 'entropy'):
436
  confidence, x0 = sample_tokens(
437
  mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True
438
  )
439
  else:
440
  if rcr:
441
- if alg == 'maskgit_plus':
442
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
443
- elif alg == 'topk_margin':
444
  confidence, x0 = sample_tokens(
445
  mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True
446
  )
447
- elif alg == 'entropy':
448
  confidence, x0 = sample_tokens(
449
  mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True
450
  )
@@ -453,41 +426,53 @@ class DreamGenerationMixin:
453
  else:
454
  raise RuntimeError(f"Unknown alg: {alg}")
455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
  if rcr:
457
- # 仅在 rcr=True:应用 RCR
458
- print("[RCR] step", i)
459
- self._apply_rcr_logic(
460
- x, x0, confidence, mask_index, overtime_confidence,
461
- mask_token_id, i, steps, s, t
 
 
 
 
 
462
  )
463
- else:
464
- # 原版 Dream 逻辑:保持不变
465
- num_mask_token = mask_index.sum() / mask_index.shape[0]
466
- number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
467
- full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
468
- full_confidence[mask_index] = confidence
469
- if number_transfer_tokens > 0:
470
- if alg_temp is None or alg_temp == 0:
471
- _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
472
- else:
473
- full_confidence = full_confidence / alg_temp
474
- full_confidence = F.softmax(full_confidence, dim=-1)
475
- transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
476
- x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
477
- x_[mask_index] = x0.clone()
478
- row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
479
- x[row_indices, transfer_index] = x_[row_indices, transfer_index]
480
-
481
- # this allows user-defined token control of the intermediate steps
482
  x = generation_tokens_hook_func(i, x, logits)
483
 
484
  if histories is not None:
485
  histories.append(x.clone())
486
 
487
  if return_dict_in_generate:
488
- return DreamModelOutput(
489
- sequences=x,
490
- history=histories,
491
- )
492
  else:
493
  return x
 
1
  # coding=utf-8
2
+ # Copyright ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import warnings
4
  import copy
5
  from dataclasses import dataclass
 
19
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
20
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
21
  sorted_indices_to_remove = cumulative_probs > top_p
 
22
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
23
  sorted_indices_to_remove[..., 0] = 0
 
24
  mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
25
  mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
26
  logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
 
28
 
29
 
30
  def top_k_logits(logits, top_k=None):
31
+ top_k = min(top_k, logits.size(-1))
 
32
  indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
33
  logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
34
  return logits
35
 
36
 
37
+ def sample_tokens(
38
+ logits,
39
+ temperature=0.0,
40
+ top_p=None,
41
+ top_k=None,
42
+ margin_confidence=False,
43
+ neg_entropy=False,
44
+ ):
45
  if temperature > 0:
46
  logits = logits / temperature
47
  if top_p is not None and top_p < 1:
48
  logits = top_p_logits(logits, top_p)
49
  if top_k is not None:
50
  logits = top_k_logits(logits, top_k)
51
+
52
  probs = torch.softmax(logits, dim=-1)
53
 
54
  if temperature > 0:
 
67
  confidence = top1_probs - top2_probs
68
 
69
  if neg_entropy:
70
+ # 保持你原来的“熵”定义(注意它是负值;不改符号,避免影响 baseline)
71
  epsilon = 1e-10
72
  log_probs = torch.log(probs + epsilon)
73
+ confidence = torch.sum(probs * log_probs, dim=-1)
 
74
 
75
  return confidence, x0
76
 
 
88
  self.top_k: Optional[int] = kwargs.pop("top_k", None)
89
  self.max_length = kwargs.pop("max_length", 20)
90
  self.max_new_tokens = kwargs.pop("max_new_tokens", None)
91
+ # diffusion specific
92
  self.eps: float = kwargs.pop("eps", 1e-3)
93
  self.steps: int = kwargs.pop("steps", 512)
94
+ self.alg: str = kwargs.pop("alg", "origin")
95
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
96
 
97
+ # RCR:默认关闭;开启后只做“选后回遮”,不动 baseline 行为
98
  self.rcr: bool = kwargs.pop("rcr", False)
99
+ self.conf_alg: str = kwargs.pop("conf_alg", "maskgit_plus")
100
 
101
+ # generate 输出控制
102
  self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
103
  self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
104
  self.output_history: bool = kwargs.pop("output_history", False)
105
 
106
+ # special tokens
107
  self.mask_token_id = kwargs.pop("mask_token_id", None)
108
  self.pad_token_id = kwargs.pop("pad_token_id", None)
109
  self.bos_token_id = kwargs.pop("bos_token_id", None)
110
  self.eos_token_id = kwargs.pop("eos_token_id", None)
111
 
 
112
  self.generation_kwargs = kwargs.pop("generation_kwargs", {})
113
 
 
114
  self._from_model_config = kwargs.pop("_from_model_config", False)
115
  self._commit_hash = kwargs.pop("_commit_hash", None)
116
  self.transformers_version = kwargs.pop("transformers_version", __version__)
 
134
  def _expand_inputs_for_generation(
135
  expand_size: int = 1,
136
  input_ids: Optional[torch.LongTensor] = None,
137
+ attention_mask: Optional[torch.LongTensor] = None,
138
  ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
139
  if expand_size == 1:
140
  return input_ids, attention_mask
 
144
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
145
  return input_ids, attention_mask
146
 
147
+ # rcr=True 使用;不改变 baseline 的选入逻辑
148
+ def _rcr_remask_after_selection(
149
+ self,
150
+ x, # [B, L] 当前序列
151
+ mask_token_id: int,
152
+ step: int,
153
+ steps: int,
154
+ s: torch.Tensor,
155
+ t: torch.Tensor,
156
+ is_fixed: torch.Tensor, # [B, L] bool,已“确定”的位置
157
+ fixed_conf: torch.Tensor # [B, L] float,已确定位置的置信度(其余为 -inf)
158
+ ):
159
  """
160
+ 在已经“按 baseline 完成选入”之后,按累计目标回遮最低置信度的超额 token。
161
+ —— 极小侵入:不改变 baseline 的挑选,只在其后做回遮。
 
 
 
162
  """
163
+ B, L = x.shape
164
+ # 计算“批均值语义”的 num_mask_token(与 baseline 保持一致)
165
+ # 注意这里基于当前 x 的 [MASK] 数量计算
166
+ mask_index = (x == mask_token_id)
167
+ num_mask_token = (mask_index.sum() / mask_index.shape[0]).item()
 
168
 
169
+ # Dream 原调度:到本步为止应累计确定的目标总量
170
+ target_cum = int(num_mask_token * (1 - (s / t).item())) if step < steps - 1 else int(num_mask_token)
 
 
 
171
 
172
  for j in range(B):
173
+ # 当前累计(已确定)的数量
174
+ fixed_j = is_fixed[j]
175
+ current_gen = int(fixed_j.sum().item())
176
+ # 如果超额,回遮最低置信度的那部分
177
+ to_remask = max(0, current_gen - target_cum)
178
+ if to_remask > 0:
179
+ cand_idx = torch.where(fixed_j)[0]
180
+ if cand_idx.numel() == 0:
181
+ continue
182
+ conf_vals = fixed_conf[j, cand_idx]
183
+ # 取最小的 to_remask
184
+ k = min(to_remask, int(cand_idx.numel()))
185
+ _, local_low = torch.topk(conf_vals, k=k, largest=False)
186
+ low_global = cand_idx[local_low]
187
+ # 打回 [MASK],并清空标记
188
+ x[j, low_global] = mask_token_id
189
+ is_fixed[j, low_global] = False
190
+ fixed_conf[j, low_global] = float("-inf")
 
 
 
 
 
 
 
 
191
 
192
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
193
  if is_torchdynamo_compiling():
 
212
  if not has_default_max_length and generation_config.max_length is not None:
213
  logger.warning(
214
  f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
215
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence."
 
 
216
  )
217
  generation_config.max_length = generation_config.max_new_tokens + input_ids_length
 
218
  elif has_default_max_length:
219
  if generation_config.max_length == DreamGenerationConfig().max_length:
220
  generation_config.max_length = generation_config.max_length + input_ids_length
 
241
  generation_config.pad_token_id = self.generation_config.pad_token_id
242
  if generation_config.mask_token_id is None:
243
  generation_config.mask_token_id = self.generation_config.mask_token_id
 
244
  return generation_config
245
 
246
  def _prepare_special_tokens(self, generation_config: DreamGenerationConfig, device: Optional[Union[torch.device, str]] = None):
 
293
  has_default_max_length=has_default_max_length,
294
  input_ids_length=input_ids_length,
295
  )
 
296
  self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
297
 
298
  if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
299
  warnings.warn(
300
  "You are calling .generate() with the `input_ids` being on a device type different"
301
  f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
302
+ f" is on {self.device.type}."
 
 
 
 
303
  )
304
  if (
305
  hasattr(generation_config, "pad_token_id")
 
307
  and attention_mask is None
308
  ):
309
  warnings.warn(
310
+ "Padding was detected but no attention mask is passed here. For correct results, please set `attention_mask`."
 
 
311
  )
312
 
313
  input_ids, attention_mask = self._expand_inputs_for_generation(
 
331
  attention_mask: Optional[torch.LongTensor],
332
  generation_config: DreamGenerationConfig,
333
  generation_tokens_hook_func,
334
+ generation_logits_hook_func,
335
  ) -> Union[DreamModelOutput, torch.LongTensor]:
 
336
  output_history = generation_config.output_history
337
  return_dict_in_generate = generation_config.return_dict_in_generate
338
  max_length = generation_config.max_length
 
345
  top_p = generation_config.top_p
346
  top_k = generation_config.top_k
347
 
 
348
  rcr = generation_config.rcr
349
  conf_alg = generation_config.conf_alg
350
 
351
  histories = [] if (return_dict_in_generate and output_history) else None
352
 
353
+ # pad to max_length
354
  x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
355
 
356
  if attention_mask is not None and torch.any(attention_mask == 0.0):
 
367
 
368
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
369
 
370
+ # rcr=True:引入轻量跟踪,不影响 baseline
371
+ is_fixed = torch.zeros_like(x, dtype=torch.bool) if rcr else None
372
+ fixed_conf = torch.full_like(x, float("-inf")) if rcr else None # 存放已确定位置的置信度
373
 
 
374
  x = generation_tokens_hook_func(None, x, None)
375
  for i in range(steps):
376
  mask_index = (x == mask_token_id)
377
  logits = self(x, attention_mask, tok_idx).logits
378
  logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
379
 
 
380
  logits = generation_logits_hook_func(i, x, logits)
381
 
382
  mask_logits = logits[mask_index]
383
  t = timesteps[i]
384
  s = timesteps[i + 1]
385
 
386
+ if alg == "origin":
387
+ # 完全保持原始
388
  p_transfer = 1 - s / t if i < steps - 1 else 1
389
  x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
390
  transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
391
  _, x0[transfer_index_t_s] = sample_tokens(
392
+ mask_logits[transfer_index_t_s],
393
+ temperature=temperature,
394
+ top_p=top_p,
395
+ top_k=top_k,
396
  )
397
  x[mask_index] = x0.clone()
398
+
399
+ # origin 分支不做 RCR(与原版一致)
400
  else:
401
+ # 置信度算法:rcr=False alg;rcr=True 用 conf_alg(与之前一致)
402
+ if (not rcr and alg == "maskgit_plus") or (rcr and conf_alg == "maskgit_plus"):
 
403
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
404
+ elif (not rcr and alg == "topk_margin") or (rcr and conf_alg == "topk_margin"):
405
  confidence, x0 = sample_tokens(
406
  mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True
407
  )
408
+ elif (not rcr and alg == "entropy") or (rcr and conf_alg == "entropy"):
409
  confidence, x0 = sample_tokens(
410
  mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True
411
  )
412
  else:
413
  if rcr:
414
+ if alg == "maskgit_plus":
415
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
416
+ elif alg == "topk_margin":
417
  confidence, x0 = sample_tokens(
418
  mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True
419
  )
420
+ elif alg == "entropy":
421
  confidence, x0 = sample_tokens(
422
  mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True
423
  )
 
426
  else:
427
  raise RuntimeError(f"Unknown alg: {alg}")
428
 
429
+ # ===== baseline 的“选入”逻辑:原样保留 =====
430
+ num_mask_token = mask_index.sum() / mask_index.shape[0]
431
+ number_transfer_tokens = (
432
+ int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
433
+ )
434
+ full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
435
+ full_confidence[mask_index] = confidence
436
+
437
+ if number_transfer_tokens > 0:
438
+ if alg_temp is None or alg_temp == 0:
439
+ _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
440
+ else:
441
+ full_confidence = full_confidence / alg_temp
442
+ full_confidence = F.softmax(full_confidence, dim=-1)
443
+ transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
444
+
445
+ x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
446
+ x_[mask_index] = x0.clone()
447
+ row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
448
+ x[row_indices, transfer_index] = x_[row_indices, transfer_index]
449
+
450
+ # ===== 仅 rcr=True:记录“已确定”的位置与它们的置信度(用于后续回遮)=====
451
+ if rcr:
452
+ is_fixed[row_indices, transfer_index] = True
453
+ # 注意:这里存 baseline 使用的 full_confidence(与 baseline 完全一致)
454
+ fixed_conf[row_indices, transfer_index] = full_confidence[row_indices, transfer_index]
455
+
456
+ # ===== 仅 rcr=True:在“选入”之后按累计目标回遮最低置信度的超额部分 =====
457
  if rcr:
458
+ # 这一步只回遮,完全不改变 baseline 的选入行为
459
+ self._rcr_remask_after_selection(
460
+ x=x,
461
+ mask_token_id=mask_token_id,
462
+ step=i,
463
+ steps=steps,
464
+ s=s,
465
+ t=t,
466
+ is_fixed=is_fixed,
467
+ fixed_conf=fixed_conf,
468
  )
469
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  x = generation_tokens_hook_func(i, x, logits)
471
 
472
  if histories is not None:
473
  histories.append(x.clone())
474
 
475
  if return_dict_in_generate:
476
+ return DreamModelOutput(sequences=x, history=histories)
 
 
 
477
  else:
478
  return x