File size: 10,106 Bytes
8d18b7c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 | """Zenith Model - Wrapper for DeepSeek Base Models with MoE and EQ"""
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from ..configs import ZenithConfig
from .moe_wrapper import MoELayer
from .eq_adapter_wrapper import EQAdapterWrapper
logger = logging.getLogger(__name__)
@dataclass
class ZenithModelOutput(CausalLMOutput):
"""Output for Zenith model with multi-task heads."""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
thoughts_logits: Optional[torch.FloatTensor] = None
emotion_logits: Optional[torch.FloatTensor] = None
frustration_logits: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
moe_aux_loss: Optional[torch.FloatTensor] = None
eq_loss: Optional[torch.FloatTensor] = None
class ZenithModel(PreTrainedModel):
"""Zenith model with hybrid MoE and EQ adapters built on DeepSeek base."""
config_class = ZenithConfig
base_model_prefix = "zenith"
def __init__(
self,
config: ZenithConfig,
base_model: Optional[PreTrainedModel] = None,
):
super().__init__(config)
self.config = config
# Load or initialize base model
if base_model is not None:
logger.info(f"Using provided base model: {base_model.__class__.__name__}")
self.transformer = base_model
else:
# Initialize from scratch (for training from scratch)
logger.info("Initializing new model from scratch")
self._init_transformer()
# Apply MoE modifications if configured
if config.num_experts > 1:
self._apply_moe_conversion()
# Apply EQ adapter wrapper if configured
if config.use_eq_adapter:
self.eq_wrapper = EQAdapterWrapper(
config.d_model,
config.eq_adapter_hidden_dim,
config.eq_num_emotions,
config.eq_frustration_dim,
config.eq_dropout,
)
else:
self.eq_wrapper = None
# Multi-task heads (optional)
self.thoughts_head = None
self.emotion_head = None
self.frustration_head = None
logger.info(f"ZenithModel initialized: {config.model_type}, "
f"params={config.total_params / 1e9:.1f}B")
def _init_transformer(self):
"""Initialize transformer from config."""
# This would create a transformer from scratch
# For now, we'll rely on loading a pretrained base
raise NotImplementedError("Please provide a base_model or load from pretrained")
def _apply_moe_conversion(self):
"""Convert some dense layers to MoE layers."""
logger.info(f"Converting to MoE with {self.config.num_experts} experts")
# This would replace some layers with MoELayer
# Implementation depends on base model architecture
pass
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
thoughts_labels: Optional[torch.FloatTensor] = None,
emotion_labels: Optional[torch.LongTensor] = None,
frustration_labels: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_moe_aux_loss: Optional[bool] = True,
output_eq_loss: Optional[bool] = True,
use_cache: Optional[bool] = None,
**kwargs,
) -> ZenithModelOutput:
"""Forward pass with optional multi-task outputs."""
# Forward through base transformer
transformer_outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=True, # Need hidden states for adapters
use_cache=use_cache,
**kwargs,
)
hidden_states = transformer_outputs.hidden_states[-1] # Last layer
moe_aux_loss = getattr(transformer_outputs, "moe_aux_loss", None) if output_moe_aux_loss else None
# Apply EQ adapter if present
eq_loss = None
if self.eq_wrapper is not None:
hidden_states, eq_loss = self.eq_wrapper(hidden_states, attention_mask)
# Override last hidden state
# Note: This is simplified - in practice need to modify transformer output properly
# Compute language modeling loss
lm_logits = self.transformer.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
# Add auxiliary losses
if moe_aux_loss is not None and self.config.aux_loss_weight > 0:
loss = loss + self.config.aux_loss_weight * moe_aux_loss if loss is not None else moe_aux_loss
if eq_loss is not None and self.config.use_eq_adapter:
eq_loss_weight = 0.1 # Configurable
loss = loss + eq_loss_weight * eq_loss if loss is not None else eq_loss
return ZenithModelOutput(
loss=loss,
logits=lm_logits,
hidden_states=transformer_outputs.hidden_states if output_hidden_states else None,
attentions=transformer_outputs.attentions,
moe_aux_loss=moe_aux_loss,
eq_loss=eq_loss,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
**kwargs,
):
"""Prepare inputs for text generation."""
# Use transformer's implementation
return self.transformer.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
**kwargs,
)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
config: Optional[ZenithConfig] = None,
**kwargs,
) -> "ZenithModel":
"""Load from pretrained DeepSeek base model."""
# Load base model
logger.info(f"Loading base model: {pretrained_model_name_or_path}")
base_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
**kwargs,
)
# Get or create config
if config is None:
# Infer config from base model
base_config = base_model.config
config = ZenithConfig(
model_type=f"zenith-{base_config.hidden_size // 256}B",
d_model=base_config.hidden_size,
d_ff=base_config.intermediate_size,
num_layers=base_config.num_hidden_layers,
num_heads=base_config.num_attention_heads,
num_kv_heads=getattr(base_config, "num_key_value_heads", base_config.num_attention_heads),
head_dim=base_config.hidden_size // base_config.num_attention_heads,
vocab_size=base_config.vocab_size,
max_seq_len=getattr(base_config, "max_position_embeddings", 8192),
rope_theta=getattr(base_config, "rope_theta", 10000.0),
)
# Create Zenith model
model = cls(config, base_model=base_model)
return model
def save_pretrained(self, save_directory: str):
"""Save model."""
# Save base transformer
self.transformer.save_pretrained(save_directory)
# Save config
self.config.save_pretrained(save_directory)
# Save additional modules
if self.eq_wrapper is not None:
torch.save(
self.eq_wrapper.state_dict(),
f"{save_directory}/eq_adapter.pt",
)
class ZenithForCausalLM(PreTrainedModel):
"""Zenith model with LM head (compatibility wrapper)."""
def __init__(self, config: ZenithConfig, base_model: Optional[PreTrainedModel] = None):
super().__init__(config)
self.model = ZenithModel(config, base_model)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Tie weights if base model has tied embeddings
if hasattr(self.model.transformer, "get_input_embeddings"):
self.lm_head.weight = self.model.transformer.get_input_embeddings().weight
def forward(self, **kwargs):
outputs = self.model(**kwargs)
return CausalLMOutput(
loss=outputs.loss,
logits=outputs.logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def generate(self, **kwargs):
return self.model.generate(**kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, config: Optional[ZenithConfig] = None, **kwargs):
model = super().from_pretrained(pretrained_model_name_or_path, config=config, **kwargs)
return model
|