| from transformers import PretrainedConfig |
| import json |
|
|
|
|
| class StripedHyenaConfig(PretrainedConfig): |
| model_type = "stripedhyena" |
|
|
| def __init__( |
| self, |
| vocab_size=32000, |
| hidden_size=4096, |
| num_filters=4096, |
| inner_mlp_size=14336, |
| attn_layer_idxs=[], |
| hyena_layer_idxs=[], |
| num_layers=32, |
| tie_embeddings=False, |
| short_filter_length=3, |
| num_attention_heads=32, |
| proj_groups=4, |
| hyena_filter_groups=1, |
| split_k0=True, |
| column_split_hyena=True, |
| column_split=False, |
| model_parallel_size=1, |
| pipe_parallel_size=1, |
| short_filter_bias=True, |
| mha_out_proj_bias=False, |
| qkv_proj_bias=False, |
| final_norm=True, |
| use_cache=True, |
| use_flash_attention_2=True, |
| use_flash_rmsnorm=True, |
| use_flash_depthwise=False, |
| use_flashfft=False, |
| inference_mode=False, |
| prefill_style="fft", |
| max_seqlen=32768, |
| eps=1e-5, |
| state_size=2, |
| rotary_emb_base=500000, |
| smeared_gqa=False, |
| make_vocab_size_divisible_by=8, |
| log_intermediate_values=False, |
| **kwargs, |
| ): |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.num_filters = num_filters |
| self.inner_mlp_size = inner_mlp_size |
| self.attn_layer_idxs = attn_layer_idxs |
| self.hyena_layer_idxs = hyena_layer_idxs |
| self.num_layers = num_layers |
| self.tie_embeddings = tie_embeddings |
| self.short_filter_length = short_filter_length |
| self.num_attention_heads = num_attention_heads |
| self.proj_groups = proj_groups |
| self.hyena_filter_groups = hyena_filter_groups |
| self.split_k0 = split_k0 |
| self.column_split_hyena = column_split_hyena |
| self.column_split = column_split |
| self.model_parallel_size = model_parallel_size |
| self.pipe_parallel_size = pipe_parallel_size |
| self.short_filter_bias = short_filter_bias |
| self.mha_out_proj_bias = mha_out_proj_bias |
| self.qkv_proj_bias = qkv_proj_bias |
| self.final_norm = final_norm |
| self.use_cache = use_cache |
| self.use_flash_attention_2 = use_flash_attention_2 |
| self.use_flash_rmsnorm = use_flash_rmsnorm |
| self.use_flash_depthwise = use_flash_depthwise |
| self.use_flashfft = use_flashfft |
| self.inference_mode = inference_mode |
| self.prefill_style = prefill_style |
| self.max_seqlen = max_seqlen |
| self.eps = eps |
| self.state_size = state_size |
| self.rotary_emb_base = rotary_emb_base |
| self.smeared_gqa = smeared_gqa |
| self.make_vocab_size_divisible_by = make_vocab_size_divisible_by |
| self.log_intermediate_values = log_intermediate_values |
| super().__init__(**kwargs) |
|
|
| def to_dict(self): |
| return {attr: getattr(self, attr) for attr in self.__dict__} |
|
|
| @classmethod |
| def from_original_config(cls, config_path, **kwargs): |
| with open(config_path, "r") as f: |
| config = json.load(f) |
|
|
| return cls(**config, **kwargs) |
|
|