| """
|
| Zenith Model for Hugging Face Transformers
|
|
|
| This module provides the Zenith model implementation that is compatible with
|
| Hugging Face's Transformers library, allowing for easy training, inference,
|
| and deployment.
|
|
|
| Zenith features:
|
| - Hybrid Dense+MoE architecture
|
| - Ring attention for long contexts
|
| - EQ adapters for emotional intelligence
|
| - Curriculum learning and quality filtering
|
| - OpenThoughts-1.2M integration
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import math
|
| from typing import Optional, Tuple, List, Dict, Any
|
|
|
| from transformers import PreTrainedModel, PretrainedConfig
|
| from transformers.modeling_outputs import CausalLMOutput
|
|
|
|
|
| class ZenithConfig(PretrainedConfig):
|
| """Configuration class for Zenith models."""
|
|
|
| model_type = "zenith"
|
|
|
| def __init__(
|
| self,
|
|
|
| hidden_size: int = 4096,
|
| num_layers: int = 32,
|
| num_heads: int = 32,
|
| num_key_value_heads: Optional[int] = None,
|
| intermediate_size: int = 11008,
|
| hidden_act: str = "silu",
|
| max_position_embeddings: int = 8192,
|
| vocab_size: int = 32000,
|
|
|
|
|
| num_experts: int = 0,
|
| moe_top_k: int = 2,
|
| moe_capacity_factor: float = 1.0,
|
| moe_load_balancing_weight: float = 0.01,
|
| moe_router_learning_rate: float = 1e-3,
|
| moe_layers: Optional[List[int]] = None,
|
|
|
|
|
| use_eq_adapter: bool = False,
|
| eq_adapter_hidden_size: int = 64,
|
| eq_loss_weight: float = 0.1,
|
| emotion_loss_weight: float = 0.1,
|
| frustration_loss_weight: float = 0.1,
|
|
|
|
|
| use_ring_attention: bool = False,
|
| ring_attention_chunk_size: int = 8192,
|
| ring_attention_overlap: int = 2048,
|
|
|
|
|
| use_tenstorrent_optimizations: bool = False,
|
| tensor_parallel_size: int = 1,
|
| pipeline_parallel_size: int = 1,
|
| noc_optimization: bool = False,
|
|
|
|
|
| gradient_checkpointing: bool = False,
|
| use_cache: bool = True,
|
|
|
| **kwargs
|
| ):
|
| super().__init__(**kwargs)
|
|
|
| self.hidden_size = hidden_size
|
| self.num_layers = num_layers
|
| self.num_heads = num_heads
|
| self.num_key_value_heads = num_key_value_heads or num_heads
|
| self.intermediate_size = intermediate_size
|
| self.hidden_act = hidden_act
|
| self.max_position_embeddings = max_position_embeddings
|
| self.vocab_size = vocab_size
|
|
|
| self.num_experts = num_experts
|
| self.moe_top_k = moe_top_k
|
| self.moe_capacity_factor = moe_capacity_factor
|
| self.moe_load_balancing_weight = moe_load_balancing_weight
|
| self.moe_router_learning_rate = moe_router_learning_rate
|
| self.moe_layers = moe_layers or []
|
|
|
| self.use_eq_adapter = use_eq_adapter
|
| self.eq_adapter_hidden_size = eq_adapter_hidden_size
|
| self.eq_loss_weight = eq_loss_weight
|
| self.emotion_loss_weight = emotion_loss_weight
|
| self.frustration_loss_weight = frustration_loss_weight
|
|
|
| self.use_ring_attention = use_ring_attention
|
| self.ring_attention_chunk_size = ring_attention_chunk_size
|
| self.ring_attention_overlap = ring_attention_overlap
|
|
|
| self.use_tenstorrent_optimizations = use_tenstorrent_optimizations
|
| self.tensor_parallel_size = tensor_parallel_size
|
| self.pipeline_parallel_size = pipeline_parallel_size
|
| self.noc_optimization = noc_optimization
|
|
|
| self.gradient_checkpointing = gradient_checkpointing
|
| self.use_cache = use_cache
|
|
|
|
|
| class MoELayer(nn.Module):
|
| """Mixture of Experts layer."""
|
|
|
| def __init__(self, config: ZenithConfig):
|
| super().__init__()
|
| self.config = config
|
| self.num_experts = config.num_experts
|
| self.top_k = config.moe_top_k
|
|
|
|
|
| self.router = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
|
|
|
|
| self.experts = nn.ModuleList([
|
| nn.Sequential(
|
| nn.Linear(config.hidden_size, config.intermediate_size),
|
| nn.SiLU(),
|
| nn.Linear(config.intermediate_size, config.hidden_size)
|
| )
|
| for _ in range(config.num_experts)
|
| ])
|
|
|
| def forward(self, hidden_states: torch.Tensor):
|
| batch_size, seq_len, hidden_dim = hidden_states.shape
|
|
|
|
|
| flat_hidden = hidden_states.view(-1, hidden_dim)
|
|
|
|
|
| router_logits = self.router(flat_hidden)
|
| router_probs = F.softmax(router_logits, dim=-1)
|
|
|
|
|
| topk_values, topk_indices = torch.topk(
|
| router_probs,
|
| k=self.top_k,
|
| dim=-1
|
| )
|
|
|
|
|
| topk_values = topk_values / topk_values.sum(dim=-1, keepdim=True)
|
|
|
|
|
| expert_outputs = torch.zeros_like(flat_hidden)
|
|
|
| for expert_idx in range(self.num_experts):
|
|
|
| expert_mask = (topk_indices == expert_idx).any(dim=-1)
|
| if expert_mask.any():
|
| expert_input = flat_hidden[expert_mask]
|
| expert_output = self.experts[expert_idx](expert_input)
|
| expert_outputs[expert_mask] += topk_values[expert_mask,
|
| (topk_indices[expert_mask] == expert_idx).nonzero(as_tuple=True)[1]
|
| ].unsqueeze(-1) * expert_output
|
|
|
|
|
| expert_outputs = expert_outputs.view(batch_size, seq_len, hidden_dim)
|
|
|
|
|
| router_probs_mean = router_probs.mean(dim=0)
|
| load_balance_loss = self.config.moe_load_balancing_weight * (
|
| router_probs_mean * torch.log(router_probs_mean + 1e-10)
|
| ).sum()
|
|
|
| return expert_outputs, load_balance_loss
|
|
|
|
|
| class EQAdapter(nn.Module):
|
| """Enhanced Emotional Intelligence Adapter with recurrent state and core architecture integration."""
|
|
|
| def __init__(self, config: ZenithConfig):
|
| super().__init__()
|
| self.config = config
|
|
|
|
|
| self.frustration_detector = nn.Sequential(
|
| nn.Linear(config.hidden_size, config.eq_adapter_hidden_size),
|
| nn.Tanh(),
|
| nn.Linear(config.eq_adapter_hidden_size, 1),
|
| nn.Sigmoid()
|
| )
|
|
|
|
|
| self.emotion_classifier = nn.Sequential(
|
| nn.Linear(config.hidden_size, config.eq_adapter_hidden_size),
|
| nn.Tanh(),
|
| nn.Linear(config.eq_adapter_hidden_size, 8)
|
| )
|
|
|
|
|
| if config.use_eq_recurrence:
|
| self.eq_gru = nn.GRUCell(
|
| input_size=config.eq_adapter_hidden_size,
|
| hidden_size=config.eq_state_dim
|
| )
|
|
|
| self.state_projection = nn.Linear(config.hidden_size, config.eq_state_dim)
|
|
|
| self.gru_input_proj = nn.Linear(config.hidden_size, config.eq_adapter_hidden_size)
|
| else:
|
| self.eq_gru = None
|
| self.state_projection = None
|
| self.gru_input_proj = None
|
|
|
|
|
| if config.use_eq_attention_bias:
|
| self.attn_bias_proj = nn.Linear(
|
| config.eq_state_dim if config.use_eq_recurrence else config.eq_adapter_hidden_size,
|
| config.num_heads,
|
| bias=False
|
| )
|
| else:
|
| self.attn_bias_proj = None
|
|
|
|
|
| if config.use_eq_gated_ffn:
|
| self.ffn_gate_proj = nn.Linear(
|
| config.eq_state_dim if config.use_eq_recurrence else config.eq_adapter_hidden_size,
|
| config.intermediate_size,
|
| bias=False
|
| )
|
| else:
|
| self.ffn_gate_proj = None
|
|
|
| def forward(self, hidden_states: torch.Tensor, prev_eq_state: Optional[torch.Tensor] = None):
|
| """
|
| Args:
|
| hidden_states: [batch, seq_len, hidden_size]
|
| prev_eq_state: [batch, eq_state_dim] previous EQ state (for recurrence)
|
|
|
| Returns:
|
| frustration: [batch, 1]
|
| emotion_logits: [batch, 8]
|
| eq_state: [batch, eq_state_dim] updated EQ state
|
| attn_bias: [batch, num_heads, head_dim] or None
|
| ffn_gate: [batch, d_ff] or None
|
| """
|
|
|
| pooled = hidden_states.mean(dim=1)
|
|
|
|
|
| frustration = self.frustration_detector(pooled)
|
|
|
|
|
| emotion_logits = self.emotion_classifier(pooled)
|
|
|
|
|
| if self.config.use_eq_recurrence and self.eq_gru is not None:
|
|
|
| gru_input = torch.tanh(self.gru_input_proj(pooled))
|
| if prev_eq_state is None:
|
|
|
| eq_state = torch.tanh(self.state_projection(pooled))
|
| else:
|
| eq_state = self.eq_gru(gru_input, prev_eq_state)
|
| else:
|
|
|
| eq_state = torch.tanh(pooled)
|
|
|
|
|
| attn_bias = None
|
| if self.attn_bias_proj is not None:
|
| attn_bias = self.attn_bias_proj(eq_state)
|
|
|
|
|
| ffn_gate = None
|
| if self.ffn_gate_proj is not None:
|
| ffn_gate = torch.sigmoid(self.ffn_gate_proj(eq_state))
|
|
|
| return frustration, emotion_logits, eq_state, attn_bias, ffn_gate
|
|
|
|
|
| class ZenithLayer(nn.Module):
|
| """Single transformer layer with optional MoE and EQ adapter."""
|
|
|
| def __init__(self, config: ZenithConfig, layer_idx: int):
|
| super().__init__()
|
| self.config = config
|
| self.layer_idx = layer_idx
|
|
|
|
|
| self.use_moe = (
|
| config.num_experts > 0 and
|
| (not config.moe_layers or layer_idx in config.moe_layers)
|
| )
|
|
|
|
|
| self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
|
|
|
|
| self.attn_dropout = nn.Dropout(0.1)
|
|
|
|
|
| if self.use_moe:
|
| self.mlp = MoELayer(config)
|
| else:
|
| if config.use_eq_gated_ffn:
|
|
|
| self.mlp = nn.Sequential(
|
| nn.Linear(config.hidden_size, config.intermediate_size),
|
| nn.SiLU(),
|
| )
|
| self.gate_proj = nn.Linear(config.intermediate_size, config.intermediate_size)
|
| self.out_proj_mlp = nn.Linear(config.intermediate_size, config.hidden_size)
|
| else:
|
| self.mlp = nn.Sequential(
|
| nn.Linear(config.hidden_size, config.intermediate_size),
|
| nn.SiLU(),
|
| nn.Linear(config.intermediate_size, config.hidden_size)
|
| )
|
|
|
|
|
| self.norm1 = nn.LayerNorm(config.hidden_size)
|
| self.norm2 = nn.LayerNorm(config.hidden_size)
|
|
|
|
|
| self.dropout = nn.Dropout(0.1)
|
|
|
|
|
| if config.use_eq_adapter:
|
| self.eq_adapter = EQAdapter(config)
|
| else:
|
| self.eq_adapter = None
|
|
|
| def forward(
|
| self,
|
| hidden_states: torch.Tensor,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| output_attentions: bool = False,
|
| prev_eq_state: Optional[torch.Tensor] = None
|
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| """
|
| Args:
|
| hidden_states: [batch, seq_len, hidden_size]
|
| attention_mask: attention mask
|
| output_attentions: whether to output attention weights
|
| prev_eq_state: [batch, eq_state_dim] previous EQ state from previous layer
|
|
|
| Returns:
|
| hidden_states: [batch, seq_len, hidden_size]
|
| attn_weights: [batch, num_heads, seq_len, seq_len] or None
|
| moe_loss: scalar or None
|
| eq_state: [batch, eq_state_dim] or None
|
| consistency_loss: scalar or None
|
| """
|
|
|
| eq_state = None
|
| attn_bias = None
|
| ffn_gate = None
|
| consistency_loss = None
|
|
|
| if self.eq_adapter is not None:
|
| frustration, emotion_logits, eq_state, attn_bias, ffn_gate = self.eq_adapter(
|
| hidden_states, prev_eq_state
|
| )
|
|
|
|
|
| if self.config.use_eq_recurrence and prev_eq_state is not None:
|
| consistency_loss = F.mse_loss(eq_state, prev_eq_state.detach())
|
|
|
|
|
| residual = hidden_states
|
| hidden_states = self.norm1(hidden_states)
|
|
|
|
|
| if attn_bias is not None:
|
| batch_size, seq_len, _ = hidden_states.shape
|
|
|
|
|
| q = self.q_proj(hidden_states)
|
| k = self.k_proj(hidden_states)
|
| v = self.v_proj(hidden_states)
|
|
|
|
|
| q = q.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
|
| k = k.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
|
| v = v.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
|
|
|
|
|
| attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.config.head_dim)
|
|
|
|
|
| attn_scores = attn_scores + attn_bias.unsqueeze(-1).unsqueeze(-1)
|
|
|
|
|
| if attention_mask is not None:
|
| attn_scores = attn_scores + attention_mask
|
|
|
|
|
| attn_weights = F.softmax(attn_scores, dim=-1)
|
| attn_weights = self.attn_dropout(attn_weights)
|
|
|
|
|
| attn_output = torch.matmul(attn_weights, v)
|
| attn_output = attn_output.transpose(1, 2).contiguous().view(
|
| batch_size, seq_len, self.config.hidden_size
|
| )
|
| attn_output = self.out_proj(attn_output)
|
| else:
|
|
|
| batch_size, seq_len, _ = hidden_states.shape
|
|
|
| q = self.q_proj(hidden_states)
|
| k = self.k_proj(hidden_states)
|
| v = self.v_proj(hidden_states)
|
|
|
| q = q.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
|
| k = k.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
|
| v = v.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
|
|
|
| attn_output, attn_weights = F.scaled_dot_product_attention(
|
| q, k, v,
|
| attn_mask=attention_mask,
|
| dropout_p=0.1 if self.training else 0.0,
|
| is_causal=True
|
| )
|
|
|
| attn_output = attn_output.transpose(1, 2).contiguous().view(
|
| batch_size, seq_len, self.config.hidden_size
|
| )
|
| attn_output = self.out_proj(attn_output)
|
|
|
| hidden_states = residual + self.dropout(attn_output)
|
|
|
|
|
| residual = hidden_states
|
| hidden_states = self.norm2(hidden_states)
|
|
|
| if self.use_moe:
|
| mlp_output, moe_loss = self.mlp(hidden_states)
|
| else:
|
| if self.config.use_eq_gated_ffn:
|
|
|
| intermediate = self.mlp(hidden_states)
|
|
|
| ffn_gate_expanded = ffn_gate.unsqueeze(1).expand(-1, intermediate.size(1), -1)
|
| gated_intermediate = intermediate * ffn_gate_expanded
|
|
|
| mlp_output = self.out_proj_mlp(gated_intermediate)
|
| else:
|
| mlp_output = self.mlp(hidden_states)
|
| moe_loss = None
|
|
|
| hidden_states = residual + self.dropout(mlp_output)
|
|
|
| return hidden_states, attn_weights, moe_loss, eq_state, consistency_loss
|
|
|
|
|
| class ZenithPreTrainedModel(PreTrainedModel):
|
| """Base class for Zenith models."""
|
|
|
| config_class = ZenithConfig
|
| base_model_prefix = "zenith"
|
| supports_gradient_checkpointing = True
|
|
|
| def _init_weights(self, module):
|
| """Initialize weights."""
|
| if isinstance(module, nn.Linear):
|
| nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| if module.bias is not None:
|
| nn.init.zeros_(module.bias)
|
| elif isinstance(module, nn.Embedding):
|
| nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| if module.padding_idx is not None:
|
| module.weight.data[module.padding_idx].zero_()
|
|
|
|
|
| class ZenithModel(ZenithPreTrainedModel):
|
| """Full Zenith model."""
|
|
|
| def __init__(self, config: ZenithConfig):
|
| super().__init__(config)
|
| self.config = config
|
|
|
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
|
|
|
| self.layers = nn.ModuleList([
|
| ZenithLayer(config, i)
|
| for i in range(config.num_layers)
|
| ])
|
|
|
|
|
| self.norm = nn.LayerNorm(config.hidden_size)
|
|
|
|
|
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
| self.lm_head.weight = self.embed_tokens.weight
|
|
|
|
|
| if config.use_eq_adapter:
|
| self.eq_adapter = EQAdapter(config)
|
| else:
|
| self.eq_adapter = None
|
|
|
|
|
| self.apply(self._init_weights)
|
|
|
| def forward(
|
| self,
|
| input_ids: Optional[torch.LongTensor] = None,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| position_ids: Optional[torch.LongTensor] = None,
|
| past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| inputs_embeds: Optional[torch.FloatTensor] = None,
|
| labels: Optional[torch.LongTensor] = None,
|
| use_cache: Optional[bool] = None,
|
| output_attentions: Optional[bool] = None,
|
| output_hidden_states: Optional[bool] = None,
|
| return_dict: Optional[bool] = None,
|
| ):
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
| if input_ids is not None and inputs_embeds is not None:
|
| raise ValueError("You cannot specify both input_ids and inputs_embeds")
|
| elif input_ids is not None:
|
| batch_size, seq_length = input_ids.shape
|
| elif inputs_embeds is not None:
|
| batch_size, seq_length, _ = inputs_embeds.shape
|
| else:
|
| raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
| if inputs_embeds is None:
|
| inputs_embeds = self.embed_tokens(input_ids)
|
|
|
| hidden_states = inputs_embeds
|
|
|
|
|
| all_hidden_states = () if output_hidden_states else None
|
| all_self_attns = () if output_attentions else None
|
| all_moe_losses = []
|
| all_eq_states = [] if self.config.use_eq_adapter else None
|
| all_consistency_losses = [] if (self.config.use_eq_adapter and self.config.use_eq_recurrence) else None
|
|
|
|
|
| prev_eq_state = None
|
|
|
| for layer in self.layers:
|
| if output_hidden_states:
|
| all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
| layer_outputs = layer(
|
| hidden_states,
|
| attention_mask=attention_mask,
|
| output_attentions=output_attentions,
|
| prev_eq_state=prev_eq_state
|
| )
|
|
|
| hidden_states = layer_outputs[0]
|
|
|
|
|
| if self.config.use_eq_adapter:
|
| eq_state = layer_outputs[3] if len(layer_outputs) > 3 else None
|
| consistency_loss = layer_outputs[4] if len(layer_outputs) > 4 else None
|
|
|
| if eq_state is not None:
|
| all_eq_states.append(eq_state)
|
| prev_eq_state = eq_state
|
|
|
| if consistency_loss is not None:
|
| all_consistency_losses.append(consistency_loss)
|
|
|
| if output_attentions:
|
| all_self_attns = all_self_attns + (layer_outputs[1],)
|
|
|
| if layer_outputs[2] is not None:
|
| all_moe_losses.append(layer_outputs[2])
|
|
|
| hidden_states = self.norm(hidden_states)
|
|
|
|
|
| logits = self.lm_head(hidden_states)
|
|
|
|
|
| loss = None
|
| if labels is not None:
|
|
|
| shift_logits = logits[..., :-1, :].contiguous()
|
| shift_labels = labels[..., 1:].contiguous()
|
|
|
|
|
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| loss = loss_fct(
|
| shift_logits.view(-1, shift_logits.size(-1)),
|
| shift_labels.view(-1)
|
| )
|
|
|
|
|
| if all_moe_losses:
|
| loss += torch.stack(all_moe_losses).mean()
|
|
|
| if self.eq_adapter is not None and all_consistency_losses:
|
| loss += self.config.eq_consistency_weight * torch.stack(all_consistency_losses).mean()
|
|
|
| if not return_dict:
|
| output = (logits,) + all_hidden_states + all_self_attns
|
| return ((loss,) + output) if loss is not None else output
|
|
|
| from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
| return CausalLMOutputWithPast(
|
| loss=loss,
|
| logits=logits,
|
| hidden_states=all_hidden_states,
|
| attentions=all_self_attns,
|
| past_key_values=None
|
| )
|
|
|
| def prepare_inputs_for_generation(
|
| self,
|
| input_ids,
|
| past_key_values=None,
|
| attention_mask=None,
|
| **kwargs
|
| ):
|
|
|
| if past_key_values:
|
| input_ids = input_ids[:, -1:]
|
|
|
| return {
|
| "input_ids": input_ids,
|
| "attention_mask": attention_mask,
|
| "past_key_values": past_key_values,
|
| }
|
|
|
|
|
|
|
| from transformers import AutoConfig, AutoModelForCausalLM
|
|
|
| AutoConfig.register("zenith", ZenithConfig)
|
| AutoModelForCausalLM.register(ZenithConfig, ZenithModel) |