ikaganacar commited on
Commit
e38bdfb
·
1 Parent(s): 5db5e42
Model_Architecture/model.py CHANGED
@@ -484,6 +484,10 @@ class Block(nn.Module):
484
 
485
  x = x + ffn_out
486
  return x, lb_loss
 
 
 
 
487
 
488
 
489
  #####################################
@@ -491,11 +495,12 @@ class Block(nn.Module):
491
  #####################################
492
 
493
  class ismail(nn.Module):
494
- def __init__(self, args: ModelArgs):
495
  super().__init__()
496
  self.args = args
497
  self.vocab_size = args.vocab_size
498
  self.n_layers = args.n_layers
 
499
 
500
  self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
501
  self.layers = nn.ModuleList([Block(i, args) for i in range(args.n_layers)])
@@ -523,8 +528,18 @@ class ismail(nn.Module):
523
  mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
524
 
525
  total_lb_loss = 0.0
 
526
  for layer in self.layers:
527
- h, lb_loss = layer(h, start_pos, freqs_cis, mask)
 
 
 
 
 
 
 
 
 
528
  if lb_loss is not None:
529
  total_lb_loss += lb_loss
530
 
 
484
 
485
  x = x + ffn_out
486
  return x, lb_loss
487
+
488
+ def checkpoint_forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
489
+ """Wrapper for gradient checkpointing that captures other args"""
490
+ return self.forward(x, self.start_pos, self.freqs_cis, self.mask)
491
 
492
 
493
  #####################################
 
495
  #####################################
496
 
497
  class ismail(nn.Module):
498
+ def __init__(self, args: ModelArgs, use_checkpointing: bool = False):
499
  super().__init__()
500
  self.args = args
501
  self.vocab_size = args.vocab_size
502
  self.n_layers = args.n_layers
503
+ self.use_checkpointing = use_checkpointing
504
 
505
  self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
506
  self.layers = nn.ModuleList([Block(i, args) for i in range(args.n_layers)])
 
528
  mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
529
 
530
  total_lb_loss = 0.0
531
+
532
  for layer in self.layers:
533
+ layer.start_pos = start_pos
534
+ layer.freqs_cis = freqs_cis
535
+ layer.mask = mask
536
+
537
+ if self.training and self.use_checkpointing:
538
+ from torch.utils.checkpoint import checkpoint
539
+ h, lb_loss = checkpoint(layer.checkpoint_forward, h)
540
+ else:
541
+ h, lb_loss = layer(h, start_pos, freqs_cis, mask)
542
+
543
  if lb_loss is not None:
544
  total_lb_loss += lb_loss
545
 
Model_Architecture/train.py CHANGED
@@ -145,12 +145,11 @@ def setup_model(config, device):
145
 
146
  # Estimate size
147
  #size_info = estimate_model_size(args)
148
-
149
- model = ismail(args).to(device)
150
 
151
- if config["training"].get("use_checkpointing", True):
152
- for layer in model.layers:
153
- layer.forward = lambda *args, layer=layer: checkpoint(layer._forward, *args)
 
154
  print("✅ Gradient checkpointing enabled")
155
 
156
  # Compile for speed (PyTorch 2.0+)
 
145
 
146
  # Estimate size
147
  #size_info = estimate_model_size(args)
 
 
148
 
149
+ use_checkpointing = config["training"].get("use_checkpointing", False)
150
+ model = ismail(args, use_checkpointing=use_checkpointing).to(device)
151
+
152
+ if use_checkpointing:
153
  print("✅ Gradient checkpointing enabled")
154
 
155
  # Compile for speed (PyTorch 2.0+)