""" Zenith Model for Hugging Face Transformers This module provides the Zenith model implementation that is compatible with Hugging Face's Transformers library, allowing for easy training, inference, and deployment. Zenith features: - Hybrid Dense+MoE architecture - Ring attention for long contexts - EQ adapters for emotional intelligence - Curriculum learning and quality filtering - OpenThoughts-1.2M integration """ import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Optional, Tuple, List, Dict, Any from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import CausalLMOutput class ZenithConfig(PretrainedConfig): """Configuration class for Zenith models.""" model_type = "zenith" def __init__( self, # Architecture hidden_size: int = 4096, num_layers: int = 32, num_heads: int = 32, num_key_value_heads: Optional[int] = None, intermediate_size: int = 11008, hidden_act: str = "silu", max_position_embeddings: int = 8192, vocab_size: int = 32000, # MoE num_experts: int = 0, moe_top_k: int = 2, moe_capacity_factor: float = 1.0, moe_load_balancing_weight: float = 0.01, moe_router_learning_rate: float = 1e-3, moe_layers: Optional[List[int]] = None, # Which layers use MoE (0-indexed) # EQ Adapter use_eq_adapter: bool = False, eq_adapter_hidden_size: int = 64, eq_loss_weight: float = 0.1, emotion_loss_weight: float = 0.1, frustration_loss_weight: float = 0.1, # Ring Attention use_ring_attention: bool = False, ring_attention_chunk_size: int = 8192, ring_attention_overlap: int = 2048, # p300 optimizations use_tenstorrent_optimizations: bool = False, tensor_parallel_size: int = 1, pipeline_parallel_size: int = 1, noc_optimization: bool = False, # Training gradient_checkpointing: bool = False, use_cache: bool = True, **kwargs ): super().__init__(**kwargs) self.hidden_size = hidden_size self.num_layers = num_layers self.num_heads = num_heads self.num_key_value_heads = num_key_value_heads or num_heads self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.max_position_embeddings = max_position_embeddings self.vocab_size = vocab_size self.num_experts = num_experts self.moe_top_k = moe_top_k self.moe_capacity_factor = moe_capacity_factor self.moe_load_balancing_weight = moe_load_balancing_weight self.moe_router_learning_rate = moe_router_learning_rate self.moe_layers = moe_layers or [] self.use_eq_adapter = use_eq_adapter self.eq_adapter_hidden_size = eq_adapter_hidden_size self.eq_loss_weight = eq_loss_weight self.emotion_loss_weight = emotion_loss_weight self.frustration_loss_weight = frustration_loss_weight self.use_ring_attention = use_ring_attention self.ring_attention_chunk_size = ring_attention_chunk_size self.ring_attention_overlap = ring_attention_overlap self.use_tenstorrent_optimizations = use_tenstorrent_optimizations self.tensor_parallel_size = tensor_parallel_size self.pipeline_parallel_size = pipeline_parallel_size self.noc_optimization = noc_optimization self.gradient_checkpointing = gradient_checkpointing self.use_cache = use_cache class MoELayer(nn.Module): """Mixture of Experts layer.""" def __init__(self, config: ZenithConfig): super().__init__() self.config = config self.num_experts = config.num_experts self.top_k = config.moe_top_k # Router self.router = nn.Linear(config.hidden_size, config.num_experts, bias=False) # Experts self.experts = nn.ModuleList([ nn.Sequential( nn.Linear(config.hidden_size, config.intermediate_size), nn.SiLU(), nn.Linear(config.intermediate_size, config.hidden_size) ) for _ in range(config.num_experts) ]) def forward(self, hidden_states: torch.Tensor): batch_size, seq_len, hidden_dim = hidden_states.shape # Reshape for routing flat_hidden = hidden_states.view(-1, hidden_dim) # Router logits router_logits = self.router(flat_hidden) router_probs = F.softmax(router_logits, dim=-1) # Top-k routing topk_values, topk_indices = torch.topk( router_probs, k=self.top_k, dim=-1 ) # Normalize top-k weights topk_values = topk_values / topk_values.sum(dim=-1, keepdim=True) # Compute expert outputs expert_outputs = torch.zeros_like(flat_hidden) for expert_idx in range(self.num_experts): # Find tokens routed to this expert expert_mask = (topk_indices == expert_idx).any(dim=-1) if expert_mask.any(): expert_input = flat_hidden[expert_mask] expert_output = self.experts[expert_idx](expert_input) expert_outputs[expert_mask] += topk_values[expert_mask, (topk_indices[expert_mask] == expert_idx).nonzero(as_tuple=True)[1] ].unsqueeze(-1) * expert_output # Reshape back expert_outputs = expert_outputs.view(batch_size, seq_len, hidden_dim) # Load balancing loss router_probs_mean = router_probs.mean(dim=0) load_balance_loss = self.config.moe_load_balancing_weight * ( router_probs_mean * torch.log(router_probs_mean + 1e-10) ).sum() return expert_outputs, load_balance_loss class EQAdapter(nn.Module): """Enhanced Emotional Intelligence Adapter with recurrent state and core architecture integration.""" def __init__(self, config: ZenithConfig): super().__init__() self.config = config # Frustration detection (regression) self.frustration_detector = nn.Sequential( nn.Linear(config.hidden_size, config.eq_adapter_hidden_size), nn.Tanh(), nn.Linear(config.eq_adapter_hidden_size, 1), nn.Sigmoid() ) # Emotion classification (8 classes) self.emotion_classifier = nn.Sequential( nn.Linear(config.hidden_size, config.eq_adapter_hidden_size), nn.Tanh(), nn.Linear(config.eq_adapter_hidden_size, 8) ) # Recurrent EQ state (GRU) for layer-to-layer consistency if config.use_eq_recurrence: self.eq_gru = nn.GRUCell( input_size=config.eq_adapter_hidden_size, hidden_size=config.eq_state_dim ) # Projection to generate initial state from pooled features self.state_projection = nn.Linear(config.hidden_size, config.eq_state_dim) # Projection to reduce pooled features to GRU input size self.gru_input_proj = nn.Linear(config.hidden_size, config.eq_adapter_hidden_size) else: self.eq_gru = None self.state_projection = None self.gru_input_proj = None # EQ state to attention bias (scalar per head) if config.use_eq_attention_bias: self.attn_bias_proj = nn.Linear( config.eq_state_dim if config.use_eq_recurrence else config.eq_adapter_hidden_size, config.num_heads, bias=False ) else: self.attn_bias_proj = None # EQ state to FFN gate if config.use_eq_gated_ffn: self.ffn_gate_proj = nn.Linear( config.eq_state_dim if config.use_eq_recurrence else config.eq_adapter_hidden_size, config.intermediate_size, bias=False ) else: self.ffn_gate_proj = None def forward(self, hidden_states: torch.Tensor, prev_eq_state: Optional[torch.Tensor] = None): """ Args: hidden_states: [batch, seq_len, hidden_size] prev_eq_state: [batch, eq_state_dim] previous EQ state (for recurrence) Returns: frustration: [batch, 1] emotion_logits: [batch, 8] eq_state: [batch, eq_state_dim] updated EQ state attn_bias: [batch, num_heads, head_dim] or None ffn_gate: [batch, d_ff] or None """ # Pool over sequence dimension pooled = hidden_states.mean(dim=1) # Frustration score (0-1) frustration = self.frustration_detector(pooled) # Emotion logits emotion_logits = self.emotion_classifier(pooled) # Compute EQ state if self.config.use_eq_recurrence and self.eq_gru is not None: # Project pooled features to GRU input size gru_input = torch.tanh(self.gru_input_proj(pooled)) if prev_eq_state is None: # Initialize state from projection eq_state = torch.tanh(self.state_projection(pooled)) else: eq_state = self.eq_gru(gru_input, prev_eq_state) else: # No recurrence, use pooled features directly eq_state = torch.tanh(pooled) # Compute attention bias if enabled attn_bias = None if self.attn_bias_proj is not None: attn_bias = self.attn_bias_proj(eq_state) # [batch, num_heads] # Compute FFN gate if enabled ffn_gate = None if self.ffn_gate_proj is not None: ffn_gate = torch.sigmoid(self.ffn_gate_proj(eq_state)) return frustration, emotion_logits, eq_state, attn_bias, ffn_gate class ZenithLayer(nn.Module): """Single transformer layer with optional MoE and EQ adapter.""" def __init__(self, config: ZenithConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx # Determine if this layer uses MoE self.use_moe = ( config.num_experts > 0 and (not config.moe_layers or layer_idx in config.moe_layers) ) # Self attention projections self.q_proj = nn.Linear(config.hidden_size, config.hidden_size) self.k_proj = nn.Linear(config.hidden_size, config.hidden_size) self.v_proj = nn.Linear(config.hidden_size, config.hidden_size) self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) # Attention dropout self.attn_dropout = nn.Dropout(0.1) # MoE or dense feed-forward if self.use_moe: self.mlp = MoELayer(config) else: if config.use_eq_gated_ffn: # Gated MLP: gate applied to intermediate representation self.mlp = nn.Sequential( nn.Linear(config.hidden_size, config.intermediate_size), nn.SiLU(), ) self.gate_proj = nn.Linear(config.intermediate_size, config.intermediate_size) self.out_proj_mlp = nn.Linear(config.intermediate_size, config.hidden_size) else: self.mlp = nn.Sequential( nn.Linear(config.hidden_size, config.intermediate_size), nn.SiLU(), nn.Linear(config.intermediate_size, config.hidden_size) ) # Layer norm self.norm1 = nn.LayerNorm(config.hidden_size) self.norm2 = nn.LayerNorm(config.hidden_size) # Dropout self.dropout = nn.Dropout(0.1) # EQ adapter (if enabled) if config.use_eq_adapter: self.eq_adapter = EQAdapter(config) else: self.eq_adapter = None def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, prev_eq_state: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Args: hidden_states: [batch, seq_len, hidden_size] attention_mask: attention mask output_attentions: whether to output attention weights prev_eq_state: [batch, eq_state_dim] previous EQ state from previous layer Returns: hidden_states: [batch, seq_len, hidden_size] attn_weights: [batch, num_heads, seq_len, seq_len] or None moe_loss: scalar or None eq_state: [batch, eq_state_dim] or None consistency_loss: scalar or None """ # Process EQ adapter if enabled eq_state = None attn_bias = None ffn_gate = None consistency_loss = None if self.eq_adapter is not None: frustration, emotion_logits, eq_state, attn_bias, ffn_gate = self.eq_adapter( hidden_states, prev_eq_state ) # Compute consistency loss if recurrence enabled and we have previous state if self.config.use_eq_recurrence and prev_eq_state is not None: consistency_loss = F.mse_loss(eq_state, prev_eq_state.detach()) # Self attention with residual residual = hidden_states hidden_states = self.norm1(hidden_states) # Apply attention bias if enabled (before softmax) if attn_bias is not None: batch_size, seq_len, _ = hidden_states.shape # Compute Q, K, V from normalized hidden states q = self.q_proj(hidden_states) # [batch, seq_len, hidden_size] k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) # Reshape to multi-head: [batch, seq_len, num_heads, head_dim] q = q.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2) k = k.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2) v = v.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2) # Compute attention scores attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.config.head_dim) # Add bias: [batch, num_heads] -> [batch, num_heads, 1, 1] -> broadcast to all positions attn_scores = attn_scores + attn_bias.unsqueeze(-1).unsqueeze(-1) # Apply attention mask if provided if attention_mask is not None: attn_scores = attn_scores + attention_mask # Softmax and dropout attn_weights = F.softmax(attn_scores, dim=-1) attn_weights = self.attn_dropout(attn_weights) # Apply to values attn_output = torch.matmul(attn_weights, v) attn_output = attn_output.transpose(1, 2).contiguous().view( batch_size, seq_len, self.config.hidden_size ) attn_output = self.out_proj(attn_output) else: # Standard attention using manual projections batch_size, seq_len, _ = hidden_states.shape q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) q = q.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2) k = k.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2) v = v.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2) attn_output, attn_weights = F.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask, dropout_p=0.1 if self.training else 0.0, is_causal=True ) attn_output = attn_output.transpose(1, 2).contiguous().view( batch_size, seq_len, self.config.hidden_size ) attn_output = self.out_proj(attn_output) hidden_states = residual + self.dropout(attn_output) # Feed-forward with residual residual = hidden_states hidden_states = self.norm2(hidden_states) if self.use_moe: mlp_output, moe_loss = self.mlp(hidden_states) else: if self.config.use_eq_gated_ffn: # Apply first part of MLP intermediate = self.mlp(hidden_states) # [batch, seq_len, intermediate_size] # Apply gate to intermediate representation ffn_gate_expanded = ffn_gate.unsqueeze(1).expand(-1, intermediate.size(1), -1) gated_intermediate = intermediate * ffn_gate_expanded # Apply output projection mlp_output = self.out_proj_mlp(gated_intermediate) else: mlp_output = self.mlp(hidden_states) moe_loss = None hidden_states = residual + self.dropout(mlp_output) return hidden_states, attn_weights, moe_loss, eq_state, consistency_loss class ZenithPreTrainedModel(PreTrainedModel): """Base class for Zenith models.""" config_class = ZenithConfig base_model_prefix = "zenith" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize weights.""" if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class ZenithModel(ZenithPreTrainedModel): """Full Zenith model.""" def __init__(self, config: ZenithConfig): super().__init__(config) self.config = config # Token embedding self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) # Transformer layers self.layers = nn.ModuleList([ ZenithLayer(config, i) for i in range(config.num_layers) ]) # Final layer norm self.norm = nn.LayerNorm(config.hidden_size) # LM head self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Tie weights self.lm_head.weight = self.embed_tokens.weight # EQ adapter (if enabled) if config.use_eq_adapter: self.eq_adapter = EQAdapter(config) else: self.eq_adapter = None # Initialize weights self.apply(self._init_weights) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds # Transformer layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_moe_losses = [] all_eq_states = [] if self.config.use_eq_adapter else None all_consistency_losses = [] if (self.config.use_eq_adapter and self.config.use_eq_recurrence) else None # Initialize recurrent EQ state prev_eq_state = None for layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_outputs = layer( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, prev_eq_state=prev_eq_state ) hidden_states = layer_outputs[0] # Extract EQ state and consistency loss from layer outputs if self.config.use_eq_adapter: eq_state = layer_outputs[3] if len(layer_outputs) > 3 else None consistency_loss = layer_outputs[4] if len(layer_outputs) > 4 else None if eq_state is not None: all_eq_states.append(eq_state) prev_eq_state = eq_state # Pass to next layer if consistency_loss is not None: all_consistency_losses.append(consistency_loss) if output_attentions: all_self_attns = all_self_attns + (layer_outputs[1],) if layer_outputs[2] is not None: # MoE loss all_moe_losses.append(layer_outputs[2]) hidden_states = self.norm(hidden_states) # LM head logits = self.lm_head(hidden_states) # Compute loss if labels provided loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) # Add auxiliary losses if all_moe_losses: loss += torch.stack(all_moe_losses).mean() if self.eq_adapter is not None and all_consistency_losses: loss += self.config.eq_consistency_weight * torch.stack(all_consistency_losses).mean() if not return_dict: output = (logits,) + all_hidden_states + all_self_attns return ((loss,) + output) if loss is not None else output from transformers.modeling_outputs import CausalLMOutputWithPast return CausalLMOutputWithPast( loss=loss, logits=logits, hidden_states=all_hidden_states, attentions=all_self_attns, past_key_values=None ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, **kwargs ): # For generation, we only need the last token if past_key_values is not None if past_key_values: input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, } # Register the model from transformers import AutoConfig, AutoModelForCausalLM AutoConfig.register("zenith", ZenithConfig) AutoModelForCausalLM.register(ZenithConfig, ZenithModel)