File size: 10,106 Bytes
8d18b7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
"""Zenith Model - Wrapper for DeepSeek Base Models with MoE and EQ"""

import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput

from ..configs import ZenithConfig
from .moe_wrapper import MoELayer
from .eq_adapter_wrapper import EQAdapterWrapper

logger = logging.getLogger(__name__)


@dataclass
class ZenithModelOutput(CausalLMOutput):
    """Output for Zenith model with multi-task heads."""
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    thoughts_logits: Optional[torch.FloatTensor] = None
    emotion_logits: Optional[torch.FloatTensor] = None
    frustration_logits: Optional[torch.FloatTensor] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    moe_aux_loss: Optional[torch.FloatTensor] = None
    eq_loss: Optional[torch.FloatTensor] = None


class ZenithModel(PreTrainedModel):
    """Zenith model with hybrid MoE and EQ adapters built on DeepSeek base."""

    config_class = ZenithConfig
    base_model_prefix = "zenith"

    def __init__(

        self,

        config: ZenithConfig,

        base_model: Optional[PreTrainedModel] = None,

    ):
        super().__init__(config)

        self.config = config

        # Load or initialize base model
        if base_model is not None:
            logger.info(f"Using provided base model: {base_model.__class__.__name__}")
            self.transformer = base_model
        else:
            # Initialize from scratch (for training from scratch)
            logger.info("Initializing new model from scratch")
            self._init_transformer()

        # Apply MoE modifications if configured
        if config.num_experts > 1:
            self._apply_moe_conversion()

        # Apply EQ adapter wrapper if configured
        if config.use_eq_adapter:
            self.eq_wrapper = EQAdapterWrapper(
                config.d_model,
                config.eq_adapter_hidden_dim,
                config.eq_num_emotions,
                config.eq_frustration_dim,
                config.eq_dropout,
            )
        else:
            self.eq_wrapper = None

        # Multi-task heads (optional)
        self.thoughts_head = None
        self.emotion_head = None
        self.frustration_head = None

        logger.info(f"ZenithModel initialized: {config.model_type}, "
                    f"params={config.total_params / 1e9:.1f}B")

    def _init_transformer(self):
        """Initialize transformer from config."""
        # This would create a transformer from scratch
        # For now, we'll rely on loading a pretrained base
        raise NotImplementedError("Please provide a base_model or load from pretrained")

    def _apply_moe_conversion(self):
        """Convert some dense layers to MoE layers."""
        logger.info(f"Converting to MoE with {self.config.num_experts} experts")
        # This would replace some layers with MoELayer
        # Implementation depends on base model architecture
        pass

    def forward(

        self,

        input_ids: Optional[torch.LongTensor] = None,

        attention_mask: Optional[torch.FloatTensor] = None,

        position_ids: Optional[torch.LongTensor] = None,

        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,

        inputs_embeds: Optional[torch.FloatTensor] = None,

        labels: Optional[torch.LongTensor] = None,

        thoughts_labels: Optional[torch.FloatTensor] = None,

        emotion_labels: Optional[torch.LongTensor] = None,

        frustration_labels: Optional[torch.FloatTensor] = None,

        output_attentions: Optional[bool] = None,

        output_hidden_states: Optional[bool] = None,

        output_moe_aux_loss: Optional[bool] = True,

        output_eq_loss: Optional[bool] = True,

        use_cache: Optional[bool] = None,

        **kwargs,

    ) -> ZenithModelOutput:
        """Forward pass with optional multi-task outputs."""

        # Forward through base transformer
        transformer_outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=True,  # Need hidden states for adapters
            use_cache=use_cache,
            **kwargs,
        )

        hidden_states = transformer_outputs.hidden_states[-1]  # Last layer
        moe_aux_loss = getattr(transformer_outputs, "moe_aux_loss", None) if output_moe_aux_loss else None

        # Apply EQ adapter if present
        eq_loss = None
        if self.eq_wrapper is not None:
            hidden_states, eq_loss = self.eq_wrapper(hidden_states, attention_mask)
            # Override last hidden state
            # Note: This is simplified - in practice need to modify transformer output properly

        # Compute language modeling loss
        lm_logits = self.transformer.lm_head(hidden_states)
        loss = None
        if labels is not None:
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        # Add auxiliary losses
        if moe_aux_loss is not None and self.config.aux_loss_weight > 0:
            loss = loss + self.config.aux_loss_weight * moe_aux_loss if loss is not None else moe_aux_loss

        if eq_loss is not None and self.config.use_eq_adapter:
            eq_loss_weight = 0.1  # Configurable
            loss = loss + eq_loss_weight * eq_loss if loss is not None else eq_loss

        return ZenithModelOutput(
            loss=loss,
            logits=lm_logits,
            hidden_states=transformer_outputs.hidden_states if output_hidden_states else None,
            attentions=transformer_outputs.attentions,
            moe_aux_loss=moe_aux_loss,
            eq_loss=eq_loss,
        )

    def prepare_inputs_for_generation(

        self,

        input_ids,

        past_key_values=None,

        attention_mask=None,

        **kwargs,

    ):
        """Prepare inputs for text generation."""
        # Use transformer's implementation
        return self.transformer.prepare_inputs_for_generation(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            **kwargs,
        )

    @classmethod
    def from_pretrained(

        cls,

        pretrained_model_name_or_path: str,

        config: Optional[ZenithConfig] = None,

        **kwargs,

    ) -> "ZenithModel":
        """Load from pretrained DeepSeek base model."""
        # Load base model
        logger.info(f"Loading base model: {pretrained_model_name_or_path}")
        base_model = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path,
            **kwargs,
        )

        # Get or create config
        if config is None:
            # Infer config from base model
            base_config = base_model.config
            config = ZenithConfig(
                model_type=f"zenith-{base_config.hidden_size // 256}B",
                d_model=base_config.hidden_size,
                d_ff=base_config.intermediate_size,
                num_layers=base_config.num_hidden_layers,
                num_heads=base_config.num_attention_heads,
                num_kv_heads=getattr(base_config, "num_key_value_heads", base_config.num_attention_heads),
                head_dim=base_config.hidden_size // base_config.num_attention_heads,
                vocab_size=base_config.vocab_size,
                max_seq_len=getattr(base_config, "max_position_embeddings", 8192),
                rope_theta=getattr(base_config, "rope_theta", 10000.0),
            )

        # Create Zenith model
        model = cls(config, base_model=base_model)

        return model

    def save_pretrained(self, save_directory: str):
        """Save model."""
        # Save base transformer
        self.transformer.save_pretrained(save_directory)

        # Save config
        self.config.save_pretrained(save_directory)

        # Save additional modules
        if self.eq_wrapper is not None:
            torch.save(
                self.eq_wrapper.state_dict(),
                f"{save_directory}/eq_adapter.pt",
            )


class ZenithForCausalLM(PreTrainedModel):
    """Zenith model with LM head (compatibility wrapper)."""

    def __init__(self, config: ZenithConfig, base_model: Optional[PreTrainedModel] = None):
        super().__init__(config)
        self.model = ZenithModel(config, base_model)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        # Tie weights if base model has tied embeddings
        if hasattr(self.model.transformer, "get_input_embeddings"):
            self.lm_head.weight = self.model.transformer.get_input_embeddings().weight

    def forward(self, **kwargs):
        outputs = self.model(**kwargs)
        return CausalLMOutput(
            loss=outputs.loss,
            logits=outputs.logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def generate(self, **kwargs):
        return self.model.generate(**kwargs)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str, config: Optional[ZenithConfig] = None, **kwargs):
        model = super().from_pretrained(pretrained_model_name_or_path, config=config, **kwargs)
        return model