Zenith-7b-V1 / models /zenith_model.py
Zandy-Wandy's picture
Upload Zenith-7B model
8d18b7c verified
"""Zenith Model - Wrapper for DeepSeek Base Models with MoE and EQ"""
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from ..configs import ZenithConfig
from .moe_wrapper import MoELayer
from .eq_adapter_wrapper import EQAdapterWrapper
logger = logging.getLogger(__name__)
@dataclass
class ZenithModelOutput(CausalLMOutput):
"""Output for Zenith model with multi-task heads."""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
thoughts_logits: Optional[torch.FloatTensor] = None
emotion_logits: Optional[torch.FloatTensor] = None
frustration_logits: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
moe_aux_loss: Optional[torch.FloatTensor] = None
eq_loss: Optional[torch.FloatTensor] = None
class ZenithModel(PreTrainedModel):
"""Zenith model with hybrid MoE and EQ adapters built on DeepSeek base."""
config_class = ZenithConfig
base_model_prefix = "zenith"
def __init__(
self,
config: ZenithConfig,
base_model: Optional[PreTrainedModel] = None,
):
super().__init__(config)
self.config = config
# Load or initialize base model
if base_model is not None:
logger.info(f"Using provided base model: {base_model.__class__.__name__}")
self.transformer = base_model
else:
# Initialize from scratch (for training from scratch)
logger.info("Initializing new model from scratch")
self._init_transformer()
# Apply MoE modifications if configured
if config.num_experts > 1:
self._apply_moe_conversion()
# Apply EQ adapter wrapper if configured
if config.use_eq_adapter:
self.eq_wrapper = EQAdapterWrapper(
config.d_model,
config.eq_adapter_hidden_dim,
config.eq_num_emotions,
config.eq_frustration_dim,
config.eq_dropout,
)
else:
self.eq_wrapper = None
# Multi-task heads (optional)
self.thoughts_head = None
self.emotion_head = None
self.frustration_head = None
logger.info(f"ZenithModel initialized: {config.model_type}, "
f"params={config.total_params / 1e9:.1f}B")
def _init_transformer(self):
"""Initialize transformer from config."""
# This would create a transformer from scratch
# For now, we'll rely on loading a pretrained base
raise NotImplementedError("Please provide a base_model or load from pretrained")
def _apply_moe_conversion(self):
"""Convert some dense layers to MoE layers."""
logger.info(f"Converting to MoE with {self.config.num_experts} experts")
# This would replace some layers with MoELayer
# Implementation depends on base model architecture
pass
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
thoughts_labels: Optional[torch.FloatTensor] = None,
emotion_labels: Optional[torch.LongTensor] = None,
frustration_labels: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_moe_aux_loss: Optional[bool] = True,
output_eq_loss: Optional[bool] = True,
use_cache: Optional[bool] = None,
**kwargs,
) -> ZenithModelOutput:
"""Forward pass with optional multi-task outputs."""
# Forward through base transformer
transformer_outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=True, # Need hidden states for adapters
use_cache=use_cache,
**kwargs,
)
hidden_states = transformer_outputs.hidden_states[-1] # Last layer
moe_aux_loss = getattr(transformer_outputs, "moe_aux_loss", None) if output_moe_aux_loss else None
# Apply EQ adapter if present
eq_loss = None
if self.eq_wrapper is not None:
hidden_states, eq_loss = self.eq_wrapper(hidden_states, attention_mask)
# Override last hidden state
# Note: This is simplified - in practice need to modify transformer output properly
# Compute language modeling loss
lm_logits = self.transformer.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
# Add auxiliary losses
if moe_aux_loss is not None and self.config.aux_loss_weight > 0:
loss = loss + self.config.aux_loss_weight * moe_aux_loss if loss is not None else moe_aux_loss
if eq_loss is not None and self.config.use_eq_adapter:
eq_loss_weight = 0.1 # Configurable
loss = loss + eq_loss_weight * eq_loss if loss is not None else eq_loss
return ZenithModelOutput(
loss=loss,
logits=lm_logits,
hidden_states=transformer_outputs.hidden_states if output_hidden_states else None,
attentions=transformer_outputs.attentions,
moe_aux_loss=moe_aux_loss,
eq_loss=eq_loss,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
**kwargs,
):
"""Prepare inputs for text generation."""
# Use transformer's implementation
return self.transformer.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
**kwargs,
)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
config: Optional[ZenithConfig] = None,
**kwargs,
) -> "ZenithModel":
"""Load from pretrained DeepSeek base model."""
# Load base model
logger.info(f"Loading base model: {pretrained_model_name_or_path}")
base_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
**kwargs,
)
# Get or create config
if config is None:
# Infer config from base model
base_config = base_model.config
config = ZenithConfig(
model_type=f"zenith-{base_config.hidden_size // 256}B",
d_model=base_config.hidden_size,
d_ff=base_config.intermediate_size,
num_layers=base_config.num_hidden_layers,
num_heads=base_config.num_attention_heads,
num_kv_heads=getattr(base_config, "num_key_value_heads", base_config.num_attention_heads),
head_dim=base_config.hidden_size // base_config.num_attention_heads,
vocab_size=base_config.vocab_size,
max_seq_len=getattr(base_config, "max_position_embeddings", 8192),
rope_theta=getattr(base_config, "rope_theta", 10000.0),
)
# Create Zenith model
model = cls(config, base_model=base_model)
return model
def save_pretrained(self, save_directory: str):
"""Save model."""
# Save base transformer
self.transformer.save_pretrained(save_directory)
# Save config
self.config.save_pretrained(save_directory)
# Save additional modules
if self.eq_wrapper is not None:
torch.save(
self.eq_wrapper.state_dict(),
f"{save_directory}/eq_adapter.pt",
)
class ZenithForCausalLM(PreTrainedModel):
"""Zenith model with LM head (compatibility wrapper)."""
def __init__(self, config: ZenithConfig, base_model: Optional[PreTrainedModel] = None):
super().__init__(config)
self.model = ZenithModel(config, base_model)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Tie weights if base model has tied embeddings
if hasattr(self.model.transformer, "get_input_embeddings"):
self.lm_head.weight = self.model.transformer.get_input_embeddings().weight
def forward(self, **kwargs):
outputs = self.model(**kwargs)
return CausalLMOutput(
loss=outputs.loss,
logits=outputs.logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def generate(self, **kwargs):
return self.model.generate(**kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, config: Optional[ZenithConfig] = None, **kwargs):
model = super().from_pretrained(pretrained_model_name_or_path, config=config, **kwargs)
return model