ikaganacar commited on
Commit
d7d2fb2
·
1 Parent(s): 9dd056d

Things got Messy

Browse files
Model_Architecture/data.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from typing import Tuple, Optional, Literal, List
5
+ from pathlib import Path
6
+ from tqdm import tqdm
7
+ import mmap
8
+ import numpy as np
9
+
10
+ from model import ModelArgs
11
+
12
+ #####################################
13
+ # DATA
14
+ #####################################
15
+ class TextDataset(Dataset):
16
+ def __init__(self, txt: str, tokenizer, args: ModelArgs, stride: Optional[int] = None, max_samples: Optional[int] = None):
17
+ """
18
+ Optimized text dataset with memory-mapped reading and batched tokenization.
19
+
20
+ Args:
21
+ txt: Text content or path to file
22
+ tokenizer: Pretrained tokenizer with .encode() method
23
+ args: ModelArgs containing max_seq_len, max_batch_size
24
+ stride: Sliding window stride. Defaults to max_seq_len // 2
25
+ max_samples: Limit number of samples for quick testing
26
+ """
27
+ self.max_seq_len = args.max_seq_len
28
+ self.stride = stride if stride is not None else self.max_seq_len // 2
29
+
30
+ # Handle file paths efficiently with memory mapping
31
+ if Path(txt).exists():
32
+ text_content = self._read_file_mmap(txt)
33
+ else:
34
+ text_content = txt
35
+
36
+ # Validate input
37
+ if not text_content or len(text_content.strip()) < self.max_seq_len:
38
+ raise ValueError(f"Text too short. Need at least {self.max_seq_len} chars, got {len(text_content)}")
39
+
40
+ print(f"📝 Tokenizing {len(text_content):,} characters...")
41
+
42
+ # Tokenize with progress bar for large texts
43
+ token_ids = self._tokenize_with_progress(tokenizer, text_content)
44
+
45
+ # Create sliding windows with vectorized operations
46
+ self.samples = self._create_sliding_windows(token_ids, max_samples)
47
+
48
+ print(f"✅ Created {len(self.samples)} training samples")
49
+
50
+ def _read_file_mmap(self, file_path: str) -> str:
51
+ """Memory-efficient file reading for large files"""
52
+ try:
53
+ with open(file_path, 'r', encoding='utf-8') as f:
54
+ with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
55
+ return mm.read().decode('utf-8', errors='ignore')
56
+ except Exception as e:
57
+ raise RuntimeError(f"Failed to read file {file_path}: {e}")
58
+
59
+ def _tokenize_with_progress(self, tokenizer, text: str) -> List[int]:
60
+ """Tokenize with progress bar for large texts"""
61
+ # Process in chunks for memory efficiency
62
+ chunk_size = 10_000_000 # 10MB chunks
63
+ tokens = []
64
+
65
+ if len(text) > chunk_size:
66
+ # Process large texts in chunks
67
+ pbar = tqdm(total=len(text), desc="Tokenizing", unit="char")
68
+ for i in range(0, len(text), chunk_size):
69
+ chunk = text[i:i + chunk_size]
70
+ chunk_tokens = tokenizer.encode(chunk, allowed_special={"<|endoftext|>"})
71
+ tokens.extend(chunk_tokens)
72
+ pbar.update(len(chunk))
73
+ pbar.close()
74
+ else:
75
+ # Single pass for smaller texts
76
+ tokens = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
77
+
78
+ if not tokens:
79
+ raise ValueError("No tokens generated from input text")
80
+
81
+ return tokens
82
+
83
+ def _create_sliding_windows(self, token_ids: List[int], max_samples: Optional[int]) -> torch.Tensor:
84
+ """Create overlapping sequences using vectorized operations"""
85
+ if len(token_ids) < self.max_seq_len + 1:
86
+ raise ValueError(f"Not enough tokens. Need {self.max_seq_len + 1}, got {len(token_ids)}")
87
+
88
+ # Convert to numpy for faster slicing
89
+ tokens_array = np.array(token_ids, dtype=np.int64)
90
+
91
+ # Calculate number of windows
92
+ num_windows = (len(tokens_array) - self.max_seq_len - 1) // self.stride + 1
93
+
94
+ if max_samples:
95
+ num_windows = min(num_windows, max_samples)
96
+
97
+ # Pre-allocate tensors
98
+ inputs = torch.zeros(num_windows, self.max_seq_len, dtype=torch.long)
99
+ targets = torch.zeros(num_windows, self.max_seq_len, dtype=torch.long)
100
+
101
+ # Fill tensors efficiently
102
+ for i in range(num_windows):
103
+ start = i * self.stride
104
+ inputs[i] = torch.from_numpy(tokens_array[start:start + self.max_seq_len])
105
+ targets[i] = torch.from_numpy(tokens_array[start + 1:start + self.max_seq_len + 1])
106
+
107
+ # Stack into pairs (more memory efficient than separate lists)
108
+ self.samples = torch.stack([inputs, targets], dim=1)
109
+
110
+ return self.samples
111
+
112
+ def __len__(self):
113
+ return len(self.samples)
114
+
115
+ def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
116
+ """Return (input_ids, target_ids) tuple"""
117
+ return self.samples[idx, 0], self.samples[idx, 1]
118
+
119
+
120
+ def create_dataloader(
121
+ txt: str,
122
+ args: ModelArgs,
123
+ stride: Optional[int] = None,
124
+ shuffle: bool = True,
125
+ drop_last: bool = True,
126
+ num_workers: int = 0,
127
+ pin_memory: bool = True,
128
+ persistent_workers: bool = False,
129
+ max_samples: Optional[int] = None
130
+ ) -> DataLoader:
131
+ """
132
+ Optimized DataLoader with proper memory pinning and worker settings.
133
+
134
+ Args:
135
+ txt: Text content or file path
136
+ args: ModelArgs configuration
137
+ stride: Sliding window stride
138
+ shuffle: Whether to shuffle samples
139
+ drop_last: Drop incomplete batches
140
+ num_workers: Number of data loading workers (0 = main process)
141
+ pin_memory: Pin memory for faster GPU transfer (recommended)
142
+ persistent_workers: Keep workers alive between epochs (if num_workers > 0)
143
+ max_samples: Limit samples for testing
144
+ """
145
+ # Use the best default tokenizer for your setup
146
+ # tiktoken's gpt2 is fast, well-tested, and has reasonable vocab size (~50k)
147
+ # For multilingual or code, consider "cl100k_base" or "o200k_base"
148
+ tokenizer_name = getattr(args, "tokenizer_name", "gpt2")
149
+ tokenizer = tiktoken.get_encoding(tokenizer_name)
150
+
151
+ # Create dataset with size validation
152
+ try:
153
+ dataset = TextDataset(
154
+ txt=txt,
155
+ tokenizer=tokenizer,
156
+ args=args,
157
+ stride=stride,
158
+ max_samples=max_samples
159
+ )
160
+ except Exception as e:
161
+ raise RuntimeError(f"Failed to create dataset: {e}")
162
+
163
+ # Create DataLoader with optimized settings
164
+ dataloader = DataLoader(
165
+ dataset,
166
+ batch_size=args.max_batch_size,
167
+ shuffle=shuffle,
168
+ drop_last=drop_last,
169
+ num_workers=num_workers,
170
+ pin_memory=pin_memory,
171
+ persistent_workers=persistent_workers if num_workers > 0 else False,
172
+ prefetch_factor=2 if num_workers > 0 else None,
173
+ )
174
+
175
+ return dataloader
176
+
177
+
178
+ # Convenience function for downloading sample data
179
+ def get_sample_data(url: str = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt") -> str:
180
+ """Download sample text data for testing"""
181
+ try:
182
+ import requests
183
+ response = requests.get(url)
184
+ response.raise_for_status()
185
+ return response.text
186
+ except Exception as e:
187
+ print(f"⚠️ Could not download sample data: {e}")
188
+ return ""
Model_Architecture/generation.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import tiktoken
3
- from model import Transformer, ModelArgs
4
 
5
 
6
  #####################################
@@ -151,7 +151,7 @@ if __name__ == "__main__":
151
  # Initialize model and tokenizer
152
  print("Initializing model...")
153
  torch.manual_seed(123)
154
- model = Transformer(args)
155
  model.eval()
156
 
157
  tokenizer = tiktoken.get_encoding("gpt2")
 
1
  import torch
2
  import tiktoken
3
+ from model import ismail, ModelArgs
4
 
5
 
6
  #####################################
 
151
  # Initialize model and tokenizer
152
  print("Initializing model...")
153
  torch.manual_seed(123)
154
+ model = ismail(args)
155
  model.eval()
156
 
157
  tokenizer = tiktoken.get_encoding("gpt2")
Model_Architecture/model.py CHANGED
@@ -2,7 +2,7 @@ import tiktoken
2
  import torch
3
  import torch.nn as nn
4
  from torch.utils.data import Dataset, DataLoader
5
-
6
  import math
7
  from dataclasses import dataclass
8
  from typing import Tuple, Optional, Literal
@@ -18,31 +18,32 @@ from kernel import act_quant, weight_dequant, fp8_gemm
18
  @dataclass
19
  class ModelArgs:
20
  max_batch_size: int = 8
21
- max_seq_len: int = 4096 * 4
22
  dtype: Literal["bf16", "fp8"] = "bf16"
23
  scale_fmt: Optional[str] = None
 
24
  vocab_size: int = 102400
25
- dim: int = 2048
26
- inter_dim: int = 10944
27
- moe_inter_dim: int = 1408
28
- n_layers: int = 27
29
- n_dense_layers: int = 1
30
- n_heads: int = 16
 
31
  # moe
32
- n_routed_experts: int = 64
33
- n_shared_experts: int = 2
34
- n_activated_experts: int = 6
35
- n_expert_groups: int = 1
36
- n_limited_groups: int = 1
37
- score_func: Literal["softmax", "sigmoid"] = "softmax"
38
  route_scale: float = 1.
39
- use_routing_bias: bool = False # Enable routing bias for fine-tuning expert selection
 
40
  # mla
41
  q_lora_rank: int = 0
42
  kv_lora_rank: int = 512
43
  qk_nope_head_dim: int = 128
44
  qk_rope_head_dim: int = 64
45
  v_head_dim: int = 128
 
46
  # yarn
47
  original_seq_len: int = 4096
48
  rope_theta: float = 10000.0
@@ -58,54 +59,7 @@ block_size = 128
58
  gemm_impl: Literal["bf16", "fp8"] = "bf16"
59
 
60
 
61
- #####################################
62
- # DATA
63
- #####################################
64
- class TextDataset(Dataset):
65
- def __init__(self, txt, tokenizer, args: ModelArgs, stride: Optional[int] = None):
66
- self.input_ids = []
67
- self.target_ids = []
68
-
69
- # Use max_seq_len from ModelArgs
70
- max_length = args.max_seq_len
71
- if stride is None:
72
- stride = max_length // 2 # Default stride is half the sequence length
73
-
74
- # Tokenize the entire text
75
- token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
76
-
77
- # Use a sliding window to chunk the book into overlapping sequences of max_length
78
- for i in range(0, len(token_ids) - max_length, stride):
79
- input_chunk = token_ids[i:i + max_length]
80
- target_chunk = token_ids[i + 1: i + max_length + 1]
81
- self.input_ids.append(torch.tensor(input_chunk))
82
- self.target_ids.append(torch.tensor(target_chunk))
83
 
84
- def __len__(self):
85
- return len(self.input_ids)
86
-
87
- def __getitem__(self, idx):
88
- return self.input_ids[idx], self.target_ids[idx]
89
-
90
-
91
- def create_dataloader(txt, args: ModelArgs, stride: Optional[int] = None,
92
- shuffle: bool = True, drop_last: bool = True, num_workers: int = 0):
93
- # Initialize the tokenizer
94
- tokenizer = tiktoken.get_encoding("gpt2")
95
-
96
- # Create dataset with ModelArgs
97
- dataset = TextDataset(txt, tokenizer, args, stride)
98
-
99
- # Create dataloader using batch_size from ModelArgs
100
- dataloader = DataLoader(
101
- dataset,
102
- batch_size=args.max_batch_size,
103
- shuffle=shuffle,
104
- drop_last=drop_last,
105
- num_workers=num_workers
106
- )
107
-
108
- return dataloader
109
 
110
  #####################################
111
  # RoPE
@@ -321,9 +275,6 @@ class Gate(nn.Module):
321
  self.dim = args.dim
322
  self.n_routed_experts = args.n_routed_experts
323
  self.n_activated_experts = args.n_activated_experts
324
- self.n_expert_groups = args.n_expert_groups
325
- self.n_limited_groups = args.n_limited_groups
326
- self.score_func = args.score_func
327
  self.route_scale = args.route_scale
328
 
329
  # Gate weight
@@ -341,10 +292,7 @@ class Gate(nn.Module):
341
  scores = linear(x, self.weight)
342
 
343
  # Apply scoring function
344
- if self.score_func == "softmax":
345
- scores = scores.softmax(dim=-1, dtype=torch.float32)
346
- else:
347
- scores = scores.sigmoid()
348
 
349
  original_scores = scores
350
 
@@ -352,17 +300,6 @@ class Gate(nn.Module):
352
  if self.bias is not None:
353
  scores = scores + self.bias
354
 
355
- # Expert grouping for load balancing
356
- if self.n_expert_groups > 1:
357
- scores = scores.view(x.size(0), self.n_expert_groups, -1)
358
- if self.bias is None:
359
- group_scores = scores.amax(dim=-1)
360
- else:
361
- group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
362
- indices = group_scores.topk(self.n_limited_groups, dim=-1)[1]
363
- mask = scores.new_ones(x.size(0), self.n_expert_groups, dtype=bool).scatter_(1, indices, False)
364
- scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
365
-
366
  # Select top-k experts
367
  indices = torch.topk(scores, self.n_activated_experts, dim=-1)[1]
368
  weights = original_scores.gather(1, indices)
@@ -391,56 +328,115 @@ class Expert(nn.Module):
391
 
392
 
393
  class MoE(nn.Module):
394
-
395
  def __init__(self, args: ModelArgs):
396
  super().__init__()
397
  self.dim = args.dim
398
  self.n_routed_experts = args.n_routed_experts
399
  self.n_activated_experts = args.n_activated_experts
400
-
401
- # Gate for routing
402
  self.gate = Gate(args)
403
-
404
- # Routed experts
405
  self.experts = nn.ModuleList([
406
  Expert(args.dim, args.moe_inter_dim)
407
  for _ in range(args.n_routed_experts)
408
  ])
409
-
410
- # Shared experts (always process all tokens)
411
  self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
412
-
413
- def forward(self, x: torch.Tensor) -> torch.Tensor:
414
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  original_shape = x.size()
416
  x = x.view(-1, self.dim)
417
-
418
- # Route tokens to experts
419
- weights, indices = self.gate(x)
420
-
421
- # Initialize output for routed experts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  y = torch.zeros_like(x)
423
-
424
- # Process each routed expert
425
  for i in range(self.n_routed_experts):
426
- # Find tokens routed to this expert
427
  idx, top = torch.where(indices == i)
428
  if idx.numel() == 0:
429
  continue
430
-
431
- # Process tokens with this expert
432
- expert_output = self.experts[i](x[idx])
433
-
434
- # Weight and accumulate expert outputs
435
- y[idx] += expert_output * weights[idx, top, None]
436
-
437
- # Process all tokens with shared experts
438
  z = self.shared_experts(x)
439
-
440
- # Combine routed and shared expert outputs
441
  output = (y + z).view(original_shape)
 
 
 
 
 
 
442
 
443
- return output
444
 
445
 
446
  #####################################
@@ -482,7 +478,7 @@ class Block(nn.Module):
482
  # TRANSFORMER MODEL
483
  #####################################
484
 
485
- class Transformer(nn.Module):
486
  def __init__(self, args: ModelArgs):
487
  super().__init__()
488
  self.args = args
 
2
  import torch
3
  import torch.nn as nn
4
  from torch.utils.data import Dataset, DataLoader
5
+ from contextlib import nullcontext
6
  import math
7
  from dataclasses import dataclass
8
  from typing import Tuple, Optional, Literal
 
18
  @dataclass
19
  class ModelArgs:
20
  max_batch_size: int = 8
21
+ max_seq_len: int = 2048
22
  dtype: Literal["bf16", "fp8"] = "bf16"
23
  scale_fmt: Optional[str] = None
24
+
25
  vocab_size: int = 102400
26
+ dim: int = 1024
27
+ inter_dim: int = 4096
28
+ moe_inter_dim: int = 1024
29
+ n_layers: int = 20
30
+ n_dense_layers: int = 3
31
+ n_heads: int = 12
32
+
33
  # moe
34
+ n_routed_experts: int = 6
35
+ n_shared_experts: int = 1
36
+ n_activated_experts: int = 2
 
 
 
37
  route_scale: float = 1.
38
+ use_routing_bias: bool = True # Enable routing bias for fine-tuning expert selection
39
+
40
  # mla
41
  q_lora_rank: int = 0
42
  kv_lora_rank: int = 512
43
  qk_nope_head_dim: int = 128
44
  qk_rope_head_dim: int = 64
45
  v_head_dim: int = 128
46
+
47
  # yarn
48
  original_seq_len: int = 4096
49
  rope_theta: float = 10000.0
 
59
  gemm_impl: Literal["bf16", "fp8"] = "bf16"
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  #####################################
65
  # RoPE
 
275
  self.dim = args.dim
276
  self.n_routed_experts = args.n_routed_experts
277
  self.n_activated_experts = args.n_activated_experts
 
 
 
278
  self.route_scale = args.route_scale
279
 
280
  # Gate weight
 
292
  scores = linear(x, self.weight)
293
 
294
  # Apply scoring function
295
+ scores = scores.sigmoid()
 
 
 
296
 
297
  original_scores = scores
298
 
 
300
  if self.bias is not None:
301
  scores = scores + self.bias
302
 
 
 
 
 
 
 
 
 
 
 
 
303
  # Select top-k experts
304
  indices = torch.topk(scores, self.n_activated_experts, dim=-1)[1]
305
  weights = original_scores.gather(1, indices)
 
328
 
329
 
330
  class MoE(nn.Module):
 
331
  def __init__(self, args: ModelArgs):
332
  super().__init__()
333
  self.dim = args.dim
334
  self.n_routed_experts = args.n_routed_experts
335
  self.n_activated_experts = args.n_activated_experts
336
+ self.active_expert_idx = None # None = all active (inference mode)
337
+
338
  self.gate = Gate(args)
 
 
339
  self.experts = nn.ModuleList([
340
  Expert(args.dim, args.moe_inter_dim)
341
  for _ in range(args.n_routed_experts)
342
  ])
 
 
343
  self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
344
+ self.ffn_norm = RMSNorm(args.dim)
345
+
346
+ # Load balance loss coefficient
347
+ self.lb_loss_coef = 0.01
348
+
349
+ def set_active_expert(self, expert_idx: Optional[int]):
350
+ """Freeze all but the active expert to save optimizer memory"""
351
+ self.active_expert_idx = expert_idx
352
+
353
+ for i, expert in enumerate(self.experts):
354
+ requires_grad = (expert_idx is None) or (i == expert_idx)
355
+ for param in expert.parameters():
356
+ param.requires_grad = requires_grad
357
+
358
+ def compute_load_balance_loss(self, router_probs, expert_indices):
359
+ """Encourage uniform expert utilization"""
360
+ # router_probs: [num_tokens, n_experts]
361
+ # expert_indices: [num_tokens, top_k]
362
+
363
+ # Token fraction per expert
364
+ tokens_per_expert = torch.zeros(self.n_routed_experts, device=router_probs.device)
365
+ indices_flat = expert_indices.view(-1)
366
+ ones = torch.ones_like(indices_flat, dtype=torch.float32)
367
+ tokens_per_expert.scatter_add_(0, indices_flat, ones)
368
+ tokens_per_expert = tokens_per_expert / (indices_flat.numel() + 1e-8)
369
+
370
+ # Average routing probability per expert
371
+ router_prob_per_expert = router_probs.mean(dim=0)
372
+
373
+ # Load balancing loss (minimize difference)
374
+ loss = torch.mean(tokens_per_expert * router_prob_per_expert) * self.n_routed_experts
375
+ return loss
376
+
377
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
378
  original_shape = x.size()
379
  x = x.view(-1, self.dim)
380
+
381
+ # Always compute routing (even in sequential mode to train the gate)
382
+ router_logits = F.linear(x, self.gate.weight)
383
+ router_probs = router_logits.sigmoid()
384
+
385
+ if self.gate.bias is not None:
386
+ router_logits = router_logits + self.gate.bias
387
+
388
+ # Select top-k experts
389
+ weights, indices = torch.topk(router_probs, self.n_activated_experts, dim=-1)
390
+
391
+ # Normalize weights
392
+ if self.gate.score_func == "sigmoid":
393
+ weights = weights / weights.sum(dim=-1, keepdim=True)
394
+ weights = weights * self.gate.route_scale
395
+
396
+ # Sequential Training Mode
397
+ if self.training and self.active_expert_idx is not None:
398
+ y = torch.zeros_like(x)
399
+
400
+ # Only compute gradients for active expert
401
+ for i in range(self.n_routed_experts):
402
+ idx, top = torch.where(indices == i)
403
+ if idx.numel() == 0:
404
+ continue
405
+
406
+ # Use gradient context manager
407
+ grad_context = nullcontext() if i == self.active_expert_idx else torch.no_grad()
408
+
409
+ with grad_context:
410
+ expert_out = self.experts[i](x[idx])
411
+ y[idx] += expert_out * weights[idx, top, None]
412
+
413
+ # Load balance loss (still needed for gate training)
414
+ lb_loss = self.compute_load_balance_loss(router_probs, indices)
415
+
416
+ # Shared experts always train
417
+ z = self.shared_experts(x)
418
+
419
+ return (y + z).view(original_shape), lb_loss
420
+
421
+ # Normal MoE Mode (inference or full training)
422
  y = torch.zeros_like(x)
 
 
423
  for i in range(self.n_routed_experts):
 
424
  idx, top = torch.where(indices == i)
425
  if idx.numel() == 0:
426
  continue
427
+
428
+ expert_out = self.experts[i](x[idx])
429
+ y[idx] += expert_out * weights[idx, top, None]
430
+
 
 
 
 
431
  z = self.shared_experts(x)
 
 
432
  output = (y + z).view(original_shape)
433
+
434
+ if self.training:
435
+ lb_loss = self.compute_load_balance_loss(router_probs, indices)
436
+ return output, lb_loss
437
+ else:
438
+ return output, None
439
 
 
440
 
441
 
442
  #####################################
 
478
  # TRANSFORMER MODEL
479
  #####################################
480
 
481
+ class ismail(nn.Module):
482
  def __init__(self, args: ModelArgs):
483
  super().__init__()
484
  self.args = args
Model_Architecture/model_size.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ # Add the Model_Architecture directory to path
5
+ sys.path.insert(0, str(Path(__file__).parent))
6
+
7
+ from model import ModelArgs
8
+
9
+ def estimate_model_size(args: ModelArgs):
10
+ """Calculate detailed model size and parameter count"""
11
+
12
+ print(f"\n{'='*70}")
13
+ print(f"MODEL ARCHITECTURE ANALYSIS: ismail")
14
+ print(f"{'='*70}\n")
15
+
16
+ # Display configuration
17
+ print(f"📋 CONFIGURATION:")
18
+ print(f" Model dimension (dim): {args.dim}")
19
+ print(f" Vocabulary size: {args.vocab_size:,}")
20
+ print(f" Number of layers: {args.n_layers}")
21
+ print(f" Dense layers: {args.n_dense_layers}")
22
+ print(f" MoE layers: {args.n_layers - args.n_dense_layers}")
23
+ print(f" Attention heads: {args.n_heads}")
24
+ print(f" Max sequence length: {args.max_seq_len}")
25
+ print(f" Max batch size: {args.max_batch_size}")
26
+ print(f" \nMoE Configuration:")
27
+ print(f" Routed experts: {args.n_routed_experts}")
28
+ print(f" Shared experts: {args.n_shared_experts}")
29
+ print(f" Activated experts: {args.n_activated_experts}")
30
+ print(f" \nMLA Configuration:")
31
+ print(f" Q LoRA rank: {args.q_lora_rank}")
32
+ print(f" KV LoRA rank: {args.kv_lora_rank}")
33
+ print(f" QK nope head dim: {args.qk_nope_head_dim}")
34
+ print(f" QK rope head dim: {args.qk_rope_head_dim}")
35
+ print(f" V head dim: {args.v_head_dim}")
36
+
37
+ # Calculate parameters by component
38
+ print(f"\n{'='*70}")
39
+ print(f"🔢 PARAMETER COUNT BY COMPONENT:")
40
+ print(f"{'='*70}\n")
41
+
42
+ # 1. Embeddings
43
+ tok_embed_params = args.vocab_size * args.dim
44
+ output_params = args.vocab_size * args.dim
45
+ total_embed_params = tok_embed_params + output_params
46
+ print(f" Token Embeddings: {tok_embed_params:>15,} params")
47
+ print(f" Output Layer: {output_params:>15,} params")
48
+ print(f" {'─' * 50}")
49
+ print(f" Total Embeddings: {total_embed_params:>15,} params\n")
50
+
51
+ # 2. Attention (per layer)
52
+ if args.q_lora_rank == 0:
53
+ wq_params = args.dim * args.n_heads * (args.qk_nope_head_dim + args.qk_rope_head_dim)
54
+ wq_norm_params = 0
55
+ else:
56
+ wq_params = args.dim * args.q_lora_rank + args.q_lora_rank * args.n_heads * (args.qk_nope_head_dim + args.qk_rope_head_dim)
57
+ wq_norm_params = args.q_lora_rank
58
+
59
+ wkv_a_params = args.dim * (args.kv_lora_rank + args.qk_rope_head_dim)
60
+ kv_norm_params = args.kv_lora_rank
61
+ wkv_b_params = args.kv_lora_rank * args.n_heads * (args.qk_nope_head_dim + args.v_head_dim)
62
+ wo_params = args.n_heads * args.v_head_dim * args.dim
63
+ attn_norm_params = args.dim
64
+
65
+ attn_params_per_layer = wq_params + wq_norm_params + wkv_a_params + kv_norm_params + wkv_b_params + wo_params + attn_norm_params
66
+
67
+ print(f" Attention (per layer):")
68
+ if args.q_lora_rank > 0:
69
+ print(f" WQ (LoRA): {wq_params:>15,} params")
70
+ print(f" Q Norm: {wq_norm_params:>15,} params")
71
+ else:
72
+ print(f" WQ: {wq_params:>15,} params")
73
+ print(f" WKV_A: {wkv_a_params:>15,} params")
74
+ print(f" KV Norm: {kv_norm_params:>15,} params")
75
+ print(f" WKV_B: {wkv_b_params:>15,} params")
76
+ print(f" WO: {wo_params:>15,} params")
77
+ print(f" Attn Norm: {attn_norm_params:>15,} params")
78
+ print(f" {'─' * 50}")
79
+ print(f" Subtotal: {attn_params_per_layer:>15,} params\n")
80
+
81
+ # 3. Dense FFN
82
+ dense_w1_params = args.dim * args.inter_dim
83
+ dense_w2_params = args.inter_dim * args.dim
84
+ dense_w3_params = args.dim * args.inter_dim
85
+ ffn_norm_params = args.dim
86
+ dense_ffn_per_layer = dense_w1_params + dense_w2_params + dense_w3_params + ffn_norm_params
87
+
88
+ print(f" Dense FFN (per layer):")
89
+ print(f" FC1 (W1): {dense_w1_params:>15,} params")
90
+ print(f" FC2 (W3): {dense_w3_params:>15,} params")
91
+ print(f" FC3 (W2): {dense_w2_params:>15,} params")
92
+ print(f" FFN Norm: {ffn_norm_params:>15,} params")
93
+ print(f" {'─' * 50}")
94
+ print(f" Subtotal: {dense_ffn_per_layer:>15,} params\n")
95
+
96
+ # 4. MoE FFN
97
+ gate_params = args.n_routed_experts * args.dim
98
+ if args.use_routing_bias:
99
+ gate_params += args.n_routed_experts
100
+
101
+ expert_w1_params = args.dim * args.moe_inter_dim
102
+ expert_w2_params = args.moe_inter_dim * args.dim
103
+ expert_w3_params = args.dim * args.moe_inter_dim
104
+ per_expert_params = expert_w1_params + expert_w2_params + expert_w3_params
105
+ routed_experts_params = args.n_routed_experts * per_expert_params
106
+
107
+ shared_w1_params = args.dim * (args.n_shared_experts * args.moe_inter_dim)
108
+ shared_w2_params = (args.n_shared_experts * args.moe_inter_dim) * args.dim
109
+ shared_w3_params = args.dim * (args.n_shared_experts * args.moe_inter_dim)
110
+ shared_experts_params = shared_w1_params + shared_w2_params + shared_w3_params
111
+
112
+ moe_ffn_per_layer = gate_params + routed_experts_params + shared_experts_params + ffn_norm_params
113
+
114
+ print(f" MoE FFN (per layer):")
115
+ print(f" Gate: {gate_params:>15,} params")
116
+ print(f" Routed Experts ({args.n_routed_experts}x): {routed_experts_params:>15,} params")
117
+ print(f" Per expert: {per_expert_params:>15,} params")
118
+ print(f" Shared Experts: {shared_experts_params:>15,} params")
119
+ print(f" FFN Norm: {ffn_norm_params:>15,} params")
120
+ print(f" {'─' * 50}")
121
+ print(f" Subtotal: {moe_ffn_per_layer:>15,} params\n")
122
+
123
+ # 5. Final Norm
124
+ final_norm_params = args.dim
125
+
126
+ # Total calculation
127
+ dense_layer_params = attn_params_per_layer + dense_ffn_per_layer
128
+ moe_layer_params = attn_params_per_layer + moe_ffn_per_layer
129
+
130
+ total_dense_params = args.n_dense_layers * dense_layer_params
131
+ total_moe_params = (args.n_layers - args.n_dense_layers) * moe_layer_params
132
+
133
+ total_params = total_embed_params + total_dense_params + total_moe_params + final_norm_params
134
+
135
+ print(f" Layer Summary:")
136
+ print(f" Dense layers ({args.n_dense_layers}x): {total_dense_params:>15,} params")
137
+ print(f" MoE layers ({args.n_layers - args.n_dense_layers}x): {total_moe_params:>15,} params")
138
+ print(f" Final Norm: {final_norm_params:>15,} params")
139
+
140
+ print(f"\n{'='*70}")
141
+ print(f"📊 TOTAL PARAMETERS: {total_params:>15,} ({total_params/1e6:.2f}M)")
142
+ print(f"{'='*70}\n")
143
+
144
+ # Memory calculations
145
+ print(f"{'='*70}")
146
+ print(f"💾 MEMORY USAGE:")
147
+ print(f"{'='*70}\n")
148
+
149
+ bytes_per_param_bf16 = 2
150
+ bytes_per_param_fp32 = 4
151
+
152
+ # Model weights
153
+ weight_memory_bf16 = total_params * bytes_per_param_bf16 / (1024**3)
154
+ weight_memory_fp32 = total_params * bytes_per_param_fp32 / (1024**3)
155
+
156
+ print(f" Model Weights:")
157
+ print(f" BF16 (inference): {weight_memory_bf16:>10.3f} GB")
158
+ print(f" FP32 (training): {weight_memory_fp32:>10.3f} GB\n")
159
+
160
+ # KV Cache
161
+ kv_cache_per_layer = args.max_batch_size * args.max_seq_len * (args.kv_lora_rank + args.qk_rope_head_dim)
162
+ total_kv_cache = kv_cache_per_layer * args.n_layers * bytes_per_param_bf16 / (1024**3)
163
+
164
+ print(f" KV Cache (BF16):")
165
+ print(f" Per layer: {kv_cache_per_layer * bytes_per_param_bf16 / (1024**3):>10.3f} GB")
166
+ print(f" Total ({args.n_layers} layers): {total_kv_cache:>10.3f} GB\n")
167
+
168
+ # Activations (rough estimate)
169
+ activation_memory = (args.max_batch_size * args.max_seq_len * args.dim * args.n_layers * 4) / (1024**3)
170
+
171
+ print(f" Activations (estimate): {activation_memory:>10.3f} GB\n")
172
+
173
+ # Training overhead
174
+ gradients_memory = weight_memory_fp32 # Same size as weights
175
+ optimizer_states = weight_memory_fp32 * 2 # Adam: 2x for momentum + variance
176
+ training_overhead = gradients_memory + optimizer_states
177
+
178
+ print(f" Training Overhead (FP32):")
179
+ print(f" Gradients: {gradients_memory:>10.3f} GB")
180
+ print(f" Optimizer states (Adam): {optimizer_states:>10.3f} GB")
181
+ print(f" Total overhead: {training_overhead:>10.3f} GB\n")
182
+
183
+ # Total estimates
184
+ inference_total = weight_memory_bf16 + total_kv_cache + activation_memory
185
+ training_total = weight_memory_fp32 + total_kv_cache + activation_memory + training_overhead
186
+
187
+ print(f"{'='*70}")
188
+ print(f" INFERENCE (BF16): {inference_total:>10.3f} GB")
189
+ print(f" TRAINING (FP32 + Adam): {training_total:>10.3f} GB")
190
+ print(f"{'='*70}\n")
191
+
192
+ # Memory analysis
193
+ print(f"{'='*70}")
194
+ print(f"🎯 MEMORY ANALYSIS:")
195
+ print(f"{'='*70}\n")
196
+
197
+ for threshold, name in [(8, "8GB"), (16, "16GB"), (24, "24GB"), (32, "32GB"), (40, "40GB"), (48, "48GB"), (80, "80GB")]:
198
+ if inference_total <= threshold:
199
+ print(f" ✅ Inference fits in {name} GPU")
200
+ break
201
+ else:
202
+ print(f" ❌ Inference requires >80GB GPU")
203
+
204
+ for threshold, name in [(8, "8GB"), (16, "16GB"), (24, "24GB"), (32, "32GB"), (40, "40GB"), (48, "48GB"), (80, "80GB")]:
205
+ if training_total <= threshold:
206
+ print(f" ✅ Training fits in {name} GPU")
207
+ break
208
+ else:
209
+ print(f" ❌ Training requires >80GB GPU")
210
+
211
+ print(f"\n{'='*70}\n")
212
+
213
+ return {
214
+ 'total_params': total_params,
215
+ 'weight_memory_gb': weight_memory_bf16,
216
+ 'inference_memory_gb': inference_total,
217
+ 'training_memory_gb': training_total
218
+ }
219
+
220
+
221
+ if __name__ == "__main__":
222
+ # Load default configuration
223
+ args = ModelArgs()
224
+
225
+ # Run estimation
226
+ results = estimate_model_size(args)
Model_Architecture/train.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Sequential Expert Training Script for MoE on Single GPU
4
+ Memory Usage: ~7.2GB (vs 10.9GB for full MoE)
5
+ """
6
+
7
+ import argparse
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from pathlib import Path
11
+ import json
12
+ import time
13
+ import math
14
+
15
+ # Import your model
16
+ from model import ismail, ModelArgs
17
+ from model_size import estimate_model_size
18
+
19
+ # Try to import optional dependencies
20
+ try:
21
+ import wandb
22
+ HAS_WANDB = True
23
+ except ImportError:
24
+ HAS_WANDB = False
25
+ print("⚠️ wandb not installed. Run 'pip install wandb' for experiment tracking.")
26
+
27
+ try:
28
+ import bitsandbytes as bnb
29
+ HAS_BNB = True
30
+ except ImportError:
31
+ HAS_BNB = False
32
+ print("⚠️ bitsandbytes not installed. Run 'pip install bitsandbytes' for memory-efficient optimizer.")
33
+
34
+ # Configuration
35
+ DEFAULT_CONFIG = {
36
+ "model": {
37
+ "vocab_size": 32000, # Reduced from 102400
38
+ "dim": 1024,
39
+ "inter_dim": 4096,
40
+ "moe_inter_dim": 1024,
41
+ "n_layers": 16,
42
+ "n_dense_layers": 1, # Only first layer dense
43
+ "n_heads": 16, # Increased for better parallelism
44
+ # MoE
45
+ "n_routed_experts": 6,
46
+ "n_shared_experts": 1,
47
+ "n_activated_experts": 2,
48
+ # MLA
49
+ "q_lora_rank": 128, # Enable Q LoRA
50
+ "kv_lora_rank": 512,
51
+ "qk_nope_head_dim": 64,
52
+ "qk_rope_head_dim": 32,
53
+ "v_head_dim": 64,
54
+ # Sequence
55
+ "max_seq_len": 2048, # Start shorter
56
+ "max_batch_size": 4,
57
+ },
58
+ "training": {
59
+ "learning_rate": 3e-4,
60
+ "weight_decay": 0.1,
61
+ "beta1": 0.9,
62
+ "beta2": 0.95,
63
+ "grad_clip": 1.0,
64
+ "warmup_steps": 1000,
65
+ "total_steps": 50000,
66
+ "expert_rotation_steps": 2000, # Rotate expert every N steps
67
+ "gradient_accumulation_steps": 16,
68
+ "eval_every": 1000,
69
+ "save_every": 5000,
70
+ "save_dir": "./checkpoints",
71
+ "log_every": 100,
72
+ "dtype": "bf16",
73
+ "compile": True, # PyTorch 2.0+ compilation
74
+ },
75
+ "data": {
76
+ "train_file": "./data/train.txt",
77
+ "val_file": "./data/val.txt",
78
+ "stride": 512,
79
+ },
80
+ "logging": {
81
+ "use_wandb": HAS_WANDB,
82
+ "project_name": "sequential-moe",
83
+ "run_name": "moe-12gb-gpu",
84
+ }
85
+ }
86
+
87
+
88
+ def parse_args():
89
+ parser = argparse.ArgumentParser(description="Train MoE model with sequential experts")
90
+ parser.add_argument("--config", type=str, help="Path to config JSON")
91
+ parser.add_argument("--train_file", type=str, help="Training text file")
92
+ parser.add_argument("--val_file", type=str, help="Validation text file")
93
+ parser.add_argument("--save_dir", type=str, default="./checkpoints")
94
+ parser.add_argument("--resume", type=str, help="Checkpoint to resume from")
95
+ parser.add_argument("--no_wandb", action="store_true", help="Disable wandb")
96
+ return parser.parse_args()
97
+
98
+
99
+ def load_config(args):
100
+ """Load and merge configuration"""
101
+ config = DEFAULT_CONFIG.copy()
102
+
103
+ if args.config and Path(args.config).exists():
104
+ with open(args.config) as f:
105
+ user_config = json.load(f)
106
+ # Deep merge
107
+ for key, value in user_config.items():
108
+ if key in config and isinstance(value, dict):
109
+ config[key].update(value)
110
+ else:
111
+ config[key] = value
112
+
113
+ # Override from CLI args
114
+ if args.train_file:
115
+ config["data"]["train_file"] = args.train_file
116
+ if args.val_file:
117
+ config["data"]["val_file"] = args.val_file
118
+ if args.save_dir:
119
+ config["training"]["save_dir"] = args.save_dir
120
+ if args.no_wandb:
121
+ config["logging"]["use_wandb"] = False
122
+
123
+ return config
124
+
125
+
126
+ def setup_model(config, device):
127
+ """Initialize model and print size estimate"""
128
+ args = ModelArgs(**config["model"])
129
+
130
+ print("\n" + "="*70)
131
+ print("MODEL INITIALIZATION")
132
+ print("="*70 + "\n")
133
+
134
+ # Estimate size
135
+ size_info = estimate_model_size(args)
136
+
137
+ model = ismail(args).to(device)
138
+
139
+ # Compile for speed (PyTorch 2.0+)
140
+ if config["training"]["compile"]:
141
+ try:
142
+ model = torch.compile(model)
143
+ print("✅ Model compiled with torch.compile()\n")
144
+ except Exception as e:
145
+ print(f"⚠️ Compilation failed: {e}\n")
146
+
147
+ return model, args
148
+
149
+
150
+ def setup_optimizer(model, config):
151
+ """Setup memory-efficient optimizer"""
152
+ training_cfg = config["training"]
153
+
154
+ # Separate parameter groups
155
+ expert_params = []
156
+ base_params = []
157
+ router_params = []
158
+
159
+ for name, param in model.named_parameters():
160
+ if "experts" in name and "shared" not in name:
161
+ expert_params.append(param)
162
+ elif "gate" in name:
163
+ router_params.append(param)
164
+ else:
165
+ base_params.append(param)
166
+
167
+ # Use 8-bit Adam if available
168
+ if HAS_BNB:
169
+ optimizer_class = bnb.optim.AdamW8bit
170
+ print("✅ Using AdamW8bit for memory efficiency")
171
+ else:
172
+ optimizer_class = torch.optim.AdamW
173
+ print("⚠️ Using standard AdamW (install bitsandbytes for memory savings)")
174
+
175
+ optimizer = optimizer_class(
176
+ [
177
+ {"params": base_params, "weight_decay": training_cfg["weight_decay"]},
178
+ {"params": expert_params, "weight_decay": training_cfg["weight_decay"]},
179
+ {"params": router_params, "weight_decay": 0.0}, # Usually no WD for router
180
+ ],
181
+ lr=training_cfg["learning_rate"],
182
+ betas=(training_cfg["beta1"], training_cfg["beta2"]),
183
+ )
184
+
185
+ return optimizer
186
+
187
+
188
+ def get_lr(step, config):
189
+ """Learning rate scheduler with warmup and cosine decay"""
190
+ training_cfg = config["training"]
191
+ warmup_steps = training_cfg["warmup_steps"]
192
+ total_steps = training_cfg["total_steps"]
193
+ base_lr = training_cfg["learning_rate"]
194
+
195
+ if step < warmup_steps:
196
+ return base_lr * step / warmup_steps
197
+
198
+ # Cosine decay
199
+ progress = (step - warmup_steps) / (total_steps - warmup_steps)
200
+ return base_lr * 0.5 * (1 + math.cos(math.pi * progress))
201
+
202
+
203
+ def load_data(config):
204
+ """Create data loaders"""
205
+ data_cfg = config["data"]
206
+
207
+ print("\n" + "="*70)
208
+ print("DATA LOADING")
209
+ print("="*70 + "\n")
210
+
211
+ from data import create_dataloader
212
+
213
+ train_loader = create_dataloader(
214
+ txt=Path(data_cfg["train_file"]).read_text(encoding="utf-8"),
215
+ args=ModelArgs(**config["model"]),
216
+ stride=data_cfg["stride"],
217
+ shuffle=True,
218
+ drop_last=True,
219
+ )
220
+
221
+ val_loader = create_dataloader(
222
+ txt=Path(data_cfg["val_file"]).read_text(encoding="utf-8"),
223
+ args=ModelArgs(**config["model"]),
224
+ stride=data_cfg["stride"],
225
+ shuffle=False,
226
+ drop_last=True,
227
+ )
228
+
229
+ print(f"✅ Train batches: {len(train_loader)}")
230
+ print(f"✅ Val batches: {len(val_loader)}\n")
231
+
232
+ return train_loader, val_loader
233
+
234
+
235
+ def evaluate(model, val_loader, device, config):
236
+ """Evaluate model on validation set"""
237
+ model.eval()
238
+ total_loss = 0.0
239
+ total_tokens = 0
240
+
241
+ with torch.no_grad():
242
+ for input_ids, target_ids in val_loader:
243
+ input_ids = input_ids.to(device)
244
+ target_ids = target_ids.to(device)
245
+
246
+ logits, lb_loss = model(input_ids, start_pos=0)
247
+ loss = F.cross_entropy(
248
+ logits.view(-1, logits.size(-1)),
249
+ target_ids.view(-1),
250
+ ignore_index=-1,
251
+ )
252
+
253
+ total_loss += loss.item() * target_ids.numel()
254
+ total_tokens += target_ids.numel()
255
+
256
+ model.train()
257
+ return total_loss / total_tokens
258
+
259
+
260
+ def save_checkpoint(model, optimizer, step, config, expert_idx=None):
261
+ """Save model checkpoint"""
262
+ save_dir = Path(config["training"]["save_dir"])
263
+ save_dir.mkdir(parents=True, exist_ok=True)
264
+
265
+ # Create checkpoint name
266
+ if expert_idx is not None:
267
+ ckpt_name = f"step_{step}_expert_{expert_idx}.pt"
268
+ else:
269
+ ckpt_name = f"step_{step}.pt"
270
+
271
+ ckpt_path = save_dir / ckpt_name
272
+
273
+ checkpoint = {
274
+ "step": step,
275
+ "model_state_dict": model.state_dict(),
276
+ "optimizer_state_dict": optimizer.state_dict(),
277
+ "config": config,
278
+ }
279
+
280
+ torch.save(checkpoint, ckpt_path)
281
+ print(f"💾 Checkpoint saved: {ckpt_path}")
282
+
283
+
284
+ def train_step(model, batch, device, config, scaler=None):
285
+ """Single training step"""
286
+ input_ids, target_ids = batch
287
+ input_ids = input_ids.to(device, non_blocking=True)
288
+ target_ids = target_ids.to(device, non_blocking=True)
289
+
290
+ # Forward pass
291
+ with torch.cuda.amp.autocast(enabled=(config["training"]["dtype"] == "bf16")):
292
+ logits, lb_loss = model(input_ids, start_pos=0)
293
+
294
+ # Main language modeling loss
295
+ lm_loss = F.cross_entropy(
296
+ logits.view(-1, logits.size(-1)),
297
+ target_ids.view(-1),
298
+ ignore_index=-1,
299
+ )
300
+
301
+ # Total loss with load balancing
302
+ total_loss = lm_loss + config["training"].get("lb_loss_coef", 0.01) * lb_loss
303
+
304
+ return total_loss, lm_loss, lb_loss
305
+
306
+
307
+ def main():
308
+ args = parse_args()
309
+ config = load_config(args)
310
+
311
+ # Device setup
312
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
313
+ torch.backends.cuda.matmul.allow_tf32 = True
314
+ torch.backends.cudnn.allow_tf32 = True
315
+
316
+ # Wandb setup
317
+ if config["logging"]["use_wandb"] and HAS_WANDB:
318
+ wandb.init(
319
+ project=config["logging"]["project_name"],
320
+ name=config["logging"]["run_name"],
321
+ config=config,
322
+ )
323
+
324
+ # Model setup
325
+ model, model_args = setup_model(config, device)
326
+
327
+ # Optimizer setup
328
+ optimizer = setup_optimizer(model, config)
329
+
330
+ # Data setup
331
+ train_loader, val_loader = load_data(config)
332
+ train_iter = iter(train_loader)
333
+
334
+ # Training state
335
+ step = 0
336
+ best_val_loss = float("inf")
337
+
338
+ # Resume from checkpoint
339
+ if args.resume:
340
+ ckpt = torch.load(args.resume, map_location=device)
341
+ model.load_state_dict(ckpt["model_state_dict"])
342
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
343
+ step = ckpt["step"]
344
+ print(f"✅ Resumed from step {step}\n")
345
+
346
+ # Gradient scaler for mixed precision
347
+ scaler = torch.cuda.amp.GradScaler(enabled=(config["training"]["dtype"] == "bf16"))
348
+
349
+ # Expert rotation schedule
350
+ current_expert = 0
351
+ rotation_steps = config["training"]["expert_rotation_steps"]
352
+
353
+ # Set initial expert
354
+ model.set_active_expert(current_expert)
355
+ print(f"🎯 Training expert {current_expert}/{model_args.n_routed_experts - 1}")
356
+
357
+ # Training loop
358
+ print("\n" + "="*70)
359
+ print("TRAINING STARTED")
360
+ print("="*70 + "\n")
361
+
362
+ model.train()
363
+
364
+ while step < config["training"]["total_steps"]:
365
+ step_start = time.time()
366
+
367
+ # Expert rotation
368
+ if step > 0 and step % rotation_steps == 0:
369
+ current_expert = (current_expert + 1) % model_args.n_routed_experts
370
+ model.set_active_expert(current_expert)
371
+ print(f"\n🔄 Rotating to expert {current_expert}/{model_args.n_routed_experts - 1}")
372
+
373
+ # Clear gradients after rotation
374
+ optimizer.zero_grad(set_to_none=True)
375
+
376
+ # Get batch with cycle handling
377
+ try:
378
+ batch = next(train_iter)
379
+ except StopIteration:
380
+ train_iter = iter(train_loader)
381
+ batch = next(train_iter)
382
+
383
+ # Training step with gradient accumulation
384
+ accum_steps = config["training"]["gradient_accumulation_steps"]
385
+ total_loss_accum = 0.0
386
+ lm_loss_accum = 0.0
387
+ lb_loss_accum = 0.0
388
+
389
+ for accum_step in range(accum_steps):
390
+ # Split batch for micro-batching (if needed)
391
+ # For now, process full batch
392
+ loss, lm_loss, lb_loss = train_step(model, batch, device, config, scaler)
393
+
394
+ # Normalize for accumulation
395
+ loss = loss / accum_steps
396
+
397
+ # Backward pass
398
+ if config["training"]["dtype"] == "bf16":
399
+ scaler.scale(loss).backward()
400
+ else:
401
+ loss.backward()
402
+
403
+ total_loss_accum += loss.item()
404
+ lm_loss_accum += lm_loss.item() / accum_steps
405
+ lb_loss_accum += lb_loss.item() / accum_steps
406
+
407
+ # Gradient clipping
408
+ if config["training"]["grad_clip"] > 0:
409
+ if config["training"]["dtype"] == "bf16":
410
+ scaler.unscale_(optimizer)
411
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config["training"]["grad_clip"])
412
+
413
+ # Optimizer step
414
+ if config["training"]["dtype"] == "bf16":
415
+ scaler.step(optimizer)
416
+ scaler.update()
417
+ else:
418
+ optimizer.step()
419
+
420
+ optimizer.zero_grad(set_to_none=True)
421
+
422
+ # LR scheduling
423
+ lr = get_lr(step, config)
424
+ for param_group in optimizer.param_groups:
425
+ param_group["lr"] = lr
426
+
427
+ # Logging
428
+ if step % config["training"]["log_every"] == 0:
429
+ step_time = time.time() - step_start
430
+ tokens_per_sec = (model_args.max_batch_size * model_args.max_seq_len) / step_time
431
+
432
+ print(f"Step {step:6d} | "
433
+ f"Loss: {lm_loss_accum:.4f} | "
434
+ f"LB Loss: {lb_loss_accum:.4f} | "
435
+ f"LR: {lr:.2e} | "
436
+ f"Expert: {current_expert} | "
437
+ f"Tokens/s: {tokens_per_sec:.0f}")
438
+
439
+ if config["logging"]["use_wandb"] and HAS_WANDB:
440
+ wandb.log({
441
+ "step": step,
442
+ "loss": lm_loss_accum,
443
+ "load_balance_loss": lb_loss_accum,
444
+ "total_loss": total_loss_accum,
445
+ "learning_rate": lr,
446
+ "active_expert": current_expert,
447
+ "tokens_per_sec": tokens_per_sec,
448
+ "gpu_memory_gb": torch.cuda.memory_allocated() / 1024**3,
449
+ })
450
+
451
+ # Evaluation
452
+ if step % config["training"]["eval_every"] == 0 and step > 0:
453
+ print(f"\n📊 Evaluating at step {step}...")
454
+ val_loss = evaluate(model, val_loader, device, config)
455
+ print(f"Val Loss: {val_loss:.4f} | Perplexity: {math.exp(val_loss):.2f}\n")
456
+
457
+ if config["logging"]["use_wandb"] and HAS_WANDB:
458
+ wandb.log({"val_loss": val_loss, "val_perplexity": math.exp(val_loss)})
459
+
460
+ # Save best model
461
+ if val_loss < best_val_loss:
462
+ best_val_loss = val_loss
463
+ save_checkpoint(model, optimizer, step, config, expert_idx="best")
464
+
465
+ # Save checkpoint
466
+ if step % config["training"]["save_every"] == 0 and step > 0:
467
+ save_checkpoint(model, optimizer, step, config, expert_idx=current_expert)
468
+
469
+ step += 1
470
+
471
+ # Final save
472
+ save_checkpoint(model, optimizer, step, config, expert_idx="final")
473
+
474
+ if config["logging"]["use_wandb"] and HAS_WANDB:
475
+ wandb.finish()
476
+
477
+ print("\n" + "="*70)
478
+ print("TRAINING COMPLETED")
479
+ print("="*70)
480
+
481
+
482
+ if __name__ == "__main__":
483
+ main()