Commit ·
5db5e42
1
Parent(s): 89731d3
Fixes - I guess
Browse files- Model_Architecture/config.json +13 -12
- Model_Architecture/train.py +45 -40
Model_Architecture/config.json
CHANGED
|
@@ -1,26 +1,26 @@
|
|
| 1 |
{
|
| 2 |
"model": {
|
| 3 |
-
"max_batch_size":
|
| 4 |
-
"max_seq_len":
|
| 5 |
"dtype": "bf16",
|
| 6 |
"scale_fmt": null,
|
| 7 |
"vocab_size": 32768,
|
| 8 |
-
"dim":
|
| 9 |
"inter_dim": 4096,
|
| 10 |
-
"moe_inter_dim":
|
| 11 |
-
"n_layers":
|
| 12 |
"n_dense_layers": 3,
|
| 13 |
"n_heads": 12,
|
| 14 |
-
"n_routed_experts":
|
| 15 |
"n_shared_experts": 1,
|
| 16 |
"n_activated_experts": 2,
|
| 17 |
"route_scale": 1.0,
|
| 18 |
"use_routing_bias": true,
|
| 19 |
"q_lora_rank": 0,
|
| 20 |
-
"kv_lora_rank":
|
| 21 |
-
"qk_nope_head_dim":
|
| 22 |
"qk_rope_head_dim": 64,
|
| 23 |
-
"v_head_dim":
|
| 24 |
"original_seq_len": 4096,
|
| 25 |
"rope_theta": 10000.0,
|
| 26 |
"rope_factor": 40,
|
|
@@ -36,9 +36,10 @@
|
|
| 36 |
"beta2": 0.95,
|
| 37 |
"grad_clip": 1.0,
|
| 38 |
"warmup_steps": 1000,
|
| 39 |
-
"total_steps":
|
| 40 |
-
"
|
| 41 |
-
"
|
|
|
|
| 42 |
"eval_every": 1000,
|
| 43 |
"save_every": 5000,
|
| 44 |
"save_dir": "./checkpoints",
|
|
|
|
| 1 |
{
|
| 2 |
"model": {
|
| 3 |
+
"max_batch_size": 2,
|
| 4 |
+
"max_seq_len": 512,
|
| 5 |
"dtype": "bf16",
|
| 6 |
"scale_fmt": null,
|
| 7 |
"vocab_size": 32768,
|
| 8 |
+
"dim": 512,
|
| 9 |
"inter_dim": 4096,
|
| 10 |
+
"moe_inter_dim": 512,
|
| 11 |
+
"n_layers": 16,
|
| 12 |
"n_dense_layers": 3,
|
| 13 |
"n_heads": 12,
|
| 14 |
+
"n_routed_experts": 4,
|
| 15 |
"n_shared_experts": 1,
|
| 16 |
"n_activated_experts": 2,
|
| 17 |
"route_scale": 1.0,
|
| 18 |
"use_routing_bias": true,
|
| 19 |
"q_lora_rank": 0,
|
| 20 |
+
"kv_lora_rank": 256,
|
| 21 |
+
"qk_nope_head_dim": 96,
|
| 22 |
"qk_rope_head_dim": 64,
|
| 23 |
+
"v_head_dim": 96,
|
| 24 |
"original_seq_len": 4096,
|
| 25 |
"rope_theta": 10000.0,
|
| 26 |
"rope_factor": 40,
|
|
|
|
| 36 |
"beta2": 0.95,
|
| 37 |
"grad_clip": 1.0,
|
| 38 |
"warmup_steps": 1000,
|
| 39 |
+
"total_steps": 100000,
|
| 40 |
+
"use_checkpointing": true,
|
| 41 |
+
"expert_rotation_steps": 5000,
|
| 42 |
+
"gradient_accumulation_steps": 8,
|
| 43 |
"eval_every": 1000,
|
| 44 |
"save_every": 5000,
|
| 45 |
"save_dir": "./checkpoints",
|
Model_Architecture/train.py
CHANGED
|
@@ -1,16 +1,20 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Sequential Expert Training Script for MoE on Single GPU
|
| 4 |
-
Memory Usage: ~7.2GB (vs 10.9GB for full MoE)
|
| 5 |
-
"""
|
| 6 |
|
| 7 |
import argparse
|
| 8 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import torch.nn.functional as F
|
| 10 |
from pathlib import Path
|
| 11 |
import json
|
| 12 |
import time
|
| 13 |
import math
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Import your model
|
| 16 |
from model import ismail, ModelArgs
|
|
@@ -143,6 +147,11 @@ def setup_model(config, device):
|
|
| 143 |
#size_info = estimate_model_size(args)
|
| 144 |
|
| 145 |
model = ismail(args).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
# Compile for speed (PyTorch 2.0+)
|
| 148 |
if config["training"]["compile"]:
|
|
@@ -296,41 +305,48 @@ def save_checkpoint(model, optimizer, step, config, expert_idx=None):
|
|
| 296 |
print(f"💾 Checkpoint saved: {ckpt_path}")
|
| 297 |
|
| 298 |
|
| 299 |
-
def train_step(model, batch, device, config, scaler=None):
|
| 300 |
-
"""
|
| 301 |
input_ids, target_ids = batch
|
| 302 |
-
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
-
# Forward pass
|
| 306 |
with torch.amp.autocast(device_type='cuda', enabled=(config["training"]["dtype"] == "bf16")):
|
| 307 |
-
output = model(
|
| 308 |
-
|
| 309 |
-
# Handle model output (tuple in training mode with MoE, single tensor otherwise)
|
| 310 |
if isinstance(output, tuple):
|
| 311 |
logits, lb_loss = output
|
| 312 |
else:
|
| 313 |
logits = output
|
| 314 |
lb_loss = 0.0
|
| 315 |
-
|
| 316 |
-
# Main language modeling loss
|
| 317 |
lm_loss = F.cross_entropy(
|
| 318 |
logits.view(-1, logits.size(-1)),
|
| 319 |
-
|
| 320 |
ignore_index=-1,
|
| 321 |
)
|
| 322 |
-
|
| 323 |
-
# Total loss with load balancing
|
| 324 |
-
# Ensure lb_loss is a tensor for proper gradient computation
|
| 325 |
if isinstance(lb_loss, float):
|
| 326 |
-
|
| 327 |
-
total_loss = lm_loss
|
| 328 |
else:
|
| 329 |
lb_loss_coef = config["training"].get("lb_loss_coef", 0.01)
|
| 330 |
-
total_loss = lm_loss + lb_loss_coef * lb_loss
|
| 331 |
|
| 332 |
-
#
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
|
| 336 |
def main():
|
|
@@ -417,24 +433,13 @@ def main():
|
|
| 417 |
total_loss_accum = 0.0
|
| 418 |
lm_loss_accum = 0.0
|
| 419 |
lb_loss_accum = 0.0
|
| 420 |
-
|
| 421 |
-
for accum_step in range(accum_steps):
|
| 422 |
-
# Split batch for micro-batching (if needed)
|
| 423 |
-
# For now, process full batch
|
| 424 |
-
loss, lm_loss, lb_loss = train_step(model, batch, device, config, scaler)
|
| 425 |
-
|
| 426 |
-
# Normalize for accumulation
|
| 427 |
-
loss = loss / accum_steps
|
| 428 |
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
|
| 435 |
-
total_loss_accum += loss.item()
|
| 436 |
-
lm_loss_accum += lm_loss / accum_steps # Already a float from train_step
|
| 437 |
-
lb_loss_accum += lb_loss / accum_steps # Already a float from train_step
|
| 438 |
|
| 439 |
# Gradient clipping
|
| 440 |
if config["training"]["grad_clip"] > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
|
| 2 |
import argparse
|
| 3 |
import torch
|
| 4 |
+
|
| 5 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 6 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 7 |
+
torch.backends.cudnn.benchmark = True
|
| 8 |
+
|
| 9 |
+
torch.cuda.empty_cache()
|
| 10 |
+
|
| 11 |
import torch.nn.functional as F
|
| 12 |
from pathlib import Path
|
| 13 |
import json
|
| 14 |
import time
|
| 15 |
import math
|
| 16 |
+
from torch.utils.checkpoint import checkpoint
|
| 17 |
+
|
| 18 |
|
| 19 |
# Import your model
|
| 20 |
from model import ismail, ModelArgs
|
|
|
|
| 147 |
#size_info = estimate_model_size(args)
|
| 148 |
|
| 149 |
model = ismail(args).to(device)
|
| 150 |
+
|
| 151 |
+
if config["training"].get("use_checkpointing", True):
|
| 152 |
+
for layer in model.layers:
|
| 153 |
+
layer.forward = lambda *args, layer=layer: checkpoint(layer._forward, *args)
|
| 154 |
+
print("✅ Gradient checkpointing enabled")
|
| 155 |
|
| 156 |
# Compile for speed (PyTorch 2.0+)
|
| 157 |
if config["training"]["compile"]:
|
|
|
|
| 305 |
print(f"💾 Checkpoint saved: {ckpt_path}")
|
| 306 |
|
| 307 |
|
| 308 |
+
def train_step(model, batch, device, config, accum_step, accum_steps, scaler=None):
|
| 309 |
+
"""Process a MICRO-batch for gradient accumulation"""
|
| 310 |
input_ids, target_ids = batch
|
| 311 |
+
|
| 312 |
+
# Split batch into micro-batches
|
| 313 |
+
micro_batch_size = input_ids.size(0) // accum_steps
|
| 314 |
+
start_idx = micro_batch_size * accum_step
|
| 315 |
+
end_idx = start_idx + micro_batch_size
|
| 316 |
+
|
| 317 |
+
# Get micro-batch slices
|
| 318 |
+
input_mb = input_ids[start_idx:end_idx].to(device, non_blocking=True)
|
| 319 |
+
target_mb = target_ids[start_idx:end_idx].to(device, non_blocking=True)
|
| 320 |
|
| 321 |
+
# Forward pass
|
| 322 |
with torch.amp.autocast(device_type='cuda', enabled=(config["training"]["dtype"] == "bf16")):
|
| 323 |
+
output = model(input_mb, start_pos=0)
|
| 324 |
+
|
|
|
|
| 325 |
if isinstance(output, tuple):
|
| 326 |
logits, lb_loss = output
|
| 327 |
else:
|
| 328 |
logits = output
|
| 329 |
lb_loss = 0.0
|
| 330 |
+
|
|
|
|
| 331 |
lm_loss = F.cross_entropy(
|
| 332 |
logits.view(-1, logits.size(-1)),
|
| 333 |
+
target_mb.view(-1),
|
| 334 |
ignore_index=-1,
|
| 335 |
)
|
| 336 |
+
|
|
|
|
|
|
|
| 337 |
if isinstance(lb_loss, float):
|
| 338 |
+
total_loss = lm_loss / accum_steps
|
|
|
|
| 339 |
else:
|
| 340 |
lb_loss_coef = config["training"].get("lb_loss_coef", 0.01)
|
| 341 |
+
total_loss = (lm_loss + lb_loss_coef * lb_loss) / accum_steps
|
| 342 |
|
| 343 |
+
# Backward
|
| 344 |
+
if config["training"]["dtype"] == "bf16":
|
| 345 |
+
scaler.scale(total_loss).backward()
|
| 346 |
+
else:
|
| 347 |
+
total_loss.backward()
|
| 348 |
+
|
| 349 |
+
return lm_loss.item(), lb_loss if isinstance(lb_loss, float) else lb_loss.item()
|
| 350 |
|
| 351 |
|
| 352 |
def main():
|
|
|
|
| 433 |
total_loss_accum = 0.0
|
| 434 |
lm_loss_accum = 0.0
|
| 435 |
lb_loss_accum = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
|
| 437 |
+
for accum_step in range(accum_steps):
|
| 438 |
+
lm_loss, lb_loss = train_step(model, batch, device, config,
|
| 439 |
+
accum_step, accum_steps, scaler)
|
| 440 |
+
lm_loss_accum += lm_loss / accum_steps
|
| 441 |
+
lb_loss_accum += lb_loss / accum_steps
|
| 442 |
|
|
|
|
|
|
|
|
|
|
| 443 |
|
| 444 |
# Gradient clipping
|
| 445 |
if config["training"]["grad_clip"] > 0:
|