Commit ·
1461702
1
Parent(s): 36b9687
Fixes
Browse files- Model_Architecture/train.py +17 -6
Model_Architecture/train.py
CHANGED
|
@@ -294,13 +294,20 @@ def save_checkpoint(model, optimizer, step, config, expert_idx=None):
|
|
| 294 |
|
| 295 |
|
| 296 |
def train_step(model, batch, device, config, accum_step, accum_steps, scaler=None):
|
| 297 |
-
"""Process a
|
| 298 |
input_ids, target_ids = batch
|
| 299 |
|
| 300 |
-
|
| 301 |
-
micro_batch_size =
|
|
|
|
|
|
|
|
|
|
| 302 |
start_idx = micro_batch_size * accum_step
|
| 303 |
-
end_idx = start_idx + micro_batch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
# Get micro-batch slices
|
| 306 |
input_mb = input_ids[start_idx:end_idx].to(device, non_blocking=True)
|
|
@@ -425,8 +432,12 @@ def main():
|
|
| 425 |
for accum_step in range(accum_steps):
|
| 426 |
lm_loss, lb_loss = train_step(model, batch, device, config,
|
| 427 |
accum_step, accum_steps, scaler)
|
| 428 |
-
|
| 429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
|
| 431 |
|
| 432 |
# Gradient clipping
|
|
|
|
| 294 |
|
| 295 |
|
| 296 |
def train_step(model, batch, device, config, accum_step, accum_steps, scaler=None):
|
| 297 |
+
"""Process a SINGLE micro-batch for gradient accumulation"""
|
| 298 |
input_ids, target_ids = batch
|
| 299 |
|
| 300 |
+
batch_size = input_ids.size(0)
|
| 301 |
+
micro_batch_size = max(1, batch_size // accum_steps)
|
| 302 |
+
print(f"Batch size: {batch_size}, Micro-batch: {start_idx}:{end_idx}, Size: {input_mb.shape}")
|
| 303 |
+
|
| 304 |
+
# Calculate slice indices
|
| 305 |
start_idx = micro_batch_size * accum_step
|
| 306 |
+
end_idx = min(start_idx + micro_batch_size, batch_size)
|
| 307 |
+
|
| 308 |
+
# 🚨 CRITICAL: Skip if this micro-batch is empty (last iteration)
|
| 309 |
+
if start_idx >= batch_size:
|
| 310 |
+
return 0.0, 0.0 # Return zero loss, will be divided later
|
| 311 |
|
| 312 |
# Get micro-batch slices
|
| 313 |
input_mb = input_ids[start_idx:end_idx].to(device, non_blocking=True)
|
|
|
|
| 432 |
for accum_step in range(accum_steps):
|
| 433 |
lm_loss, lb_loss = train_step(model, batch, device, config,
|
| 434 |
accum_step, accum_steps, scaler)
|
| 435 |
+
|
| 436 |
+
# Only accumulate if not empty
|
| 437 |
+
if lm_loss > 0:
|
| 438 |
+
lm_loss_accum += lm_loss / accum_steps
|
| 439 |
+
lb_loss_accum += lb_loss / accum_steps
|
| 440 |
+
total_loss_accum += (lm_loss + lb_loss) / accum_steps
|
| 441 |
|
| 442 |
|
| 443 |
# Gradient clipping
|