ikaganacar commited on
Commit
36b9687
·
1 Parent(s): 7e94c65
Model_Architecture/model.py CHANGED
@@ -505,6 +505,7 @@ class ismail(nn.Module):
505
  self.layers = nn.ModuleList([Block(i, args) for i in range(args.n_layers)])
506
  self.norm = RMSNorm(args.dim)
507
  self.output = Linear(args.dim, args.vocab_size, bias=False)
 
508
 
509
  self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
510
 
@@ -527,13 +528,13 @@ class ismail(nn.Module):
527
  mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
528
 
529
  total_lb_loss = 0.0
530
-
531
  for layer in self.layers:
532
  layer.start_pos = start_pos
533
  layer.freqs_cis = freqs_cis
534
  layer.mask = mask
535
 
536
- if self.training and True: # Enable gradient checkpointing during training
537
  from torch.utils.checkpoint import checkpoint
538
  h, lb_loss = checkpoint(layer.checkpoint_forward, h)
539
  else:
 
505
  self.layers = nn.ModuleList([Block(i, args) for i in range(args.n_layers)])
506
  self.norm = RMSNorm(args.dim)
507
  self.output = Linear(args.dim, args.vocab_size, bias=False)
508
+ self.use_checkpointing = False
509
 
510
  self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
511
 
 
528
  mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
529
 
530
  total_lb_loss = 0.0
531
+
532
  for layer in self.layers:
533
  layer.start_pos = start_pos
534
  layer.freqs_cis = freqs_cis
535
  layer.mask = mask
536
 
537
+ if self.training and self.use_checkpointing:
538
  from torch.utils.checkpoint import checkpoint
539
  h, lb_loss = checkpoint(layer.checkpoint_forward, h)
540
  else:
Model_Architecture/train.py CHANGED
@@ -136,28 +136,16 @@ def load_config(args):
136
 
137
 
138
  def setup_model(config, device):
139
- """Initialize model and print size estimate"""
140
  args = ModelArgs(**config["model"])
141
-
142
- print("\n" + "="*70)
143
- print("MODEL INITIALIZATION")
144
- print("="*70 + "\n")
145
-
146
- # Estimate size
147
- #size_info = estimate_model_size(args)
148
-
149
  model = ismail(args).to(device)
150
-
151
- if config["training"].get("use_checkpointing", True):
152
- for layer in model.layers:
153
- layer.forward = lambda *args, layer=layer: checkpoint(layer._forward, *args)
154
- print("✅ Gradient checkpointing enabled")
155
 
156
- # Compile for speed (PyTorch 2.0+)
 
 
157
  if config["training"]["compile"]:
158
  try:
159
  model = torch.compile(model)
160
- print("✅ Model compiled with torch.compile()\n")
161
  except Exception as e:
162
  print(f"⚠️ Compilation failed: {e}\n")
163
 
 
136
 
137
 
138
  def setup_model(config, device):
 
139
  args = ModelArgs(**config["model"])
 
 
 
 
 
 
 
 
140
  model = ismail(args).to(device)
 
 
 
 
 
141
 
142
+ # Add this line to enable checkpointing
143
+ model.use_checkpointing = config["training"].get("use_checkpointing", True)
144
+
145
  if config["training"]["compile"]:
146
  try:
147
  model = torch.compile(model)
148
+ print("✅ Model compiled\n")
149
  except Exception as e:
150
  print(f"⚠️ Compilation failed: {e}\n")
151