Spaces:
Running on Zero
Running on Zero
File size: 7,319 Bytes
45e7dfb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | """
config.py -- SpikeWhale: combined config from SpikeTransformer (My Project) + NanoWhale (DeepSeek-V4).
Features carried from My Project (not in NanoWhale):
- DERF attention: erf(alpha*score+bias)*gamma replaces softmax
- XSA (Exclusive Self-Attention): orthogonality correction removes self-echo from attn output
- Engram N-gram module: hash-table N-gram lookup with DERF gate injected into embeddings
- Three-tier optimizer: embed/table params trained at lower LR
Features carried from NanoWhale (not in My Project):
- MLA (Multi-Head Latent Attention): low-rank Q projection + direct K,V (MQA)
- Partial RoPE: rotary embeddings on only qk_rope_head_dim dims of Q and K
- Low-rank grouped output projection (o_lora_rank)
- Hyper-Connections: hc_mult residual streams with learned routing between layers
- Shared expert in MoE (always-active expert alongside routed experts)
- sqrtsoftplus expert scoring (vs softmax in My Project)
- Hash-based routing for first num_hash_layers layers
- norm_topk_prob + routed_scaling_factor
- Multi-Token Prediction (MTP): extra heads predict k steps ahead
- torch.compile, FineWeb-Edu streaming, Trackio, YAML configs in train.py
"""
from transformers import PretrainedConfig
class SpikeWhaleConfig(PretrainedConfig):
model_type = "spike_whale"
def __init__(
self,
# Standard
vocab_size: int = 129280,
hidden_size: int = 2048,
num_hidden_layers: int = 11,
max_position_embeddings: int = 8192,
rms_norm_eps: float = 1e-6,
initializer_range: float = 0.02,
tie_word_embeddings: bool = False,
hidden_dropout: float = 0.0,
bos_token_id: int = 0,
eos_token_id: int = 1,
# MLA Attention (NanoWhale)
num_attention_heads: int = 8,
num_key_value_heads: int = 1, # 1 = MQA; >1 = GQA
q_lora_rank: int = 160, # low-rank Q: hidden -> q_lora_rank -> num_heads*head_dim
head_dim: int = 96, # total per-head dim = nope_head_dim + qk_rope_head_dim
qk_rope_head_dim: int = 32, # RoPE applied only to these dims
o_lora_rank: int = 80, # low-rank output: num_heads*head_dim -> o_lora_rank -> hidden
attention_dropout: float = 0.0,
rope_theta: float = 10000.0,
# DERF + XSA (My Project)
use_derf: bool = True,
use_xsa: bool = True,
# MoE (combined)
use_moe: bool = True,
moe_intermediate_size: int = 640,
n_routed_experts: int = 4,
n_shared_experts: int = 1, # NanoWhale: always-active shared expert
num_experts_per_tok: int = 2,
norm_topk_prob: bool = True, # NanoWhale: normalize top-k routing weights
scoring_func: str = "sqrtsoftplus", # NanoWhale: sqrt(softplus(x)) vs softmax
routed_scaling_factor: float = 1.0, # NanoWhale: scale routed expert weights
num_hash_layers: int = 2, # NanoWhale: first N layers use hash routing
moe_aux_loss_coef: float = 0.01,
moe_layers: list = None,
# Hyper-Connections (NanoWhale)
use_hyper_connections: bool = True,
hc_mult: int = 4, # number of parallel residual streams
hc_sinkhorn_iters: int = 20,
hc_eps: float = 1e-6,
# Multi-Token Prediction (NanoWhale)
num_nextn_predict_layers: int = 1, # extra MTP heads (0 = disabled)
# Engram N-gram module (My Project)
use_engram: bool = True,
engram_compress_dim: int = 64,
engram_num_heads: int = 4,
engram_table_size: int = 8192,
engram_max_ngram: int = 3,
engram_gate_init_bias: float = -4.0,
use_hrm_refine: bool = False,
hrm_refine_steps: int = 3,
hrm_refine_dim: int = 256,
# --- ModularMind-on-V2 additions (off/unused unless enabled) ---
use_latent_io: bool = False, # add latent output head + injection input path
d_latent: int = 256, # RecursiveLink contract dim (fixed across chain)
chain_position: int = 0, # context-doubling slot: ctx & theta scale by 2^pos
base_context: int = 8192, # ctx at position 0 (>= training --seq-len)
base_rope_theta: float = 10000.0,
**kwargs,
):
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.max_position_embeddings = max_position_embeddings
self.rms_norm_eps = rms_norm_eps
self.initializer_range = initializer_range
self.hidden_dropout = hidden_dropout
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.q_lora_rank = q_lora_rank
self.head_dim = head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.nope_head_dim = head_dim - qk_rope_head_dim
self.o_lora_rank = o_lora_rank
self.attention_dropout = attention_dropout
self.rope_theta = rope_theta
self.use_derf = use_derf
self.use_xsa = use_xsa
self.use_moe = use_moe
self.moe_intermediate_size = moe_intermediate_size
self.n_routed_experts = n_routed_experts
self.n_shared_experts = n_shared_experts
self.num_experts_per_tok = num_experts_per_tok
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
self.num_hash_layers = num_hash_layers
self.moe_aux_loss_coef = moe_aux_loss_coef
self.moe_layers = moe_layers if moe_layers is not None else list(range(num_hidden_layers))
self.use_hyper_connections = use_hyper_connections
self.hc_mult = hc_mult
self.hc_sinkhorn_iters = hc_sinkhorn_iters
self.hc_eps = hc_eps
self.num_nextn_predict_layers = num_nextn_predict_layers
self.use_engram = use_engram
self.engram_compress_dim = engram_compress_dim
self.engram_num_heads = engram_num_heads
self.engram_table_size = engram_table_size
self.engram_max_ngram = engram_max_ngram
self.engram_gate_init_bias = engram_gate_init_bias
self.use_hrm_refine = use_hrm_refine
self.hrm_refine_steps = hrm_refine_steps
self.hrm_refine_dim = hrm_refine_dim
# --- ModularMind-on-V2 additions ---
self.use_latent_io = use_latent_io
self.d_latent = d_latent
self.chain_position = chain_position
self.base_context = base_context
self.base_rope_theta = base_rope_theta
# Context-doubling: each chain slot doubles ctx and rope theta.
# position 0 -> (8192, 10000); position 1 -> (16384, 20000); etc.
# Applied only when latent IO is on (i.e. this is a ModularMind specialist),
# so plain V2 keeps its own max_position_embeddings/rope_theta untouched.
if use_latent_io:
scale = 2 ** chain_position
self.max_position_embeddings = base_context * scale
self.rope_theta = base_rope_theta * scale
|