autoprogrammer commited on
Commit
78f65a4
·
verified ·
1 Parent(s): fd52594

Update generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +136 -170
generation_utils.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import warnings
2
  import copy
3
  from dataclasses import dataclass
@@ -17,7 +18,6 @@ def _apply_top_p_k_temp(logits, temperature=0.0, top_p=None, top_k=None):
17
  if temperature and temperature > 0:
18
  logits = logits / temperature
19
  if top_p is not None and top_p < 1:
20
- # top-p
21
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
22
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
23
  sorted_indices_to_remove = cumulative_probs > top_p
@@ -27,7 +27,6 @@ def _apply_top_p_k_temp(logits, temperature=0.0, top_p=None, top_k=None):
27
  mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
28
  logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
29
  if top_k is not None:
30
- # top-k
31
  top_k = int(min(top_k, logits.size(-1)))
32
  if top_k > 0:
33
  indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
@@ -35,6 +34,26 @@ def _apply_top_p_k_temp(logits, temperature=0.0, top_p=None, top_k=None):
35
  return logits
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  @dataclass
39
  class DreamModelOutput(ModelOutput):
40
  sequences: torch.LongTensor = None
@@ -52,16 +71,23 @@ class DreamGenerationConfig(GenerationConfig):
52
  self.max_length = kwargs.pop("max_length", 20)
53
  self.max_new_tokens = kwargs.pop("max_new_tokens", None)
54
 
55
- # diffusion specific params
56
  self.eps: float = kwargs.pop("eps", 1e-3)
57
  self.steps: int = kwargs.pop("steps", 512)
58
- self.alg: str = kwargs.pop("alg", 'origin') # vanilla 使用
 
 
59
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
60
 
61
- # RCR
62
  self.rcr: bool = kwargs.pop("rcr", False)
63
- # 注意:论文版 RCR 会忽略这里的 conf_alg,并统一用“选中 token 概率”做 running max
64
- self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
 
 
 
 
 
65
 
66
  # outputs
67
  self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
@@ -93,7 +119,9 @@ class DreamGenerationConfig(GenerationConfig):
93
  self.validate(is_init=True)
94
 
95
  def validate(self, is_init=False):
96
- pass
 
 
97
 
98
 
99
  class DreamGenerationMixin:
@@ -111,70 +139,12 @@ class DreamGenerationMixin:
111
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
112
  return input_ids, attention_mask
113
 
114
- # =============== 论文版 RCR:运行最大置信度 + 直接选 n_t 回遮 ===============
115
- def _apply_rcr_logic_paper(
116
- self,
117
- x: torch.Tensor, # [B, L]
118
- rmax_conf: torch.Tensor, # [B, L], float32, running max of selected-token prob
119
- init_mask_bool: torch.Tensor, # [B, L], 初始生成区域(最开始是 MASK 的位置)
120
- init_mask_count: torch.Tensor, # [B], 初始 MASK 数 M0
121
- mask_token_id: int,
122
- step: int,
123
- total_steps: int,
124
- s: torch.Tensor,
125
- t: torch.Tensor,
126
- ):
127
- """
128
- 目标:在“初始生成区域”(init_mask_bool) 内,让“已确认个数”符合 vanilla 的线性进度;
129
- 但位置选择依据“历史最大置信度 rmax_conf”——每步保留 rmax_conf 高的,回遮 rmax_conf 低的。
130
-
131
- 做法:
132
- target_cum = floor(M0 * (1 - s/t)) # 最后一步 = M0
133
- 在 init_mask_bool[j] 内按 rmax_conf[j] 降序选 target_cum 个 => 保持已确认(不 mask)
134
- 其余位置设为 mask_token_id
135
- """
136
- B, L = x.shape
137
- for j in range(B):
138
- M0 = int(init_mask_count[j].item())
139
- if step < total_steps - 1:
140
- target_cum = int(M0 * (1.0 - (s.item() / t.item())))
141
- else:
142
- target_cum = M0
143
-
144
- # 在初始生成区域内排序
145
- region_idx = torch.where(init_mask_bool[j])[0]
146
- if region_idx.numel() == 0:
147
- continue
148
-
149
- # rmax_conf 越大越稳,保留前 target_cum 个
150
- scores = rmax_conf[j, region_idx] # float32
151
- # 防御:若还没更新过,rmax_conf 初始 0.0,会被优先回遮(符合“历史没自信过”的直觉)
152
- target_cum = min(target_cum, int(region_idx.numel()))
153
- if target_cum <= 0:
154
- # 全部保持 mask
155
- x[j, region_idx] = mask_token_id
156
- continue
157
-
158
- _, keep_local = torch.topk(scores, k=target_cum, largest=True)
159
- keep_global = region_idx[keep_local]
160
-
161
- # 其余回遮
162
- mask_global = torch.ones_like(region_idx, dtype=torch.bool, device=x.device)
163
- mask_global[keep_local] = False
164
- remask_idx = region_idx[mask_global]
165
-
166
- if remask_idx.numel() > 0:
167
- x[j, remask_idx] = mask_token_id
168
- # keep_global 上保持当前写入的 token,不动
169
-
170
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
171
  if is_torchdynamo_compiling():
172
  return
173
  if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
174
  warnings.warn(
175
- f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
176
- "generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
177
- "generation.",
178
  UserWarning,
179
  )
180
  if input_ids_length >= generation_config.max_length:
@@ -186,9 +156,7 @@ class DreamGenerationMixin:
186
  def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
187
  if generation_config.max_new_tokens is not None:
188
  if not has_default_max_length and generation_config.max_length is not None:
189
- logger.warning(
190
- f"Both `max_new_tokens` and `max_length` are set. `max_new_tokens` takes precedence."
191
- )
192
  generation_config.max_length = generation_config.max_new_tokens + input_ids_length
193
  elif has_default_max_length:
194
  if generation_config.max_length == DreamGenerationConfig().max_length:
@@ -273,7 +241,7 @@ class DreamGenerationMixin:
273
 
274
  if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
275
  warnings.warn(
276
- "You are calling .generate() with `input_ids` on a different device than the model.",
277
  UserWarning,
278
  )
279
  if (
@@ -320,7 +288,15 @@ class DreamGenerationMixin:
320
  top_p = generation_config.top_p
321
  top_k = generation_config.top_k
322
 
323
- rcr = generation_config.rcr # 打开则走论文版 RCR(历史最大 top-1 概率)
 
 
 
 
 
 
 
 
324
  histories = [] if (return_dict_in_generate and output_history) else None
325
 
326
  # pad input_ids to max_length
@@ -340,120 +316,110 @@ class DreamGenerationMixin:
340
 
341
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
342
 
 
343
  if rcr:
344
- # 初始生成区域(prompt 之外扩展出来的那一段)
345
- init_mask_bool = (x == mask_token_id) # [B, L]
346
- init_mask_count = init_mask_bool.sum(dim=1) # [B]
347
- # 历史最大“被选 token 概率”(float32)
348
- rmax_conf = torch.zeros_like(x, dtype=torch.float32, device=x.device)
349
- logger.warning(
350
- "[RCR] Using PAPER version: running-max of SELECTED-TOKEN PROB; "
351
- "this overrides `conf_alg` (e.g., entropy) for remasking decisions."
352
- )
353
 
354
  x = generation_tokens_hook_func(None, x, None)
355
 
356
  for i in range(steps):
357
  mask_index = (x == mask_token_id)
358
 
359
- # 前向
360
  logits = self(x, attention_mask, tok_idx).logits
361
  logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
362
  logits = generation_logits_hook_func(i, x, logits)
363
 
 
364
  t = timesteps[i]
365
  s = timesteps[i + 1]
366
 
367
- if not rcr:
368
- # ===== vanilla 路径(保持你原来的实现)=====
369
- mask_logits = logits[mask_index]
370
- if alg == 'origin':
371
- p_transfer = 1 - s / t if i < steps - 1 else 1
372
- x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
373
- transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
374
- if transfer_index_t_s.any():
375
- logits_sub = mask_logits[transfer_index_t_s]
376
- logits_sub = _apply_top_p_k_temp(logits_sub, temperature, top_p, top_k)
377
- probs_sub = torch.softmax(logits_sub, dim=-1)
378
- try:
379
- x0_sel = dists.Categorical(probs=probs_sub).sample()
380
- except Exception:
381
- x0_sel = probs_sub.argmax(dim=-1)
382
- x0[transfer_index_t_s] = x0_sel
383
- x[mask_index] = x0.clone()
384
- else:
385
- # 按你 vanilla 的 top-k / alg_temp 逻辑
386
- mask_logits = _apply_top_p_k_temp(logits[mask_index], temperature, top_p, top_k)
387
- probs = torch.softmax(mask_logits, dim=-1)
388
- if temperature and temperature > 0:
389
- try:
390
- x0 = dists.Categorical(probs=probs).sample()
391
- confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
392
- except Exception:
393
- confidence, x0 = probs.max(dim=-1)
394
- else:
395
- confidence, x0 = probs.max(dim=-1)
396
-
397
- avg_mask_now = (mask_index.sum().item() / max(1, mask_index.shape[0]))
398
- ratio = (1.0 - (s.item() / t.item())) if i < steps - 1 else 1.0
399
- number_transfer_tokens = int(avg_mask_now * ratio)
400
-
401
- full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
402
- full_confidence[mask_index] = confidence
403
-
404
- if number_transfer_tokens > 0:
405
- if alg_temp is None or alg_temp == 0:
406
- _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
407
- else:
408
- full_confidence = full_confidence / alg_temp
409
- full_confidence = F.softmax(full_confidence, dim=-1)
410
- transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
411
- x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
412
- x_[mask_index] = x0.clone()
413
- row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
414
- x[row_indices, transfer_index] = x_[row_indices, transfer_index]
415
 
416
- else:
417
- # ===== 论文版 RCR =====
418
- # 1) 仅对当前 mask 的位置,做 top_p/top_k/temperature 过滤后采样(或贪心)
419
- mask_logits = logits[mask_index]
420
- mask_logits = _apply_top_p_k_temp(mask_logits, temperature, top_p, top_k)
421
- probs = torch.softmax(mask_logits, dim=-1)
422
-
423
- # 采样 / 贪心
424
- if temperature and temperature > 0:
425
- try:
426
- x0 = dists.Categorical(probs=probs).sample()
427
- except Exception:
428
- x0 = probs.argmax(dim=-1)
429
- else:
430
- x0 = probs.argmax(dim=-1)
431
 
432
- # 被选 token 的概率 p_sel(论文要求用这个做“历史置信度”)
433
- p_sel = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) # [M], float32
434
-
435
- # 写入选中的 token
436
- x_maskwrite = torch.full_like(x, mask_token_id, dtype=torch.long)
437
- x_maskwrite[mask_index] = x0
438
- x = torch.where(mask_index, x_maskwrite, x)
439
-
440
- # 更新 running-max 置信度(float32)
441
- # 先铺到全长
442
- full_p_sel = torch.zeros_like(x, dtype=torch.float32)
443
- full_p_sel[mask_index] = p_sel.to(torch.float32)
444
- rmax_conf = torch.maximum(rmax_conf, full_p_sel)
445
-
446
- # 2) 基于 rmax_conf 直接确定“下一步要保留的已确认个数”,其余全部回遮
447
- self._apply_rcr_logic_paper(
448
- x=x,
449
- rmax_conf=rmax_conf,
450
- init_mask_bool=init_mask_bool,
451
- init_mask_count=init_mask_count,
452
- mask_token_id=mask_token_id,
453
- step=i,
454
- total_steps=steps,
455
- s=s, t=t,
456
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
 
458
  x = generation_tokens_hook_func(i, x, logits)
459
  if histories is not None:
 
1
+ # coding=utf-8
2
  import warnings
3
  import copy
4
  from dataclasses import dataclass
 
18
  if temperature and temperature > 0:
19
  logits = logits / temperature
20
  if top_p is not None and top_p < 1:
 
21
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
22
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
23
  sorted_indices_to_remove = cumulative_probs > top_p
 
27
  mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
28
  logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
29
  if top_k is not None:
 
30
  top_k = int(min(top_k, logits.size(-1)))
31
  if top_k > 0:
32
  indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
 
34
  return logits
35
 
36
 
37
+ def _confidence_from_probs(
38
+ probs: torch.Tensor, # [..., V]
39
+ chosen_ids: Optional[torch.Tensor], # [...]
40
+ mode: str # 'entropy' | 'maskgit_plus' | 'topk_margin'
41
+ ) -> torch.Tensor:
42
+ """返回“越大越自信”的标量分数,与解码一致。"""
43
+ if mode == "entropy":
44
+ eps = 1e-10
45
+ logp = torch.log(probs + eps)
46
+ return -(probs * logp).sum(dim=-1) # -H(p)
47
+ elif mode == "maskgit_plus":
48
+ assert chosen_ids is not None, "maskgit_plus 需要 chosen_ids"
49
+ return torch.gather(probs, -1, chosen_ids.unsqueeze(-1)).squeeze(-1) # p(x0)
50
+ elif mode == "topk_margin":
51
+ sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
52
+ return sorted_probs[..., 0] - sorted_probs[..., 1] # top1 - top2
53
+ else:
54
+ raise ValueError(f"Unknown conf mode: {mode}")
55
+
56
+
57
  @dataclass
58
  class DreamModelOutput(ModelOutput):
59
  sequences: torch.LongTensor = None
 
71
  self.max_length = kwargs.pop("max_length", 20)
72
  self.max_new_tokens = kwargs.pop("max_new_tokens", None)
73
 
74
+ # diffusion
75
  self.eps: float = kwargs.pop("eps", 1e-3)
76
  self.steps: int = kwargs.pop("steps", 512)
77
+
78
+ # vanilla 的打分算法(rcr=False 时使用)
79
+ self.alg: str = kwargs.pop("alg", 'maskgit_plus') # 'origin' | 'maskgit_plus' | 'topk_margin' | 'entropy'
80
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
81
 
82
+ # === RCR ===
83
  self.rcr: bool = kwargs.pop("rcr", False)
84
+ # rcr=True 时用于解码 & 历史分一致的置信度定义
85
+ self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus') # 'maskgit_plus' | 'topk_margin' | 'entropy'
86
+ # 注意:下两项会被 _sample 内部“写死”为 1/4 到 3/4,总是覆盖
87
+ self.rcr_start_step: int = kwargs.pop("rcr_start_step", 0)
88
+ self.rcr_end_step: int = kwargs.pop("rcr_end_step", None) or self.steps
89
+ # 是否保护“本步刚写”的 token 不被回遮
90
+ self.rcr_protect_current_step: bool = kwargs.pop("rcr_protect_current_step", False)
91
 
92
  # outputs
93
  self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
 
119
  self.validate(is_init=True)
120
 
121
  def validate(self, is_init=False):
122
+ # 简单边界
123
+ self.rcr_start_step = max(0, int(self.rcr_start_step))
124
+ self.rcr_end_step = max(self.rcr_start_step, int(self.rcr_end_step))
125
 
126
 
127
  class DreamGenerationMixin:
 
139
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
140
  return input_ids, attention_mask
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
143
  if is_torchdynamo_compiling():
144
  return
145
  if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
146
  warnings.warn(
147
+ f"Using default `max_length` (={generation_config.max_length}). Prefer `max_new_tokens`.",
 
 
148
  UserWarning,
149
  )
150
  if input_ids_length >= generation_config.max_length:
 
156
  def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
157
  if generation_config.max_new_tokens is not None:
158
  if not has_default_max_length and generation_config.max_length is not None:
159
+ logger.warning("Both `max_new_tokens` and `max_length` are set. `max_new_tokens` takes precedence.")
 
 
160
  generation_config.max_length = generation_config.max_new_tokens + input_ids_length
161
  elif has_default_max_length:
162
  if generation_config.max_length == DreamGenerationConfig().max_length:
 
241
 
242
  if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
243
  warnings.warn(
244
+ "You are calling .generate() with `input_ids` on a device different from the model.",
245
  UserWarning,
246
  )
247
  if (
 
288
  top_p = generation_config.top_p
289
  top_k = generation_config.top_k
290
 
291
+ rcr = generation_config.rcr
292
+ conf_alg = generation_config.conf_alg if rcr else generation_config.alg
293
+
294
+ # === 写死 RCR 生效窗口:总步数的 1/4 到 3/4(左闭右开 [start, end))===
295
+ rcr_start = max(0, steps // 4)
296
+ rcr_end = max(rcr_start, min(steps, (3 * steps) // 4))
297
+
298
+ protect_cur = bool(generation_config.rcr_protect_current_step)
299
+
300
  histories = [] if (return_dict_in_generate and output_history) else None
301
 
302
  # pad input_ids to max_length
 
316
 
317
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
318
 
319
+ # ==== RCR 状态 ====
320
  if rcr:
321
+ init_mask_bool = (x == mask_token_id) # 初始生成区域
322
+ init_mask_count = init_mask_bool.sum(dim=1) # [B]
323
+ hist_conf = torch.zeros_like(x, dtype=torch.float32, device=x.device) # 历史最大置信度
324
+ gen_mask = torch.zeros_like(x, dtype=torch.bool, device=x.device) # 已确认位置
325
+ written_step = torch.full_like(x, -1, dtype=torch.int32, device=x.device)
 
 
 
 
326
 
327
  x = generation_tokens_hook_func(None, x, None)
328
 
329
  for i in range(steps):
330
  mask_index = (x == mask_token_id)
331
 
332
+ # 前向 + Dream 的右移对齐
333
  logits = self(x, attention_mask, tok_idx).logits
334
  logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
335
  logits = generation_logits_hook_func(i, x, logits)
336
 
337
+ # 时间步
338
  t = timesteps[i]
339
  s = timesteps[i + 1]
340
 
341
+ # —— 仅抽出 mask 位置的 logits 并做过滤 ——
342
+ mask_logits = logits[mask_index]
343
+ if mask_logits.numel() == 0:
344
+ x = generation_tokens_hook_func(i, x, logits)
345
+ if histories is not None:
346
+ histories.append(x.clone())
347
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
+ mask_logits = _apply_top_p_k_temp(mask_logits, temperature, top_p, top_k)
350
+ probs = torch.softmax(mask_logits, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
+ # 采样 / 贪心拿到 x0
353
+ if temperature and temperature > 0:
354
+ try:
355
+ x0 = dists.Categorical(probs=probs).sample()
356
+ except Exception:
357
+ x0 = probs.argmax(dim=-1)
358
+ else:
359
+ x0 = probs.argmax(dim=-1)
360
+
361
+ # 统一置信度(与解码一致)
362
+ conf_now = _confidence_from_probs(
363
+ probs=probs,
364
+ chosen_ids=x0 if conf_alg == "maskgit_plus" else None,
365
+ mode=conf_alg
366
+ ).to(torch.float32) # [M]
367
+
368
+ # ====== 计算当步写入配额 k_t(与 vanilla 一致)======
369
+ Mt = mask_index.sum().item()
370
+ ratio = (1.0 - (s.item() / t.item())) if i < steps - 1 else 1.0
371
+ k_t = int(Mt * ratio)
372
+
373
+ # —— 写入:top-k_t ——(无论 RCR 窗口与否,先写)
374
+ full_conf_now = torch.full((x.size(0), x.size(1)), -1e9, dtype=torch.float32, device=x.device)
375
+ full_x0 = torch.full_like(x, mask_token_id, dtype=torch.long)
376
+ full_conf_now[mask_index] = conf_now
377
+ full_x0[mask_index] = x0
378
+
379
+ for b in range(x.size(0)):
380
+ masked_b = int(mask_index[b].sum().item())
381
+ if masked_b == 0 or k_t <= 0:
382
+ continue
383
+ k_b = min(k_t, masked_b)
384
+ _, sel_idx = torch.topk(full_conf_now[b], k=k_b, largest=True)
385
+ x[b, sel_idx] = full_x0[b, sel_idx]
386
+ if rcr:
387
+ gen_mask[b, sel_idx] = True
388
+ written_step[b, sel_idx] = i
389
+ # 更新历史最大置信度(与解码同定义)
390
+ hist_conf[b, sel_idx] = torch.maximum(hist_conf[b, sel_idx], full_conf_now[b, sel_idx])
391
+
392
+ # —— RCR 窗口外:不回遮,仅跟踪历史;窗口内:执行回遮到目标累计 ——
393
+ if rcr and (rcr_start <= i < rcr_end):
394
+ for b in range(x.size(0)):
395
+ M0 = int(init_mask_count[b].item())
396
+ target_cum = M0 if i >= steps - 1 else int(M0 * (1.0 - (s.item() / t.item())))
397
+ # 当前累计确认:初始生成区域内的已确认数
398
+ C_t = int((gen_mask[b] & init_mask_bool[b]).sum().item())
399
+ over = max(0, C_t - target_cum)
400
+ if over <= 0:
401
+ continue
402
+
403
+ # 候选:初始区域 ∧ 已确认(可选:排除本步刚写)
404
+ cand = torch.where(gen_mask[b] & init_mask_bool[b])[0]
405
+ if cand.numel() == 0:
406
+ continue
407
+ if protect_cur:
408
+ mask_old = (written_step[b, cand] < i)
409
+ cand = cand[mask_old]
410
+ if cand.numel() == 0:
411
+ # 全是本步写的,且要求保护,则跳过回遮
412
+ continue
413
+
414
+ over = min(over, int(cand.numel()))
415
+ scores = hist_conf[b, cand] # 越大越自信
416
+ _, low_local = torch.topk(scores, k=over, largest=False)
417
+ low_global = cand[low_local]
418
+
419
+ # 回遮
420
+ x[b, low_global] = mask_token_id
421
+ gen_mask[b, low_global] = False
422
+ # 历史分数与 written_step 保留
423
 
424
  x = generation_tokens_hook_func(i, x, logits)
425
  if histories is not None: