| """Zenith Model - Wrapper for DeepSeek Base Models with MoE and EQ"""
|
|
|
| import logging
|
| from dataclasses import dataclass
|
| from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
| import torch
|
| import torch.nn as nn
|
| from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
|
| from transformers.modeling_outputs import CausalLMOutput
|
|
|
| from ..configs import ZenithConfig
|
| from .moe_wrapper import MoELayer
|
| from .eq_adapter_wrapper import EQAdapterWrapper
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| @dataclass
|
| class ZenithModelOutput(CausalLMOutput):
|
| """Output for Zenith model with multi-task heads."""
|
| loss: Optional[torch.FloatTensor] = None
|
| logits: torch.FloatTensor = None
|
| thoughts_logits: Optional[torch.FloatTensor] = None
|
| emotion_logits: Optional[torch.FloatTensor] = None
|
| frustration_logits: Optional[torch.FloatTensor] = None
|
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| moe_aux_loss: Optional[torch.FloatTensor] = None
|
| eq_loss: Optional[torch.FloatTensor] = None
|
|
|
|
|
| class ZenithModel(PreTrainedModel):
|
| """Zenith model with hybrid MoE and EQ adapters built on DeepSeek base."""
|
|
|
| config_class = ZenithConfig
|
| base_model_prefix = "zenith"
|
|
|
| def __init__(
|
| self,
|
| config: ZenithConfig,
|
| base_model: Optional[PreTrainedModel] = None,
|
| ):
|
| super().__init__(config)
|
|
|
| self.config = config
|
|
|
|
|
| if base_model is not None:
|
| logger.info(f"Using provided base model: {base_model.__class__.__name__}")
|
| self.transformer = base_model
|
| else:
|
|
|
| logger.info("Initializing new model from scratch")
|
| self._init_transformer()
|
|
|
|
|
| if config.num_experts > 1:
|
| self._apply_moe_conversion()
|
|
|
|
|
| if config.use_eq_adapter:
|
| self.eq_wrapper = EQAdapterWrapper(
|
| config.d_model,
|
| config.eq_adapter_hidden_dim,
|
| config.eq_num_emotions,
|
| config.eq_frustration_dim,
|
| config.eq_dropout,
|
| )
|
| else:
|
| self.eq_wrapper = None
|
|
|
|
|
| self.thoughts_head = None
|
| self.emotion_head = None
|
| self.frustration_head = None
|
|
|
| logger.info(f"ZenithModel initialized: {config.model_type}, "
|
| f"params={config.total_params / 1e9:.1f}B")
|
|
|
| def _init_transformer(self):
|
| """Initialize transformer from config."""
|
|
|
|
|
| raise NotImplementedError("Please provide a base_model or load from pretrained")
|
|
|
| def _apply_moe_conversion(self):
|
| """Convert some dense layers to MoE layers."""
|
| logger.info(f"Converting to MoE with {self.config.num_experts} experts")
|
|
|
|
|
| pass
|
|
|
| def forward(
|
| self,
|
| input_ids: Optional[torch.LongTensor] = None,
|
| attention_mask: Optional[torch.FloatTensor] = None,
|
| position_ids: Optional[torch.LongTensor] = None,
|
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| inputs_embeds: Optional[torch.FloatTensor] = None,
|
| labels: Optional[torch.LongTensor] = None,
|
| thoughts_labels: Optional[torch.FloatTensor] = None,
|
| emotion_labels: Optional[torch.LongTensor] = None,
|
| frustration_labels: Optional[torch.FloatTensor] = None,
|
| output_attentions: Optional[bool] = None,
|
| output_hidden_states: Optional[bool] = None,
|
| output_moe_aux_loss: Optional[bool] = True,
|
| output_eq_loss: Optional[bool] = True,
|
| use_cache: Optional[bool] = None,
|
| **kwargs,
|
| ) -> ZenithModelOutput:
|
| """Forward pass with optional multi-task outputs."""
|
|
|
|
|
| transformer_outputs = self.transformer(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| position_ids=position_ids,
|
| past_key_values=past_key_values,
|
| inputs_embeds=inputs_embeds,
|
| output_attentions=output_attentions,
|
| output_hidden_states=True,
|
| use_cache=use_cache,
|
| **kwargs,
|
| )
|
|
|
| hidden_states = transformer_outputs.hidden_states[-1]
|
| moe_aux_loss = getattr(transformer_outputs, "moe_aux_loss", None) if output_moe_aux_loss else None
|
|
|
|
|
| eq_loss = None
|
| if self.eq_wrapper is not None:
|
| hidden_states, eq_loss = self.eq_wrapper(hidden_states, attention_mask)
|
|
|
|
|
|
|
|
|
| lm_logits = self.transformer.lm_head(hidden_states)
|
| loss = None
|
| if labels is not None:
|
| shift_logits = lm_logits[..., :-1, :].contiguous()
|
| shift_labels = labels[..., 1:].contiguous()
|
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
|
|
|
| if moe_aux_loss is not None and self.config.aux_loss_weight > 0:
|
| loss = loss + self.config.aux_loss_weight * moe_aux_loss if loss is not None else moe_aux_loss
|
|
|
| if eq_loss is not None and self.config.use_eq_adapter:
|
| eq_loss_weight = 0.1
|
| loss = loss + eq_loss_weight * eq_loss if loss is not None else eq_loss
|
|
|
| return ZenithModelOutput(
|
| loss=loss,
|
| logits=lm_logits,
|
| hidden_states=transformer_outputs.hidden_states if output_hidden_states else None,
|
| attentions=transformer_outputs.attentions,
|
| moe_aux_loss=moe_aux_loss,
|
| eq_loss=eq_loss,
|
| )
|
|
|
| def prepare_inputs_for_generation(
|
| self,
|
| input_ids,
|
| past_key_values=None,
|
| attention_mask=None,
|
| **kwargs,
|
| ):
|
| """Prepare inputs for text generation."""
|
|
|
| return self.transformer.prepare_inputs_for_generation(
|
| input_ids,
|
| past_key_values=past_key_values,
|
| attention_mask=attention_mask,
|
| **kwargs,
|
| )
|
|
|
| @classmethod
|
| def from_pretrained(
|
| cls,
|
| pretrained_model_name_or_path: str,
|
| config: Optional[ZenithConfig] = None,
|
| **kwargs,
|
| ) -> "ZenithModel":
|
| """Load from pretrained DeepSeek base model."""
|
|
|
| logger.info(f"Loading base model: {pretrained_model_name_or_path}")
|
| base_model = AutoModelForCausalLM.from_pretrained(
|
| pretrained_model_name_or_path,
|
| **kwargs,
|
| )
|
|
|
|
|
| if config is None:
|
|
|
| base_config = base_model.config
|
| config = ZenithConfig(
|
| model_type=f"zenith-{base_config.hidden_size // 256}B",
|
| d_model=base_config.hidden_size,
|
| d_ff=base_config.intermediate_size,
|
| num_layers=base_config.num_hidden_layers,
|
| num_heads=base_config.num_attention_heads,
|
| num_kv_heads=getattr(base_config, "num_key_value_heads", base_config.num_attention_heads),
|
| head_dim=base_config.hidden_size // base_config.num_attention_heads,
|
| vocab_size=base_config.vocab_size,
|
| max_seq_len=getattr(base_config, "max_position_embeddings", 8192),
|
| rope_theta=getattr(base_config, "rope_theta", 10000.0),
|
| )
|
|
|
|
|
| model = cls(config, base_model=base_model)
|
|
|
| return model
|
|
|
| def save_pretrained(self, save_directory: str):
|
| """Save model."""
|
|
|
| self.transformer.save_pretrained(save_directory)
|
|
|
|
|
| self.config.save_pretrained(save_directory)
|
|
|
|
|
| if self.eq_wrapper is not None:
|
| torch.save(
|
| self.eq_wrapper.state_dict(),
|
| f"{save_directory}/eq_adapter.pt",
|
| )
|
|
|
|
|
| class ZenithForCausalLM(PreTrainedModel):
|
| """Zenith model with LM head (compatibility wrapper)."""
|
|
|
| def __init__(self, config: ZenithConfig, base_model: Optional[PreTrainedModel] = None):
|
| super().__init__(config)
|
| self.model = ZenithModel(config, base_model)
|
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
|
|
|
|
| if hasattr(self.model.transformer, "get_input_embeddings"):
|
| self.lm_head.weight = self.model.transformer.get_input_embeddings().weight
|
|
|
| def forward(self, **kwargs):
|
| outputs = self.model(**kwargs)
|
| return CausalLMOutput(
|
| loss=outputs.loss,
|
| logits=outputs.logits,
|
| hidden_states=outputs.hidden_states,
|
| attentions=outputs.attentions,
|
| )
|
|
|
| def generate(self, **kwargs):
|
| return self.model.generate(**kwargs)
|
|
|
| @classmethod
|
| def from_pretrained(cls, pretrained_model_name_or_path: str, config: Optional[ZenithConfig] = None, **kwargs):
|
| model = super().from_pretrained(pretrained_model_name_or_path, config=config, **kwargs)
|
| return model
|
|
|