"""Zenith Model - Wrapper for DeepSeek Base Models with MoE and EQ""" import logging from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel from transformers.modeling_outputs import CausalLMOutput from ..configs import ZenithConfig from .moe_wrapper import MoELayer from .eq_adapter_wrapper import EQAdapterWrapper logger = logging.getLogger(__name__) @dataclass class ZenithModelOutput(CausalLMOutput): """Output for Zenith model with multi-task heads.""" loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None thoughts_logits: Optional[torch.FloatTensor] = None emotion_logits: Optional[torch.FloatTensor] = None frustration_logits: Optional[torch.FloatTensor] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None moe_aux_loss: Optional[torch.FloatTensor] = None eq_loss: Optional[torch.FloatTensor] = None class ZenithModel(PreTrainedModel): """Zenith model with hybrid MoE and EQ adapters built on DeepSeek base.""" config_class = ZenithConfig base_model_prefix = "zenith" def __init__( self, config: ZenithConfig, base_model: Optional[PreTrainedModel] = None, ): super().__init__(config) self.config = config # Load or initialize base model if base_model is not None: logger.info(f"Using provided base model: {base_model.__class__.__name__}") self.transformer = base_model else: # Initialize from scratch (for training from scratch) logger.info("Initializing new model from scratch") self._init_transformer() # Apply MoE modifications if configured if config.num_experts > 1: self._apply_moe_conversion() # Apply EQ adapter wrapper if configured if config.use_eq_adapter: self.eq_wrapper = EQAdapterWrapper( config.d_model, config.eq_adapter_hidden_dim, config.eq_num_emotions, config.eq_frustration_dim, config.eq_dropout, ) else: self.eq_wrapper = None # Multi-task heads (optional) self.thoughts_head = None self.emotion_head = None self.frustration_head = None logger.info(f"ZenithModel initialized: {config.model_type}, " f"params={config.total_params / 1e9:.1f}B") def _init_transformer(self): """Initialize transformer from config.""" # This would create a transformer from scratch # For now, we'll rely on loading a pretrained base raise NotImplementedError("Please provide a base_model or load from pretrained") def _apply_moe_conversion(self): """Convert some dense layers to MoE layers.""" logger.info(f"Converting to MoE with {self.config.num_experts} experts") # This would replace some layers with MoELayer # Implementation depends on base model architecture pass def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, thoughts_labels: Optional[torch.FloatTensor] = None, emotion_labels: Optional[torch.LongTensor] = None, frustration_labels: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_moe_aux_loss: Optional[bool] = True, output_eq_loss: Optional[bool] = True, use_cache: Optional[bool] = None, **kwargs, ) -> ZenithModelOutput: """Forward pass with optional multi-task outputs.""" # Forward through base transformer transformer_outputs = self.transformer( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=True, # Need hidden states for adapters use_cache=use_cache, **kwargs, ) hidden_states = transformer_outputs.hidden_states[-1] # Last layer moe_aux_loss = getattr(transformer_outputs, "moe_aux_loss", None) if output_moe_aux_loss else None # Apply EQ adapter if present eq_loss = None if self.eq_wrapper is not None: hidden_states, eq_loss = self.eq_wrapper(hidden_states, attention_mask) # Override last hidden state # Note: This is simplified - in practice need to modify transformer output properly # Compute language modeling loss lm_logits = self.transformer.lm_head(hidden_states) loss = None if labels is not None: shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # Add auxiliary losses if moe_aux_loss is not None and self.config.aux_loss_weight > 0: loss = loss + self.config.aux_loss_weight * moe_aux_loss if loss is not None else moe_aux_loss if eq_loss is not None and self.config.use_eq_adapter: eq_loss_weight = 0.1 # Configurable loss = loss + eq_loss_weight * eq_loss if loss is not None else eq_loss return ZenithModelOutput( loss=loss, logits=lm_logits, hidden_states=transformer_outputs.hidden_states if output_hidden_states else None, attentions=transformer_outputs.attentions, moe_aux_loss=moe_aux_loss, eq_loss=eq_loss, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, **kwargs, ): """Prepare inputs for text generation.""" # Use transformer's implementation return self.transformer.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, **kwargs, ) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, config: Optional[ZenithConfig] = None, **kwargs, ) -> "ZenithModel": """Load from pretrained DeepSeek base model.""" # Load base model logger.info(f"Loading base model: {pretrained_model_name_or_path}") base_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path, **kwargs, ) # Get or create config if config is None: # Infer config from base model base_config = base_model.config config = ZenithConfig( model_type=f"zenith-{base_config.hidden_size // 256}B", d_model=base_config.hidden_size, d_ff=base_config.intermediate_size, num_layers=base_config.num_hidden_layers, num_heads=base_config.num_attention_heads, num_kv_heads=getattr(base_config, "num_key_value_heads", base_config.num_attention_heads), head_dim=base_config.hidden_size // base_config.num_attention_heads, vocab_size=base_config.vocab_size, max_seq_len=getattr(base_config, "max_position_embeddings", 8192), rope_theta=getattr(base_config, "rope_theta", 10000.0), ) # Create Zenith model model = cls(config, base_model=base_model) return model def save_pretrained(self, save_directory: str): """Save model.""" # Save base transformer self.transformer.save_pretrained(save_directory) # Save config self.config.save_pretrained(save_directory) # Save additional modules if self.eq_wrapper is not None: torch.save( self.eq_wrapper.state_dict(), f"{save_directory}/eq_adapter.pt", ) class ZenithForCausalLM(PreTrainedModel): """Zenith model with LM head (compatibility wrapper).""" def __init__(self, config: ZenithConfig, base_model: Optional[PreTrainedModel] = None): super().__init__(config) self.model = ZenithModel(config, base_model) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) # Tie weights if base model has tied embeddings if hasattr(self.model.transformer, "get_input_embeddings"): self.lm_head.weight = self.model.transformer.get_input_embeddings().weight def forward(self, **kwargs): outputs = self.model(**kwargs) return CausalLMOutput( loss=outputs.loss, logits=outputs.logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def generate(self, **kwargs): return self.model.generate(**kwargs) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, config: Optional[ZenithConfig] = None, **kwargs): model = super().from_pretrained(pretrained_model_name_or_path, config=config, **kwargs) return model