| |
| from transformers import PretrainedConfig |
|
|
|
|
| class NSAConfig(PretrainedConfig): |
| model_type = "nsa" |
|
|
| def __init__( |
| self, |
| vocab_size=50257, |
| hidden_size=768, |
| num_hidden_layers=12, |
| num_attention_heads=12, |
| n_kv_groups=1, |
| d_k=64, |
| d_v=64, |
| max_position_embeddings=2048, |
| rope_theta=10000, |
| nsa=None, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.n_kv_groups = n_kv_groups |
| self.d_k = d_k |
| self.d_v = d_v |
| self.max_position_embeddings = max_position_embeddings |
| self.rope_theta = rope_theta |
| self.nsa = nsa or { |
| "branches": ["cmp", "sel", "win"], |
| "window": 512, |
| "gqa_groups": n_kv_groups, |
| "block": 32, |
| "stride": 16, |
| "sel_block": 64, |
| "sel_top_n": 16, |
| } |
|
|