kgrabko commited on
Commit
701b6b5
·
verified ·
1 Parent(s): 53bee7b

Upload load_JiRack5_SlimPajama_3b_safetensors.py

Browse files
load_JiRack5_SlimPajama_3b_safetensors.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 CMS Manhattan
2
+ # All rights reserved.
3
+ # Author: Konstantin Vladimirovich Grabko
4
+ # Email: grabko@cmsmanhattan.com
5
+ # Phone: +1(516)777-0945
6
+ #
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU General Public License as published by
9
+ # the Free Software Foundation, version 3 of the License.
10
+ #
11
+ # This program is distributed in the hope that it will be useful,
12
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
+ # GNU General Public License for more details.
15
+ #
16
+ # You should have received a copy of the GNU General Public License
17
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
18
+ #
19
+ # Additional terms:
20
+ # Any commercial use or distribution of this software or derivative works
21
+ # requires explicit written permission from the copyright holder.
22
+
23
+ import os
24
+ import torch
25
+ from torch.cuda.amp import autocast, GradScaler
26
+ from datasets import load_dataset
27
+ from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
28
+ from safetensors.torch import save_file, load_file
29
+ from JiRackPyTorch_GPT5_class_3b import JiRackPyTorch, VOCAB_SIZE, MAX_SEQ_LEN
30
+
31
+ # --- Настройки сохранения ---
32
+ SAVE_DIR = "jirack_weights"
33
+ SHARD_1 = "model-00001-of-00002.safetensors"
34
+ SHARD_2 = "model-00002-of-00002.safetensors"
35
+
36
+ def save_sharded_safetensors(model, directory):
37
+ os.makedirs(directory, exist_ok=True)
38
+ state_dict = model.state_dict()
39
+ keys = list(state_dict.keys())
40
+ mid = len(keys) // 2
41
+
42
+ shard1 = {k: state_dict[k] for k in keys[:mid]}
43
+ shard2 = {k: state_dict[k] for k in keys[mid:]}
44
+
45
+ save_file(shard1, os.path.join(directory, SHARD_1))
46
+ save_file(shard2, os.path.join(directory, SHARD_2))
47
+ print(f"--- [CHECKPOINT] Model shards saved to {directory} ---")
48
+
49
+ def load_sharded_safetensors(model, directory):
50
+ p1 = os.path.join(directory, SHARD_1)
51
+ p2 = os.path.join(directory, SHARD_2)
52
+ if os.path.exists(p1) and os.path.exists(p2):
53
+ print(f"--- [RESUME] Loading weights from {directory} ---")
54
+ sd = {}
55
+ sd.update(load_file(p1))
56
+ sd.update(load_file(p2))
57
+ model.load_state_dict(sd)
58
+ return True
59
+ return False
60
+
61
+ def train():
62
+ # 1. Setup Device & Model
63
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
+ model = JiRackPyTorch().to(device)
65
+
66
+ # Проверка авторства перед стартом
67
+ print(f"--- JiRack Engine Start ---")
68
+ print(f"--- {model.get_author_info()} ---")
69
+
70
+ # Попытка возобновить обучение
71
+ load_sharded_safetensors(model, SAVE_DIR)
72
+
73
+ # 2. Dataset & Tokenizer
74
+ print("Loading Dataset: SlimPajama...")
75
+ dataset = load_dataset("cerebras/SlimPajama-627B", split="train", streaming=True)
76
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
77
+ tokenizer.pad_token = tokenizer.eos_token
78
+
79
+ # 3. Hyperparameters
80
+ lr = 2e-4
81
+ batch_size = 2 # Micro-batch
82
+ grad_accum = 16 # Effective batch = 32
83
+ max_steps = 100000
84
+
85
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.1)
86
+ scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=max_steps)
87
+ scaler = GradScaler() # Для работы с float16/bf16
88
+
89
+ model.train()
90
+ data_iter = iter(dataset)
91
+
92
+ print(f"Starting Training Loop. Target steps: {max_steps}")
93
+
94
+ for step in range(1, max_steps + 1):
95
+ optimizer.zero_grad(set_to_none=True)
96
+ total_loss = 0
97
+
98
+ for _ in range(grad_accum):
99
+ try:
100
+ batch = next(data_iter)
101
+ except StopIteration:
102
+ data_iter = iter(dataset)
103
+ batch = next(data_iter)
104
+
105
+ tokens = tokenizer(batch['text'], truncation=True, max_length=MAX_SEQ_LEN + 1,
106
+ padding="max_length", return_tensors="pt").input_ids.to(device)
107
+
108
+ # Предсказываем следующий токен
109
+ x = tokens[:, :-1]
110
+ y = tokens[:, 1:]
111
+
112
+ with autocast(dtype=torch.bfloat16):
113
+ logits, loss, _ = model(x, targets=y)
114
+ loss = loss / grad_accum
115
+
116
+ scaler.scale(loss).backward()
117
+ total_loss += loss.item()
118
+
119
+ # Step Optimizer
120
+ scaler.unscale_(optimizer)
121
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
122
+ scaler.step(optimizer)
123
+ scaler.update()
124
+ scheduler.step()
125
+
126
+ # Logging
127
+ if step % 10 == 0:
128
+ print(f"Step: {step} | Loss: {total_loss:.4f} | LR: {scheduler.get_last_lr()[0]:.6f}")
129
+
130
+ # Sharded Saving
131
+ if step % 500 == 0:
132
+ save_sharded_safetensors(model, SAVE_DIR)
133
+
134
+ if __name__ == "__main__":
135
+ try:
136
+ train()
137
+ except Exception as e:
138
+ print(f"CRITICAL ERROR: {e}")
139
+ # Сохраняем веса даже при падении, если возможно
140
+ # save_sharded_safetensors(model, "jirack_crash_recovery")