Commit ·
e38bdfb
1
Parent(s): 5db5e42
Fixes
Browse files- Model_Architecture/model.py +17 -2
- Model_Architecture/train.py +4 -5
Model_Architecture/model.py
CHANGED
|
@@ -484,6 +484,10 @@ class Block(nn.Module):
|
|
| 484 |
|
| 485 |
x = x + ffn_out
|
| 486 |
return x, lb_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
|
| 488 |
|
| 489 |
#####################################
|
|
@@ -491,11 +495,12 @@ class Block(nn.Module):
|
|
| 491 |
#####################################
|
| 492 |
|
| 493 |
class ismail(nn.Module):
|
| 494 |
-
def __init__(self, args: ModelArgs):
|
| 495 |
super().__init__()
|
| 496 |
self.args = args
|
| 497 |
self.vocab_size = args.vocab_size
|
| 498 |
self.n_layers = args.n_layers
|
|
|
|
| 499 |
|
| 500 |
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
| 501 |
self.layers = nn.ModuleList([Block(i, args) for i in range(args.n_layers)])
|
|
@@ -523,8 +528,18 @@ class ismail(nn.Module):
|
|
| 523 |
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
|
| 524 |
|
| 525 |
total_lb_loss = 0.0
|
|
|
|
| 526 |
for layer in self.layers:
|
| 527 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
if lb_loss is not None:
|
| 529 |
total_lb_loss += lb_loss
|
| 530 |
|
|
|
|
| 484 |
|
| 485 |
x = x + ffn_out
|
| 486 |
return x, lb_loss
|
| 487 |
+
|
| 488 |
+
def checkpoint_forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 489 |
+
"""Wrapper for gradient checkpointing that captures other args"""
|
| 490 |
+
return self.forward(x, self.start_pos, self.freqs_cis, self.mask)
|
| 491 |
|
| 492 |
|
| 493 |
#####################################
|
|
|
|
| 495 |
#####################################
|
| 496 |
|
| 497 |
class ismail(nn.Module):
|
| 498 |
+
def __init__(self, args: ModelArgs, use_checkpointing: bool = False):
|
| 499 |
super().__init__()
|
| 500 |
self.args = args
|
| 501 |
self.vocab_size = args.vocab_size
|
| 502 |
self.n_layers = args.n_layers
|
| 503 |
+
self.use_checkpointing = use_checkpointing
|
| 504 |
|
| 505 |
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
| 506 |
self.layers = nn.ModuleList([Block(i, args) for i in range(args.n_layers)])
|
|
|
|
| 528 |
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
|
| 529 |
|
| 530 |
total_lb_loss = 0.0
|
| 531 |
+
|
| 532 |
for layer in self.layers:
|
| 533 |
+
layer.start_pos = start_pos
|
| 534 |
+
layer.freqs_cis = freqs_cis
|
| 535 |
+
layer.mask = mask
|
| 536 |
+
|
| 537 |
+
if self.training and self.use_checkpointing:
|
| 538 |
+
from torch.utils.checkpoint import checkpoint
|
| 539 |
+
h, lb_loss = checkpoint(layer.checkpoint_forward, h)
|
| 540 |
+
else:
|
| 541 |
+
h, lb_loss = layer(h, start_pos, freqs_cis, mask)
|
| 542 |
+
|
| 543 |
if lb_loss is not None:
|
| 544 |
total_lb_loss += lb_loss
|
| 545 |
|
Model_Architecture/train.py
CHANGED
|
@@ -145,12 +145,11 @@ def setup_model(config, device):
|
|
| 145 |
|
| 146 |
# Estimate size
|
| 147 |
#size_info = estimate_model_size(args)
|
| 148 |
-
|
| 149 |
-
model = ismail(args).to(device)
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
| 154 |
print("✅ Gradient checkpointing enabled")
|
| 155 |
|
| 156 |
# Compile for speed (PyTorch 2.0+)
|
|
|
|
| 145 |
|
| 146 |
# Estimate size
|
| 147 |
#size_info = estimate_model_size(args)
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
use_checkpointing = config["training"].get("use_checkpointing", False)
|
| 150 |
+
model = ismail(args, use_checkpointing=use_checkpointing).to(device)
|
| 151 |
+
|
| 152 |
+
if use_checkpointing:
|
| 153 |
print("✅ Gradient checkpointing enabled")
|
| 154 |
|
| 155 |
# Compile for speed (PyTorch 2.0+)
|