Update generation_utils.py
Browse files- generation_utils.py +136 -170
generation_utils.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import warnings
|
| 2 |
import copy
|
| 3 |
from dataclasses import dataclass
|
|
@@ -17,7 +18,6 @@ def _apply_top_p_k_temp(logits, temperature=0.0, top_p=None, top_k=None):
|
|
| 17 |
if temperature and temperature > 0:
|
| 18 |
logits = logits / temperature
|
| 19 |
if top_p is not None and top_p < 1:
|
| 20 |
-
# top-p
|
| 21 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 22 |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 23 |
sorted_indices_to_remove = cumulative_probs > top_p
|
|
@@ -27,7 +27,6 @@ def _apply_top_p_k_temp(logits, temperature=0.0, top_p=None, top_k=None):
|
|
| 27 |
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
| 28 |
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
|
| 29 |
if top_k is not None:
|
| 30 |
-
# top-k
|
| 31 |
top_k = int(min(top_k, logits.size(-1)))
|
| 32 |
if top_k > 0:
|
| 33 |
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
|
@@ -35,6 +34,26 @@ def _apply_top_p_k_temp(logits, temperature=0.0, top_p=None, top_k=None):
|
|
| 35 |
return logits
|
| 36 |
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
@dataclass
|
| 39 |
class DreamModelOutput(ModelOutput):
|
| 40 |
sequences: torch.LongTensor = None
|
|
@@ -52,16 +71,23 @@ class DreamGenerationConfig(GenerationConfig):
|
|
| 52 |
self.max_length = kwargs.pop("max_length", 20)
|
| 53 |
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
|
| 54 |
|
| 55 |
-
# diffusion
|
| 56 |
self.eps: float = kwargs.pop("eps", 1e-3)
|
| 57 |
self.steps: int = kwargs.pop("steps", 512)
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
|
| 60 |
|
| 61 |
-
# RCR
|
| 62 |
self.rcr: bool = kwargs.pop("rcr", False)
|
| 63 |
-
#
|
| 64 |
-
self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# outputs
|
| 67 |
self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
|
|
@@ -93,7 +119,9 @@ class DreamGenerationConfig(GenerationConfig):
|
|
| 93 |
self.validate(is_init=True)
|
| 94 |
|
| 95 |
def validate(self, is_init=False):
|
| 96 |
-
|
|
|
|
|
|
|
| 97 |
|
| 98 |
|
| 99 |
class DreamGenerationMixin:
|
|
@@ -111,70 +139,12 @@ class DreamGenerationMixin:
|
|
| 111 |
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
|
| 112 |
return input_ids, attention_mask
|
| 113 |
|
| 114 |
-
# =============== 论文版 RCR:运行最大置信度 + 直接选 n_t 回遮 ===============
|
| 115 |
-
def _apply_rcr_logic_paper(
|
| 116 |
-
self,
|
| 117 |
-
x: torch.Tensor, # [B, L]
|
| 118 |
-
rmax_conf: torch.Tensor, # [B, L], float32, running max of selected-token prob
|
| 119 |
-
init_mask_bool: torch.Tensor, # [B, L], 初始生成区域(最开始是 MASK 的位置)
|
| 120 |
-
init_mask_count: torch.Tensor, # [B], 初始 MASK 数 M0
|
| 121 |
-
mask_token_id: int,
|
| 122 |
-
step: int,
|
| 123 |
-
total_steps: int,
|
| 124 |
-
s: torch.Tensor,
|
| 125 |
-
t: torch.Tensor,
|
| 126 |
-
):
|
| 127 |
-
"""
|
| 128 |
-
目标:在“初始生成区域”(init_mask_bool) 内,让“已确认个数”符合 vanilla 的线性进度;
|
| 129 |
-
但位置选择依据“历史最大置信度 rmax_conf”——每步保留 rmax_conf 高的,回遮 rmax_conf 低的。
|
| 130 |
-
|
| 131 |
-
做法:
|
| 132 |
-
target_cum = floor(M0 * (1 - s/t)) # 最后一步 = M0
|
| 133 |
-
在 init_mask_bool[j] 内按 rmax_conf[j] 降序选 target_cum 个 => 保持已确认(不 mask)
|
| 134 |
-
其余位置设为 mask_token_id
|
| 135 |
-
"""
|
| 136 |
-
B, L = x.shape
|
| 137 |
-
for j in range(B):
|
| 138 |
-
M0 = int(init_mask_count[j].item())
|
| 139 |
-
if step < total_steps - 1:
|
| 140 |
-
target_cum = int(M0 * (1.0 - (s.item() / t.item())))
|
| 141 |
-
else:
|
| 142 |
-
target_cum = M0
|
| 143 |
-
|
| 144 |
-
# 在初始生成区域内排序
|
| 145 |
-
region_idx = torch.where(init_mask_bool[j])[0]
|
| 146 |
-
if region_idx.numel() == 0:
|
| 147 |
-
continue
|
| 148 |
-
|
| 149 |
-
# rmax_conf 越大越稳,保留前 target_cum 个
|
| 150 |
-
scores = rmax_conf[j, region_idx] # float32
|
| 151 |
-
# 防御:若还没更新过,rmax_conf 初始 0.0,会被优先回遮(符合“历史没自信过”的直觉)
|
| 152 |
-
target_cum = min(target_cum, int(region_idx.numel()))
|
| 153 |
-
if target_cum <= 0:
|
| 154 |
-
# 全部保持 mask
|
| 155 |
-
x[j, region_idx] = mask_token_id
|
| 156 |
-
continue
|
| 157 |
-
|
| 158 |
-
_, keep_local = torch.topk(scores, k=target_cum, largest=True)
|
| 159 |
-
keep_global = region_idx[keep_local]
|
| 160 |
-
|
| 161 |
-
# 其余回遮
|
| 162 |
-
mask_global = torch.ones_like(region_idx, dtype=torch.bool, device=x.device)
|
| 163 |
-
mask_global[keep_local] = False
|
| 164 |
-
remask_idx = region_idx[mask_global]
|
| 165 |
-
|
| 166 |
-
if remask_idx.numel() > 0:
|
| 167 |
-
x[j, remask_idx] = mask_token_id
|
| 168 |
-
# keep_global 上保持当前写入的 token,不动
|
| 169 |
-
|
| 170 |
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
|
| 171 |
if is_torchdynamo_compiling():
|
| 172 |
return
|
| 173 |
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
|
| 174 |
warnings.warn(
|
| 175 |
-
f"Using
|
| 176 |
-
"generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
|
| 177 |
-
"generation.",
|
| 178 |
UserWarning,
|
| 179 |
)
|
| 180 |
if input_ids_length >= generation_config.max_length:
|
|
@@ -186,9 +156,7 @@ class DreamGenerationMixin:
|
|
| 186 |
def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
|
| 187 |
if generation_config.max_new_tokens is not None:
|
| 188 |
if not has_default_max_length and generation_config.max_length is not None:
|
| 189 |
-
logger.warning(
|
| 190 |
-
f"Both `max_new_tokens` and `max_length` are set. `max_new_tokens` takes precedence."
|
| 191 |
-
)
|
| 192 |
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
| 193 |
elif has_default_max_length:
|
| 194 |
if generation_config.max_length == DreamGenerationConfig().max_length:
|
|
@@ -273,7 +241,7 @@ class DreamGenerationMixin:
|
|
| 273 |
|
| 274 |
if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
|
| 275 |
warnings.warn(
|
| 276 |
-
"You are calling .generate() with `input_ids` on a different
|
| 277 |
UserWarning,
|
| 278 |
)
|
| 279 |
if (
|
|
@@ -320,7 +288,15 @@ class DreamGenerationMixin:
|
|
| 320 |
top_p = generation_config.top_p
|
| 321 |
top_k = generation_config.top_k
|
| 322 |
|
| 323 |
-
rcr = generation_config.rcr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
histories = [] if (return_dict_in_generate and output_history) else None
|
| 325 |
|
| 326 |
# pad input_ids to max_length
|
|
@@ -340,120 +316,110 @@ class DreamGenerationMixin:
|
|
| 340 |
|
| 341 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
| 342 |
|
|
|
|
| 343 |
if rcr:
|
| 344 |
-
#
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
logger.warning(
|
| 350 |
-
"[RCR] Using PAPER version: running-max of SELECTED-TOKEN PROB; "
|
| 351 |
-
"this overrides `conf_alg` (e.g., entropy) for remasking decisions."
|
| 352 |
-
)
|
| 353 |
|
| 354 |
x = generation_tokens_hook_func(None, x, None)
|
| 355 |
|
| 356 |
for i in range(steps):
|
| 357 |
mask_index = (x == mask_token_id)
|
| 358 |
|
| 359 |
-
# 前向
|
| 360 |
logits = self(x, attention_mask, tok_idx).logits
|
| 361 |
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
| 362 |
logits = generation_logits_hook_func(i, x, logits)
|
| 363 |
|
|
|
|
| 364 |
t = timesteps[i]
|
| 365 |
s = timesteps[i + 1]
|
| 366 |
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
if transfer_index_t_s.any():
|
| 375 |
-
logits_sub = mask_logits[transfer_index_t_s]
|
| 376 |
-
logits_sub = _apply_top_p_k_temp(logits_sub, temperature, top_p, top_k)
|
| 377 |
-
probs_sub = torch.softmax(logits_sub, dim=-1)
|
| 378 |
-
try:
|
| 379 |
-
x0_sel = dists.Categorical(probs=probs_sub).sample()
|
| 380 |
-
except Exception:
|
| 381 |
-
x0_sel = probs_sub.argmax(dim=-1)
|
| 382 |
-
x0[transfer_index_t_s] = x0_sel
|
| 383 |
-
x[mask_index] = x0.clone()
|
| 384 |
-
else:
|
| 385 |
-
# 按你 vanilla 的 top-k / alg_temp 逻辑
|
| 386 |
-
mask_logits = _apply_top_p_k_temp(logits[mask_index], temperature, top_p, top_k)
|
| 387 |
-
probs = torch.softmax(mask_logits, dim=-1)
|
| 388 |
-
if temperature and temperature > 0:
|
| 389 |
-
try:
|
| 390 |
-
x0 = dists.Categorical(probs=probs).sample()
|
| 391 |
-
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
| 392 |
-
except Exception:
|
| 393 |
-
confidence, x0 = probs.max(dim=-1)
|
| 394 |
-
else:
|
| 395 |
-
confidence, x0 = probs.max(dim=-1)
|
| 396 |
-
|
| 397 |
-
avg_mask_now = (mask_index.sum().item() / max(1, mask_index.shape[0]))
|
| 398 |
-
ratio = (1.0 - (s.item() / t.item())) if i < steps - 1 else 1.0
|
| 399 |
-
number_transfer_tokens = int(avg_mask_now * ratio)
|
| 400 |
-
|
| 401 |
-
full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
|
| 402 |
-
full_confidence[mask_index] = confidence
|
| 403 |
-
|
| 404 |
-
if number_transfer_tokens > 0:
|
| 405 |
-
if alg_temp is None or alg_temp == 0:
|
| 406 |
-
_, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
|
| 407 |
-
else:
|
| 408 |
-
full_confidence = full_confidence / alg_temp
|
| 409 |
-
full_confidence = F.softmax(full_confidence, dim=-1)
|
| 410 |
-
transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
|
| 411 |
-
x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
|
| 412 |
-
x_[mask_index] = x0.clone()
|
| 413 |
-
row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
|
| 414 |
-
x[row_indices, transfer_index] = x_[row_indices, transfer_index]
|
| 415 |
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
# 1) 仅对当前 mask 的位置,做 top_p/top_k/temperature 过滤后采样(或贪心)
|
| 419 |
-
mask_logits = logits[mask_index]
|
| 420 |
-
mask_logits = _apply_top_p_k_temp(mask_logits, temperature, top_p, top_k)
|
| 421 |
-
probs = torch.softmax(mask_logits, dim=-1)
|
| 422 |
-
|
| 423 |
-
# 采样 / 贪心
|
| 424 |
-
if temperature and temperature > 0:
|
| 425 |
-
try:
|
| 426 |
-
x0 = dists.Categorical(probs=probs).sample()
|
| 427 |
-
except Exception:
|
| 428 |
-
x0 = probs.argmax(dim=-1)
|
| 429 |
-
else:
|
| 430 |
-
x0 = probs.argmax(dim=-1)
|
| 431 |
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
|
| 458 |
x = generation_tokens_hook_func(i, x, logits)
|
| 459 |
if histories is not None:
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
import warnings
|
| 3 |
import copy
|
| 4 |
from dataclasses import dataclass
|
|
|
|
| 18 |
if temperature and temperature > 0:
|
| 19 |
logits = logits / temperature
|
| 20 |
if top_p is not None and top_p < 1:
|
|
|
|
| 21 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 22 |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 23 |
sorted_indices_to_remove = cumulative_probs > top_p
|
|
|
|
| 27 |
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
| 28 |
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
|
| 29 |
if top_k is not None:
|
|
|
|
| 30 |
top_k = int(min(top_k, logits.size(-1)))
|
| 31 |
if top_k > 0:
|
| 32 |
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
|
|
|
| 34 |
return logits
|
| 35 |
|
| 36 |
|
| 37 |
+
def _confidence_from_probs(
|
| 38 |
+
probs: torch.Tensor, # [..., V]
|
| 39 |
+
chosen_ids: Optional[torch.Tensor], # [...]
|
| 40 |
+
mode: str # 'entropy' | 'maskgit_plus' | 'topk_margin'
|
| 41 |
+
) -> torch.Tensor:
|
| 42 |
+
"""返回“越大越自信”的标量分数,与解码一致。"""
|
| 43 |
+
if mode == "entropy":
|
| 44 |
+
eps = 1e-10
|
| 45 |
+
logp = torch.log(probs + eps)
|
| 46 |
+
return -(probs * logp).sum(dim=-1) # -H(p)
|
| 47 |
+
elif mode == "maskgit_plus":
|
| 48 |
+
assert chosen_ids is not None, "maskgit_plus 需要 chosen_ids"
|
| 49 |
+
return torch.gather(probs, -1, chosen_ids.unsqueeze(-1)).squeeze(-1) # p(x0)
|
| 50 |
+
elif mode == "topk_margin":
|
| 51 |
+
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
| 52 |
+
return sorted_probs[..., 0] - sorted_probs[..., 1] # top1 - top2
|
| 53 |
+
else:
|
| 54 |
+
raise ValueError(f"Unknown conf mode: {mode}")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
@dataclass
|
| 58 |
class DreamModelOutput(ModelOutput):
|
| 59 |
sequences: torch.LongTensor = None
|
|
|
|
| 71 |
self.max_length = kwargs.pop("max_length", 20)
|
| 72 |
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
|
| 73 |
|
| 74 |
+
# diffusion
|
| 75 |
self.eps: float = kwargs.pop("eps", 1e-3)
|
| 76 |
self.steps: int = kwargs.pop("steps", 512)
|
| 77 |
+
|
| 78 |
+
# vanilla 的打分算法(rcr=False 时使用)
|
| 79 |
+
self.alg: str = kwargs.pop("alg", 'maskgit_plus') # 'origin' | 'maskgit_plus' | 'topk_margin' | 'entropy'
|
| 80 |
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
|
| 81 |
|
| 82 |
+
# === RCR ===
|
| 83 |
self.rcr: bool = kwargs.pop("rcr", False)
|
| 84 |
+
# rcr=True 时用于解码 & 历史分一致的置信度定义
|
| 85 |
+
self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus') # 'maskgit_plus' | 'topk_margin' | 'entropy'
|
| 86 |
+
# 注意:下两项会被 _sample 内部“写死”为 1/4 到 3/4,总是覆盖
|
| 87 |
+
self.rcr_start_step: int = kwargs.pop("rcr_start_step", 0)
|
| 88 |
+
self.rcr_end_step: int = kwargs.pop("rcr_end_step", None) or self.steps
|
| 89 |
+
# 是否保护“本步刚写”的 token 不被回遮
|
| 90 |
+
self.rcr_protect_current_step: bool = kwargs.pop("rcr_protect_current_step", False)
|
| 91 |
|
| 92 |
# outputs
|
| 93 |
self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
|
|
|
|
| 119 |
self.validate(is_init=True)
|
| 120 |
|
| 121 |
def validate(self, is_init=False):
|
| 122 |
+
# 简单边界
|
| 123 |
+
self.rcr_start_step = max(0, int(self.rcr_start_step))
|
| 124 |
+
self.rcr_end_step = max(self.rcr_start_step, int(self.rcr_end_step))
|
| 125 |
|
| 126 |
|
| 127 |
class DreamGenerationMixin:
|
|
|
|
| 139 |
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
|
| 140 |
return input_ids, attention_mask
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
|
| 143 |
if is_torchdynamo_compiling():
|
| 144 |
return
|
| 145 |
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
|
| 146 |
warnings.warn(
|
| 147 |
+
f"Using default `max_length` (={generation_config.max_length}). Prefer `max_new_tokens`.",
|
|
|
|
|
|
|
| 148 |
UserWarning,
|
| 149 |
)
|
| 150 |
if input_ids_length >= generation_config.max_length:
|
|
|
|
| 156 |
def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
|
| 157 |
if generation_config.max_new_tokens is not None:
|
| 158 |
if not has_default_max_length and generation_config.max_length is not None:
|
| 159 |
+
logger.warning("Both `max_new_tokens` and `max_length` are set. `max_new_tokens` takes precedence.")
|
|
|
|
|
|
|
| 160 |
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
| 161 |
elif has_default_max_length:
|
| 162 |
if generation_config.max_length == DreamGenerationConfig().max_length:
|
|
|
|
| 241 |
|
| 242 |
if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
|
| 243 |
warnings.warn(
|
| 244 |
+
"You are calling .generate() with `input_ids` on a device different from the model.",
|
| 245 |
UserWarning,
|
| 246 |
)
|
| 247 |
if (
|
|
|
|
| 288 |
top_p = generation_config.top_p
|
| 289 |
top_k = generation_config.top_k
|
| 290 |
|
| 291 |
+
rcr = generation_config.rcr
|
| 292 |
+
conf_alg = generation_config.conf_alg if rcr else generation_config.alg
|
| 293 |
+
|
| 294 |
+
# === 写死 RCR 生效窗口:总步数的 1/4 到 3/4(左闭右开 [start, end))===
|
| 295 |
+
rcr_start = max(0, steps // 4)
|
| 296 |
+
rcr_end = max(rcr_start, min(steps, (3 * steps) // 4))
|
| 297 |
+
|
| 298 |
+
protect_cur = bool(generation_config.rcr_protect_current_step)
|
| 299 |
+
|
| 300 |
histories = [] if (return_dict_in_generate and output_history) else None
|
| 301 |
|
| 302 |
# pad input_ids to max_length
|
|
|
|
| 316 |
|
| 317 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
| 318 |
|
| 319 |
+
# ==== RCR 状态 ====
|
| 320 |
if rcr:
|
| 321 |
+
init_mask_bool = (x == mask_token_id) # 初始生成区域
|
| 322 |
+
init_mask_count = init_mask_bool.sum(dim=1) # [B]
|
| 323 |
+
hist_conf = torch.zeros_like(x, dtype=torch.float32, device=x.device) # 历史最大置信度
|
| 324 |
+
gen_mask = torch.zeros_like(x, dtype=torch.bool, device=x.device) # 已确认位置
|
| 325 |
+
written_step = torch.full_like(x, -1, dtype=torch.int32, device=x.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
x = generation_tokens_hook_func(None, x, None)
|
| 328 |
|
| 329 |
for i in range(steps):
|
| 330 |
mask_index = (x == mask_token_id)
|
| 331 |
|
| 332 |
+
# 前向 + Dream 的右移对齐
|
| 333 |
logits = self(x, attention_mask, tok_idx).logits
|
| 334 |
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
| 335 |
logits = generation_logits_hook_func(i, x, logits)
|
| 336 |
|
| 337 |
+
# 时间步
|
| 338 |
t = timesteps[i]
|
| 339 |
s = timesteps[i + 1]
|
| 340 |
|
| 341 |
+
# —— 仅抽出 mask 位置的 logits 并做过滤 ——
|
| 342 |
+
mask_logits = logits[mask_index]
|
| 343 |
+
if mask_logits.numel() == 0:
|
| 344 |
+
x = generation_tokens_hook_func(i, x, logits)
|
| 345 |
+
if histories is not None:
|
| 346 |
+
histories.append(x.clone())
|
| 347 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
+
mask_logits = _apply_top_p_k_temp(mask_logits, temperature, top_p, top_k)
|
| 350 |
+
probs = torch.softmax(mask_logits, dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
|
| 352 |
+
# 采样 / 贪心拿到 x0
|
| 353 |
+
if temperature and temperature > 0:
|
| 354 |
+
try:
|
| 355 |
+
x0 = dists.Categorical(probs=probs).sample()
|
| 356 |
+
except Exception:
|
| 357 |
+
x0 = probs.argmax(dim=-1)
|
| 358 |
+
else:
|
| 359 |
+
x0 = probs.argmax(dim=-1)
|
| 360 |
+
|
| 361 |
+
# 统一置信度(与解码一致)
|
| 362 |
+
conf_now = _confidence_from_probs(
|
| 363 |
+
probs=probs,
|
| 364 |
+
chosen_ids=x0 if conf_alg == "maskgit_plus" else None,
|
| 365 |
+
mode=conf_alg
|
| 366 |
+
).to(torch.float32) # [M]
|
| 367 |
+
|
| 368 |
+
# ====== 计算当步写入配额 k_t(与 vanilla 一致)======
|
| 369 |
+
Mt = mask_index.sum().item()
|
| 370 |
+
ratio = (1.0 - (s.item() / t.item())) if i < steps - 1 else 1.0
|
| 371 |
+
k_t = int(Mt * ratio)
|
| 372 |
+
|
| 373 |
+
# —— 写入:top-k_t ——(无论 RCR 窗口与否,先写)
|
| 374 |
+
full_conf_now = torch.full((x.size(0), x.size(1)), -1e9, dtype=torch.float32, device=x.device)
|
| 375 |
+
full_x0 = torch.full_like(x, mask_token_id, dtype=torch.long)
|
| 376 |
+
full_conf_now[mask_index] = conf_now
|
| 377 |
+
full_x0[mask_index] = x0
|
| 378 |
+
|
| 379 |
+
for b in range(x.size(0)):
|
| 380 |
+
masked_b = int(mask_index[b].sum().item())
|
| 381 |
+
if masked_b == 0 or k_t <= 0:
|
| 382 |
+
continue
|
| 383 |
+
k_b = min(k_t, masked_b)
|
| 384 |
+
_, sel_idx = torch.topk(full_conf_now[b], k=k_b, largest=True)
|
| 385 |
+
x[b, sel_idx] = full_x0[b, sel_idx]
|
| 386 |
+
if rcr:
|
| 387 |
+
gen_mask[b, sel_idx] = True
|
| 388 |
+
written_step[b, sel_idx] = i
|
| 389 |
+
# 更新历史最大置信度(与解码同定义)
|
| 390 |
+
hist_conf[b, sel_idx] = torch.maximum(hist_conf[b, sel_idx], full_conf_now[b, sel_idx])
|
| 391 |
+
|
| 392 |
+
# —— RCR 窗口外:不回遮,仅跟踪历史;窗口内:执行回遮到目标累计 ——
|
| 393 |
+
if rcr and (rcr_start <= i < rcr_end):
|
| 394 |
+
for b in range(x.size(0)):
|
| 395 |
+
M0 = int(init_mask_count[b].item())
|
| 396 |
+
target_cum = M0 if i >= steps - 1 else int(M0 * (1.0 - (s.item() / t.item())))
|
| 397 |
+
# 当前累计确认:初始生成区域内的已确认数
|
| 398 |
+
C_t = int((gen_mask[b] & init_mask_bool[b]).sum().item())
|
| 399 |
+
over = max(0, C_t - target_cum)
|
| 400 |
+
if over <= 0:
|
| 401 |
+
continue
|
| 402 |
+
|
| 403 |
+
# 候选:初始区域 ∧ 已确认(可选:排除本步刚写)
|
| 404 |
+
cand = torch.where(gen_mask[b] & init_mask_bool[b])[0]
|
| 405 |
+
if cand.numel() == 0:
|
| 406 |
+
continue
|
| 407 |
+
if protect_cur:
|
| 408 |
+
mask_old = (written_step[b, cand] < i)
|
| 409 |
+
cand = cand[mask_old]
|
| 410 |
+
if cand.numel() == 0:
|
| 411 |
+
# 全是本步写的,且要求保护,则跳过回遮
|
| 412 |
+
continue
|
| 413 |
+
|
| 414 |
+
over = min(over, int(cand.numel()))
|
| 415 |
+
scores = hist_conf[b, cand] # 越大越自信
|
| 416 |
+
_, low_local = torch.topk(scores, k=over, largest=False)
|
| 417 |
+
low_global = cand[low_local]
|
| 418 |
+
|
| 419 |
+
# 回遮
|
| 420 |
+
x[b, low_global] = mask_token_id
|
| 421 |
+
gen_mask[b, low_global] = False
|
| 422 |
+
# 历史分数与 written_step 保留
|
| 423 |
|
| 424 |
x = generation_tokens_hook_func(i, x, logits)
|
| 425 |
if histories is not None:
|