Spaces:
Running on Zero
Running on Zero
| """ | |
| config.py -- SpikeWhale: combined config from SpikeTransformer (My Project) + NanoWhale (DeepSeek-V4). | |
| Features carried from My Project (not in NanoWhale): | |
| - DERF attention: erf(alpha*score+bias)*gamma replaces softmax | |
| - XSA (Exclusive Self-Attention): orthogonality correction removes self-echo from attn output | |
| - Engram N-gram module: hash-table N-gram lookup with DERF gate injected into embeddings | |
| - Three-tier optimizer: embed/table params trained at lower LR | |
| Features carried from NanoWhale (not in My Project): | |
| - MLA (Multi-Head Latent Attention): low-rank Q projection + direct K,V (MQA) | |
| - Partial RoPE: rotary embeddings on only qk_rope_head_dim dims of Q and K | |
| - Low-rank grouped output projection (o_lora_rank) | |
| - Hyper-Connections: hc_mult residual streams with learned routing between layers | |
| - Shared expert in MoE (always-active expert alongside routed experts) | |
| - sqrtsoftplus expert scoring (vs softmax in My Project) | |
| - Hash-based routing for first num_hash_layers layers | |
| - norm_topk_prob + routed_scaling_factor | |
| - Multi-Token Prediction (MTP): extra heads predict k steps ahead | |
| - torch.compile, FineWeb-Edu streaming, Trackio, YAML configs in train.py | |
| """ | |
| from transformers import PretrainedConfig | |
| class SpikeWhaleConfig(PretrainedConfig): | |
| model_type = "spike_whale" | |
| def __init__( | |
| self, | |
| # Standard | |
| vocab_size: int = 129280, | |
| hidden_size: int = 2048, | |
| num_hidden_layers: int = 11, | |
| max_position_embeddings: int = 8192, | |
| rms_norm_eps: float = 1e-6, | |
| initializer_range: float = 0.02, | |
| tie_word_embeddings: bool = False, | |
| hidden_dropout: float = 0.0, | |
| bos_token_id: int = 0, | |
| eos_token_id: int = 1, | |
| # MLA Attention (NanoWhale) | |
| num_attention_heads: int = 8, | |
| num_key_value_heads: int = 1, # 1 = MQA; >1 = GQA | |
| q_lora_rank: int = 160, # low-rank Q: hidden -> q_lora_rank -> num_heads*head_dim | |
| head_dim: int = 96, # total per-head dim = nope_head_dim + qk_rope_head_dim | |
| qk_rope_head_dim: int = 32, # RoPE applied only to these dims | |
| o_lora_rank: int = 80, # low-rank output: num_heads*head_dim -> o_lora_rank -> hidden | |
| attention_dropout: float = 0.0, | |
| rope_theta: float = 10000.0, | |
| # DERF + XSA (My Project) | |
| use_derf: bool = True, | |
| use_xsa: bool = True, | |
| # MoE (combined) | |
| use_moe: bool = True, | |
| moe_intermediate_size: int = 640, | |
| n_routed_experts: int = 4, | |
| n_shared_experts: int = 1, # NanoWhale: always-active shared expert | |
| num_experts_per_tok: int = 2, | |
| norm_topk_prob: bool = True, # NanoWhale: normalize top-k routing weights | |
| scoring_func: str = "sqrtsoftplus", # NanoWhale: sqrt(softplus(x)) vs softmax | |
| routed_scaling_factor: float = 1.0, # NanoWhale: scale routed expert weights | |
| num_hash_layers: int = 2, # NanoWhale: first N layers use hash routing | |
| moe_aux_loss_coef: float = 0.01, | |
| moe_layers: list = None, | |
| # Hyper-Connections (NanoWhale) | |
| use_hyper_connections: bool = True, | |
| hc_mult: int = 4, # number of parallel residual streams | |
| hc_sinkhorn_iters: int = 20, | |
| hc_eps: float = 1e-6, | |
| # Multi-Token Prediction (NanoWhale) | |
| num_nextn_predict_layers: int = 1, # extra MTP heads (0 = disabled) | |
| # Engram N-gram module (My Project) | |
| use_engram: bool = True, | |
| engram_compress_dim: int = 64, | |
| engram_num_heads: int = 4, | |
| engram_table_size: int = 8192, | |
| engram_max_ngram: int = 3, | |
| engram_gate_init_bias: float = -4.0, | |
| use_hrm_refine: bool = False, | |
| hrm_refine_steps: int = 3, | |
| hrm_refine_dim: int = 256, | |
| # --- ModularMind-on-V2 additions (off/unused unless enabled) --- | |
| use_latent_io: bool = False, # add latent output head + injection input path | |
| d_latent: int = 256, # RecursiveLink contract dim (fixed across chain) | |
| chain_position: int = 0, # context-doubling slot: ctx & theta scale by 2^pos | |
| base_context: int = 8192, # ctx at position 0 (>= training --seq-len) | |
| base_rope_theta: float = 10000.0, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| bos_token_id=bos_token_id, | |
| eos_token_id=eos_token_id, | |
| tie_word_embeddings=tie_word_embeddings, | |
| **kwargs, | |
| ) | |
| self.vocab_size = vocab_size | |
| self.hidden_size = hidden_size | |
| self.num_hidden_layers = num_hidden_layers | |
| self.max_position_embeddings = max_position_embeddings | |
| self.rms_norm_eps = rms_norm_eps | |
| self.initializer_range = initializer_range | |
| self.hidden_dropout = hidden_dropout | |
| self.num_attention_heads = num_attention_heads | |
| self.num_key_value_heads = num_key_value_heads | |
| self.q_lora_rank = q_lora_rank | |
| self.head_dim = head_dim | |
| self.qk_rope_head_dim = qk_rope_head_dim | |
| self.nope_head_dim = head_dim - qk_rope_head_dim | |
| self.o_lora_rank = o_lora_rank | |
| self.attention_dropout = attention_dropout | |
| self.rope_theta = rope_theta | |
| self.use_derf = use_derf | |
| self.use_xsa = use_xsa | |
| self.use_moe = use_moe | |
| self.moe_intermediate_size = moe_intermediate_size | |
| self.n_routed_experts = n_routed_experts | |
| self.n_shared_experts = n_shared_experts | |
| self.num_experts_per_tok = num_experts_per_tok | |
| self.norm_topk_prob = norm_topk_prob | |
| self.scoring_func = scoring_func | |
| self.routed_scaling_factor = routed_scaling_factor | |
| self.num_hash_layers = num_hash_layers | |
| self.moe_aux_loss_coef = moe_aux_loss_coef | |
| self.moe_layers = moe_layers if moe_layers is not None else list(range(num_hidden_layers)) | |
| self.use_hyper_connections = use_hyper_connections | |
| self.hc_mult = hc_mult | |
| self.hc_sinkhorn_iters = hc_sinkhorn_iters | |
| self.hc_eps = hc_eps | |
| self.num_nextn_predict_layers = num_nextn_predict_layers | |
| self.use_engram = use_engram | |
| self.engram_compress_dim = engram_compress_dim | |
| self.engram_num_heads = engram_num_heads | |
| self.engram_table_size = engram_table_size | |
| self.engram_max_ngram = engram_max_ngram | |
| self.engram_gate_init_bias = engram_gate_init_bias | |
| self.use_hrm_refine = use_hrm_refine | |
| self.hrm_refine_steps = hrm_refine_steps | |
| self.hrm_refine_dim = hrm_refine_dim | |
| # --- ModularMind-on-V2 additions --- | |
| self.use_latent_io = use_latent_io | |
| self.d_latent = d_latent | |
| self.chain_position = chain_position | |
| self.base_context = base_context | |
| self.base_rope_theta = base_rope_theta | |
| # Context-doubling: each chain slot doubles ctx and rope theta. | |
| # position 0 -> (8192, 10000); position 1 -> (16384, 20000); etc. | |
| # Applied only when latent IO is on (i.e. this is a ModularMind specialist), | |
| # so plain V2 keeps its own max_position_embeddings/rope_theta untouched. | |
| if use_latent_io: | |
| scale = 2 ** chain_position | |
| self.max_position_embeddings = base_context * scale | |
| self.rope_theta = base_rope_theta * scale | |