anthonym21 commited on
Commit
1010007
·
1 Parent(s): 9f12aaa

Upload modeling_eve.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_eve.py +286 -0
modeling_eve.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Eve-2-MoE — Custom Mixture of Experts Language Model
3
+ Architecture: DeepSeek-V3 style Shared Expert + Top-K Routed Experts + RoPE
4
+ Author: Anthony Maio / Making Minds AI Research
5
+ License: MIT
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import math
12
+ from dataclasses import dataclass
13
+
14
+
15
+ @dataclass
16
+ class ModelConfig:
17
+ """Configuration for Eve-2-MoE."""
18
+
19
+ # Model dimensions
20
+ vocab_size: int = 50304
21
+ n_layer: int = 12
22
+ n_embd: int = 512
23
+ n_head: int = 8
24
+ head_dim: int = 64
25
+ block_size: int = 2048
26
+
27
+ # MoE settings
28
+ num_experts: int = 8
29
+ top_k: int = 2
30
+ expert_intermediate_size: int = 1408
31
+ shared_expert_intermediate_size: int = 1408
32
+ router_aux_loss_coef: float = 0.01
33
+
34
+ # Training settings
35
+ use_checkpointing: bool = False # Gradient checkpointing (saves VRAM, costs speed)
36
+
37
+ # RoPE settings
38
+ rope_theta: float = 10000.0
39
+
40
+
41
+ class RMSNorm(nn.Module):
42
+ """Root Mean Square Layer Normalization."""
43
+
44
+ def __init__(self, dim: int, eps: float = 1e-5):
45
+ super().__init__()
46
+ self.eps = eps
47
+ self.weight = nn.Parameter(torch.ones(dim))
48
+
49
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
50
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
51
+
52
+
53
+ def precompute_rope_freqs(head_dim: int, max_seq_len: int, theta: float = 10000.0,
54
+ device: torch.device = None) -> torch.Tensor:
55
+ """Precompute the complex exponential frequencies for RoPE.
56
+
57
+ Returns a (max_seq_len, head_dim // 2) complex tensor.
58
+ """
59
+ freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
60
+ t = torch.arange(max_seq_len, device=device).float()
61
+ freqs = torch.outer(t, freqs)
62
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
63
+
64
+
65
+ def apply_rope(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
66
+ """Apply rotary position embeddings to input tensor.
67
+
68
+ Args:
69
+ x: (B, n_head, T, head_dim)
70
+ freqs_cis: (T, head_dim // 2) complex
71
+ Returns:
72
+ (B, n_head, T, head_dim) with rotary embeddings applied
73
+ """
74
+ # Reshape x to complex: (B, n_head, T, head_dim//2, 2) -> complex
75
+ B, H, T, D = x.shape
76
+ x_complex = torch.view_as_complex(x.float().reshape(B, H, T, D // 2, 2))
77
+ # Broadcast freqs_cis: (1, 1, T, head_dim//2)
78
+ freqs_cis = freqs_cis[:T].unsqueeze(0).unsqueeze(0)
79
+ x_rotated = x_complex * freqs_cis
80
+ # Back to real: (B, H, T, head_dim)
81
+ return torch.view_as_real(x_rotated).reshape(B, H, T, D).type_as(x)
82
+
83
+
84
+ class MLP(nn.Module):
85
+ """Feed-forward network with SwiGLU activation."""
86
+
87
+ def __init__(self, config: ModelConfig, intermediate_size: int = None):
88
+ super().__init__()
89
+ hidden_dim = intermediate_size or config.expert_intermediate_size
90
+ self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False) # Gate
91
+ self.w2 = nn.Linear(config.n_embd, hidden_dim, bias=False) # Up
92
+ self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) # Down
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ return self.c_proj(F.silu(self.w1(x)) * self.w2(x))
96
+
97
+
98
+ class SharedMoE(nn.Module):
99
+ """Mixture of Experts with one shared expert and K routed experts.
100
+
101
+ DeepSeek-V3 style: a shared expert processes all tokens while a top-k
102
+ router selects from a pool of specialized experts per token.
103
+ """
104
+
105
+ def __init__(self, config: ModelConfig):
106
+ super().__init__()
107
+ self.config = config
108
+ self.top_k = config.top_k
109
+
110
+ # Shared expert (always active)
111
+ self.shared_expert = MLP(config, config.shared_expert_intermediate_size)
112
+
113
+ # Routed experts
114
+ self.experts = nn.ModuleList([MLP(config) for _ in range(config.num_experts)])
115
+ self.router = nn.Linear(config.n_embd, config.num_experts, bias=False)
116
+
117
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
118
+ B, T, C = x.shape
119
+
120
+ # Shared path
121
+ shared_out = self.shared_expert(x)
122
+
123
+ # Router
124
+ logits = self.router(x)
125
+ probs = F.softmax(logits, dim=-1)
126
+
127
+ # Top-K selection with normalized weights
128
+ top_k_weights, top_k_indices = torch.topk(probs, self.top_k, dim=-1)
129
+ top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
130
+
131
+ # Load balancing auxiliary loss
132
+ flat_probs = probs.view(-1, self.config.num_experts)
133
+ expert_usage = flat_probs.mean(dim=0)
134
+ aux_loss = torch.sum(expert_usage * expert_usage) * self.config.num_experts
135
+
136
+ # Route tokens to experts
137
+ routed_out = torch.zeros_like(x)
138
+ flat_x = x.view(-1, C)
139
+ flat_indices = top_k_indices.view(-1, self.top_k)
140
+ flat_weights = top_k_weights.view(-1, self.top_k)
141
+
142
+ for i, expert in enumerate(self.experts):
143
+ mask = flat_indices == i
144
+ batch_idx, rank_idx = torch.where(mask)
145
+
146
+ if batch_idx.numel() > 0:
147
+ expert_input = flat_x[batch_idx]
148
+ expert_output = expert(expert_input)
149
+ weight = flat_weights[batch_idx, rank_idx].unsqueeze(-1)
150
+ routed_out.view(-1, C).index_add_(0, batch_idx, expert_output * weight)
151
+
152
+ return shared_out + routed_out, aux_loss
153
+
154
+
155
+ class CausalSelfAttention(nn.Module):
156
+ """Multi-head causal self-attention with Rotary Position Embeddings."""
157
+
158
+ def __init__(self, config: ModelConfig):
159
+ super().__init__()
160
+ self.n_head = config.n_head
161
+ self.head_dim = config.head_dim
162
+ self.n_embd = config.n_embd
163
+
164
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
165
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
166
+
167
+ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
168
+ B, T, C = x.shape
169
+
170
+ qkv = self.c_attn(x)
171
+ q, k, v = qkv.split(self.n_embd, dim=2)
172
+
173
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
174
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
175
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
176
+
177
+ # Apply RoPE to Q and K
178
+ q = apply_rope(q, freqs_cis)
179
+ k = apply_rope(k, freqs_cis)
180
+
181
+ # Flash Attention (auto-dispatches to cuDNN/FlashAttn kernels)
182
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
183
+
184
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
185
+ return self.c_proj(y)
186
+
187
+
188
+ class Block(nn.Module):
189
+ """Transformer block: RMSNorm → Attention → RMSNorm → MoE."""
190
+
191
+ def __init__(self, config: ModelConfig):
192
+ super().__init__()
193
+ self.ln_1 = RMSNorm(config.n_embd)
194
+ self.attn = CausalSelfAttention(config)
195
+ self.ln_2 = RMSNorm(config.n_embd)
196
+ self.mlp = SharedMoE(config)
197
+
198
+ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
199
+ x = x + self.attn(self.ln_1(x), freqs_cis)
200
+ mlp_out, aux_loss = self.mlp(self.ln_2(x))
201
+ x = x + mlp_out
202
+ return x, aux_loss
203
+
204
+
205
+ class DeepSeekMoE(nn.Module):
206
+ """Eve-2-MoE: DeepSeek-V3 style Mixture of Experts language model.
207
+
208
+ Architecture:
209
+ - Token embeddings (no learned position embeddings — uses RoPE)
210
+ - N transformer blocks with RoPE attention + shared MoE FFN
211
+ - RMSNorm + tied linear head
212
+ """
213
+
214
+ def __init__(self, config: ModelConfig):
215
+ super().__init__()
216
+ self.config = config
217
+
218
+ self.transformer = nn.ModuleDict(dict(
219
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
220
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
221
+ ln_f=RMSNorm(config.n_embd),
222
+ ))
223
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
224
+
225
+ # Weight tying
226
+ self.transformer.wte.weight = self.lm_head.weight
227
+
228
+ # Precompute RoPE frequencies (registered as buffer so they move with .to(device))
229
+ freqs_cis = precompute_rope_freqs(config.head_dim, config.block_size, config.rope_theta)
230
+ self.register_buffer("freqs_cis", freqs_cis, persistent=False)
231
+
232
+ # Initialize weights
233
+ self.apply(self._init_weights)
234
+
235
+ def _init_weights(self, module):
236
+ if isinstance(module, nn.Linear):
237
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
238
+ if module.bias is not None:
239
+ torch.nn.init.zeros_(module.bias)
240
+ elif isinstance(module, nn.Embedding):
241
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
242
+
243
+ def forward(self, idx: torch.Tensor, targets: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
244
+ B, T = idx.shape
245
+ assert T <= self.config.block_size, f"Sequence length {T} exceeds block_size {self.config.block_size}"
246
+
247
+ x = self.transformer.wte(idx)
248
+
249
+ total_aux_loss = 0.0
250
+ for block in self.transformer.h:
251
+ if self.config.use_checkpointing and self.training:
252
+ x, aux_loss = torch.utils.checkpoint.checkpoint(
253
+ block, x, self.freqs_cis, use_reentrant=False
254
+ )
255
+ else:
256
+ x, aux_loss = block(x, self.freqs_cis)
257
+ total_aux_loss += aux_loss
258
+
259
+ x = self.transformer.ln_f(x)
260
+ logits = self.lm_head(x)
261
+
262
+ loss = None
263
+ if targets is not None:
264
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
265
+ loss = loss + self.config.router_aux_loss_coef * total_aux_loss
266
+
267
+ return logits, loss
268
+
269
+ @torch.no_grad()
270
+ def generate(self, idx: torch.Tensor, max_new_tokens: int,
271
+ temperature: float = 0.8, top_k: int = 50) -> torch.Tensor:
272
+ """Autoregressive generation with temperature and top-k sampling."""
273
+ for _ in range(max_new_tokens):
274
+ idx_cond = idx[:, -self.config.block_size:]
275
+ logits, _ = self(idx_cond)
276
+ logits = logits[:, -1, :] / temperature
277
+
278
+ if top_k is not None:
279
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
280
+ logits[logits < v[:, [-1]]] = -float("Inf")
281
+
282
+ probs = F.softmax(logits, dim=-1)
283
+ idx_next = torch.multinomial(probs, num_samples=1)
284
+ idx = torch.cat((idx, idx_next), dim=1)
285
+
286
+ return idx