Zenith-70b-p300-V1 / modeling_zenith.py
Zandy-Wandy's picture
Upload Zenith-70b-V1-Tenstorrent-Blackhole-p300 model
02a1aee verified
"""
Zenith Model for Hugging Face Transformers - 70B-p300 Variant
This module provides the Zenith model implementation that is compatible with
Hugging Face's Transformers library, allowing for easy training, inference,
and deployment.
Zenith features:
- Hybrid Dense+MoE architecture
- Ring attention for long contexts (32K)
- EQ adapters for emotional intelligence
- Curriculum learning and quality filtering
- OpenThoughts-1.2M integration
- Tenstorrent p300 optimizations (TP=8, PP=4, NoC)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, List, Dict, Any
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutput
class ZenithConfig(PretrainedConfig):
"""Configuration class for Zenith models."""
model_type = "zenith"
def __init__(
self,
# Architecture
hidden_size: int = 8192,
num_layers: int = 80,
num_heads: int = 64,
num_key_value_heads: Optional[int] = None,
intermediate_size: int = 22016,
hidden_act: str = "silu",
max_position_embeddings: int = 32768,
vocab_size: int = 32000,
# MoE
num_experts: int = 12,
moe_top_k: int = 2,
moe_capacity_factor: float = 1.0,
moe_load_balancing_weight: float = 0.01,
moe_router_learning_rate: float = 1e-3,
moe_layers: Optional[List[int]] = None, # Which layers use MoE (0-indexed)
# EQ Adapter
use_eq_adapter: bool = True,
eq_adapter_hidden_size: int = 64,
eq_loss_weight: float = 0.1,
emotion_loss_weight: float = 0.1,
frustration_loss_weight: float = 0.1,
# Ring Attention
use_ring_attention: bool = True,
ring_attention_chunk_size: int = 8192,
ring_attention_overlap: int = 2048,
# p300 optimizations
use_tenstorrent_optimizations: bool = True,
tensor_parallel_size: int = 8,
pipeline_parallel_size: int = 4,
noc_optimization: bool = True,
# Training
gradient_checkpointing: bool = False,
use_cache: bool = True,
**kwargs
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.num_key_value_heads = num_key_value_heads or num_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.vocab_size = vocab_size
self.num_experts = num_experts
self.moe_top_k = moe_top_k
self.moe_capacity_factor = moe_capacity_factor
self.moe_load_balancing_weight = moe_load_balancing_weight
self.moe_router_learning_rate = moe_router_learning_rate
self.moe_layers = moe_layers or list(range(32, 56)) # Middle 24 layers for 70B
self.use_eq_adapter = use_eq_adapter
self.eq_adapter_hidden_size = eq_adapter_hidden_size
self.eq_loss_weight = eq_loss_weight
self.emotion_loss_weight = emotion_loss_weight
self.frustration_loss_weight = frustration_loss_weight
self.use_ring_attention = use_ring_attention
self.ring_attention_chunk_size = ring_attention_chunk_size
self.ring_attention_overlap = ring_attention_overlap
self.use_tenstorrent_optimizations = use_tenstorrent_optimizations
self.tensor_parallel_size = tensor_parallel_size
self.pipeline_parallel_size = pipeline_parallel_size
self.noc_optimization = noc_optimization
self.gradient_checkpointing = gradient_checkpointing
self.use_cache = use_cache
class MoELayer(nn.Module):
"""Mixture of Experts layer with top-k routing."""
def __init__(self, config: ZenithConfig):
super().__init__()
self.config = config
self.num_experts = config.num_experts
self.top_k = config.moe_top_k
# Expert networks
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size),
nn.SiLU(),
nn.Linear(config.intermediate_size, config.hidden_size)
)
for _ in range(self.num_experts)
])
# Router
self.router = nn.Linear(config.hidden_size, config.num_experts, bias=False)
# Load balancing loss
self.register_buffer("expert_counts", torch.zeros(config.num_experts))
def forward(self, hidden_states: torch.Tensor):
batch_size, seq_len, hidden_dim = hidden_states.shape
# Compute routing weights
routing_logits = self.router(hidden_states) # [batch, seq_len, num_experts]
routing_weights = F.softmax(routing_logits, dim=-1)
# Top-k selection
top_k_weights, top_k_indices = torch.topk(
routing_weights, self.top_k, dim=-1
) # [batch, seq_len, top_k]
# Normalize top-k weights
top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
# Expert output aggregation
output = torch.zeros_like(hidden_states)
for i in range(self.top_k):
expert_indices = top_k_indices[:, :, i] # [batch, seq_len]
expert_weights = top_k_weights[:, :, i].unsqueeze(-1) # [batch, seq_len, 1]
# Dispatch to experts
for expert_idx in range(self.num_experts):
mask = (expert_indices == expert_idx)
if mask.any():
expert_input = hidden_states[mask]
expert_output = self.experts[expert_idx](expert_input)
output[mask] += expert_weights[mask] * expert_output
# Load balancing tracking
if self.training:
with torch.no_grad():
flat_indices = top_k_indices.view(-1)
for idx in flat_indices:
self.expert_counts[idx] += 1
return output
class EQAdapter(nn.Module):
"""Enhanced Emotional Intelligence Adapter with recurrent state and core architecture integration."""
def __init__(self, config: ZenithConfig):
super().__init__()
self.config = config
# Frustration detection (regression)
self.frustration_detector = nn.Sequential(
nn.Linear(config.hidden_size, config.eq_adapter_hidden_size),
nn.Tanh(),
nn.Linear(config.eq_adapter_hidden_size, 1),
nn.Sigmoid()
)
# Emotion classification (8 classes)
self.emotion_classifier = nn.Sequential(
nn.Linear(config.hidden_size, config.eq_adapter_hidden_size),
nn.Tanh(),
nn.Linear(config.eq_adapter_hidden_size, 8)
)
# Recurrent EQ state (GRU) for layer-to-layer consistency
if config.use_eq_recurrence:
self.eq_gru = nn.GRUCell(
input_size=config.eq_adapter_hidden_size,
hidden_size=config.eq_state_dim
)
self.state_projection = nn.Linear(config.hidden_size, config.eq_state_dim)
self.gru_input_proj = nn.Linear(config.hidden_size, config.eq_adapter_hidden_size)
else:
self.eq_gru = None
self.state_projection = None
self.gru_input_proj = None
# EQ state to attention bias (scalar per head)
if config.use_eq_attention_bias:
self.attn_bias_proj = nn.Linear(
config.eq_state_dim if config.use_eq_recurrence else config.eq_adapter_hidden_size,
config.num_heads,
bias=False
)
else:
self.attn_bias_proj = None
# EQ state to FFN gate
if config.use_eq_gated_ffn:
self.ffn_gate_proj = nn.Linear(
config.eq_state_dim if config.use_eq_recurrence else config.eq_adapter_hidden_size,
config.intermediate_size,
bias=False
)
else:
self.ffn_gate_proj = None
def forward(self, hidden_states: torch.Tensor, prev_eq_state: Optional[torch.Tensor] = None):
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
prev_eq_state: [batch, eq_state_dim] previous EQ state (for recurrence)
Returns:
frustration: [batch, 1]
emotion_logits: [batch, 8]
eq_state: [batch, eq_state_dim] updated EQ state
attn_bias: [batch, num_heads, head_dim] or None
ffn_gate: [batch, d_ff] or None
"""
# Pool over sequence dimension
pooled = hidden_states.mean(dim=1)
# Frustration score (0-1)
frustration = self.frustration_detector(pooled)
# Emotion logits
emotion_logits = self.emotion_classifier(pooled)
# Compute EQ state
if self.config.use_eq_recurrence and self.eq_gru is not None:
# Project pooled features to GRU input size
gru_input = torch.tanh(self.gru_input_proj(pooled))
if prev_eq_state is None:
# Initialize state from projection
eq_state = torch.tanh(self.state_projection(pooled))
else:
eq_state = self.eq_gru(gru_input, prev_eq_state)
else:
# No recurrence, use pooled features directly
eq_state = torch.tanh(pooled)
# Compute attention bias if enabled
attn_bias = None
if self.attn_bias_proj is not None:
attn_bias = self.attn_bias_proj(eq_state) # [batch, num_heads]
# Compute FFN gate if enabled
ffn_gate = None
if self.ffn_gate_proj is not None:
ffn_gate = torch.sigmoid(self.ffn_gate_proj(eq_state))
return frustration, emotion_logits, eq_state, attn_bias, ffn_gate
class ZenithLayer(nn.Module):
"""Single transformer layer with optional MoE."""
def __init__(self, config: ZenithConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
# Determine if this layer uses MoE
self.use_moe = (
config.num_experts > 0 and
(config.moe_layers is None or layer_idx in config.moe_layers)
)
# Self attention projections
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
# Attention dropout
self.attn_dropout = nn.Dropout(0.1)
# MoE or dense feed-forward
if self.use_moe:
self.mlp = MoELayer(config)
else:
if config.use_eq_gated_ffn:
# Gated MLP: gate applied to intermediate representation
self.mlp = nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size),
nn.SiLU(),
)
self.gate_proj = nn.Linear(config.intermediate_size, config.intermediate_size)
self.out_proj_mlp = nn.Linear(config.intermediate_size, config.hidden_size)
else:
self.mlp = nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size),
nn.SiLU(),
nn.Linear(config.intermediate_size, config.hidden_size)
)
# Layer norms
self.norm1 = nn.LayerNorm(config.hidden_size)
self.norm2 = nn.LayerNorm(config.hidden_size)
# Dropout
self.dropout = nn.Dropout(0.1)
# EQ adapter (if enabled)
if config.use_eq_adapter:
self.eq_adapter = EQAdapter(config)
else:
self.eq_adapter = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
prev_eq_state: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
attention_mask: attention mask
output_attentions: whether to output attention weights
prev_eq_state: [batch, eq_state_dim] previous EQ state from previous layer
Returns:
hidden_states: [batch, seq_len, hidden_size]
attn_weights: [batch, num_heads, seq_len, seq_len] or None
moe_loss: scalar or None
eq_state: [batch, eq_state_dim] or None
consistency_loss: scalar or None
"""
# Process EQ adapter if enabled
eq_state = None
attn_bias = None
ffn_gate = None
consistency_loss = None
if self.eq_adapter is not None:
frustration, emotion_logits, eq_state, attn_bias, ffn_gate = self.eq_adapter(
hidden_states, prev_eq_state
)
# Compute consistency loss if recurrence enabled and we have previous state
if self.config.use_eq_recurrence and prev_eq_state is not None:
consistency_loss = F.mse_loss(eq_state, prev_eq_state.detach())
# Self-attention with residual
residual = hidden_states
hidden_states = self.norm1(hidden_states)
# Apply attention bias if enabled (before softmax)
if attn_bias is not None:
batch_size, seq_len, _ = hidden_states.shape
# Compute Q, K, V from normalized hidden states
q = self.q_proj(hidden_states) # [batch, seq_len, hidden_size]
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# Reshape to multi-head: [batch, seq_len, num_heads, head_dim]
q = q.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
# Compute attention scores
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.config.head_dim)
# Add bias: [batch, num_heads] -> [batch, num_heads, 1, 1] -> broadcast to all positions
attn_scores = attn_scores + attn_bias.unsqueeze(-1).unsqueeze(-1)
# Apply attention mask if provided
if attention_mask is not None:
attn_scores = attn_scores + attention_mask
# Softmax and dropout
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = self.self_attn.dropout(attn_weights)
# Apply to values
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.config.hidden_size
)
attn_output = self.self_attn.out_proj(attn_output)
else:
# Standard attention using manual projections
batch_size, seq_len, _ = hidden_states.shape
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = q.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.config.num_heads, self.config.head_dim).transpose(1, 2)
attn_output, attn_weights = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attention_mask,
dropout_p=0.1 if self.training else 0.0,
is_causal=True
)
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.config.hidden_size
)
attn_output = self.out_proj(attn_output)
hidden_states = residual + self.dropout(attn_output)
# Feed-forward with residual
residual = hidden_states
hidden_states = self.norm2(hidden_states)
if self.use_moe:
mlp_output, moe_loss = self.mlp(hidden_states)
else:
if self.config.use_eq_gated_ffn:
# Apply first part of MLP
intermediate = self.mlp(hidden_states) # [batch, seq_len, intermediate_size]
# Apply gate to intermediate representation
ffn_gate_expanded = ffn_gate.unsqueeze(1).expand(-1, intermediate.size(1), -1)
gated_intermediate = intermediate * ffn_gate_expanded
# Apply output projection
mlp_output = self.out_proj_mlp(gated_intermediate)
else:
mlp_output = self.mlp(hidden_states)
moe_loss = None
# Apply FFN gate if enabled
if ffn_gate is not None:
# ffn_gate: [batch, d_ff] -> need to broadcast to match mlp_output
ffn_gate_expanded = ffn_gate.unsqueeze(1).expand(-1, mlp_output.size(1), -1)
mlp_output = mlp_output * ffn_gate_expanded
hidden_states = residual + self.dropout(mlp_output)
return hidden_states, attn_weights, moe_loss, eq_state, consistency_loss
class ZenithModel(PreTrainedModel):
"""Zenith model with hybrid MoE and EQ adapters."""
config_class = ZenithConfig
def __init__(self, config: ZenithConfig):
super().__init__(config)
self.config = config
# Token embedding
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
# Transformer layers
self.layers = nn.ModuleList([
ZenithLayer(config, i)
for i in range(config.num_layers)
])
# Final layer norm
self.norm = nn.LayerNorm(config.hidden_size)
# LM head
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Tie weights
self.tie_weights()
def tie_weights(self):
self.lm_head.weight = self.embed_tokens.weight
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# Get embeddings
if input_ids is not None:
hidden_states = self.embed_tokens(input_ids)
elif inputs_embeds is not None:
hidden_states = inputs_embeds
else:
raise ValueError("Must provide input_ids or inputs_embeds")
# Apply transformer layers
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_moe_losses = []
all_consistency_losses = [] if self.config.use_eq_adapter and self.config.use_eq_recurrence else None
# Initialize recurrent EQ state
prev_eq_state = None
for layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
prev_eq_state=prev_eq_state
)
hidden_states = layer_outputs[0]
# Extract EQ state and consistency loss from layer outputs
if self.config.use_eq_adapter:
eq_state = layer_outputs[3] if len(layer_outputs) > 3 else None
consistency_loss = layer_outputs[4] if len(layer_outputs) > 4 else None
if eq_state is not None:
prev_eq_state = eq_state # Pass to next layer
if consistency_loss is not None:
all_consistency_losses.append(consistency_loss)
if output_attentions:
all_attentions += (layer_outputs[1],)
if layer_outputs[2] is not None: # MoE loss
all_moe_losses.append(layer_outputs[2])
# Final norm
hidden_states = self.norm(hidden_states)
# LM head
logits = self.lm_head(hidden_states)
# Compute loss
loss = None
if labels is not None:
# Shift logits and labels for next-token prediction
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
# Add auxiliary losses
if all_moe_losses:
loss += torch.stack(all_moe_losses).mean()
if self.config.use_eq_adapter and all_consistency_losses:
loss += self.config.eq_consistency_weight * torch.stack(all_consistency_losses).mean()
if not return_dict:
output = (logits,) + all_hidden_states + all_attentions
return ((loss,) + output) if loss is not None else output
return CausalLMOutput(
loss=loss,
logits=logits,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, **kwargs
):
# Omit tokens already processed
if past_key_values:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
}
class ZenithForCausalLM(PreTrainedModel):
"""Zenith model with LM head (compatibility wrapper)."""
def __init__(self, config: ZenithConfig):
super().__init__(config)
self.model = ZenithModel(config)
self.config = config
def forward(self, **kwargs):
return self.model(**kwargs)
def generate(self, **kwargs):
return self.model.generate(**kwargs)
# Register the model
from transformers import AutoConfig, AutoModelForCausalLM
AutoConfig.register("zenith", ZenithConfig)
AutoModelForCausalLM.register(ZenithConfig, ZenithModel)
if __name__ == "__main__":
# Quick test
config = ZenithConfig()
model = ZenithForCausalLM(config)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Active parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")