rp-yu commited on
Commit
8527a33
·
verified ·
1 Parent(s): 742c817

Update generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +89 -0
generation_utils.py CHANGED
@@ -420,6 +420,95 @@ class DimpleGenerationMixin:
420
  # tokenizer=None, # only for debug, need to be removed !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
421
  **kwargs,
422
  ) -> Union[DimpleModelOutput, torch.LongTensor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
424
  generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
425
  generation_tokens_hook_func = model_kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
 
420
  # tokenizer=None, # only for debug, need to be removed !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
421
  **kwargs,
422
  ) -> Union[DimpleModelOutput, torch.LongTensor]:
423
+
424
+ """
425
+ Generates sequences using a diffusion-based masked token denoising algorithm.
426
+
427
+ This method replaces masked tokens in `inputs` through iterative refinement, based on a denoising process
428
+ inspired by diffusion models. It uses intermediate confidence-based sampling to progressively fill in masked tokens.
429
+
430
+ Args:
431
+ inputs (torch.Tensor):
432
+ Input token IDs.
433
+ generation_config (DimpleGenerationConfig, optional):
434
+ An instance of `DimpleGenerationConfig` containing generation hyperparameters. If not provided,
435
+ the default generation config from the model is used.
436
+ **kwargs:
437
+ Additional generation parameters that override those in `generation_config`.
438
+
439
+ Returns:
440
+ DimpleModelOutput if `return_dict_in_generate=True`, else `torch.LongTensor` of generated token IDs.
441
+
442
+ Key Parameters (either in `generation_config` or passed via kwargs):
443
+
444
+ - `max_new_tokens` (int, default=None):
445
+ The number of new tokens to generate or fill in. This sets the target length of the generated sequence beyond
446
+ the prompt. It is added to the input length to determine the total sequence length.
447
+
448
+ - `output_history` (bool, default=False):
449
+ If `True`, returns the full sequence history at each denoising step. This is useful for visualization or debugging
450
+ purposes. Only returned if `return_dict_in_generate=True`.
451
+
452
+ - `return_dict_in_generate` (bool, default=False):
453
+ If `True`, returns a `DimpleModelOutput` dictionary containing the final sequences and, optionally, the stepwise history.
454
+ If `False`, returns a plain tensor of token IDs.
455
+
456
+ - `steps` (int, default=512):
457
+ The number of denoising steps to perform during generation. Each step progressively refines the sequence by replacing
458
+ some masked tokens based on a sampling algorithm.
459
+
460
+ - `temperature` (float, default=0.0):
461
+ Sampling temperature applied to logits before softmax. Lower values make outputs more deterministic,
462
+ while higher values allow for more randomness in token selection.
463
+
464
+ - `top_p` (float, default=None):
465
+ Nucleus sampling parameter. If set, only the most probable tokens whose cumulative probability exceeds `top_p`
466
+ are considered during sampling.
467
+
468
+ - `alg` (str, default="origin"):
469
+ The denoising algorithm to use for determining which tokens to replace at each step. Options include:
470
+ - `"origin"`: random token selection based on a probability ratio.
471
+ - `"origin-ratio"`: like `"origin"` but uses continuous transfer ratio.
472
+ - `"autoregressive"`: always fills the left-most masked token.
473
+ - `"maskgit_plus"`: confidence-based selection similar to Google's MaskGIT.
474
+ - `"topk_margin"`: token selection based on margin (top1 - top2 probability).
475
+ - `"entropy"`: prioritizes tokens with high negative entropy (uncertainty).
476
+
477
+ - `use_cache` (bool, default=False):
478
+ Enables prefilling of past key values (past KV) for efficient decoding.
479
+
480
+ - `alg_p_threshold` (float, optional, default=None):
481
+ A confidence threshold used to determine whether a token is confident enough to be selected. If the token's
482
+ confidence is above this value, it is unmasked and committed to the sequence. Helps stabilize generation.
483
+
484
+ - `use_original_confidence` (bool, default=True):
485
+ If `True`, confidence scores are computed using the original (pre-sampled) probability distribution.
486
+ If `False`, uses the current step's softmaxed logits. Enables more stable token selection in some cases.
487
+
488
+ - `decoding_pipeline` (str, default="dim"):
489
+ The generation decoding pipeline to use:
490
+ - `"dim"`: Dimple decoding pipeline.
491
+ - `"dream"`: Original DREAM token selection pipeline.
492
+
493
+ Example:
494
+ ```python
495
+ output = model.diffusion_generate(
496
+ inputs=input_ids,
497
+ max_new_tokens=64,
498
+ output_history=True,
499
+ return_dict_in_generate=True,
500
+ steps=64,
501
+ temperature=0.2,
502
+ top_p=0.95,
503
+ alg="origin",
504
+ use_cache=True,
505
+ alg_p_threshold=0.95,
506
+ use_original_confidence=True,
507
+ decoding_pipeline="dim"
508
+ )
509
+ ```
510
+ """
511
+
512
  # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
513
  generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
514
  generation_tokens_hook_func = model_kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)