""" config.py -- SpikeWhale: combined config from SpikeTransformer (My Project) + NanoWhale (DeepSeek-V4). Features carried from My Project (not in NanoWhale): - DERF attention: erf(alpha*score+bias)*gamma replaces softmax - XSA (Exclusive Self-Attention): orthogonality correction removes self-echo from attn output - Engram N-gram module: hash-table N-gram lookup with DERF gate injected into embeddings - Three-tier optimizer: embed/table params trained at lower LR Features carried from NanoWhale (not in My Project): - MLA (Multi-Head Latent Attention): low-rank Q projection + direct K,V (MQA) - Partial RoPE: rotary embeddings on only qk_rope_head_dim dims of Q and K - Low-rank grouped output projection (o_lora_rank) - Hyper-Connections: hc_mult residual streams with learned routing between layers - Shared expert in MoE (always-active expert alongside routed experts) - sqrtsoftplus expert scoring (vs softmax in My Project) - Hash-based routing for first num_hash_layers layers - norm_topk_prob + routed_scaling_factor - Multi-Token Prediction (MTP): extra heads predict k steps ahead - torch.compile, FineWeb-Edu streaming, Trackio, YAML configs in train.py """ from transformers import PretrainedConfig class SpikeWhaleConfig(PretrainedConfig): model_type = "spike_whale" def __init__( self, # Standard vocab_size: int = 129280, hidden_size: int = 2048, num_hidden_layers: int = 11, max_position_embeddings: int = 8192, rms_norm_eps: float = 1e-6, initializer_range: float = 0.02, tie_word_embeddings: bool = False, hidden_dropout: float = 0.0, bos_token_id: int = 0, eos_token_id: int = 1, # MLA Attention (NanoWhale) num_attention_heads: int = 8, num_key_value_heads: int = 1, # 1 = MQA; >1 = GQA q_lora_rank: int = 160, # low-rank Q: hidden -> q_lora_rank -> num_heads*head_dim head_dim: int = 96, # total per-head dim = nope_head_dim + qk_rope_head_dim qk_rope_head_dim: int = 32, # RoPE applied only to these dims o_lora_rank: int = 80, # low-rank output: num_heads*head_dim -> o_lora_rank -> hidden attention_dropout: float = 0.0, rope_theta: float = 10000.0, # DERF + XSA (My Project) use_derf: bool = True, use_xsa: bool = True, # MoE (combined) use_moe: bool = True, moe_intermediate_size: int = 640, n_routed_experts: int = 4, n_shared_experts: int = 1, # NanoWhale: always-active shared expert num_experts_per_tok: int = 2, norm_topk_prob: bool = True, # NanoWhale: normalize top-k routing weights scoring_func: str = "sqrtsoftplus", # NanoWhale: sqrt(softplus(x)) vs softmax routed_scaling_factor: float = 1.0, # NanoWhale: scale routed expert weights num_hash_layers: int = 2, # NanoWhale: first N layers use hash routing moe_aux_loss_coef: float = 0.01, moe_layers: list = None, # Hyper-Connections (NanoWhale) use_hyper_connections: bool = True, hc_mult: int = 4, # number of parallel residual streams hc_sinkhorn_iters: int = 20, hc_eps: float = 1e-6, # Multi-Token Prediction (NanoWhale) num_nextn_predict_layers: int = 1, # extra MTP heads (0 = disabled) # Engram N-gram module (My Project) use_engram: bool = True, engram_compress_dim: int = 64, engram_num_heads: int = 4, engram_table_size: int = 8192, engram_max_ngram: int = 3, engram_gate_init_bias: float = -4.0, use_hrm_refine: bool = False, hrm_refine_steps: int = 3, hrm_refine_dim: int = 256, # --- ModularMind-on-V2 additions (off/unused unless enabled) --- use_latent_io: bool = False, # add latent output head + injection input path d_latent: int = 256, # RecursiveLink contract dim (fixed across chain) chain_position: int = 0, # context-doubling slot: ctx & theta scale by 2^pos base_context: int = 8192, # ctx at position 0 (>= training --seq-len) base_rope_theta: float = 10000.0, **kwargs, ): super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.max_position_embeddings = max_position_embeddings self.rms_norm_eps = rms_norm_eps self.initializer_range = initializer_range self.hidden_dropout = hidden_dropout self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.q_lora_rank = q_lora_rank self.head_dim = head_dim self.qk_rope_head_dim = qk_rope_head_dim self.nope_head_dim = head_dim - qk_rope_head_dim self.o_lora_rank = o_lora_rank self.attention_dropout = attention_dropout self.rope_theta = rope_theta self.use_derf = use_derf self.use_xsa = use_xsa self.use_moe = use_moe self.moe_intermediate_size = moe_intermediate_size self.n_routed_experts = n_routed_experts self.n_shared_experts = n_shared_experts self.num_experts_per_tok = num_experts_per_tok self.norm_topk_prob = norm_topk_prob self.scoring_func = scoring_func self.routed_scaling_factor = routed_scaling_factor self.num_hash_layers = num_hash_layers self.moe_aux_loss_coef = moe_aux_loss_coef self.moe_layers = moe_layers if moe_layers is not None else list(range(num_hidden_layers)) self.use_hyper_connections = use_hyper_connections self.hc_mult = hc_mult self.hc_sinkhorn_iters = hc_sinkhorn_iters self.hc_eps = hc_eps self.num_nextn_predict_layers = num_nextn_predict_layers self.use_engram = use_engram self.engram_compress_dim = engram_compress_dim self.engram_num_heads = engram_num_heads self.engram_table_size = engram_table_size self.engram_max_ngram = engram_max_ngram self.engram_gate_init_bias = engram_gate_init_bias self.use_hrm_refine = use_hrm_refine self.hrm_refine_steps = hrm_refine_steps self.hrm_refine_dim = hrm_refine_dim # --- ModularMind-on-V2 additions --- self.use_latent_io = use_latent_io self.d_latent = d_latent self.chain_position = chain_position self.base_context = base_context self.base_rope_theta = base_rope_theta # Context-doubling: each chain slot doubles ctx and rope theta. # position 0 -> (8192, 10000); position 1 -> (16384, 20000); etc. # Applied only when latent IO is on (i.e. this is a ModularMind specialist), # so plain V2 keeps its own max_position_embeddings/rope_theta untouched. if use_latent_io: scale = 2 ** chain_position self.max_position_embeddings = base_context * scale self.rope_theta = base_rope_theta * scale