Commit
·
4e97b2c
1
Parent(s):
986639c
Fixes
Browse files- Model_Architecture/train.py +21 -14
Model_Architecture/train.py
CHANGED
|
@@ -368,8 +368,13 @@ def main():
|
|
| 368 |
step = ckpt["step"]
|
| 369 |
print(f"✅ Resumed from step {step}\n")
|
| 370 |
|
| 371 |
-
#
|
| 372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
# Expert rotation
|
| 375 |
current_expert = 0
|
|
@@ -377,11 +382,10 @@ def main():
|
|
| 377 |
model.set_active_expert(current_expert)
|
| 378 |
print(f"🎯 Training expert {current_expert}/{model_args.n_routed_experts - 1}")
|
| 379 |
|
| 380 |
-
#
|
| 381 |
accum_steps = config["training"]["gradient_accumulation_steps"]
|
| 382 |
total_steps = config["training"]["total_steps"]
|
| 383 |
grad_clip = config["training"]["grad_clip"]
|
| 384 |
-
dtype_bf16 = config["training"]["dtype"] == "bf16"
|
| 385 |
|
| 386 |
print("\n" + "="*70)
|
| 387 |
print("TRAINING STARTED")
|
|
@@ -389,7 +393,7 @@ def main():
|
|
| 389 |
|
| 390 |
model.train()
|
| 391 |
|
| 392 |
-
#
|
| 393 |
while step < total_steps:
|
| 394 |
step_start = time.time()
|
| 395 |
|
|
@@ -407,7 +411,7 @@ def main():
|
|
| 407 |
train_iter = iter(train_loader)
|
| 408 |
batch = next(train_iter)
|
| 409 |
|
| 410 |
-
#
|
| 411 |
input_ids, target_ids = batch
|
| 412 |
batch_size = input_ids.size(0)
|
| 413 |
micro_batch_size = batch_size // accum_steps
|
|
@@ -416,12 +420,12 @@ def main():
|
|
| 416 |
lm_loss_accum = 0.0
|
| 417 |
lb_loss_accum = 0.0
|
| 418 |
|
| 419 |
-
#
|
| 420 |
for accum_step in range(accum_steps):
|
| 421 |
# Calculate slice indices
|
| 422 |
start_idx = micro_batch_size * accum_step
|
| 423 |
|
| 424 |
-
# Handle last micro-batch
|
| 425 |
if accum_step == accum_steps - 1:
|
| 426 |
end_idx = batch_size
|
| 427 |
else:
|
|
@@ -436,22 +440,25 @@ def main():
|
|
| 436 |
model, input_mb, target_mb, device, config, scaler
|
| 437 |
)
|
| 438 |
|
| 439 |
-
# Accumulate losses
|
| 440 |
lm_loss_accum += lm_loss / accum_steps
|
| 441 |
lb_loss_accum += lb_loss / accum_steps
|
| 442 |
|
| 443 |
-
# Gradient clipping
|
| 444 |
if grad_clip > 0:
|
| 445 |
-
|
|
|
|
| 446 |
scaler.unscale_(optimizer)
|
| 447 |
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
| 448 |
|
| 449 |
-
#
|
| 450 |
if dtype_bf16:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
scaler.step(optimizer)
|
| 452 |
scaler.update()
|
| 453 |
-
else:
|
| 454 |
-
optimizer.step()
|
| 455 |
|
| 456 |
optimizer.zero_grad(set_to_none=True)
|
| 457 |
|
|
|
|
| 368 |
step = ckpt["step"]
|
| 369 |
print(f"✅ Resumed from step {step}\n")
|
| 370 |
|
| 371 |
+
# ✅ FIX: Only create scaler for FP16, not BF16
|
| 372 |
+
dtype_bf16 = config["training"]["dtype"] == "bf16"
|
| 373 |
+
if dtype_bf16:
|
| 374 |
+
scaler = None
|
| 375 |
+
print("⚠️ BF16 mode: Disabling GradScaler (not needed/supported)\n")
|
| 376 |
+
else:
|
| 377 |
+
scaler = torch.amp.GradScaler(device='cuda', enabled=True)
|
| 378 |
|
| 379 |
# Expert rotation
|
| 380 |
current_expert = 0
|
|
|
|
| 382 |
model.set_active_expert(current_expert)
|
| 383 |
print(f"🎯 Training expert {current_expert}/{model_args.n_routed_experts - 1}")
|
| 384 |
|
| 385 |
+
# Define variables
|
| 386 |
accum_steps = config["training"]["gradient_accumulation_steps"]
|
| 387 |
total_steps = config["training"]["total_steps"]
|
| 388 |
grad_clip = config["training"]["grad_clip"]
|
|
|
|
| 389 |
|
| 390 |
print("\n" + "="*70)
|
| 391 |
print("TRAINING STARTED")
|
|
|
|
| 393 |
|
| 394 |
model.train()
|
| 395 |
|
| 396 |
+
# MAIN TRAINING LOOP
|
| 397 |
while step < total_steps:
|
| 398 |
step_start = time.time()
|
| 399 |
|
|
|
|
| 411 |
train_iter = iter(train_loader)
|
| 412 |
batch = next(train_iter)
|
| 413 |
|
| 414 |
+
# Split batch
|
| 415 |
input_ids, target_ids = batch
|
| 416 |
batch_size = input_ids.size(0)
|
| 417 |
micro_batch_size = batch_size // accum_steps
|
|
|
|
| 420 |
lm_loss_accum = 0.0
|
| 421 |
lb_loss_accum = 0.0
|
| 422 |
|
| 423 |
+
# Gradient accumulation loop
|
| 424 |
for accum_step in range(accum_steps):
|
| 425 |
# Calculate slice indices
|
| 426 |
start_idx = micro_batch_size * accum_step
|
| 427 |
|
| 428 |
+
# Handle last micro-batch
|
| 429 |
if accum_step == accum_steps - 1:
|
| 430 |
end_idx = batch_size
|
| 431 |
else:
|
|
|
|
| 440 |
model, input_mb, target_mb, device, config, scaler
|
| 441 |
)
|
| 442 |
|
| 443 |
+
# Accumulate losses
|
| 444 |
lm_loss_accum += lm_loss / accum_steps
|
| 445 |
lb_loss_accum += lb_loss / accum_steps
|
| 446 |
|
| 447 |
+
# Gradient clipping (if enabled)
|
| 448 |
if grad_clip > 0:
|
| 449 |
+
# Skip unscale for BF16
|
| 450 |
+
if not dtype_bf16:
|
| 451 |
scaler.unscale_(optimizer)
|
| 452 |
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
| 453 |
|
| 454 |
+
# ✅ FIX: Conditional optimizer step
|
| 455 |
if dtype_bf16:
|
| 456 |
+
# BF16: Direct step
|
| 457 |
+
optimizer.step()
|
| 458 |
+
else:
|
| 459 |
+
# FP16: Scaled step
|
| 460 |
scaler.step(optimizer)
|
| 461 |
scaler.update()
|
|
|
|
|
|
|
| 462 |
|
| 463 |
optimizer.zero_grad(set_to_none=True)
|
| 464 |
|