Update generation_utils.py
Browse files- generation_utils.py +89 -0
generation_utils.py
CHANGED
|
@@ -420,6 +420,95 @@ class DimpleGenerationMixin:
|
|
| 420 |
# tokenizer=None, # only for debug, need to be removed !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 421 |
**kwargs,
|
| 422 |
) -> Union[DimpleModelOutput, torch.LongTensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
| 424 |
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
| 425 |
generation_tokens_hook_func = model_kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
|
|
|
|
| 420 |
# tokenizer=None, # only for debug, need to be removed !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 421 |
**kwargs,
|
| 422 |
) -> Union[DimpleModelOutput, torch.LongTensor]:
|
| 423 |
+
|
| 424 |
+
"""
|
| 425 |
+
Generates sequences using a diffusion-based masked token denoising algorithm.
|
| 426 |
+
|
| 427 |
+
This method replaces masked tokens in `inputs` through iterative refinement, based on a denoising process
|
| 428 |
+
inspired by diffusion models. It uses intermediate confidence-based sampling to progressively fill in masked tokens.
|
| 429 |
+
|
| 430 |
+
Args:
|
| 431 |
+
inputs (torch.Tensor):
|
| 432 |
+
Input token IDs.
|
| 433 |
+
generation_config (DimpleGenerationConfig, optional):
|
| 434 |
+
An instance of `DimpleGenerationConfig` containing generation hyperparameters. If not provided,
|
| 435 |
+
the default generation config from the model is used.
|
| 436 |
+
**kwargs:
|
| 437 |
+
Additional generation parameters that override those in `generation_config`.
|
| 438 |
+
|
| 439 |
+
Returns:
|
| 440 |
+
DimpleModelOutput if `return_dict_in_generate=True`, else `torch.LongTensor` of generated token IDs.
|
| 441 |
+
|
| 442 |
+
Key Parameters (either in `generation_config` or passed via kwargs):
|
| 443 |
+
|
| 444 |
+
- `max_new_tokens` (int, default=None):
|
| 445 |
+
The number of new tokens to generate or fill in. This sets the target length of the generated sequence beyond
|
| 446 |
+
the prompt. It is added to the input length to determine the total sequence length.
|
| 447 |
+
|
| 448 |
+
- `output_history` (bool, default=False):
|
| 449 |
+
If `True`, returns the full sequence history at each denoising step. This is useful for visualization or debugging
|
| 450 |
+
purposes. Only returned if `return_dict_in_generate=True`.
|
| 451 |
+
|
| 452 |
+
- `return_dict_in_generate` (bool, default=False):
|
| 453 |
+
If `True`, returns a `DimpleModelOutput` dictionary containing the final sequences and, optionally, the stepwise history.
|
| 454 |
+
If `False`, returns a plain tensor of token IDs.
|
| 455 |
+
|
| 456 |
+
- `steps` (int, default=512):
|
| 457 |
+
The number of denoising steps to perform during generation. Each step progressively refines the sequence by replacing
|
| 458 |
+
some masked tokens based on a sampling algorithm.
|
| 459 |
+
|
| 460 |
+
- `temperature` (float, default=0.0):
|
| 461 |
+
Sampling temperature applied to logits before softmax. Lower values make outputs more deterministic,
|
| 462 |
+
while higher values allow for more randomness in token selection.
|
| 463 |
+
|
| 464 |
+
- `top_p` (float, default=None):
|
| 465 |
+
Nucleus sampling parameter. If set, only the most probable tokens whose cumulative probability exceeds `top_p`
|
| 466 |
+
are considered during sampling.
|
| 467 |
+
|
| 468 |
+
- `alg` (str, default="origin"):
|
| 469 |
+
The denoising algorithm to use for determining which tokens to replace at each step. Options include:
|
| 470 |
+
- `"origin"`: random token selection based on a probability ratio.
|
| 471 |
+
- `"origin-ratio"`: like `"origin"` but uses continuous transfer ratio.
|
| 472 |
+
- `"autoregressive"`: always fills the left-most masked token.
|
| 473 |
+
- `"maskgit_plus"`: confidence-based selection similar to Google's MaskGIT.
|
| 474 |
+
- `"topk_margin"`: token selection based on margin (top1 - top2 probability).
|
| 475 |
+
- `"entropy"`: prioritizes tokens with high negative entropy (uncertainty).
|
| 476 |
+
|
| 477 |
+
- `use_cache` (bool, default=False):
|
| 478 |
+
Enables prefilling of past key values (past KV) for efficient decoding.
|
| 479 |
+
|
| 480 |
+
- `alg_p_threshold` (float, optional, default=None):
|
| 481 |
+
A confidence threshold used to determine whether a token is confident enough to be selected. If the token's
|
| 482 |
+
confidence is above this value, it is unmasked and committed to the sequence. Helps stabilize generation.
|
| 483 |
+
|
| 484 |
+
- `use_original_confidence` (bool, default=True):
|
| 485 |
+
If `True`, confidence scores are computed using the original (pre-sampled) probability distribution.
|
| 486 |
+
If `False`, uses the current step's softmaxed logits. Enables more stable token selection in some cases.
|
| 487 |
+
|
| 488 |
+
- `decoding_pipeline` (str, default="dim"):
|
| 489 |
+
The generation decoding pipeline to use:
|
| 490 |
+
- `"dim"`: Dimple decoding pipeline.
|
| 491 |
+
- `"dream"`: Original DREAM token selection pipeline.
|
| 492 |
+
|
| 493 |
+
Example:
|
| 494 |
+
```python
|
| 495 |
+
output = model.diffusion_generate(
|
| 496 |
+
inputs=input_ids,
|
| 497 |
+
max_new_tokens=64,
|
| 498 |
+
output_history=True,
|
| 499 |
+
return_dict_in_generate=True,
|
| 500 |
+
steps=64,
|
| 501 |
+
temperature=0.2,
|
| 502 |
+
top_p=0.95,
|
| 503 |
+
alg="origin",
|
| 504 |
+
use_cache=True,
|
| 505 |
+
alg_p_threshold=0.95,
|
| 506 |
+
use_original_confidence=True,
|
| 507 |
+
decoding_pipeline="dim"
|
| 508 |
+
)
|
| 509 |
+
```
|
| 510 |
+
"""
|
| 511 |
+
|
| 512 |
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
| 513 |
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
| 514 |
generation_tokens_hook_func = model_kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
|