| | """ |
| | RND1 Generation Configuration. |
| | |
| | This module defines the generation configuration for RND1 models, |
| | controlling the diffusion-based generation process. |
| | """ |
| |
|
| | from typing import Optional |
| | from transformers.generation.configuration_utils import GenerationConfig |
| |
|
| |
|
| | class RND1GenerationConfig(GenerationConfig): |
| | """ |
| | Configuration class for RND1 generation parameters. |
| | |
| | This class extends the base GenerationConfig to include parameters |
| | specific to diffusion-based language generation. |
| | |
| | Args: |
| | max_length: Maximum sequence length |
| | num_diffusion_steps: Number of denoising steps in the diffusion process |
| | mask_token_id: Token ID used for masking during diffusion |
| | temperature: Temperature for sampling (higher = more random) |
| | top_k: Optional top-k filtering |
| | top_p: Optional nucleus (top-p) filtering |
| | greedy: Whether to use greedy decoding (True) or stochastic sampling (False) |
| | **kwargs: Additional arguments passed to GenerationConfig |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | max_length: int = 256, |
| | num_diffusion_steps: int = 256, |
| | mask_token_id: int = 151669, |
| | temperature: float = 1.0, |
| | top_k: Optional[int] = None, |
| | top_p: Optional[float] = None, |
| | greedy: bool = True, |
| | bos_token_id: Optional[int] = None, |
| | eos_token_id: Optional[int] = None, |
| | pad_token_id: Optional[int] = None, |
| | use_cache: bool = False, |
| | **kwargs, |
| | ): |
| | |
| | kwargs.pop('use_cache', None) |
| |
|
| | super().__init__( |
| | max_length=max_length, |
| | bos_token_id=bos_token_id, |
| | eos_token_id=eos_token_id, |
| | pad_token_id=pad_token_id, |
| | temperature=temperature, |
| | top_k=top_k, |
| | top_p=top_p, |
| | do_sample=not greedy, |
| | use_cache=False, |
| | **kwargs, |
| | ) |
| |
|
| | |
| | self.num_diffusion_steps = num_diffusion_steps |
| | self.mask_token_id = mask_token_id |
| | self.greedy = greedy |
| | self.temperature = float(temperature) |
| |
|
| | def to_dict(self): |
| | """Convert configuration to dictionary.""" |
| | output = super().to_dict() |
| | output["num_diffusion_steps"] = self.num_diffusion_steps |
| | output["mask_token_id"] = self.mask_token_id |
| | output["greedy"] = self.greedy |
| | return output |