Upload load_JiRack5_SlimPajama_3b_safetensors.py
Browse files
load_JiRack5_SlimPajama_3b_safetensors.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 CMS Manhattan
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
# Author: Konstantin Vladimirovich Grabko
|
| 4 |
+
# Email: grabko@cmsmanhattan.com
|
| 5 |
+
# Phone: +1(516)777-0945
|
| 6 |
+
#
|
| 7 |
+
# This program is free software: you can redistribute it and/or modify
|
| 8 |
+
# it under the terms of the GNU General Public License as published by
|
| 9 |
+
# the Free Software Foundation, version 3 of the License.
|
| 10 |
+
#
|
| 11 |
+
# This program is distributed in the hope that it will be useful,
|
| 12 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 13 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 14 |
+
# GNU General Public License for more details.
|
| 15 |
+
#
|
| 16 |
+
# You should have received a copy of the GNU General Public License
|
| 17 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 18 |
+
#
|
| 19 |
+
# Additional terms:
|
| 20 |
+
# Any commercial use or distribution of this software or derivative works
|
| 21 |
+
# requires explicit written permission from the copyright holder.
|
| 22 |
+
|
| 23 |
+
import os
|
| 24 |
+
import torch
|
| 25 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 26 |
+
from datasets import load_dataset
|
| 27 |
+
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
| 28 |
+
from safetensors.torch import save_file, load_file
|
| 29 |
+
from JiRackPyTorch_GPT5_class_3b import JiRackPyTorch, VOCAB_SIZE, MAX_SEQ_LEN
|
| 30 |
+
|
| 31 |
+
# --- Настройки сохранения ---
|
| 32 |
+
SAVE_DIR = "jirack_weights"
|
| 33 |
+
SHARD_1 = "model-00001-of-00002.safetensors"
|
| 34 |
+
SHARD_2 = "model-00002-of-00002.safetensors"
|
| 35 |
+
|
| 36 |
+
def save_sharded_safetensors(model, directory):
|
| 37 |
+
os.makedirs(directory, exist_ok=True)
|
| 38 |
+
state_dict = model.state_dict()
|
| 39 |
+
keys = list(state_dict.keys())
|
| 40 |
+
mid = len(keys) // 2
|
| 41 |
+
|
| 42 |
+
shard1 = {k: state_dict[k] for k in keys[:mid]}
|
| 43 |
+
shard2 = {k: state_dict[k] for k in keys[mid:]}
|
| 44 |
+
|
| 45 |
+
save_file(shard1, os.path.join(directory, SHARD_1))
|
| 46 |
+
save_file(shard2, os.path.join(directory, SHARD_2))
|
| 47 |
+
print(f"--- [CHECKPOINT] Model shards saved to {directory} ---")
|
| 48 |
+
|
| 49 |
+
def load_sharded_safetensors(model, directory):
|
| 50 |
+
p1 = os.path.join(directory, SHARD_1)
|
| 51 |
+
p2 = os.path.join(directory, SHARD_2)
|
| 52 |
+
if os.path.exists(p1) and os.path.exists(p2):
|
| 53 |
+
print(f"--- [RESUME] Loading weights from {directory} ---")
|
| 54 |
+
sd = {}
|
| 55 |
+
sd.update(load_file(p1))
|
| 56 |
+
sd.update(load_file(p2))
|
| 57 |
+
model.load_state_dict(sd)
|
| 58 |
+
return True
|
| 59 |
+
return False
|
| 60 |
+
|
| 61 |
+
def train():
|
| 62 |
+
# 1. Setup Device & Model
|
| 63 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 64 |
+
model = JiRackPyTorch().to(device)
|
| 65 |
+
|
| 66 |
+
# Проверка авторства перед стартом
|
| 67 |
+
print(f"--- JiRack Engine Start ---")
|
| 68 |
+
print(f"--- {model.get_author_info()} ---")
|
| 69 |
+
|
| 70 |
+
# Попытка возобновить обучение
|
| 71 |
+
load_sharded_safetensors(model, SAVE_DIR)
|
| 72 |
+
|
| 73 |
+
# 2. Dataset & Tokenizer
|
| 74 |
+
print("Loading Dataset: SlimPajama...")
|
| 75 |
+
dataset = load_dataset("cerebras/SlimPajama-627B", split="train", streaming=True)
|
| 76 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 77 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 78 |
+
|
| 79 |
+
# 3. Hyperparameters
|
| 80 |
+
lr = 2e-4
|
| 81 |
+
batch_size = 2 # Micro-batch
|
| 82 |
+
grad_accum = 16 # Effective batch = 32
|
| 83 |
+
max_steps = 100000
|
| 84 |
+
|
| 85 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.1)
|
| 86 |
+
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=max_steps)
|
| 87 |
+
scaler = GradScaler() # Для работы с float16/bf16
|
| 88 |
+
|
| 89 |
+
model.train()
|
| 90 |
+
data_iter = iter(dataset)
|
| 91 |
+
|
| 92 |
+
print(f"Starting Training Loop. Target steps: {max_steps}")
|
| 93 |
+
|
| 94 |
+
for step in range(1, max_steps + 1):
|
| 95 |
+
optimizer.zero_grad(set_to_none=True)
|
| 96 |
+
total_loss = 0
|
| 97 |
+
|
| 98 |
+
for _ in range(grad_accum):
|
| 99 |
+
try:
|
| 100 |
+
batch = next(data_iter)
|
| 101 |
+
except StopIteration:
|
| 102 |
+
data_iter = iter(dataset)
|
| 103 |
+
batch = next(data_iter)
|
| 104 |
+
|
| 105 |
+
tokens = tokenizer(batch['text'], truncation=True, max_length=MAX_SEQ_LEN + 1,
|
| 106 |
+
padding="max_length", return_tensors="pt").input_ids.to(device)
|
| 107 |
+
|
| 108 |
+
# Предсказываем следующий токен
|
| 109 |
+
x = tokens[:, :-1]
|
| 110 |
+
y = tokens[:, 1:]
|
| 111 |
+
|
| 112 |
+
with autocast(dtype=torch.bfloat16):
|
| 113 |
+
logits, loss, _ = model(x, targets=y)
|
| 114 |
+
loss = loss / grad_accum
|
| 115 |
+
|
| 116 |
+
scaler.scale(loss).backward()
|
| 117 |
+
total_loss += loss.item()
|
| 118 |
+
|
| 119 |
+
# Step Optimizer
|
| 120 |
+
scaler.unscale_(optimizer)
|
| 121 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 122 |
+
scaler.step(optimizer)
|
| 123 |
+
scaler.update()
|
| 124 |
+
scheduler.step()
|
| 125 |
+
|
| 126 |
+
# Logging
|
| 127 |
+
if step % 10 == 0:
|
| 128 |
+
print(f"Step: {step} | Loss: {total_loss:.4f} | LR: {scheduler.get_last_lr()[0]:.6f}")
|
| 129 |
+
|
| 130 |
+
# Sharded Saving
|
| 131 |
+
if step % 500 == 0:
|
| 132 |
+
save_sharded_safetensors(model, SAVE_DIR)
|
| 133 |
+
|
| 134 |
+
if __name__ == "__main__":
|
| 135 |
+
try:
|
| 136 |
+
train()
|
| 137 |
+
except Exception as e:
|
| 138 |
+
print(f"CRITICAL ERROR: {e}")
|
| 139 |
+
# Сохраняем веса даже при падении, если возможно
|
| 140 |
+
# save_sharded_safetensors(model, "jirack_crash_recovery")
|