ModuleMind / agents /modmind /config.py
Quazim0t0's picture
Add files using upload-large-folder tool
45e7dfb verified
Raw
History Blame Contribute Delete
7.32 kB
"""
config.py -- SpikeWhale: combined config from SpikeTransformer (My Project) + NanoWhale (DeepSeek-V4).
Features carried from My Project (not in NanoWhale):
- DERF attention: erf(alpha*score+bias)*gamma replaces softmax
- XSA (Exclusive Self-Attention): orthogonality correction removes self-echo from attn output
- Engram N-gram module: hash-table N-gram lookup with DERF gate injected into embeddings
- Three-tier optimizer: embed/table params trained at lower LR
Features carried from NanoWhale (not in My Project):
- MLA (Multi-Head Latent Attention): low-rank Q projection + direct K,V (MQA)
- Partial RoPE: rotary embeddings on only qk_rope_head_dim dims of Q and K
- Low-rank grouped output projection (o_lora_rank)
- Hyper-Connections: hc_mult residual streams with learned routing between layers
- Shared expert in MoE (always-active expert alongside routed experts)
- sqrtsoftplus expert scoring (vs softmax in My Project)
- Hash-based routing for first num_hash_layers layers
- norm_topk_prob + routed_scaling_factor
- Multi-Token Prediction (MTP): extra heads predict k steps ahead
- torch.compile, FineWeb-Edu streaming, Trackio, YAML configs in train.py
"""
from transformers import PretrainedConfig
class SpikeWhaleConfig(PretrainedConfig):
model_type = "spike_whale"
def __init__(
self,
# Standard
vocab_size: int = 129280,
hidden_size: int = 2048,
num_hidden_layers: int = 11,
max_position_embeddings: int = 8192,
rms_norm_eps: float = 1e-6,
initializer_range: float = 0.02,
tie_word_embeddings: bool = False,
hidden_dropout: float = 0.0,
bos_token_id: int = 0,
eos_token_id: int = 1,
# MLA Attention (NanoWhale)
num_attention_heads: int = 8,
num_key_value_heads: int = 1, # 1 = MQA; >1 = GQA
q_lora_rank: int = 160, # low-rank Q: hidden -> q_lora_rank -> num_heads*head_dim
head_dim: int = 96, # total per-head dim = nope_head_dim + qk_rope_head_dim
qk_rope_head_dim: int = 32, # RoPE applied only to these dims
o_lora_rank: int = 80, # low-rank output: num_heads*head_dim -> o_lora_rank -> hidden
attention_dropout: float = 0.0,
rope_theta: float = 10000.0,
# DERF + XSA (My Project)
use_derf: bool = True,
use_xsa: bool = True,
# MoE (combined)
use_moe: bool = True,
moe_intermediate_size: int = 640,
n_routed_experts: int = 4,
n_shared_experts: int = 1, # NanoWhale: always-active shared expert
num_experts_per_tok: int = 2,
norm_topk_prob: bool = True, # NanoWhale: normalize top-k routing weights
scoring_func: str = "sqrtsoftplus", # NanoWhale: sqrt(softplus(x)) vs softmax
routed_scaling_factor: float = 1.0, # NanoWhale: scale routed expert weights
num_hash_layers: int = 2, # NanoWhale: first N layers use hash routing
moe_aux_loss_coef: float = 0.01,
moe_layers: list = None,
# Hyper-Connections (NanoWhale)
use_hyper_connections: bool = True,
hc_mult: int = 4, # number of parallel residual streams
hc_sinkhorn_iters: int = 20,
hc_eps: float = 1e-6,
# Multi-Token Prediction (NanoWhale)
num_nextn_predict_layers: int = 1, # extra MTP heads (0 = disabled)
# Engram N-gram module (My Project)
use_engram: bool = True,
engram_compress_dim: int = 64,
engram_num_heads: int = 4,
engram_table_size: int = 8192,
engram_max_ngram: int = 3,
engram_gate_init_bias: float = -4.0,
use_hrm_refine: bool = False,
hrm_refine_steps: int = 3,
hrm_refine_dim: int = 256,
# --- ModularMind-on-V2 additions (off/unused unless enabled) ---
use_latent_io: bool = False, # add latent output head + injection input path
d_latent: int = 256, # RecursiveLink contract dim (fixed across chain)
chain_position: int = 0, # context-doubling slot: ctx & theta scale by 2^pos
base_context: int = 8192, # ctx at position 0 (>= training --seq-len)
base_rope_theta: float = 10000.0,
**kwargs,
):
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.max_position_embeddings = max_position_embeddings
self.rms_norm_eps = rms_norm_eps
self.initializer_range = initializer_range
self.hidden_dropout = hidden_dropout
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.q_lora_rank = q_lora_rank
self.head_dim = head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.nope_head_dim = head_dim - qk_rope_head_dim
self.o_lora_rank = o_lora_rank
self.attention_dropout = attention_dropout
self.rope_theta = rope_theta
self.use_derf = use_derf
self.use_xsa = use_xsa
self.use_moe = use_moe
self.moe_intermediate_size = moe_intermediate_size
self.n_routed_experts = n_routed_experts
self.n_shared_experts = n_shared_experts
self.num_experts_per_tok = num_experts_per_tok
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
self.num_hash_layers = num_hash_layers
self.moe_aux_loss_coef = moe_aux_loss_coef
self.moe_layers = moe_layers if moe_layers is not None else list(range(num_hidden_layers))
self.use_hyper_connections = use_hyper_connections
self.hc_mult = hc_mult
self.hc_sinkhorn_iters = hc_sinkhorn_iters
self.hc_eps = hc_eps
self.num_nextn_predict_layers = num_nextn_predict_layers
self.use_engram = use_engram
self.engram_compress_dim = engram_compress_dim
self.engram_num_heads = engram_num_heads
self.engram_table_size = engram_table_size
self.engram_max_ngram = engram_max_ngram
self.engram_gate_init_bias = engram_gate_init_bias
self.use_hrm_refine = use_hrm_refine
self.hrm_refine_steps = hrm_refine_steps
self.hrm_refine_dim = hrm_refine_dim
# --- ModularMind-on-V2 additions ---
self.use_latent_io = use_latent_io
self.d_latent = d_latent
self.chain_position = chain_position
self.base_context = base_context
self.base_rope_theta = base_rope_theta
# Context-doubling: each chain slot doubles ctx and rope theta.
# position 0 -> (8192, 10000); position 1 -> (16384, 20000); etc.
# Applied only when latent IO is on (i.e. this is a ModularMind specialist),
# so plain V2 keeps its own max_position_embeddings/rope_theta untouched.
if use_latent_io:
scale = 2 ** chain_position
self.max_position_embeddings = base_context * scale
self.rope_theta = base_rope_theta * scale