File size: 7,319 Bytes
45e7dfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
config.py -- SpikeWhale: combined config from SpikeTransformer (My Project) + NanoWhale (DeepSeek-V4).

Features carried from My Project (not in NanoWhale):
  - DERF attention: erf(alpha*score+bias)*gamma replaces softmax
  - XSA (Exclusive Self-Attention): orthogonality correction removes self-echo from attn output
  - Engram N-gram module: hash-table N-gram lookup with DERF gate injected into embeddings
  - Three-tier optimizer: embed/table params trained at lower LR

Features carried from NanoWhale (not in My Project):
  - MLA (Multi-Head Latent Attention): low-rank Q projection + direct K,V (MQA)
  - Partial RoPE: rotary embeddings on only qk_rope_head_dim dims of Q and K
  - Low-rank grouped output projection (o_lora_rank)
  - Hyper-Connections: hc_mult residual streams with learned routing between layers
  - Shared expert in MoE (always-active expert alongside routed experts)
  - sqrtsoftplus expert scoring (vs softmax in My Project)
  - Hash-based routing for first num_hash_layers layers
  - norm_topk_prob + routed_scaling_factor
  - Multi-Token Prediction (MTP): extra heads predict k steps ahead
  - torch.compile, FineWeb-Edu streaming, Trackio, YAML configs in train.py
"""

from transformers import PretrainedConfig


class SpikeWhaleConfig(PretrainedConfig):
    model_type = "spike_whale"

    def __init__(
        self,
        # Standard
        vocab_size: int = 129280,
        hidden_size: int = 2048,
        num_hidden_layers: int = 11,
        max_position_embeddings: int = 8192,
        rms_norm_eps: float = 1e-6,
        initializer_range: float = 0.02,
        tie_word_embeddings: bool = False,
        hidden_dropout: float = 0.0,
        bos_token_id: int = 0,
        eos_token_id: int = 1,
        # MLA Attention (NanoWhale)
        num_attention_heads: int = 8,
        num_key_value_heads: int = 1,   # 1 = MQA; >1 = GQA
        q_lora_rank: int = 160,         # low-rank Q: hidden -> q_lora_rank -> num_heads*head_dim
        head_dim: int = 96,             # total per-head dim = nope_head_dim + qk_rope_head_dim
        qk_rope_head_dim: int = 32,     # RoPE applied only to these dims
        o_lora_rank: int = 80,          # low-rank output: num_heads*head_dim -> o_lora_rank -> hidden
        attention_dropout: float = 0.0,
        rope_theta: float = 10000.0,
        # DERF + XSA (My Project)
        use_derf: bool = True,
        use_xsa: bool = True,
        # MoE (combined)
        use_moe: bool = True,
        moe_intermediate_size: int = 640,
        n_routed_experts: int = 4,
        n_shared_experts: int = 1,      # NanoWhale: always-active shared expert
        num_experts_per_tok: int = 2,
        norm_topk_prob: bool = True,    # NanoWhale: normalize top-k routing weights
        scoring_func: str = "sqrtsoftplus",  # NanoWhale: sqrt(softplus(x)) vs softmax
        routed_scaling_factor: float = 1.0,  # NanoWhale: scale routed expert weights
        num_hash_layers: int = 2,       # NanoWhale: first N layers use hash routing
        moe_aux_loss_coef: float = 0.01,
        moe_layers: list = None,
        # Hyper-Connections (NanoWhale)
        use_hyper_connections: bool = True,
        hc_mult: int = 4,               # number of parallel residual streams
        hc_sinkhorn_iters: int = 20,
        hc_eps: float = 1e-6,
        # Multi-Token Prediction (NanoWhale)
        num_nextn_predict_layers: int = 1,  # extra MTP heads (0 = disabled)
        # Engram N-gram module (My Project)
        use_engram: bool = True,
        engram_compress_dim: int = 64,
        engram_num_heads: int = 4,
        engram_table_size: int = 8192,
        engram_max_ngram: int = 3,
        engram_gate_init_bias: float = -4.0,
        use_hrm_refine: bool = False,
        hrm_refine_steps: int = 3,
        hrm_refine_dim: int = 256,
        # --- ModularMind-on-V2 additions (off/unused unless enabled) ---
        use_latent_io: bool = False,   # add latent output head + injection input path
        d_latent: int = 256,           # RecursiveLink contract dim (fixed across chain)
        chain_position: int = 0,       # context-doubling slot: ctx & theta scale by 2^pos
        base_context: int = 8192,      # ctx at position 0 (>= training --seq-len)
        base_rope_theta: float = 10000.0,
        **kwargs,
    ):
        super().__init__(
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.max_position_embeddings = max_position_embeddings
        self.rms_norm_eps = rms_norm_eps
        self.initializer_range = initializer_range
        self.hidden_dropout = hidden_dropout

        self.num_attention_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        self.q_lora_rank = q_lora_rank
        self.head_dim = head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.nope_head_dim = head_dim - qk_rope_head_dim
        self.o_lora_rank = o_lora_rank
        self.attention_dropout = attention_dropout
        self.rope_theta = rope_theta
        self.use_derf = use_derf
        self.use_xsa = use_xsa

        self.use_moe = use_moe
        self.moe_intermediate_size = moe_intermediate_size
        self.n_routed_experts = n_routed_experts
        self.n_shared_experts = n_shared_experts
        self.num_experts_per_tok = num_experts_per_tok
        self.norm_topk_prob = norm_topk_prob
        self.scoring_func = scoring_func
        self.routed_scaling_factor = routed_scaling_factor
        self.num_hash_layers = num_hash_layers
        self.moe_aux_loss_coef = moe_aux_loss_coef
        self.moe_layers = moe_layers if moe_layers is not None else list(range(num_hidden_layers))

        self.use_hyper_connections = use_hyper_connections
        self.hc_mult = hc_mult
        self.hc_sinkhorn_iters = hc_sinkhorn_iters
        self.hc_eps = hc_eps

        self.num_nextn_predict_layers = num_nextn_predict_layers

        self.use_engram = use_engram
        self.engram_compress_dim = engram_compress_dim
        self.engram_num_heads = engram_num_heads
        self.engram_table_size = engram_table_size
        self.engram_max_ngram = engram_max_ngram
        self.engram_gate_init_bias = engram_gate_init_bias
        self.use_hrm_refine = use_hrm_refine
        self.hrm_refine_steps = hrm_refine_steps
        self.hrm_refine_dim = hrm_refine_dim
        # --- ModularMind-on-V2 additions ---
        self.use_latent_io = use_latent_io
        self.d_latent = d_latent
        self.chain_position = chain_position
        self.base_context = base_context
        self.base_rope_theta = base_rope_theta
        # Context-doubling: each chain slot doubles ctx and rope theta.
        # position 0 -> (8192, 10000); position 1 -> (16384, 20000); etc.
        # Applied only when latent IO is on (i.e. this is a ModularMind specialist),
        # so plain V2 keeps its own max_position_embeddings/rope_theta untouched.
        if use_latent_io:
            scale = 2 ** chain_position
            self.max_position_embeddings = base_context * scale
            self.rope_theta = base_rope_theta * scale