autoprogrammer commited on
Commit
32f22f3
·
verified ·
1 Parent(s): 762a23f

Update generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +31 -34
generation_utils.py CHANGED
@@ -1,8 +1,9 @@
1
  # coding=utf-8
2
- # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team.
 
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
  # You may obtain a copy of the License at
7
  #
8
  # http://www.apache.org/licenses/LICENSE-2.0
@@ -77,7 +78,8 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confid
77
  if neg_entropy:
78
  epsilon = 1e-10
79
  log_probs = torch.log(probs + epsilon)
80
- confidence = torch.sum(probs * log_probs, dim=-1)
 
81
 
82
  return confidence, x0
83
 
@@ -101,9 +103,8 @@ class DreamGenerationConfig(GenerationConfig):
101
  self.alg: str = kwargs.pop("alg", 'origin')
102
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
103
 
104
- # === RCR 相关参数(新增;默认不影响原逻辑) ===
105
  self.rcr: bool = kwargs.pop("rcr", False)
106
- # 仅在 rcr=True 时用于选择置信度算法;rcr=False 不读取它
107
  self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
108
 
109
  # Parameters that define the output variables of `generate`
@@ -120,7 +121,7 @@ class DreamGenerationConfig(GenerationConfig):
120
  # Wild card
121
  self.generation_kwargs = kwargs.pop("generation_kwargs", {})
122
 
123
- # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub interface.
124
  self._from_model_config = kwargs.pop("_from_model_config", False)
125
  self._commit_hash = kwargs.pop("_commit_hash", None)
126
  self.transformers_version = kwargs.pop("transformers_version", __version__)
@@ -154,48 +155,46 @@ class DreamGenerationMixin:
154
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
155
  return input_ids, attention_mask
156
 
157
- # === 新增:RCR 逻辑,仅在 rcr=True 时被调用;不改动非 RCR 分支 ===
158
  def _apply_rcr_logic(self, x, x0, confidence, mask_index, overtime_confidence,
159
  mask_token_id, step, total_steps, s, t):
160
  """
161
- 在 Dream 的“maskgit”采样骨架上,执行 Running Confidence Remasking:
162
- - 本步采用 Dream 原调度:global_k = num_mask_token * (1 - s/t)
163
- - 先以当前置信度将 top-k token 从 [MASK] 转为预测 token,并累计它们的置信度
164
- - 再施加“目标累计”约束:截至本步应累计生成 target_cum = num_mask_token * (1 - s/t)
165
- 若当前累计 > 目标,则把最低置信度的那些 token 反遮盖回 [MASK]
166
- 说明:只影响 rcr=True 的路径;rcr=False 时完全不调用本函数。
167
  """
168
  device = x.device
169
  B = x.shape[0]
170
 
171
- # 与 Dream 一致的 num_mask_token(按 batch 平均)
172
  num_mask_token = mask_index.sum() / mask_index.shape[0]
173
- # 本步的转移数量(按 Dream 调度)
174
  number_transfer_tokens = int(num_mask_token * (1 - s / t)) if step < total_steps - 1 else int(num_mask_token)
175
 
176
- # 构造全长置信度和候选值(非 mask 位置分别设为 -inf / mask_token_id
177
  full_conf = torch.full_like(x, -torch.inf, device=device, dtype=confidence.dtype)
178
  x_temp = torch.zeros_like(x, device=device, dtype=torch.long) + mask_token_id
179
  full_conf[mask_index] = confidence
180
  x_temp[mask_index] = x0.clone()
181
 
182
  for j in range(B):
183
- # 逐样本 clamp,避免 batch 均值带来越界
184
  masked_j = int(mask_index[j].sum().item())
185
  k_j = min(number_transfer_tokens, masked_j)
186
 
187
- # 先按置信度选出本步 top-k_j
188
  if k_j > 0:
189
  _, select_idx = torch.topk(full_conf[j], k=k_j, largest=True)
190
  x[j, select_idx] = x_temp[j, select_idx]
191
  overtime_confidence[j, select_idx] = full_conf[j, select_idx].clone().float()
192
 
193
- # 目标累计约束:截至本步应累计的生成数
194
  if step < total_steps - 1:
195
- target_cum = int(num_mask_token * (1 - s)) # 累计目标:随 s 递减而线性增长
196
- gen_mask = overtime_confidence[j] > 0
 
197
  current_gen = int(gen_mask.sum().item())
198
- # 若超额,则按最低置信度回遮
199
  to_remask = max(0, current_gen - target_cum)
200
  if to_remask > 0:
201
  gen_indices = torch.where(gen_mask)[0]
@@ -205,7 +204,7 @@ class DreamGenerationMixin:
205
  _, local_low = torch.topk(gen_conf, k=to_remask, largest=False)
206
  low_global = gen_indices[local_low]
207
  x[j, low_global] = mask_token_id
208
- overtime_confidence[j, low_global] = 0.0
209
 
210
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
211
  if is_torchdynamo_compiling():
@@ -362,7 +361,7 @@ class DreamGenerationMixin:
362
  generation_tokens_hook_func,
363
  generation_logits_hook_func
364
  ) -> Union[DreamModelOutput, torch.LongTensor]:
365
- # === 原变量 ===
366
  output_history = generation_config.output_history
367
  return_dict_in_generate = generation_config.return_dict_in_generate
368
  max_length = generation_config.max_length
@@ -375,7 +374,7 @@ class DreamGenerationMixin:
375
  top_p = generation_config.top_p
376
  top_k = generation_config.top_k
377
 
378
- # === 新增:RCR 控制变量(不会影响 rcr=False 的路径) ===
379
  rcr = generation_config.rcr
380
  conf_alg = generation_config.conf_alg
381
 
@@ -398,8 +397,8 @@ class DreamGenerationMixin:
398
 
399
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
400
 
401
- # === 仅在 rcr=True 时分配 Overtime Confidence(不影响 baseline) ===
402
- overtime_confidence = torch.zeros_like(x, dtype=torch.float32) if rcr else None
403
 
404
  # this allows user-defined token control of the intermediate steps
405
  x = generation_tokens_hook_func(None, x, None)
@@ -416,7 +415,7 @@ class DreamGenerationMixin:
416
  s = timesteps[i + 1]
417
 
418
  if alg == 'origin':
419
- # === 原版 origin 分支:保持不变 ===
420
  p_transfer = 1 - s / t if i < steps - 1 else 1
421
  x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
422
  transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
@@ -425,8 +424,7 @@ class DreamGenerationMixin:
425
  )
426
  x[mask_index] = x0.clone()
427
  else:
428
- # === origin 分支 ===
429
- # rcr=False:保持原有使用 alg 的置信度算法
430
  # rcr=True :使用 conf_alg 指定的置信度算法(不改变 rcr=False 的行为)
431
  if (not rcr and alg == 'maskgit_plus') or (rcr and conf_alg == 'maskgit_plus'):
432
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
@@ -439,7 +437,6 @@ class DreamGenerationMixin:
439
  mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True
440
  )
441
  else:
442
- # 兼容:如果 rcr=True 但 conf_alg 非上述三者,回退到 alg 指定
443
  if rcr:
444
  if alg == 'maskgit_plus':
445
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
@@ -457,14 +454,14 @@ class DreamGenerationMixin:
457
  raise RuntimeError(f"Unknown alg: {alg}")
458
 
459
  if rcr:
460
- # === 仅在 rcr=True 时:应用 RCR;不会触碰 baseline 分支实现 ===
461
- print("rcr")
462
  self._apply_rcr_logic(
463
  x, x0, confidence, mask_index, overtime_confidence,
464
  mask_token_id, i, steps, s, t
465
  )
466
  else:
467
- # === 原版 Dream 逻辑:保持不变(包括 device=self.device 等细节) ===
468
  num_mask_token = mask_index.sum() / mask_index.shape[0]
469
  number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
470
  full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
 
1
  # coding=utf-8
2
+ # Copyright 2024 The Dream team, HKUNLP Group and the
3
+ # HuggingFace Inc. team. All rights reserved.
4
  #
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # You may not use this file except in compliance with the License.
7
  # You may obtain a copy of the License at
8
  #
9
  # http://www.apache.org/licenses/LICENSE-2.0
 
78
  if neg_entropy:
79
  epsilon = 1e-10
80
  log_probs = torch.log(probs + epsilon)
81
+ # 改动 1:用“负熵”的正值(越大越自信),与其它置信度方向保持一致
82
+ confidence = -(probs * log_probs).sum(dim=-1)
83
 
84
  return confidence, x0
85
 
 
103
  self.alg: str = kwargs.pop("alg", 'origin')
104
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
105
 
106
+ # === RCR 相关参数(默认不影响原逻辑) ===
107
  self.rcr: bool = kwargs.pop("rcr", False)
 
108
  self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
109
 
110
  # Parameters that define the output variables of `generate`
 
121
  # Wild card
122
  self.generation_kwargs = kwargs.pop("generation_kwargs", {})
123
 
124
+ # hub interface
125
  self._from_model_config = kwargs.pop("_from_model_config", False)
126
  self._commit_hash = kwargs.pop("_commit_hash", None)
127
  self.transformers_version = kwargs.pop("transformers_version", __version__)
 
155
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
156
  return input_ids, attention_mask
157
 
158
+ # === RCR:仅在 rcr=True 时调用;不改动 baseline 分支 ===
159
  def _apply_rcr_logic(self, x, x0, confidence, mask_index, overtime_confidence,
160
  mask_token_id, step, total_steps, s, t):
161
  """
162
+ Running Confidence Remasking:
163
+ - 采用 Dream 的调度:k_step = num_mask_token * (1 - s/t)
164
+ - 本步先按置信度从 [MASK] 中挑 top-k_step 写入预测,并把置信度累计到 overtime_confidence
165
+ - 再施加“累计目标”约束:target_cum = num_mask_token * (1 - s/t)
166
+ 若当前累计 > 目标,则把最低置信度的 token 反遮回 [MASK]
 
167
  """
168
  device = x.device
169
  B = x.shape[0]
170
 
171
+ # 与 Dream 一致的“批均值”口径
172
  num_mask_token = mask_index.sum() / mask_index.shape[0]
 
173
  number_transfer_tokens = int(num_mask_token * (1 - s / t)) if step < total_steps - 1 else int(num_mask_token)
174
 
175
+ # 构造全长置信度和候选(非 mask -inf / mask_token
176
  full_conf = torch.full_like(x, -torch.inf, device=device, dtype=confidence.dtype)
177
  x_temp = torch.zeros_like(x, device=device, dtype=torch.long) + mask_token_id
178
  full_conf[mask_index] = confidence
179
  x_temp[mask_index] = x0.clone()
180
 
181
  for j in range(B):
 
182
  masked_j = int(mask_index[j].sum().item())
183
  k_j = min(number_transfer_tokens, masked_j)
184
 
185
+ # 先选本步 top-k_j
186
  if k_j > 0:
187
  _, select_idx = torch.topk(full_conf[j], k=k_j, largest=True)
188
  x[j, select_idx] = x_temp[j, select_idx]
189
  overtime_confidence[j, select_idx] = full_conf[j, select_idx].clone().float()
190
 
191
+ # 累计目标(与 baseline 对齐)
192
  if step < total_steps - 1:
193
+ target_cum = int(num_mask_token * (1 - s / t))
194
+ # 改动 2:用有限性判断“已生成”,而不是 > 0
195
+ gen_mask = torch.isfinite(overtime_confidence[j])
196
  current_gen = int(gen_mask.sum().item())
197
+
198
  to_remask = max(0, current_gen - target_cum)
199
  if to_remask > 0:
200
  gen_indices = torch.where(gen_mask)[0]
 
204
  _, local_low = torch.topk(gen_conf, k=to_remask, largest=False)
205
  low_global = gen_indices[local_low]
206
  x[j, low_global] = mask_token_id
207
+ overtime_confidence[j, low_global] = float("-inf")
208
 
209
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
210
  if is_torchdynamo_compiling():
 
361
  generation_tokens_hook_func,
362
  generation_logits_hook_func
363
  ) -> Union[DreamModelOutput, torch.LongTensor]:
364
+ # ---- 原参数 ----
365
  output_history = generation_config.output_history
366
  return_dict_in_generate = generation_config.return_dict_in_generate
367
  max_length = generation_config.max_length
 
374
  top_p = generation_config.top_p
375
  top_k = generation_config.top_k
376
 
377
+ # ---- RCR 参数 ----
378
  rcr = generation_config.rcr
379
  conf_alg = generation_config.conf_alg
380
 
 
397
 
398
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
399
 
400
+ # 改动 2:仅在 rcr=True 时,用 -inf 初始化,后续用 isfinite 判断
401
+ overtime_confidence = torch.full_like(x, float("-inf"), dtype=torch.float32) if rcr else None
402
 
403
  # this allows user-defined token control of the intermediate steps
404
  x = generation_tokens_hook_func(None, x, None)
 
415
  s = timesteps[i + 1]
416
 
417
  if alg == 'origin':
418
+ # 原版 origin 分支:保持不变
419
  p_transfer = 1 - s / t if i < steps - 1 else 1
420
  x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
421
  transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
 
424
  )
425
  x[mask_index] = x0.clone()
426
  else:
427
+ # rcr=False:沿用 alg 指定的置信度算法
 
428
  # rcr=True :使用 conf_alg 指定的置信度算法(不改变 rcr=False 的行为)
429
  if (not rcr and alg == 'maskgit_plus') or (rcr and conf_alg == 'maskgit_plus'):
430
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
 
437
  mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True
438
  )
439
  else:
 
440
  if rcr:
441
  if alg == 'maskgit_plus':
442
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
 
454
  raise RuntimeError(f"Unknown alg: {alg}")
455
 
456
  if rcr:
457
+ # 仅在 rcr=True:应用 RCR
458
+ print("[RCR] step", i)
459
  self._apply_rcr_logic(
460
  x, x0, confidence, mask_index, overtime_confidence,
461
  mask_token_id, i, steps, s, t
462
  )
463
  else:
464
+ # 原版 Dream 逻辑:保持不变
465
  num_mask_token = mask_index.sum() / mask_index.shape[0]
466
  number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
467
  full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)