anthonym21 commited on
Commit
cf4bcac
·
verified ·
1 Parent(s): c0522ad

Rewrite modeling_eve.py with HF-compatible EveMoEForCausalLM

Browse files
Files changed (1) hide show
  1. modeling_eve.py +370 -55
modeling_eve.py CHANGED
@@ -1,169 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
- from transformers import PreTrainedModel
6
- from .configuration_eve import EveConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class RMSNorm(nn.Module):
9
- def __init__(self, dim, eps=1e-5):
 
 
10
  super().__init__()
11
  self.eps = eps
12
  self.weight = nn.Parameter(torch.ones(dim))
13
- def forward(self, x):
 
14
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
15
 
16
- def precompute_rope_freqs(head_dim, max_seq_len, theta=10000.0, device=None):
 
 
 
 
 
 
17
  freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
18
  t = torch.arange(max_seq_len, device=device).float()
19
  freqs = torch.outer(t, freqs)
20
- return torch.polar(torch.ones_like(freqs), freqs)
 
21
 
22
- def apply_rope(x, freqs_cis):
 
 
 
 
 
 
 
 
 
23
  B, H, T, D = x.shape
24
  x_complex = torch.view_as_complex(x.float().reshape(B, H, T, D // 2, 2))
25
- freqs_cis = freqs_cis[:T].view(1, 1, T, D // 2)
 
26
  x_rotated = x_complex * freqs_cis
 
27
  return torch.view_as_real(x_rotated).reshape(B, H, T, D).type_as(x)
28
 
 
29
  class MLP(nn.Module):
30
- def __init__(self, config, intermediate_size=None):
 
 
31
  super().__init__()
32
  hidden_dim = intermediate_size or config.expert_intermediate_size
33
- self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False)
34
- self.w2 = nn.Linear(config.n_embd, hidden_dim, bias=False)
35
- self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
36
- def forward(self, x):
 
37
  return self.c_proj(F.silu(self.w1(x)) * self.w2(x))
38
 
 
39
  class SharedMoE(nn.Module):
 
 
 
 
 
 
40
  def __init__(self, config):
41
  super().__init__()
42
  self.config = config
43
  self.top_k = config.top_k
 
 
44
  self.shared_expert = MLP(config, config.shared_expert_intermediate_size)
 
 
45
  self.experts = nn.ModuleList([MLP(config) for _ in range(config.num_experts)])
46
  self.router = nn.Linear(config.n_embd, config.num_experts, bias=False)
47
 
48
- def forward(self, x):
49
  B, T, C = x.shape
 
 
50
  shared_out = self.shared_expert(x)
 
 
51
  logits = self.router(x)
52
  probs = F.softmax(logits, dim=-1)
 
 
53
  top_k_weights, top_k_indices = torch.topk(probs, self.top_k, dim=-1)
54
  top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
55
-
 
56
  flat_probs = probs.view(-1, self.config.num_experts)
57
  expert_usage = flat_probs.mean(dim=0)
58
  aux_loss = torch.sum(expert_usage * expert_usage) * self.config.num_experts
59
-
 
60
  routed_out = torch.zeros_like(x)
61
  flat_x = x.view(-1, C)
62
  flat_indices = top_k_indices.view(-1, self.top_k)
63
  flat_weights = top_k_weights.view(-1, self.top_k)
64
-
65
  for i, expert in enumerate(self.experts):
66
  mask = flat_indices == i
67
  batch_idx, rank_idx = torch.where(mask)
 
68
  if batch_idx.numel() > 0:
69
  expert_input = flat_x[batch_idx]
70
  expert_output = expert(expert_input)
71
  weight = flat_weights[batch_idx, rank_idx].unsqueeze(-1)
72
  routed_out.view(-1, C).index_add_(0, batch_idx, expert_output * weight)
 
73
  return shared_out + routed_out, aux_loss
74
 
 
75
  class CausalSelfAttention(nn.Module):
 
 
76
  def __init__(self, config):
77
  super().__init__()
78
  self.n_head = config.n_head
79
  self.head_dim = config.head_dim
80
  self.n_embd = config.n_embd
 
81
  self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
82
  self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
83
 
84
- def forward(self, x, freqs_cis):
85
  B, T, C = x.shape
 
86
  qkv = self.c_attn(x)
87
  q, k, v = qkv.split(self.n_embd, dim=2)
 
88
  q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
89
  k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
90
  v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
 
 
91
  q = apply_rope(q, freqs_cis)
92
  k = apply_rope(k, freqs_cis)
 
 
93
  y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
 
94
  y = y.transpose(1, 2).contiguous().view(B, T, C)
95
  return self.c_proj(y)
96
 
 
97
  class Block(nn.Module):
 
 
98
  def __init__(self, config):
99
  super().__init__()
100
  self.ln_1 = RMSNorm(config.n_embd)
 
101
  self.ln_2 = RMSNorm(config.n_embd)
102
- self.attn = CausalSelfAttention(config) # Named 'attn' to match safetensors
103
  self.mlp = SharedMoE(config)
104
 
105
- def forward(self, x, freqs_cis):
106
- attn_out = self.attn(self.ln_1(x), freqs_cis)
107
- x = x + attn_out
108
  mlp_out, aux_loss = self.mlp(self.ln_2(x))
109
  x = x + mlp_out
110
  return x, aux_loss
111
 
112
- class DeepSeekMoE(PreTrainedModel):
113
- config_class = EveConfig
114
- _tied_weights_keys = ["lm_head.weight"]
115
-
116
- def __init__(self, config):
117
- super().__init__(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  self.config = config
 
119
  self.transformer = nn.ModuleDict(dict(
120
  wte=nn.Embedding(config.vocab_size, config.n_embd),
121
  h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
122
  ln_f=RMSNorm(config.n_embd),
123
  ))
124
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
125
-
126
- # Tie weights
127
  self.transformer.wte.weight = self.lm_head.weight
128
-
 
129
  freqs_cis = precompute_rope_freqs(config.head_dim, config.block_size, config.rope_theta)
130
  self.register_buffer("freqs_cis", freqs_cis, persistent=False)
131
 
132
- def get_input_embeddings(self):
133
- return self.transformer.wte
134
-
135
- def set_input_embeddings(self, value):
136
- self.transformer.wte = value
137
-
138
- def get_output_embeddings(self):
139
- return self.lm_head
140
 
141
- def set_output_embeddings(self, new_embeddings):
142
- self.lm_head = new_embeddings
 
 
 
 
 
143
 
144
- def forward(self, input_ids=None, idx=None, labels=None, targets=None, **kwargs):
145
- if idx is None: idx = input_ids
146
- if targets is None: targets = labels
147
-
148
  B, T = idx.shape
 
 
149
  x = self.transformer.wte(idx)
 
150
  total_aux_loss = 0.0
151
-
152
- freqs_cis = self.freqs_cis.to(x.device)
153
-
154
  for block in self.transformer.h:
155
- x, aux_loss = block(x, freqs_cis[:T])
 
 
 
 
 
156
  total_aux_loss += aux_loss
157
-
158
  x = self.transformer.ln_f(x)
159
  logits = self.lm_head(x)
160
-
161
  loss = None
162
  if targets is not None:
163
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
164
  loss = loss + self.config.router_aux_loss_coef * total_aux_loss
165
-
166
- return (loss, logits) if loss is not None else logits
167
 
168
- def prepare_inputs_for_generation(self, input_ids, **kwargs):
169
- return {"input_ids": input_ids}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Eve-2-MoE — Custom Mixture of Experts Language Model
3
+ =====================================================
4
+ Architecture: DeepSeek-V3 style Shared Expert + Top-K Routed Experts + RoPE
5
+ Author: Anthony Maio / Making Minds AI Research
6
+ License: MIT
7
+
8
+ Usage (HuggingFace):
9
+ from transformers import AutoModelForCausalLM
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ "anthonym21/Eve-2-MoE-272M", trust_remote_code=True
12
+ )
13
+
14
+ Usage (standalone):
15
+ from modeling_eve import ModelConfig, DeepSeekMoE
16
+ model = DeepSeekMoE(ModelConfig())
17
+ """
18
 
19
  import torch
20
  import torch.nn as nn
21
  import torch.nn.functional as F
22
+ import math
23
+ from collections import OrderedDict
24
+ from dataclasses import dataclass
25
+
26
+
27
+ # ============================================================
28
+ # Standalone config (no transformers dependency)
29
+ # ============================================================
30
+
31
+ @dataclass
32
+ class ModelConfig:
33
+ """Configuration for Eve-2-MoE (standalone, no HF dependency)."""
34
+
35
+ # Model dimensions
36
+ vocab_size: int = 50304
37
+ n_layer: int = 12
38
+ n_embd: int = 512
39
+ n_head: int = 8
40
+ head_dim: int = 64
41
+ block_size: int = 2048
42
+
43
+ # MoE settings
44
+ num_experts: int = 8
45
+ top_k: int = 2
46
+ expert_intermediate_size: int = 1408
47
+ shared_expert_intermediate_size: int = 1408
48
+ router_aux_loss_coef: float = 0.01
49
+
50
+ # Training settings
51
+ use_checkpointing: bool = False # Gradient checkpointing (saves VRAM, costs speed)
52
+
53
+ # RoPE settings
54
+ rope_theta: float = 10000.0
55
+
56
+
57
+ # ============================================================
58
+ # Utility: strip torch.compile prefix from state dicts
59
+ # ============================================================
60
+
61
+ def _strip_orig_mod_prefix(state_dict):
62
+ """Remove '_orig_mod.' prefix from keys saved by torch.compile'd models."""
63
+ cleaned = OrderedDict()
64
+ for k, v in state_dict.items():
65
+ cleaned[k.replace("_orig_mod.", "")] = v
66
+ return cleaned
67
+
68
+
69
+ # ============================================================
70
+ # Building blocks (shared by standalone and HF models)
71
+ # ============================================================
72
 
73
  class RMSNorm(nn.Module):
74
+ """Root Mean Square Layer Normalization."""
75
+
76
+ def __init__(self, dim: int, eps: float = 1e-5):
77
  super().__init__()
78
  self.eps = eps
79
  self.weight = nn.Parameter(torch.ones(dim))
80
+
81
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
82
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
83
 
84
+
85
+ def precompute_rope_freqs(head_dim: int, max_seq_len: int, theta: float = 10000.0,
86
+ device: torch.device = None) -> torch.Tensor:
87
+ """Precompute the complex exponential frequencies for RoPE.
88
+
89
+ Returns a (max_seq_len, head_dim // 2) complex tensor.
90
+ """
91
  freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
92
  t = torch.arange(max_seq_len, device=device).float()
93
  freqs = torch.outer(t, freqs)
94
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
95
+
96
 
97
+ def apply_rope(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
98
+ """Apply rotary position embeddings to input tensor.
99
+
100
+ Args:
101
+ x: (B, n_head, T, head_dim)
102
+ freqs_cis: (T, head_dim // 2) complex
103
+ Returns:
104
+ (B, n_head, T, head_dim) with rotary embeddings applied
105
+ """
106
+ # Reshape x to complex: (B, n_head, T, head_dim//2, 2) -> complex
107
  B, H, T, D = x.shape
108
  x_complex = torch.view_as_complex(x.float().reshape(B, H, T, D // 2, 2))
109
+ # Broadcast freqs_cis: (1, 1, T, head_dim//2)
110
+ freqs_cis = freqs_cis[:T].unsqueeze(0).unsqueeze(0)
111
  x_rotated = x_complex * freqs_cis
112
+ # Back to real: (B, H, T, head_dim)
113
  return torch.view_as_real(x_rotated).reshape(B, H, T, D).type_as(x)
114
 
115
+
116
  class MLP(nn.Module):
117
+ """Feed-forward network with SwiGLU activation."""
118
+
119
+ def __init__(self, config, intermediate_size: int = None):
120
  super().__init__()
121
  hidden_dim = intermediate_size or config.expert_intermediate_size
122
+ self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False) # Gate
123
+ self.w2 = nn.Linear(config.n_embd, hidden_dim, bias=False) # Up
124
+ self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) # Down
125
+
126
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
127
  return self.c_proj(F.silu(self.w1(x)) * self.w2(x))
128
 
129
+
130
  class SharedMoE(nn.Module):
131
+ """Mixture of Experts with one shared expert and K routed experts.
132
+
133
+ DeepSeek-V3 style: a shared expert processes all tokens while a top-k
134
+ router selects from a pool of specialized experts per token.
135
+ """
136
+
137
  def __init__(self, config):
138
  super().__init__()
139
  self.config = config
140
  self.top_k = config.top_k
141
+
142
+ # Shared expert (always active)
143
  self.shared_expert = MLP(config, config.shared_expert_intermediate_size)
144
+
145
+ # Routed experts
146
  self.experts = nn.ModuleList([MLP(config) for _ in range(config.num_experts)])
147
  self.router = nn.Linear(config.n_embd, config.num_experts, bias=False)
148
 
149
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
150
  B, T, C = x.shape
151
+
152
+ # Shared path
153
  shared_out = self.shared_expert(x)
154
+
155
+ # Router
156
  logits = self.router(x)
157
  probs = F.softmax(logits, dim=-1)
158
+
159
+ # Top-K selection with normalized weights
160
  top_k_weights, top_k_indices = torch.topk(probs, self.top_k, dim=-1)
161
  top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
162
+
163
+ # Load balancing auxiliary loss
164
  flat_probs = probs.view(-1, self.config.num_experts)
165
  expert_usage = flat_probs.mean(dim=0)
166
  aux_loss = torch.sum(expert_usage * expert_usage) * self.config.num_experts
167
+
168
+ # Route tokens to experts
169
  routed_out = torch.zeros_like(x)
170
  flat_x = x.view(-1, C)
171
  flat_indices = top_k_indices.view(-1, self.top_k)
172
  flat_weights = top_k_weights.view(-1, self.top_k)
173
+
174
  for i, expert in enumerate(self.experts):
175
  mask = flat_indices == i
176
  batch_idx, rank_idx = torch.where(mask)
177
+
178
  if batch_idx.numel() > 0:
179
  expert_input = flat_x[batch_idx]
180
  expert_output = expert(expert_input)
181
  weight = flat_weights[batch_idx, rank_idx].unsqueeze(-1)
182
  routed_out.view(-1, C).index_add_(0, batch_idx, expert_output * weight)
183
+
184
  return shared_out + routed_out, aux_loss
185
 
186
+
187
  class CausalSelfAttention(nn.Module):
188
+ """Multi-head causal self-attention with Rotary Position Embeddings."""
189
+
190
  def __init__(self, config):
191
  super().__init__()
192
  self.n_head = config.n_head
193
  self.head_dim = config.head_dim
194
  self.n_embd = config.n_embd
195
+
196
  self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
197
  self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
198
 
199
+ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
200
  B, T, C = x.shape
201
+
202
  qkv = self.c_attn(x)
203
  q, k, v = qkv.split(self.n_embd, dim=2)
204
+
205
  q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
206
  k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
207
  v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
208
+
209
+ # Apply RoPE to Q and K
210
  q = apply_rope(q, freqs_cis)
211
  k = apply_rope(k, freqs_cis)
212
+
213
+ # Flash Attention (auto-dispatches to cuDNN/FlashAttn kernels)
214
  y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
215
+
216
  y = y.transpose(1, 2).contiguous().view(B, T, C)
217
  return self.c_proj(y)
218
 
219
+
220
  class Block(nn.Module):
221
+ """Transformer block: RMSNorm -> Attention -> RMSNorm -> MoE."""
222
+
223
  def __init__(self, config):
224
  super().__init__()
225
  self.ln_1 = RMSNorm(config.n_embd)
226
+ self.attn = CausalSelfAttention(config)
227
  self.ln_2 = RMSNorm(config.n_embd)
 
228
  self.mlp = SharedMoE(config)
229
 
230
+ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
231
+ x = x + self.attn(self.ln_1(x), freqs_cis)
 
232
  mlp_out, aux_loss = self.mlp(self.ln_2(x))
233
  x = x + mlp_out
234
  return x, aux_loss
235
 
236
+
237
+ # ============================================================
238
+ # Standalone model (backward compatible, no HF dependency)
239
+ # ============================================================
240
+
241
+ class DeepSeekMoE(nn.Module):
242
+ """Eve-2-MoE: DeepSeek-V3 style Mixture of Experts language model.
243
+
244
+ Standalone nn.Module — works without the transformers library.
245
+ For HuggingFace integration, use EveMoEForCausalLM instead.
246
+
247
+ Architecture:
248
+ - Token embeddings (no learned position embeddings — uses RoPE)
249
+ - N transformer blocks with RoPE attention + shared MoE FFN
250
+ - RMSNorm + tied linear head
251
+ """
252
+
253
+ def __init__(self, config: ModelConfig):
254
+ super().__init__()
255
  self.config = config
256
+
257
  self.transformer = nn.ModuleDict(dict(
258
  wte=nn.Embedding(config.vocab_size, config.n_embd),
259
  h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
260
  ln_f=RMSNorm(config.n_embd),
261
  ))
262
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
263
+
264
+ # Weight tying
265
  self.transformer.wte.weight = self.lm_head.weight
266
+
267
+ # Precompute RoPE frequencies (registered as buffer so they move with .to(device))
268
  freqs_cis = precompute_rope_freqs(config.head_dim, config.block_size, config.rope_theta)
269
  self.register_buffer("freqs_cis", freqs_cis, persistent=False)
270
 
271
+ # Initialize weights
272
+ self.apply(self._init_weights)
 
 
 
 
 
 
273
 
274
+ def _init_weights(self, module):
275
+ if isinstance(module, nn.Linear):
276
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
277
+ if module.bias is not None:
278
+ torch.nn.init.zeros_(module.bias)
279
+ elif isinstance(module, nn.Embedding):
280
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
281
 
282
+ def forward(self, idx: torch.Tensor, targets: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
 
 
 
283
  B, T = idx.shape
284
+ assert T <= self.config.block_size, f"Sequence length {T} exceeds block_size {self.config.block_size}"
285
+
286
  x = self.transformer.wte(idx)
287
+
288
  total_aux_loss = 0.0
 
 
 
289
  for block in self.transformer.h:
290
+ if self.config.use_checkpointing and self.training:
291
+ x, aux_loss = torch.utils.checkpoint.checkpoint(
292
+ block, x, self.freqs_cis, use_reentrant=False
293
+ )
294
+ else:
295
+ x, aux_loss = block(x, self.freqs_cis)
296
  total_aux_loss += aux_loss
297
+
298
  x = self.transformer.ln_f(x)
299
  logits = self.lm_head(x)
300
+
301
  loss = None
302
  if targets is not None:
303
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
304
  loss = loss + self.config.router_aux_loss_coef * total_aux_loss
 
 
305
 
306
+ return logits, loss
307
+
308
+ @torch.no_grad()
309
+ def generate(self, idx: torch.Tensor, max_new_tokens: int,
310
+ temperature: float = 0.8, top_k: int = 50) -> torch.Tensor:
311
+ """Autoregressive generation with temperature and top-k sampling."""
312
+ for _ in range(max_new_tokens):
313
+ idx_cond = idx[:, -self.config.block_size:]
314
+ logits, _ = self(idx_cond)
315
+ logits = logits[:, -1, :] / temperature
316
+
317
+ if top_k is not None:
318
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
319
+ logits[logits < v[:, [-1]]] = -float("Inf")
320
+
321
+ probs = F.softmax(logits, dim=-1)
322
+ idx_next = torch.multinomial(probs, num_samples=1)
323
+ idx = torch.cat((idx, idx_next), dim=1)
324
+
325
+ return idx
326
+
327
+
328
+ # ============================================================
329
+ # HuggingFace PreTrainedModel integration
330
+ # (only available when transformers is installed)
331
+ # ============================================================
332
+
333
+ try:
334
+ from transformers import PreTrainedModel
335
+ from transformers.modeling_outputs import CausalLMOutputWithPast
336
+
337
+ try:
338
+ from .configuration_eve import EveConfig
339
+ except ImportError:
340
+ from configuration_eve import EveConfig
341
+
342
+ class EveMoEPreTrainedModel(PreTrainedModel):
343
+ """Base class for Eve-2-MoE HuggingFace models."""
344
+
345
+ config_class = EveConfig
346
+ base_model_prefix = "transformer"
347
+ supports_gradient_checkpointing = True
348
+ _no_split_modules = ["Block"]
349
+
350
+ def _init_weights(self, module):
351
+ std = 0.02
352
+ if isinstance(module, nn.Linear):
353
+ module.weight.data.normal_(mean=0.0, std=std)
354
+ if module.bias is not None:
355
+ module.bias.data.zero_()
356
+ elif isinstance(module, nn.Embedding):
357
+ module.weight.data.normal_(mean=0.0, std=std)
358
+
359
+ class EveMoEForCausalLM(EveMoEPreTrainedModel):
360
+ """Eve-2-MoE for causal language modeling (HuggingFace compatible).
361
+
362
+ This model has the same weights and architecture as DeepSeekMoE but
363
+ follows HuggingFace conventions for from_pretrained() and generate().
364
+
365
+ Usage:
366
+ from transformers import AutoModelForCausalLM
367
+ model = AutoModelForCausalLM.from_pretrained(
368
+ "anthonym21/Eve-2-MoE-272M", trust_remote_code=True
369
+ )
370
+ output = model.generate(input_ids, max_new_tokens=100)
371
+ """
372
+
373
+ _tied_weights_keys = ["lm_head.weight"]
374
+
375
+ def __init__(self, config: EveConfig):
376
+ super().__init__(config)
377
+
378
+ self.transformer = nn.ModuleDict(dict(
379
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
380
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
381
+ ln_f=RMSNorm(config.n_embd),
382
+ ))
383
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
384
+
385
+ # Precompute RoPE frequencies
386
+ freqs_cis = precompute_rope_freqs(config.head_dim, config.block_size, config.rope_theta)
387
+ self.register_buffer("freqs_cis", freqs_cis, persistent=False)
388
+
389
+ # Initialize weights and apply final processing
390
+ self.post_init()
391
+
392
+ def get_input_embeddings(self):
393
+ return self.transformer.wte
394
+
395
+ def set_input_embeddings(self, value):
396
+ self.transformer.wte = value
397
+
398
+ def get_output_embeddings(self):
399
+ return self.lm_head
400
+
401
+ def set_output_embeddings(self, new_embeddings):
402
+ self.lm_head = new_embeddings
403
+
404
+ def forward(
405
+ self,
406
+ input_ids: torch.LongTensor = None,
407
+ attention_mask: torch.Tensor = None,
408
+ labels: torch.LongTensor = None,
409
+ return_dict: bool = None,
410
+ **kwargs,
411
+ ):
412
+ """
413
+ Args:
414
+ input_ids: Token IDs, shape (batch, seq_len).
415
+ attention_mask: Ignored (model uses causal mask via Flash Attention).
416
+ Accepted for pipeline/generate() compatibility.
417
+ labels: Language modeling labels. Same shape as input_ids.
418
+ The loss is computed with internal shift (labels[..., 1:] predicted
419
+ from input[..., :-1]), following HuggingFace convention.
420
+ return_dict: Whether to return a CausalLMOutputWithPast or a tuple.
421
+ """
422
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
423
+
424
+ B, T = input_ids.shape
425
+ assert T <= self.config.block_size, \
426
+ f"Sequence length {T} exceeds block_size {self.config.block_size}"
427
+
428
+ x = self.transformer.wte(input_ids)
429
+
430
+ total_aux_loss = 0.0
431
+ for block in self.transformer.h:
432
+ if self.config.use_checkpointing and self.training:
433
+ x, aux_loss = torch.utils.checkpoint.checkpoint(
434
+ block, x, self.freqs_cis, use_reentrant=False
435
+ )
436
+ else:
437
+ x, aux_loss = block(x, self.freqs_cis)
438
+ total_aux_loss += aux_loss
439
+
440
+ x = self.transformer.ln_f(x)
441
+ logits = self.lm_head(x)
442
+
443
+ loss = None
444
+ if labels is not None:
445
+ # Shift so that tokens < n predict n (HF convention)
446
+ shift_logits = logits[..., :-1, :].contiguous()
447
+ shift_labels = labels[..., 1:].contiguous()
448
+ loss = F.cross_entropy(
449
+ shift_logits.view(-1, self.config.vocab_size),
450
+ shift_labels.view(-1),
451
+ )
452
+ loss = loss + self.config.router_aux_loss_coef * total_aux_loss
453
+
454
+ if not return_dict:
455
+ output = (logits,)
456
+ return (loss,) + output if loss is not None else output
457
+
458
+ return CausalLMOutputWithPast(
459
+ loss=loss,
460
+ logits=logits,
461
+ )
462
+
463
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
464
+ # Truncate to block_size for models without KV cache
465
+ if input_ids.shape[1] > self.config.block_size:
466
+ input_ids = input_ids[:, -self.config.block_size:]
467
+ if attention_mask is not None:
468
+ attention_mask = attention_mask[:, -self.config.block_size:]
469
+
470
+ return {
471
+ "input_ids": input_ids,
472
+ "attention_mask": attention_mask,
473
+ }
474
+
475
+ def load_state_dict(self, state_dict, *args, **kwargs):
476
+ """Override to handle weights saved from torch.compile'd models."""
477
+ # Strip _orig_mod. prefix if present (torch.compile artifact)
478
+ if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
479
+ state_dict = _strip_orig_mod_prefix(state_dict)
480
+ return super().load_state_dict(state_dict, *args, **kwargs)
481
+
482
+ except ImportError:
483
+ # transformers not installed — standalone usage only (DeepSeekMoE + ModelConfig)
484
+ pass