|
|
"""New file, authored myself: Hugging Face configuration file for StripedHyena-2 |
|
|
(Evo 2) models. Based on Together's configuration_hyena.py but extended to cover |
|
|
all Evo 2-specific hyper-parameters from the provided config.json. |
|
|
""" |
|
|
|
|
|
from typing import List, Optional |
|
|
import json |
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
|
|
|
|
|
|
class Evo2Config(PretrainedConfig): |
|
|
"""Configuration class for Evo 2 (StripedHyena-2) causal-LM checkpoints. |
|
|
|
|
|
Every keyword argument listed here has the same default value as in the |
|
|
reference `config.json`. Additional keys coming from the Hugging Face |
|
|
`PretrainedConfig` base (e.g. *bos_eos_token_id*) can still be supplied via |
|
|
**kwargs. |
|
|
""" |
|
|
|
|
|
model_type: str = "evo2" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
|
|
|
vocab_size: int = 512, |
|
|
hidden_size: int = 4096, |
|
|
tie_embeddings: bool = True, |
|
|
make_vocab_size_divisible_by: int = 8, |
|
|
|
|
|
num_hidden_layers: int = 32, |
|
|
num_layers: int = 32, |
|
|
num_attention_heads: int = 32, |
|
|
|
|
|
num_filters: int = 4096, |
|
|
inner_mlp_size: int = 11264, |
|
|
inner_size_multiple_of: int = 16, |
|
|
|
|
|
hcl_layer_idxs: Optional[List[int]] = None, |
|
|
hcm_layer_idxs: Optional[List[int]] = None, |
|
|
hcs_layer_idxs: Optional[List[int]] = None, |
|
|
attn_layer_idxs: Optional[List[int]] = None, |
|
|
|
|
|
hcm_filter_length: int = 128, |
|
|
hcl_filter_groups: int = 4096, |
|
|
hcm_filter_groups: int = 256, |
|
|
hcs_filter_groups: int = 256, |
|
|
hcs_filter_length: int = 7, |
|
|
short_filter_length: int = 3, |
|
|
short_filter_bias: bool = False, |
|
|
proj_groups: int = 1, |
|
|
hyena_filter_groups: int = 1, |
|
|
column_split_hyena: bool = False, |
|
|
column_split: bool = True, |
|
|
interleave: bool = True, |
|
|
|
|
|
mha_out_proj_bias: bool = True, |
|
|
hyena_out_proj_bias: bool = True, |
|
|
qkv_proj_bias: bool = False, |
|
|
use_fp8_input_projections: bool = True, |
|
|
|
|
|
mlp_init_method: str = "torch.nn.init.zeros_", |
|
|
mlp_output_init_method: str = "torch.nn.init.zeros_", |
|
|
|
|
|
eps: float = 1e-6, |
|
|
|
|
|
state_size: int = 16, |
|
|
rotary_emb_base: int = 100_000_000_000, |
|
|
rotary_emb_scaling_factor: int = 128, |
|
|
use_interpolated_rotary_pos_emb: bool = True, |
|
|
|
|
|
max_seqlen: int = 1_048_576, |
|
|
max_batch_size: int = 1, |
|
|
model_parallel_size: int = 1, |
|
|
pipe_parallel_size: int = 1, |
|
|
final_norm: bool = True, |
|
|
use_flash_attn: bool = True, |
|
|
use_flash_rmsnorm: bool = False, |
|
|
use_flash_depthwise: bool = False, |
|
|
use_flashfft: bool = False, |
|
|
use_laughing_hyena: bool = False, |
|
|
inference_mode: bool = True, |
|
|
tokenizer_type: str = "CharLevelTokenizer", |
|
|
prefill_style: str = "fft", |
|
|
mlp_activation: str = "gelu", |
|
|
print_activations: bool = False, |
|
|
log_intermediate_values: bool = False, |
|
|
|
|
|
hyena_flip_x1x2: bool = False, |
|
|
use_cache: bool = True, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
|
|
|
if hcl_layer_idxs is None: |
|
|
hcl_layer_idxs = [2, 6, 9, 13, 16, 20, 23, 27, 30] |
|
|
if hcm_layer_idxs is None: |
|
|
hcm_layer_idxs = [1, 5, 8, 12, 15, 19, 22, 26, 29] |
|
|
if hcs_layer_idxs is None: |
|
|
hcs_layer_idxs = [0, 4, 7, 11, 14, 18, 21, 25, 28] |
|
|
if attn_layer_idxs is None: |
|
|
attn_layer_idxs = [3, 10, 17, 24, 31] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
params = locals().copy() |
|
|
params.pop("self") |
|
|
kwargs_extra = params.pop("kwargs") |
|
|
|
|
|
|
|
|
super().__init__(**kwargs_extra) |
|
|
|
|
|
|
|
|
self.__dict__.update(params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_dict(self): |
|
|
""" |
|
|
Serializes this instance to a Python dictionary. Override the default |
|
|
to explicitly remove the '__class__' key, which is not JSON-serializable. |
|
|
""" |
|
|
|
|
|
output = super().to_dict() |
|
|
|
|
|
|
|
|
if "__class__" in output: |
|
|
del output["__class__"] |
|
|
|
|
|
return output |
|
|
|
|
|
@classmethod |
|
|
def from_original_config(cls, config_path: str, **kwargs): |
|
|
"""Load a config directly from a json file and convert.""" |
|
|
with open(config_path, "r", encoding="utf-8") as fp: |
|
|
original = json.load(fp) |
|
|
return cls(**original, **kwargs) |
|
|
|