ikaganacar commited on
Commit
5db5e42
·
1 Parent(s): 89731d3

Fixes - I guess

Browse files
Model_Architecture/config.json CHANGED
@@ -1,26 +1,26 @@
1
  {
2
  "model": {
3
- "max_batch_size": 4,
4
- "max_seq_len": 1024,
5
  "dtype": "bf16",
6
  "scale_fmt": null,
7
  "vocab_size": 32768,
8
- "dim": 768,
9
  "inter_dim": 4096,
10
- "moe_inter_dim": 768,
11
- "n_layers": 20,
12
  "n_dense_layers": 3,
13
  "n_heads": 12,
14
- "n_routed_experts": 6,
15
  "n_shared_experts": 1,
16
  "n_activated_experts": 2,
17
  "route_scale": 1.0,
18
  "use_routing_bias": true,
19
  "q_lora_rank": 0,
20
- "kv_lora_rank": 512,
21
- "qk_nope_head_dim": 128,
22
  "qk_rope_head_dim": 64,
23
- "v_head_dim": 128,
24
  "original_seq_len": 4096,
25
  "rope_theta": 10000.0,
26
  "rope_factor": 40,
@@ -36,9 +36,10 @@
36
  "beta2": 0.95,
37
  "grad_clip": 1.0,
38
  "warmup_steps": 1000,
39
- "total_steps": 50000,
40
- "expert_rotation_steps": 2000,
41
- "gradient_accumulation_steps": 16,
 
42
  "eval_every": 1000,
43
  "save_every": 5000,
44
  "save_dir": "./checkpoints",
 
1
  {
2
  "model": {
3
+ "max_batch_size": 2,
4
+ "max_seq_len": 512,
5
  "dtype": "bf16",
6
  "scale_fmt": null,
7
  "vocab_size": 32768,
8
+ "dim": 512,
9
  "inter_dim": 4096,
10
+ "moe_inter_dim": 512,
11
+ "n_layers": 16,
12
  "n_dense_layers": 3,
13
  "n_heads": 12,
14
+ "n_routed_experts": 4,
15
  "n_shared_experts": 1,
16
  "n_activated_experts": 2,
17
  "route_scale": 1.0,
18
  "use_routing_bias": true,
19
  "q_lora_rank": 0,
20
+ "kv_lora_rank": 256,
21
+ "qk_nope_head_dim": 96,
22
  "qk_rope_head_dim": 64,
23
+ "v_head_dim": 96,
24
  "original_seq_len": 4096,
25
  "rope_theta": 10000.0,
26
  "rope_factor": 40,
 
36
  "beta2": 0.95,
37
  "grad_clip": 1.0,
38
  "warmup_steps": 1000,
39
+ "total_steps": 100000,
40
+ "use_checkpointing": true,
41
+ "expert_rotation_steps": 5000,
42
+ "gradient_accumulation_steps": 8,
43
  "eval_every": 1000,
44
  "save_every": 5000,
45
  "save_dir": "./checkpoints",
Model_Architecture/train.py CHANGED
@@ -1,16 +1,20 @@
1
- #!/usr/bin/env python3
2
- """
3
- Sequential Expert Training Script for MoE on Single GPU
4
- Memory Usage: ~7.2GB (vs 10.9GB for full MoE)
5
- """
6
 
7
  import argparse
8
  import torch
 
 
 
 
 
 
 
9
  import torch.nn.functional as F
10
  from pathlib import Path
11
  import json
12
  import time
13
  import math
 
 
14
 
15
  # Import your model
16
  from model import ismail, ModelArgs
@@ -143,6 +147,11 @@ def setup_model(config, device):
143
  #size_info = estimate_model_size(args)
144
 
145
  model = ismail(args).to(device)
 
 
 
 
 
146
 
147
  # Compile for speed (PyTorch 2.0+)
148
  if config["training"]["compile"]:
@@ -296,41 +305,48 @@ def save_checkpoint(model, optimizer, step, config, expert_idx=None):
296
  print(f"💾 Checkpoint saved: {ckpt_path}")
297
 
298
 
299
- def train_step(model, batch, device, config, scaler=None):
300
- """Single training step"""
301
  input_ids, target_ids = batch
302
- input_ids = input_ids.to(device, non_blocking=True)
303
- target_ids = target_ids.to(device, non_blocking=True)
 
 
 
 
 
 
 
304
 
305
- # Forward pass (using new torch.amp API)
306
  with torch.amp.autocast(device_type='cuda', enabled=(config["training"]["dtype"] == "bf16")):
307
- output = model(input_ids, start_pos=0)
308
-
309
- # Handle model output (tuple in training mode with MoE, single tensor otherwise)
310
  if isinstance(output, tuple):
311
  logits, lb_loss = output
312
  else:
313
  logits = output
314
  lb_loss = 0.0
315
-
316
- # Main language modeling loss
317
  lm_loss = F.cross_entropy(
318
  logits.view(-1, logits.size(-1)),
319
- target_ids.view(-1),
320
  ignore_index=-1,
321
  )
322
-
323
- # Total loss with load balancing
324
- # Ensure lb_loss is a tensor for proper gradient computation
325
  if isinstance(lb_loss, float):
326
- lb_loss_coef = 0.0 # If no MoE layer, no load balance loss
327
- total_loss = lm_loss
328
  else:
329
  lb_loss_coef = config["training"].get("lb_loss_coef", 0.01)
330
- total_loss = lm_loss + lb_loss_coef * lb_loss
331
 
332
- # Return scalar values for logging
333
- return total_loss, lm_loss.item(), lb_loss if isinstance(lb_loss, float) else lb_loss.item()
 
 
 
 
 
334
 
335
 
336
  def main():
@@ -417,24 +433,13 @@ def main():
417
  total_loss_accum = 0.0
418
  lm_loss_accum = 0.0
419
  lb_loss_accum = 0.0
420
-
421
- for accum_step in range(accum_steps):
422
- # Split batch for micro-batching (if needed)
423
- # For now, process full batch
424
- loss, lm_loss, lb_loss = train_step(model, batch, device, config, scaler)
425
-
426
- # Normalize for accumulation
427
- loss = loss / accum_steps
428
 
429
- # Backward pass
430
- if config["training"]["dtype"] == "bf16":
431
- scaler.scale(loss).backward()
432
- else:
433
- loss.backward()
434
 
435
- total_loss_accum += loss.item()
436
- lm_loss_accum += lm_loss / accum_steps # Already a float from train_step
437
- lb_loss_accum += lb_loss / accum_steps # Already a float from train_step
438
 
439
  # Gradient clipping
440
  if config["training"]["grad_clip"] > 0:
 
 
 
 
 
 
1
 
2
  import argparse
3
  import torch
4
+
5
+ torch.backends.cuda.matmul.allow_tf32 = True
6
+ torch.backends.cudnn.allow_tf32 = True
7
+ torch.backends.cudnn.benchmark = True
8
+
9
+ torch.cuda.empty_cache()
10
+
11
  import torch.nn.functional as F
12
  from pathlib import Path
13
  import json
14
  import time
15
  import math
16
+ from torch.utils.checkpoint import checkpoint
17
+
18
 
19
  # Import your model
20
  from model import ismail, ModelArgs
 
147
  #size_info = estimate_model_size(args)
148
 
149
  model = ismail(args).to(device)
150
+
151
+ if config["training"].get("use_checkpointing", True):
152
+ for layer in model.layers:
153
+ layer.forward = lambda *args, layer=layer: checkpoint(layer._forward, *args)
154
+ print("✅ Gradient checkpointing enabled")
155
 
156
  # Compile for speed (PyTorch 2.0+)
157
  if config["training"]["compile"]:
 
305
  print(f"💾 Checkpoint saved: {ckpt_path}")
306
 
307
 
308
+ def train_step(model, batch, device, config, accum_step, accum_steps, scaler=None):
309
+ """Process a MICRO-batch for gradient accumulation"""
310
  input_ids, target_ids = batch
311
+
312
+ # Split batch into micro-batches
313
+ micro_batch_size = input_ids.size(0) // accum_steps
314
+ start_idx = micro_batch_size * accum_step
315
+ end_idx = start_idx + micro_batch_size
316
+
317
+ # Get micro-batch slices
318
+ input_mb = input_ids[start_idx:end_idx].to(device, non_blocking=True)
319
+ target_mb = target_ids[start_idx:end_idx].to(device, non_blocking=True)
320
 
321
+ # Forward pass
322
  with torch.amp.autocast(device_type='cuda', enabled=(config["training"]["dtype"] == "bf16")):
323
+ output = model(input_mb, start_pos=0)
324
+
 
325
  if isinstance(output, tuple):
326
  logits, lb_loss = output
327
  else:
328
  logits = output
329
  lb_loss = 0.0
330
+
 
331
  lm_loss = F.cross_entropy(
332
  logits.view(-1, logits.size(-1)),
333
+ target_mb.view(-1),
334
  ignore_index=-1,
335
  )
336
+
 
 
337
  if isinstance(lb_loss, float):
338
+ total_loss = lm_loss / accum_steps
 
339
  else:
340
  lb_loss_coef = config["training"].get("lb_loss_coef", 0.01)
341
+ total_loss = (lm_loss + lb_loss_coef * lb_loss) / accum_steps
342
 
343
+ # Backward
344
+ if config["training"]["dtype"] == "bf16":
345
+ scaler.scale(total_loss).backward()
346
+ else:
347
+ total_loss.backward()
348
+
349
+ return lm_loss.item(), lb_loss if isinstance(lb_loss, float) else lb_loss.item()
350
 
351
 
352
  def main():
 
433
  total_loss_accum = 0.0
434
  lm_loss_accum = 0.0
435
  lb_loss_accum = 0.0
 
 
 
 
 
 
 
 
436
 
437
+ for accum_step in range(accum_steps):
438
+ lm_loss, lb_loss = train_step(model, batch, device, config,
439
+ accum_step, accum_steps, scaler)
440
+ lm_loss_accum += lm_loss / accum_steps
441
+ lb_loss_accum += lb_loss / accum_steps
442
 
 
 
 
443
 
444
  # Gradient clipping
445
  if config["training"]["grad_clip"] > 0: