Commit
·
1c3be6f
1
Parent(s):
431091f
Fixes are lies
Browse files- Model_Architecture/config.json +3 -3
- Model_Architecture/model.py +3 -13
- Model_Architecture/train.py +0 -2
Model_Architecture/config.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
{
|
| 2 |
"model": {
|
| 3 |
-
"max_batch_size":
|
| 4 |
"max_seq_len": 512,
|
| 5 |
"dtype": "bf16",
|
| 6 |
"scale_fmt": null,
|
|
@@ -37,7 +37,7 @@
|
|
| 37 |
"grad_clip": 1.0,
|
| 38 |
"warmup_steps": 1000,
|
| 39 |
"total_steps": 100000,
|
| 40 |
-
"use_checkpointing":
|
| 41 |
"expert_rotation_steps": 5000,
|
| 42 |
"gradient_accumulation_steps": 8,
|
| 43 |
"eval_every": 1000,
|
|
@@ -45,7 +45,7 @@
|
|
| 45 |
"save_dir": "./checkpoints",
|
| 46 |
"log_every": 100,
|
| 47 |
"dtype": "bf16",
|
| 48 |
-
"compile":
|
| 49 |
},
|
| 50 |
"data": {
|
| 51 |
"train_file": "./data/train.txt",
|
|
|
|
| 1 |
{
|
| 2 |
"model": {
|
| 3 |
+
"max_batch_size": 8,
|
| 4 |
"max_seq_len": 512,
|
| 5 |
"dtype": "bf16",
|
| 6 |
"scale_fmt": null,
|
|
|
|
| 37 |
"grad_clip": 1.0,
|
| 38 |
"warmup_steps": 1000,
|
| 39 |
"total_steps": 100000,
|
| 40 |
+
"use_checkpointing": false,
|
| 41 |
"expert_rotation_steps": 5000,
|
| 42 |
"gradient_accumulation_steps": 8,
|
| 43 |
"eval_every": 1000,
|
|
|
|
| 45 |
"save_dir": "./checkpoints",
|
| 46 |
"log_every": 100,
|
| 47 |
"dtype": "bf16",
|
| 48 |
+
"compile": false
|
| 49 |
},
|
| 50 |
"data": {
|
| 51 |
"train_file": "./data/train.txt",
|
Model_Architecture/model.py
CHANGED
|
@@ -520,7 +520,6 @@ class ismail(nn.Module):
|
|
| 520 |
h = self.tok_embeddings(tokens).to(Linear.dtype)
|
| 521 |
freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
|
| 522 |
|
| 523 |
-
# Create causal mask
|
| 524 |
mask = None
|
| 525 |
if seqlen > 1:
|
| 526 |
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
|
|
@@ -528,25 +527,16 @@ class ismail(nn.Module):
|
|
| 528 |
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
|
| 529 |
|
| 530 |
total_lb_loss = 0.0
|
| 531 |
-
|
|
|
|
| 532 |
for layer in self.layers:
|
| 533 |
-
layer
|
| 534 |
-
layer.freqs_cis = freqs_cis
|
| 535 |
-
layer.mask = mask
|
| 536 |
-
|
| 537 |
-
if self.training and self.use_checkpointing:
|
| 538 |
-
from torch.utils.checkpoint import checkpoint
|
| 539 |
-
h, lb_loss = checkpoint(layer.checkpoint_forward, h, use_reentrant=False )
|
| 540 |
-
else:
|
| 541 |
-
h, lb_loss = layer(h, start_pos, freqs_cis, mask)
|
| 542 |
-
|
| 543 |
if lb_loss is not None:
|
| 544 |
total_lb_loss += lb_loss
|
| 545 |
|
| 546 |
h = self.norm(h)
|
| 547 |
output = self.output(h)
|
| 548 |
|
| 549 |
-
# Return output and total load balancing loss if training
|
| 550 |
if self.training and total_lb_loss > 0:
|
| 551 |
return output, total_lb_loss
|
| 552 |
return output
|
|
|
|
| 520 |
h = self.tok_embeddings(tokens).to(Linear.dtype)
|
| 521 |
freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
|
| 522 |
|
|
|
|
| 523 |
mask = None
|
| 524 |
if seqlen > 1:
|
| 525 |
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
|
|
|
|
| 527 |
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
|
| 528 |
|
| 529 |
total_lb_loss = 0.0
|
| 530 |
+
|
| 531 |
+
# ✅ SIMPLE forward pass - no checkpointing
|
| 532 |
for layer in self.layers:
|
| 533 |
+
h, lb_loss = layer(h, start_pos, freqs_cis, mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 534 |
if lb_loss is not None:
|
| 535 |
total_lb_loss += lb_loss
|
| 536 |
|
| 537 |
h = self.norm(h)
|
| 538 |
output = self.output(h)
|
| 539 |
|
|
|
|
| 540 |
if self.training and total_lb_loss > 0:
|
| 541 |
return output, total_lb_loss
|
| 542 |
return output
|
Model_Architecture/train.py
CHANGED
|
@@ -13,12 +13,10 @@ from pathlib import Path
|
|
| 13 |
import json
|
| 14 |
import time
|
| 15 |
import math
|
| 16 |
-
from torch.utils.checkpoint import checkpoint
|
| 17 |
|
| 18 |
|
| 19 |
# Import your model
|
| 20 |
from model import ismail, ModelArgs
|
| 21 |
-
from model_size import estimate_model_size
|
| 22 |
|
| 23 |
# Try to import optional dependencies
|
| 24 |
try:
|
|
|
|
| 13 |
import json
|
| 14 |
import time
|
| 15 |
import math
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
# Import your model
|
| 19 |
from model import ismail, ModelArgs
|
|
|
|
| 20 |
|
| 21 |
# Try to import optional dependencies
|
| 22 |
try:
|