materials.str-bamba / str_bamba /bamba_config.py
victor-shirasuna
Upload files
3d83373
raw
history blame contribute delete
710 Bytes
from dataclasses import dataclass, field
@dataclass
class BambaConfig:
d_model: int = 2560
d_intermediate: int = 0
n_layer: int = 64
vocab_size: int = 50277
max_position_embeddings: int = 262144
ssm_cfg: dict = field(default_factory=dict)
attn_layer_idx: list = field(default_factory=list)
attn_cfg: dict = field(default_factory=dict)
rms_norm: bool = True
residual_in_fp32: bool = True
fused_add_norm: bool = True
pad_vocab_size_multiple: int = 8
tie_embeddings: bool = True
@dataclass
class BambaEncoderDecoderConfig:
encoder_config: BambaConfig = None
decoder_config: BambaConfig = None
tie_word_embeddings: bool = True
seed: int = 0