autoprogrammer commited on
Commit
34c9b0b
·
verified ·
1 Parent(s): 0d07b88

Update generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +160 -130
generation_utils.py CHANGED
@@ -1,5 +1,7 @@
1
  # coding=utf-8
2
- # Copyright ...
 
 
3
  import warnings
4
  import copy
5
  from dataclasses import dataclass
@@ -16,33 +18,38 @@ logger = logging.get_logger(__name__)
16
 
17
 
18
  def top_p_logits(logits, top_p=None):
 
 
19
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
20
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
21
  sorted_indices_to_remove = cumulative_probs > top_p
 
22
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
23
  sorted_indices_to_remove[..., 0] = 0
24
  mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
25
  mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
26
- logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
27
- return logits
28
 
29
 
30
  def top_k_logits(logits, top_k=None):
 
 
31
  top_k = min(top_k, logits.size(-1))
32
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
33
- logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
34
- return logits
35
 
36
 
37
  def sample_tokens(
38
  logits,
39
- temperature=0.0,
40
- top_p=None,
41
- top_k=None,
42
- margin_confidence=False,
43
- neg_entropy=False,
44
  ):
45
- if temperature > 0:
 
46
  logits = logits / temperature
47
  if top_p is not None and top_p < 1:
48
  logits = top_p_logits(logits, top_p)
@@ -51,25 +58,27 @@ def sample_tokens(
51
 
52
  probs = torch.softmax(logits, dim=-1)
53
 
54
- if temperature > 0:
 
55
  try:
56
  x0 = dists.Categorical(probs=probs).sample()
57
  confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
58
  except Exception:
59
  confidence, x0 = probs.max(dim=-1)
60
  else:
 
61
  confidence, x0 = probs.max(dim=-1)
62
 
63
  if margin_confidence:
64
  sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
65
- top1_probs = sorted_probs[:, 0]
66
- top2_probs = sorted_probs[:, 1]
67
  confidence = top1_probs - top2_probs
68
 
69
  if neg_entropy:
70
- # 保持你原来的“熵”定义(注意它是负值;不改符号,避免影响 baseline)
71
- epsilon = 1e-10
72
- log_probs = torch.log(probs + epsilon)
73
  confidence = torch.sum(probs * log_probs, dim=-1)
74
 
75
  return confidence, x0
@@ -88,13 +97,14 @@ class DreamGenerationConfig(GenerationConfig):
88
  self.top_k: Optional[int] = kwargs.pop("top_k", None)
89
  self.max_length = kwargs.pop("max_length", 20)
90
  self.max_new_tokens = kwargs.pop("max_new_tokens", None)
91
- # diffusion specific
 
92
  self.eps: float = kwargs.pop("eps", 1e-3)
93
  self.steps: int = kwargs.pop("steps", 512)
94
  self.alg: str = kwargs.pop("alg", "origin")
95
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
96
 
97
- # RCR:默认关闭;开启后只做“选后回遮”,不动 baseline 行为
98
  self.rcr: bool = kwargs.pop("rcr", False)
99
  self.conf_alg: str = kwargs.pop("conf_alg", "maskgit_plus")
100
 
@@ -111,6 +121,7 @@ class DreamGenerationConfig(GenerationConfig):
111
 
112
  self.generation_kwargs = kwargs.pop("generation_kwargs", {})
113
 
 
114
  self._from_model_config = kwargs.pop("_from_model_config", False)
115
  self._commit_hash = kwargs.pop("_commit_hash", None)
116
  self.transformers_version = kwargs.pop("transformers_version", __version__)
@@ -126,6 +137,7 @@ class DreamGenerationConfig(GenerationConfig):
126
  self.validate(is_init=True)
127
 
128
  def validate(self, is_init=False):
 
129
  pass
130
 
131
 
@@ -144,50 +156,70 @@ class DreamGenerationMixin:
144
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
145
  return input_ids, attention_mask
146
 
147
- # 仅 rcr=True 使用;不改变 baseline 的选入逻辑
148
- def _rcr_remask_after_selection(
149
  self,
150
- x, # [B, L] 当前序列
 
 
 
 
151
  mask_token_id: int,
152
  step: int,
153
- steps: int,
154
  s: torch.Tensor,
155
  t: torch.Tensor,
156
- is_fixed: torch.Tensor, # [B, L] bool,已“确定”的位置
157
- fixed_conf: torch.Tensor # [B, L] float,已确定位置的置信度(其余为 -inf)
158
  ):
159
  """
160
- 在已经“按 baseline 完成选入”之后,按累计目标回遮最低置信度的超额 token。
161
- —— 极小侵入:不改变 baseline 的挑选,只在其后做回遮。
 
 
 
162
  """
163
- B, L = x.shape
164
- # 计算“批均值语义”的 num_mask_token(与 baseline 保持一致)
165
- # 注意这里基于当前 x[MASK] 数量计算
166
- mask_index = (x == mask_token_id)
167
- num_mask_token = (mask_index.sum() / mask_index.shape[0]).item()
 
 
 
168
 
169
- # Dream 原调度:到本步为止应累计确定的目标总量
170
- target_cum = int(num_mask_token * (1 - (s / t).item())) if step < steps - 1 else int(num_mask_token)
 
 
 
171
 
172
  for j in range(B):
173
- # 当前累计(已确定)的数量
174
- fixed_j = is_fixed[j]
175
- current_gen = int(fixed_j.sum().item())
176
- # 如果超额,回遮最低置信度的那部分
177
- to_remask = max(0, current_gen - target_cum)
178
- if to_remask > 0:
179
- cand_idx = torch.where(fixed_j)[0]
180
- if cand_idx.numel() == 0:
181
- continue
182
- conf_vals = fixed_conf[j, cand_idx]
183
- # 取最小的 to_remask
184
- k = min(to_remask, int(cand_idx.numel()))
185
- _, local_low = torch.topk(conf_vals, k=k, largest=False)
186
- low_global = cand_idx[local_low]
187
- # 打回 [MASK],并清空标记
188
- x[j, low_global] = mask_token_id
189
- is_fixed[j, low_global] = False
190
- fixed_conf[j, low_global] = float("-inf")
 
 
 
 
 
 
 
 
 
 
191
 
192
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
193
  if is_torchdynamo_compiling():
@@ -200,11 +232,9 @@ class DreamGenerationMixin:
200
  UserWarning,
201
  )
202
  if input_ids_length >= generation_config.max_length:
203
- input_ids_string = "input_ids"
204
  raise ValueError(
205
- f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
206
- f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
207
- " increasing `max_length` or, better yet, setting `max_new_tokens`."
208
  )
209
 
210
  def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
@@ -218,12 +248,14 @@ class DreamGenerationMixin:
218
  elif has_default_max_length:
219
  if generation_config.max_length == DreamGenerationConfig().max_length:
220
  generation_config.max_length = generation_config.max_length + input_ids_length
221
- max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
222
- if max_position_embeddings is not None:
223
- generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
224
  return generation_config
225
 
226
- def _prepare_generation_config(self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict) -> DreamGenerationConfig:
 
 
227
  using_model_generation_config = False
228
  if generation_config is None:
229
  generation_config = DreamGenerationConfig.from_model_config(self.config)
@@ -231,7 +263,7 @@ class DreamGenerationMixin:
231
 
232
  if not is_torchdynamo_compiling():
233
  generation_config = copy.deepcopy(generation_config)
234
- _kwargs = generation_config.update(**kwargs)
235
  if not using_model_generation_config:
236
  if generation_config.bos_token_id is None:
237
  generation_config.bos_token_id = self.generation_config.bos_token_id
@@ -243,7 +275,9 @@ class DreamGenerationMixin:
243
  generation_config.mask_token_id = self.generation_config.mask_token_id
244
  return generation_config
245
 
246
- def _prepare_special_tokens(self, generation_config: DreamGenerationConfig, device: Optional[Union[torch.device, str]] = None):
 
 
247
  def _tensor_or_none(token, device=None):
248
  if token is None:
249
  return token
@@ -293,21 +327,24 @@ class DreamGenerationMixin:
293
  has_default_max_length=has_default_max_length,
294
  input_ids_length=input_ids_length,
295
  )
 
296
  self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
297
 
298
  if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
299
  warnings.warn(
300
- "You are calling .generate() with the `input_ids` being on a device type different"
301
- f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
302
- f" is on {self.device.type}."
303
  )
 
304
  if (
305
  hasattr(generation_config, "pad_token_id")
306
  and torch.any(input_ids == generation_config.pad_token_id)
307
  and attention_mask is None
308
  ):
309
  warnings.warn(
310
- "Padding was detected but no attention mask is passed here. For correct results, please set `attention_mask`."
 
311
  )
312
 
313
  input_ids, attention_mask = self._expand_inputs_for_generation(
@@ -333,6 +370,7 @@ class DreamGenerationMixin:
333
  generation_tokens_hook_func,
334
  generation_logits_hook_func,
335
  ) -> Union[DreamModelOutput, torch.LongTensor]:
 
336
  output_history = generation_config.output_history
337
  return_dict_in_generate = generation_config.return_dict_in_generate
338
  max_length = generation_config.max_length
@@ -345,12 +383,13 @@ class DreamGenerationMixin:
345
  top_p = generation_config.top_p
346
  top_k = generation_config.top_k
347
 
 
348
  rcr = generation_config.rcr
349
  conf_alg = generation_config.conf_alg
350
 
351
  histories = [] if (return_dict_in_generate and output_history) else None
352
 
353
- # pad to max_length
354
  x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
355
 
356
  if attention_mask is not None and torch.any(attention_mask == 0.0):
@@ -367,107 +406,98 @@ class DreamGenerationMixin:
367
 
368
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
369
 
370
- # rcr=True:引入轻量跟踪,不影响 baseline
371
- is_fixed = torch.zeros_like(x, dtype=torch.bool) if rcr else None
372
- fixed_conf = torch.full(x.shape, float("-inf"), device=x.device, dtype=torch.float32) if rcr else None
373
- # 存放已确定位置的置信度
374
 
 
375
  x = generation_tokens_hook_func(None, x, None)
 
376
  for i in range(steps):
377
  mask_index = (x == mask_token_id)
 
378
  logits = self(x, attention_mask, tok_idx).logits
379
  logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
380
 
 
381
  logits = generation_logits_hook_func(i, x, logits)
382
 
383
  mask_logits = logits[mask_index]
384
  t = timesteps[i]
385
  s = timesteps[i + 1]
386
 
 
 
 
 
387
  if alg == "origin":
388
- # 完全保持原始
389
  p_transfer = 1 - s / t if i < steps - 1 else 1
390
- x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
391
  transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
392
  _, x0[transfer_index_t_s] = sample_tokens(
393
- mask_logits[transfer_index_t_s],
394
- temperature=temperature,
395
- top_p=top_p,
396
- top_k=top_k,
397
  )
398
  x[mask_index] = x0.clone()
399
 
400
- # origin 分支不做 RCR(与原版一致)
401
  else:
402
- # 置信度算法:rcr=False 用 alg;rcr=True 用 conf_alg(与之前一致)
403
- if (not rcr and alg == "maskgit_plus") or (rcr and conf_alg == "maskgit_plus"):
 
 
 
 
 
404
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
405
- elif (not rcr and alg == "topk_margin") or (rcr and conf_alg == "topk_margin"):
406
  confidence, x0 = sample_tokens(
407
  mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True
408
  )
409
- elif (not rcr and alg == "entropy") or (rcr and conf_alg == "entropy"):
410
  confidence, x0 = sample_tokens(
411
  mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True
412
  )
413
  else:
414
- if rcr:
415
- if alg == "maskgit_plus":
416
- confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
417
- elif alg == "topk_margin":
418
- confidence, x0 = sample_tokens(
419
- mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True
420
- )
421
- elif alg == "entropy":
422
- confidence, x0 = sample_tokens(
423
- mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True
424
- )
425
- else:
426
- raise RuntimeError(f"Unknown alg: {alg}")
427
- else:
428
- raise RuntimeError(f"Unknown alg: {alg}")
429
-
430
- # ===== baseline 的“选入”逻辑:原样保留 =====
431
- num_mask_token = mask_index.sum() / mask_index.shape[0]
432
- number_transfer_tokens = (
433
- int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
434
  )
435
- full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
436
  full_confidence[mask_index] = confidence
437
 
438
- if number_transfer_tokens > 0:
439
- if alg_temp is None or alg_temp == 0:
440
- _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
441
- else:
442
- full_confidence = full_confidence / alg_temp
443
- full_confidence = F.softmax(full_confidence, dim=-1)
444
- transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
445
-
446
- x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
447
- x_[mask_index] = x0.clone()
448
- row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
449
- x[row_indices, transfer_index] = x_[row_indices, transfer_index]
450
-
451
- # ===== 仅 rcr=True:记录“已确定”的位置与它们的置信度(用于后续回遮)=====
452
- if rcr:
453
- is_fixed[row_indices, transfer_index] = True
454
- # 注意:这里存 baseline 使用的 full_confidence(与 baseline 完全一致)
455
- fixed_conf[row_indices, transfer_index] = full_confidence[row_indices, transfer_index]
456
-
457
- # ===== 仅 rcr=True:在“选入”之后按累计目标回遮最低置信度的超额部分 =====
458
  if rcr:
459
- # 这一步只回遮,完全不改变 baseline 的选入行为
460
- self._rcr_remask_after_selection(
461
  x=x,
 
 
 
 
462
  mask_token_id=mask_token_id,
463
  step=i,
464
- steps=steps,
465
  s=s,
466
  t=t,
467
- is_fixed=is_fixed,
468
- fixed_conf=fixed_conf,
469
  )
470
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  x = generation_tokens_hook_func(i, x, logits)
472
 
473
  if histories is not None:
 
1
  # coding=utf-8
2
+ # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace
3
+ # Licensed under the Apache License, Version 2.0
4
+
5
  import warnings
6
  import copy
7
  from dataclasses import dataclass
 
18
 
19
 
20
  def top_p_logits(logits, top_p=None):
21
+ if top_p is None or top_p >= 1:
22
+ return logits
23
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
24
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
25
  sorted_indices_to_remove = cumulative_probs > top_p
26
+ # keep first token above threshold
27
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
28
  sorted_indices_to_remove[..., 0] = 0
29
  mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
30
  mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
31
+ return logits.masked_fill(mask, torch.finfo(logits.dtype).min)
 
32
 
33
 
34
  def top_k_logits(logits, top_k=None):
35
+ if top_k is None:
36
+ return logits
37
  top_k = min(top_k, logits.size(-1))
38
+ thresh = torch.topk(logits, top_k)[0][..., -1, None]
39
+ indices_to_remove = logits < thresh
40
+ return logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
41
 
42
 
43
  def sample_tokens(
44
  logits,
45
+ temperature: float = 0.0,
46
+ top_p: Optional[float] = None,
47
+ top_k: Optional[int] = None,
48
+ margin_confidence: bool = False,
49
+ neg_entropy: bool = False,
50
  ):
51
+ # 保持 dtype 与 logits 一致(包含 bf16/fp16)
52
+ if temperature and temperature > 0:
53
  logits = logits / temperature
54
  if top_p is not None and top_p < 1:
55
  logits = top_p_logits(logits, top_p)
 
58
 
59
  probs = torch.softmax(logits, dim=-1)
60
 
61
+ if temperature and temperature > 0:
62
+ # 采样
63
  try:
64
  x0 = dists.Categorical(probs=probs).sample()
65
  confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
66
  except Exception:
67
  confidence, x0 = probs.max(dim=-1)
68
  else:
69
+ # 贪心
70
  confidence, x0 = probs.max(dim=-1)
71
 
72
  if margin_confidence:
73
  sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
74
+ top1_probs = sorted_probs[..., 0]
75
+ top2_probs = sorted_probs[..., 1]
76
  confidence = top1_probs - top2_probs
77
 
78
  if neg_entropy:
79
+ eps = probs.new_tensor(1e-10)
80
+ log_probs = torch.log(probs + eps)
81
+ # 负熵(和为负数),数值上越大(绝对值越小)表示不确定;此处直接用于排序
82
  confidence = torch.sum(probs * log_probs, dim=-1)
83
 
84
  return confidence, x0
 
97
  self.top_k: Optional[int] = kwargs.pop("top_k", None)
98
  self.max_length = kwargs.pop("max_length", 20)
99
  self.max_new_tokens = kwargs.pop("max_new_tokens", None)
100
+
101
+ # diffusion specific params
102
  self.eps: float = kwargs.pop("eps", 1e-3)
103
  self.steps: int = kwargs.pop("steps", 512)
104
  self.alg: str = kwargs.pop("alg", "origin")
105
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
106
 
107
+ # RCR 参数(默认不生效)
108
  self.rcr: bool = kwargs.pop("rcr", False)
109
  self.conf_alg: str = kwargs.pop("conf_alg", "maskgit_plus")
110
 
 
121
 
122
  self.generation_kwargs = kwargs.pop("generation_kwargs", {})
123
 
124
+ # hub meta
125
  self._from_model_config = kwargs.pop("_from_model_config", False)
126
  self._commit_hash = kwargs.pop("_commit_hash", None)
127
  self.transformers_version = kwargs.pop("transformers_version", __version__)
 
137
  self.validate(is_init=True)
138
 
139
  def validate(self, is_init=False):
140
+ # 保留空实现,兼容 upstream
141
  pass
142
 
143
 
 
156
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
157
  return input_ids, attention_mask
158
 
159
+ def _apply_rcr_logic(
 
160
  self,
161
+ x: torch.LongTensor,
162
+ x0_sel: torch.LongTensor,
163
+ conf_sel: torch.Tensor,
164
+ mask_index: torch.Tensor,
165
+ overtime_confidence: torch.Tensor,
166
  mask_token_id: int,
167
  step: int,
168
+ total_steps: int,
169
  s: torch.Tensor,
170
  t: torch.Tensor,
 
 
171
  ):
172
  """
173
+ Running Confidence Remasking (RCR)
174
+ - Dream 原调度计算每步应转移的 token 数;
175
+ - 先把本步最高置信度的若干个位置从 [MASK] 转为预测;
176
+ - 再根据“截至本步的目标累计数量”,把最低置信度的多余部分回遮回 [MASK]。
177
+ 仅在 rcr=True 时调用。
178
  """
179
+ device = x.device
180
+ dtype = overtime_confidence.dtype # == logits.dtype
181
+ B = x.shape[0]
182
+
183
+ # 当前 batch 平均剩余 mask 数
184
+ num_mask_token = mask_index.sum() / mask_index.shape[0]
185
+ # 本步的转移数量(与 Dream 调度一致)
186
+ number_transfer_tokens = int(num_mask_token * (1 - s / t)) if step < total_steps - 1 else int(num_mask_token)
187
 
188
+ # 构造“全长”置信度与候选 token(非 mask 位置分别设为 -inf / mask_token_id)
189
+ full_conf = torch.full(x.shape, float("-inf"), device=device, dtype=dtype)
190
+ x_temp = torch.full_like(x, fill_value=mask_token_id, dtype=torch.long, device=device)
191
+ full_conf[mask_index] = conf_sel
192
+ x_temp[mask_index] = x0_sel
193
 
194
  for j in range(B):
195
+ masked_j = int(mask_index[j].sum().item())
196
+ if masked_j == 0:
197
+ continue
198
+ k_j = min(number_transfer_tokens, masked_j)
199
+
200
+ if k_j > 0:
201
+ # 选出本步 top-k_j 的位置
202
+ _, select_idx = torch.topk(full_conf[j], k=k_j, largest=True)
203
+ x[j, select_idx] = x_temp[j, select_idx]
204
+ # 记录这些位置的置信度,用于累计与回遮判断
205
+ overtime_confidence[j, select_idx] = full_conf[j, select_idx]
206
+
207
+ # 目标累计(与原 Dream 线性进度对齐)
208
+ if step < total_steps - 1:
209
+ target_cum = int(num_mask_token * (1 - s / t)) # 累计目标到当前步
210
+ gen_mask = overtime_confidence[j] > overtime_confidence.new_tensor(0)
211
+ current_gen = int(gen_mask.sum().item())
212
+ overflow = max(0, current_gen - target_cum)
213
+ if overflow > 0:
214
+ gen_indices = torch.where(gen_mask)[0]
215
+ if gen_indices.numel() > 0:
216
+ gen_conf = overtime_confidence[j, gen_indices]
217
+ overflow = min(overflow, int(gen_indices.numel()))
218
+ # 选“最低置信度”的 overflow 个位置回遮
219
+ _, low_local = torch.topk(gen_conf, k=overflow, largest=False)
220
+ low_global = gen_indices[low_local]
221
+ x[j, low_global] = mask_token_id
222
+ overtime_confidence[j, low_global] = overtime_confidence.new_zeros(low_global.shape)
223
 
224
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
225
  if is_torchdynamo_compiling():
 
232
  UserWarning,
233
  )
234
  if input_ids_length >= generation_config.max_length:
 
235
  raise ValueError(
236
+ f"Input length is {input_ids_length}, but `max_length` is {generation_config.max_length}. "
237
+ "Consider increasing `max_length` or setting `max_new_tokens`."
 
238
  )
239
 
240
  def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
 
248
  elif has_default_max_length:
249
  if generation_config.max_length == DreamGenerationConfig().max_length:
250
  generation_config.max_length = generation_config.max_length + input_ids_length
251
+ mpe = getattr(self.config, "max_position_embeddings", None)
252
+ if mpe is not None:
253
+ generation_config.max_length = min(generation_config.max_length, mpe)
254
  return generation_config
255
 
256
+ def _prepare_generation_config(
257
+ self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict
258
+ ) -> DreamGenerationConfig:
259
  using_model_generation_config = False
260
  if generation_config is None:
261
  generation_config = DreamGenerationConfig.from_model_config(self.config)
 
263
 
264
  if not is_torchdynamo_compiling():
265
  generation_config = copy.deepcopy(generation_config)
266
+ _ = generation_config.update(**kwargs)
267
  if not using_model_generation_config:
268
  if generation_config.bos_token_id is None:
269
  generation_config.bos_token_id = self.generation_config.bos_token_id
 
275
  generation_config.mask_token_id = self.generation_config.mask_token_id
276
  return generation_config
277
 
278
+ def _prepare_special_tokens(
279
+ self, generation_config: DreamGenerationConfig, device: Optional[Union[torch.device, str]] = None
280
+ ):
281
  def _tensor_or_none(token, device=None):
282
  if token is None:
283
  return token
 
327
  has_default_max_length=has_default_max_length,
328
  input_ids_length=input_ids_length,
329
  )
330
+
331
  self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
332
 
333
  if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
334
  warnings.warn(
335
+ "You are calling .generate() with `input_ids` on a device type different than your model's device. "
336
+ f"`input_ids` is on {input_ids.device.type}, model is on {self.device.type}.",
337
+ UserWarning,
338
  )
339
+
340
  if (
341
  hasattr(generation_config, "pad_token_id")
342
  and torch.any(input_ids == generation_config.pad_token_id)
343
  and attention_mask is None
344
  ):
345
  warnings.warn(
346
+ "Padding was detected but no attention mask is passed. For correct results, set `attention_mask` when batch-padding inputs.",
347
+ UserWarning,
348
  )
349
 
350
  input_ids, attention_mask = self._expand_inputs_for_generation(
 
370
  generation_tokens_hook_func,
371
  generation_logits_hook_func,
372
  ) -> Union[DreamModelOutput, torch.LongTensor]:
373
+ # 原变量
374
  output_history = generation_config.output_history
375
  return_dict_in_generate = generation_config.return_dict_in_generate
376
  max_length = generation_config.max_length
 
383
  top_p = generation_config.top_p
384
  top_k = generation_config.top_k
385
 
386
+ # RCR 控制
387
  rcr = generation_config.rcr
388
  conf_alg = generation_config.conf_alg
389
 
390
  histories = [] if (return_dict_in_generate and output_history) else None
391
 
392
+ # pad max_length
393
  x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
394
 
395
  if attention_mask is not None and torch.any(attention_mask == 0.0):
 
406
 
407
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
408
 
409
+ # 置信度累计缓冲,延迟到拿到 logits.dtype 后再初始化,避免 dtype 错误
410
+ overtime_confidence = None # dtype = logits.dtype(初始化时设置)
 
 
411
 
412
+ # 允许用户控制中间 tokens
413
  x = generation_tokens_hook_func(None, x, None)
414
+
415
  for i in range(steps):
416
  mask_index = (x == mask_token_id)
417
+
418
  logits = self(x, attention_mask, tok_idx).logits
419
  logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
420
 
421
+ # 允许用户控制中间 logits
422
  logits = generation_logits_hook_func(i, x, logits)
423
 
424
  mask_logits = logits[mask_index]
425
  t = timesteps[i]
426
  s = timesteps[i + 1]
427
 
428
+ # 首次根据 logits.dtype 初始化 overtime_confidence(避免 Float/BFloat16 冲突)
429
+ if rcr and overtime_confidence is None:
430
+ overtime_confidence = torch.zeros_like(x, dtype=logits.dtype, device=x.device)
431
+
432
  if alg == "origin":
433
+ # 原始 Dream 逻辑(不动)
434
  p_transfer = 1 - s / t if i < steps - 1 else 1
435
+ x0 = torch.full_like(x[mask_index], fill_value=mask_token_id, dtype=torch.long, device=self.device)
436
  transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
437
  _, x0[transfer_index_t_s] = sample_tokens(
438
+ mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k
 
 
 
439
  )
440
  x[mask_index] = x0.clone()
441
 
 
442
  else:
443
+ # 选择置信度算法
444
+ use_alg = alg
445
+ if rcr:
446
+ # rcr=True 时,置信度算法由 conf_alg 决定(不影响 baseline)
447
+ use_alg = conf_alg
448
+
449
+ if use_alg == "maskgit_plus":
450
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
451
+ elif use_alg == "topk_margin":
452
  confidence, x0 = sample_tokens(
453
  mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True
454
  )
455
+ elif use_alg == "entropy":
456
  confidence, x0 = sample_tokens(
457
  mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True
458
  )
459
  else:
460
+ raise RuntimeError(f"Unknown alg: {alg}")
461
+
462
+ # 统一 full_confidence dtype = logits.dtype(避免 int/float 混合)
463
+ full_confidence = torch.full(
464
+ x.shape, float("-inf"), device=self.device, dtype=logits.dtype
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
  )
 
466
  full_confidence[mask_index] = confidence
467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  if rcr:
469
+ # === RCR 分支:先转移 top-k,再根据累计目标回遮 ===
470
+ self._apply_rcr_logic(
471
  x=x,
472
+ x0_sel=x0,
473
+ conf_sel=confidence,
474
+ mask_index=mask_index,
475
+ overtime_confidence=overtime_confidence,
476
  mask_token_id=mask_token_id,
477
  step=i,
478
+ total_steps=steps,
479
  s=s,
480
  t=t,
 
 
481
  )
482
+ else:
483
+ # === baseline 分支:保持 Dream 逻辑不变 ===
484
+ num_mask_token = mask_index.sum() / mask_index.shape[0]
485
+ number_transfer_tokens = (
486
+ int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
487
+ )
488
+ if number_transfer_tokens > 0:
489
+ if alg_temp is None or alg_temp == 0:
490
+ _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
491
+ else:
492
+ fc = full_confidence / alg_temp
493
+ fc = F.softmax(fc, dim=-1)
494
+ transfer_index = torch.multinomial(fc, num_samples=number_transfer_tokens)
495
+ x_ = torch.full_like(x, fill_value=mask_token_id, dtype=torch.long, device=self.device)
496
+ x_[mask_index] = x0.clone()
497
+ row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
498
+ x[row_indices, transfer_index] = x_[row_indices, transfer_index]
499
+
500
+ # 允许用户控制中间 tokens
501
  x = generation_tokens_hook_func(i, x, logits)
502
 
503
  if histories is not None: