autoprogrammer commited on
Commit
8c62663
·
verified ·
1 Parent(s): 8ddc04c

Update generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +242 -190
generation_utils.py CHANGED
@@ -1,5 +1,18 @@
1
  # coding=utf-8
2
- # Copyright 2024 ...
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import warnings
4
  import copy
5
  from dataclasses import dataclass
@@ -9,106 +22,75 @@ import torch
9
  import torch.distributions as dists
10
  from torch.nn import functional as F
11
  from transformers import __version__
12
- from transformers.generation.configuration_utils import GenerationConfig
13
- from transformers.utils import ModelOutput, is_torchdynamo_compiling, logging
 
 
 
 
 
 
14
 
15
  logger = logging.get_logger(__name__)
16
 
17
 
18
  def top_p_logits(logits, top_p=None):
19
- if top_p is None or top_p >= 1:
20
- return logits
21
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
22
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
23
  sorted_indices_to_remove = cumulative_probs > top_p
 
24
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
25
  sorted_indices_to_remove[..., 0] = 0
 
26
  mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
27
  mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
28
- return logits.masked_fill(mask, torch.finfo(logits.dtype).min)
29
-
30
 
31
  def top_k_logits(logits, top_k=None):
32
- if top_k is None:
33
- return logits
34
- top_k = min(int(top_k), logits.size(-1))
35
- thresh = torch.topk(logits, top_k)[0][..., -1, None]
36
- indices_to_remove = logits < thresh
37
- return logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
38
-
39
-
40
- def sample_tokens(
41
- logits,
42
- temperature=0.0,
43
- top_p=None,
44
- top_k=None,
45
- margin_confidence=False,
46
- neg_entropy=False,
47
- ):
48
- # temperature
49
- if temperature > 0:
50
- logits = logits / temperature
51
 
52
- # filtering
53
- logits = top_p_logits(logits, top_p)
54
- logits = top_k_logits(logits, top_k)
55
 
56
- # probs
 
 
 
 
 
 
 
57
  probs = torch.softmax(logits, dim=-1)
58
 
59
- # sample or argmax
60
  if temperature > 0:
61
  try:
62
  x0 = dists.Categorical(probs=probs).sample()
63
  confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
64
- except Exception:
65
  confidence, x0 = probs.max(dim=-1)
66
  else:
67
  confidence, x0 = probs.max(dim=-1)
68
-
69
- # confidence variants
70
  if margin_confidence:
71
  sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
72
- top1_probs = sorted_probs[..., 0]
73
- top2_probs = sorted_probs[..., 1]
74
- confidence = top1_probs - top2_probs
75
-
 
 
76
  if neg_entropy:
77
  epsilon = 1e-10
78
  log_probs = torch.log(probs + epsilon)
79
- # 注意:neg_entropy 越大代表越“确定”
80
  confidence = -(probs * log_probs).sum(dim=-1)
81
-
82
  return confidence, x0
83
 
84
 
85
- def get_num_transfer_tokens_maskgit(mask_index: torch.Tensor, steps: int, mode: str = "linear") -> torch.Tensor:
86
- """
87
- LLaDA 风格:预计算每一步要“转移(解码)”的 token 数(逐样本),保证总量等于总 mask 数。
88
- mask_index: [B, L] bool
89
- return: [B, steps] long
90
- """
91
- device = mask_index.device
92
- num_masked_tokens = mask_index.sum(dim=-1, keepdim=True).float() # [B,1]
93
-
94
- t = torch.linspace(0, 1, steps + 1, device=device)[1:] # (steps,)
95
- if mode == "linear":
96
- ratio = t
97
- elif mode == "cosine":
98
- ratio = 1 - torch.cos(t * torch.pi / 2)
99
- elif mode == "pow2":
100
- ratio = t ** 2
101
- elif mode == "sqrt":
102
- ratio = torch.sqrt(t)
103
- else:
104
- raise ValueError(f"Unknown mode: {mode}")
105
-
106
- # 累积配额(四舍五入),再做差得到每步配额
107
- cum = (ratio.unsqueeze(0) * num_masked_tokens).round().long() # [B, steps]
108
- per_step = torch.diff(cum, dim=-1, prepend=torch.zeros_like(cum[:, :1]))
109
- return per_step # [B, steps], 每行之和 ≈ num_masked_tokens(四舍五入引入±1 误差)
110
-
111
-
112
  @dataclass
113
  class DreamModelOutput(ModelOutput):
114
  sequences: torch.LongTensor = None
@@ -125,20 +107,19 @@ class DreamGenerationConfig(GenerationConfig):
125
  # diffusion specific params
126
  self.eps: float = kwargs.pop("eps", 1e-3)
127
  self.steps: int = kwargs.pop("steps", 512)
128
- self.alg: str = kwargs.pop("alg", "origin")
129
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
130
 
131
  # RCR specific parameters
132
  self.rcr: bool = kwargs.pop("rcr", False)
133
- self.conf_alg: str = kwargs.pop("conf_alg", "maskgit_plus")
134
- self.mode: str = kwargs.pop("mode", "linear") # LLaDA 调度
135
 
136
- # Output control
137
  self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
138
  self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
139
  self.output_history: bool = kwargs.pop("output_history", False)
140
 
141
- # Special tokens
142
  self.mask_token_id = kwargs.pop("mask_token_id", None)
143
  self.pad_token_id = kwargs.pop("pad_token_id", None)
144
  self.bos_token_id = kwargs.pop("bos_token_id", None)
@@ -147,12 +128,16 @@ class DreamGenerationConfig(GenerationConfig):
147
  # Wild card
148
  self.generation_kwargs = kwargs.pop("generation_kwargs", {})
149
 
150
- # Hub info
 
151
  self._from_model_config = kwargs.pop("_from_model_config", False)
152
  self._commit_hash = kwargs.pop("_commit_hash", None)
153
  self.transformers_version = kwargs.pop("transformers_version", __version__)
154
 
 
155
  if not self._from_model_config:
 
 
156
  for key, value in kwargs.items():
157
  try:
158
  setattr(self, key, value)
@@ -160,19 +145,22 @@ class DreamGenerationConfig(GenerationConfig):
160
  logger.error(f"Can't set {key} with value {value} for {self}")
161
  raise err
162
 
 
163
  self.validate(is_init=True)
164
 
165
  def validate(self, is_init=False):
166
  pass
167
 
168
-
169
  class DreamGenerationMixin:
170
  @staticmethod
171
  def _expand_inputs_for_generation(
172
  expand_size: int = 1,
173
  input_ids: Optional[torch.LongTensor] = None,
174
- attention_mask: Optional[torch.LongTensor] = None,
175
  ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
 
 
 
176
  if expand_size == 1:
177
  return input_ids, attention_mask
178
  if input_ids is not None:
@@ -181,47 +169,129 @@ class DreamGenerationMixin:
181
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
182
  return input_ids, attention_mask
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
 
 
 
185
  if is_torchdynamo_compiling():
186
  return
 
 
187
  if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
 
188
  warnings.warn(
189
- f"Using the model-agnostic default `max_length` (={generation_config.max_length}); "
190
- f"prefer setting `max_new_tokens`.",
 
191
  UserWarning,
192
  )
193
  if input_ids_length >= generation_config.max_length:
 
194
  raise ValueError(
195
- f"Input length is {input_ids_length}, but `max_length` is set to {generation_config.max_length}. "
196
- f"Increase `max_length` or set `max_new_tokens`."
 
197
  )
198
 
199
- def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
 
 
 
 
 
 
 
200
  if generation_config.max_new_tokens is not None:
201
  if not has_default_max_length and generation_config.max_length is not None:
202
  logger.warning(
203
- f"Both `max_new_tokens` and `max_length` set. `max_new_tokens` takes precedence."
 
 
 
204
  )
205
  generation_config.max_length = generation_config.max_new_tokens + input_ids_length
 
206
  elif has_default_max_length:
207
  if generation_config.max_length == DreamGenerationConfig().max_length:
208
  generation_config.max_length = generation_config.max_length + input_ids_length
209
  max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
210
  if max_position_embeddings is not None:
211
  generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
 
212
  return generation_config
213
 
214
  def _prepare_generation_config(
215
  self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict
216
  ) -> DreamGenerationConfig:
 
 
 
 
 
217
  using_model_generation_config = False
218
  if generation_config is None:
219
  generation_config = DreamGenerationConfig.from_model_config(self.config)
220
  using_model_generation_config = True
221
 
 
 
 
222
  if not is_torchdynamo_compiling():
223
  generation_config = copy.deepcopy(generation_config)
224
  _kwargs = generation_config.update(**kwargs)
 
225
  if not using_model_generation_config:
226
  if generation_config.bos_token_id is None:
227
  generation_config.bos_token_id = self.generation_config.bos_token_id
@@ -239,9 +309,20 @@ class DreamGenerationMixin:
239
  generation_config: DreamGenerationConfig,
240
  device: Optional[Union[torch.device, str]] = None,
241
  ):
 
 
 
 
 
 
 
 
 
 
242
  def _tensor_or_none(token, device=None):
243
  if token is None:
244
  return token
 
245
  device = device if device is not None else self.device
246
  if isinstance(token, torch.Tensor):
247
  return token.to(device)
@@ -252,13 +333,19 @@ class DreamGenerationMixin:
252
  pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
253
  mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)
254
 
 
255
  if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
256
  eos_token_tensor = eos_token_tensor.unsqueeze(0)
257
 
 
258
  if pad_token_tensor is None and eos_token_tensor is not None:
259
  pad_token_tensor = eos_token_tensor[0]
260
  logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
261
 
 
 
 
 
262
  generation_config._bos_token_tensor = bos_token_tensor
263
  generation_config._eos_token_tensor = eos_token_tensor
264
  generation_config._pad_token_tensor = pad_token_tensor
@@ -271,16 +358,19 @@ class DreamGenerationMixin:
271
  generation_config: Optional[DreamGenerationConfig] = None,
272
  **kwargs,
273
  ) -> Union[DreamModelOutput, torch.LongTensor]:
 
274
  generation_config = self._prepare_generation_config(generation_config, **kwargs)
275
  generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
276
  generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
277
 
 
278
  assert inputs is not None
279
  input_ids = inputs
280
  device = input_ids.device
281
  attention_mask = kwargs.pop("attention_mask", None)
282
  self._prepare_special_tokens(generation_config, device=device)
283
 
 
284
  input_ids_length = input_ids.shape[-1]
285
  has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
286
  generation_config = self._prepare_generated_length(
@@ -288,23 +378,35 @@ class DreamGenerationMixin:
288
  has_default_max_length=has_default_max_length,
289
  input_ids_length=input_ids_length,
290
  )
291
- self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
292
 
 
 
 
293
  if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
294
  warnings.warn(
295
- "You are calling .generate() with input_ids on a different device than the model.",
 
 
 
 
 
296
  UserWarning,
297
  )
298
- if hasattr(generation_config, "pad_token_id") and torch.any(input_ids == generation_config.pad_token_id) and attention_mask is None:
 
 
 
 
299
  warnings.warn(
300
- "Padding detected but no attention_mask is passed. For correct results, pass attention_mask.",
 
301
  UserWarning,
302
  )
303
 
304
  input_ids, attention_mask = self._expand_inputs_for_generation(
305
  expand_size=generation_config.num_return_sequences,
306
  input_ids=input_ids,
307
- attention_mask=attention_mask,
308
  )
309
 
310
  result = self._sample(
@@ -312,7 +414,7 @@ class DreamGenerationMixin:
312
  attention_mask=attention_mask,
313
  generation_config=generation_config,
314
  generation_tokens_hook_func=generation_tokens_hook_func,
315
- generation_logits_hook_func=generation_logits_hook_func,
316
  )
317
  return result
318
 
@@ -322,10 +424,9 @@ class DreamGenerationMixin:
322
  attention_mask: Optional[torch.LongTensor],
323
  generation_config: DreamGenerationConfig,
324
  generation_tokens_hook_func,
325
- generation_logits_hook_func,
326
  ) -> Union[DreamModelOutput, torch.LongTensor]:
327
-
328
- # --- init values ---
329
  output_history = generation_config.output_history
330
  return_dict_in_generate = generation_config.return_dict_in_generate
331
  max_length = generation_config.max_length
@@ -338,20 +439,22 @@ class DreamGenerationMixin:
338
  top_p = generation_config.top_p
339
  top_k = generation_config.top_k
340
 
341
- # RCR specific
342
  rcr = generation_config.rcr
343
  conf_alg = generation_config.conf_alg
344
- mode = generation_config.mode
345
 
346
  histories = [] if (return_dict_in_generate and output_history) else None
347
 
348
- # pad to max_length with [MASK]
349
  x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
350
 
351
  if attention_mask is not None and torch.any(attention_mask == 0.0):
 
352
  attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
353
  tok_idx = attention_mask.long().cumsum(-1) - 1
354
  tok_idx.masked_fill_(attention_mask == 0, 1)
 
 
355
  attention_mask = torch.logical_and(
356
  attention_mask.unsqueeze(1).unsqueeze(-2),
357
  attention_mask.unsqueeze(1).unsqueeze(-1),
@@ -360,126 +463,75 @@ class DreamGenerationMixin:
360
  tok_idx = None
361
  attention_mask = "full"
362
 
363
- # global linear schedule 1 -> eps
364
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
365
 
366
- # 初始 mask(用于预分配 per-step token 预算;与 LLaDA 类似)
367
- initial_mask_index = (x == mask_token_id) # [B, L]
368
- per_step_tokens = get_num_transfer_tokens_maskgit(initial_mask_index, steps, mode=mode) # [B, steps]
369
-
370
- # RCR tracking
371
  overtime_confidence = torch.zeros_like(x, dtype=torch.float32) if rcr else None
372
 
373
- # user-defined token control
374
  x = generation_tokens_hook_func(None, x, None)
375
-
376
  for i in range(steps):
377
- # 当前还未确定的 mask 位置
378
- mask_index = (x == mask_token_id) # [B, L]
379
-
380
- # 模型 logits(单步预测 + 向右对齐)
381
  logits = self(x, attention_mask, tok_idx).logits
382
- logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
383
 
384
- # user-defined logits control
385
  logits = generation_logits_hook_func(i, x, logits)
386
 
387
- # 只取 mask 位置对应的 logits 参与采样
388
- mask_logits = logits[mask_index] # [M, V] (M=mask 个数)
389
  t = timesteps[i]
390
  s = timesteps[i + 1]
391
-
392
- if alg == "origin":
393
- # Dream 迁移(保留)
394
- p_transfer = 1 - (s / t).item() if i < steps - 1 else 1.0
395
- x0 = torch.zeros_like(x[mask_index], device=x.device, dtype=torch.long) + mask_token_id
396
- transfer_index_t_s = (torch.rand(*x0.shape, device=x.device) < p_transfer)
397
- _, x0[transfer_index_t_s] = sample_tokens(
398
- mask_logits[transfer_index_t_s],
399
- temperature=temperature,
400
- top_p=top_p,
401
- top_k=top_k,
402
- )
403
  x[mask_index] = x0.clone()
404
  else:
405
- # 选择置信度算法:RCR 时优先 conf_alg;非 RCR 时用 alg 的同名变体
406
- choose = conf_alg if rcr else alg
407
- if choose == "maskgit_plus":
408
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
409
- elif choose == "topk_margin":
410
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
411
- elif choose == "entropy":
412
- confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
413
  else:
414
- raise RuntimeError(f"Unknown alg/conf_alg: {choose}")
415
-
416
- # 将预测/置信度写回到全长(非 mask 位置用原 token / -inf)
417
- full_conf = torch.full_like(x, -torch.inf, device=x.device, dtype=logits.dtype) # [B, L]
418
- x_temp = torch.zeros_like(x, device=x.device, dtype=torch.long) + mask_token_id # [B, L]
419
- x_temp[mask_index] = x0.clone()
420
- full_conf[mask_index] = confidence
421
-
422
- if not rcr:
423
- # ---------- 非 RCR:逐样本的“当步配额” ----------
424
- k_per_row = per_step_tokens[:, i] # [B]
425
- B = x.size(0)
426
- for j in range(B):
427
- k_j = int(k_per_row[j].item())
428
- if k_j <= 0:
429
- continue
430
- # clamp:不能超过当前样本剩余 mask 数
431
- masked_count_j = mask_index[j].sum().item()
432
- k_j = min(k_j, int(masked_count_j))
433
- if k_j <= 0:
434
- continue
435
- # 只在 mask 内选 topk(full_conf 的非 mask 处已是 -inf)
436
- _, select_idx = torch.topk(full_conf[j], k_j, largest=True)
437
- x[j, select_idx] = x_temp[j, select_idx]
438
 
 
 
 
 
 
439
  else:
440
- # ---------- RCR:LLaDA 风格的“累积选取 + 下一步反遮盖” ----------
441
- B = x.size(0)
442
- for j in range(B):
443
- # 当步+未来的总剩余配额(从第 i 步到最后一步)
444
- total_remaining_tokens = int(per_step_tokens[j, i:].sum().item())
445
- if total_remaining_tokens <= 0:
446
- continue
447
-
448
- masked_count_j = mask_index[j].sum().item()
449
- k_total = min(total_remaining_tokens, int(masked_count_j))
450
- if k_total <= 0:
451
- continue
452
-
453
- # 1) 累积选取:一次性选出“当步至结尾”应确定的 token 集合
454
- # (在 mask 内的 topk)
455
- _, select_indices = torch.topk(full_conf[j], k_total, largest=True)
456
- x[j, select_indices] = x_temp[j, select_indices]
457
- overtime_confidence[j, select_indices] = full_conf[j, select_indices].clone().float()
458
-
459
- # 2) 下一步前:把“下一步之后还应保留给未来步数的那部分”按最低置信度反遮盖回去
460
- if i < (steps - 1):
461
- next_to_keep_for_future = int(per_step_tokens[j, i + 1 :].sum().item())
462
- if next_to_keep_for_future > 0:
463
- # 仅在“已选中的位置”(overtime_confidence>0)里,反遮盖最低置信度的那部分
464
- current_conf = overtime_confidence[j]
465
- # 把 0 置信度(未生成)位置临时设成 +inf,避免被误选为“最低”
466
- safe_conf = torch.where(current_conf == 0.0, torch.tensor(float("inf"), device=x.device), current_conf)
467
- # 需要反遮盖的数量不应超过当前已选中的数
468
- gen_count = (safe_conf != float("inf")).sum().item()
469
- k_remask = min(next_to_keep_for_future, int(gen_count))
470
- if k_remask > 0:
471
- # 选“最不自信”的 k_remask 个
472
- _, local_mask_indices = torch.topk(safe_conf, k_remask, largest=False)
473
- x[j, local_mask_indices] = mask_token_id
474
- overtime_confidence[j, local_mask_indices] = 0.0 # 清零表示撤回
475
-
476
- # user-defined token control
477
  x = generation_tokens_hook_func(i, x, logits)
478
 
479
  if histories is not None:
480
  histories.append(x.clone())
481
-
482
  if return_dict_in_generate:
483
- return DreamModelOutput(sequences=x, history=histories)
 
 
 
484
  else:
485
  return x
 
1
  # coding=utf-8
2
+ # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
  import warnings
17
  import copy
18
  from dataclasses import dataclass
 
22
  import torch.distributions as dists
23
  from torch.nn import functional as F
24
  from transformers import __version__
25
+ from transformers.generation.configuration_utils import (
26
+ GenerationConfig
27
+ )
28
+ from transformers.utils import (
29
+ ModelOutput,
30
+ is_torchdynamo_compiling,
31
+ logging,
32
+ )
33
 
34
  logger = logging.get_logger(__name__)
35
 
36
 
37
  def top_p_logits(logits, top_p=None):
 
 
38
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
39
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
40
  sorted_indices_to_remove = cumulative_probs > top_p
41
+ # Shift the indices to the right to keep the first token above the threshold
42
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
43
  sorted_indices_to_remove[..., 0] = 0
44
+
45
  mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
46
  mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
47
+ logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
48
+ return logits
49
 
50
  def top_k_logits(logits, top_k=None):
51
+ top_k = min(top_k, logits.size(-1)) # Safety check
52
+ # Remove all tokens with a probability less than the last token of the top-k
53
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
54
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
55
+ return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
 
 
 
57
 
58
+ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
59
+
60
+ if temperature > 0:
61
+ logits = logits / temperature
62
+ if top_p is not None and top_p < 1:
63
+ logits = top_p_logits(logits, top_p)
64
+ if top_k is not None:
65
+ logits = top_k_logits(logits, top_k)
66
  probs = torch.softmax(logits, dim=-1)
67
 
 
68
  if temperature > 0:
69
  try:
70
  x0 = dists.Categorical(probs=probs).sample()
71
  confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
72
+ except:
73
  confidence, x0 = probs.max(dim=-1)
74
  else:
75
  confidence, x0 = probs.max(dim=-1)
76
+
 
77
  if margin_confidence:
78
  sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
79
+ # Extract top1 and top2 probabilities
80
+ top1_probs = sorted_probs[:, 0]
81
+ top2_probs = sorted_probs[:, 1]
82
+ # Calculate confidence as top1 - top2
83
+ confidence = top1_probs - top2_probs
84
+
85
  if neg_entropy:
86
  epsilon = 1e-10
87
  log_probs = torch.log(probs + epsilon)
88
+ # 注意:这里返回的是“负熵”的相反数(越大越自信)
89
  confidence = -(probs * log_probs).sum(dim=-1)
90
+
91
  return confidence, x0
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  @dataclass
95
  class DreamModelOutput(ModelOutput):
96
  sequences: torch.LongTensor = None
 
107
  # diffusion specific params
108
  self.eps: float = kwargs.pop("eps", 1e-3)
109
  self.steps: int = kwargs.pop("steps", 512)
110
+ self.alg: str = kwargs.pop("alg", 'origin')
111
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
112
 
113
  # RCR specific parameters
114
  self.rcr: bool = kwargs.pop("rcr", False)
115
+ self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
 
116
 
117
+ # Parameters that define the output variables of `generate`
118
  self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
119
  self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
120
  self.output_history: bool = kwargs.pop("output_history", False)
121
 
122
+ # Special tokens that can be used at generation time
123
  self.mask_token_id = kwargs.pop("mask_token_id", None)
124
  self.pad_token_id = kwargs.pop("pad_token_id", None)
125
  self.bos_token_id = kwargs.pop("bos_token_id", None)
 
128
  # Wild card
129
  self.generation_kwargs = kwargs.pop("generation_kwargs", {})
130
 
131
+ # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
132
+ # interface.
133
  self._from_model_config = kwargs.pop("_from_model_config", False)
134
  self._commit_hash = kwargs.pop("_commit_hash", None)
135
  self.transformers_version = kwargs.pop("transformers_version", __version__)
136
 
137
+ # Additional attributes without default values
138
  if not self._from_model_config:
139
+ # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
140
+ # model's default configuration file
141
  for key, value in kwargs.items():
142
  try:
143
  setattr(self, key, value)
 
145
  logger.error(f"Can't set {key} with value {value} for {self}")
146
  raise err
147
 
148
+ # Validate the values of the attributes
149
  self.validate(is_init=True)
150
 
151
  def validate(self, is_init=False):
152
  pass
153
 
 
154
  class DreamGenerationMixin:
155
  @staticmethod
156
  def _expand_inputs_for_generation(
157
  expand_size: int = 1,
158
  input_ids: Optional[torch.LongTensor] = None,
159
+ attention_mask: Optional[torch.LongTensor] = None
160
  ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
161
+ """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
162
+ # Do not call torch.repeat_interleave if expand_size is 1 because it clones
163
+ # the input tensor and thus requires more memory although no change is applied
164
  if expand_size == 1:
165
  return input_ids, attention_mask
166
  if input_ids is not None:
 
169
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
170
  return input_ids, attention_mask
171
 
172
+ def _apply_rcr_logic(self, x, x0, confidence, mask_index, overtime_confidence,
173
+ mask_token_id, step, total_steps, s, t):
174
+ """
175
+ RCR:在 Dream 原逻辑上做“最小侵入”改动,使其真正生效。
176
+ - 仍采用 Dream 的调度:本步 global k = num_mask_token * (1 - s/t)
177
+ - 逐样本 clamp,避免批均值 k 在样本上越界
178
+ - 目标累计约束:到本步为止累计应已生成 target_cum = num_mask_token * (1 - s/t)。
179
+ 若当前累计 > 目标,按最低置信度反遮盖回 [MASK]。
180
+ """
181
+ device = x.device
182
+ B, L = x.shape
183
+
184
+ # 与 Dream 保持一致:使用“批均值”的 num_mask_token 与 (1 - s/t) 调度定义
185
+ num_mask_token = (mask_index.sum() / mask_index.shape[0]).item()
186
+ k_global = int(num_mask_token * (1 - (s / t).item())) if step < total_steps - 1 else int(num_mask_token)
187
+
188
+ # 构造全长置信度和临时候选(非 mask 位置分别置为 -inf / mask_token)
189
+ full_conf = torch.full_like(x, -torch.inf, device=device, dtype=confidence.dtype)
190
+ x_temp = torch.zeros_like(x, device=device, dtype=torch.long) + mask_token_id
191
+ full_conf[mask_index] = confidence
192
+ x_temp[mask_index] = x0.clone()
193
+
194
+ for j in range(B):
195
+ # 逐样本 clamp
196
+ masked_count_j = int(mask_index[j].sum().item())
197
+ k_j = min(k_global, masked_count_j)
198
+ if k_j > 0:
199
+ # 只在 mask 内选 topk(非 mask 位置 full_conf 为 -inf,不会被选中)
200
+ _, select_idx = torch.topk(full_conf[j], k_j, largest=True)
201
+ x[j, select_idx] = x_temp[j, select_idx]
202
+ overtime_confidence[j, select_idx] = full_conf[j, select_idx].clone().float()
203
+
204
+ # ===== 目标累计约束 + 反遮盖 =====
205
+ if step < total_steps - 1:
206
+ # Dream 的“到本步为止累计应已生成”的目标数量
207
+ target_cum = int(num_mask_token * (1 - (s / t).item()))
208
+ # 当前已生成的数量(overtime_confidence>0 的位置视为已确定)
209
+ gen_mask = overtime_confidence[j] > 0
210
+ current_gen = int(gen_mask.sum().item())
211
+
212
+ # 若超过目标,反遮盖(remask)最低置信度的那部分,使当前累计 ≈ 目标累计
213
+ to_remask = max(0, current_gen - target_cum)
214
+ if to_remask > 0:
215
+ gen_indices = torch.where(gen_mask)[0]
216
+ if gen_indices.numel() > 0:
217
+ gen_conf = overtime_confidence[j, gen_indices]
218
+ to_remask = min(to_remask, int(gen_indices.numel()))
219
+ _, local_low = torch.topk(gen_conf, k=to_remask, largest=False)
220
+ low_global = gen_indices[local_low]
221
+ x[j, low_global] = mask_token_id
222
+ overtime_confidence[j, low_global] = 0.0
223
+
224
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
225
+ """Performs validation related to the resulting generated length"""
226
+
227
+ # Can't throw warnings/exceptions during compilation
228
  if is_torchdynamo_compiling():
229
  return
230
+
231
+ # 1. Max length warnings related to poor parameterization
232
  if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
233
+ # 20 is the default max_length of the generation config
234
  warnings.warn(
235
+ f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
236
+ "generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
237
+ "generation.",
238
  UserWarning,
239
  )
240
  if input_ids_length >= generation_config.max_length:
241
+ input_ids_string = "input_ids"
242
  raise ValueError(
243
+ f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
244
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
245
+ " increasing `max_length` or, better yet, setting `max_new_tokens`."
246
  )
247
 
248
+ def _prepare_generated_length(
249
+ self,
250
+ generation_config,
251
+ has_default_max_length,
252
+ input_ids_length,
253
+ ):
254
+ """Prepared max and min length in generation configs to avoid clashes between similar attributes"""
255
+
256
  if generation_config.max_new_tokens is not None:
257
  if not has_default_max_length and generation_config.max_length is not None:
258
  logger.warning(
259
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
260
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
261
+ "Please refer to the documentation for more information. "
262
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
263
  )
264
  generation_config.max_length = generation_config.max_new_tokens + input_ids_length
265
+
266
  elif has_default_max_length:
267
  if generation_config.max_length == DreamGenerationConfig().max_length:
268
  generation_config.max_length = generation_config.max_length + input_ids_length
269
  max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
270
  if max_position_embeddings is not None:
271
  generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
272
+
273
  return generation_config
274
 
275
  def _prepare_generation_config(
276
  self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict
277
  ) -> DreamGenerationConfig:
278
+ """
279
+ Prepares the base generation config, then applies any generation configuration options from kwargs. This
280
+ function handles retrocompatibility with respect to configuration files.
281
+ """
282
+ # priority: `generation_config` argument > `model.generation_config` (the default generation config)
283
  using_model_generation_config = False
284
  if generation_config is None:
285
  generation_config = DreamGenerationConfig.from_model_config(self.config)
286
  using_model_generation_config = True
287
 
288
+ # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
289
+ # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an
290
+ # exception will be raised in `_validate_model_kwargs`
291
  if not is_torchdynamo_compiling():
292
  generation_config = copy.deepcopy(generation_config)
293
  _kwargs = generation_config.update(**kwargs)
294
+ # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
295
  if not using_model_generation_config:
296
  if generation_config.bos_token_id is None:
297
  generation_config.bos_token_id = self.generation_config.bos_token_id
 
309
  generation_config: DreamGenerationConfig,
310
  device: Optional[Union[torch.device, str]] = None,
311
  ):
312
+ """
313
+ Prepares the special tokens for generation, overwriting the generation config with their processed versions
314
+ converted to tensor.
315
+
316
+ Note that `generation_config` is changed in place and stops being serializable after this method is called.
317
+ That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the
318
+ function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
319
+ """
320
+
321
+ # Convert special tokens to tensors
322
  def _tensor_or_none(token, device=None):
323
  if token is None:
324
  return token
325
+
326
  device = device if device is not None else self.device
327
  if isinstance(token, torch.Tensor):
328
  return token.to(device)
 
333
  pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
334
  mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)
335
 
336
+ # We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
337
  if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
338
  eos_token_tensor = eos_token_tensor.unsqueeze(0)
339
 
340
+ # Set pad token if unset (and there are conditions to do so)
341
  if pad_token_tensor is None and eos_token_tensor is not None:
342
  pad_token_tensor = eos_token_tensor[0]
343
  logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
344
 
345
+ # Update generation config with the updated special tokens tensors
346
+ # NOTE: this must be written into a different attribute name than the one holding the original special tokens
347
+ # (in their non-tensor form), in order to enable end-to-end compilation. See
348
+ # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
349
  generation_config._bos_token_tensor = bos_token_tensor
350
  generation_config._eos_token_tensor = eos_token_tensor
351
  generation_config._pad_token_tensor = pad_token_tensor
 
358
  generation_config: Optional[DreamGenerationConfig] = None,
359
  **kwargs,
360
  ) -> Union[DreamModelOutput, torch.LongTensor]:
361
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
362
  generation_config = self._prepare_generation_config(generation_config, **kwargs)
363
  generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
364
  generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
365
 
366
+ # 2. Define model inputs
367
  assert inputs is not None
368
  input_ids = inputs
369
  device = input_ids.device
370
  attention_mask = kwargs.pop("attention_mask", None)
371
  self._prepare_special_tokens(generation_config, device=device)
372
 
373
+ # 3. Prepare `max_length`.
374
  input_ids_length = input_ids.shape[-1]
375
  has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
376
  generation_config = self._prepare_generated_length(
 
378
  has_default_max_length=has_default_max_length,
379
  input_ids_length=input_ids_length,
380
  )
 
381
 
382
+ self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
383
+
384
+ # 4. Check input_ids
385
  if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
386
  warnings.warn(
387
+ "You are calling .generate() with the `input_ids` being on a device type different"
388
+ f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
389
+ f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
390
+ " Please make sure that you have put `input_ids` to the"
391
+ f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
392
+ " running `.generate()`.",
393
  UserWarning,
394
  )
395
+ if (
396
+ hasattr(generation_config, "pad_token_id") and
397
+ torch.any(input_ids == generation_config.pad_token_id) and
398
+ attention_mask is None
399
+ ):
400
  warnings.warn(
401
+ "Padding was detected but no attention mask is passed here. For correct "
402
+ "generation results, please set `attention_mask` when batch-padding inputs.",
403
  UserWarning,
404
  )
405
 
406
  input_ids, attention_mask = self._expand_inputs_for_generation(
407
  expand_size=generation_config.num_return_sequences,
408
  input_ids=input_ids,
409
+ attention_mask=attention_mask
410
  )
411
 
412
  result = self._sample(
 
414
  attention_mask=attention_mask,
415
  generation_config=generation_config,
416
  generation_tokens_hook_func=generation_tokens_hook_func,
417
+ generation_logits_hook_func=generation_logits_hook_func
418
  )
419
  return result
420
 
 
424
  attention_mask: Optional[torch.LongTensor],
425
  generation_config: DreamGenerationConfig,
426
  generation_tokens_hook_func,
427
+ generation_logits_hook_func
428
  ) -> Union[DreamModelOutput, torch.LongTensor]:
429
+ # init values
 
430
  output_history = generation_config.output_history
431
  return_dict_in_generate = generation_config.return_dict_in_generate
432
  max_length = generation_config.max_length
 
439
  top_p = generation_config.top_p
440
  top_k = generation_config.top_k
441
 
442
+ # RCR specific values
443
  rcr = generation_config.rcr
444
  conf_alg = generation_config.conf_alg
 
445
 
446
  histories = [] if (return_dict_in_generate and output_history) else None
447
 
448
+ # pad input_ids to max_length
449
  x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
450
 
451
  if attention_mask is not None and torch.any(attention_mask == 0.0):
452
+ # we do not mask the [MASK] tokens so value = 1.0
453
  attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
454
  tok_idx = attention_mask.long().cumsum(-1) - 1
455
  tok_idx.masked_fill_(attention_mask == 0, 1)
456
+ # attention_mask is of shape [B, N]
457
+ # broadcast to [B, 1, N, N]
458
  attention_mask = torch.logical_and(
459
  attention_mask.unsqueeze(1).unsqueeze(-2),
460
  attention_mask.unsqueeze(1).unsqueeze(-1),
 
463
  tok_idx = None
464
  attention_mask = "full"
465
 
 
466
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
467
 
468
+ # RCR tracking - initialize overtime confidence tracking
 
 
 
 
469
  overtime_confidence = torch.zeros_like(x, dtype=torch.float32) if rcr else None
470
 
471
+ # this allows user-defined token control of the intermediate steps
472
  x = generation_tokens_hook_func(None, x, None)
 
473
  for i in range(steps):
474
+ mask_index = (x == mask_token_id)
 
 
 
475
  logits = self(x, attention_mask, tok_idx).logits
476
+ logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
477
 
478
+ # this allows user-defined logits control of the intermediate steps
479
  logits = generation_logits_hook_func(i, x, logits)
480
 
481
+ mask_logits = logits[mask_index]
 
482
  t = timesteps[i]
483
  s = timesteps[i + 1]
484
+
485
+ if alg == 'origin':
486
+ p_transfer = 1 - s / t if i < steps - 1 else 1
487
+ x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
488
+ transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
489
+ _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
 
 
 
 
 
 
490
  x[mask_index] = x0.clone()
491
  else:
492
+ if alg == 'maskgit_plus' or (rcr and conf_alg == 'maskgit_plus'):
 
 
493
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
494
+ elif alg == 'topk_margin' or (rcr and conf_alg == 'topk_margin'):
495
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
496
+ elif alg == 'entropy' or (rcr and conf_alg == 'entropy'):
497
+ confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
498
  else:
499
+ raise RuntimeError(f"Unknown alg: {alg}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
 
501
+ # Apply RCR logic if enabled
502
+ if rcr:
503
+ print(f"[RCR EXEC] Step {i}: RCR logic executed")
504
+ self._apply_rcr_logic(x, x0, confidence, mask_index, overtime_confidence,
505
+ mask_token_id, i, steps, s, t)
506
  else:
507
+ # Original Dream sampling logic
508
+ num_mask_token = mask_index.sum() / mask_index.shape[0]
509
+ number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
510
+ # --------- 仅此处小修:device 用 x.device,避免跨设备 ----------
511
+ full_confidence = torch.full_like(x, -torch.inf, device=x.device, dtype=logits.dtype)
512
+ full_confidence[mask_index] = confidence
513
+ if number_transfer_tokens > 0:
514
+ if alg_temp is None or alg_temp == 0:
515
+ _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
516
+ else:
517
+ full_confidence = full_confidence / alg_temp
518
+ full_confidence = F.softmax(full_confidence, dim=-1)
519
+ transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
520
+ x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
521
+ x_[mask_index] = x0.clone()
522
+ row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
523
+ x[row_indices,transfer_index] = x_[row_indices,transfer_index]
524
+
525
+ # this allows user-defined token control of the intermediate steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  x = generation_tokens_hook_func(i, x, logits)
527
 
528
  if histories is not None:
529
  histories.append(x.clone())
530
+
531
  if return_dict_in_generate:
532
+ return DreamModelOutput(
533
+ sequences=x,
534
+ history=histories,
535
+ )
536
  else:
537
  return x