anthonym21 commited on
Commit
dd9d5c4
·
verified ·
1 Parent(s): fd5b405

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. config.json +14 -0
  2. generate.py +85 -0
  3. modeling_eve.py +286 -0
  4. pytorch_model.bin +3 -0
  5. requirements.txt +5 -0
  6. train.py +482 -0
config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architecture": "Eve-2-MoE",
3
+ "vocab_size": 50304,
4
+ "n_layer": 12,
5
+ "n_embd": 512,
6
+ "n_head": 8,
7
+ "head_dim": 64,
8
+ "block_size": 2048,
9
+ "num_experts": 8,
10
+ "top_k": 2,
11
+ "expert_intermediate_size": 1408,
12
+ "shared_expert_intermediate_size": 1408,
13
+ "rope_theta": 10000.0
14
+ }
generate.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Eve-2-MoE Inference
3
+ ===================
4
+ Quick generation script. Works with local weights or HuggingFace download.
5
+
6
+ Usage:
7
+ python generate.py --prompt "The future of AI is"
8
+ python generate.py --prompt "The future of AI is" --model_path ./model_final/pytorch_model.bin
9
+ python generate.py --prompt "The future of AI is" --hf_repo anthonym21/Eve-2-MoE-250M
10
+ """
11
+
12
+ import argparse
13
+ import torch
14
+ import tiktoken
15
+ from modeling_eve import ModelConfig, DeepSeekMoE
16
+
17
+
18
+ def load_model(model_path: str = None, hf_repo: str = None, device: str = "cuda"):
19
+ config = ModelConfig()
20
+ model = DeepSeekMoE(config)
21
+
22
+ if hf_repo:
23
+ from huggingface_hub import hf_hub_download
24
+ model_path = hf_hub_download(repo_id=hf_repo, filename="pytorch_model.bin")
25
+
26
+ if model_path:
27
+ state_dict = torch.load(model_path, map_location=device, weights_only=True)
28
+ model.load_state_dict(state_dict)
29
+
30
+ return model.to(device).eval()
31
+
32
+
33
+ def generate_streaming(model, prompt: str, max_tokens: int = 200,
34
+ temperature: float = 0.8, top_k: int = 50, device: str = "cuda"):
35
+ enc = tiktoken.get_encoding("gpt2")
36
+ tokens = torch.tensor(enc.encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
37
+
38
+ print(prompt, end="", flush=True)
39
+
40
+ with torch.no_grad():
41
+ for _ in range(max_tokens):
42
+ idx_cond = tokens[:, -model.config.block_size:]
43
+
44
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=(device == "cuda")):
45
+ logits, _ = model(idx_cond)
46
+
47
+ logits = logits[:, -1, :] / temperature
48
+
49
+ if top_k is not None:
50
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
51
+ logits[logits < v[:, [-1]]] = -float("Inf")
52
+
53
+ probs = torch.softmax(logits, dim=-1)
54
+ idx_next = torch.multinomial(probs, num_samples=1)
55
+ tokens = torch.cat((tokens, idx_next), dim=1)
56
+
57
+ print(enc.decode([idx_next.item()]), end="", flush=True)
58
+
59
+ print("\n")
60
+
61
+
62
+ def main():
63
+ p = argparse.ArgumentParser()
64
+ p.add_argument("--prompt", type=str, default="The future of artificial intelligence is")
65
+ p.add_argument("--model_path", type=str, default=None)
66
+ p.add_argument("--hf_repo", type=str, default=None)
67
+ p.add_argument("--max_tokens", type=int, default=200)
68
+ p.add_argument("--temperature", type=float, default=0.8)
69
+ p.add_argument("--top_k", type=int, default=50)
70
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
71
+ args = p.parse_args()
72
+
73
+ if not args.model_path and not args.hf_repo:
74
+ args.hf_repo = "anthonym21/Eve-2-MoE-250M"
75
+
76
+ print(f"Loading model on {args.device}...")
77
+ model = load_model(args.model_path, args.hf_repo, args.device)
78
+ param_count = sum(p.numel() for p in model.parameters())
79
+ print(f"Parameters: {param_count / 1e6:.2f}M\n")
80
+
81
+ generate_streaming(model, args.prompt, args.max_tokens, args.temperature, args.top_k, args.device)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()
modeling_eve.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Eve-2-MoE — Custom Mixture of Experts Language Model
3
+ Architecture: DeepSeek-V3 style Shared Expert + Top-K Routed Experts + RoPE
4
+ Author: Anthony Maio / Making Minds AI Research
5
+ License: MIT
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import math
12
+ from dataclasses import dataclass
13
+
14
+
15
+ @dataclass
16
+ class ModelConfig:
17
+ """Configuration for Eve-2-MoE."""
18
+
19
+ # Model dimensions
20
+ vocab_size: int = 50304
21
+ n_layer: int = 12
22
+ n_embd: int = 512
23
+ n_head: int = 8
24
+ head_dim: int = 64
25
+ block_size: int = 2048
26
+
27
+ # MoE settings
28
+ num_experts: int = 8
29
+ top_k: int = 2
30
+ expert_intermediate_size: int = 1408
31
+ shared_expert_intermediate_size: int = 1408
32
+ router_aux_loss_coef: float = 0.01
33
+
34
+ # Training settings
35
+ use_checkpointing: bool = False # Gradient checkpointing (saves VRAM, costs speed)
36
+
37
+ # RoPE settings
38
+ rope_theta: float = 10000.0
39
+
40
+
41
+ class RMSNorm(nn.Module):
42
+ """Root Mean Square Layer Normalization."""
43
+
44
+ def __init__(self, dim: int, eps: float = 1e-5):
45
+ super().__init__()
46
+ self.eps = eps
47
+ self.weight = nn.Parameter(torch.ones(dim))
48
+
49
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
50
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
51
+
52
+
53
+ def precompute_rope_freqs(head_dim: int, max_seq_len: int, theta: float = 10000.0,
54
+ device: torch.device = None) -> torch.Tensor:
55
+ """Precompute the complex exponential frequencies for RoPE.
56
+
57
+ Returns a (max_seq_len, head_dim // 2) complex tensor.
58
+ """
59
+ freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
60
+ t = torch.arange(max_seq_len, device=device).float()
61
+ freqs = torch.outer(t, freqs)
62
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
63
+
64
+
65
+ def apply_rope(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
66
+ """Apply rotary position embeddings to input tensor.
67
+
68
+ Args:
69
+ x: (B, n_head, T, head_dim)
70
+ freqs_cis: (T, head_dim // 2) complex
71
+ Returns:
72
+ (B, n_head, T, head_dim) with rotary embeddings applied
73
+ """
74
+ # Reshape x to complex: (B, n_head, T, head_dim//2, 2) -> complex
75
+ B, H, T, D = x.shape
76
+ x_complex = torch.view_as_complex(x.float().reshape(B, H, T, D // 2, 2))
77
+ # Broadcast freqs_cis: (1, 1, T, head_dim//2)
78
+ freqs_cis = freqs_cis[:T].unsqueeze(0).unsqueeze(0)
79
+ x_rotated = x_complex * freqs_cis
80
+ # Back to real: (B, H, T, head_dim)
81
+ return torch.view_as_real(x_rotated).reshape(B, H, T, D).type_as(x)
82
+
83
+
84
+ class MLP(nn.Module):
85
+ """Feed-forward network with SwiGLU activation."""
86
+
87
+ def __init__(self, config: ModelConfig, intermediate_size: int = None):
88
+ super().__init__()
89
+ hidden_dim = intermediate_size or config.expert_intermediate_size
90
+ self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False) # Gate
91
+ self.w2 = nn.Linear(config.n_embd, hidden_dim, bias=False) # Up
92
+ self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) # Down
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ return self.c_proj(F.silu(self.w1(x)) * self.w2(x))
96
+
97
+
98
+ class SharedMoE(nn.Module):
99
+ """Mixture of Experts with one shared expert and K routed experts.
100
+
101
+ DeepSeek-V3 style: a shared expert processes all tokens while a top-k
102
+ router selects from a pool of specialized experts per token.
103
+ """
104
+
105
+ def __init__(self, config: ModelConfig):
106
+ super().__init__()
107
+ self.config = config
108
+ self.top_k = config.top_k
109
+
110
+ # Shared expert (always active)
111
+ self.shared_expert = MLP(config, config.shared_expert_intermediate_size)
112
+
113
+ # Routed experts
114
+ self.experts = nn.ModuleList([MLP(config) for _ in range(config.num_experts)])
115
+ self.router = nn.Linear(config.n_embd, config.num_experts, bias=False)
116
+
117
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
118
+ B, T, C = x.shape
119
+
120
+ # Shared path
121
+ shared_out = self.shared_expert(x)
122
+
123
+ # Router
124
+ logits = self.router(x)
125
+ probs = F.softmax(logits, dim=-1)
126
+
127
+ # Top-K selection with normalized weights
128
+ top_k_weights, top_k_indices = torch.topk(probs, self.top_k, dim=-1)
129
+ top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
130
+
131
+ # Load balancing auxiliary loss
132
+ flat_probs = probs.view(-1, self.config.num_experts)
133
+ expert_usage = flat_probs.mean(dim=0)
134
+ aux_loss = torch.sum(expert_usage * expert_usage) * self.config.num_experts
135
+
136
+ # Route tokens to experts
137
+ routed_out = torch.zeros_like(x)
138
+ flat_x = x.view(-1, C)
139
+ flat_indices = top_k_indices.view(-1, self.top_k)
140
+ flat_weights = top_k_weights.view(-1, self.top_k)
141
+
142
+ for i, expert in enumerate(self.experts):
143
+ mask = flat_indices == i
144
+ batch_idx, rank_idx = torch.where(mask)
145
+
146
+ if batch_idx.numel() > 0:
147
+ expert_input = flat_x[batch_idx]
148
+ expert_output = expert(expert_input)
149
+ weight = flat_weights[batch_idx, rank_idx].unsqueeze(-1)
150
+ routed_out.view(-1, C).index_add_(0, batch_idx, expert_output * weight)
151
+
152
+ return shared_out + routed_out, aux_loss
153
+
154
+
155
+ class CausalSelfAttention(nn.Module):
156
+ """Multi-head causal self-attention with Rotary Position Embeddings."""
157
+
158
+ def __init__(self, config: ModelConfig):
159
+ super().__init__()
160
+ self.n_head = config.n_head
161
+ self.head_dim = config.head_dim
162
+ self.n_embd = config.n_embd
163
+
164
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
165
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
166
+
167
+ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
168
+ B, T, C = x.shape
169
+
170
+ qkv = self.c_attn(x)
171
+ q, k, v = qkv.split(self.n_embd, dim=2)
172
+
173
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
174
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
175
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
176
+
177
+ # Apply RoPE to Q and K
178
+ q = apply_rope(q, freqs_cis)
179
+ k = apply_rope(k, freqs_cis)
180
+
181
+ # Flash Attention (auto-dispatches to cuDNN/FlashAttn kernels)
182
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
183
+
184
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
185
+ return self.c_proj(y)
186
+
187
+
188
+ class Block(nn.Module):
189
+ """Transformer block: RMSNorm → Attention → RMSNorm → MoE."""
190
+
191
+ def __init__(self, config: ModelConfig):
192
+ super().__init__()
193
+ self.ln_1 = RMSNorm(config.n_embd)
194
+ self.attn = CausalSelfAttention(config)
195
+ self.ln_2 = RMSNorm(config.n_embd)
196
+ self.mlp = SharedMoE(config)
197
+
198
+ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
199
+ x = x + self.attn(self.ln_1(x), freqs_cis)
200
+ mlp_out, aux_loss = self.mlp(self.ln_2(x))
201
+ x = x + mlp_out
202
+ return x, aux_loss
203
+
204
+
205
+ class DeepSeekMoE(nn.Module):
206
+ """Eve-2-MoE: DeepSeek-V3 style Mixture of Experts language model.
207
+
208
+ Architecture:
209
+ - Token embeddings (no learned position embeddings — uses RoPE)
210
+ - N transformer blocks with RoPE attention + shared MoE FFN
211
+ - RMSNorm + tied linear head
212
+ """
213
+
214
+ def __init__(self, config: ModelConfig):
215
+ super().__init__()
216
+ self.config = config
217
+
218
+ self.transformer = nn.ModuleDict(dict(
219
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
220
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
221
+ ln_f=RMSNorm(config.n_embd),
222
+ ))
223
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
224
+
225
+ # Weight tying
226
+ self.transformer.wte.weight = self.lm_head.weight
227
+
228
+ # Precompute RoPE frequencies (registered as buffer so they move with .to(device))
229
+ freqs_cis = precompute_rope_freqs(config.head_dim, config.block_size, config.rope_theta)
230
+ self.register_buffer("freqs_cis", freqs_cis, persistent=False)
231
+
232
+ # Initialize weights
233
+ self.apply(self._init_weights)
234
+
235
+ def _init_weights(self, module):
236
+ if isinstance(module, nn.Linear):
237
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
238
+ if module.bias is not None:
239
+ torch.nn.init.zeros_(module.bias)
240
+ elif isinstance(module, nn.Embedding):
241
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
242
+
243
+ def forward(self, idx: torch.Tensor, targets: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
244
+ B, T = idx.shape
245
+ assert T <= self.config.block_size, f"Sequence length {T} exceeds block_size {self.config.block_size}"
246
+
247
+ x = self.transformer.wte(idx)
248
+
249
+ total_aux_loss = 0.0
250
+ for block in self.transformer.h:
251
+ if self.config.use_checkpointing and self.training:
252
+ x, aux_loss = torch.utils.checkpoint.checkpoint(
253
+ block, x, self.freqs_cis, use_reentrant=False
254
+ )
255
+ else:
256
+ x, aux_loss = block(x, self.freqs_cis)
257
+ total_aux_loss += aux_loss
258
+
259
+ x = self.transformer.ln_f(x)
260
+ logits = self.lm_head(x)
261
+
262
+ loss = None
263
+ if targets is not None:
264
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
265
+ loss = loss + self.config.router_aux_loss_coef * total_aux_loss
266
+
267
+ return logits, loss
268
+
269
+ @torch.no_grad()
270
+ def generate(self, idx: torch.Tensor, max_new_tokens: int,
271
+ temperature: float = 0.8, top_k: int = 50) -> torch.Tensor:
272
+ """Autoregressive generation with temperature and top-k sampling."""
273
+ for _ in range(max_new_tokens):
274
+ idx_cond = idx[:, -self.config.block_size:]
275
+ logits, _ = self(idx_cond)
276
+ logits = logits[:, -1, :] / temperature
277
+
278
+ if top_k is not None:
279
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
280
+ logits[logits < v[:, [-1]]] = -float("Inf")
281
+
282
+ probs = F.softmax(logits, dim=-1)
283
+ idx_next = torch.multinomial(probs, num_samples=1)
284
+ idx = torch.cat((idx, idx_next), dim=1)
285
+
286
+ return idx
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68b3a3b00732a4977ef4c27d6dfbcc5ca70f73d47047103c108baac3a5d2108a
3
+ size 1088054098
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.2.0
2
+ tiktoken
3
+ datasets
4
+ huggingface_hub
5
+ wandb
train.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Eve-2-MoE Training Script — Multi-GPU DDP
3
+ ==========================================
4
+ Usage:
5
+ Single GPU: python train.py
6
+ Multi-GPU: torchrun --nproc_per_node=2 train.py
7
+ 4x GPU: torchrun --nproc_per_node=4 train.py
8
+
9
+ Override config: torchrun --nproc_per_node=2 train.py --max_steps 15000 --batch_size 48
10
+
11
+ Author: Anthony Maio / Making Minds AI Research
12
+ """
13
+
14
+ import os
15
+ import sys
16
+ import math
17
+ import time
18
+ import json
19
+ import argparse
20
+ import logging
21
+ from pathlib import Path
22
+ from contextlib import nullcontext
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ import torch.distributed as dist
28
+ from torch.nn.parallel import DistributedDataParallel as DDP
29
+
30
+ import tiktoken
31
+ from datasets import load_dataset
32
+
33
+ from modeling_eve import ModelConfig, DeepSeekMoE
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Distributed setup
37
+ # ---------------------------------------------------------------------------
38
+
39
+ def setup_distributed():
40
+ """Initialize DDP if launched with torchrun, otherwise single-GPU."""
41
+ if "RANK" in os.environ:
42
+ dist.init_process_group(backend="nccl")
43
+ rank = dist.get_rank()
44
+ world_size = dist.get_world_size()
45
+ local_rank = int(os.environ["LOCAL_RANK"])
46
+ torch.cuda.set_device(local_rank)
47
+ device = torch.device(f"cuda:{local_rank}")
48
+ else:
49
+ rank = 0
50
+ world_size = 1
51
+ local_rank = 0
52
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
+
54
+ is_master = rank == 0
55
+ return rank, world_size, local_rank, device, is_master
56
+
57
+
58
+ def cleanup_distributed():
59
+ if dist.is_initialized():
60
+ dist.destroy_process_group()
61
+
62
+
63
+ # ---------------------------------------------------------------------------
64
+ # Data loading
65
+ # ---------------------------------------------------------------------------
66
+
67
+ class StreamingDataLoader:
68
+ """Streams tokenized batches from FineWeb-Edu.
69
+
70
+ Each DDP rank skips interleaved samples so no two GPUs see the same data.
71
+ """
72
+
73
+ def __init__(self, batch_size: int, block_size: int, rank: int = 0,
74
+ world_size: int = 1, dataset_name: str = "sample-10BT"):
75
+ self.batch_size = batch_size
76
+ self.block_size = block_size
77
+ self.rank = rank
78
+ self.world_size = world_size
79
+ self.dataset_name = dataset_name
80
+ self.enc = tiktoken.get_encoding("gpt2")
81
+ self._init_stream()
82
+
83
+ def _init_stream(self):
84
+ ds = load_dataset("HuggingFaceFW/fineweb-edu", name=self.dataset_name,
85
+ split="train", streaming=True)
86
+ # Shard the stream across DDP ranks
87
+ if self.world_size > 1:
88
+ ds = ds.shard(num_shards=self.world_size, index=self.rank)
89
+ self.iter_dataset = iter(ds)
90
+
91
+ def get_batch(self) -> tuple[torch.Tensor, torch.Tensor]:
92
+ total_tokens = self.batch_size * self.block_size
93
+
94
+ batch_tokens = []
95
+ while len(batch_tokens) < total_tokens + 1:
96
+ try:
97
+ text = next(self.iter_dataset)["text"]
98
+ tokens = self.enc.encode(text, allowed_special={"<|endoftext|>"})
99
+ batch_tokens.extend(tokens)
100
+ except StopIteration:
101
+ print(f"[Rank {self.rank}] Dataset exhausted, restarting stream...")
102
+ self._init_stream()
103
+
104
+ data = torch.tensor(batch_tokens[:total_tokens + 1], dtype=torch.long)
105
+ x = data[:total_tokens].view(self.batch_size, self.block_size)
106
+ y = data[1:total_tokens + 1].view(self.batch_size, self.block_size)
107
+ return x, y
108
+
109
+
110
+ class ValidationLoader:
111
+ """WikiText-2 validation set."""
112
+
113
+ def __init__(self, block_size: int, device: torch.device):
114
+ self.block_size = block_size
115
+ self.device = device
116
+ enc = tiktoken.get_encoding("gpt2")
117
+
118
+ ds = load_dataset("wikitext", "wikitext-2-v1", split="test")
119
+ text = "\n\n".join(ds["text"])
120
+ tokens = enc.encode(text, allowed_special={"<|endoftext|>"})
121
+ self.data = torch.tensor(tokens, dtype=torch.long, device=device)
122
+
123
+ @torch.no_grad()
124
+ def estimate_loss(self, model, eval_iters: int = 50, batch_size: int = 32) -> float:
125
+ model.eval()
126
+ losses = torch.zeros(eval_iters, device=self.device)
127
+
128
+ for k in range(eval_iters):
129
+ ix = torch.randint(len(self.data) - self.block_size, (batch_size,))
130
+ x = torch.stack([self.data[i:i + self.block_size] for i in ix])
131
+ y = torch.stack([self.data[i + 1:i + self.block_size + 1] for i in ix])
132
+
133
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
134
+ _, loss = model(x, y)
135
+ losses[k] = loss.item()
136
+
137
+ model.train()
138
+ return losses.mean().item()
139
+
140
+
141
+ # ---------------------------------------------------------------------------
142
+ # Learning rate schedule
143
+ # ---------------------------------------------------------------------------
144
+
145
+ def get_lr(step: int, max_steps: int, warmup_steps: int, peak_lr: float, min_lr_ratio: float = 0.1) -> float:
146
+ """Cosine decay with linear warmup."""
147
+ min_lr = peak_lr * min_lr_ratio
148
+
149
+ # Linear warmup
150
+ if step < warmup_steps:
151
+ return peak_lr * (step + 1) / (warmup_steps + 1)
152
+
153
+ # Post-training (shouldn't happen, but safe)
154
+ if step > max_steps:
155
+ return min_lr
156
+
157
+ # Cosine decay
158
+ decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
159
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
160
+ return min_lr + coeff * (peak_lr - min_lr)
161
+
162
+
163
+ # ---------------------------------------------------------------------------
164
+ # Checkpointing
165
+ # ---------------------------------------------------------------------------
166
+
167
+ def save_checkpoint(model, optimizer, step: int, loss: float, val_loss: float,
168
+ config: ModelConfig, checkpoint_dir: Path, is_ddp: bool):
169
+ """Save training checkpoint (model weights, optimizer state, metadata)."""
170
+ raw_model = model.module if is_ddp else model
171
+ checkpoint = {
172
+ "step": step,
173
+ "model_state_dict": raw_model.state_dict(),
174
+ "optimizer_state_dict": optimizer.state_dict(),
175
+ "train_loss": loss,
176
+ "val_loss": val_loss,
177
+ "config": {
178
+ "vocab_size": config.vocab_size,
179
+ "n_layer": config.n_layer,
180
+ "n_embd": config.n_embd,
181
+ "n_head": config.n_head,
182
+ "head_dim": config.head_dim,
183
+ "block_size": config.block_size,
184
+ "num_experts": config.num_experts,
185
+ "top_k": config.top_k,
186
+ "expert_intermediate_size": config.expert_intermediate_size,
187
+ "shared_expert_intermediate_size": config.shared_expert_intermediate_size,
188
+ "rope_theta": config.rope_theta,
189
+ },
190
+ }
191
+ path = checkpoint_dir / f"step_{step}.pt"
192
+ torch.save(checkpoint, path)
193
+ print(f" Checkpoint saved: {path}")
194
+
195
+ # Also save a "latest" symlink/copy for easy resume
196
+ latest = checkpoint_dir / "latest.pt"
197
+ torch.save(checkpoint, latest)
198
+
199
+
200
+ def save_final_model(model, config: ModelConfig, output_dir: Path, is_ddp: bool):
201
+ """Save just the model weights + config for HuggingFace upload."""
202
+ raw_model = model.module if is_ddp else model
203
+ output_dir.mkdir(parents=True, exist_ok=True)
204
+
205
+ torch.save(raw_model.state_dict(), output_dir / "pytorch_model.bin")
206
+
207
+ config_data = {
208
+ "architecture": "Eve-2-MoE",
209
+ "vocab_size": config.vocab_size,
210
+ "n_layer": config.n_layer,
211
+ "n_embd": config.n_embd,
212
+ "n_head": config.n_head,
213
+ "head_dim": config.head_dim,
214
+ "block_size": config.block_size,
215
+ "num_experts": config.num_experts,
216
+ "top_k": config.top_k,
217
+ "expert_intermediate_size": config.expert_intermediate_size,
218
+ "shared_expert_intermediate_size": config.shared_expert_intermediate_size,
219
+ "rope_theta": config.rope_theta,
220
+ }
221
+ with open(output_dir / "config.json", "w") as f:
222
+ json.dump(config_data, f, indent=2)
223
+
224
+ print(f" Final model saved to {output_dir}")
225
+
226
+
227
+ # ---------------------------------------------------------------------------
228
+ # Main training loop
229
+ # ---------------------------------------------------------------------------
230
+
231
+ def parse_args():
232
+ p = argparse.ArgumentParser(description="Eve-2-MoE Training")
233
+
234
+ # Architecture (defaults match 250M config)
235
+ p.add_argument("--n_layer", type=int, default=12)
236
+ p.add_argument("--n_embd", type=int, default=512)
237
+ p.add_argument("--n_head", type=int, default=8)
238
+ p.add_argument("--num_experts", type=int, default=8)
239
+ p.add_argument("--block_size", type=int, default=2048)
240
+
241
+ # Training
242
+ p.add_argument("--max_steps", type=int, default=7500,
243
+ help="Total training steps. 7500 steps ≈ 500M tokens (1hr single B200)")
244
+ p.add_argument("--batch_size", type=int, default=32,
245
+ help="Per-GPU batch size")
246
+ p.add_argument("--learning_rate", type=float, default=5e-4)
247
+ p.add_argument("--warmup_steps", type=int, default=200)
248
+ p.add_argument("--weight_decay", type=float, default=0.1)
249
+ p.add_argument("--grad_clip", type=float, default=1.0)
250
+ p.add_argument("--min_lr_ratio", type=float, default=0.1,
251
+ help="Minimum LR as fraction of peak (cosine decay floor)")
252
+
253
+ # Data
254
+ p.add_argument("--dataset", type=str, default="sample-10BT",
255
+ help="FineWeb-Edu subset name")
256
+
257
+ # Checkpointing
258
+ p.add_argument("--save_every", type=int, default=500)
259
+ p.add_argument("--val_every", type=int, default=500)
260
+ p.add_argument("--checkpoint_dir", type=str, default="checkpoints")
261
+ p.add_argument("--output_dir", type=str, default="model_final")
262
+
263
+ # Misc
264
+ p.add_argument("--compile", action="store_true", default=True,
265
+ help="Use torch.compile (recommended for B200/H100)")
266
+ p.add_argument("--no_compile", action="store_true",
267
+ help="Disable torch.compile")
268
+ p.add_argument("--wandb_project", type=str, default="Eve-2-MoE",
269
+ help="WandB project name (empty to disable)")
270
+ p.add_argument("--wandb_run", type=str, default=None,
271
+ help="WandB run name")
272
+ p.add_argument("--resume", type=str, default=None,
273
+ help="Path to checkpoint to resume from")
274
+ p.add_argument("--use_checkpointing", action="store_true",
275
+ help="Enable gradient checkpointing (saves VRAM)")
276
+
277
+ return p.parse_args()
278
+
279
+
280
+ def main():
281
+ args = parse_args()
282
+
283
+ # --- Distributed setup ---
284
+ rank, world_size, local_rank, device, is_master = setup_distributed()
285
+
286
+ if is_master:
287
+ print(f"{'=' * 60}")
288
+ print(f" Eve-2-MoE Training")
289
+ print(f" GPUs: {world_size} | Device: {torch.cuda.get_device_name(device)}")
290
+ print(f" Steps: {args.max_steps} | Batch/GPU: {args.batch_size}")
291
+ print(f" Global batch: {args.batch_size * world_size} × {args.block_size} = "
292
+ f"{args.batch_size * world_size * args.block_size:,} tokens/step")
293
+ print(f" Total tokens: ~{args.max_steps * args.batch_size * world_size * args.block_size / 1e9:.1f}B")
294
+ print(f"{'=' * 60}")
295
+
296
+ # --- Model ---
297
+ config = ModelConfig(
298
+ n_layer=args.n_layer,
299
+ n_embd=args.n_embd,
300
+ n_head=args.n_head,
301
+ num_experts=args.num_experts,
302
+ block_size=args.block_size,
303
+ use_checkpointing=args.use_checkpointing,
304
+ )
305
+
306
+ model = DeepSeekMoE(config).to(device)
307
+
308
+ if is_master:
309
+ param_count = sum(p.numel() for p in model.parameters())
310
+ print(f" Parameters: {param_count / 1e6:.2f}M")
311
+
312
+ # Compile
313
+ if args.compile and not args.no_compile:
314
+ if is_master:
315
+ print(" Compiling model with torch.compile...")
316
+ model = torch.compile(model)
317
+
318
+ # DDP wrapper
319
+ is_ddp = world_size > 1
320
+ if is_ddp:
321
+ model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)
322
+
323
+ raw_model = model.module if is_ddp else model
324
+
325
+ # --- Optimizer ---
326
+ optimizer = torch.optim.AdamW(
327
+ raw_model.parameters(),
328
+ lr=args.learning_rate,
329
+ betas=(0.9, 0.95),
330
+ weight_decay=args.weight_decay,
331
+ )
332
+
333
+ # --- Resume from checkpoint ---
334
+ start_step = 0
335
+ if args.resume:
336
+ if is_master:
337
+ print(f" Resuming from {args.resume}...")
338
+ ckpt = torch.load(args.resume, map_location=device)
339
+ raw_model.load_state_dict(ckpt["model_state_dict"])
340
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
341
+ start_step = ckpt["step"] + 1
342
+ if is_master:
343
+ print(f" Resumed at step {start_step}")
344
+
345
+ # --- Data ---
346
+ train_loader = StreamingDataLoader(
347
+ batch_size=args.batch_size,
348
+ block_size=config.block_size,
349
+ rank=rank,
350
+ world_size=world_size,
351
+ dataset_name=args.dataset,
352
+ )
353
+
354
+ val_loader = None
355
+ if is_master:
356
+ val_loader = ValidationLoader(config.block_size, device)
357
+
358
+ # --- Checkpoint directory ---
359
+ checkpoint_dir = Path(args.checkpoint_dir)
360
+ if is_master:
361
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
362
+
363
+ # --- WandB ---
364
+ wandb_enabled = False
365
+ if is_master and args.wandb_project:
366
+ try:
367
+ import wandb
368
+ wandb.init(
369
+ project=args.wandb_project,
370
+ name=args.wandb_run or f"eve2-{world_size}gpu-{args.max_steps}steps",
371
+ config=vars(args),
372
+ )
373
+ wandb_enabled = True
374
+ except ImportError:
375
+ print(" WandB not installed, skipping.")
376
+
377
+ # --- Training loop ---
378
+ model.train()
379
+ tokens_per_step = args.batch_size * world_size * config.block_size
380
+
381
+ if is_master:
382
+ print(f"\n Starting training from step {start_step}...\n")
383
+
384
+ for step in range(start_step, args.max_steps):
385
+ t0 = time.time()
386
+
387
+ # Learning rate schedule
388
+ lr = get_lr(step, args.max_steps, args.warmup_steps, args.learning_rate, args.min_lr_ratio)
389
+ for param_group in optimizer.param_groups:
390
+ param_group["lr"] = lr
391
+
392
+ # Get batch
393
+ x, y = train_loader.get_batch()
394
+ x, y = x.to(device), y.to(device)
395
+
396
+ # Forward
397
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
398
+ logits, loss = model(x, y)
399
+
400
+ # Backward
401
+ optimizer.zero_grad(set_to_none=True)
402
+ loss.backward()
403
+
404
+ # Gradient clipping
405
+ if args.grad_clip > 0:
406
+ grad_norm = torch.nn.utils.clip_grad_norm_(raw_model.parameters(), args.grad_clip)
407
+ else:
408
+ grad_norm = None
409
+
410
+ optimizer.step()
411
+
412
+ # Timing
413
+ torch.cuda.synchronize()
414
+ t1 = time.time()
415
+ dt_ms = (t1 - t0) * 1000
416
+ tok_per_sec = tokens_per_step / (t1 - t0)
417
+
418
+ # --- Logging ---
419
+ if is_master and step % 10 == 0:
420
+ grad_str = f" | Grad: {grad_norm:.2f}" if grad_norm is not None else ""
421
+ print(f" Step {step:>6d}/{args.max_steps} | Loss: {loss.item():.4f} | "
422
+ f"LR: {lr:.2e} | {tok_per_sec:,.0f} tok/s | {dt_ms:.0f}ms{grad_str}")
423
+
424
+ if wandb_enabled:
425
+ import wandb
426
+ log = {
427
+ "train_loss": loss.item(),
428
+ "lr": lr,
429
+ "tokens_per_sec": tok_per_sec,
430
+ "step_time_ms": dt_ms,
431
+ }
432
+ if grad_norm is not None:
433
+ log["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
434
+ wandb.log(log, step=step)
435
+
436
+ # --- Validation ---
437
+ if is_master and val_loader and step > 0 and step % args.val_every == 0:
438
+ val_loss = val_loader.estimate_loss(raw_model)
439
+ print(f" >>> Validation Loss: {val_loss:.4f}")
440
+ if wandb_enabled:
441
+ wandb.log({"val_loss": val_loss}, step=step)
442
+
443
+ # Save checkpoint
444
+ save_checkpoint(model, optimizer, step, loss.item(), val_loss,
445
+ config, checkpoint_dir, is_ddp)
446
+
447
+ # --- Periodic save (no val) ---
448
+ elif is_master and step > 0 and step % args.save_every == 0 and step % args.val_every != 0:
449
+ save_checkpoint(model, optimizer, step, loss.item(), -1.0,
450
+ config, checkpoint_dir, is_ddp)
451
+
452
+ # --- Final validation & save ---
453
+ if is_master:
454
+ print(f"\n{'=' * 60}")
455
+ print(" Training complete!")
456
+
457
+ if val_loader:
458
+ final_val = val_loader.estimate_loss(raw_model)
459
+ print(f" Final Val Loss: {final_val:.4f}")
460
+
461
+ # Save final model for HF upload
462
+ output_dir = Path(args.output_dir)
463
+ save_final_model(model, config, output_dir, is_ddp)
464
+
465
+ # Save final checkpoint too
466
+ save_checkpoint(model, optimizer, args.max_steps, loss.item(),
467
+ final_val if val_loader else -1.0,
468
+ config, checkpoint_dir, is_ddp)
469
+
470
+ print(f"\n Upload to HuggingFace:")
471
+ print(f" huggingface-cli upload anthonym21/Eve-2-MoE-250M {output_dir}/")
472
+ print(f"{'=' * 60}")
473
+
474
+ if wandb_enabled:
475
+ import wandb
476
+ wandb.finish()
477
+
478
+ cleanup_distributed()
479
+
480
+
481
+ if __name__ == "__main__":
482
+ main()