ikaganacar commited on
Commit
431091f
·
1 Parent(s): 83a7e87

Fix Fix Fix

Browse files
Files changed (1) hide show
  1. Model_Architecture/train.py +58 -58
Model_Architecture/train.py CHANGED
@@ -293,24 +293,12 @@ def save_checkpoint(model, optimizer, step, config, expert_idx=None):
293
  print(f"💾 Checkpoint saved: {ckpt_path}")
294
 
295
 
296
- def train_step(model, batch, device, config, accum_step, accum_steps, scaler=None):
297
- """Process a SINGLE micro-batch for gradient accumulation"""
298
- input_ids, target_ids = batch
299
 
300
- batch_size = input_ids.size(0)
301
- micro_batch_size = max(1, batch_size // accum_steps)
302
-
303
- # Calculate slice indices
304
- start_idx = micro_batch_size * accum_step
305
- end_idx = min(start_idx + micro_batch_size, batch_size)
306
-
307
- # 🚨 CRITICAL: Skip if this micro-batch is empty (last iteration)
308
- if start_idx >= batch_size:
309
- return 0.0, 0.0 # Return zero loss, will be divided later
310
-
311
- # Get micro-batch slices
312
- input_mb = input_ids[start_idx:end_idx].to(device, non_blocking=True)
313
- target_mb = target_ids[start_idx:end_idx].to(device, non_blocking=True)
314
 
315
  # Forward pass
316
  with torch.amp.autocast(device_type='cuda', enabled=(config["training"]["dtype"] == "bf16")):
@@ -328,18 +316,20 @@ def train_step(model, batch, device, config, accum_step, accum_steps, scaler=Non
328
  ignore_index=-1,
329
  )
330
 
 
331
  if isinstance(lb_loss, float):
332
- total_loss = lm_loss / accum_steps
333
  else:
334
  lb_loss_coef = config["training"].get("lb_loss_coef", 0.01)
335
- total_loss = (lm_loss + lb_loss_coef * lb_loss) / accum_steps
336
 
337
- # Backward
338
  if config["training"]["dtype"] == "bf16":
339
  scaler.scale(total_loss).backward()
340
  else:
341
  total_loss.backward()
342
 
 
343
  return lm_loss.item(), lb_loss if isinstance(lb_loss, float) else lb_loss.item()
344
 
345
 
@@ -349,19 +339,13 @@ def main():
349
 
350
  # Device setup
351
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
352
-
353
- # Enable TF32 for better performance on Ampere+ GPUs (using new API)
354
- if torch.cuda.is_available():
355
- torch.backends.cudnn.conv.fp32_precision = 'tf32'
356
- torch.backends.cuda.matmul.fp32_precision = 'tf32'
357
 
358
  # Wandb setup
359
  if config["logging"]["use_wandb"] and HAS_WANDB:
360
- wandb.init(
361
- project=config["logging"]["project_name"],
362
- name=config["logging"]["run_name"],
363
- config=config,
364
- )
365
 
366
  # Model setup
367
  model, model_args = setup_model(config, device)
@@ -385,25 +369,29 @@ def main():
385
  step = ckpt["step"]
386
  print(f"✅ Resumed from step {step}\n")
387
 
388
- # Gradient scaler for mixed precision (using new torch.amp API)
389
  scaler = torch.amp.GradScaler(device='cuda', enabled=(config["training"]["dtype"] == "bf16"))
390
 
391
- # Expert rotation schedule
392
  current_expert = 0
393
  rotation_steps = config["training"]["expert_rotation_steps"]
394
-
395
- # Set initial expert
396
  model.set_active_expert(current_expert)
397
  print(f"🎯 Training expert {current_expert}/{model_args.n_routed_experts - 1}")
398
 
399
- # Training loop
 
 
 
 
 
400
  print("\n" + "="*70)
401
  print("TRAINING STARTED")
402
  print("="*70 + "\n")
403
 
404
  model.train()
405
 
406
- while step < config["training"]["total_steps"]:
 
407
  step_start = time.time()
408
 
409
  # Expert rotation
@@ -411,42 +399,56 @@ def main():
411
  current_expert = (current_expert + 1) % model_args.n_routed_experts
412
  model.set_active_expert(current_expert)
413
  print(f"\n🔄 Rotating to expert {current_expert}/{model_args.n_routed_experts - 1}")
414
-
415
- # Clear gradients after rotation
416
  optimizer.zero_grad(set_to_none=True)
417
 
418
- # Get batch with cycle handling
419
  try:
420
  batch = next(train_iter)
421
  except StopIteration:
422
  train_iter = iter(train_loader)
423
  batch = next(train_iter)
424
 
425
- # Training step with gradient accumulation
426
- accum_steps = config["training"]["gradient_accumulation_steps"]
427
- total_loss_accum = 0.0
 
 
 
428
  lm_loss_accum = 0.0
429
  lb_loss_accum = 0.0
430
-
 
431
  for accum_step in range(accum_steps):
432
- lm_loss, lb_loss = train_step(model, batch, device, config,
433
- accum_step, accum_steps, scaler)
434
 
435
- # Only accumulate if not empty
436
- if lm_loss > 0:
437
- lm_loss_accum += lm_loss / accum_steps
438
- lb_loss_accum += lb_loss / accum_steps
439
- total_loss_accum += (lm_loss + lb_loss) / accum_steps
440
-
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
  # Gradient clipping
443
- if config["training"]["grad_clip"] > 0:
444
- if config["training"]["dtype"] == "bf16":
445
  scaler.unscale_(optimizer)
446
- torch.nn.utils.clip_grad_norm_(model.parameters(), config["training"]["grad_clip"])
447
 
448
  # Optimizer step
449
- if config["training"]["dtype"] == "bf16":
450
  scaler.step(optimizer)
451
  scaler.update()
452
  else:
@@ -462,7 +464,7 @@ def main():
462
  # Logging
463
  if step % config["training"]["log_every"] == 0:
464
  step_time = time.time() - step_start
465
- tokens_per_sec = (model_args.max_batch_size * model_args.max_seq_len) / step_time
466
 
467
  print(f"Step {step:6d} | "
468
  f"Loss: {lm_loss_accum:.4f} | "
@@ -476,7 +478,6 @@ def main():
476
  "step": step,
477
  "loss": lm_loss_accum,
478
  "load_balance_loss": lb_loss_accum,
479
- "total_loss": total_loss_accum,
480
  "learning_rate": lr,
481
  "active_expert": current_expert,
482
  "tokens_per_sec": tokens_per_sec,
@@ -492,7 +493,6 @@ def main():
492
  if config["logging"]["use_wandb"] and HAS_WANDB:
493
  wandb.log({"val_loss": val_loss, "val_perplexity": math.exp(val_loss)})
494
 
495
- # Save best model
496
  if val_loss < best_val_loss:
497
  best_val_loss = val_loss
498
  save_checkpoint(model, optimizer, step, config, expert_idx="best")
 
293
  print(f"💾 Checkpoint saved: {ckpt_path}")
294
 
295
 
296
+ def train_step(model, input_mb, target_mb, device, config, scaler=None):
297
+ """Process a SINGLE micro-batch (already sliced)"""
 
298
 
299
+ # Move data to device
300
+ input_mb = input_mb.to(device, non_blocking=True)
301
+ target_mb = target_mb.to(device, non_blocking=True)
 
 
 
 
 
 
 
 
 
 
 
302
 
303
  # Forward pass
304
  with torch.amp.autocast(device_type='cuda', enabled=(config["training"]["dtype"] == "bf16")):
 
316
  ignore_index=-1,
317
  )
318
 
319
+ # Normalize for accumulation (divide by accum_steps)
320
  if isinstance(lb_loss, float):
321
+ total_loss = lm_loss / config["training"]["gradient_accumulation_steps"]
322
  else:
323
  lb_loss_coef = config["training"].get("lb_loss_coef", 0.01)
324
+ total_loss = (lm_loss + lb_loss_coef * lb_loss) / config["training"]["gradient_accumulation_steps"]
325
 
326
+ # Backward pass (automatically frees graph after backward)
327
  if config["training"]["dtype"] == "bf16":
328
  scaler.scale(total_loss).backward()
329
  else:
330
  total_loss.backward()
331
 
332
+ # Return raw values for logging
333
  return lm_loss.item(), lb_loss if isinstance(lb_loss, float) else lb_loss.item()
334
 
335
 
 
339
 
340
  # Device setup
341
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
342
+ torch.backends.cudnn.conv.fp32_precision = 'tf32'
343
+ torch.backends.cuda.matmul.fp32_precision = 'tf32'
 
 
 
344
 
345
  # Wandb setup
346
  if config["logging"]["use_wandb"] and HAS_WANDB:
347
+ wandb.init(project=config["logging"]["project_name"],
348
+ name=config["logging"]["run_name"], config=config)
 
 
 
349
 
350
  # Model setup
351
  model, model_args = setup_model(config, device)
 
369
  step = ckpt["step"]
370
  print(f"✅ Resumed from step {step}\n")
371
 
372
+ # Gradient scaler
373
  scaler = torch.amp.GradScaler(device='cuda', enabled=(config["training"]["dtype"] == "bf16"))
374
 
375
+ # Expert rotation
376
  current_expert = 0
377
  rotation_steps = config["training"]["expert_rotation_steps"]
 
 
378
  model.set_active_expert(current_expert)
379
  print(f"🎯 Training expert {current_expert}/{model_args.n_routed_experts - 1}")
380
 
381
+ # DEFINE VARIABLES HERE - outside the loop
382
+ accum_steps = config["training"]["gradient_accumulation_steps"]
383
+ total_steps = config["training"]["total_steps"]
384
+ grad_clip = config["training"]["grad_clip"]
385
+ dtype_bf16 = config["training"]["dtype"] == "bf16"
386
+
387
  print("\n" + "="*70)
388
  print("TRAINING STARTED")
389
  print("="*70 + "\n")
390
 
391
  model.train()
392
 
393
+ # MAIN TRAINING LOOP
394
+ while step < total_steps:
395
  step_start = time.time()
396
 
397
  # Expert rotation
 
399
  current_expert = (current_expert + 1) % model_args.n_routed_experts
400
  model.set_active_expert(current_expert)
401
  print(f"\n🔄 Rotating to expert {current_expert}/{model_args.n_routed_experts - 1}")
 
 
402
  optimizer.zero_grad(set_to_none=True)
403
 
404
+ # Get batch
405
  try:
406
  batch = next(train_iter)
407
  except StopIteration:
408
  train_iter = iter(train_loader)
409
  batch = next(train_iter)
410
 
411
+ # SPLIT BATCH OUTSIDE ACCUMULATION LOOP
412
+ input_ids, target_ids = batch
413
+ batch_size = input_ids.size(0)
414
+ micro_batch_size = batch_size // accum_steps
415
+
416
+ # Initialize accumulators
417
  lm_loss_accum = 0.0
418
  lb_loss_accum = 0.0
419
+
420
+ # ✅ GRADIENT ACCUMULATION LOOP
421
  for accum_step in range(accum_steps):
422
+ # Calculate slice indices
423
+ start_idx = micro_batch_size * accum_step
424
 
425
+ # Handle last micro-batch (includes remainder)
426
+ if accum_step == accum_steps - 1:
427
+ end_idx = batch_size
428
+ else:
429
+ end_idx = start_idx + micro_batch_size
430
+
431
+ # Extract micro-batch
432
+ input_mb = input_ids[start_idx:end_idx]
433
+ target_mb = target_ids[start_idx:end_idx]
434
+
435
+ # Process micro-batch
436
+ lm_loss, lb_loss = train_step(
437
+ model, input_mb, target_mb, device, config, scaler
438
+ )
439
+
440
+ # Accumulate losses (normalized by accum_steps)
441
+ lm_loss_accum += lm_loss / accum_steps
442
+ lb_loss_accum += lb_loss / accum_steps
443
 
444
  # Gradient clipping
445
+ if grad_clip > 0:
446
+ if dtype_bf16:
447
  scaler.unscale_(optimizer)
448
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
449
 
450
  # Optimizer step
451
+ if dtype_bf16:
452
  scaler.step(optimizer)
453
  scaler.update()
454
  else:
 
464
  # Logging
465
  if step % config["training"]["log_every"] == 0:
466
  step_time = time.time() - step_start
467
+ tokens_per_sec = (batch_size * model_args.max_seq_len) / step_time
468
 
469
  print(f"Step {step:6d} | "
470
  f"Loss: {lm_loss_accum:.4f} | "
 
478
  "step": step,
479
  "loss": lm_loss_accum,
480
  "load_balance_loss": lb_loss_accum,
 
481
  "learning_rate": lr,
482
  "active_expert": current_expert,
483
  "tokens_per_sec": tokens_per_sec,
 
493
  if config["logging"]["use_wandb"] and HAS_WANDB:
494
  wandb.log({"val_loss": val_loss, "val_perplexity": math.exp(val_loss)})
495
 
 
496
  if val_loss < best_val_loss:
497
  best_val_loss = val_loss
498
  save_checkpoint(model, optimizer, step, config, expert_idx="best")