kgrabko commited on
Commit
108847b
·
verified ·
1 Parent(s): 54b8edf

Update fine_tune_jit_with_validation_1b.py

Browse files
Files changed (1) hide show
  1. fine_tune_jit_with_validation_1b.py +355 -339
fine_tune_jit_with_validation_1b.py CHANGED
@@ -1,340 +1,356 @@
1
- # Copyright (c) 2025 CMS Manhattan
2
- # All rights reserved.
3
- #
4
- # This file is part of a project authored by CMS Manhattan. You may use, distribute, and modify
5
- # this code under the terms of the APACHE 2.0 license.
6
-
7
-
8
- """
9
- Before run this script, download the GPT-2 tokenizer files into a local folder named 'tokenizer':
10
- mkdir -p tokenizer
11
- wget -O tokenizer/tokenizer.json https://huggingface.co/gpt2/resolve/main/tokenizer.json
12
- wget -O tokenizer/vocab.json https://huggingface.co/gpt2/resolve/main/vocab.json
13
- wget -O tokenizer/merges.txt https://huggingface.co/gpt2/resolve/main/merges.txt
14
- wget -O tokenizer/tokenizer_config.json https://huggingface.co/gpt2/resolve/main/tokenizer_config.json
15
- """
16
- import os
17
- import torch
18
- import torch.nn as nn
19
- import torch.optim as optim
20
- from torch.utils.data import IterableDataset, DataLoader
21
- from transformers import GPT2TokenizerFast
22
- from tqdm import tqdm
23
- import shutil
24
- import math
25
- from pathlib import Path
26
- import re
27
-
28
- # ============================= SETTINGS =============================
29
- TRAIN_SEQ_LEN = 256
30
- BATCH_SIZE = 1
31
- EPOCHS = 1
32
- LEARNING_RATE = 6e-6
33
- WEIGHT_DECAY = 0.01
34
- GRAD_CLIP = 1.0
35
- KEEP_LAST_EPOCHS = 3
36
- VAL_SPLIT_RATIO = 0.05
37
- VOCAB_SIZE = 50257
38
-
39
- BASE_MODEL_PATH = Path("models/gpt_modern_1b_class.script.pt")
40
- LAST_TRAINED_PATH = Path("models/gpt_1b_last_trained.script.pt")
41
- BACKUP_DIR = Path("models/backups")
42
- BACKUP_DIR.mkdir(exist_ok=True)
43
-
44
- RAW_PATH = Path("datasets/dialogues_text.txt")
45
- CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
46
-
47
- # Device selection
48
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
- print(f"Using device: {device}")
50
-
51
- # -- Dataset cleaning --
52
- force_clean = False
53
- if not CLEAN_PATH.exists():
54
- print("Cleaned dataset not found. Performing initial cleaning...")
55
- force_clean = True
56
- else:
57
- try:
58
- if RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime:
59
- print("Detected changes in the raw dataset. Re-cleaning...")
60
- force_clean = True
61
- else:
62
- print(f"Using existing cleaned dataset → {CLEAN_PATH}")
63
- except FileNotFoundError:
64
- print("File system synchronization error. Performing re-cleaning for safety...")
65
- force_clean = True
66
-
67
- if force_clean:
68
- if not RAW_PATH.exists():
69
- raise FileNotFoundError(f"ERROR: Source file {RAW_PATH} not found. Check the path.")
70
-
71
- print("Cleaning up the dataset from garbage (wrong separators, extra spaces)...")
72
- text = RAW_PATH.read_text(encoding="utf-8")
73
- text = re.sub(r' {2,}', ' ', text)
74
- text = text.replace(" \n", "\n").replace("\n ", "\n")
75
- CLEAN_PATH.write_text(text, encoding="utf-8")
76
- print(f"Dataset successfully cleaned and saved → {CLEAN_PATH}")
77
-
78
- DATASET_PATH = CLEAN_PATH
79
- OUTPUT_DIR = Path("build/fine_tuning_output")
80
- MODEL_SAVE_NAME = "gpt_finetuned.script.pt"
81
-
82
- # ============================= DATASET (LAZY) =============================
83
-
84
- class LazyTextDataset(IterableDataset):
85
- """Lazy memory-efficient dataset, splits on-the-fly into train and val."""
86
- # Обратите внимание: аргумент tokenizer_name по-прежнему имеет значение по умолчанию "gpt2",
87
- # но в функции train() мы теперь передаем локальный путь.
88
- def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, tokenizer_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO):
89
- self.seq_len = seq_len
90
- # Эта строка теперь загружает токенизатор из локальной папки, если передан локальный путь.
91
- self.tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_name)
92
- self.tokenizer.pad_token = self.tokenizer.eos_token
93
- self.text_file = text_file
94
- self.split_type = split_type
95
- self.val_ratio = val_ratio
96
-
97
- print(f"Loading and tokenizing text from {text_file}")
98
- with open(text_file, "r", encoding="utf-8") as f:
99
- self.data = f.read()
100
- self.tokens = self.tokenizer.encode(self.data)
101
-
102
- # Work out split indices
103
- total_tokens = len(self.tokens) - 1 # because label sequence shifted
104
- total_batches = total_tokens // seq_len
105
- val_size = int(total_batches * self.val_ratio)
106
- train_size = total_batches - val_size
107
- if split_type == 'train':
108
- self.start = 0
109
- self.stop = train_size
110
- elif split_type == 'val':
111
- self.start = train_size
112
- self.stop = train_size + val_size
113
- else:
114
- raise ValueError(f"split_type should be 'train' or 'val', got {split_type}")
115
- self.total_sequences = self.stop - self.start
116
- print(f"Lazy dataset: {self.total_sequences:,} sequences for {split_type} split (from {total_batches:,} total)")
117
-
118
- def __iter__(self):
119
- for i in range(self.start * self.seq_len, self.stop * self.seq_len, self.seq_len):
120
- # Make sure last batch fits
121
- if i + self.seq_len + 1 > len(self.tokens):
122
- break
123
- input_seq = torch.tensor(self.tokens[i : i + self.seq_len], dtype=torch.long)
124
- label_seq = torch.tensor(self.tokens[i + 1 : i + self.seq_len + 1], dtype=torch.long)
125
- yield input_seq, label_seq
126
-
127
- def __len__(self):
128
- return self.total_sequences
129
-
130
- # ============================= GET LOGITS UTIL =============================
131
-
132
- def get_logits_from_model(model, inputs):
133
- """
134
- Robust wrapper to call either a scripted JIT model or nn.Module.
135
- Handles models that either return (logits, kv) or just logits.
136
- """
137
- # Ensure inputs on same device as model parameters/buffers
138
- inputs = inputs.to(device)
139
- try:
140
- out = model(inputs)
141
- except RuntimeError as e:
142
- # Some JIT modules expect plain tensor on CPU device for tracing path.
143
- # Re-raise if unrelated
144
- raise
145
-
146
- # Model may return logits or (logits, kv)
147
- if isinstance(out, tuple) or (isinstance(out, list) and len(out) >= 1):
148
- logits = out[0]
149
- else:
150
- logits = out
151
- return logits
152
-
153
- # ============================= EVALUATION (VALIDATION) =============================
154
-
155
- def evaluate(model, dataloader, criterion, device):
156
- model.eval()
157
- total_loss = 0.0
158
- count = 0
159
- with torch.no_grad():
160
- for inputs, targets in dataloader:
161
- inputs, targets = inputs.to(device), targets.to(device)
162
- logits = get_logits_from_model(model, inputs)
163
- logits = logits.contiguous().view(-1, logits.size(-1))
164
- targets = targets.contiguous().view(-1)[:logits.shape[0]]
165
- loss = criterion(logits, targets)
166
- total_loss += loss.item()
167
- count += 1
168
- avg_loss = total_loss / max(count, 1)
169
- model.train()
170
- return avg_loss
171
-
172
- # ============================= CLEANUP OLD EPOCHS =============================
173
-
174
- def cleanup_old_epochs(keep_last=KEEP_LAST_EPOCHS):
175
- epochs = sorted([p for p in OUTPUT_DIR.glob("epoch*") if p.is_dir()],
176
- key=lambda x: int(x.name.replace("epoch", "")))
177
- for old in epochs[:-keep_last]:
178
- if old.exists():
179
- shutil.rmtree(old)
180
- print(f"Old epoch deleted: {old.name}")
181
-
182
- # ============================= TRAINING =============================
183
-
184
- def train():
185
- OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
186
- print("Loading model...")
187
- model = None
188
- if LAST_TRAINED_PATH.exists():
189
- print(f"Continuing training from last JIT model: {LAST_TRAINED_PATH}")
190
- model = torch.jit.load(LAST_TRAINED_PATH, map_location=device)
191
- elif BASE_MODEL_PATH.exists():
192
- print(f"Starting from base JIT model: {BASE_MODEL_PATH}")
193
- model = torch.jit.load(BASE_MODEL_PATH, map_location=device)
194
- else:
195
- print(f"ERROR: JIT model not found. Checked paths: {BASE_MODEL_PATH} and {LAST_TRAINED_PATH}")
196
- print("Please run the JIT export script (e.g., 'model_export.py') first.")
197
- return
198
-
199
- # Sometimes torch.jit.load with map_location doesn't move every internal buffer.
200
- # Force a device move for ScriptModule, wrapped in try/except for compatibility.
201
- try:
202
- model.to(device)
203
- except Exception:
204
- # If ScriptModule.to fails, attempt moving by reloading state_dict -> module approach is not always possible.
205
- pass
206
-
207
- # As extra safety, try to move any freqs buffers inside submodules (best-effort).
208
- try:
209
- for name, buf in model.named_buffers():
210
- if buf is not None and buf.device != device:
211
- try:
212
- model.register_buffer(name, buf.to(device))
213
- except Exception:
214
- # Some ScriptModule buffers may not be re-registerable; ignore non-critical failures.
215
- pass
216
- except Exception:
217
- pass
218
-
219
- # Проверка весов на NaN/Inf
220
- try:
221
- for n, p in model.named_parameters():
222
- if torch.isnan(p).any():
223
- print(f"[FATAL] NaN in weights: {n}")
224
- exit(10)
225
- if torch.isinf(p).any():
226
- print(f"[FATAL] Inf in weights: {n}")
227
- exit(11)
228
- except Exception:
229
- # some JIT modules may not expose named_parameters() - ignore if unavailable
230
- pass
231
-
232
- model.train()
233
- try:
234
- model.gradient_checkpointing_enable()
235
- print("✅ Gradient Checkpointing Enabled.")
236
- except Exception:
237
- print("⚠️ Warning: model.gradient_checkpointing_enable() not found on JIT model. Training will proceed without GC.")
238
-
239
- # =========================================================================
240
- # ФИНАЛЬНОЕ ИСПРАВЛЕНИЕ: Используем ЛОКАЛЬНУЮ ПАПКУ токенизатора
241
- # =========================================================================
242
- LOCAL_TOKENIZER_PATH = "./tokenizer" # Путь к папке, куда вы загрузили файлы токенизатора
243
-
244
- train_dataset = LazyTextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, tokenizer_name=LOCAL_TOKENIZER_PATH, split_type='train', val_ratio=VAL_SPLIT_RATIO)
245
- val_dataset = LazyTextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, tokenizer_name=LOCAL_TOKENIZER_PATH, split_type='val', val_ratio=VAL_SPLIT_RATIO)
246
- # =========================================================================
247
-
248
- # IterableDataset: must use drop_last=True and shuffle=False, num_workers=0 on CPU/GPU
249
- train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=0)
250
- val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=0)
251
-
252
- optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
253
- criterion = nn.CrossEntropyLoss()
254
-
255
- total_steps = (len(train_dataset) // BATCH_SIZE) * EPOCHS
256
- print(f"\n=== BEGINNING LONG-TERM TRAINING ===")
257
- print(f"Epochs: {EPOCHS} | Steps (Train): {total_steps} | Examples (Train): {len(train_dataset)}")
258
- print(f"Batch Size (Effective): {BATCH_SIZE} | Precision: FP32")
259
-
260
- global_step = 0
261
- for epoch in range(1, EPOCHS + 1):
262
- print(f"\n--- Epoch {epoch}/{EPOCHS} ---")
263
- epoch_loss = 0.0
264
-
265
- with tqdm(train_dataloader, desc=f"Epoch {epoch} [TRAIN]", leave=False) as pbar:
266
- for inputs, targets in pbar:
267
- inputs, targets = inputs.to(device), targets.to(device)
268
- optimizer.zero_grad()
269
- logits = get_logits_from_model(model, inputs)
270
- logits = logits.contiguous().view(-1, logits.size(-1))
271
- targets_view = targets.contiguous().view(-1)[:logits.shape[0]]
272
- loss = criterion(logits, targets_view)
273
- loss.backward()
274
- try:
275
- torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
276
- except Exception:
277
- pass
278
- optimizer.step()
279
-
280
- loss_val = loss.item()
281
- epoch_loss += loss_val
282
- global_step += 1
283
-
284
- pbar.set_postfix({
285
- "loss": f"{loss_val:.3f}",
286
- "ppl": f"{math.exp(min(loss_val, 10)):.1f}",
287
- "step": f"{global_step}"
288
- })
289
-
290
- avg_train_loss = epoch_loss / max(1, len(train_dataset) // BATCH_SIZE)
291
- print(f" [TRAIN] Average loss: {avg_train_loss:.3f} | PPL: {math.exp(avg_train_loss):.1f}")
292
-
293
- print(" [VALIDATION] Starting evaluation...")
294
- val_loss = evaluate(model, val_dataloader, criterion, device)
295
- print(f" [VALIDATION] Average loss: {val_loss:.3f} | PPL: {math.exp(val_loss):.1f}")
296
-
297
- epoch_dir = OUTPUT_DIR / f"epoch{epoch}"
298
- epoch_dir.mkdir(exist_ok=True)
299
- try:
300
- torch.jit.save(model, epoch_dir / MODEL_SAVE_NAME)
301
- print(f"Model saved: {epoch_dir / MODEL_SAVE_NAME}")
302
- except Exception:
303
- # If saving scripted model fails, fallback to state_dict
304
- torch.save(model.state_dict(), epoch_dir / "state_dict.pt")
305
- print(f"State dict saved: {epoch_dir / 'state_dict.pt'}")
306
- cleanup_old_epochs()
307
-
308
- final_dir = OUTPUT_DIR / "final"
309
- final_dir.mkdir(exist_ok=True)
310
- try:
311
- torch.jit.save(model, final_dir / MODEL_SAVE_NAME)
312
- except Exception:
313
- torch.save(model.state_dict(), final_dir / "state_dict.pt")
314
- # Try to save tokenizer if available
315
- try:
316
- train_dataset.tokenizer.save_pretrained(final_dir)
317
- except Exception:
318
- pass
319
-
320
- if LAST_TRAINED_PATH.exists():
321
- backup_path = BACKUP_DIR / f"gpt_last_trained_backup_{int(os.path.getmtime(LAST_TRAINED_PATH))}.script.pt"
322
- shutil.copy(LAST_TRAINED_PATH, backup_path)
323
- print(f"Backup of previous model created → {backup_path.name}")
324
-
325
- try:
326
- shutil.copy(final_dir / MODEL_SAVE_NAME, LAST_TRAINED_PATH)
327
- print(f"Last trained model saved {LAST_TRAINED_PATH}")
328
- except Exception:
329
- pass
330
-
331
- print(f"\nTRAINING COMPLETED! Model ready:")
332
- print(f" • For chat: {final_dir / MODEL_SAVE_NAME}")
333
- print(f" • For further fine-tuning: {LAST_TRAINED_PATH}")
334
-
335
- if __name__ == "__main__":
336
- if not RAW_PATH.exists():
337
- print(f"ERROR: No file {RAW_PATH}")
338
- print("Put your text into datasets/dialogues_text.txt")
339
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  train()
 
1
+ # Copyright (c) 2025 CMS Manhattan
2
+ # All rights reserved.
3
+ # Author: Konstantin Vladimirovich Grabko
4
+ # Email: grabko@cmsmanhattan.com
5
+ # Phone: +1(516)777-0945
6
+ #
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU General Public License as published by
9
+ # the Free Software Foundation, version 3 of the License.
10
+ #
11
+ # This program is distributed in the hope that it will be useful,
12
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
+ # GNU General Public License for more details.
15
+ #
16
+ # You should have received a copy of the GNU General Public License
17
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
18
+ #
19
+ # Additional terms:
20
+ # Any commercial use or distribution of this software or derivative works
21
+ # requires explicit written permission from the copyright holder.
22
+
23
+
24
+ """
25
+ Before run this script, download the GPT-2 tokenizer files into a local folder named 'tokenizer':
26
+ mkdir -p tokenizer
27
+ wget -O tokenizer/tokenizer.json https://huggingface.co/gpt2/resolve/main/tokenizer.json
28
+ wget -O tokenizer/vocab.json https://huggingface.co/gpt2/resolve/main/vocab.json
29
+ wget -O tokenizer/merges.txt https://huggingface.co/gpt2/resolve/main/merges.txt
30
+ wget -O tokenizer/tokenizer_config.json https://huggingface.co/gpt2/resolve/main/tokenizer_config.json
31
+ """
32
+ import os
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.optim as optim
36
+ from torch.utils.data import IterableDataset, DataLoader
37
+ from transformers import GPT2TokenizerFast
38
+ from tqdm import tqdm
39
+ import shutil
40
+ import math
41
+ from pathlib import Path
42
+ import re
43
+
44
+ # ============================= SETTINGS =============================
45
+ TRAIN_SEQ_LEN = 256
46
+ BATCH_SIZE = 1
47
+ EPOCHS = 1
48
+ LEARNING_RATE = 6e-6
49
+ WEIGHT_DECAY = 0.01
50
+ GRAD_CLIP = 1.0
51
+ KEEP_LAST_EPOCHS = 3
52
+ VAL_SPLIT_RATIO = 0.05
53
+ VOCAB_SIZE = 50257
54
+
55
+ BASE_MODEL_PATH = Path("models/gpt_modern_1b_class.script.pt")
56
+ LAST_TRAINED_PATH = Path("models/gpt_1b_last_trained.script.pt")
57
+ BACKUP_DIR = Path("models/backups")
58
+ BACKUP_DIR.mkdir(exist_ok=True)
59
+
60
+ RAW_PATH = Path("datasets/dialogues_text.txt")
61
+ CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
62
+
63
+ # Device selection
64
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+ print(f"Using device: {device}")
66
+
67
+ # -- Dataset cleaning --
68
+ force_clean = False
69
+ if not CLEAN_PATH.exists():
70
+ print("Cleaned dataset not found. Performing initial cleaning...")
71
+ force_clean = True
72
+ else:
73
+ try:
74
+ if RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime:
75
+ print("Detected changes in the raw dataset. Re-cleaning...")
76
+ force_clean = True
77
+ else:
78
+ print(f"Using existing cleaned dataset → {CLEAN_PATH}")
79
+ except FileNotFoundError:
80
+ print("File system synchronization error. Performing re-cleaning for safety...")
81
+ force_clean = True
82
+
83
+ if force_clean:
84
+ if not RAW_PATH.exists():
85
+ raise FileNotFoundError(f"ERROR: Source file {RAW_PATH} not found. Check the path.")
86
+
87
+ print("Cleaning up the dataset from garbage (wrong separators, extra spaces)...")
88
+ text = RAW_PATH.read_text(encoding="utf-8")
89
+ text = re.sub(r' {2,}', ' ', text)
90
+ text = text.replace(" \n", "\n").replace("\n ", "\n")
91
+ CLEAN_PATH.write_text(text, encoding="utf-8")
92
+ print(f"Dataset successfully cleaned and saved → {CLEAN_PATH}")
93
+
94
+ DATASET_PATH = CLEAN_PATH
95
+ OUTPUT_DIR = Path("build/fine_tuning_output")
96
+ MODEL_SAVE_NAME = "gpt_finetuned.script.pt"
97
+
98
+ # ============================= DATASET (LAZY) =============================
99
+
100
+ class LazyTextDataset(IterableDataset):
101
+ """Lazy memory-efficient dataset, splits on-the-fly into train and val."""
102
+ # Обратите внимание: аргумент tokenizer_name по-прежнему имеет значение по умолчанию "gpt2",
103
+ # но в функции train() мы теперь передаем локальный путь.
104
+ def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, tokenizer_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO):
105
+ self.seq_len = seq_len
106
+ # Эта строка теперь загружает токенизатор из локальной папки, если передан локальный путь.
107
+ self.tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_name)
108
+ self.tokenizer.pad_token = self.tokenizer.eos_token
109
+ self.text_file = text_file
110
+ self.split_type = split_type
111
+ self.val_ratio = val_ratio
112
+
113
+ print(f"Loading and tokenizing text from {text_file}")
114
+ with open(text_file, "r", encoding="utf-8") as f:
115
+ self.data = f.read()
116
+ self.tokens = self.tokenizer.encode(self.data)
117
+
118
+ # Work out split indices
119
+ total_tokens = len(self.tokens) - 1 # because label sequence shifted
120
+ total_batches = total_tokens // seq_len
121
+ val_size = int(total_batches * self.val_ratio)
122
+ train_size = total_batches - val_size
123
+ if split_type == 'train':
124
+ self.start = 0
125
+ self.stop = train_size
126
+ elif split_type == 'val':
127
+ self.start = train_size
128
+ self.stop = train_size + val_size
129
+ else:
130
+ raise ValueError(f"split_type should be 'train' or 'val', got {split_type}")
131
+ self.total_sequences = self.stop - self.start
132
+ print(f"Lazy dataset: {self.total_sequences:,} sequences for {split_type} split (from {total_batches:,} total)")
133
+
134
+ def __iter__(self):
135
+ for i in range(self.start * self.seq_len, self.stop * self.seq_len, self.seq_len):
136
+ # Make sure last batch fits
137
+ if i + self.seq_len + 1 > len(self.tokens):
138
+ break
139
+ input_seq = torch.tensor(self.tokens[i : i + self.seq_len], dtype=torch.long)
140
+ label_seq = torch.tensor(self.tokens[i + 1 : i + self.seq_len + 1], dtype=torch.long)
141
+ yield input_seq, label_seq
142
+
143
+ def __len__(self):
144
+ return self.total_sequences
145
+
146
+ # ============================= GET LOGITS UTIL =============================
147
+
148
+ def get_logits_from_model(model, inputs):
149
+ """
150
+ Robust wrapper to call either a scripted JIT model or nn.Module.
151
+ Handles models that either return (logits, kv) or just logits.
152
+ """
153
+ # Ensure inputs on same device as model parameters/buffers
154
+ inputs = inputs.to(device)
155
+ try:
156
+ out = model(inputs)
157
+ except RuntimeError as e:
158
+ # Some JIT modules expect plain tensor on CPU device for tracing path.
159
+ # Re-raise if unrelated
160
+ raise
161
+
162
+ # Model may return logits or (logits, kv)
163
+ if isinstance(out, tuple) or (isinstance(out, list) and len(out) >= 1):
164
+ logits = out[0]
165
+ else:
166
+ logits = out
167
+ return logits
168
+
169
+ # ============================= EVALUATION (VALIDATION) =============================
170
+
171
+ def evaluate(model, dataloader, criterion, device):
172
+ model.eval()
173
+ total_loss = 0.0
174
+ count = 0
175
+ with torch.no_grad():
176
+ for inputs, targets in dataloader:
177
+ inputs, targets = inputs.to(device), targets.to(device)
178
+ logits = get_logits_from_model(model, inputs)
179
+ logits = logits.contiguous().view(-1, logits.size(-1))
180
+ targets = targets.contiguous().view(-1)[:logits.shape[0]]
181
+ loss = criterion(logits, targets)
182
+ total_loss += loss.item()
183
+ count += 1
184
+ avg_loss = total_loss / max(count, 1)
185
+ model.train()
186
+ return avg_loss
187
+
188
+ # ============================= CLEANUP OLD EPOCHS =============================
189
+
190
+ def cleanup_old_epochs(keep_last=KEEP_LAST_EPOCHS):
191
+ epochs = sorted([p for p in OUTPUT_DIR.glob("epoch*") if p.is_dir()],
192
+ key=lambda x: int(x.name.replace("epoch", "")))
193
+ for old in epochs[:-keep_last]:
194
+ if old.exists():
195
+ shutil.rmtree(old)
196
+ print(f"Old epoch deleted: {old.name}")
197
+
198
+ # ============================= TRAINING =============================
199
+
200
+ def train():
201
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
202
+ print("Loading model...")
203
+ model = None
204
+ if LAST_TRAINED_PATH.exists():
205
+ print(f"Continuing training from last JIT model: {LAST_TRAINED_PATH}")
206
+ model = torch.jit.load(LAST_TRAINED_PATH, map_location=device)
207
+ elif BASE_MODEL_PATH.exists():
208
+ print(f"Starting from base JIT model: {BASE_MODEL_PATH}")
209
+ model = torch.jit.load(BASE_MODEL_PATH, map_location=device)
210
+ else:
211
+ print(f"ERROR: JIT model not found. Checked paths: {BASE_MODEL_PATH} and {LAST_TRAINED_PATH}")
212
+ print("Please run the JIT export script (e.g., 'model_export.py') first.")
213
+ return
214
+
215
+ # Sometimes torch.jit.load with map_location doesn't move every internal buffer.
216
+ # Force a device move for ScriptModule, wrapped in try/except for compatibility.
217
+ try:
218
+ model.to(device)
219
+ except Exception:
220
+ # If ScriptModule.to fails, attempt moving by reloading state_dict -> module approach is not always possible.
221
+ pass
222
+
223
+ # As extra safety, try to move any freqs buffers inside submodules (best-effort).
224
+ try:
225
+ for name, buf in model.named_buffers():
226
+ if buf is not None and buf.device != device:
227
+ try:
228
+ model.register_buffer(name, buf.to(device))
229
+ except Exception:
230
+ # Some ScriptModule buffers may not be re-registerable; ignore non-critical failures.
231
+ pass
232
+ except Exception:
233
+ pass
234
+
235
+ # Проверка весов на NaN/Inf
236
+ try:
237
+ for n, p in model.named_parameters():
238
+ if torch.isnan(p).any():
239
+ print(f"[FATAL] NaN in weights: {n}")
240
+ exit(10)
241
+ if torch.isinf(p).any():
242
+ print(f"[FATAL] Inf in weights: {n}")
243
+ exit(11)
244
+ except Exception:
245
+ # some JIT modules may not expose named_parameters() - ignore if unavailable
246
+ pass
247
+
248
+ model.train()
249
+ try:
250
+ model.gradient_checkpointing_enable()
251
+ print("✅ Gradient Checkpointing Enabled.")
252
+ except Exception:
253
+ print("⚠️ Warning: model.gradient_checkpointing_enable() not found on JIT model. Training will proceed without GC.")
254
+
255
+ # =========================================================================
256
+ # ФИНАЛЬНОЕ ИСПРАВЛЕНИЕ: Используем ЛОКАЛЬНУЮ ПАПКУ токенизатора
257
+ # =========================================================================
258
+ LOCAL_TOKENIZER_PATH = "./tokenizer" # Путь к папке, куда вы загрузили файлы токенизатора
259
+
260
+ train_dataset = LazyTextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, tokenizer_name=LOCAL_TOKENIZER_PATH, split_type='train', val_ratio=VAL_SPLIT_RATIO)
261
+ val_dataset = LazyTextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, tokenizer_name=LOCAL_TOKENIZER_PATH, split_type='val', val_ratio=VAL_SPLIT_RATIO)
262
+ # =========================================================================
263
+
264
+ # IterableDataset: must use drop_last=True and shuffle=False, num_workers=0 on CPU/GPU
265
+ train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=0)
266
+ val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=0)
267
+
268
+ optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
269
+ criterion = nn.CrossEntropyLoss()
270
+
271
+ total_steps = (len(train_dataset) // BATCH_SIZE) * EPOCHS
272
+ print(f"\n=== BEGINNING LONG-TERM TRAINING ===")
273
+ print(f"Epochs: {EPOCHS} | Steps (Train): {total_steps} | Examples (Train): {len(train_dataset)}")
274
+ print(f"Batch Size (Effective): {BATCH_SIZE} | Precision: FP32")
275
+
276
+ global_step = 0
277
+ for epoch in range(1, EPOCHS + 1):
278
+ print(f"\n--- Epoch {epoch}/{EPOCHS} ---")
279
+ epoch_loss = 0.0
280
+
281
+ with tqdm(train_dataloader, desc=f"Epoch {epoch} [TRAIN]", leave=False) as pbar:
282
+ for inputs, targets in pbar:
283
+ inputs, targets = inputs.to(device), targets.to(device)
284
+ optimizer.zero_grad()
285
+ logits = get_logits_from_model(model, inputs)
286
+ logits = logits.contiguous().view(-1, logits.size(-1))
287
+ targets_view = targets.contiguous().view(-1)[:logits.shape[0]]
288
+ loss = criterion(logits, targets_view)
289
+ loss.backward()
290
+ try:
291
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
292
+ except Exception:
293
+ pass
294
+ optimizer.step()
295
+
296
+ loss_val = loss.item()
297
+ epoch_loss += loss_val
298
+ global_step += 1
299
+
300
+ pbar.set_postfix({
301
+ "loss": f"{loss_val:.3f}",
302
+ "ppl": f"{math.exp(min(loss_val, 10)):.1f}",
303
+ "step": f"{global_step}"
304
+ })
305
+
306
+ avg_train_loss = epoch_loss / max(1, len(train_dataset) // BATCH_SIZE)
307
+ print(f" [TRAIN] Average loss: {avg_train_loss:.3f} | PPL: {math.exp(avg_train_loss):.1f}")
308
+
309
+ print(" [VALIDATION] Starting evaluation...")
310
+ val_loss = evaluate(model, val_dataloader, criterion, device)
311
+ print(f" [VALIDATION] Average loss: {val_loss:.3f} | PPL: {math.exp(val_loss):.1f}")
312
+
313
+ epoch_dir = OUTPUT_DIR / f"epoch{epoch}"
314
+ epoch_dir.mkdir(exist_ok=True)
315
+ try:
316
+ torch.jit.save(model, epoch_dir / MODEL_SAVE_NAME)
317
+ print(f"Model saved: {epoch_dir / MODEL_SAVE_NAME}")
318
+ except Exception:
319
+ # If saving scripted model fails, fallback to state_dict
320
+ torch.save(model.state_dict(), epoch_dir / "state_dict.pt")
321
+ print(f"State dict saved: {epoch_dir / 'state_dict.pt'}")
322
+ cleanup_old_epochs()
323
+
324
+ final_dir = OUTPUT_DIR / "final"
325
+ final_dir.mkdir(exist_ok=True)
326
+ try:
327
+ torch.jit.save(model, final_dir / MODEL_SAVE_NAME)
328
+ except Exception:
329
+ torch.save(model.state_dict(), final_dir / "state_dict.pt")
330
+ # Try to save tokenizer if available
331
+ try:
332
+ train_dataset.tokenizer.save_pretrained(final_dir)
333
+ except Exception:
334
+ pass
335
+
336
+ if LAST_TRAINED_PATH.exists():
337
+ backup_path = BACKUP_DIR / f"gpt_last_trained_backup_{int(os.path.getmtime(LAST_TRAINED_PATH))}.script.pt"
338
+ shutil.copy(LAST_TRAINED_PATH, backup_path)
339
+ print(f"Backup of previous model created → {backup_path.name}")
340
+
341
+ try:
342
+ shutil.copy(final_dir / MODEL_SAVE_NAME, LAST_TRAINED_PATH)
343
+ print(f"Last trained model saved → {LAST_TRAINED_PATH}")
344
+ except Exception:
345
+ pass
346
+
347
+ print(f"\nTRAINING COMPLETED! Model ready:")
348
+ print(f" • For chat: {final_dir / MODEL_SAVE_NAME}")
349
+ print(f" • For further fine-tuning: {LAST_TRAINED_PATH}")
350
+
351
+ if __name__ == "__main__":
352
+ if not RAW_PATH.exists():
353
+ print(f"ERROR: No file {RAW_PATH}")
354
+ print("Put your text into datasets/dialogues_text.txt")
355
+ else:
356
  train()