Update generation_utils.py
Browse files- generation_utils.py +31 -34
generation_utils.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
# coding=utf-8
|
| 2 |
-
# Copyright 2024 The Dream team, HKUNLP Group and the
|
|
|
|
| 3 |
#
|
| 4 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
#
|
| 6 |
# You may obtain a copy of the License at
|
| 7 |
#
|
| 8 |
# http://www.apache.org/licenses/LICENSE-2.0
|
|
@@ -77,7 +78,8 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confid
|
|
| 77 |
if neg_entropy:
|
| 78 |
epsilon = 1e-10
|
| 79 |
log_probs = torch.log(probs + epsilon)
|
| 80 |
-
|
|
|
|
| 81 |
|
| 82 |
return confidence, x0
|
| 83 |
|
|
@@ -101,9 +103,8 @@ class DreamGenerationConfig(GenerationConfig):
|
|
| 101 |
self.alg: str = kwargs.pop("alg", 'origin')
|
| 102 |
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
|
| 103 |
|
| 104 |
-
# === RCR
|
| 105 |
self.rcr: bool = kwargs.pop("rcr", False)
|
| 106 |
-
# 仅在 rcr=True 时用于选择置信度算法;rcr=False 不读取它
|
| 107 |
self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
|
| 108 |
|
| 109 |
# Parameters that define the output variables of `generate`
|
|
@@ -120,7 +121,7 @@ class DreamGenerationConfig(GenerationConfig):
|
|
| 120 |
# Wild card
|
| 121 |
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
| 122 |
|
| 123 |
-
#
|
| 124 |
self._from_model_config = kwargs.pop("_from_model_config", False)
|
| 125 |
self._commit_hash = kwargs.pop("_commit_hash", None)
|
| 126 |
self.transformers_version = kwargs.pop("transformers_version", __version__)
|
|
@@ -154,48 +155,46 @@ class DreamGenerationMixin:
|
|
| 154 |
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
|
| 155 |
return input_ids, attention_mask
|
| 156 |
|
| 157 |
-
# ===
|
| 158 |
def _apply_rcr_logic(self, x, x0, confidence, mask_index, overtime_confidence,
|
| 159 |
mask_token_id, step, total_steps, s, t):
|
| 160 |
"""
|
| 161 |
-
|
| 162 |
-
-
|
| 163 |
-
-
|
| 164 |
-
-
|
| 165 |
-
若当前累计 >
|
| 166 |
-
说明:只影响 rcr=True 的路径;rcr=False 时完全不调用本函数。
|
| 167 |
"""
|
| 168 |
device = x.device
|
| 169 |
B = x.shape[0]
|
| 170 |
|
| 171 |
-
# 与 Dream
|
| 172 |
num_mask_token = mask_index.sum() / mask_index.shape[0]
|
| 173 |
-
# 本步的转移数量(按 Dream 调度)
|
| 174 |
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if step < total_steps - 1 else int(num_mask_token)
|
| 175 |
|
| 176 |
-
#
|
| 177 |
full_conf = torch.full_like(x, -torch.inf, device=device, dtype=confidence.dtype)
|
| 178 |
x_temp = torch.zeros_like(x, device=device, dtype=torch.long) + mask_token_id
|
| 179 |
full_conf[mask_index] = confidence
|
| 180 |
x_temp[mask_index] = x0.clone()
|
| 181 |
|
| 182 |
for j in range(B):
|
| 183 |
-
# 逐样本 clamp,避免 batch 均值带来越界
|
| 184 |
masked_j = int(mask_index[j].sum().item())
|
| 185 |
k_j = min(number_transfer_tokens, masked_j)
|
| 186 |
|
| 187 |
-
#
|
| 188 |
if k_j > 0:
|
| 189 |
_, select_idx = torch.topk(full_conf[j], k=k_j, largest=True)
|
| 190 |
x[j, select_idx] = x_temp[j, select_idx]
|
| 191 |
overtime_confidence[j, select_idx] = full_conf[j, select_idx].clone().float()
|
| 192 |
|
| 193 |
-
#
|
| 194 |
if step < total_steps - 1:
|
| 195 |
-
target_cum = int(num_mask_token * (1 - s))
|
| 196 |
-
|
|
|
|
| 197 |
current_gen = int(gen_mask.sum().item())
|
| 198 |
-
|
| 199 |
to_remask = max(0, current_gen - target_cum)
|
| 200 |
if to_remask > 0:
|
| 201 |
gen_indices = torch.where(gen_mask)[0]
|
|
@@ -205,7 +204,7 @@ class DreamGenerationMixin:
|
|
| 205 |
_, local_low = torch.topk(gen_conf, k=to_remask, largest=False)
|
| 206 |
low_global = gen_indices[local_low]
|
| 207 |
x[j, low_global] = mask_token_id
|
| 208 |
-
overtime_confidence[j, low_global] =
|
| 209 |
|
| 210 |
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
|
| 211 |
if is_torchdynamo_compiling():
|
|
@@ -362,7 +361,7 @@ class DreamGenerationMixin:
|
|
| 362 |
generation_tokens_hook_func,
|
| 363 |
generation_logits_hook_func
|
| 364 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
| 365 |
-
#
|
| 366 |
output_history = generation_config.output_history
|
| 367 |
return_dict_in_generate = generation_config.return_dict_in_generate
|
| 368 |
max_length = generation_config.max_length
|
|
@@ -375,7 +374,7 @@ class DreamGenerationMixin:
|
|
| 375 |
top_p = generation_config.top_p
|
| 376 |
top_k = generation_config.top_k
|
| 377 |
|
| 378 |
-
#
|
| 379 |
rcr = generation_config.rcr
|
| 380 |
conf_alg = generation_config.conf_alg
|
| 381 |
|
|
@@ -398,8 +397,8 @@ class DreamGenerationMixin:
|
|
| 398 |
|
| 399 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
| 400 |
|
| 401 |
-
#
|
| 402 |
-
overtime_confidence = torch.
|
| 403 |
|
| 404 |
# this allows user-defined token control of the intermediate steps
|
| 405 |
x = generation_tokens_hook_func(None, x, None)
|
|
@@ -416,7 +415,7 @@ class DreamGenerationMixin:
|
|
| 416 |
s = timesteps[i + 1]
|
| 417 |
|
| 418 |
if alg == 'origin':
|
| 419 |
-
#
|
| 420 |
p_transfer = 1 - s / t if i < steps - 1 else 1
|
| 421 |
x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
|
| 422 |
transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
|
|
@@ -425,8 +424,7 @@ class DreamGenerationMixin:
|
|
| 425 |
)
|
| 426 |
x[mask_index] = x0.clone()
|
| 427 |
else:
|
| 428 |
-
#
|
| 429 |
-
# rcr=False:保持原有使用 alg 的置信度算法
|
| 430 |
# rcr=True :使用 conf_alg 指定的置信度算法(不改变 rcr=False 的行为)
|
| 431 |
if (not rcr and alg == 'maskgit_plus') or (rcr and conf_alg == 'maskgit_plus'):
|
| 432 |
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
|
@@ -439,7 +437,6 @@ class DreamGenerationMixin:
|
|
| 439 |
mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True
|
| 440 |
)
|
| 441 |
else:
|
| 442 |
-
# 兼容:如果 rcr=True 但 conf_alg 非上述三者,回退到 alg 指定
|
| 443 |
if rcr:
|
| 444 |
if alg == 'maskgit_plus':
|
| 445 |
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
|
@@ -457,14 +454,14 @@ class DreamGenerationMixin:
|
|
| 457 |
raise RuntimeError(f"Unknown alg: {alg}")
|
| 458 |
|
| 459 |
if rcr:
|
| 460 |
-
#
|
| 461 |
-
print("
|
| 462 |
self._apply_rcr_logic(
|
| 463 |
x, x0, confidence, mask_index, overtime_confidence,
|
| 464 |
mask_token_id, i, steps, s, t
|
| 465 |
)
|
| 466 |
else:
|
| 467 |
-
#
|
| 468 |
num_mask_token = mask_index.sum() / mask_index.shape[0]
|
| 469 |
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
|
| 470 |
full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
|
|
|
|
| 1 |
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The Dream team, HKUNLP Group and the
|
| 3 |
+
# HuggingFace Inc. team. All rights reserved.
|
| 4 |
#
|
| 5 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# You may not use this file except in compliance with the License.
|
| 7 |
# You may obtain a copy of the License at
|
| 8 |
#
|
| 9 |
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
| 78 |
if neg_entropy:
|
| 79 |
epsilon = 1e-10
|
| 80 |
log_probs = torch.log(probs + epsilon)
|
| 81 |
+
# 改动 1:用“负熵”的正值(越大越自信),与其它置信度方向保持一致
|
| 82 |
+
confidence = -(probs * log_probs).sum(dim=-1)
|
| 83 |
|
| 84 |
return confidence, x0
|
| 85 |
|
|
|
|
| 103 |
self.alg: str = kwargs.pop("alg", 'origin')
|
| 104 |
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
|
| 105 |
|
| 106 |
+
# === RCR 相关参数(默认不影响原逻辑) ===
|
| 107 |
self.rcr: bool = kwargs.pop("rcr", False)
|
|
|
|
| 108 |
self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
|
| 109 |
|
| 110 |
# Parameters that define the output variables of `generate`
|
|
|
|
| 121 |
# Wild card
|
| 122 |
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
| 123 |
|
| 124 |
+
# hub interface
|
| 125 |
self._from_model_config = kwargs.pop("_from_model_config", False)
|
| 126 |
self._commit_hash = kwargs.pop("_commit_hash", None)
|
| 127 |
self.transformers_version = kwargs.pop("transformers_version", __version__)
|
|
|
|
| 155 |
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
|
| 156 |
return input_ids, attention_mask
|
| 157 |
|
| 158 |
+
# === RCR:仅在 rcr=True 时调用;不改动 baseline 分支 ===
|
| 159 |
def _apply_rcr_logic(self, x, x0, confidence, mask_index, overtime_confidence,
|
| 160 |
mask_token_id, step, total_steps, s, t):
|
| 161 |
"""
|
| 162 |
+
Running Confidence Remasking:
|
| 163 |
+
- 采用 Dream 的调度:k_step = num_mask_token * (1 - s/t)
|
| 164 |
+
- 本步先按置信度从 [MASK] 中挑 top-k_step 写入预测,并把置信度累计到 overtime_confidence
|
| 165 |
+
- 再施加“累计目标”约束:target_cum = num_mask_token * (1 - s/t)
|
| 166 |
+
若当前累计 > 目标,则把最低置信度的 token 反遮回 [MASK]
|
|
|
|
| 167 |
"""
|
| 168 |
device = x.device
|
| 169 |
B = x.shape[0]
|
| 170 |
|
| 171 |
+
# 与 Dream 一致的“批均值”口径
|
| 172 |
num_mask_token = mask_index.sum() / mask_index.shape[0]
|
|
|
|
| 173 |
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if step < total_steps - 1 else int(num_mask_token)
|
| 174 |
|
| 175 |
+
# 构造全长置信度和候选(非 mask 置 -inf / mask_token)
|
| 176 |
full_conf = torch.full_like(x, -torch.inf, device=device, dtype=confidence.dtype)
|
| 177 |
x_temp = torch.zeros_like(x, device=device, dtype=torch.long) + mask_token_id
|
| 178 |
full_conf[mask_index] = confidence
|
| 179 |
x_temp[mask_index] = x0.clone()
|
| 180 |
|
| 181 |
for j in range(B):
|
|
|
|
| 182 |
masked_j = int(mask_index[j].sum().item())
|
| 183 |
k_j = min(number_transfer_tokens, masked_j)
|
| 184 |
|
| 185 |
+
# 先选本步 top-k_j
|
| 186 |
if k_j > 0:
|
| 187 |
_, select_idx = torch.topk(full_conf[j], k=k_j, largest=True)
|
| 188 |
x[j, select_idx] = x_temp[j, select_idx]
|
| 189 |
overtime_confidence[j, select_idx] = full_conf[j, select_idx].clone().float()
|
| 190 |
|
| 191 |
+
# 累计目标(与 baseline 对齐)
|
| 192 |
if step < total_steps - 1:
|
| 193 |
+
target_cum = int(num_mask_token * (1 - s / t))
|
| 194 |
+
# 改动 2:用有限性判断“已生成”,而不是 > 0
|
| 195 |
+
gen_mask = torch.isfinite(overtime_confidence[j])
|
| 196 |
current_gen = int(gen_mask.sum().item())
|
| 197 |
+
|
| 198 |
to_remask = max(0, current_gen - target_cum)
|
| 199 |
if to_remask > 0:
|
| 200 |
gen_indices = torch.where(gen_mask)[0]
|
|
|
|
| 204 |
_, local_low = torch.topk(gen_conf, k=to_remask, largest=False)
|
| 205 |
low_global = gen_indices[local_low]
|
| 206 |
x[j, low_global] = mask_token_id
|
| 207 |
+
overtime_confidence[j, low_global] = float("-inf")
|
| 208 |
|
| 209 |
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
|
| 210 |
if is_torchdynamo_compiling():
|
|
|
|
| 361 |
generation_tokens_hook_func,
|
| 362 |
generation_logits_hook_func
|
| 363 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
| 364 |
+
# ---- 原参数 ----
|
| 365 |
output_history = generation_config.output_history
|
| 366 |
return_dict_in_generate = generation_config.return_dict_in_generate
|
| 367 |
max_length = generation_config.max_length
|
|
|
|
| 374 |
top_p = generation_config.top_p
|
| 375 |
top_k = generation_config.top_k
|
| 376 |
|
| 377 |
+
# ---- RCR 参数 ----
|
| 378 |
rcr = generation_config.rcr
|
| 379 |
conf_alg = generation_config.conf_alg
|
| 380 |
|
|
|
|
| 397 |
|
| 398 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
| 399 |
|
| 400 |
+
# 改动 2:仅在 rcr=True 时,用 -inf 初始化,后续用 isfinite 判断
|
| 401 |
+
overtime_confidence = torch.full_like(x, float("-inf"), dtype=torch.float32) if rcr else None
|
| 402 |
|
| 403 |
# this allows user-defined token control of the intermediate steps
|
| 404 |
x = generation_tokens_hook_func(None, x, None)
|
|
|
|
| 415 |
s = timesteps[i + 1]
|
| 416 |
|
| 417 |
if alg == 'origin':
|
| 418 |
+
# 原版 origin 分支:保持不变
|
| 419 |
p_transfer = 1 - s / t if i < steps - 1 else 1
|
| 420 |
x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
|
| 421 |
transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
|
|
|
|
| 424 |
)
|
| 425 |
x[mask_index] = x0.clone()
|
| 426 |
else:
|
| 427 |
+
# rcr=False:沿用 alg 指定的置信度算法
|
|
|
|
| 428 |
# rcr=True :使用 conf_alg 指定的置信度算法(不改变 rcr=False 的行为)
|
| 429 |
if (not rcr and alg == 'maskgit_plus') or (rcr and conf_alg == 'maskgit_plus'):
|
| 430 |
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
|
|
|
| 437 |
mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True
|
| 438 |
)
|
| 439 |
else:
|
|
|
|
| 440 |
if rcr:
|
| 441 |
if alg == 'maskgit_plus':
|
| 442 |
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
|
|
|
| 454 |
raise RuntimeError(f"Unknown alg: {alg}")
|
| 455 |
|
| 456 |
if rcr:
|
| 457 |
+
# 仅在 rcr=True:应用 RCR
|
| 458 |
+
print("[RCR] step", i)
|
| 459 |
self._apply_rcr_logic(
|
| 460 |
x, x0, confidence, mask_index, overtime_confidence,
|
| 461 |
mask_token_id, i, steps, s, t
|
| 462 |
)
|
| 463 |
else:
|
| 464 |
+
# 原版 Dream 逻辑:保持不变
|
| 465 |
num_mask_token = mask_index.sum() / mask_index.shape[0]
|
| 466 |
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
|
| 467 |
full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
|