Commit
·
36b9687
1
Parent(s):
7e94c65
Fixes
Browse files- Model_Architecture/model.py +3 -2
- Model_Architecture/train.py +4 -16
Model_Architecture/model.py
CHANGED
|
@@ -505,6 +505,7 @@ class ismail(nn.Module):
|
|
| 505 |
self.layers = nn.ModuleList([Block(i, args) for i in range(args.n_layers)])
|
| 506 |
self.norm = RMSNorm(args.dim)
|
| 507 |
self.output = Linear(args.dim, args.vocab_size, bias=False)
|
|
|
|
| 508 |
|
| 509 |
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
|
| 510 |
|
|
@@ -527,13 +528,13 @@ class ismail(nn.Module):
|
|
| 527 |
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
|
| 528 |
|
| 529 |
total_lb_loss = 0.0
|
| 530 |
-
|
| 531 |
for layer in self.layers:
|
| 532 |
layer.start_pos = start_pos
|
| 533 |
layer.freqs_cis = freqs_cis
|
| 534 |
layer.mask = mask
|
| 535 |
|
| 536 |
-
if self.training and
|
| 537 |
from torch.utils.checkpoint import checkpoint
|
| 538 |
h, lb_loss = checkpoint(layer.checkpoint_forward, h)
|
| 539 |
else:
|
|
|
|
| 505 |
self.layers = nn.ModuleList([Block(i, args) for i in range(args.n_layers)])
|
| 506 |
self.norm = RMSNorm(args.dim)
|
| 507 |
self.output = Linear(args.dim, args.vocab_size, bias=False)
|
| 508 |
+
self.use_checkpointing = False
|
| 509 |
|
| 510 |
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
|
| 511 |
|
|
|
|
| 528 |
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
|
| 529 |
|
| 530 |
total_lb_loss = 0.0
|
| 531 |
+
|
| 532 |
for layer in self.layers:
|
| 533 |
layer.start_pos = start_pos
|
| 534 |
layer.freqs_cis = freqs_cis
|
| 535 |
layer.mask = mask
|
| 536 |
|
| 537 |
+
if self.training and self.use_checkpointing:
|
| 538 |
from torch.utils.checkpoint import checkpoint
|
| 539 |
h, lb_loss = checkpoint(layer.checkpoint_forward, h)
|
| 540 |
else:
|
Model_Architecture/train.py
CHANGED
|
@@ -136,28 +136,16 @@ def load_config(args):
|
|
| 136 |
|
| 137 |
|
| 138 |
def setup_model(config, device):
|
| 139 |
-
"""Initialize model and print size estimate"""
|
| 140 |
args = ModelArgs(**config["model"])
|
| 141 |
-
|
| 142 |
-
print("\n" + "="*70)
|
| 143 |
-
print("MODEL INITIALIZATION")
|
| 144 |
-
print("="*70 + "\n")
|
| 145 |
-
|
| 146 |
-
# Estimate size
|
| 147 |
-
#size_info = estimate_model_size(args)
|
| 148 |
-
|
| 149 |
model = ismail(args).to(device)
|
| 150 |
-
|
| 151 |
-
if config["training"].get("use_checkpointing", True):
|
| 152 |
-
for layer in model.layers:
|
| 153 |
-
layer.forward = lambda *args, layer=layer: checkpoint(layer._forward, *args)
|
| 154 |
-
print("✅ Gradient checkpointing enabled")
|
| 155 |
|
| 156 |
-
#
|
|
|
|
|
|
|
| 157 |
if config["training"]["compile"]:
|
| 158 |
try:
|
| 159 |
model = torch.compile(model)
|
| 160 |
-
print("✅ Model compiled
|
| 161 |
except Exception as e:
|
| 162 |
print(f"⚠️ Compilation failed: {e}\n")
|
| 163 |
|
|
|
|
| 136 |
|
| 137 |
|
| 138 |
def setup_model(config, device):
|
|
|
|
| 139 |
args = ModelArgs(**config["model"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
model = ismail(args).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
+
# Add this line to enable checkpointing
|
| 143 |
+
model.use_checkpointing = config["training"].get("use_checkpointing", True)
|
| 144 |
+
|
| 145 |
if config["training"]["compile"]:
|
| 146 |
try:
|
| 147 |
model = torch.compile(model)
|
| 148 |
+
print("✅ Model compiled\n")
|
| 149 |
except Exception as e:
|
| 150 |
print(f"⚠️ Compilation failed: {e}\n")
|
| 151 |
|