ikaganacar commited on
Commit
4e97b2c
·
1 Parent(s): 986639c
Files changed (1) hide show
  1. Model_Architecture/train.py +21 -14
Model_Architecture/train.py CHANGED
@@ -368,8 +368,13 @@ def main():
368
  step = ckpt["step"]
369
  print(f"✅ Resumed from step {step}\n")
370
 
371
- # Gradient scaler
372
- scaler = torch.amp.GradScaler(device='cuda', enabled=(config["training"]["dtype"] == "bf16"))
 
 
 
 
 
373
 
374
  # Expert rotation
375
  current_expert = 0
@@ -377,11 +382,10 @@ def main():
377
  model.set_active_expert(current_expert)
378
  print(f"🎯 Training expert {current_expert}/{model_args.n_routed_experts - 1}")
379
 
380
- # DEFINE VARIABLES HERE - outside the loop
381
  accum_steps = config["training"]["gradient_accumulation_steps"]
382
  total_steps = config["training"]["total_steps"]
383
  grad_clip = config["training"]["grad_clip"]
384
- dtype_bf16 = config["training"]["dtype"] == "bf16"
385
 
386
  print("\n" + "="*70)
387
  print("TRAINING STARTED")
@@ -389,7 +393,7 @@ def main():
389
 
390
  model.train()
391
 
392
- # MAIN TRAINING LOOP
393
  while step < total_steps:
394
  step_start = time.time()
395
 
@@ -407,7 +411,7 @@ def main():
407
  train_iter = iter(train_loader)
408
  batch = next(train_iter)
409
 
410
- # SPLIT BATCH OUTSIDE ACCUMULATION LOOP
411
  input_ids, target_ids = batch
412
  batch_size = input_ids.size(0)
413
  micro_batch_size = batch_size // accum_steps
@@ -416,12 +420,12 @@ def main():
416
  lm_loss_accum = 0.0
417
  lb_loss_accum = 0.0
418
 
419
- # GRADIENT ACCUMULATION LOOP
420
  for accum_step in range(accum_steps):
421
  # Calculate slice indices
422
  start_idx = micro_batch_size * accum_step
423
 
424
- # Handle last micro-batch (includes remainder)
425
  if accum_step == accum_steps - 1:
426
  end_idx = batch_size
427
  else:
@@ -436,22 +440,25 @@ def main():
436
  model, input_mb, target_mb, device, config, scaler
437
  )
438
 
439
- # Accumulate losses (normalized by accum_steps)
440
  lm_loss_accum += lm_loss / accum_steps
441
  lb_loss_accum += lb_loss / accum_steps
442
 
443
- # Gradient clipping
444
  if grad_clip > 0:
445
- if dtype_bf16:
 
446
  scaler.unscale_(optimizer)
447
  torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
448
 
449
- # Optimizer step
450
  if dtype_bf16:
 
 
 
 
451
  scaler.step(optimizer)
452
  scaler.update()
453
- else:
454
- optimizer.step()
455
 
456
  optimizer.zero_grad(set_to_none=True)
457
 
 
368
  step = ckpt["step"]
369
  print(f"✅ Resumed from step {step}\n")
370
 
371
+ # FIX: Only create scaler for FP16, not BF16
372
+ dtype_bf16 = config["training"]["dtype"] == "bf16"
373
+ if dtype_bf16:
374
+ scaler = None
375
+ print("⚠️ BF16 mode: Disabling GradScaler (not needed/supported)\n")
376
+ else:
377
+ scaler = torch.amp.GradScaler(device='cuda', enabled=True)
378
 
379
  # Expert rotation
380
  current_expert = 0
 
382
  model.set_active_expert(current_expert)
383
  print(f"🎯 Training expert {current_expert}/{model_args.n_routed_experts - 1}")
384
 
385
+ # Define variables
386
  accum_steps = config["training"]["gradient_accumulation_steps"]
387
  total_steps = config["training"]["total_steps"]
388
  grad_clip = config["training"]["grad_clip"]
 
389
 
390
  print("\n" + "="*70)
391
  print("TRAINING STARTED")
 
393
 
394
  model.train()
395
 
396
+ # MAIN TRAINING LOOP
397
  while step < total_steps:
398
  step_start = time.time()
399
 
 
411
  train_iter = iter(train_loader)
412
  batch = next(train_iter)
413
 
414
+ # Split batch
415
  input_ids, target_ids = batch
416
  batch_size = input_ids.size(0)
417
  micro_batch_size = batch_size // accum_steps
 
420
  lm_loss_accum = 0.0
421
  lb_loss_accum = 0.0
422
 
423
+ # Gradient accumulation loop
424
  for accum_step in range(accum_steps):
425
  # Calculate slice indices
426
  start_idx = micro_batch_size * accum_step
427
 
428
+ # Handle last micro-batch
429
  if accum_step == accum_steps - 1:
430
  end_idx = batch_size
431
  else:
 
440
  model, input_mb, target_mb, device, config, scaler
441
  )
442
 
443
+ # Accumulate losses
444
  lm_loss_accum += lm_loss / accum_steps
445
  lb_loss_accum += lb_loss / accum_steps
446
 
447
+ # Gradient clipping (if enabled)
448
  if grad_clip > 0:
449
+ # Skip unscale for BF16
450
+ if not dtype_bf16:
451
  scaler.unscale_(optimizer)
452
  torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
453
 
454
+ # FIX: Conditional optimizer step
455
  if dtype_bf16:
456
+ # BF16: Direct step
457
+ optimizer.step()
458
+ else:
459
+ # FP16: Scaled step
460
  scaler.step(optimizer)
461
  scaler.update()
 
 
462
 
463
  optimizer.zero_grad(set_to_none=True)
464