ikaganacar's picture
Some Fixes
8f73121
import argparse
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()
import torch.nn.functional as F
from pathlib import Path
import json
import time
import math
# Import your model
from model import ismail, ModelArgs
# Try to import optional dependencies
try:
import wandb
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
print("⚠️ wandb not installed. Run 'pip install wandb' for experiment tracking.")
try:
import bitsandbytes as bnb
HAS_BNB = True
except ImportError:
HAS_BNB = False
print("⚠️ bitsandbytes not installed. Run 'pip install bitsandbytes' for memory-efficient optimizer.")
# Configuration - matches ModelArgs defaults
DEFAULT_CONFIG = {
"model": {
"max_batch_size": 8,
"max_seq_len": 2048,
"dtype": "bf16",
"scale_fmt": None,
"vocab_size": 102400,
"dim": 1024,
"inter_dim": 4096,
"moe_inter_dim": 1024,
"n_layers": 20,
"n_dense_layers": 3,
"n_heads": 12,
"n_routed_experts": 6,
"n_shared_experts": 1,
"n_activated_experts": 2,
"route_scale": 1.0,
"use_routing_bias": True,
"q_lora_rank": 0,
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"original_seq_len": 4096,
"rope_theta": 10000.0,
"rope_factor": 40,
"beta_fast": 32,
"beta_slow": 1,
"mscale": 1.0,
"tokenizer_name": "gpt2",
},
"training": {
"learning_rate": 3e-4,
"weight_decay": 0.1,
"beta1": 0.9,
"beta2": 0.95,
"grad_clip": 1.0,
"warmup_steps": 1000,
"total_steps": 50000,
"expert_rotation_steps": 2000, # Rotate expert every N steps
"gradient_accumulation_steps": 16,
"eval_every": 1000,
"save_every": 5000,
"save_dir": "./checkpoints",
"log_every": 100,
"dtype": "bf16",
"compile": True, # PyTorch 2.0+ compilation
},
"data": {
"train_file": "./data/train.txt",
"val_file": "./data/val.txt",
"stride": 512,
},
"logging": {
"use_wandb": HAS_WANDB,
"project_name": "sequential-moe",
"run_name": "moe-12gb-gpu",
}
}
def parse_args():
parser = argparse.ArgumentParser(description="Train MoE model with sequential experts")
parser.add_argument("--config", type=str, help="Path to config JSON")
parser.add_argument("--train_file", type=str, help="Training text file")
parser.add_argument("--val_file", type=str, help="Validation text file")
parser.add_argument("--save_dir", type=str, default="./checkpoints")
parser.add_argument("--resume", type=str, help="Checkpoint to resume from")
parser.add_argument("--no_wandb", action="store_true", help="Disable wandb")
return parser.parse_args()
def load_config(args):
"""Load and merge configuration"""
config = DEFAULT_CONFIG.copy()
if args.config and Path(args.config).exists():
with open(args.config) as f:
user_config = json.load(f)
# Deep merge
for key, value in user_config.items():
if key in config and isinstance(value, dict):
config[key].update(value)
else:
config[key] = value
# Override from CLI args
if args.train_file:
config["data"]["train_file"] = args.train_file
if args.val_file:
config["data"]["val_file"] = args.val_file
if args.save_dir:
config["training"]["save_dir"] = args.save_dir
if args.no_wandb:
config["logging"]["use_wandb"] = False
return config
def setup_model(config, device):
from model import Linear
args = ModelArgs(**config["model"])
# ✅ CRITICAL: Set the global dtype for Linear layers
training_dtype = config["training"]["dtype"].lower()
if training_dtype == "bf16":
Linear.dtype = torch.bfloat16
elif training_dtype == "fp16":
Linear.dtype = torch.float16
else:
Linear.dtype = torch.float32
model = ismail(args).to(device=device, dtype=Linear.dtype)
# Add this line to enable checkpointing
model.use_checkpointing = config["training"].get("use_checkpointing", True)
if config["training"]["compile"]:
try:
model = torch.compile(model)
print("✅ Model compiled\n")
except Exception as e:
print(f"⚠️ Compilation failed: {e}\n")
return model, args
def setup_optimizer(model, config):
"""Setup memory-efficient optimizer"""
training_cfg = config["training"]
# Separate parameter groups
expert_params = []
base_params = []
router_params = []
for name, param in model.named_parameters():
if "experts" in name and "shared" not in name:
expert_params.append(param)
elif "gate" in name:
router_params.append(param)
else:
base_params.append(param)
# Use 8-bit Adam if available
if HAS_BNB:
optimizer_class = bnb.optim.AdamW8bit
print("✅ Using AdamW8bit for memory efficiency")
else:
optimizer_class = torch.optim.AdamW
print("⚠️ Using standard AdamW (install bitsandbytes for memory savings)")
optimizer = optimizer_class(
[
{"params": base_params, "weight_decay": training_cfg["weight_decay"]},
{"params": expert_params, "weight_decay": training_cfg["weight_decay"]},
{"params": router_params, "weight_decay": 0.0}, # Usually no WD for router
],
lr=training_cfg["learning_rate"],
betas=(training_cfg["beta1"], training_cfg["beta2"]),
)
return optimizer
def get_lr(step, config):
"""Learning rate scheduler with warmup and cosine decay"""
training_cfg = config["training"]
warmup_steps = training_cfg["warmup_steps"]
total_steps = training_cfg["total_steps"]
base_lr = training_cfg["learning_rate"]
if step < warmup_steps:
return base_lr * step / warmup_steps
# Cosine decay
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return base_lr * 0.5 * (1 + math.cos(math.pi * progress))
def load_data(config):
from data import create_dataloader
data_cfg = config["data"]
print("\n" + "="*70)
print("DATA LOADING")
print("="*70 + "\n")
from model import ModelArgs
args = ModelArgs(**config["model"])
train_loader, tokenizer = create_dataloader(
txt=str(data_cfg["train_file"]),
use_turkish_tokenizer=True,
args=args,
stride=data_cfg["stride"],
shuffle=True,
drop_last=True,
use_memory_efficient=True,
is_val=False
)
val_loader, tokenizer = create_dataloader(
txt=str(data_cfg["val_file"]),
use_turkish_tokenizer=True,
args=args,
stride=data_cfg["stride"],
shuffle=False,
drop_last=True,
use_memory_efficient=True,
is_val=True
)
print(f"✅ Train batches: {len(train_loader)}")
print(f"✅ Val batches: {len(val_loader)}\n")
return train_loader, val_loader, tokenizer # Return tokenizer
def evaluate(model, val_loader, device, config, tokenizer, active_expert=None):
"""Evaluate model on validation set
Args:
active_expert: If not None, only evaluate with this expert active
(useful for sequential training to see individual expert progress)
"""
model.eval()
# CRITICAL FIX: Store original gradient requirements for experts
original_expert_grads = {}
for name, param in model.named_parameters():
if "experts" in name:
original_expert_grads[name] = param.requires_grad
# Enable gradients for all experts during evaluation
for name, param in model.named_parameters():
if "experts" in name:
param.requires_grad = True
# Clear caches...
for layer in model.layers:
if hasattr(layer.attn, 'kv_cache'):
layer.attn.kv_cache.zero_()
if hasattr(layer.attn, 'pe_cache'):
layer.attn.pe_cache.zero_()
# Set expert mode for validation
if hasattr(model, 'set_active_expert'):
# CRITICAL: For validation, temporarily set to None (all experts)
# even if we're in sequential training mode
if active_expert is not None:
print(f" Validating with ONLY expert {active_expert}")
# Store the actual active expert but use all for forward pass
validation_expert = active_expert
else:
print(f" Validating with ALL experts")
validation_expert = None
# Always use all experts for validation forward pass
model.set_active_expert(None)
total_loss = 0.0
total_tokens = 0
max_batches = config["training"].get("max_val_batches", 200)
from tqdm import tqdm
pbar = tqdm(total=max_batches, desc="📊 Validating", ncols=80)
val_dtype = config["training"]["dtype"]
batch_losses = []
with torch.no_grad():
for i, (input_ids, target_ids) in enumerate(val_loader):
if i >= max_batches:
break
input_ids = input_ids.to(device, non_blocking=True)
target_ids = target_ids.to(device, non_blocking=True)
# CRITICAL: Use proper autocast settings based on dtype
use_autocast = val_dtype in ['bf16', 'fp16']
with torch.amp.autocast(device_type='cuda', enabled=use_autocast, dtype=torch.bfloat16 if val_dtype == 'bf16' else torch.float16):
output = model(input_ids, start_pos=0)
logits = output[0] if isinstance(output, tuple) else output
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
target_ids.view(-1),
ignore_index=-1,
)
batch_losses.append(loss.item())
total_loss += loss.item() * target_ids.numel()
total_tokens += target_ids.numel()
pbar.update(1)
pbar.set_postfix({'loss': f'{loss.item():.3f}'})
pbar.close()
# CRITICAL: Restore original gradient requirements
for name, param in model.named_parameters():
if name in original_expert_grads:
param.requires_grad = original_expert_grads[name]
# Restore the active expert if in sequential training mode
if hasattr(model, 'set_active_expert') and 'validation_expert' in locals():
model.set_active_expert(validation_expert)
model.train()
final_loss = total_loss / total_tokens
# Show loss variation stats
if len(batch_losses) > 1:
loss_std = torch.std(torch.tensor(batch_losses)).item()
print(f" Loss std dev: {loss_std:.6f} (should be >0.01)")
return final_loss
def save_checkpoint(model, optimizer, step, config, expert_idx=None):
"""Save model checkpoint"""
save_dir = Path(config["training"]["save_dir"])
save_dir.mkdir(parents=True, exist_ok=True)
ckpt_name = f"step_{step}_expert_{expert_idx}.pt" if expert_idx is not None else f"step_{step}.pt"
ckpt_path = save_dir / ckpt_name
# 🔥 Exclude cache buffers - they should be reinitialized from config
state_dict = model.state_dict()
filtered_state_dict = {k: v for k, v in state_dict.items() if 'cache' not in k.lower()}
checkpoint = {
"step": step,
"model_state_dict": filtered_state_dict,
"optimizer_state_dict": optimizer.state_dict(),
"config": config,
}
torch.save(checkpoint, ckpt_path)
print(f"💾 Checkpoint saved: {ckpt_path}")
def train_step(model, input_mb, target_mb, device, config, scaler=None):
"""Process a SINGLE micro-batch (already sliced)"""
# 🚨 Validate data with more detail
if input_mb.size(0) == 0:
print("🚨 Warning: Empty micro-batch received")
return 0.0, 0.0
vocab_size = config["model"]["vocab_size"]
input_max = input_mb.max().item()
target_max = target_mb.max().item()
if input_max >= vocab_size or target_max >= vocab_size:
print(f"🚨 Invalid token detected! "
f"Input max: {input_max}, Target max: {target_max}, "
f"Vocab size: {vocab_size}")
# Clamp tokens to valid range
input_mb = torch.clamp(input_mb, max=vocab_size-1)
target_mb = torch.clamp(target_mb, max=vocab_size-1)
# Check for NaN in data
if torch.isnan(input_mb).any() or torch.isnan(target_mb).any():
print("🚨 NaN detected in input data! Replacing with zeros")
input_mb = torch.nan_to_num(input_mb, nan=0)
target_mb = torch.nan_to_num(target_mb, nan=0)
input_mb = input_mb.to(device, non_blocking=True)
target_mb = target_mb.to(device, non_blocking=True)
training_dtype = config["training"]["dtype"].lower()
use_autocast = training_dtype in ['bf16', 'fp16']
autocast_dtype = torch.bfloat16 if training_dtype == 'bf16' else torch.float16
with torch.amp.autocast(device_type='cuda', enabled=use_autocast, dtype=autocast_dtype if use_autocast else None):
output = model(input_mb, start_pos=0)
if isinstance(output, tuple):
logits, lb_loss = output
else:
logits = output
lb_loss = 0.0
# 🚨 Check for NaN in logits before computing loss
if torch.isnan(logits).any():
print(f"🚨 NaN detected in logits! Scale: {logits.abs().max().item()}")
print(f" Input range: [{input_mb.min().item()}, {input_mb.max().item()}]")
return 0.0, 0.0
lm_loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
target_mb.view(-1),
ignore_index=-1,
)
# 🚨 Check for NaN in loss components
if torch.isnan(lm_loss):
print(f"🚨 NaN in lm_loss!")
return 0.0, 0.0
accum_steps = config["training"]["gradient_accumulation_steps"]
if isinstance(lb_loss, float):
total_loss = lm_loss / accum_steps
else:
if torch.isnan(lb_loss):
print(f"🚨 NaN in lb_loss! Setting to 0")
lb_loss = 0.0
lb_loss_coef = config["training"].get("lb_loss_coef", 0.01)
total_loss = (lm_loss + lb_loss_coef * lb_loss) / accum_steps
# Backward with NaN check
if scaler is not None:
scaler.scale(total_loss).backward()
else:
total_loss.backward()
return lm_loss.item(), lb_loss if isinstance(lb_loss, float) else lb_loss.item()
def main():
args = parse_args()
config = load_config(args)
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.conv.fp32_precision = 'tf32'
torch.backends.cuda.matmul.fp32_precision = 'tf32'
# Wandb setup
if config["logging"]["use_wandb"] and HAS_WANDB:
wandb.init(project=config["logging"]["project_name"],
name=config["logging"]["run_name"], config=config)
# Model setup
model, model_args = setup_model(config, device)
# Optimizer setup
optimizer = setup_optimizer(model, config)
# Data setup
train_loader, val_loader, tokenizer = load_data(config)
train_iter = iter(train_loader)
# Training state
step = 0
best_val_loss = float("inf")
# Resume from checkpoint
if args.resume:
print(f"📥 Loading checkpoint from {args.resume}...")
ckpt = torch.load(args.resume, map_location=device)
# Create model with current config (ensures correct cache sizes)
model, model_args = setup_model(config, device)
# Load state dict but skip/resize mismatched buffers
model_state_dict = model.state_dict()
loaded_state_dict = ckpt["model_state_dict"]
skip_count = 0
for name, param in loaded_state_dict.items():
if name in model_state_dict:
if model_state_dict[name].shape != param.shape:
if "cache" in name: # Skip cache buffers
skip_count += 1
continue
else:
raise RuntimeError(f"Shape mismatch {name}: {param.shape} vs {model_state_dict[name].shape}")
model_state_dict[name].copy_(param)
else:
print(f"⚠️ Unexpected parameter: {name}")
model.load_state_dict(model_state_dict, strict=False)
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
step = ckpt["step"]
print(f"✅ Resumed from step {step} (skipped {skip_count} cache buffers)\n")
# ✅ FIX: Only create scaler for FP16, not BF16 or FP32
training_dtype = config["training"]["dtype"].lower()
use_fp16 = training_dtype == "fp16"
use_bf16 = training_dtype == "bf16"
if use_fp16:
scaler = torch.amp.GradScaler(device='cuda', enabled=True)
print("✅ FP16 mode: Using GradScaler\n")
elif use_bf16:
scaler = None
print("⚠️ BF16 mode: Disabling GradScaler (not needed/supported)\n")
else: # FP32
scaler = None
print("✅ FP32 mode: No scaler needed\n")
# Expert rotation
current_expert = 0
rotation_steps = config["training"]["expert_rotation_steps"]
# Check if we should train all experts simultaneously
train_all_experts = config["training"].get("train_all_experts", False)
if train_all_experts:
print("🎯 Training ALL experts simultaneously\n")
model.set_active_expert(None) # None = all experts active
else:
print(f"🎯 Training expert {current_expert}/{model_args.n_routed_experts - 1} (sequential mode)\n")
model.set_active_expert(current_expert)
# Define variables
accum_steps = config["training"]["gradient_accumulation_steps"]
total_steps = config["training"]["total_steps"]
grad_clip = config["training"]["grad_clip"]
print("\n" + "="*70)
print("TRAINING STARTED")
print("="*70 + "\n")
model.train()
# MAIN TRAINING LOOP
while step < total_steps:
step_start = time.time()
# Expert rotation (only in sequential mode)
if not train_all_experts and step > 0 and step % rotation_steps == 0:
current_expert = (current_expert + 1) % model_args.n_routed_experts
model.set_active_expert(current_expert)
print(f"\n🔄 Rotating to expert {current_expert}/{model_args.n_routed_experts - 1}")
optimizer.zero_grad(set_to_none=True)
# Get batch
try:
batch = next(train_iter)
except StopIteration:
train_iter = iter(train_loader)
batch = next(train_iter)
# Split batch
input_ids, target_ids = batch
batch_size = input_ids.size(0)
micro_batch_size = batch_size // accum_steps
# Initialize accumulators
lm_loss_accum = 0.0
lb_loss_accum = 0.0
# Gradient accumulation loop
for accum_step in range(accum_steps):
# Calculate slice indices
start_idx = micro_batch_size * accum_step
# Handle last micro-batch
if accum_step == accum_steps - 1:
end_idx = batch_size
else:
end_idx = start_idx + micro_batch_size
# Extract micro-batch
input_mb = input_ids[start_idx:end_idx]
target_mb = target_ids[start_idx:end_idx]
# Process micro-batch
lm_loss, lb_loss = train_step(
model, input_mb, target_mb, device, config, scaler
)
# Accumulate losses
lm_loss_accum += lm_loss / accum_steps
lb_loss_accum += lb_loss / accum_steps
# Gradient clipping (if enabled)
if grad_clip > 0:
# Only unscale if using FP16 scaler
if scaler is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
# ✅ FIX: Conditional optimizer step
if scaler is not None:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
# LR scheduling
lr = get_lr(step, config)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
# Logging
if step % config["training"]["log_every"] == 0:
step_time = time.time() - step_start
tokens_per_sec = (batch_size * model_args.max_seq_len) / step_time
print(f"Step {step:6d} | "
f"Loss: {lm_loss_accum:.4f} | "
f"LB Loss: {lb_loss_accum:.4f} | "
f"LR: {lr:.2e} | "
f"Expert: {current_expert} | "
f"Tokens/s: {tokens_per_sec:.0f}")
if config["logging"]["use_wandb"] and HAS_WANDB:
wandb.log({
"step": step,
"loss": lm_loss_accum,
"load_balance_loss": lb_loss_accum,
"learning_rate": lr,
"active_expert": current_expert,
"tokens_per_sec": tokens_per_sec,
"gpu_memory_gb": torch.cuda.memory_allocated() / 1024**3,
})
# Evaluation
if step % config["training"]["eval_every"] == 0 and step > 0:
print(f"\n📊 Evaluating at step {step}...")
if train_all_experts:
# In all-experts mode, just validate with all experts
val_loss = evaluate(model, val_loader, device, config, tokenizer, active_expert=None)
print(f"Val Loss: {val_loss:.4f} | Perplexity: {math.exp(val_loss):.2f}\n")
if config["logging"]["use_wandb"] and HAS_WANDB:
wandb.log({"val_loss": val_loss, "val_perplexity": math.exp(val_loss)})
if val_loss < best_val_loss:
best_val_loss = val_loss
save_checkpoint(model, optimizer, step, config, expert_idx="best")
else:
# In sequential mode, validate both per-expert and all-experts
val_loss_active = evaluate(model, val_loader, device, config, tokenizer, active_expert=current_expert)
print(f"Val Loss (Expert {current_expert}): {val_loss_active:.4f} | Perplexity: {math.exp(val_loss_active):.2f}")
val_loss_all = evaluate(model, val_loader, device, config, tokenizer, active_expert=None)
print(f"Val Loss (All Experts): {val_loss_all:.4f} | Perplexity: {math.exp(val_loss_all):.2f}\n")
if config["logging"]["use_wandb"] and HAS_WANDB:
wandb.log({
f"val_loss_expert_{current_expert}": val_loss_active,
f"val_perplexity_expert_{current_expert}": math.exp(val_loss_active),
"val_loss_all_experts": val_loss_all,
"val_perplexity_all_experts": math.exp(val_loss_all)
})
# Save best based on active expert performance
if val_loss_active < best_val_loss:
best_val_loss = val_loss_active
save_checkpoint(model, optimizer, step, config, expert_idx="best")
# Save checkpoint
if step % config["training"]["save_every"] == 0 and step > 0:
save_checkpoint(model, optimizer, step, config, expert_idx=current_expert)
step += 1
# Final save
save_checkpoint(model, optimizer, step, config, expert_idx="final")
if config["logging"]["use_wandb"] and HAS_WANDB:
wandb.finish()
print("\n" + "="*70)
print("TRAINING COMPLETED")
print("="*70)
if __name__ == "__main__":
main()