ikaganacar commited on
Commit
1461702
·
1 Parent(s): 36b9687
Files changed (1) hide show
  1. Model_Architecture/train.py +17 -6
Model_Architecture/train.py CHANGED
@@ -294,13 +294,20 @@ def save_checkpoint(model, optimizer, step, config, expert_idx=None):
294
 
295
 
296
  def train_step(model, batch, device, config, accum_step, accum_steps, scaler=None):
297
- """Process a MICRO-batch for gradient accumulation"""
298
  input_ids, target_ids = batch
299
 
300
- # Split batch into micro-batches
301
- micro_batch_size = input_ids.size(0) // accum_steps
 
 
 
302
  start_idx = micro_batch_size * accum_step
303
- end_idx = start_idx + micro_batch_size
 
 
 
 
304
 
305
  # Get micro-batch slices
306
  input_mb = input_ids[start_idx:end_idx].to(device, non_blocking=True)
@@ -425,8 +432,12 @@ def main():
425
  for accum_step in range(accum_steps):
426
  lm_loss, lb_loss = train_step(model, batch, device, config,
427
  accum_step, accum_steps, scaler)
428
- lm_loss_accum += lm_loss / accum_steps
429
- lb_loss_accum += lb_loss / accum_steps
 
 
 
 
430
 
431
 
432
  # Gradient clipping
 
294
 
295
 
296
  def train_step(model, batch, device, config, accum_step, accum_steps, scaler=None):
297
+ """Process a SINGLE micro-batch for gradient accumulation"""
298
  input_ids, target_ids = batch
299
 
300
+ batch_size = input_ids.size(0)
301
+ micro_batch_size = max(1, batch_size // accum_steps)
302
+ print(f"Batch size: {batch_size}, Micro-batch: {start_idx}:{end_idx}, Size: {input_mb.shape}")
303
+
304
+ # Calculate slice indices
305
  start_idx = micro_batch_size * accum_step
306
+ end_idx = min(start_idx + micro_batch_size, batch_size)
307
+
308
+ # 🚨 CRITICAL: Skip if this micro-batch is empty (last iteration)
309
+ if start_idx >= batch_size:
310
+ return 0.0, 0.0 # Return zero loss, will be divided later
311
 
312
  # Get micro-batch slices
313
  input_mb = input_ids[start_idx:end_idx].to(device, non_blocking=True)
 
432
  for accum_step in range(accum_steps):
433
  lm_loss, lb_loss = train_step(model, batch, device, config,
434
  accum_step, accum_steps, scaler)
435
+
436
+ # Only accumulate if not empty
437
+ if lm_loss > 0:
438
+ lm_loss_accum += lm_loss / accum_steps
439
+ lb_loss_accum += lb_loss / accum_steps
440
+ total_loss_accum += (lm_loss + lb_loss) / accum_steps
441
 
442
 
443
  # Gradient clipping