import warnings import copy from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union import torch import torch.distributions as dists from torch.nn import functional as F from transformers import __version__ from transformers.generation.configuration_utils import GenerationConfig from transformers.utils import ModelOutput, is_torchdynamo_compiling, logging logger = logging.get_logger(__name__) def _apply_top_p_k_temp(logits, temperature=0.0, top_p=None, top_k=None): if temperature and temperature > 0: logits = logits / temperature if top_p is not None and top_p < 1: # top-p sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) if top_k is not None: # top-k top_k = int(min(top_k, logits.size(-1))) if top_k > 0: indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) return logits @dataclass class DreamModelOutput(ModelOutput): sequences: torch.LongTensor = None history: Optional[Tuple[torch.FloatTensor]] = None class DreamGenerationConfig(GenerationConfig): def __init__(self, **kwargs): # sampling self.temperature: float = kwargs.pop("temperature", 0.0) self.top_p: Optional[float] = kwargs.pop("top_p", None) self.top_k: Optional[int] = kwargs.pop("top_k", None) # length self.max_length = kwargs.pop("max_length", 20) self.max_new_tokens = kwargs.pop("max_new_tokens", None) # diffusion specific params self.eps: float = kwargs.pop("eps", 1e-3) self.steps: int = kwargs.pop("steps", 512) self.alg: str = kwargs.pop("alg", 'origin') # vanilla 使用 self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None) # RCR self.rcr: bool = kwargs.pop("rcr", False) # 注意:论文版 RCR 会忽略这里的 conf_alg,并统一用“选中 token 概率”做 running max self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus') # outputs self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1) self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False) self.output_history: bool = kwargs.pop("output_history", False) # special tokens self.mask_token_id = kwargs.pop("mask_token_id", None) self.pad_token_id = kwargs.pop("pad_token_id", None) self.bos_token_id = kwargs.pop("bos_token_id", None) self.eos_token_id = kwargs.pop("eos_token_id", None) # misc self.generation_kwargs = kwargs.pop("generation_kwargs", {}) # bookkeeping self._from_model_config = kwargs.pop("_from_model_config", False) self._commit_hash = kwargs.pop("_commit_hash", None) self.transformers_version = kwargs.pop("transformers_version", __version__) if not self._from_model_config: for key, value in kwargs.items(): try: setattr(self, key, value) except AttributeError as err: logger.error(f"Can't set {key} with value {value} for {self}") raise err self.validate(is_init=True) def validate(self, is_init=False): pass class DreamGenerationMixin: @staticmethod def _expand_inputs_for_generation( expand_size: int = 1, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None ): if expand_size == 1: return input_ids, attention_mask if input_ids is not None: input_ids = input_ids.repeat_interleave(expand_size, dim=0) if attention_mask is not None: attention_mask = attention_mask.repeat_interleave(expand_size, dim=0) return input_ids, attention_mask # =============== 论文版 RCR:运行最大置信度 + 直接选 n_t 回遮 =============== def _apply_rcr_logic_paper( self, x: torch.Tensor, # [B, L] rmax_conf: torch.Tensor, # [B, L], float32, running max of selected-token prob init_mask_bool: torch.Tensor, # [B, L], 初始生成区域(最开始是 MASK 的位置) init_mask_count: torch.Tensor, # [B], 初始 MASK 数 M0 mask_token_id: int, step: int, total_steps: int, s: torch.Tensor, t: torch.Tensor, ): """ 目标:在“初始生成区域”(init_mask_bool) 内,让“已确认个数”符合 vanilla 的线性进度; 但位置选择依据“历史最大置信度 rmax_conf”——每步保留 rmax_conf 高的,回遮 rmax_conf 低的。 做法: target_cum = floor(M0 * (1 - s/t)) # 最后一步 = M0 在 init_mask_bool[j] 内按 rmax_conf[j] 降序选 target_cum 个 => 保持已确认(不 mask) 其余位置设为 mask_token_id """ B, L = x.shape for j in range(B): M0 = int(init_mask_count[j].item()) if step < total_steps - 1: target_cum = int(M0 * (1.0 - (s.item() / t.item()))) else: target_cum = M0 # 在初始生成区域内排序 region_idx = torch.where(init_mask_bool[j])[0] if region_idx.numel() == 0: continue # rmax_conf 越大越稳,保留前 target_cum 个 scores = rmax_conf[j, region_idx] # float32 # 防御:若还没更新过,rmax_conf 初始 0.0,会被优先回遮(符合“历史没自信过”的直觉) target_cum = min(target_cum, int(region_idx.numel())) if target_cum <= 0: # 全部保持 mask x[j, region_idx] = mask_token_id continue _, keep_local = torch.topk(scores, k=target_cum, largest=True) keep_global = region_idx[keep_local] # 其余回遮 mask_global = torch.ones_like(region_idx, dtype=torch.bool, device=x.device) mask_global[keep_local] = False remask_idx = region_idx[mask_global] if remask_idx.numel() > 0: x[j, remask_idx] = mask_token_id # keep_global 上保持当前写入的 token,不动 def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): if is_torchdynamo_compiling(): return if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: warnings.warn( f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the " "generation length. We recommend setting `max_new_tokens` to control the maximum length of the " "generation.", UserWarning, ) if input_ids_length >= generation_config.max_length: raise ValueError( f"Input length is {input_ids_length}, but `max_length` is {generation_config.max_length}. " "Increase `max_length` or set `max_new_tokens`." ) def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length): if generation_config.max_new_tokens is not None: if not has_default_max_length and generation_config.max_length is not None: logger.warning( f"Both `max_new_tokens` and `max_length` are set. `max_new_tokens` takes precedence." ) generation_config.max_length = generation_config.max_new_tokens + input_ids_length elif has_default_max_length: if generation_config.max_length == DreamGenerationConfig().max_length: generation_config.max_length = generation_config.max_length + input_ids_length mpe = getattr(self.config, "max_position_embeddings", None) if mpe is not None: generation_config.max_length = min(generation_config.max_length, mpe) return generation_config def _prepare_generation_config(self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict) -> DreamGenerationConfig: using_model_generation_config = False if generation_config is None: generation_config = DreamGenerationConfig.from_model_config(self.config) using_model_generation_config = True if not is_torchdynamo_compiling(): generation_config = copy.deepcopy(generation_config) _ = generation_config.update(**kwargs) if not using_model_generation_config: if generation_config.bos_token_id is None: generation_config.bos_token_id = self.generation_config.bos_token_id if generation_config.eos_token_id is None: generation_config.eos_token_id = self.generation_config.eos_token_id if generation_config.pad_token_id is None: generation_config.pad_token_id = self.generation_config.pad_token_id if generation_config.mask_token_id is None: generation_config.mask_token_id = self.generation_config.mask_token_id return generation_config def _prepare_special_tokens(self, generation_config: DreamGenerationConfig, device=None): def _tensor_or_none(token, device=None): if token is None: return token device = device if device is not None else self.device if isinstance(token, torch.Tensor): return token.to(device) return torch.tensor(token, device=device, dtype=torch.long) bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device) eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device) pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device) mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device) if eos_token_tensor is not None and eos_token_tensor.ndim == 0: eos_token_tensor = eos_token_tensor.unsqueeze(0) if pad_token_tensor is None and eos_token_tensor is not None: pad_token_tensor = eos_token_tensor[0] logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") generation_config._bos_token_tensor = bos_token_tensor generation_config._eos_token_tensor = eos_token_tensor generation_config._pad_token_tensor = pad_token_tensor generation_config._mask_token_tensor = mask_token_tensor @torch.no_grad() def diffusion_generate( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[DreamGenerationConfig] = None, **kwargs, ): generation_config = self._prepare_generation_config(generation_config, **kwargs) generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x) generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits) assert inputs is not None input_ids = inputs device = input_ids.device attention_mask = kwargs.pop("attention_mask", None) self._prepare_special_tokens(generation_config, device=device) input_ids_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None generation_config = self._prepare_generated_length( generation_config=generation_config, has_default_max_length=has_default_max_length, input_ids_length=input_ids_length, ) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type: warnings.warn( "You are calling .generate() with `input_ids` on a different device than the model.", UserWarning, ) if ( hasattr(generation_config, "pad_token_id") and torch.any(input_ids == generation_config.pad_token_id) and attention_mask is None ): warnings.warn( "Padding detected but no attention mask was passed. Set `attention_mask` for correct generation.", UserWarning, ) input_ids, attention_mask = self._expand_inputs_for_generation( expand_size=generation_config.num_return_sequences, input_ids=input_ids, attention_mask=attention_mask, ) return self._sample( input_ids, attention_mask=attention_mask, generation_config=generation_config, generation_tokens_hook_func=generation_tokens_hook_func, generation_logits_hook_func=generation_logits_hook_func, ) def _sample( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor], generation_config: DreamGenerationConfig, generation_tokens_hook_func, generation_logits_hook_func ): output_history = generation_config.output_history return_dict_in_generate = generation_config.return_dict_in_generate max_length = generation_config.max_length mask_token_id = generation_config.mask_token_id steps = generation_config.steps eps = generation_config.eps alg = generation_config.alg alg_temp = generation_config.alg_temp temperature = generation_config.temperature top_p = generation_config.top_p top_k = generation_config.top_k rcr = generation_config.rcr # 打开则走论文版 RCR(历史最大 top-1 概率) histories = [] if (return_dict_in_generate and output_history) else None # pad input_ids to max_length x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id) if attention_mask is not None and torch.any(attention_mask == 0.0): attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0) tok_idx = attention_mask.long().cumsum(-1) - 1 tok_idx.masked_fill_(attention_mask == 0, 1) attention_mask = torch.logical_and( attention_mask.unsqueeze(1).unsqueeze(-2), attention_mask.unsqueeze(1).unsqueeze(-1), ) else: tok_idx = None attention_mask = "full" timesteps = torch.linspace(1, eps, steps + 1, device=x.device) if rcr: # 初始生成区域(prompt 之外扩展出来的那一段) init_mask_bool = (x == mask_token_id) # [B, L] init_mask_count = init_mask_bool.sum(dim=1) # [B] # 历史最大“被选 token 概率”(float32) rmax_conf = torch.zeros_like(x, dtype=torch.float32, device=x.device) logger.warning( "[RCR] Using PAPER version: running-max of SELECTED-TOKEN PROB; " "this overrides `conf_alg` (e.g., entropy) for remasking decisions." ) x = generation_tokens_hook_func(None, x, None) for i in range(steps): mask_index = (x == mask_token_id) # 前向 logits = self(x, attention_mask, tok_idx).logits logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) logits = generation_logits_hook_func(i, x, logits) t = timesteps[i] s = timesteps[i + 1] if not rcr: # ===== vanilla 路径(保持你原来的实现)===== mask_logits = logits[mask_index] if alg == 'origin': p_transfer = 1 - s / t if i < steps - 1 else 1 x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer if transfer_index_t_s.any(): logits_sub = mask_logits[transfer_index_t_s] logits_sub = _apply_top_p_k_temp(logits_sub, temperature, top_p, top_k) probs_sub = torch.softmax(logits_sub, dim=-1) try: x0_sel = dists.Categorical(probs=probs_sub).sample() except Exception: x0_sel = probs_sub.argmax(dim=-1) x0[transfer_index_t_s] = x0_sel x[mask_index] = x0.clone() else: # 按你 vanilla 的 top-k / alg_temp 逻辑 mask_logits = _apply_top_p_k_temp(logits[mask_index], temperature, top_p, top_k) probs = torch.softmax(mask_logits, dim=-1) if temperature and temperature > 0: try: x0 = dists.Categorical(probs=probs).sample() confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) except Exception: confidence, x0 = probs.max(dim=-1) else: confidence, x0 = probs.max(dim=-1) avg_mask_now = (mask_index.sum().item() / max(1, mask_index.shape[0])) ratio = (1.0 - (s.item() / t.item())) if i < steps - 1 else 1.0 number_transfer_tokens = int(avg_mask_now * ratio) full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype) full_confidence[mask_index] = confidence if number_transfer_tokens > 0: if alg_temp is None or alg_temp == 0: _, transfer_index = torch.topk(full_confidence, number_transfer_tokens) else: full_confidence = full_confidence / alg_temp full_confidence = F.softmax(full_confidence, dim=-1) transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens) x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id x_[mask_index] = x0.clone() row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index) x[row_indices, transfer_index] = x_[row_indices, transfer_index] else: # ===== 论文版 RCR ===== # 1) 仅对当前 mask 的位置,做 top_p/top_k/temperature 过滤后采样(或贪心) mask_logits = logits[mask_index] mask_logits = _apply_top_p_k_temp(mask_logits, temperature, top_p, top_k) probs = torch.softmax(mask_logits, dim=-1) # 采样 / 贪心 if temperature and temperature > 0: try: x0 = dists.Categorical(probs=probs).sample() except Exception: x0 = probs.argmax(dim=-1) else: x0 = probs.argmax(dim=-1) # 被选 token 的概率 p_sel(论文要求用这个做“历史置信度”) p_sel = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) # [M], float32 # 写入选中的 token x_maskwrite = torch.full_like(x, mask_token_id, dtype=torch.long) x_maskwrite[mask_index] = x0 x = torch.where(mask_index, x_maskwrite, x) # 更新 running-max 置信度(float32) # 先铺到全长 full_p_sel = torch.zeros_like(x, dtype=torch.float32) full_p_sel[mask_index] = p_sel.to(torch.float32) rmax_conf = torch.maximum(rmax_conf, full_p_sel) # 2) 基于 rmax_conf 直接确定“下一步要保留的已确认个数”,其余全部回遮 self._apply_rcr_logic_paper( x=x, rmax_conf=rmax_conf, init_mask_bool=init_mask_bool, init_mask_count=init_mask_count, mask_token_id=mask_token_id, step=i, total_steps=steps, s=s, t=t, ) x = generation_tokens_hook_func(i, x, logits) if histories is not None: histories.append(x.clone()) if return_dict_in_generate: return DreamModelOutput(sequences=x, history=histories) else: return x