ikaganacar commited on
Commit
1c3be6f
·
1 Parent(s): 431091f

Fixes are lies

Browse files
Model_Architecture/config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "model": {
3
- "max_batch_size": 2,
4
  "max_seq_len": 512,
5
  "dtype": "bf16",
6
  "scale_fmt": null,
@@ -37,7 +37,7 @@
37
  "grad_clip": 1.0,
38
  "warmup_steps": 1000,
39
  "total_steps": 100000,
40
- "use_checkpointing": true,
41
  "expert_rotation_steps": 5000,
42
  "gradient_accumulation_steps": 8,
43
  "eval_every": 1000,
@@ -45,7 +45,7 @@
45
  "save_dir": "./checkpoints",
46
  "log_every": 100,
47
  "dtype": "bf16",
48
- "compile": true
49
  },
50
  "data": {
51
  "train_file": "./data/train.txt",
 
1
  {
2
  "model": {
3
+ "max_batch_size": 8,
4
  "max_seq_len": 512,
5
  "dtype": "bf16",
6
  "scale_fmt": null,
 
37
  "grad_clip": 1.0,
38
  "warmup_steps": 1000,
39
  "total_steps": 100000,
40
+ "use_checkpointing": false,
41
  "expert_rotation_steps": 5000,
42
  "gradient_accumulation_steps": 8,
43
  "eval_every": 1000,
 
45
  "save_dir": "./checkpoints",
46
  "log_every": 100,
47
  "dtype": "bf16",
48
+ "compile": false
49
  },
50
  "data": {
51
  "train_file": "./data/train.txt",
Model_Architecture/model.py CHANGED
@@ -520,7 +520,6 @@ class ismail(nn.Module):
520
  h = self.tok_embeddings(tokens).to(Linear.dtype)
521
  freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
522
 
523
- # Create causal mask
524
  mask = None
525
  if seqlen > 1:
526
  mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
@@ -528,25 +527,16 @@ class ismail(nn.Module):
528
  mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
529
 
530
  total_lb_loss = 0.0
531
-
 
532
  for layer in self.layers:
533
- layer.start_pos = start_pos
534
- layer.freqs_cis = freqs_cis
535
- layer.mask = mask
536
-
537
- if self.training and self.use_checkpointing:
538
- from torch.utils.checkpoint import checkpoint
539
- h, lb_loss = checkpoint(layer.checkpoint_forward, h, use_reentrant=False )
540
- else:
541
- h, lb_loss = layer(h, start_pos, freqs_cis, mask)
542
-
543
  if lb_loss is not None:
544
  total_lb_loss += lb_loss
545
 
546
  h = self.norm(h)
547
  output = self.output(h)
548
 
549
- # Return output and total load balancing loss if training
550
  if self.training and total_lb_loss > 0:
551
  return output, total_lb_loss
552
  return output
 
520
  h = self.tok_embeddings(tokens).to(Linear.dtype)
521
  freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
522
 
 
523
  mask = None
524
  if seqlen > 1:
525
  mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
 
527
  mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
528
 
529
  total_lb_loss = 0.0
530
+
531
+ # ✅ SIMPLE forward pass - no checkpointing
532
  for layer in self.layers:
533
+ h, lb_loss = layer(h, start_pos, freqs_cis, mask)
 
 
 
 
 
 
 
 
 
534
  if lb_loss is not None:
535
  total_lb_loss += lb_loss
536
 
537
  h = self.norm(h)
538
  output = self.output(h)
539
 
 
540
  if self.training and total_lb_loss > 0:
541
  return output, total_lb_loss
542
  return output
Model_Architecture/train.py CHANGED
@@ -13,12 +13,10 @@ from pathlib import Path
13
  import json
14
  import time
15
  import math
16
- from torch.utils.checkpoint import checkpoint
17
 
18
 
19
  # Import your model
20
  from model import ismail, ModelArgs
21
- from model_size import estimate_model_size
22
 
23
  # Try to import optional dependencies
24
  try:
 
13
  import json
14
  import time
15
  import math
 
16
 
17
 
18
  # Import your model
19
  from model import ismail, ModelArgs
 
20
 
21
  # Try to import optional dependencies
22
  try: