Zenith-7b-V1 / modeling_zenith.py
Zandy-Wandy's picture
Upload Zenith-7B model
1ea8a03 verified
"""
Zenith Model for Hugging Face Transformers
This module provides the Zenith model implementation that is compatible with
Hugging Face's Transformers library, allowing for easy training, inference,
and deployment.
Zenith features:
- Hybrid Dense+MoE architecture
- Ring attention for long contexts
- EQ adapters for emotional intelligence
- Curriculum learning and quality filtering
- OpenThoughts-1.2M integration
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, List, Dict, Any
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutput
class ZenithConfig(PretrainedConfig):
"""Configuration class for Zenith models."""
model_type = "zenith"
def __init__(
self,
# Architecture
hidden_size: int = 4096,
num_layers: int = 32,
num_heads: int = 32,
num_key_value_heads: Optional[int] = None,
intermediate_size: int = 11008,
hidden_act: str = "silu",
max_position_embeddings: int = 8192,
vocab_size: int = 32000,
# MoE
num_experts: int = 0,
moe_top_k: int = 2,
moe_capacity_factor: float = 1.0,
moe_load_balancing_weight: float = 0.01,
moe_router_learning_rate: float = 1e-3,
moe_layers: Optional[List[int]] = None, # Which layers use MoE (0-indexed)
# EQ Adapter
use_eq_adapter: bool = False,
eq_adapter_hidden_size: int = 64,
eq_loss_weight: float = 0.1,
emotion_loss_weight: float = 0.1,
frustration_loss_weight: float = 0.1,
# Ring Attention
use_ring_attention: bool = False,
ring_attention_chunk_size: int = 8192,
ring_attention_overlap: int = 2048,
# p300 optimizations
use_tenstorrent_optimizations: bool = False,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
noc_optimization: bool = False,
# Training
gradient_checkpointing: bool = False,
use_cache: bool = True,
**kwargs
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.num_key_value_heads = num_key_value_heads or num_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.vocab_size = vocab_size
self.num_experts = num_experts
self.moe_top_k = moe_top_k
self.moe_capacity_factor = moe_capacity_factor
self.moe_load_balancing_weight = moe_load_balancing_weight
self.moe_router_learning_rate = moe_router_learning_rate
self.moe_layers = moe_layers or []
self.use_eq_adapter = use_eq_adapter
self.eq_adapter_hidden_size = eq_adapter_hidden_size
self.eq_loss_weight = eq_loss_weight
self.emotion_loss_weight = emotion_loss_weight
self.frustration_loss_weight = frustration_loss_weight
self.use_ring_attention = use_ring_attention
self.ring_attention_chunk_size = ring_attention_chunk_size
self.ring_attention_overlap = ring_attention_overlap
self.use_tenstorrent_optimizations = use_tenstorrent_optimizations
self.tensor_parallel_size = tensor_parallel_size
self.pipeline_parallel_size = pipeline_parallel_size
self.noc_optimization = noc_optimization
self.gradient_checkpointing = gradient_checkpointing
self.use_cache = use_cache
class MoELayer(nn.Module):
"""Mixture of Experts layer."""
def __init__(self, config: ZenithConfig):
super().__init__()
self.config = config
self.num_experts = config.num_experts
self.top_k = config.moe_top_k
# Router
self.router = nn.Linear(config.hidden_size, config.num_experts, bias=False)
# Experts
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size),
nn.SiLU(),
nn.Linear(config.intermediate_size, config.hidden_size)
)
for _ in range(config.num_experts)
])
def forward(self, hidden_states: torch.Tensor):
batch_size, seq_len, hidden_dim = hidden_states.shape
# Reshape for routing
flat_hidden = hidden_states.view(-1, hidden_dim)
# Router logits
router_logits = self.router(flat_hidden)
router_probs = F.softmax(router_logits, dim=-1)
# Top-k routing
topk_values, topk_indices = torch.topk(
router_probs,
k=self.top_k,
dim=-1
)
# Normalize top-k weights
topk_values = topk_values / topk_values.sum(dim=-1, keepdim=True)
# Compute expert outputs
expert_outputs = torch.zeros_like(flat_hidden)
for expert_idx in range(self.num_experts):
# Find tokens routed to this expert
expert_mask = (topk_indices == expert_idx).any(dim=-1)
if expert_mask.any():
expert_input = flat_hidden[expert_mask]
expert_output = self.experts[expert_idx](expert_input)
expert_outputs[expert_mask] += topk_values[expert_mask,
(topk_indices[expert_mask] == expert_idx).nonzero(as_tuple=True)[1]
].unsqueeze(-1) * expert_output
# Reshape back
expert_outputs = expert_outputs.view(batch_size, seq_len, hidden_dim)
# Load balancing loss
router_probs_mean = router_probs.mean(dim=0)
load_balance_loss = self.config.moe_load_balancing_weight * (
router_probs_mean * torch.log(router_probs_mean + 1e-10)
).sum()
return expert_outputs, load_balance_loss
class EQAdapter(nn.Module):
"""Enhanced Emotional Intelligence Adapter with recurrent state and core architecture integration."""
def __init__(self, config: ZenithConfig):
super().__init__()
self.config = config
# Frustration detection (regression)
self.frustration_detector = nn.Sequential(
nn.Linear(config.hidden_size, config.eq_adapter_hidden_size),
nn.Tanh(),
nn.Linear(config.eq_adapter_hidden_size, 1),
nn.Sigmoid()
)
# Emotion classification (8 classes)
self.emotion_classifier = nn.Sequential(
nn.Linear(config.hidden_size, config.eq_adapter_hidden_size),
nn.Tanh(),
nn.Linear(config.eq_adapter_hidden_size, 8)
)
# Recurrent EQ state (GRU) for layer-to-layer consistency
if config.use_eq_recurrence:
self.eq_gru = nn.GRUCell(
input_size=config.eq_adapter_hidden_size,
hidden_size=config.eq_state_dim
)
# Projection to generate initial state from pooled features
self.state_projection = nn.Linear(config.hidden_size, config.eq_state_dim)
# Projection to reduce pooled features to GRU input size
self.gru_input_proj = nn.Linear(config.hidden_size, config.eq_adapter_hidden_size)
else:
self.eq_gru = None
self.state_projection = None
self.gru_input_proj = None
# EQ state to attention bias (scalar per head)
if config.use_eq_attention_bias:
self.attn_bias_proj = nn.Linear(
config.eq_state_dim if config.use_eq_recurrence else config.eq_adapter_hidden_size,
config.num_heads,
bias=False
)
else:
self.attn_bias_proj = None
# EQ state to FFN gate
if config.use_eq_gated_ffn:
self.ffn_gate_proj = nn.Linear(
config.eq_state_dim if config.use_eq_recurrence else config.eq_adapter_hidden_size,
config.intermediate_size,
bias=False
)
else:
self.ffn_gate_proj = None
def forward(self, hidden_states: torch.Tensor, prev_eq_state: Optional[torch.Tensor] = None):
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
prev_eq_state: [batch, eq_state_dim] previous EQ state (for recurrence)
Returns:
frustration: [batch, 1]
emotion_logits: [batch, 8]
eq_state: [batch, eq_state_dim] updated EQ state
attn_bias: [batch, num_heads, head_dim] or None
ffn_gate: [batch, d_ff] or None
"""
# Pool over sequence dimension
pooled = hidden_states.mean(dim=1)
# Frustration score (0-1)
frustration = self.frustration_detector(pooled)
# Emotion logits
emotion_logits = self.emotion_classifier(pooled)
# Compute EQ state
if self.config.use_eq_recurrence and self.eq_gru is not None:
# Project pooled features to GRU input size
gru_input = torch.tanh(self.gru_input_proj(pooled))
if prev_eq_state is None:
# Initialize state from projection
eq_state = torch.tanh(self.state_projection(pooled))
else:
eq_state = self.eq_gru(gru_input, prev_eq_state)
else:
# No recurrence, use pooled features directly
eq_state = torch.tanh(pooled)
# Compute attention bias if enabled
attn_bias = None
if self.attn_bias_proj is not None:
attn_bias = self.attn_bias_proj(eq_state) # [batch, num_heads]
# Compute FFN gate if enabled
ffn_gate = None
if self.ffn_gate_proj is not None:
ffn_gate = torch.sigmoid(self.ffn_gate_proj(eq_state))
return frustration, emotion_logits, eq_state, attn_bias, ffn_gate
class ZenithLayer(nn.Module):
"""Single transformer layer with optional MoE and EQ adapter."""
def __init__(self, config: ZenithConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
# Determine if this layer uses MoE
self.use_moe = (
config.num_experts > 0 and
(not config.moe_layers or layer_idx in config.moe_layers)
)
# Self attention projections
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
# Attention dropout
self.attn_dropout = nn.Dropout(0.1)
# MoE or dense feed-forward
if self.use_moe:
self.mlp = MoELayer(config)
else:
if config.use_eq_gated_ffn:
# Gated MLP: gate applied to intermediate representation
self.mlp = nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size),
nn.SiLU(),
)
self.gate_proj = nn.Linear(config.intermediate_size, config.intermediate_size)
self.out_proj_mlp = nn.Linear(config.intermediate_size, config.hidden_size)
else:
self.mlp = nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size),
nn.SiLU(),
nn.Linear(config.intermediate_size, config.hidden_size)
)
# Layer norm
self.norm1 = nn.LayerNorm(config.hidden_size)
self.norm2 = nn.LayerNorm(config.hidden_size)
# Dropout
self.dropout = nn.Dropout(0.1)
# EQ adapter (if enabled)
if config.use_eq_adapter:
self.eq_adapter = EQAdapter(config)
else:
self.eq_adapter = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
prev_eq_state: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
attention_mask: attention mask
output_attentions: whether to output attention weights
prev_eq_state: [batch, eq_state_dim] previous EQ state from previous layer
Returns:
hidden_states: [batch, seq_len, hidden_size]
attn_weights: [batch, num_heads, seq_len, seq_len] or None
moe_loss: scalar or None
eq_state: [batch, eq_state_dim] or None
consistency_loss: scalar or None
"""
# Process EQ adapter if enabled
eq_state = None
attn_bias = None
ffn_gate = None
consistency_loss = None
if self.eq_adapter is not None:
frustration, emotion_logits, eq_state, attn_bias, ffn_gate = self.eq_adapter(
hidden_states, prev_eq_state
)
# Compute consistency loss if recurrence enabled and we have previous state
if self.config.use_eq_recurrence and prev_eq_state is not None:
consistency_loss = F.mse_loss(eq_state, prev_eq_state.detach())
# Self attention with residual
residual = hidden_states
hidden_states = self.norm1(hidden_states)
# Apply attention bias if enabled (before softmax)
if attn_bias is not None:
batch_size, seq_len, _ = hidden_states.shape
# Compute Q, K, V from normalized hidden states
q = self.q_proj(hidden_states) # [batch, seq_len, hidden_size]
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# Reshape to multi-head: [batch, seq_len, num_heads, head_dim]
q = q.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
# Compute attention scores
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.config.head_dim)
# Add bias: [batch, num_heads] -> [batch, num_heads, 1, 1] -> broadcast to all positions
attn_scores = attn_scores + attn_bias.unsqueeze(-1).unsqueeze(-1)
# Apply attention mask if provided
if attention_mask is not None:
attn_scores = attn_scores + attention_mask
# Softmax and dropout
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = self.attn_dropout(attn_weights)
# Apply to values
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.config.hidden_size
)
attn_output = self.out_proj(attn_output)
else:
# Standard attention using manual projections
batch_size, seq_len, _ = hidden_states.shape
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = q.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
attn_output, attn_weights = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attention_mask,
dropout_p=0.1 if self.training else 0.0,
is_causal=True
)
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.config.hidden_size
)
attn_output = self.out_proj(attn_output)
hidden_states = residual + self.dropout(attn_output)
# Feed-forward with residual
residual = hidden_states
hidden_states = self.norm2(hidden_states)
if self.use_moe:
mlp_output, moe_loss = self.mlp(hidden_states)
else:
if self.config.use_eq_gated_ffn:
# Apply first part of MLP
intermediate = self.mlp(hidden_states) # [batch, seq_len, intermediate_size]
# Apply gate to intermediate representation
ffn_gate_expanded = ffn_gate.unsqueeze(1).expand(-1, intermediate.size(1), -1)
gated_intermediate = intermediate * ffn_gate_expanded
# Apply output projection
mlp_output = self.out_proj_mlp(gated_intermediate)
else:
mlp_output = self.mlp(hidden_states)
moe_loss = None
hidden_states = residual + self.dropout(mlp_output)
return hidden_states, attn_weights, moe_loss, eq_state, consistency_loss
class ZenithPreTrainedModel(PreTrainedModel):
"""Base class for Zenith models."""
config_class = ZenithConfig
base_model_prefix = "zenith"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize weights."""
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class ZenithModel(ZenithPreTrainedModel):
"""Full Zenith model."""
def __init__(self, config: ZenithConfig):
super().__init__(config)
self.config = config
# Token embedding
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
# Transformer layers
self.layers = nn.ModuleList([
ZenithLayer(config, i)
for i in range(config.num_layers)
])
# Final layer norm
self.norm = nn.LayerNorm(config.hidden_size)
# LM head
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Tie weights
self.lm_head.weight = self.embed_tokens.weight
# EQ adapter (if enabled)
if config.use_eq_adapter:
self.eq_adapter = EQAdapter(config)
else:
self.eq_adapter = None
# Initialize weights
self.apply(self._init_weights)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
# Transformer layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_moe_losses = []
all_eq_states = [] if self.config.use_eq_adapter else None
all_consistency_losses = [] if (self.config.use_eq_adapter and self.config.use_eq_recurrence) else None
# Initialize recurrent EQ state
prev_eq_state = None
for layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
prev_eq_state=prev_eq_state
)
hidden_states = layer_outputs[0]
# Extract EQ state and consistency loss from layer outputs
if self.config.use_eq_adapter:
eq_state = layer_outputs[3] if len(layer_outputs) > 3 else None
consistency_loss = layer_outputs[4] if len(layer_outputs) > 4 else None
if eq_state is not None:
all_eq_states.append(eq_state)
prev_eq_state = eq_state # Pass to next layer
if consistency_loss is not None:
all_consistency_losses.append(consistency_loss)
if output_attentions:
all_self_attns = all_self_attns + (layer_outputs[1],)
if layer_outputs[2] is not None: # MoE loss
all_moe_losses.append(layer_outputs[2])
hidden_states = self.norm(hidden_states)
# LM head
logits = self.lm_head(hidden_states)
# Compute loss if labels provided
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
# Add auxiliary losses
if all_moe_losses:
loss += torch.stack(all_moe_losses).mean()
if self.eq_adapter is not None and all_consistency_losses:
loss += self.config.eq_consistency_weight * torch.stack(all_consistency_losses).mean()
if not return_dict:
output = (logits,) + all_hidden_states + all_self_attns
return ((loss,) + output) if loss is not None else output
from transformers.modeling_outputs import CausalLMOutputWithPast
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
hidden_states=all_hidden_states,
attentions=all_self_attns,
past_key_values=None
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
**kwargs
):
# For generation, we only need the last token if past_key_values is not None
if past_key_values:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
}
# Register the model
from transformers import AutoConfig, AutoModelForCausalLM
AutoConfig.register("zenith", ZenithConfig)
AutoModelForCausalLM.register(ZenithConfig, ZenithModel)