Commit ·
d7d2fb2
1
Parent(s): 9dd056d
Things got Messy
Browse files- Model_Architecture/data.py +188 -0
- Model_Architecture/generation.py +2 -2
- Model_Architecture/model.py +106 -110
- Model_Architecture/model_size.py +226 -0
- Model_Architecture/train.py +483 -0
Model_Architecture/data.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tiktoken
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader
|
| 4 |
+
from typing import Tuple, Optional, Literal, List
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import mmap
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from model import ModelArgs
|
| 11 |
+
|
| 12 |
+
#####################################
|
| 13 |
+
# DATA
|
| 14 |
+
#####################################
|
| 15 |
+
class TextDataset(Dataset):
|
| 16 |
+
def __init__(self, txt: str, tokenizer, args: ModelArgs, stride: Optional[int] = None, max_samples: Optional[int] = None):
|
| 17 |
+
"""
|
| 18 |
+
Optimized text dataset with memory-mapped reading and batched tokenization.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
txt: Text content or path to file
|
| 22 |
+
tokenizer: Pretrained tokenizer with .encode() method
|
| 23 |
+
args: ModelArgs containing max_seq_len, max_batch_size
|
| 24 |
+
stride: Sliding window stride. Defaults to max_seq_len // 2
|
| 25 |
+
max_samples: Limit number of samples for quick testing
|
| 26 |
+
"""
|
| 27 |
+
self.max_seq_len = args.max_seq_len
|
| 28 |
+
self.stride = stride if stride is not None else self.max_seq_len // 2
|
| 29 |
+
|
| 30 |
+
# Handle file paths efficiently with memory mapping
|
| 31 |
+
if Path(txt).exists():
|
| 32 |
+
text_content = self._read_file_mmap(txt)
|
| 33 |
+
else:
|
| 34 |
+
text_content = txt
|
| 35 |
+
|
| 36 |
+
# Validate input
|
| 37 |
+
if not text_content or len(text_content.strip()) < self.max_seq_len:
|
| 38 |
+
raise ValueError(f"Text too short. Need at least {self.max_seq_len} chars, got {len(text_content)}")
|
| 39 |
+
|
| 40 |
+
print(f"📝 Tokenizing {len(text_content):,} characters...")
|
| 41 |
+
|
| 42 |
+
# Tokenize with progress bar for large texts
|
| 43 |
+
token_ids = self._tokenize_with_progress(tokenizer, text_content)
|
| 44 |
+
|
| 45 |
+
# Create sliding windows with vectorized operations
|
| 46 |
+
self.samples = self._create_sliding_windows(token_ids, max_samples)
|
| 47 |
+
|
| 48 |
+
print(f"✅ Created {len(self.samples)} training samples")
|
| 49 |
+
|
| 50 |
+
def _read_file_mmap(self, file_path: str) -> str:
|
| 51 |
+
"""Memory-efficient file reading for large files"""
|
| 52 |
+
try:
|
| 53 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 54 |
+
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
|
| 55 |
+
return mm.read().decode('utf-8', errors='ignore')
|
| 56 |
+
except Exception as e:
|
| 57 |
+
raise RuntimeError(f"Failed to read file {file_path}: {e}")
|
| 58 |
+
|
| 59 |
+
def _tokenize_with_progress(self, tokenizer, text: str) -> List[int]:
|
| 60 |
+
"""Tokenize with progress bar for large texts"""
|
| 61 |
+
# Process in chunks for memory efficiency
|
| 62 |
+
chunk_size = 10_000_000 # 10MB chunks
|
| 63 |
+
tokens = []
|
| 64 |
+
|
| 65 |
+
if len(text) > chunk_size:
|
| 66 |
+
# Process large texts in chunks
|
| 67 |
+
pbar = tqdm(total=len(text), desc="Tokenizing", unit="char")
|
| 68 |
+
for i in range(0, len(text), chunk_size):
|
| 69 |
+
chunk = text[i:i + chunk_size]
|
| 70 |
+
chunk_tokens = tokenizer.encode(chunk, allowed_special={"<|endoftext|>"})
|
| 71 |
+
tokens.extend(chunk_tokens)
|
| 72 |
+
pbar.update(len(chunk))
|
| 73 |
+
pbar.close()
|
| 74 |
+
else:
|
| 75 |
+
# Single pass for smaller texts
|
| 76 |
+
tokens = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
|
| 77 |
+
|
| 78 |
+
if not tokens:
|
| 79 |
+
raise ValueError("No tokens generated from input text")
|
| 80 |
+
|
| 81 |
+
return tokens
|
| 82 |
+
|
| 83 |
+
def _create_sliding_windows(self, token_ids: List[int], max_samples: Optional[int]) -> torch.Tensor:
|
| 84 |
+
"""Create overlapping sequences using vectorized operations"""
|
| 85 |
+
if len(token_ids) < self.max_seq_len + 1:
|
| 86 |
+
raise ValueError(f"Not enough tokens. Need {self.max_seq_len + 1}, got {len(token_ids)}")
|
| 87 |
+
|
| 88 |
+
# Convert to numpy for faster slicing
|
| 89 |
+
tokens_array = np.array(token_ids, dtype=np.int64)
|
| 90 |
+
|
| 91 |
+
# Calculate number of windows
|
| 92 |
+
num_windows = (len(tokens_array) - self.max_seq_len - 1) // self.stride + 1
|
| 93 |
+
|
| 94 |
+
if max_samples:
|
| 95 |
+
num_windows = min(num_windows, max_samples)
|
| 96 |
+
|
| 97 |
+
# Pre-allocate tensors
|
| 98 |
+
inputs = torch.zeros(num_windows, self.max_seq_len, dtype=torch.long)
|
| 99 |
+
targets = torch.zeros(num_windows, self.max_seq_len, dtype=torch.long)
|
| 100 |
+
|
| 101 |
+
# Fill tensors efficiently
|
| 102 |
+
for i in range(num_windows):
|
| 103 |
+
start = i * self.stride
|
| 104 |
+
inputs[i] = torch.from_numpy(tokens_array[start:start + self.max_seq_len])
|
| 105 |
+
targets[i] = torch.from_numpy(tokens_array[start + 1:start + self.max_seq_len + 1])
|
| 106 |
+
|
| 107 |
+
# Stack into pairs (more memory efficient than separate lists)
|
| 108 |
+
self.samples = torch.stack([inputs, targets], dim=1)
|
| 109 |
+
|
| 110 |
+
return self.samples
|
| 111 |
+
|
| 112 |
+
def __len__(self):
|
| 113 |
+
return len(self.samples)
|
| 114 |
+
|
| 115 |
+
def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 116 |
+
"""Return (input_ids, target_ids) tuple"""
|
| 117 |
+
return self.samples[idx, 0], self.samples[idx, 1]
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def create_dataloader(
|
| 121 |
+
txt: str,
|
| 122 |
+
args: ModelArgs,
|
| 123 |
+
stride: Optional[int] = None,
|
| 124 |
+
shuffle: bool = True,
|
| 125 |
+
drop_last: bool = True,
|
| 126 |
+
num_workers: int = 0,
|
| 127 |
+
pin_memory: bool = True,
|
| 128 |
+
persistent_workers: bool = False,
|
| 129 |
+
max_samples: Optional[int] = None
|
| 130 |
+
) -> DataLoader:
|
| 131 |
+
"""
|
| 132 |
+
Optimized DataLoader with proper memory pinning and worker settings.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
txt: Text content or file path
|
| 136 |
+
args: ModelArgs configuration
|
| 137 |
+
stride: Sliding window stride
|
| 138 |
+
shuffle: Whether to shuffle samples
|
| 139 |
+
drop_last: Drop incomplete batches
|
| 140 |
+
num_workers: Number of data loading workers (0 = main process)
|
| 141 |
+
pin_memory: Pin memory for faster GPU transfer (recommended)
|
| 142 |
+
persistent_workers: Keep workers alive between epochs (if num_workers > 0)
|
| 143 |
+
max_samples: Limit samples for testing
|
| 144 |
+
"""
|
| 145 |
+
# Use the best default tokenizer for your setup
|
| 146 |
+
# tiktoken's gpt2 is fast, well-tested, and has reasonable vocab size (~50k)
|
| 147 |
+
# For multilingual or code, consider "cl100k_base" or "o200k_base"
|
| 148 |
+
tokenizer_name = getattr(args, "tokenizer_name", "gpt2")
|
| 149 |
+
tokenizer = tiktoken.get_encoding(tokenizer_name)
|
| 150 |
+
|
| 151 |
+
# Create dataset with size validation
|
| 152 |
+
try:
|
| 153 |
+
dataset = TextDataset(
|
| 154 |
+
txt=txt,
|
| 155 |
+
tokenizer=tokenizer,
|
| 156 |
+
args=args,
|
| 157 |
+
stride=stride,
|
| 158 |
+
max_samples=max_samples
|
| 159 |
+
)
|
| 160 |
+
except Exception as e:
|
| 161 |
+
raise RuntimeError(f"Failed to create dataset: {e}")
|
| 162 |
+
|
| 163 |
+
# Create DataLoader with optimized settings
|
| 164 |
+
dataloader = DataLoader(
|
| 165 |
+
dataset,
|
| 166 |
+
batch_size=args.max_batch_size,
|
| 167 |
+
shuffle=shuffle,
|
| 168 |
+
drop_last=drop_last,
|
| 169 |
+
num_workers=num_workers,
|
| 170 |
+
pin_memory=pin_memory,
|
| 171 |
+
persistent_workers=persistent_workers if num_workers > 0 else False,
|
| 172 |
+
prefetch_factor=2 if num_workers > 0 else None,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
return dataloader
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# Convenience function for downloading sample data
|
| 179 |
+
def get_sample_data(url: str = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt") -> str:
|
| 180 |
+
"""Download sample text data for testing"""
|
| 181 |
+
try:
|
| 182 |
+
import requests
|
| 183 |
+
response = requests.get(url)
|
| 184 |
+
response.raise_for_status()
|
| 185 |
+
return response.text
|
| 186 |
+
except Exception as e:
|
| 187 |
+
print(f"⚠️ Could not download sample data: {e}")
|
| 188 |
+
return ""
|
Model_Architecture/generation.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import torch
|
| 2 |
import tiktoken
|
| 3 |
-
from model import
|
| 4 |
|
| 5 |
|
| 6 |
#####################################
|
|
@@ -151,7 +151,7 @@ if __name__ == "__main__":
|
|
| 151 |
# Initialize model and tokenizer
|
| 152 |
print("Initializing model...")
|
| 153 |
torch.manual_seed(123)
|
| 154 |
-
model =
|
| 155 |
model.eval()
|
| 156 |
|
| 157 |
tokenizer = tiktoken.get_encoding("gpt2")
|
|
|
|
| 1 |
import torch
|
| 2 |
import tiktoken
|
| 3 |
+
from model import ismail, ModelArgs
|
| 4 |
|
| 5 |
|
| 6 |
#####################################
|
|
|
|
| 151 |
# Initialize model and tokenizer
|
| 152 |
print("Initializing model...")
|
| 153 |
torch.manual_seed(123)
|
| 154 |
+
model = ismail(args)
|
| 155 |
model.eval()
|
| 156 |
|
| 157 |
tokenizer = tiktoken.get_encoding("gpt2")
|
Model_Architecture/model.py
CHANGED
|
@@ -2,7 +2,7 @@ import tiktoken
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
from torch.utils.data import Dataset, DataLoader
|
| 5 |
-
|
| 6 |
import math
|
| 7 |
from dataclasses import dataclass
|
| 8 |
from typing import Tuple, Optional, Literal
|
|
@@ -18,31 +18,32 @@ from kernel import act_quant, weight_dequant, fp8_gemm
|
|
| 18 |
@dataclass
|
| 19 |
class ModelArgs:
|
| 20 |
max_batch_size: int = 8
|
| 21 |
-
max_seq_len: int =
|
| 22 |
dtype: Literal["bf16", "fp8"] = "bf16"
|
| 23 |
scale_fmt: Optional[str] = None
|
|
|
|
| 24 |
vocab_size: int = 102400
|
| 25 |
-
dim: int =
|
| 26 |
-
inter_dim: int =
|
| 27 |
-
moe_inter_dim: int =
|
| 28 |
-
n_layers: int =
|
| 29 |
-
n_dense_layers: int =
|
| 30 |
-
n_heads: int =
|
|
|
|
| 31 |
# moe
|
| 32 |
-
n_routed_experts: int =
|
| 33 |
-
n_shared_experts: int =
|
| 34 |
-
n_activated_experts: int =
|
| 35 |
-
n_expert_groups: int = 1
|
| 36 |
-
n_limited_groups: int = 1
|
| 37 |
-
score_func: Literal["softmax", "sigmoid"] = "softmax"
|
| 38 |
route_scale: float = 1.
|
| 39 |
-
use_routing_bias: bool =
|
|
|
|
| 40 |
# mla
|
| 41 |
q_lora_rank: int = 0
|
| 42 |
kv_lora_rank: int = 512
|
| 43 |
qk_nope_head_dim: int = 128
|
| 44 |
qk_rope_head_dim: int = 64
|
| 45 |
v_head_dim: int = 128
|
|
|
|
| 46 |
# yarn
|
| 47 |
original_seq_len: int = 4096
|
| 48 |
rope_theta: float = 10000.0
|
|
@@ -58,54 +59,7 @@ block_size = 128
|
|
| 58 |
gemm_impl: Literal["bf16", "fp8"] = "bf16"
|
| 59 |
|
| 60 |
|
| 61 |
-
#####################################
|
| 62 |
-
# DATA
|
| 63 |
-
#####################################
|
| 64 |
-
class TextDataset(Dataset):
|
| 65 |
-
def __init__(self, txt, tokenizer, args: ModelArgs, stride: Optional[int] = None):
|
| 66 |
-
self.input_ids = []
|
| 67 |
-
self.target_ids = []
|
| 68 |
-
|
| 69 |
-
# Use max_seq_len from ModelArgs
|
| 70 |
-
max_length = args.max_seq_len
|
| 71 |
-
if stride is None:
|
| 72 |
-
stride = max_length // 2 # Default stride is half the sequence length
|
| 73 |
-
|
| 74 |
-
# Tokenize the entire text
|
| 75 |
-
token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
|
| 76 |
-
|
| 77 |
-
# Use a sliding window to chunk the book into overlapping sequences of max_length
|
| 78 |
-
for i in range(0, len(token_ids) - max_length, stride):
|
| 79 |
-
input_chunk = token_ids[i:i + max_length]
|
| 80 |
-
target_chunk = token_ids[i + 1: i + max_length + 1]
|
| 81 |
-
self.input_ids.append(torch.tensor(input_chunk))
|
| 82 |
-
self.target_ids.append(torch.tensor(target_chunk))
|
| 83 |
|
| 84 |
-
def __len__(self):
|
| 85 |
-
return len(self.input_ids)
|
| 86 |
-
|
| 87 |
-
def __getitem__(self, idx):
|
| 88 |
-
return self.input_ids[idx], self.target_ids[idx]
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
def create_dataloader(txt, args: ModelArgs, stride: Optional[int] = None,
|
| 92 |
-
shuffle: bool = True, drop_last: bool = True, num_workers: int = 0):
|
| 93 |
-
# Initialize the tokenizer
|
| 94 |
-
tokenizer = tiktoken.get_encoding("gpt2")
|
| 95 |
-
|
| 96 |
-
# Create dataset with ModelArgs
|
| 97 |
-
dataset = TextDataset(txt, tokenizer, args, stride)
|
| 98 |
-
|
| 99 |
-
# Create dataloader using batch_size from ModelArgs
|
| 100 |
-
dataloader = DataLoader(
|
| 101 |
-
dataset,
|
| 102 |
-
batch_size=args.max_batch_size,
|
| 103 |
-
shuffle=shuffle,
|
| 104 |
-
drop_last=drop_last,
|
| 105 |
-
num_workers=num_workers
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
-
return dataloader
|
| 109 |
|
| 110 |
#####################################
|
| 111 |
# RoPE
|
|
@@ -321,9 +275,6 @@ class Gate(nn.Module):
|
|
| 321 |
self.dim = args.dim
|
| 322 |
self.n_routed_experts = args.n_routed_experts
|
| 323 |
self.n_activated_experts = args.n_activated_experts
|
| 324 |
-
self.n_expert_groups = args.n_expert_groups
|
| 325 |
-
self.n_limited_groups = args.n_limited_groups
|
| 326 |
-
self.score_func = args.score_func
|
| 327 |
self.route_scale = args.route_scale
|
| 328 |
|
| 329 |
# Gate weight
|
|
@@ -341,10 +292,7 @@ class Gate(nn.Module):
|
|
| 341 |
scores = linear(x, self.weight)
|
| 342 |
|
| 343 |
# Apply scoring function
|
| 344 |
-
|
| 345 |
-
scores = scores.softmax(dim=-1, dtype=torch.float32)
|
| 346 |
-
else:
|
| 347 |
-
scores = scores.sigmoid()
|
| 348 |
|
| 349 |
original_scores = scores
|
| 350 |
|
|
@@ -352,17 +300,6 @@ class Gate(nn.Module):
|
|
| 352 |
if self.bias is not None:
|
| 353 |
scores = scores + self.bias
|
| 354 |
|
| 355 |
-
# Expert grouping for load balancing
|
| 356 |
-
if self.n_expert_groups > 1:
|
| 357 |
-
scores = scores.view(x.size(0), self.n_expert_groups, -1)
|
| 358 |
-
if self.bias is None:
|
| 359 |
-
group_scores = scores.amax(dim=-1)
|
| 360 |
-
else:
|
| 361 |
-
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
|
| 362 |
-
indices = group_scores.topk(self.n_limited_groups, dim=-1)[1]
|
| 363 |
-
mask = scores.new_ones(x.size(0), self.n_expert_groups, dtype=bool).scatter_(1, indices, False)
|
| 364 |
-
scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
|
| 365 |
-
|
| 366 |
# Select top-k experts
|
| 367 |
indices = torch.topk(scores, self.n_activated_experts, dim=-1)[1]
|
| 368 |
weights = original_scores.gather(1, indices)
|
|
@@ -391,56 +328,115 @@ class Expert(nn.Module):
|
|
| 391 |
|
| 392 |
|
| 393 |
class MoE(nn.Module):
|
| 394 |
-
|
| 395 |
def __init__(self, args: ModelArgs):
|
| 396 |
super().__init__()
|
| 397 |
self.dim = args.dim
|
| 398 |
self.n_routed_experts = args.n_routed_experts
|
| 399 |
self.n_activated_experts = args.n_activated_experts
|
| 400 |
-
|
| 401 |
-
|
| 402 |
self.gate = Gate(args)
|
| 403 |
-
|
| 404 |
-
# Routed experts
|
| 405 |
self.experts = nn.ModuleList([
|
| 406 |
Expert(args.dim, args.moe_inter_dim)
|
| 407 |
for _ in range(args.n_routed_experts)
|
| 408 |
])
|
| 409 |
-
|
| 410 |
-
# Shared experts (always process all tokens)
|
| 411 |
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
original_shape = x.size()
|
| 416 |
x = x.view(-1, self.dim)
|
| 417 |
-
|
| 418 |
-
#
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
y = torch.zeros_like(x)
|
| 423 |
-
|
| 424 |
-
# Process each routed expert
|
| 425 |
for i in range(self.n_routed_experts):
|
| 426 |
-
# Find tokens routed to this expert
|
| 427 |
idx, top = torch.where(indices == i)
|
| 428 |
if idx.numel() == 0:
|
| 429 |
continue
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
# Weight and accumulate expert outputs
|
| 435 |
-
y[idx] += expert_output * weights[idx, top, None]
|
| 436 |
-
|
| 437 |
-
# Process all tokens with shared experts
|
| 438 |
z = self.shared_experts(x)
|
| 439 |
-
|
| 440 |
-
# Combine routed and shared expert outputs
|
| 441 |
output = (y + z).view(original_shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
|
| 443 |
-
return output
|
| 444 |
|
| 445 |
|
| 446 |
#####################################
|
|
@@ -482,7 +478,7 @@ class Block(nn.Module):
|
|
| 482 |
# TRANSFORMER MODEL
|
| 483 |
#####################################
|
| 484 |
|
| 485 |
-
class
|
| 486 |
def __init__(self, args: ModelArgs):
|
| 487 |
super().__init__()
|
| 488 |
self.args = args
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
from contextlib import nullcontext
|
| 6 |
import math
|
| 7 |
from dataclasses import dataclass
|
| 8 |
from typing import Tuple, Optional, Literal
|
|
|
|
| 18 |
@dataclass
|
| 19 |
class ModelArgs:
|
| 20 |
max_batch_size: int = 8
|
| 21 |
+
max_seq_len: int = 2048
|
| 22 |
dtype: Literal["bf16", "fp8"] = "bf16"
|
| 23 |
scale_fmt: Optional[str] = None
|
| 24 |
+
|
| 25 |
vocab_size: int = 102400
|
| 26 |
+
dim: int = 1024
|
| 27 |
+
inter_dim: int = 4096
|
| 28 |
+
moe_inter_dim: int = 1024
|
| 29 |
+
n_layers: int = 20
|
| 30 |
+
n_dense_layers: int = 3
|
| 31 |
+
n_heads: int = 12
|
| 32 |
+
|
| 33 |
# moe
|
| 34 |
+
n_routed_experts: int = 6
|
| 35 |
+
n_shared_experts: int = 1
|
| 36 |
+
n_activated_experts: int = 2
|
|
|
|
|
|
|
|
|
|
| 37 |
route_scale: float = 1.
|
| 38 |
+
use_routing_bias: bool = True # Enable routing bias for fine-tuning expert selection
|
| 39 |
+
|
| 40 |
# mla
|
| 41 |
q_lora_rank: int = 0
|
| 42 |
kv_lora_rank: int = 512
|
| 43 |
qk_nope_head_dim: int = 128
|
| 44 |
qk_rope_head_dim: int = 64
|
| 45 |
v_head_dim: int = 128
|
| 46 |
+
|
| 47 |
# yarn
|
| 48 |
original_seq_len: int = 4096
|
| 49 |
rope_theta: float = 10000.0
|
|
|
|
| 59 |
gemm_impl: Literal["bf16", "fp8"] = "bf16"
|
| 60 |
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
#####################################
|
| 65 |
# RoPE
|
|
|
|
| 275 |
self.dim = args.dim
|
| 276 |
self.n_routed_experts = args.n_routed_experts
|
| 277 |
self.n_activated_experts = args.n_activated_experts
|
|
|
|
|
|
|
|
|
|
| 278 |
self.route_scale = args.route_scale
|
| 279 |
|
| 280 |
# Gate weight
|
|
|
|
| 292 |
scores = linear(x, self.weight)
|
| 293 |
|
| 294 |
# Apply scoring function
|
| 295 |
+
scores = scores.sigmoid()
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
original_scores = scores
|
| 298 |
|
|
|
|
| 300 |
if self.bias is not None:
|
| 301 |
scores = scores + self.bias
|
| 302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
# Select top-k experts
|
| 304 |
indices = torch.topk(scores, self.n_activated_experts, dim=-1)[1]
|
| 305 |
weights = original_scores.gather(1, indices)
|
|
|
|
| 328 |
|
| 329 |
|
| 330 |
class MoE(nn.Module):
|
|
|
|
| 331 |
def __init__(self, args: ModelArgs):
|
| 332 |
super().__init__()
|
| 333 |
self.dim = args.dim
|
| 334 |
self.n_routed_experts = args.n_routed_experts
|
| 335 |
self.n_activated_experts = args.n_activated_experts
|
| 336 |
+
self.active_expert_idx = None # None = all active (inference mode)
|
| 337 |
+
|
| 338 |
self.gate = Gate(args)
|
|
|
|
|
|
|
| 339 |
self.experts = nn.ModuleList([
|
| 340 |
Expert(args.dim, args.moe_inter_dim)
|
| 341 |
for _ in range(args.n_routed_experts)
|
| 342 |
])
|
|
|
|
|
|
|
| 343 |
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
|
| 344 |
+
self.ffn_norm = RMSNorm(args.dim)
|
| 345 |
+
|
| 346 |
+
# Load balance loss coefficient
|
| 347 |
+
self.lb_loss_coef = 0.01
|
| 348 |
+
|
| 349 |
+
def set_active_expert(self, expert_idx: Optional[int]):
|
| 350 |
+
"""Freeze all but the active expert to save optimizer memory"""
|
| 351 |
+
self.active_expert_idx = expert_idx
|
| 352 |
+
|
| 353 |
+
for i, expert in enumerate(self.experts):
|
| 354 |
+
requires_grad = (expert_idx is None) or (i == expert_idx)
|
| 355 |
+
for param in expert.parameters():
|
| 356 |
+
param.requires_grad = requires_grad
|
| 357 |
+
|
| 358 |
+
def compute_load_balance_loss(self, router_probs, expert_indices):
|
| 359 |
+
"""Encourage uniform expert utilization"""
|
| 360 |
+
# router_probs: [num_tokens, n_experts]
|
| 361 |
+
# expert_indices: [num_tokens, top_k]
|
| 362 |
+
|
| 363 |
+
# Token fraction per expert
|
| 364 |
+
tokens_per_expert = torch.zeros(self.n_routed_experts, device=router_probs.device)
|
| 365 |
+
indices_flat = expert_indices.view(-1)
|
| 366 |
+
ones = torch.ones_like(indices_flat, dtype=torch.float32)
|
| 367 |
+
tokens_per_expert.scatter_add_(0, indices_flat, ones)
|
| 368 |
+
tokens_per_expert = tokens_per_expert / (indices_flat.numel() + 1e-8)
|
| 369 |
+
|
| 370 |
+
# Average routing probability per expert
|
| 371 |
+
router_prob_per_expert = router_probs.mean(dim=0)
|
| 372 |
+
|
| 373 |
+
# Load balancing loss (minimize difference)
|
| 374 |
+
loss = torch.mean(tokens_per_expert * router_prob_per_expert) * self.n_routed_experts
|
| 375 |
+
return loss
|
| 376 |
+
|
| 377 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 378 |
original_shape = x.size()
|
| 379 |
x = x.view(-1, self.dim)
|
| 380 |
+
|
| 381 |
+
# Always compute routing (even in sequential mode to train the gate)
|
| 382 |
+
router_logits = F.linear(x, self.gate.weight)
|
| 383 |
+
router_probs = router_logits.sigmoid()
|
| 384 |
+
|
| 385 |
+
if self.gate.bias is not None:
|
| 386 |
+
router_logits = router_logits + self.gate.bias
|
| 387 |
+
|
| 388 |
+
# Select top-k experts
|
| 389 |
+
weights, indices = torch.topk(router_probs, self.n_activated_experts, dim=-1)
|
| 390 |
+
|
| 391 |
+
# Normalize weights
|
| 392 |
+
if self.gate.score_func == "sigmoid":
|
| 393 |
+
weights = weights / weights.sum(dim=-1, keepdim=True)
|
| 394 |
+
weights = weights * self.gate.route_scale
|
| 395 |
+
|
| 396 |
+
# Sequential Training Mode
|
| 397 |
+
if self.training and self.active_expert_idx is not None:
|
| 398 |
+
y = torch.zeros_like(x)
|
| 399 |
+
|
| 400 |
+
# Only compute gradients for active expert
|
| 401 |
+
for i in range(self.n_routed_experts):
|
| 402 |
+
idx, top = torch.where(indices == i)
|
| 403 |
+
if idx.numel() == 0:
|
| 404 |
+
continue
|
| 405 |
+
|
| 406 |
+
# Use gradient context manager
|
| 407 |
+
grad_context = nullcontext() if i == self.active_expert_idx else torch.no_grad()
|
| 408 |
+
|
| 409 |
+
with grad_context:
|
| 410 |
+
expert_out = self.experts[i](x[idx])
|
| 411 |
+
y[idx] += expert_out * weights[idx, top, None]
|
| 412 |
+
|
| 413 |
+
# Load balance loss (still needed for gate training)
|
| 414 |
+
lb_loss = self.compute_load_balance_loss(router_probs, indices)
|
| 415 |
+
|
| 416 |
+
# Shared experts always train
|
| 417 |
+
z = self.shared_experts(x)
|
| 418 |
+
|
| 419 |
+
return (y + z).view(original_shape), lb_loss
|
| 420 |
+
|
| 421 |
+
# Normal MoE Mode (inference or full training)
|
| 422 |
y = torch.zeros_like(x)
|
|
|
|
|
|
|
| 423 |
for i in range(self.n_routed_experts):
|
|
|
|
| 424 |
idx, top = torch.where(indices == i)
|
| 425 |
if idx.numel() == 0:
|
| 426 |
continue
|
| 427 |
+
|
| 428 |
+
expert_out = self.experts[i](x[idx])
|
| 429 |
+
y[idx] += expert_out * weights[idx, top, None]
|
| 430 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
z = self.shared_experts(x)
|
|
|
|
|
|
|
| 432 |
output = (y + z).view(original_shape)
|
| 433 |
+
|
| 434 |
+
if self.training:
|
| 435 |
+
lb_loss = self.compute_load_balance_loss(router_probs, indices)
|
| 436 |
+
return output, lb_loss
|
| 437 |
+
else:
|
| 438 |
+
return output, None
|
| 439 |
|
|
|
|
| 440 |
|
| 441 |
|
| 442 |
#####################################
|
|
|
|
| 478 |
# TRANSFORMER MODEL
|
| 479 |
#####################################
|
| 480 |
|
| 481 |
+
class ismail(nn.Module):
|
| 482 |
def __init__(self, args: ModelArgs):
|
| 483 |
super().__init__()
|
| 484 |
self.args = args
|
Model_Architecture/model_size.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
# Add the Model_Architecture directory to path
|
| 5 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 6 |
+
|
| 7 |
+
from model import ModelArgs
|
| 8 |
+
|
| 9 |
+
def estimate_model_size(args: ModelArgs):
|
| 10 |
+
"""Calculate detailed model size and parameter count"""
|
| 11 |
+
|
| 12 |
+
print(f"\n{'='*70}")
|
| 13 |
+
print(f"MODEL ARCHITECTURE ANALYSIS: ismail")
|
| 14 |
+
print(f"{'='*70}\n")
|
| 15 |
+
|
| 16 |
+
# Display configuration
|
| 17 |
+
print(f"📋 CONFIGURATION:")
|
| 18 |
+
print(f" Model dimension (dim): {args.dim}")
|
| 19 |
+
print(f" Vocabulary size: {args.vocab_size:,}")
|
| 20 |
+
print(f" Number of layers: {args.n_layers}")
|
| 21 |
+
print(f" Dense layers: {args.n_dense_layers}")
|
| 22 |
+
print(f" MoE layers: {args.n_layers - args.n_dense_layers}")
|
| 23 |
+
print(f" Attention heads: {args.n_heads}")
|
| 24 |
+
print(f" Max sequence length: {args.max_seq_len}")
|
| 25 |
+
print(f" Max batch size: {args.max_batch_size}")
|
| 26 |
+
print(f" \nMoE Configuration:")
|
| 27 |
+
print(f" Routed experts: {args.n_routed_experts}")
|
| 28 |
+
print(f" Shared experts: {args.n_shared_experts}")
|
| 29 |
+
print(f" Activated experts: {args.n_activated_experts}")
|
| 30 |
+
print(f" \nMLA Configuration:")
|
| 31 |
+
print(f" Q LoRA rank: {args.q_lora_rank}")
|
| 32 |
+
print(f" KV LoRA rank: {args.kv_lora_rank}")
|
| 33 |
+
print(f" QK nope head dim: {args.qk_nope_head_dim}")
|
| 34 |
+
print(f" QK rope head dim: {args.qk_rope_head_dim}")
|
| 35 |
+
print(f" V head dim: {args.v_head_dim}")
|
| 36 |
+
|
| 37 |
+
# Calculate parameters by component
|
| 38 |
+
print(f"\n{'='*70}")
|
| 39 |
+
print(f"🔢 PARAMETER COUNT BY COMPONENT:")
|
| 40 |
+
print(f"{'='*70}\n")
|
| 41 |
+
|
| 42 |
+
# 1. Embeddings
|
| 43 |
+
tok_embed_params = args.vocab_size * args.dim
|
| 44 |
+
output_params = args.vocab_size * args.dim
|
| 45 |
+
total_embed_params = tok_embed_params + output_params
|
| 46 |
+
print(f" Token Embeddings: {tok_embed_params:>15,} params")
|
| 47 |
+
print(f" Output Layer: {output_params:>15,} params")
|
| 48 |
+
print(f" {'─' * 50}")
|
| 49 |
+
print(f" Total Embeddings: {total_embed_params:>15,} params\n")
|
| 50 |
+
|
| 51 |
+
# 2. Attention (per layer)
|
| 52 |
+
if args.q_lora_rank == 0:
|
| 53 |
+
wq_params = args.dim * args.n_heads * (args.qk_nope_head_dim + args.qk_rope_head_dim)
|
| 54 |
+
wq_norm_params = 0
|
| 55 |
+
else:
|
| 56 |
+
wq_params = args.dim * args.q_lora_rank + args.q_lora_rank * args.n_heads * (args.qk_nope_head_dim + args.qk_rope_head_dim)
|
| 57 |
+
wq_norm_params = args.q_lora_rank
|
| 58 |
+
|
| 59 |
+
wkv_a_params = args.dim * (args.kv_lora_rank + args.qk_rope_head_dim)
|
| 60 |
+
kv_norm_params = args.kv_lora_rank
|
| 61 |
+
wkv_b_params = args.kv_lora_rank * args.n_heads * (args.qk_nope_head_dim + args.v_head_dim)
|
| 62 |
+
wo_params = args.n_heads * args.v_head_dim * args.dim
|
| 63 |
+
attn_norm_params = args.dim
|
| 64 |
+
|
| 65 |
+
attn_params_per_layer = wq_params + wq_norm_params + wkv_a_params + kv_norm_params + wkv_b_params + wo_params + attn_norm_params
|
| 66 |
+
|
| 67 |
+
print(f" Attention (per layer):")
|
| 68 |
+
if args.q_lora_rank > 0:
|
| 69 |
+
print(f" WQ (LoRA): {wq_params:>15,} params")
|
| 70 |
+
print(f" Q Norm: {wq_norm_params:>15,} params")
|
| 71 |
+
else:
|
| 72 |
+
print(f" WQ: {wq_params:>15,} params")
|
| 73 |
+
print(f" WKV_A: {wkv_a_params:>15,} params")
|
| 74 |
+
print(f" KV Norm: {kv_norm_params:>15,} params")
|
| 75 |
+
print(f" WKV_B: {wkv_b_params:>15,} params")
|
| 76 |
+
print(f" WO: {wo_params:>15,} params")
|
| 77 |
+
print(f" Attn Norm: {attn_norm_params:>15,} params")
|
| 78 |
+
print(f" {'─' * 50}")
|
| 79 |
+
print(f" Subtotal: {attn_params_per_layer:>15,} params\n")
|
| 80 |
+
|
| 81 |
+
# 3. Dense FFN
|
| 82 |
+
dense_w1_params = args.dim * args.inter_dim
|
| 83 |
+
dense_w2_params = args.inter_dim * args.dim
|
| 84 |
+
dense_w3_params = args.dim * args.inter_dim
|
| 85 |
+
ffn_norm_params = args.dim
|
| 86 |
+
dense_ffn_per_layer = dense_w1_params + dense_w2_params + dense_w3_params + ffn_norm_params
|
| 87 |
+
|
| 88 |
+
print(f" Dense FFN (per layer):")
|
| 89 |
+
print(f" FC1 (W1): {dense_w1_params:>15,} params")
|
| 90 |
+
print(f" FC2 (W3): {dense_w3_params:>15,} params")
|
| 91 |
+
print(f" FC3 (W2): {dense_w2_params:>15,} params")
|
| 92 |
+
print(f" FFN Norm: {ffn_norm_params:>15,} params")
|
| 93 |
+
print(f" {'─' * 50}")
|
| 94 |
+
print(f" Subtotal: {dense_ffn_per_layer:>15,} params\n")
|
| 95 |
+
|
| 96 |
+
# 4. MoE FFN
|
| 97 |
+
gate_params = args.n_routed_experts * args.dim
|
| 98 |
+
if args.use_routing_bias:
|
| 99 |
+
gate_params += args.n_routed_experts
|
| 100 |
+
|
| 101 |
+
expert_w1_params = args.dim * args.moe_inter_dim
|
| 102 |
+
expert_w2_params = args.moe_inter_dim * args.dim
|
| 103 |
+
expert_w3_params = args.dim * args.moe_inter_dim
|
| 104 |
+
per_expert_params = expert_w1_params + expert_w2_params + expert_w3_params
|
| 105 |
+
routed_experts_params = args.n_routed_experts * per_expert_params
|
| 106 |
+
|
| 107 |
+
shared_w1_params = args.dim * (args.n_shared_experts * args.moe_inter_dim)
|
| 108 |
+
shared_w2_params = (args.n_shared_experts * args.moe_inter_dim) * args.dim
|
| 109 |
+
shared_w3_params = args.dim * (args.n_shared_experts * args.moe_inter_dim)
|
| 110 |
+
shared_experts_params = shared_w1_params + shared_w2_params + shared_w3_params
|
| 111 |
+
|
| 112 |
+
moe_ffn_per_layer = gate_params + routed_experts_params + shared_experts_params + ffn_norm_params
|
| 113 |
+
|
| 114 |
+
print(f" MoE FFN (per layer):")
|
| 115 |
+
print(f" Gate: {gate_params:>15,} params")
|
| 116 |
+
print(f" Routed Experts ({args.n_routed_experts}x): {routed_experts_params:>15,} params")
|
| 117 |
+
print(f" Per expert: {per_expert_params:>15,} params")
|
| 118 |
+
print(f" Shared Experts: {shared_experts_params:>15,} params")
|
| 119 |
+
print(f" FFN Norm: {ffn_norm_params:>15,} params")
|
| 120 |
+
print(f" {'─' * 50}")
|
| 121 |
+
print(f" Subtotal: {moe_ffn_per_layer:>15,} params\n")
|
| 122 |
+
|
| 123 |
+
# 5. Final Norm
|
| 124 |
+
final_norm_params = args.dim
|
| 125 |
+
|
| 126 |
+
# Total calculation
|
| 127 |
+
dense_layer_params = attn_params_per_layer + dense_ffn_per_layer
|
| 128 |
+
moe_layer_params = attn_params_per_layer + moe_ffn_per_layer
|
| 129 |
+
|
| 130 |
+
total_dense_params = args.n_dense_layers * dense_layer_params
|
| 131 |
+
total_moe_params = (args.n_layers - args.n_dense_layers) * moe_layer_params
|
| 132 |
+
|
| 133 |
+
total_params = total_embed_params + total_dense_params + total_moe_params + final_norm_params
|
| 134 |
+
|
| 135 |
+
print(f" Layer Summary:")
|
| 136 |
+
print(f" Dense layers ({args.n_dense_layers}x): {total_dense_params:>15,} params")
|
| 137 |
+
print(f" MoE layers ({args.n_layers - args.n_dense_layers}x): {total_moe_params:>15,} params")
|
| 138 |
+
print(f" Final Norm: {final_norm_params:>15,} params")
|
| 139 |
+
|
| 140 |
+
print(f"\n{'='*70}")
|
| 141 |
+
print(f"📊 TOTAL PARAMETERS: {total_params:>15,} ({total_params/1e6:.2f}M)")
|
| 142 |
+
print(f"{'='*70}\n")
|
| 143 |
+
|
| 144 |
+
# Memory calculations
|
| 145 |
+
print(f"{'='*70}")
|
| 146 |
+
print(f"💾 MEMORY USAGE:")
|
| 147 |
+
print(f"{'='*70}\n")
|
| 148 |
+
|
| 149 |
+
bytes_per_param_bf16 = 2
|
| 150 |
+
bytes_per_param_fp32 = 4
|
| 151 |
+
|
| 152 |
+
# Model weights
|
| 153 |
+
weight_memory_bf16 = total_params * bytes_per_param_bf16 / (1024**3)
|
| 154 |
+
weight_memory_fp32 = total_params * bytes_per_param_fp32 / (1024**3)
|
| 155 |
+
|
| 156 |
+
print(f" Model Weights:")
|
| 157 |
+
print(f" BF16 (inference): {weight_memory_bf16:>10.3f} GB")
|
| 158 |
+
print(f" FP32 (training): {weight_memory_fp32:>10.3f} GB\n")
|
| 159 |
+
|
| 160 |
+
# KV Cache
|
| 161 |
+
kv_cache_per_layer = args.max_batch_size * args.max_seq_len * (args.kv_lora_rank + args.qk_rope_head_dim)
|
| 162 |
+
total_kv_cache = kv_cache_per_layer * args.n_layers * bytes_per_param_bf16 / (1024**3)
|
| 163 |
+
|
| 164 |
+
print(f" KV Cache (BF16):")
|
| 165 |
+
print(f" Per layer: {kv_cache_per_layer * bytes_per_param_bf16 / (1024**3):>10.3f} GB")
|
| 166 |
+
print(f" Total ({args.n_layers} layers): {total_kv_cache:>10.3f} GB\n")
|
| 167 |
+
|
| 168 |
+
# Activations (rough estimate)
|
| 169 |
+
activation_memory = (args.max_batch_size * args.max_seq_len * args.dim * args.n_layers * 4) / (1024**3)
|
| 170 |
+
|
| 171 |
+
print(f" Activations (estimate): {activation_memory:>10.3f} GB\n")
|
| 172 |
+
|
| 173 |
+
# Training overhead
|
| 174 |
+
gradients_memory = weight_memory_fp32 # Same size as weights
|
| 175 |
+
optimizer_states = weight_memory_fp32 * 2 # Adam: 2x for momentum + variance
|
| 176 |
+
training_overhead = gradients_memory + optimizer_states
|
| 177 |
+
|
| 178 |
+
print(f" Training Overhead (FP32):")
|
| 179 |
+
print(f" Gradients: {gradients_memory:>10.3f} GB")
|
| 180 |
+
print(f" Optimizer states (Adam): {optimizer_states:>10.3f} GB")
|
| 181 |
+
print(f" Total overhead: {training_overhead:>10.3f} GB\n")
|
| 182 |
+
|
| 183 |
+
# Total estimates
|
| 184 |
+
inference_total = weight_memory_bf16 + total_kv_cache + activation_memory
|
| 185 |
+
training_total = weight_memory_fp32 + total_kv_cache + activation_memory + training_overhead
|
| 186 |
+
|
| 187 |
+
print(f"{'='*70}")
|
| 188 |
+
print(f" INFERENCE (BF16): {inference_total:>10.3f} GB")
|
| 189 |
+
print(f" TRAINING (FP32 + Adam): {training_total:>10.3f} GB")
|
| 190 |
+
print(f"{'='*70}\n")
|
| 191 |
+
|
| 192 |
+
# Memory analysis
|
| 193 |
+
print(f"{'='*70}")
|
| 194 |
+
print(f"🎯 MEMORY ANALYSIS:")
|
| 195 |
+
print(f"{'='*70}\n")
|
| 196 |
+
|
| 197 |
+
for threshold, name in [(8, "8GB"), (16, "16GB"), (24, "24GB"), (32, "32GB"), (40, "40GB"), (48, "48GB"), (80, "80GB")]:
|
| 198 |
+
if inference_total <= threshold:
|
| 199 |
+
print(f" ✅ Inference fits in {name} GPU")
|
| 200 |
+
break
|
| 201 |
+
else:
|
| 202 |
+
print(f" ❌ Inference requires >80GB GPU")
|
| 203 |
+
|
| 204 |
+
for threshold, name in [(8, "8GB"), (16, "16GB"), (24, "24GB"), (32, "32GB"), (40, "40GB"), (48, "48GB"), (80, "80GB")]:
|
| 205 |
+
if training_total <= threshold:
|
| 206 |
+
print(f" ✅ Training fits in {name} GPU")
|
| 207 |
+
break
|
| 208 |
+
else:
|
| 209 |
+
print(f" ❌ Training requires >80GB GPU")
|
| 210 |
+
|
| 211 |
+
print(f"\n{'='*70}\n")
|
| 212 |
+
|
| 213 |
+
return {
|
| 214 |
+
'total_params': total_params,
|
| 215 |
+
'weight_memory_gb': weight_memory_bf16,
|
| 216 |
+
'inference_memory_gb': inference_total,
|
| 217 |
+
'training_memory_gb': training_total
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
if __name__ == "__main__":
|
| 222 |
+
# Load default configuration
|
| 223 |
+
args = ModelArgs()
|
| 224 |
+
|
| 225 |
+
# Run estimation
|
| 226 |
+
results = estimate_model_size(args)
|
Model_Architecture/train.py
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Sequential Expert Training Script for MoE on Single GPU
|
| 4 |
+
Memory Usage: ~7.2GB (vs 10.9GB for full MoE)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import json
|
| 12 |
+
import time
|
| 13 |
+
import math
|
| 14 |
+
|
| 15 |
+
# Import your model
|
| 16 |
+
from model import ismail, ModelArgs
|
| 17 |
+
from model_size import estimate_model_size
|
| 18 |
+
|
| 19 |
+
# Try to import optional dependencies
|
| 20 |
+
try:
|
| 21 |
+
import wandb
|
| 22 |
+
HAS_WANDB = True
|
| 23 |
+
except ImportError:
|
| 24 |
+
HAS_WANDB = False
|
| 25 |
+
print("⚠️ wandb not installed. Run 'pip install wandb' for experiment tracking.")
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
import bitsandbytes as bnb
|
| 29 |
+
HAS_BNB = True
|
| 30 |
+
except ImportError:
|
| 31 |
+
HAS_BNB = False
|
| 32 |
+
print("⚠️ bitsandbytes not installed. Run 'pip install bitsandbytes' for memory-efficient optimizer.")
|
| 33 |
+
|
| 34 |
+
# Configuration
|
| 35 |
+
DEFAULT_CONFIG = {
|
| 36 |
+
"model": {
|
| 37 |
+
"vocab_size": 32000, # Reduced from 102400
|
| 38 |
+
"dim": 1024,
|
| 39 |
+
"inter_dim": 4096,
|
| 40 |
+
"moe_inter_dim": 1024,
|
| 41 |
+
"n_layers": 16,
|
| 42 |
+
"n_dense_layers": 1, # Only first layer dense
|
| 43 |
+
"n_heads": 16, # Increased for better parallelism
|
| 44 |
+
# MoE
|
| 45 |
+
"n_routed_experts": 6,
|
| 46 |
+
"n_shared_experts": 1,
|
| 47 |
+
"n_activated_experts": 2,
|
| 48 |
+
# MLA
|
| 49 |
+
"q_lora_rank": 128, # Enable Q LoRA
|
| 50 |
+
"kv_lora_rank": 512,
|
| 51 |
+
"qk_nope_head_dim": 64,
|
| 52 |
+
"qk_rope_head_dim": 32,
|
| 53 |
+
"v_head_dim": 64,
|
| 54 |
+
# Sequence
|
| 55 |
+
"max_seq_len": 2048, # Start shorter
|
| 56 |
+
"max_batch_size": 4,
|
| 57 |
+
},
|
| 58 |
+
"training": {
|
| 59 |
+
"learning_rate": 3e-4,
|
| 60 |
+
"weight_decay": 0.1,
|
| 61 |
+
"beta1": 0.9,
|
| 62 |
+
"beta2": 0.95,
|
| 63 |
+
"grad_clip": 1.0,
|
| 64 |
+
"warmup_steps": 1000,
|
| 65 |
+
"total_steps": 50000,
|
| 66 |
+
"expert_rotation_steps": 2000, # Rotate expert every N steps
|
| 67 |
+
"gradient_accumulation_steps": 16,
|
| 68 |
+
"eval_every": 1000,
|
| 69 |
+
"save_every": 5000,
|
| 70 |
+
"save_dir": "./checkpoints",
|
| 71 |
+
"log_every": 100,
|
| 72 |
+
"dtype": "bf16",
|
| 73 |
+
"compile": True, # PyTorch 2.0+ compilation
|
| 74 |
+
},
|
| 75 |
+
"data": {
|
| 76 |
+
"train_file": "./data/train.txt",
|
| 77 |
+
"val_file": "./data/val.txt",
|
| 78 |
+
"stride": 512,
|
| 79 |
+
},
|
| 80 |
+
"logging": {
|
| 81 |
+
"use_wandb": HAS_WANDB,
|
| 82 |
+
"project_name": "sequential-moe",
|
| 83 |
+
"run_name": "moe-12gb-gpu",
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def parse_args():
|
| 89 |
+
parser = argparse.ArgumentParser(description="Train MoE model with sequential experts")
|
| 90 |
+
parser.add_argument("--config", type=str, help="Path to config JSON")
|
| 91 |
+
parser.add_argument("--train_file", type=str, help="Training text file")
|
| 92 |
+
parser.add_argument("--val_file", type=str, help="Validation text file")
|
| 93 |
+
parser.add_argument("--save_dir", type=str, default="./checkpoints")
|
| 94 |
+
parser.add_argument("--resume", type=str, help="Checkpoint to resume from")
|
| 95 |
+
parser.add_argument("--no_wandb", action="store_true", help="Disable wandb")
|
| 96 |
+
return parser.parse_args()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def load_config(args):
|
| 100 |
+
"""Load and merge configuration"""
|
| 101 |
+
config = DEFAULT_CONFIG.copy()
|
| 102 |
+
|
| 103 |
+
if args.config and Path(args.config).exists():
|
| 104 |
+
with open(args.config) as f:
|
| 105 |
+
user_config = json.load(f)
|
| 106 |
+
# Deep merge
|
| 107 |
+
for key, value in user_config.items():
|
| 108 |
+
if key in config and isinstance(value, dict):
|
| 109 |
+
config[key].update(value)
|
| 110 |
+
else:
|
| 111 |
+
config[key] = value
|
| 112 |
+
|
| 113 |
+
# Override from CLI args
|
| 114 |
+
if args.train_file:
|
| 115 |
+
config["data"]["train_file"] = args.train_file
|
| 116 |
+
if args.val_file:
|
| 117 |
+
config["data"]["val_file"] = args.val_file
|
| 118 |
+
if args.save_dir:
|
| 119 |
+
config["training"]["save_dir"] = args.save_dir
|
| 120 |
+
if args.no_wandb:
|
| 121 |
+
config["logging"]["use_wandb"] = False
|
| 122 |
+
|
| 123 |
+
return config
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def setup_model(config, device):
|
| 127 |
+
"""Initialize model and print size estimate"""
|
| 128 |
+
args = ModelArgs(**config["model"])
|
| 129 |
+
|
| 130 |
+
print("\n" + "="*70)
|
| 131 |
+
print("MODEL INITIALIZATION")
|
| 132 |
+
print("="*70 + "\n")
|
| 133 |
+
|
| 134 |
+
# Estimate size
|
| 135 |
+
size_info = estimate_model_size(args)
|
| 136 |
+
|
| 137 |
+
model = ismail(args).to(device)
|
| 138 |
+
|
| 139 |
+
# Compile for speed (PyTorch 2.0+)
|
| 140 |
+
if config["training"]["compile"]:
|
| 141 |
+
try:
|
| 142 |
+
model = torch.compile(model)
|
| 143 |
+
print("✅ Model compiled with torch.compile()\n")
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print(f"⚠️ Compilation failed: {e}\n")
|
| 146 |
+
|
| 147 |
+
return model, args
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def setup_optimizer(model, config):
|
| 151 |
+
"""Setup memory-efficient optimizer"""
|
| 152 |
+
training_cfg = config["training"]
|
| 153 |
+
|
| 154 |
+
# Separate parameter groups
|
| 155 |
+
expert_params = []
|
| 156 |
+
base_params = []
|
| 157 |
+
router_params = []
|
| 158 |
+
|
| 159 |
+
for name, param in model.named_parameters():
|
| 160 |
+
if "experts" in name and "shared" not in name:
|
| 161 |
+
expert_params.append(param)
|
| 162 |
+
elif "gate" in name:
|
| 163 |
+
router_params.append(param)
|
| 164 |
+
else:
|
| 165 |
+
base_params.append(param)
|
| 166 |
+
|
| 167 |
+
# Use 8-bit Adam if available
|
| 168 |
+
if HAS_BNB:
|
| 169 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 170 |
+
print("✅ Using AdamW8bit for memory efficiency")
|
| 171 |
+
else:
|
| 172 |
+
optimizer_class = torch.optim.AdamW
|
| 173 |
+
print("⚠️ Using standard AdamW (install bitsandbytes for memory savings)")
|
| 174 |
+
|
| 175 |
+
optimizer = optimizer_class(
|
| 176 |
+
[
|
| 177 |
+
{"params": base_params, "weight_decay": training_cfg["weight_decay"]},
|
| 178 |
+
{"params": expert_params, "weight_decay": training_cfg["weight_decay"]},
|
| 179 |
+
{"params": router_params, "weight_decay": 0.0}, # Usually no WD for router
|
| 180 |
+
],
|
| 181 |
+
lr=training_cfg["learning_rate"],
|
| 182 |
+
betas=(training_cfg["beta1"], training_cfg["beta2"]),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
return optimizer
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_lr(step, config):
|
| 189 |
+
"""Learning rate scheduler with warmup and cosine decay"""
|
| 190 |
+
training_cfg = config["training"]
|
| 191 |
+
warmup_steps = training_cfg["warmup_steps"]
|
| 192 |
+
total_steps = training_cfg["total_steps"]
|
| 193 |
+
base_lr = training_cfg["learning_rate"]
|
| 194 |
+
|
| 195 |
+
if step < warmup_steps:
|
| 196 |
+
return base_lr * step / warmup_steps
|
| 197 |
+
|
| 198 |
+
# Cosine decay
|
| 199 |
+
progress = (step - warmup_steps) / (total_steps - warmup_steps)
|
| 200 |
+
return base_lr * 0.5 * (1 + math.cos(math.pi * progress))
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def load_data(config):
|
| 204 |
+
"""Create data loaders"""
|
| 205 |
+
data_cfg = config["data"]
|
| 206 |
+
|
| 207 |
+
print("\n" + "="*70)
|
| 208 |
+
print("DATA LOADING")
|
| 209 |
+
print("="*70 + "\n")
|
| 210 |
+
|
| 211 |
+
from data import create_dataloader
|
| 212 |
+
|
| 213 |
+
train_loader = create_dataloader(
|
| 214 |
+
txt=Path(data_cfg["train_file"]).read_text(encoding="utf-8"),
|
| 215 |
+
args=ModelArgs(**config["model"]),
|
| 216 |
+
stride=data_cfg["stride"],
|
| 217 |
+
shuffle=True,
|
| 218 |
+
drop_last=True,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
val_loader = create_dataloader(
|
| 222 |
+
txt=Path(data_cfg["val_file"]).read_text(encoding="utf-8"),
|
| 223 |
+
args=ModelArgs(**config["model"]),
|
| 224 |
+
stride=data_cfg["stride"],
|
| 225 |
+
shuffle=False,
|
| 226 |
+
drop_last=True,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
print(f"✅ Train batches: {len(train_loader)}")
|
| 230 |
+
print(f"✅ Val batches: {len(val_loader)}\n")
|
| 231 |
+
|
| 232 |
+
return train_loader, val_loader
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def evaluate(model, val_loader, device, config):
|
| 236 |
+
"""Evaluate model on validation set"""
|
| 237 |
+
model.eval()
|
| 238 |
+
total_loss = 0.0
|
| 239 |
+
total_tokens = 0
|
| 240 |
+
|
| 241 |
+
with torch.no_grad():
|
| 242 |
+
for input_ids, target_ids in val_loader:
|
| 243 |
+
input_ids = input_ids.to(device)
|
| 244 |
+
target_ids = target_ids.to(device)
|
| 245 |
+
|
| 246 |
+
logits, lb_loss = model(input_ids, start_pos=0)
|
| 247 |
+
loss = F.cross_entropy(
|
| 248 |
+
logits.view(-1, logits.size(-1)),
|
| 249 |
+
target_ids.view(-1),
|
| 250 |
+
ignore_index=-1,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
total_loss += loss.item() * target_ids.numel()
|
| 254 |
+
total_tokens += target_ids.numel()
|
| 255 |
+
|
| 256 |
+
model.train()
|
| 257 |
+
return total_loss / total_tokens
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def save_checkpoint(model, optimizer, step, config, expert_idx=None):
|
| 261 |
+
"""Save model checkpoint"""
|
| 262 |
+
save_dir = Path(config["training"]["save_dir"])
|
| 263 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 264 |
+
|
| 265 |
+
# Create checkpoint name
|
| 266 |
+
if expert_idx is not None:
|
| 267 |
+
ckpt_name = f"step_{step}_expert_{expert_idx}.pt"
|
| 268 |
+
else:
|
| 269 |
+
ckpt_name = f"step_{step}.pt"
|
| 270 |
+
|
| 271 |
+
ckpt_path = save_dir / ckpt_name
|
| 272 |
+
|
| 273 |
+
checkpoint = {
|
| 274 |
+
"step": step,
|
| 275 |
+
"model_state_dict": model.state_dict(),
|
| 276 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 277 |
+
"config": config,
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
torch.save(checkpoint, ckpt_path)
|
| 281 |
+
print(f"💾 Checkpoint saved: {ckpt_path}")
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def train_step(model, batch, device, config, scaler=None):
|
| 285 |
+
"""Single training step"""
|
| 286 |
+
input_ids, target_ids = batch
|
| 287 |
+
input_ids = input_ids.to(device, non_blocking=True)
|
| 288 |
+
target_ids = target_ids.to(device, non_blocking=True)
|
| 289 |
+
|
| 290 |
+
# Forward pass
|
| 291 |
+
with torch.cuda.amp.autocast(enabled=(config["training"]["dtype"] == "bf16")):
|
| 292 |
+
logits, lb_loss = model(input_ids, start_pos=0)
|
| 293 |
+
|
| 294 |
+
# Main language modeling loss
|
| 295 |
+
lm_loss = F.cross_entropy(
|
| 296 |
+
logits.view(-1, logits.size(-1)),
|
| 297 |
+
target_ids.view(-1),
|
| 298 |
+
ignore_index=-1,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Total loss with load balancing
|
| 302 |
+
total_loss = lm_loss + config["training"].get("lb_loss_coef", 0.01) * lb_loss
|
| 303 |
+
|
| 304 |
+
return total_loss, lm_loss, lb_loss
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def main():
|
| 308 |
+
args = parse_args()
|
| 309 |
+
config = load_config(args)
|
| 310 |
+
|
| 311 |
+
# Device setup
|
| 312 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 313 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 314 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 315 |
+
|
| 316 |
+
# Wandb setup
|
| 317 |
+
if config["logging"]["use_wandb"] and HAS_WANDB:
|
| 318 |
+
wandb.init(
|
| 319 |
+
project=config["logging"]["project_name"],
|
| 320 |
+
name=config["logging"]["run_name"],
|
| 321 |
+
config=config,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# Model setup
|
| 325 |
+
model, model_args = setup_model(config, device)
|
| 326 |
+
|
| 327 |
+
# Optimizer setup
|
| 328 |
+
optimizer = setup_optimizer(model, config)
|
| 329 |
+
|
| 330 |
+
# Data setup
|
| 331 |
+
train_loader, val_loader = load_data(config)
|
| 332 |
+
train_iter = iter(train_loader)
|
| 333 |
+
|
| 334 |
+
# Training state
|
| 335 |
+
step = 0
|
| 336 |
+
best_val_loss = float("inf")
|
| 337 |
+
|
| 338 |
+
# Resume from checkpoint
|
| 339 |
+
if args.resume:
|
| 340 |
+
ckpt = torch.load(args.resume, map_location=device)
|
| 341 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 342 |
+
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
| 343 |
+
step = ckpt["step"]
|
| 344 |
+
print(f"✅ Resumed from step {step}\n")
|
| 345 |
+
|
| 346 |
+
# Gradient scaler for mixed precision
|
| 347 |
+
scaler = torch.cuda.amp.GradScaler(enabled=(config["training"]["dtype"] == "bf16"))
|
| 348 |
+
|
| 349 |
+
# Expert rotation schedule
|
| 350 |
+
current_expert = 0
|
| 351 |
+
rotation_steps = config["training"]["expert_rotation_steps"]
|
| 352 |
+
|
| 353 |
+
# Set initial expert
|
| 354 |
+
model.set_active_expert(current_expert)
|
| 355 |
+
print(f"🎯 Training expert {current_expert}/{model_args.n_routed_experts - 1}")
|
| 356 |
+
|
| 357 |
+
# Training loop
|
| 358 |
+
print("\n" + "="*70)
|
| 359 |
+
print("TRAINING STARTED")
|
| 360 |
+
print("="*70 + "\n")
|
| 361 |
+
|
| 362 |
+
model.train()
|
| 363 |
+
|
| 364 |
+
while step < config["training"]["total_steps"]:
|
| 365 |
+
step_start = time.time()
|
| 366 |
+
|
| 367 |
+
# Expert rotation
|
| 368 |
+
if step > 0 and step % rotation_steps == 0:
|
| 369 |
+
current_expert = (current_expert + 1) % model_args.n_routed_experts
|
| 370 |
+
model.set_active_expert(current_expert)
|
| 371 |
+
print(f"\n🔄 Rotating to expert {current_expert}/{model_args.n_routed_experts - 1}")
|
| 372 |
+
|
| 373 |
+
# Clear gradients after rotation
|
| 374 |
+
optimizer.zero_grad(set_to_none=True)
|
| 375 |
+
|
| 376 |
+
# Get batch with cycle handling
|
| 377 |
+
try:
|
| 378 |
+
batch = next(train_iter)
|
| 379 |
+
except StopIteration:
|
| 380 |
+
train_iter = iter(train_loader)
|
| 381 |
+
batch = next(train_iter)
|
| 382 |
+
|
| 383 |
+
# Training step with gradient accumulation
|
| 384 |
+
accum_steps = config["training"]["gradient_accumulation_steps"]
|
| 385 |
+
total_loss_accum = 0.0
|
| 386 |
+
lm_loss_accum = 0.0
|
| 387 |
+
lb_loss_accum = 0.0
|
| 388 |
+
|
| 389 |
+
for accum_step in range(accum_steps):
|
| 390 |
+
# Split batch for micro-batching (if needed)
|
| 391 |
+
# For now, process full batch
|
| 392 |
+
loss, lm_loss, lb_loss = train_step(model, batch, device, config, scaler)
|
| 393 |
+
|
| 394 |
+
# Normalize for accumulation
|
| 395 |
+
loss = loss / accum_steps
|
| 396 |
+
|
| 397 |
+
# Backward pass
|
| 398 |
+
if config["training"]["dtype"] == "bf16":
|
| 399 |
+
scaler.scale(loss).backward()
|
| 400 |
+
else:
|
| 401 |
+
loss.backward()
|
| 402 |
+
|
| 403 |
+
total_loss_accum += loss.item()
|
| 404 |
+
lm_loss_accum += lm_loss.item() / accum_steps
|
| 405 |
+
lb_loss_accum += lb_loss.item() / accum_steps
|
| 406 |
+
|
| 407 |
+
# Gradient clipping
|
| 408 |
+
if config["training"]["grad_clip"] > 0:
|
| 409 |
+
if config["training"]["dtype"] == "bf16":
|
| 410 |
+
scaler.unscale_(optimizer)
|
| 411 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), config["training"]["grad_clip"])
|
| 412 |
+
|
| 413 |
+
# Optimizer step
|
| 414 |
+
if config["training"]["dtype"] == "bf16":
|
| 415 |
+
scaler.step(optimizer)
|
| 416 |
+
scaler.update()
|
| 417 |
+
else:
|
| 418 |
+
optimizer.step()
|
| 419 |
+
|
| 420 |
+
optimizer.zero_grad(set_to_none=True)
|
| 421 |
+
|
| 422 |
+
# LR scheduling
|
| 423 |
+
lr = get_lr(step, config)
|
| 424 |
+
for param_group in optimizer.param_groups:
|
| 425 |
+
param_group["lr"] = lr
|
| 426 |
+
|
| 427 |
+
# Logging
|
| 428 |
+
if step % config["training"]["log_every"] == 0:
|
| 429 |
+
step_time = time.time() - step_start
|
| 430 |
+
tokens_per_sec = (model_args.max_batch_size * model_args.max_seq_len) / step_time
|
| 431 |
+
|
| 432 |
+
print(f"Step {step:6d} | "
|
| 433 |
+
f"Loss: {lm_loss_accum:.4f} | "
|
| 434 |
+
f"LB Loss: {lb_loss_accum:.4f} | "
|
| 435 |
+
f"LR: {lr:.2e} | "
|
| 436 |
+
f"Expert: {current_expert} | "
|
| 437 |
+
f"Tokens/s: {tokens_per_sec:.0f}")
|
| 438 |
+
|
| 439 |
+
if config["logging"]["use_wandb"] and HAS_WANDB:
|
| 440 |
+
wandb.log({
|
| 441 |
+
"step": step,
|
| 442 |
+
"loss": lm_loss_accum,
|
| 443 |
+
"load_balance_loss": lb_loss_accum,
|
| 444 |
+
"total_loss": total_loss_accum,
|
| 445 |
+
"learning_rate": lr,
|
| 446 |
+
"active_expert": current_expert,
|
| 447 |
+
"tokens_per_sec": tokens_per_sec,
|
| 448 |
+
"gpu_memory_gb": torch.cuda.memory_allocated() / 1024**3,
|
| 449 |
+
})
|
| 450 |
+
|
| 451 |
+
# Evaluation
|
| 452 |
+
if step % config["training"]["eval_every"] == 0 and step > 0:
|
| 453 |
+
print(f"\n📊 Evaluating at step {step}...")
|
| 454 |
+
val_loss = evaluate(model, val_loader, device, config)
|
| 455 |
+
print(f"Val Loss: {val_loss:.4f} | Perplexity: {math.exp(val_loss):.2f}\n")
|
| 456 |
+
|
| 457 |
+
if config["logging"]["use_wandb"] and HAS_WANDB:
|
| 458 |
+
wandb.log({"val_loss": val_loss, "val_perplexity": math.exp(val_loss)})
|
| 459 |
+
|
| 460 |
+
# Save best model
|
| 461 |
+
if val_loss < best_val_loss:
|
| 462 |
+
best_val_loss = val_loss
|
| 463 |
+
save_checkpoint(model, optimizer, step, config, expert_idx="best")
|
| 464 |
+
|
| 465 |
+
# Save checkpoint
|
| 466 |
+
if step % config["training"]["save_every"] == 0 and step > 0:
|
| 467 |
+
save_checkpoint(model, optimizer, step, config, expert_idx=current_expert)
|
| 468 |
+
|
| 469 |
+
step += 1
|
| 470 |
+
|
| 471 |
+
# Final save
|
| 472 |
+
save_checkpoint(model, optimizer, step, config, expert_idx="final")
|
| 473 |
+
|
| 474 |
+
if config["logging"]["use_wandb"] and HAS_WANDB:
|
| 475 |
+
wandb.finish()
|
| 476 |
+
|
| 477 |
+
print("\n" + "="*70)
|
| 478 |
+
print("TRAINING COMPLETED")
|
| 479 |
+
print("="*70)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
if __name__ == "__main__":
|
| 483 |
+
main()
|