| """Zenith Model Architectures - 7B, 28B, 32B, 70B configurations"""
|
|
|
| from dataclasses import dataclass, field
|
| from typing import Optional, List, Dict, Any
|
|
|
|
|
| @dataclass
|
| class ZenithConfig:
|
| """Base configuration for Zenith models."""
|
|
|
|
|
| model_type: str = "zenith"
|
| d_model: int = 2048
|
| d_ff: int = 8192
|
| num_layers: int = 24
|
| num_heads: int = 32
|
| num_kv_heads: int = 8
|
| head_dim: int = 64
|
| vocab_size: int = 32000
|
| max_seq_len: int = 8192
|
| rope_theta: float = 10000.0
|
| rope_scaling: Optional[Dict[str, Any]] = None
|
|
|
|
|
| num_experts: int = 8
|
| top_k: int = 2
|
| moe_layer_frequency: int = 2
|
| shared_experts: int = 2
|
| capacity_factor: float = 1.0
|
| aux_loss_weight: float = 0.01
|
| router_z_loss_weight: float = 1e-4
|
|
|
|
|
| use_eq_adapter: bool = True
|
| eq_adapter_hidden_size: int = 512
|
| eq_num_emotions: int = 8
|
| eq_frustration_dim: int = 256
|
| eq_dropout: float = 0.1
|
|
|
|
|
| use_eq_attention_bias: bool = False
|
| use_eq_gated_ffn: bool = False
|
| use_eq_recurrence: bool = False
|
| eq_consistency_weight: float = 0.02
|
| eq_state_dim: int = 256
|
|
|
|
|
| rms_norm_eps: float = 1e-6
|
| dropout: float = 0.0
|
| attention_dropout: float = 0.0
|
|
|
|
|
| initializer_range: float = 0.02
|
| use_gradient_checkpointing: bool = False
|
| use_flash_attention: bool = False
|
|
|
|
|
| tensor_parallel_size: int = 1
|
| pipeline_parallel_size: int = 1
|
| use_tenstorrent_optimizations: bool = False
|
| noc_optimization: bool = False
|
|
|
|
|
| use_ring_attention: bool = False
|
| ring_attention_chunk_size: int = 8192
|
| ring_attention_overlap: int = 2048
|
|
|
| def __post_init__(self):
|
| """Compute derived parameters."""
|
|
|
| self.dense_layers = self.num_layers - (self.num_layers // self.moe_layer_frequency)
|
| self.moe_layers = self.num_layers // self.moe_layer_frequency
|
|
|
|
|
| self.total_params = self._compute_total_params()
|
| self.active_params = self._compute_active_params()
|
|
|
| def _compute_total_params(self) -> int:
|
| """Estimate total parameters."""
|
|
|
| embedding_params = self.vocab_size * self.d_model
|
|
|
|
|
| layer_params = 0
|
| for _ in range(self.num_layers):
|
|
|
| attn_params = 4 * self.d_model * self.d_model
|
|
|
| norm_params = 2 * self.d_model
|
|
|
| if self.num_experts > 1:
|
|
|
| ff_params = self.num_experts * (self.d_model * self.d_ff + self.d_ff * self.d_model)
|
| else:
|
| ff_params = self.d_model * self.d_ff + self.d_ff * self.d_model
|
|
|
| if self.use_eq_adapter:
|
| eq_params = (
|
| self.d_model * self.eq_adapter_hidden_dim +
|
| self.eq_adapter_hidden_dim * 2 *
|
| self.eq_adapter_hidden_dim * self.d_model
|
| )
|
| else:
|
| eq_params = 0
|
| layer_params += attn_params + norm_params + ff_params + eq_params
|
|
|
|
|
| output_params = self.vocab_size * self.d_model
|
|
|
| return embedding_params + layer_params + output_params
|
|
|
| def _compute_active_params(self) -> int:
|
| """Estimate active parameters during inference (top-k MoE)."""
|
| if self.num_experts <= 1:
|
| return self.total_params
|
|
|
|
|
| active_ff_ratio = (self.top_k / self.num_experts)
|
| active_total = self.total_params - (self.num_layers * (self.d_model * self.d_ff * 2 * (1 - active_ff_ratio)))
|
|
|
| return int(active_total)
|
|
|
|
|
| def get_7b_config() -> ZenithConfig:
|
| """7B parameter model configuration for standard GPUs."""
|
| return ZenithConfig(
|
| model_type="zenith-7b",
|
| d_model=2048,
|
| d_ff=8192,
|
| num_layers=24,
|
| num_heads=16,
|
| num_kv_heads=8,
|
| head_dim=128,
|
| vocab_size=32000,
|
| max_seq_len=8192,
|
| num_experts=8,
|
| top_k=2,
|
| moe_layer_frequency=2,
|
| shared_experts=2,
|
| use_eq_adapter=True,
|
| eq_adapter_hidden_dim=512,
|
| use_flash_attention=True,
|
| use_gradient_checkpointing=True,
|
| )
|
|
|
|
|
| def get_28b_config() -> ZenithConfig:
|
| """28B parameter model configuration optimized for p300.
|
| Based on Qwen3.5-27B-Claude-4.6-Opus-Reasoning-Distilled."""
|
| return ZenithConfig(
|
| model_type="zenith-28b",
|
| d_model=3072,
|
| d_ff=12288,
|
| num_layers=36,
|
| num_heads=24,
|
| num_kv_heads=8,
|
| head_dim=128,
|
| vocab_size=32000,
|
| max_seq_len=32768,
|
| num_experts=8,
|
| top_k=2,
|
| moe_layer_frequency=2,
|
| shared_experts=2,
|
| use_eq_adapter=True,
|
| eq_adapter_hidden_dim=768,
|
| use_flash_attention=True,
|
| use_gradient_checkpointing=True,
|
| use_ring_attention=True,
|
| ring_attention_chunk_size=8192,
|
| ring_attention_overlap=2048,
|
| use_tenstorrent_optimizations=True,
|
| tensor_parallel_size=8,
|
| pipeline_parallel_size=4,
|
| noc_optimization=True,
|
| )
|
|
|
|
|
| def get_32b_config() -> ZenithConfig:
|
| """32B parameter model configuration."""
|
| return ZenithConfig(
|
| model_type="zenith-32b",
|
| d_model=2560,
|
| d_ff=10240,
|
| num_layers=32,
|
| num_heads=20,
|
| num_kv_heads=10,
|
| head_dim=128,
|
| vocab_size=32000,
|
| max_seq_len=8192,
|
| num_experts=8,
|
| top_k=2,
|
| moe_layer_frequency=2,
|
| shared_experts=2,
|
| use_eq_adapter=True,
|
| eq_adapter_hidden_dim=640,
|
| use_flash_attention=True,
|
| use_gradient_checkpointing=True,
|
| )
|
|
|
|
|
| def get_70b_config() -> ZenithConfig:
|
| """70B parameter model configuration based on DeepSeek-R1-Distill-Llama-70B."""
|
| return ZenithConfig(
|
| model_type="zenith-70b",
|
| d_model=4096,
|
| d_ff=16384,
|
| num_layers=48,
|
| num_heads=32,
|
| num_kv_heads=8,
|
| head_dim=128,
|
| vocab_size=32000,
|
| max_seq_len=32768,
|
| num_experts=8,
|
| top_k=2,
|
| moe_layer_frequency=2,
|
| shared_experts=2,
|
| use_eq_adapter=True,
|
| eq_adapter_hidden_dim=1024,
|
| use_flash_attention=True,
|
| use_gradient_checkpointing=True,
|
| use_ring_attention=True,
|
| ring_attention_chunk_size=8192,
|
| ring_attention_overlap=2048,
|
| use_tenstorrent_optimizations=True,
|
| tensor_parallel_size=8,
|
| pipeline_parallel_size=4,
|
| noc_optimization=True,
|
| )
|
|
|
|
|
| def get_p300_optimized_config() -> ZenithConfig:
|
| """32B model optimized specifically for Tenstorrent p300a."""
|
| base = get_32b_config()
|
| base.model_type = "zenith-32b-p300"
|
| base.use_tenstorrent_optimizations = True
|
| base.tensor_parallel_size = 8
|
| base.pipeline_parallel_size = 4
|
| base.noc_optimization = True
|
| base.max_seq_len = 32768
|
| base.use_ring_attention = True
|
| base.ring_attention_chunk_size = 8192
|
| base.ring_attention_overlap = 2048
|
| return base
|
|
|
|
|
| def get_28b_p300_config() -> ZenithConfig:
|
| """28B model optimized specifically for Tenstorrent p300a.
|
| Based on Qwen3.5-27B-Claude-4.6-Opus-Reasoning-Distilled."""
|
| config = get_28b_config()
|
| config.model_type = "zenith-28b-p300"
|
| config.use_tenstorrent_optimizations = True
|
| config.tensor_parallel_size = 8
|
| config.pipeline_parallel_size = 4
|
| config.noc_optimization = True
|
| config.max_seq_len = 32768
|
| config.use_ring_attention = True
|
| config.ring_attention_chunk_size = 8192
|
| config.ring_attention_overlap = 2048
|
| return config
|
|
|