autoprogrammer commited on
Commit
8ddc04c
·
verified ·
1 Parent(s): 4e351cb

Update generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +192 -245
generation_utils.py CHANGED
@@ -1,18 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
  import warnings
17
  import copy
18
  from dataclasses import dataclass
@@ -22,74 +9,106 @@ import torch
22
  import torch.distributions as dists
23
  from torch.nn import functional as F
24
  from transformers import __version__
25
- from transformers.generation.configuration_utils import (
26
- GenerationConfig
27
- )
28
- from transformers.utils import (
29
- ModelOutput,
30
- is_torchdynamo_compiling,
31
- logging,
32
- )
33
 
34
  logger = logging.get_logger(__name__)
35
 
36
 
37
  def top_p_logits(logits, top_p=None):
 
 
38
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
39
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
40
  sorted_indices_to_remove = cumulative_probs > top_p
41
- # Shift the indices to the right to keep the first token above the threshold
42
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
43
  sorted_indices_to_remove[..., 0] = 0
44
-
45
  mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
46
  mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
47
- logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
48
- return logits
49
-
50
- def top_k_logits(logits, top_k=None):
51
- top_k = min(top_k, logits.size(-1)) # Safety check
52
- # Remove all tokens with a probability less than the last token of the top-k
53
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
54
- logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
55
- return logits
56
-
57
 
58
- def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  if temperature > 0:
61
  logits = logits / temperature
62
- if top_p is not None and top_p < 1:
63
- logits = top_p_logits(logits, top_p)
64
- if top_k is not None:
65
- logits = top_k_logits(logits, top_k)
 
 
66
  probs = torch.softmax(logits, dim=-1)
67
 
 
68
  if temperature > 0:
69
  try:
70
  x0 = dists.Categorical(probs=probs).sample()
71
  confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
72
- except:
73
  confidence, x0 = probs.max(dim=-1)
74
  else:
75
  confidence, x0 = probs.max(dim=-1)
76
-
 
77
  if margin_confidence:
78
  sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
79
- # Extract top1 and top2 probabilities
80
- top1_probs = sorted_probs[:, 0]
81
- top2_probs = sorted_probs[:, 1]
82
- # Calculate confidence as top1 - top2
83
- confidence = top1_probs - top2_probs
84
-
85
  if neg_entropy:
86
  epsilon = 1e-10
87
  log_probs = torch.log(probs + epsilon)
88
- confidence = torch.sum(probs * log_probs, dim=-1)
89
-
 
90
  return confidence, x0
91
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  @dataclass
94
  class DreamModelOutput(ModelOutput):
95
  sequences: torch.LongTensor = None
@@ -106,19 +125,20 @@ class DreamGenerationConfig(GenerationConfig):
106
  # diffusion specific params
107
  self.eps: float = kwargs.pop("eps", 1e-3)
108
  self.steps: int = kwargs.pop("steps", 512)
109
- self.alg: str = kwargs.pop("alg", 'origin')
110
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
111
 
112
  # RCR specific parameters
113
  self.rcr: bool = kwargs.pop("rcr", False)
114
- self.conf_alg: str = kwargs.pop("conf_alg", 'maskgit_plus')
 
115
 
116
- # Parameters that define the output variables of `generate`
117
  self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
118
  self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
119
  self.output_history: bool = kwargs.pop("output_history", False)
120
 
121
- # Special tokens that can be used at generation time
122
  self.mask_token_id = kwargs.pop("mask_token_id", None)
123
  self.pad_token_id = kwargs.pop("pad_token_id", None)
124
  self.bos_token_id = kwargs.pop("bos_token_id", None)
@@ -127,16 +147,12 @@ class DreamGenerationConfig(GenerationConfig):
127
  # Wild card
128
  self.generation_kwargs = kwargs.pop("generation_kwargs", {})
129
 
130
- # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
131
- # interface.
132
  self._from_model_config = kwargs.pop("_from_model_config", False)
133
  self._commit_hash = kwargs.pop("_commit_hash", None)
134
  self.transformers_version = kwargs.pop("transformers_version", __version__)
135
 
136
- # Additional attributes without default values
137
  if not self._from_model_config:
138
- # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
139
- # model's default configuration file
140
  for key, value in kwargs.items():
141
  try:
142
  setattr(self, key, value)
@@ -144,22 +160,19 @@ class DreamGenerationConfig(GenerationConfig):
144
  logger.error(f"Can't set {key} with value {value} for {self}")
145
  raise err
146
 
147
- # Validate the values of the attributes
148
  self.validate(is_init=True)
149
 
150
  def validate(self, is_init=False):
151
  pass
152
 
 
153
  class DreamGenerationMixin:
154
  @staticmethod
155
  def _expand_inputs_for_generation(
156
  expand_size: int = 1,
157
  input_ids: Optional[torch.LongTensor] = None,
158
- attention_mask: Optional[torch.LongTensor] = None
159
  ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
160
- """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
161
- # Do not call torch.repeat_interleave if expand_size is 1 because it clones
162
- # the input tensor and thus requires more memory although no change is applied
163
  if expand_size == 1:
164
  return input_ids, attention_mask
165
  if input_ids is not None:
@@ -168,132 +181,47 @@ class DreamGenerationMixin:
168
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
169
  return input_ids, attention_mask
170
 
171
- def _apply_rcr_logic(self, x, x0, confidence, mask_index, overtime_confidence,
172
- mask_token_id, step, total_steps, s, t):
173
- """
174
- Apply Running Confidence Remasking (RCR) logic adapted for Dream model.
175
- """
176
- batch_size = x.shape[0]
177
-
178
- # Calculate number of tokens to transfer using Dream's scheduling
179
- num_mask_token = mask_index.sum() / mask_index.shape[0]
180
- number_transfer_tokens = int(num_mask_token * (1 - s / t)) if step < total_steps - 1 else int(num_mask_token)
181
-
182
- # Create full confidence tensor matching x dimensions
183
- full_confidence = torch.full_like(x, -torch.inf, device=x.device, dtype=confidence.dtype)
184
-
185
- # Create temporary tensor for x0 that matches x dimensions
186
- x_temp = torch.zeros_like(x, device=x.device, dtype=torch.long) + mask_token_id
187
-
188
- # Fill masked positions with x0 and confidence
189
- x_temp[mask_index] = x0.clone()
190
- full_confidence[mask_index] = confidence
191
-
192
- # RCR: Select tokens based on cumulative confidence
193
- for j in range(batch_size):
194
- if number_transfer_tokens > 0:
195
- batch_full_confidence = full_confidence[j]
196
-
197
- # Select top confident tokens to transfer
198
- _, select_indices = torch.topk(batch_full_confidence, k=number_transfer_tokens, largest=True)
199
- x[j, select_indices] = x_temp[j, select_indices]
200
- overtime_confidence[j, select_indices] = batch_full_confidence[select_indices].clone().float()
201
-
202
- # RCR: Re-mask lowest confidence tokens for next steps
203
- if step < total_steps - 1:
204
- # Find tokens that have been generated (non-zero confidence)
205
- generated_mask = overtime_confidence[j] > 0
206
- if generated_mask.any():
207
- # Calculate tokens to re-mask for next iteration
208
- next_num_mask_tokens = int(num_mask_token * (1 - torch.linspace(1, s, total_steps + 1, device=x.device)[step + 2] / t))
209
-
210
- if next_num_mask_tokens > 0:
211
- # Get confidence of generated tokens
212
- generated_confidence = overtime_confidence[j][generated_mask]
213
- generated_indices = torch.where(generated_mask)[0]
214
-
215
- if len(generated_confidence) >= next_num_mask_tokens:
216
- # Re-mask lowest confidence tokens
217
- _, local_mask_indices = torch.topk(
218
- generated_confidence,
219
- k=next_num_mask_tokens,
220
- largest=False
221
- )
222
- global_mask_indices = generated_indices[local_mask_indices]
223
- x[j, global_mask_indices] = mask_token_id
224
- overtime_confidence[j, global_mask_indices] = 0.0
225
-
226
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
227
- """Performs validation related to the resulting generated length"""
228
-
229
- # Can't throw warnings/exceptions during compilation
230
  if is_torchdynamo_compiling():
231
  return
232
-
233
- # 1. Max length warnings related to poor parameterization
234
  if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
235
- # 20 is the default max_length of the generation config
236
  warnings.warn(
237
- f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
238
- "generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
239
- "generation.",
240
  UserWarning,
241
  )
242
  if input_ids_length >= generation_config.max_length:
243
- input_ids_string = "input_ids"
244
  raise ValueError(
245
- f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
246
- f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
247
- " increasing `max_length` or, better yet, setting `max_new_tokens`."
248
  )
249
 
250
- def _prepare_generated_length(
251
- self,
252
- generation_config,
253
- has_default_max_length,
254
- input_ids_length,
255
- ):
256
- """Prepared max and min length in generation configs to avoid clashes between similar attributes"""
257
-
258
  if generation_config.max_new_tokens is not None:
259
  if not has_default_max_length and generation_config.max_length is not None:
260
  logger.warning(
261
- f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
262
- f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
263
- "Please refer to the documentation for more information. "
264
- "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
265
  )
266
  generation_config.max_length = generation_config.max_new_tokens + input_ids_length
267
-
268
  elif has_default_max_length:
269
  if generation_config.max_length == DreamGenerationConfig().max_length:
270
  generation_config.max_length = generation_config.max_length + input_ids_length
271
  max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
272
  if max_position_embeddings is not None:
273
  generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
274
-
275
  return generation_config
276
 
277
  def _prepare_generation_config(
278
  self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict
279
  ) -> DreamGenerationConfig:
280
- """
281
- Prepares the base generation config, then applies any generation configuration options from kwargs. This
282
- function handles retrocompatibility with respect to configuration files.
283
- """
284
- # priority: `generation_config` argument > `model.generation_config` (the default generation config)
285
  using_model_generation_config = False
286
  if generation_config is None:
287
  generation_config = DreamGenerationConfig.from_model_config(self.config)
288
  using_model_generation_config = True
289
 
290
- # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
291
- # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an
292
- # exception will be raised in `_validate_model_kwargs`
293
  if not is_torchdynamo_compiling():
294
  generation_config = copy.deepcopy(generation_config)
295
  _kwargs = generation_config.update(**kwargs)
296
- # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
297
  if not using_model_generation_config:
298
  if generation_config.bos_token_id is None:
299
  generation_config.bos_token_id = self.generation_config.bos_token_id
@@ -311,20 +239,9 @@ class DreamGenerationMixin:
311
  generation_config: DreamGenerationConfig,
312
  device: Optional[Union[torch.device, str]] = None,
313
  ):
314
- """
315
- Prepares the special tokens for generation, overwriting the generation config with their processed versions
316
- converted to tensor.
317
-
318
- Note that `generation_config` is changed in place and stops being serializable after this method is called.
319
- That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the
320
- function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
321
- """
322
-
323
- # Convert special tokens to tensors
324
  def _tensor_or_none(token, device=None):
325
  if token is None:
326
  return token
327
-
328
  device = device if device is not None else self.device
329
  if isinstance(token, torch.Tensor):
330
  return token.to(device)
@@ -335,19 +252,13 @@ class DreamGenerationMixin:
335
  pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
336
  mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)
337
 
338
- # We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
339
  if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
340
  eos_token_tensor = eos_token_tensor.unsqueeze(0)
341
 
342
- # Set pad token if unset (and there are conditions to do so)
343
  if pad_token_tensor is None and eos_token_tensor is not None:
344
  pad_token_tensor = eos_token_tensor[0]
345
  logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
346
 
347
- # Update generation config with the updated special tokens tensors
348
- # NOTE: this must be written into a different attribute name than the one holding the original special tokens
349
- # (in their non-tensor form), in order to enable end-to-end compilation. See
350
- # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
351
  generation_config._bos_token_tensor = bos_token_tensor
352
  generation_config._eos_token_tensor = eos_token_tensor
353
  generation_config._pad_token_tensor = pad_token_tensor
@@ -360,19 +271,16 @@ class DreamGenerationMixin:
360
  generation_config: Optional[DreamGenerationConfig] = None,
361
  **kwargs,
362
  ) -> Union[DreamModelOutput, torch.LongTensor]:
363
- # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
364
  generation_config = self._prepare_generation_config(generation_config, **kwargs)
365
  generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
366
  generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
367
 
368
- # 2. Define model inputs
369
  assert inputs is not None
370
  input_ids = inputs
371
  device = input_ids.device
372
  attention_mask = kwargs.pop("attention_mask", None)
373
  self._prepare_special_tokens(generation_config, device=device)
374
 
375
- # 3. Prepare `max_length`.
376
  input_ids_length = input_ids.shape[-1]
377
  has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
378
  generation_config = self._prepare_generated_length(
@@ -380,35 +288,23 @@ class DreamGenerationMixin:
380
  has_default_max_length=has_default_max_length,
381
  input_ids_length=input_ids_length,
382
  )
383
-
384
  self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
385
-
386
- # 4. Check input_ids
387
  if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
388
  warnings.warn(
389
- "You are calling .generate() with the `input_ids` being on a device type different"
390
- f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
391
- f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
392
- " Please make sure that you have put `input_ids` to the"
393
- f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
394
- " running `.generate()`.",
395
  UserWarning,
396
  )
397
- if (
398
- hasattr(generation_config, "pad_token_id") and
399
- torch.any(input_ids == generation_config.pad_token_id) and
400
- attention_mask is None
401
- ):
402
  warnings.warn(
403
- "Padding was detected but no attention mask is passed here. For correct "
404
- "generation results, please set `attention_mask` when batch-padding inputs.",
405
  UserWarning,
406
  )
407
 
408
  input_ids, attention_mask = self._expand_inputs_for_generation(
409
  expand_size=generation_config.num_return_sequences,
410
  input_ids=input_ids,
411
- attention_mask=attention_mask
412
  )
413
 
414
  result = self._sample(
@@ -416,7 +312,7 @@ class DreamGenerationMixin:
416
  attention_mask=attention_mask,
417
  generation_config=generation_config,
418
  generation_tokens_hook_func=generation_tokens_hook_func,
419
- generation_logits_hook_func=generation_logits_hook_func
420
  )
421
  return result
422
 
@@ -426,9 +322,10 @@ class DreamGenerationMixin:
426
  attention_mask: Optional[torch.LongTensor],
427
  generation_config: DreamGenerationConfig,
428
  generation_tokens_hook_func,
429
- generation_logits_hook_func
430
  ) -> Union[DreamModelOutput, torch.LongTensor]:
431
- # init values
 
432
  output_history = generation_config.output_history
433
  return_dict_in_generate = generation_config.return_dict_in_generate
434
  max_length = generation_config.max_length
@@ -441,22 +338,20 @@ class DreamGenerationMixin:
441
  top_p = generation_config.top_p
442
  top_k = generation_config.top_k
443
 
444
- # RCR specific values
445
  rcr = generation_config.rcr
446
  conf_alg = generation_config.conf_alg
 
447
 
448
  histories = [] if (return_dict_in_generate and output_history) else None
449
 
450
- # pad input_ids to max_length
451
  x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
452
 
453
  if attention_mask is not None and torch.any(attention_mask == 0.0):
454
- # we do not mask the [MASK] tokens so value = 1.0
455
  attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
456
  tok_idx = attention_mask.long().cumsum(-1) - 1
457
  tok_idx.masked_fill_(attention_mask == 0, 1)
458
- # attention_mask is of shape [B, N]
459
- # broadcast to [B, 1, N, N]
460
  attention_mask = torch.logical_and(
461
  attention_mask.unsqueeze(1).unsqueeze(-2),
462
  attention_mask.unsqueeze(1).unsqueeze(-1),
@@ -465,74 +360,126 @@ class DreamGenerationMixin:
465
  tok_idx = None
466
  attention_mask = "full"
467
 
 
468
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
469
 
470
- # RCR tracking - initialize overtime confidence tracking
 
 
 
 
471
  overtime_confidence = torch.zeros_like(x, dtype=torch.float32) if rcr else None
472
 
473
- # this allows user-defined token control of the intermediate steps
474
  x = generation_tokens_hook_func(None, x, None)
 
475
  for i in range(steps):
476
- mask_index = (x == mask_token_id)
 
 
 
477
  logits = self(x, attention_mask, tok_idx).logits
478
- logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
479
 
480
- # this allows user-defined logits control of the intermediate steps
481
  logits = generation_logits_hook_func(i, x, logits)
482
 
483
- mask_logits = logits[mask_index]
 
484
  t = timesteps[i]
485
  s = timesteps[i + 1]
486
-
487
- if alg == 'origin':
488
- p_transfer = 1 - s / t if i < steps - 1 else 1
489
- x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
490
- transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
491
- _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
 
 
 
 
 
 
492
  x[mask_index] = x0.clone()
493
  else:
494
- if alg == 'maskgit_plus' or (rcr and conf_alg == 'maskgit_plus'):
 
 
495
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
496
- elif alg == 'topk_margin' or (rcr and conf_alg == 'topk_margin'):
497
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
498
- elif alg == 'entropy' or (rcr and conf_alg == 'entropy'):
499
- confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
500
  else:
501
- raise RuntimeError(f"Unknown alg: {alg}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
 
503
- # Apply RCR logic if enabled
504
- if rcr:
505
- print(f"[RCR EXEC] Step {i}: RCR logic executed")
506
- self._apply_rcr_logic(x, x0, confidence, mask_index, overtime_confidence,
507
- mask_token_id, i, steps, s, t)
508
  else:
509
- # Original Dream sampling logic
510
- num_mask_token = mask_index.sum() / mask_index.shape[0]
511
- number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
512
- full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
513
- full_confidence[mask_index] = confidence
514
- if number_transfer_tokens > 0:
515
- if alg_temp is None or alg_temp == 0:
516
- _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
517
- else:
518
- full_confidence = full_confidence / alg_temp
519
- full_confidence = F.softmax(full_confidence, dim=-1)
520
- transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
521
- x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
522
- x_[mask_index] = x0.clone()
523
- row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
524
- x[row_indices,transfer_index] = x_[row_indices,transfer_index]
525
-
526
- # this allows user-defined token control of the intermediate steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  x = generation_tokens_hook_func(i, x, logits)
528
 
529
  if histories is not None:
530
  histories.append(x.clone())
531
-
532
  if return_dict_in_generate:
533
- return DreamModelOutput(
534
- sequences=x,
535
- history=histories,
536
- )
537
  else:
538
- return x
 
1
  # coding=utf-8
2
+ # Copyright 2024 ...
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import warnings
4
  import copy
5
  from dataclasses import dataclass
 
9
  import torch.distributions as dists
10
  from torch.nn import functional as F
11
  from transformers import __version__
12
+ from transformers.generation.configuration_utils import GenerationConfig
13
+ from transformers.utils import ModelOutput, is_torchdynamo_compiling, logging
 
 
 
 
 
 
14
 
15
  logger = logging.get_logger(__name__)
16
 
17
 
18
  def top_p_logits(logits, top_p=None):
19
+ if top_p is None or top_p >= 1:
20
+ return logits
21
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
22
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
23
  sorted_indices_to_remove = cumulative_probs > top_p
 
24
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
25
  sorted_indices_to_remove[..., 0] = 0
 
26
  mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
27
  mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
28
+ return logits.masked_fill(mask, torch.finfo(logits.dtype).min)
 
 
 
 
 
 
 
 
 
29
 
 
30
 
31
+ def top_k_logits(logits, top_k=None):
32
+ if top_k is None:
33
+ return logits
34
+ top_k = min(int(top_k), logits.size(-1))
35
+ thresh = torch.topk(logits, top_k)[0][..., -1, None]
36
+ indices_to_remove = logits < thresh
37
+ return logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
38
+
39
+
40
+ def sample_tokens(
41
+ logits,
42
+ temperature=0.0,
43
+ top_p=None,
44
+ top_k=None,
45
+ margin_confidence=False,
46
+ neg_entropy=False,
47
+ ):
48
+ # temperature
49
  if temperature > 0:
50
  logits = logits / temperature
51
+
52
+ # filtering
53
+ logits = top_p_logits(logits, top_p)
54
+ logits = top_k_logits(logits, top_k)
55
+
56
+ # probs
57
  probs = torch.softmax(logits, dim=-1)
58
 
59
+ # sample or argmax
60
  if temperature > 0:
61
  try:
62
  x0 = dists.Categorical(probs=probs).sample()
63
  confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
64
+ except Exception:
65
  confidence, x0 = probs.max(dim=-1)
66
  else:
67
  confidence, x0 = probs.max(dim=-1)
68
+
69
+ # confidence variants
70
  if margin_confidence:
71
  sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
72
+ top1_probs = sorted_probs[..., 0]
73
+ top2_probs = sorted_probs[..., 1]
74
+ confidence = top1_probs - top2_probs
75
+
 
 
76
  if neg_entropy:
77
  epsilon = 1e-10
78
  log_probs = torch.log(probs + epsilon)
79
+ # 注意:neg_entropy 越大代表越“确定”
80
+ confidence = -(probs * log_probs).sum(dim=-1)
81
+
82
  return confidence, x0
83
 
84
 
85
+ def get_num_transfer_tokens_maskgit(mask_index: torch.Tensor, steps: int, mode: str = "linear") -> torch.Tensor:
86
+ """
87
+ LLaDA 风格:预计算每一步要“转移(解码)”的 token 数(逐样本),保证总量等于总 mask 数。
88
+ mask_index: [B, L] bool
89
+ return: [B, steps] long
90
+ """
91
+ device = mask_index.device
92
+ num_masked_tokens = mask_index.sum(dim=-1, keepdim=True).float() # [B,1]
93
+
94
+ t = torch.linspace(0, 1, steps + 1, device=device)[1:] # (steps,)
95
+ if mode == "linear":
96
+ ratio = t
97
+ elif mode == "cosine":
98
+ ratio = 1 - torch.cos(t * torch.pi / 2)
99
+ elif mode == "pow2":
100
+ ratio = t ** 2
101
+ elif mode == "sqrt":
102
+ ratio = torch.sqrt(t)
103
+ else:
104
+ raise ValueError(f"Unknown mode: {mode}")
105
+
106
+ # 累积配额(四舍五入),再做差得到每步配额
107
+ cum = (ratio.unsqueeze(0) * num_masked_tokens).round().long() # [B, steps]
108
+ per_step = torch.diff(cum, dim=-1, prepend=torch.zeros_like(cum[:, :1]))
109
+ return per_step # [B, steps], 每行之和 ≈ num_masked_tokens(四舍五入引入±1 误差)
110
+
111
+
112
  @dataclass
113
  class DreamModelOutput(ModelOutput):
114
  sequences: torch.LongTensor = None
 
125
  # diffusion specific params
126
  self.eps: float = kwargs.pop("eps", 1e-3)
127
  self.steps: int = kwargs.pop("steps", 512)
128
+ self.alg: str = kwargs.pop("alg", "origin")
129
  self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
130
 
131
  # RCR specific parameters
132
  self.rcr: bool = kwargs.pop("rcr", False)
133
+ self.conf_alg: str = kwargs.pop("conf_alg", "maskgit_plus")
134
+ self.mode: str = kwargs.pop("mode", "linear") # LLaDA 调度
135
 
136
+ # Output control
137
  self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
138
  self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
139
  self.output_history: bool = kwargs.pop("output_history", False)
140
 
141
+ # Special tokens
142
  self.mask_token_id = kwargs.pop("mask_token_id", None)
143
  self.pad_token_id = kwargs.pop("pad_token_id", None)
144
  self.bos_token_id = kwargs.pop("bos_token_id", None)
 
147
  # Wild card
148
  self.generation_kwargs = kwargs.pop("generation_kwargs", {})
149
 
150
+ # Hub info
 
151
  self._from_model_config = kwargs.pop("_from_model_config", False)
152
  self._commit_hash = kwargs.pop("_commit_hash", None)
153
  self.transformers_version = kwargs.pop("transformers_version", __version__)
154
 
 
155
  if not self._from_model_config:
 
 
156
  for key, value in kwargs.items():
157
  try:
158
  setattr(self, key, value)
 
160
  logger.error(f"Can't set {key} with value {value} for {self}")
161
  raise err
162
 
 
163
  self.validate(is_init=True)
164
 
165
  def validate(self, is_init=False):
166
  pass
167
 
168
+
169
  class DreamGenerationMixin:
170
  @staticmethod
171
  def _expand_inputs_for_generation(
172
  expand_size: int = 1,
173
  input_ids: Optional[torch.LongTensor] = None,
174
+ attention_mask: Optional[torch.LongTensor] = None,
175
  ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
 
 
 
176
  if expand_size == 1:
177
  return input_ids, attention_mask
178
  if input_ids is not None:
 
181
  attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
182
  return input_ids, attention_mask
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
 
 
 
185
  if is_torchdynamo_compiling():
186
  return
 
 
187
  if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
 
188
  warnings.warn(
189
+ f"Using the model-agnostic default `max_length` (={generation_config.max_length}); "
190
+ f"prefer setting `max_new_tokens`.",
 
191
  UserWarning,
192
  )
193
  if input_ids_length >= generation_config.max_length:
 
194
  raise ValueError(
195
+ f"Input length is {input_ids_length}, but `max_length` is set to {generation_config.max_length}. "
196
+ f"Increase `max_length` or set `max_new_tokens`."
 
197
  )
198
 
199
+ def _prepare_generated_length(self, generation_config, has_default_max_length, input_ids_length):
 
 
 
 
 
 
 
200
  if generation_config.max_new_tokens is not None:
201
  if not has_default_max_length and generation_config.max_length is not None:
202
  logger.warning(
203
+ f"Both `max_new_tokens` and `max_length` set. `max_new_tokens` takes precedence."
 
 
 
204
  )
205
  generation_config.max_length = generation_config.max_new_tokens + input_ids_length
 
206
  elif has_default_max_length:
207
  if generation_config.max_length == DreamGenerationConfig().max_length:
208
  generation_config.max_length = generation_config.max_length + input_ids_length
209
  max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
210
  if max_position_embeddings is not None:
211
  generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
 
212
  return generation_config
213
 
214
  def _prepare_generation_config(
215
  self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict
216
  ) -> DreamGenerationConfig:
 
 
 
 
 
217
  using_model_generation_config = False
218
  if generation_config is None:
219
  generation_config = DreamGenerationConfig.from_model_config(self.config)
220
  using_model_generation_config = True
221
 
 
 
 
222
  if not is_torchdynamo_compiling():
223
  generation_config = copy.deepcopy(generation_config)
224
  _kwargs = generation_config.update(**kwargs)
 
225
  if not using_model_generation_config:
226
  if generation_config.bos_token_id is None:
227
  generation_config.bos_token_id = self.generation_config.bos_token_id
 
239
  generation_config: DreamGenerationConfig,
240
  device: Optional[Union[torch.device, str]] = None,
241
  ):
 
 
 
 
 
 
 
 
 
 
242
  def _tensor_or_none(token, device=None):
243
  if token is None:
244
  return token
 
245
  device = device if device is not None else self.device
246
  if isinstance(token, torch.Tensor):
247
  return token.to(device)
 
252
  pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
253
  mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)
254
 
 
255
  if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
256
  eos_token_tensor = eos_token_tensor.unsqueeze(0)
257
 
 
258
  if pad_token_tensor is None and eos_token_tensor is not None:
259
  pad_token_tensor = eos_token_tensor[0]
260
  logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
261
 
 
 
 
 
262
  generation_config._bos_token_tensor = bos_token_tensor
263
  generation_config._eos_token_tensor = eos_token_tensor
264
  generation_config._pad_token_tensor = pad_token_tensor
 
271
  generation_config: Optional[DreamGenerationConfig] = None,
272
  **kwargs,
273
  ) -> Union[DreamModelOutput, torch.LongTensor]:
 
274
  generation_config = self._prepare_generation_config(generation_config, **kwargs)
275
  generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
276
  generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
277
 
 
278
  assert inputs is not None
279
  input_ids = inputs
280
  device = input_ids.device
281
  attention_mask = kwargs.pop("attention_mask", None)
282
  self._prepare_special_tokens(generation_config, device=device)
283
 
 
284
  input_ids_length = input_ids.shape[-1]
285
  has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
286
  generation_config = self._prepare_generated_length(
 
288
  has_default_max_length=has_default_max_length,
289
  input_ids_length=input_ids_length,
290
  )
 
291
  self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
292
+
 
293
  if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
294
  warnings.warn(
295
+ "You are calling .generate() with input_ids on a different device than the model.",
 
 
 
 
 
296
  UserWarning,
297
  )
298
+ if hasattr(generation_config, "pad_token_id") and torch.any(input_ids == generation_config.pad_token_id) and attention_mask is None:
 
 
 
 
299
  warnings.warn(
300
+ "Padding detected but no attention_mask is passed. For correct results, pass attention_mask.",
 
301
  UserWarning,
302
  )
303
 
304
  input_ids, attention_mask = self._expand_inputs_for_generation(
305
  expand_size=generation_config.num_return_sequences,
306
  input_ids=input_ids,
307
+ attention_mask=attention_mask,
308
  )
309
 
310
  result = self._sample(
 
312
  attention_mask=attention_mask,
313
  generation_config=generation_config,
314
  generation_tokens_hook_func=generation_tokens_hook_func,
315
+ generation_logits_hook_func=generation_logits_hook_func,
316
  )
317
  return result
318
 
 
322
  attention_mask: Optional[torch.LongTensor],
323
  generation_config: DreamGenerationConfig,
324
  generation_tokens_hook_func,
325
+ generation_logits_hook_func,
326
  ) -> Union[DreamModelOutput, torch.LongTensor]:
327
+
328
+ # --- init values ---
329
  output_history = generation_config.output_history
330
  return_dict_in_generate = generation_config.return_dict_in_generate
331
  max_length = generation_config.max_length
 
338
  top_p = generation_config.top_p
339
  top_k = generation_config.top_k
340
 
341
+ # RCR specific
342
  rcr = generation_config.rcr
343
  conf_alg = generation_config.conf_alg
344
+ mode = generation_config.mode
345
 
346
  histories = [] if (return_dict_in_generate and output_history) else None
347
 
348
+ # pad to max_length with [MASK]
349
  x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
350
 
351
  if attention_mask is not None and torch.any(attention_mask == 0.0):
 
352
  attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
353
  tok_idx = attention_mask.long().cumsum(-1) - 1
354
  tok_idx.masked_fill_(attention_mask == 0, 1)
 
 
355
  attention_mask = torch.logical_and(
356
  attention_mask.unsqueeze(1).unsqueeze(-2),
357
  attention_mask.unsqueeze(1).unsqueeze(-1),
 
360
  tok_idx = None
361
  attention_mask = "full"
362
 
363
+ # global linear schedule 1 -> eps
364
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
365
 
366
+ # 初始 mask(用于预分配 per-step token 预算;与 LLaDA 类似)
367
+ initial_mask_index = (x == mask_token_id) # [B, L]
368
+ per_step_tokens = get_num_transfer_tokens_maskgit(initial_mask_index, steps, mode=mode) # [B, steps]
369
+
370
+ # RCR tracking
371
  overtime_confidence = torch.zeros_like(x, dtype=torch.float32) if rcr else None
372
 
373
+ # user-defined token control
374
  x = generation_tokens_hook_func(None, x, None)
375
+
376
  for i in range(steps):
377
+ # 当前还未确定的 mask 位置
378
+ mask_index = (x == mask_token_id) # [B, L]
379
+
380
+ # 模型 logits(单步预测 + 向右对齐)
381
  logits = self(x, attention_mask, tok_idx).logits
382
+ logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
383
 
384
+ # user-defined logits control
385
  logits = generation_logits_hook_func(i, x, logits)
386
 
387
+ # 只取 mask 位置对应的 logits 参与采样
388
+ mask_logits = logits[mask_index] # [M, V] (M=mask 个数)
389
  t = timesteps[i]
390
  s = timesteps[i + 1]
391
+
392
+ if alg == "origin":
393
+ # Dream 迁移(保留)
394
+ p_transfer = 1 - (s / t).item() if i < steps - 1 else 1.0
395
+ x0 = torch.zeros_like(x[mask_index], device=x.device, dtype=torch.long) + mask_token_id
396
+ transfer_index_t_s = (torch.rand(*x0.shape, device=x.device) < p_transfer)
397
+ _, x0[transfer_index_t_s] = sample_tokens(
398
+ mask_logits[transfer_index_t_s],
399
+ temperature=temperature,
400
+ top_p=top_p,
401
+ top_k=top_k,
402
+ )
403
  x[mask_index] = x0.clone()
404
  else:
405
+ # 选择置信度算法:RCR 时优先 conf_alg;非 RCR 时用 alg 的同名变体
406
+ choose = conf_alg if rcr else alg
407
+ if choose == "maskgit_plus":
408
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
409
+ elif choose == "topk_margin":
410
  confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
411
+ elif choose == "entropy":
412
+ confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
413
  else:
414
+ raise RuntimeError(f"Unknown alg/conf_alg: {choose}")
415
+
416
+ # 将预测/置信度写回到全长(非 mask 位置用原 token / -inf)
417
+ full_conf = torch.full_like(x, -torch.inf, device=x.device, dtype=logits.dtype) # [B, L]
418
+ x_temp = torch.zeros_like(x, device=x.device, dtype=torch.long) + mask_token_id # [B, L]
419
+ x_temp[mask_index] = x0.clone()
420
+ full_conf[mask_index] = confidence
421
+
422
+ if not rcr:
423
+ # ---------- 非 RCR:逐样本的“当步配额” ----------
424
+ k_per_row = per_step_tokens[:, i] # [B]
425
+ B = x.size(0)
426
+ for j in range(B):
427
+ k_j = int(k_per_row[j].item())
428
+ if k_j <= 0:
429
+ continue
430
+ # clamp:不能超过当前样本剩余 mask 数
431
+ masked_count_j = mask_index[j].sum().item()
432
+ k_j = min(k_j, int(masked_count_j))
433
+ if k_j <= 0:
434
+ continue
435
+ # 只在 mask 内选 topk(full_conf 的非 mask 处已是 -inf)
436
+ _, select_idx = torch.topk(full_conf[j], k_j, largest=True)
437
+ x[j, select_idx] = x_temp[j, select_idx]
438
 
 
 
 
 
 
439
  else:
440
+ # ---------- RCR:LLaDA 风格的“累积选取 + 下一步反遮盖” ----------
441
+ B = x.size(0)
442
+ for j in range(B):
443
+ # 当步+未来的总剩余配额(从第 i 步到最后一步)
444
+ total_remaining_tokens = int(per_step_tokens[j, i:].sum().item())
445
+ if total_remaining_tokens <= 0:
446
+ continue
447
+
448
+ masked_count_j = mask_index[j].sum().item()
449
+ k_total = min(total_remaining_tokens, int(masked_count_j))
450
+ if k_total <= 0:
451
+ continue
452
+
453
+ # 1) 累积选取:一次性选出“当步至结尾”应确定的 token 集合
454
+ # (在 mask 内的 topk)
455
+ _, select_indices = torch.topk(full_conf[j], k_total, largest=True)
456
+ x[j, select_indices] = x_temp[j, select_indices]
457
+ overtime_confidence[j, select_indices] = full_conf[j, select_indices].clone().float()
458
+
459
+ # 2) 下一步前:把“下一步之后还应保留给未来步数的那部分”按最低置信度反遮盖回去
460
+ if i < (steps - 1):
461
+ next_to_keep_for_future = int(per_step_tokens[j, i + 1 :].sum().item())
462
+ if next_to_keep_for_future > 0:
463
+ # 仅在“已选中的位置”(overtime_confidence>0)里,反遮盖最低置信度的那部分
464
+ current_conf = overtime_confidence[j]
465
+ # 把 0 置信度(未生成)位置临时设成 +inf,避免被误选为“最低”
466
+ safe_conf = torch.where(current_conf == 0.0, torch.tensor(float("inf"), device=x.device), current_conf)
467
+ # 需要反遮盖的数量不应超过当前已选中的数
468
+ gen_count = (safe_conf != float("inf")).sum().item()
469
+ k_remask = min(next_to_keep_for_future, int(gen_count))
470
+ if k_remask > 0:
471
+ # 选“最不自信”的 k_remask 个
472
+ _, local_mask_indices = torch.topk(safe_conf, k_remask, largest=False)
473
+ x[j, local_mask_indices] = mask_token_id
474
+ overtime_confidence[j, local_mask_indices] = 0.0 # 清零表示撤回
475
+
476
+ # user-defined token control
477
  x = generation_tokens_hook_func(i, x, logits)
478
 
479
  if histories is not None:
480
  histories.append(x.clone())
481
+
482
  if return_dict_in_generate:
483
+ return DreamModelOutput(sequences=x, history=histories)
 
 
 
484
  else:
485
+ return x