Commit ·
431091f
1
Parent(s): 83a7e87
Fix Fix Fix
Browse files- Model_Architecture/train.py +58 -58
Model_Architecture/train.py
CHANGED
|
@@ -293,24 +293,12 @@ def save_checkpoint(model, optimizer, step, config, expert_idx=None):
|
|
| 293 |
print(f"💾 Checkpoint saved: {ckpt_path}")
|
| 294 |
|
| 295 |
|
| 296 |
-
def train_step(model,
|
| 297 |
-
"""Process a SINGLE micro-batch
|
| 298 |
-
input_ids, target_ids = batch
|
| 299 |
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
# Calculate slice indices
|
| 304 |
-
start_idx = micro_batch_size * accum_step
|
| 305 |
-
end_idx = min(start_idx + micro_batch_size, batch_size)
|
| 306 |
-
|
| 307 |
-
# 🚨 CRITICAL: Skip if this micro-batch is empty (last iteration)
|
| 308 |
-
if start_idx >= batch_size:
|
| 309 |
-
return 0.0, 0.0 # Return zero loss, will be divided later
|
| 310 |
-
|
| 311 |
-
# Get micro-batch slices
|
| 312 |
-
input_mb = input_ids[start_idx:end_idx].to(device, non_blocking=True)
|
| 313 |
-
target_mb = target_ids[start_idx:end_idx].to(device, non_blocking=True)
|
| 314 |
|
| 315 |
# Forward pass
|
| 316 |
with torch.amp.autocast(device_type='cuda', enabled=(config["training"]["dtype"] == "bf16")):
|
|
@@ -328,18 +316,20 @@ def train_step(model, batch, device, config, accum_step, accum_steps, scaler=Non
|
|
| 328 |
ignore_index=-1,
|
| 329 |
)
|
| 330 |
|
|
|
|
| 331 |
if isinstance(lb_loss, float):
|
| 332 |
-
total_loss = lm_loss /
|
| 333 |
else:
|
| 334 |
lb_loss_coef = config["training"].get("lb_loss_coef", 0.01)
|
| 335 |
-
total_loss = (lm_loss + lb_loss_coef * lb_loss) /
|
| 336 |
|
| 337 |
-
# Backward
|
| 338 |
if config["training"]["dtype"] == "bf16":
|
| 339 |
scaler.scale(total_loss).backward()
|
| 340 |
else:
|
| 341 |
total_loss.backward()
|
| 342 |
|
|
|
|
| 343 |
return lm_loss.item(), lb_loss if isinstance(lb_loss, float) else lb_loss.item()
|
| 344 |
|
| 345 |
|
|
@@ -349,19 +339,13 @@ def main():
|
|
| 349 |
|
| 350 |
# Device setup
|
| 351 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
if torch.cuda.is_available():
|
| 355 |
-
torch.backends.cudnn.conv.fp32_precision = 'tf32'
|
| 356 |
-
torch.backends.cuda.matmul.fp32_precision = 'tf32'
|
| 357 |
|
| 358 |
# Wandb setup
|
| 359 |
if config["logging"]["use_wandb"] and HAS_WANDB:
|
| 360 |
-
wandb.init(
|
| 361 |
-
|
| 362 |
-
name=config["logging"]["run_name"],
|
| 363 |
-
config=config,
|
| 364 |
-
)
|
| 365 |
|
| 366 |
# Model setup
|
| 367 |
model, model_args = setup_model(config, device)
|
|
@@ -385,25 +369,29 @@ def main():
|
|
| 385 |
step = ckpt["step"]
|
| 386 |
print(f"✅ Resumed from step {step}\n")
|
| 387 |
|
| 388 |
-
# Gradient scaler
|
| 389 |
scaler = torch.amp.GradScaler(device='cuda', enabled=(config["training"]["dtype"] == "bf16"))
|
| 390 |
|
| 391 |
-
# Expert rotation
|
| 392 |
current_expert = 0
|
| 393 |
rotation_steps = config["training"]["expert_rotation_steps"]
|
| 394 |
-
|
| 395 |
-
# Set initial expert
|
| 396 |
model.set_active_expert(current_expert)
|
| 397 |
print(f"🎯 Training expert {current_expert}/{model_args.n_routed_experts - 1}")
|
| 398 |
|
| 399 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
print("\n" + "="*70)
|
| 401 |
print("TRAINING STARTED")
|
| 402 |
print("="*70 + "\n")
|
| 403 |
|
| 404 |
model.train()
|
| 405 |
|
| 406 |
-
|
|
|
|
| 407 |
step_start = time.time()
|
| 408 |
|
| 409 |
# Expert rotation
|
|
@@ -411,42 +399,56 @@ def main():
|
|
| 411 |
current_expert = (current_expert + 1) % model_args.n_routed_experts
|
| 412 |
model.set_active_expert(current_expert)
|
| 413 |
print(f"\n🔄 Rotating to expert {current_expert}/{model_args.n_routed_experts - 1}")
|
| 414 |
-
|
| 415 |
-
# Clear gradients after rotation
|
| 416 |
optimizer.zero_grad(set_to_none=True)
|
| 417 |
|
| 418 |
-
# Get batch
|
| 419 |
try:
|
| 420 |
batch = next(train_iter)
|
| 421 |
except StopIteration:
|
| 422 |
train_iter = iter(train_loader)
|
| 423 |
batch = next(train_iter)
|
| 424 |
|
| 425 |
-
#
|
| 426 |
-
|
| 427 |
-
|
|
|
|
|
|
|
|
|
|
| 428 |
lm_loss_accum = 0.0
|
| 429 |
lb_loss_accum = 0.0
|
| 430 |
-
|
|
|
|
| 431 |
for accum_step in range(accum_steps):
|
| 432 |
-
|
| 433 |
-
|
| 434 |
|
| 435 |
-
#
|
| 436 |
-
if
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
|
| 442 |
# Gradient clipping
|
| 443 |
-
if
|
| 444 |
-
if
|
| 445 |
scaler.unscale_(optimizer)
|
| 446 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(),
|
| 447 |
|
| 448 |
# Optimizer step
|
| 449 |
-
if
|
| 450 |
scaler.step(optimizer)
|
| 451 |
scaler.update()
|
| 452 |
else:
|
|
@@ -462,7 +464,7 @@ def main():
|
|
| 462 |
# Logging
|
| 463 |
if step % config["training"]["log_every"] == 0:
|
| 464 |
step_time = time.time() - step_start
|
| 465 |
-
tokens_per_sec = (
|
| 466 |
|
| 467 |
print(f"Step {step:6d} | "
|
| 468 |
f"Loss: {lm_loss_accum:.4f} | "
|
|
@@ -476,7 +478,6 @@ def main():
|
|
| 476 |
"step": step,
|
| 477 |
"loss": lm_loss_accum,
|
| 478 |
"load_balance_loss": lb_loss_accum,
|
| 479 |
-
"total_loss": total_loss_accum,
|
| 480 |
"learning_rate": lr,
|
| 481 |
"active_expert": current_expert,
|
| 482 |
"tokens_per_sec": tokens_per_sec,
|
|
@@ -492,7 +493,6 @@ def main():
|
|
| 492 |
if config["logging"]["use_wandb"] and HAS_WANDB:
|
| 493 |
wandb.log({"val_loss": val_loss, "val_perplexity": math.exp(val_loss)})
|
| 494 |
|
| 495 |
-
# Save best model
|
| 496 |
if val_loss < best_val_loss:
|
| 497 |
best_val_loss = val_loss
|
| 498 |
save_checkpoint(model, optimizer, step, config, expert_idx="best")
|
|
|
|
| 293 |
print(f"💾 Checkpoint saved: {ckpt_path}")
|
| 294 |
|
| 295 |
|
| 296 |
+
def train_step(model, input_mb, target_mb, device, config, scaler=None):
|
| 297 |
+
"""Process a SINGLE micro-batch (already sliced)"""
|
|
|
|
| 298 |
|
| 299 |
+
# Move data to device
|
| 300 |
+
input_mb = input_mb.to(device, non_blocking=True)
|
| 301 |
+
target_mb = target_mb.to(device, non_blocking=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
# Forward pass
|
| 304 |
with torch.amp.autocast(device_type='cuda', enabled=(config["training"]["dtype"] == "bf16")):
|
|
|
|
| 316 |
ignore_index=-1,
|
| 317 |
)
|
| 318 |
|
| 319 |
+
# Normalize for accumulation (divide by accum_steps)
|
| 320 |
if isinstance(lb_loss, float):
|
| 321 |
+
total_loss = lm_loss / config["training"]["gradient_accumulation_steps"]
|
| 322 |
else:
|
| 323 |
lb_loss_coef = config["training"].get("lb_loss_coef", 0.01)
|
| 324 |
+
total_loss = (lm_loss + lb_loss_coef * lb_loss) / config["training"]["gradient_accumulation_steps"]
|
| 325 |
|
| 326 |
+
# Backward pass (automatically frees graph after backward)
|
| 327 |
if config["training"]["dtype"] == "bf16":
|
| 328 |
scaler.scale(total_loss).backward()
|
| 329 |
else:
|
| 330 |
total_loss.backward()
|
| 331 |
|
| 332 |
+
# Return raw values for logging
|
| 333 |
return lm_loss.item(), lb_loss if isinstance(lb_loss, float) else lb_loss.item()
|
| 334 |
|
| 335 |
|
|
|
|
| 339 |
|
| 340 |
# Device setup
|
| 341 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 342 |
+
torch.backends.cudnn.conv.fp32_precision = 'tf32'
|
| 343 |
+
torch.backends.cuda.matmul.fp32_precision = 'tf32'
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
# Wandb setup
|
| 346 |
if config["logging"]["use_wandb"] and HAS_WANDB:
|
| 347 |
+
wandb.init(project=config["logging"]["project_name"],
|
| 348 |
+
name=config["logging"]["run_name"], config=config)
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
# Model setup
|
| 351 |
model, model_args = setup_model(config, device)
|
|
|
|
| 369 |
step = ckpt["step"]
|
| 370 |
print(f"✅ Resumed from step {step}\n")
|
| 371 |
|
| 372 |
+
# Gradient scaler
|
| 373 |
scaler = torch.amp.GradScaler(device='cuda', enabled=(config["training"]["dtype"] == "bf16"))
|
| 374 |
|
| 375 |
+
# Expert rotation
|
| 376 |
current_expert = 0
|
| 377 |
rotation_steps = config["training"]["expert_rotation_steps"]
|
|
|
|
|
|
|
| 378 |
model.set_active_expert(current_expert)
|
| 379 |
print(f"🎯 Training expert {current_expert}/{model_args.n_routed_experts - 1}")
|
| 380 |
|
| 381 |
+
# ✅ DEFINE VARIABLES HERE - outside the loop
|
| 382 |
+
accum_steps = config["training"]["gradient_accumulation_steps"]
|
| 383 |
+
total_steps = config["training"]["total_steps"]
|
| 384 |
+
grad_clip = config["training"]["grad_clip"]
|
| 385 |
+
dtype_bf16 = config["training"]["dtype"] == "bf16"
|
| 386 |
+
|
| 387 |
print("\n" + "="*70)
|
| 388 |
print("TRAINING STARTED")
|
| 389 |
print("="*70 + "\n")
|
| 390 |
|
| 391 |
model.train()
|
| 392 |
|
| 393 |
+
# ✅ MAIN TRAINING LOOP
|
| 394 |
+
while step < total_steps:
|
| 395 |
step_start = time.time()
|
| 396 |
|
| 397 |
# Expert rotation
|
|
|
|
| 399 |
current_expert = (current_expert + 1) % model_args.n_routed_experts
|
| 400 |
model.set_active_expert(current_expert)
|
| 401 |
print(f"\n🔄 Rotating to expert {current_expert}/{model_args.n_routed_experts - 1}")
|
|
|
|
|
|
|
| 402 |
optimizer.zero_grad(set_to_none=True)
|
| 403 |
|
| 404 |
+
# Get batch
|
| 405 |
try:
|
| 406 |
batch = next(train_iter)
|
| 407 |
except StopIteration:
|
| 408 |
train_iter = iter(train_loader)
|
| 409 |
batch = next(train_iter)
|
| 410 |
|
| 411 |
+
# ✅ SPLIT BATCH OUTSIDE ACCUMULATION LOOP
|
| 412 |
+
input_ids, target_ids = batch
|
| 413 |
+
batch_size = input_ids.size(0)
|
| 414 |
+
micro_batch_size = batch_size // accum_steps
|
| 415 |
+
|
| 416 |
+
# Initialize accumulators
|
| 417 |
lm_loss_accum = 0.0
|
| 418 |
lb_loss_accum = 0.0
|
| 419 |
+
|
| 420 |
+
# ✅ GRADIENT ACCUMULATION LOOP
|
| 421 |
for accum_step in range(accum_steps):
|
| 422 |
+
# Calculate slice indices
|
| 423 |
+
start_idx = micro_batch_size * accum_step
|
| 424 |
|
| 425 |
+
# Handle last micro-batch (includes remainder)
|
| 426 |
+
if accum_step == accum_steps - 1:
|
| 427 |
+
end_idx = batch_size
|
| 428 |
+
else:
|
| 429 |
+
end_idx = start_idx + micro_batch_size
|
| 430 |
+
|
| 431 |
+
# Extract micro-batch
|
| 432 |
+
input_mb = input_ids[start_idx:end_idx]
|
| 433 |
+
target_mb = target_ids[start_idx:end_idx]
|
| 434 |
+
|
| 435 |
+
# Process micro-batch
|
| 436 |
+
lm_loss, lb_loss = train_step(
|
| 437 |
+
model, input_mb, target_mb, device, config, scaler
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Accumulate losses (normalized by accum_steps)
|
| 441 |
+
lm_loss_accum += lm_loss / accum_steps
|
| 442 |
+
lb_loss_accum += lb_loss / accum_steps
|
| 443 |
|
| 444 |
# Gradient clipping
|
| 445 |
+
if grad_clip > 0:
|
| 446 |
+
if dtype_bf16:
|
| 447 |
scaler.unscale_(optimizer)
|
| 448 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
| 449 |
|
| 450 |
# Optimizer step
|
| 451 |
+
if dtype_bf16:
|
| 452 |
scaler.step(optimizer)
|
| 453 |
scaler.update()
|
| 454 |
else:
|
|
|
|
| 464 |
# Logging
|
| 465 |
if step % config["training"]["log_every"] == 0:
|
| 466 |
step_time = time.time() - step_start
|
| 467 |
+
tokens_per_sec = (batch_size * model_args.max_seq_len) / step_time
|
| 468 |
|
| 469 |
print(f"Step {step:6d} | "
|
| 470 |
f"Loss: {lm_loss_accum:.4f} | "
|
|
|
|
| 478 |
"step": step,
|
| 479 |
"loss": lm_loss_accum,
|
| 480 |
"load_balance_loss": lb_loss_accum,
|
|
|
|
| 481 |
"learning_rate": lr,
|
| 482 |
"active_expert": current_expert,
|
| 483 |
"tokens_per_sec": tokens_per_sec,
|
|
|
|
| 493 |
if config["logging"]["use_wandb"] and HAS_WANDB:
|
| 494 |
wandb.log({"val_loss": val_loss, "val_perplexity": math.exp(val_loss)})
|
| 495 |
|
|
|
|
| 496 |
if val_loss < best_val_loss:
|
| 497 |
best_val_loss = val_loss
|
| 498 |
save_checkpoint(model, optimizer, step, config, expert_idx="best")
|