| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import dataclasses |
| import warnings |
| from typing import Dict, Any |
|
|
| from transformers.utils import is_flash_attn_2_available |
|
|
| from .block_config import BlockConfig |
| from .transformers_4_44_2__configuration_llama import LlamaConfig |
| from .transformers_4_44_2__modeling_rope_utils import \ |
| rope_config_validation |
|
|
| rope_config_validation |
|
|
|
|
| class DeciLMConfig(LlamaConfig): |
| model_type = "nemotron-nas" |
|
|
| def __init__( |
| self, |
| block_configs: list[dict] | list[BlockConfig] = None, |
| **kwargs, |
| ): |
| attn_implementation = kwargs.pop("attn_implementation", None) |
| if attn_implementation is None and is_flash_attn_2_available(): |
| attn_implementation = "flash_attention_2" |
|
|
| if block_configs is not None: |
| if isinstance(block_configs[0], dict): |
| block_configs = [BlockConfig(**conf) for conf in block_configs] |
|
|
| using_unshifted_sink = any([block_config.attention.unshifted_sink for block_config in block_configs]) |
| if using_unshifted_sink and attn_implementation != "eager": |
| warnings.warn("Forcing attn_implementation='eager' since some attention layers use unshifted sink") |
| attn_implementation = "eager" |
|
|
| super().__init__(attn_implementation=attn_implementation, **kwargs) |
|
|
| self.intermediate_size = None |
| self.num_key_value_heads = None |
|
|
| if block_configs is not None: |
| assert len(block_configs) == self.num_hidden_layers |
|
|
| self.block_configs: list[BlockConfig] = block_configs |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| self_dict = super().to_dict() |
| if self.block_configs is not None: |
| self_dict["block_configs"] = [dataclasses.asdict(conf) for conf in self.block_configs] |
| return self_dict |
|
|