Zenith-7b-V1 / configs /zenith_config.py
Zandy-Wandy's picture
Upload Zenith-7B model
1ea8a03 verified
"""Zenith Model Architectures - 7B, 28B, 32B, 70B configurations"""
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any
@dataclass
class ZenithConfig:
"""Base configuration for Zenith models."""
# Model architecture
model_type: str = "zenith"
d_model: int = 2048
d_ff: int = 8192
num_layers: int = 24
num_heads: int = 32
num_kv_heads: int = 8
head_dim: int = 64
vocab_size: int = 32000
max_seq_len: int = 8192
rope_theta: float = 10000.0
rope_scaling: Optional[Dict[str, Any]] = None
# MoE configuration
num_experts: int = 8
top_k: int = 2
moe_layer_frequency: int = 2 # Every Nth layer is MoE
shared_experts: int = 2
capacity_factor: float = 1.0
aux_loss_weight: float = 0.01
router_z_loss_weight: float = 1e-4
# EQ Adapter configuration
use_eq_adapter: bool = True
eq_adapter_hidden_size: int = 512
eq_num_emotions: int = 8
eq_frustration_dim: int = 256
eq_dropout: float = 0.1
# EQ Engine advanced features
use_eq_attention_bias: bool = False
use_eq_gated_ffn: bool = False
use_eq_recurrence: bool = False
eq_consistency_weight: float = 0.02
eq_state_dim: int = 256 # Dimension of recurrent EQ state
# Normalization & dropout
rms_norm_eps: float = 1e-6
dropout: float = 0.0
attention_dropout: float = 0.0
# Initialization
initializer_range: float = 0.02
use_gradient_checkpointing: bool = False
use_flash_attention: bool = False
# Tenstorrent specific (p300)
tensor_parallel_size: int = 1
pipeline_parallel_size: int = 1
use_tenstorrent_optimizations: bool = False
noc_optimization: bool = False
# Ring attention (70B/28B only)
use_ring_attention: bool = False
ring_attention_chunk_size: int = 8192
ring_attention_overlap: int = 2048
def __post_init__(self):
"""Compute derived parameters."""
# Compute dense vs MoE layer distribution
self.dense_layers = self.num_layers - (self.num_layers // self.moe_layer_frequency)
self.moe_layers = self.num_layers // self.moe_layer_frequency
# Compute parameters
self.total_params = self._compute_total_params()
self.active_params = self._compute_active_params()
def _compute_total_params(self) -> int:
"""Estimate total parameters."""
# Embedding
embedding_params = self.vocab_size * self.d_model
# Layers
layer_params = 0
for _ in range(self.num_layers):
# Attention: Q, K, V, O projections
attn_params = 4 * self.d_model * self.d_model
# RMS norm
norm_params = 2 * self.d_model
# Feed-forward
if self.num_experts > 1:
# MoE: num_experts * (d_model * d_ff + d_ff * d_model)
ff_params = self.num_experts * (self.d_model * self.d_ff + self.d_ff * self.d_model)
else:
ff_params = self.d_model * self.d_ff + self.d_ff * self.d_model
# EQ adapter
if self.use_eq_adapter:
eq_params = (
self.d_model * self.eq_adapter_hidden_dim +
self.eq_adapter_hidden_dim * 2 * # frustration + emotion
self.eq_adapter_hidden_dim * self.d_model
)
else:
eq_params = 0
layer_params += attn_params + norm_params + ff_params + eq_params
# Output head (tied with embedding)
output_params = self.vocab_size * self.d_model
return embedding_params + layer_params + output_params
def _compute_active_params(self) -> int:
"""Estimate active parameters during inference (top-k MoE)."""
if self.num_experts <= 1:
return self.total_params
# MoE active params: only top-k experts per token
active_ff_ratio = (self.top_k / self.num_experts)
active_total = self.total_params - (self.num_layers * (self.d_model * self.d_ff * 2 * (1 - active_ff_ratio)))
return int(active_total)
def get_7b_config() -> ZenithConfig:
"""7B parameter model configuration for standard GPUs."""
return ZenithConfig(
model_type="zenith-7b",
d_model=2048,
d_ff=8192,
num_layers=24,
num_heads=16,
num_kv_heads=8,
head_dim=128,
vocab_size=32000,
max_seq_len=8192,
num_experts=8,
top_k=2,
moe_layer_frequency=2, # Every 2nd layer after first 14 is MoE
shared_experts=2,
use_eq_adapter=True,
eq_adapter_hidden_dim=512,
use_flash_attention=True,
use_gradient_checkpointing=True,
)
def get_28b_config() -> ZenithConfig:
"""28B parameter model configuration optimized for p300.
Based on Qwen3.5-27B-Claude-4.6-Opus-Reasoning-Distilled."""
return ZenithConfig(
model_type="zenith-28b",
d_model=3072, # Qwen3.5-27B hidden size
d_ff=12288, # 4x d_model (typical for Qwen)
num_layers=36, # Approximate for 28B
num_heads=24, # d_model / 128 = 24
num_kv_heads=8,
head_dim=128,
vocab_size=32000,
max_seq_len=32768, # 32k context
num_experts=8,
top_k=2,
moe_layer_frequency=2,
shared_experts=2,
use_eq_adapter=True,
eq_adapter_hidden_dim=768, # 0.25 * d_model
use_flash_attention=True,
use_gradient_checkpointing=True,
use_ring_attention=True, # For 32k context
ring_attention_chunk_size=8192,
ring_attention_overlap=2048,
use_tenstorrent_optimizations=True,
tensor_parallel_size=8, # 8 cores/chip for TP
pipeline_parallel_size=4, # 4 cores/chip for PP
noc_optimization=True,
)
def get_32b_config() -> ZenithConfig:
"""32B parameter model configuration."""
return ZenithConfig(
model_type="zenith-32b",
d_model=2560,
d_ff=10240,
num_layers=32,
num_heads=20,
num_kv_heads=10,
head_dim=128,
vocab_size=32000,
max_seq_len=8192,
num_experts=8,
top_k=2,
moe_layer_frequency=2,
shared_experts=2,
use_eq_adapter=True,
eq_adapter_hidden_dim=640,
use_flash_attention=True,
use_gradient_checkpointing=True,
)
def get_70b_config() -> ZenithConfig:
"""70B parameter model configuration based on DeepSeek-R1-Distill-Llama-70B."""
return ZenithConfig(
model_type="zenith-70b",
d_model=4096,
d_ff=16384,
num_layers=48,
num_heads=32,
num_kv_heads=8,
head_dim=128,
vocab_size=32000,
max_seq_len=32768, # 32k context
num_experts=8,
top_k=2,
moe_layer_frequency=2,
shared_experts=2,
use_eq_adapter=True,
eq_adapter_hidden_dim=1024,
use_flash_attention=True,
use_gradient_checkpointing=True,
use_ring_attention=True, # For 32k context
ring_attention_chunk_size=8192,
ring_attention_overlap=2048,
use_tenstorrent_optimizations=True,
tensor_parallel_size=8,
pipeline_parallel_size=4,
noc_optimization=True,
)
def get_p300_optimized_config() -> ZenithConfig:
"""32B model optimized specifically for Tenstorrent p300a."""
base = get_32b_config()
base.model_type = "zenith-32b-p300"
base.use_tenstorrent_optimizations = True
base.tensor_parallel_size = 8 # 8 cores/chip for TP
base.pipeline_parallel_size = 4 # 4 cores/chip for PP
base.noc_optimization = True
base.max_seq_len = 32768 # 32k context with ring attention
base.use_ring_attention = True
base.ring_attention_chunk_size = 8192
base.ring_attention_overlap = 2048
return base
def get_28b_p300_config() -> ZenithConfig:
"""28B model optimized specifically for Tenstorrent p300a.
Based on Qwen3.5-27B-Claude-4.6-Opus-Reasoning-Distilled."""
config = get_28b_config()
config.model_type = "zenith-28b-p300"
config.use_tenstorrent_optimizations = True
config.tensor_parallel_size = 8 # 8 cores/chip for TP
config.pipeline_parallel_size = 4 # 4 cores/chip for PP
config.noc_optimization = True
config.max_seq_len = 32768 # 32k context with ring attention
config.use_ring_attention = True
config.ring_attention_chunk_size = 8192
config.ring_attention_overlap = 2048
return config