| from transformers import MistralConfig |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
| ssm_config_default = { |
| "d_state": 64, |
| "n_qk_heads": 32, |
| "expand": 1, |
| "chunk_size": 128, |
| "activation": "identity", |
| "bias": False, |
| "d_conv": 4, |
| "d_inner": 32 * 128, |
| "d_xb": None, |
| "dt_rank": "auto", |
| "dt_min": 0.001, |
| "dt_max": 0.1, |
| "dt_init": "random", |
| "dt_scale": 1.0, |
| "dt_init_floor": 1e-4, |
| "conv_bias": True, |
| } |
|
|
|
|
| class AprielHConfig(MistralConfig): |
| model_type = "apriel_h" |
|
|
| def __init__(self, hybrid_block_layout=["m2"], ssm_cfg=None, **kwargs): |
| super().__init__(**kwargs) |
| self.hybrid_block_layout = hybrid_block_layout |
| self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads |
| self.ssm_cfg = ssm_cfg or ssm_config_default |
|
|
| for k, v in ssm_config_default.items(): |
| if k not in self.ssm_cfg: |
| self.ssm_cfg[k] = v |
|
|