kgrabko commited on
Commit
54b8edf
Β·
verified Β·
1 Parent(s): 9a7388e

Update fine_tune_jit_with_validation_cuda_1b.py

Browse files
Files changed (1) hide show
  1. fine_tune_jit_with_validation_cuda_1b.py +486 -464
fine_tune_jit_with_validation_cuda_1b.py CHANGED
@@ -1,465 +1,487 @@
1
- #!/usr/bin/env python3
2
- """
3
- # install tokenizer before run
4
- mkdir -p tokenizer
5
- wget -O tokenizer/tokenizer.json https://huggingface.co/gpt2/resolve/main/tokenizer.json
6
- wget -O tokenizer/vocab.json https://huggingface.co/gpt2/resolve/main/vocab.json
7
- wget -O tokenizer/merges.txt https://huggingface.co/gpt2/resolve/main/merges.txt
8
- wget -O tokenizer/tokenizer_config.json https://huggingface.co/gpt2/resolve/main/tokenizer_config.json
9
-
10
- Updated fine-tuning script, version "prefer Python nn.Module with gradient checkpointing".
11
-
12
- What it does:
13
- - Tries to load a local Python implementation of the model (as torch.nn.Module). If found β€” uses it and
14
- enables gradient_checkpointing (if the model supports it).
15
- - If no Python model class is found β€” falls back to JIT ScriptModule (as before).
16
- - If the original weights are only available as JIT, attempts to extract state_dict() from the ScriptModule
17
- and load it into the nn.Module (best-effort).
18
- - Saves the final trained model as a JIT torch.jit.save at the end, or as state_dict if something fails.
19
- - Saves the tokenizer locally (./tokenizer) and uses it. Gives a helpful message if the tokenizer is missing.
20
- - Supports AMP (autocast + GradScaler) on GPU.
21
- - Optional support for bitsandbytes 8-bit optimizer (if installed).
22
- - Comments and console messages are in Russian.
23
-
24
- Before running: if you have a Python file with the model implementation
25
- (for example gpt_modern_1b.py or gpt_modern_1b_class.py), place it in the same folder
26
- and make sure it contains a class named JiRackPyTorch (or one of the other names the script looks for).
27
- If no such file exists β€” the script will just use the JIT model as before.
28
- """
29
-
30
- import os
31
- os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:128,garbage_collection_threshold:0.6")
32
-
33
- import sys
34
- import importlib
35
- import math
36
- import shutil
37
- import re
38
- from pathlib import Path
39
- from typing import Optional
40
-
41
- import torch
42
- import torch.nn as nn
43
- import torch.optim as optim
44
- from torch.utils.data import IterableDataset, DataLoader
45
- from transformers import GPT2TokenizerFast
46
- from tqdm import tqdm
47
- from torch.cuda.amp import GradScaler, autocast
48
-
49
- # ========================= SETTINGS =========================
50
- TRAIN_SEQ_LEN = int(os.environ.get("TRAIN_SEQ_LEN", 64))
51
- BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 1))
52
- EPOCHS = int(os.environ.get("EPOCHS", 999))
53
- LEARNING_RATE = float(os.environ.get("LEARNING_RATE", 6e-6))
54
- WEIGHT_DECAY = float(os.environ.get("WEIGHT_DECAY", 0.01))
55
- GRAD_CLIP = float(os.environ.get("GRAD_CLIP", 1.0))
56
- KEEP_LAST_EPOCHS = int(os.environ.get("KEEP_LAST_EPOCHS", 3))
57
- VAL_SPLIT_RATIO = float(os.environ.get("VAL_SPLIT_RATIO", 0.05))
58
-
59
- BASE_MODEL_PATH = Path("models/gpt_modern_1b_class.script.pt")
60
- LAST_TRAINED_PATH = Path("models/gpt_1b_last_trained.script.pt")
61
- PT_STATE_DICT_PATH = Path("models/gpt_modern_1b_class.state_dict.pt")
62
- BACKUP_DIR = Path("models/backups")
63
- BACKUP_DIR.mkdir(parents=True, exist_ok=True)
64
-
65
- RAW_PATH = Path("datasets/dialogues_text.txt")
66
- CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
67
- TOKENIZER_LOCAL_DIR = Path("./tokenizer")
68
-
69
- OUTPUT_DIR = Path("build/fine_tuning_output")
70
- MODEL_SAVE_NAME = "gpt_finetuned.script.pt"
71
-
72
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
73
- print(f"Using device: {device}")
74
-
75
- # ========================= Tokenizer helper =========================
76
- def _load_tokenizer_local(tokenizer_name: str = "gpt2"):
77
- """
78
- Try to load tokenizer from local files. If not found β€” give the user instructions.
79
- """
80
- candidates = []
81
- env_path = os.environ.get("TOKENIZER_PATH")
82
- if env_path:
83
- candidates.append(env_path)
84
- candidates.append(str(TOKENIZER_LOCAL_DIR))
85
- candidates.append(tokenizer_name)
86
- candidates.append("./")
87
- for cand in candidates:
88
- try:
89
- tok = GPT2TokenizerFast.from_pretrained(cand, local_files_only=True)
90
- if getattr(tok, "pad_token", None) is None:
91
- tok.pad_token = tok.eos_token
92
- print(f"Tokenizer loaded from: {cand}")
93
- return tok
94
- except Exception:
95
- continue
96
-
97
- raise RuntimeError(
98
- "Local tokenizer not found. Place tokenizer.json or (vocab.json + merges.txt) into ./tokenizer\n"
99
- "OR set the path via TOKENIZER_PATH environment variable.\n"
100
- "Example: export TOKENIZER_PATH=/path/to/tokenizer\n"
101
- "If you have internet access, you can temporarily use transformers.GPT2TokenizerFast.from_pretrained('gpt2')"
102
- )
103
-
104
- # ========================= Dataset =========================
105
- class LazyTextDataset(IterableDataset):
106
- def __init__(self, text_file: Path, seq_len: int = TRAIN_SEQ_LEN, tokenizer_name: str = "gpt2",
107
- split_type: str = 'train', val_ratio: float = VAL_SPLIT_RATIO):
108
- self.seq_len = seq_len
109
- self.tokenizer = _load_tokenizer_local(tokenizer_name)
110
- self.text_file = Path(text_file)
111
- self.split_type = split_type
112
- self.val_ratio = val_ratio
113
-
114
- print(f"Loading and tokenizing {self.text_file} (one-time tokenization into ids)...")
115
- with open(self.text_file, "r", encoding="utf-8") as f:
116
- data = f.read()
117
- self.tokens = self.tokenizer.encode(data)
118
-
119
- total_tokens = max(0, len(self.tokens) - 1)
120
- total_batches = total_tokens // seq_len if seq_len > 0 else 0
121
- val_size = int(total_batches * val_ratio)
122
- train_size = total_batches - val_size
123
- if split_type == 'train':
124
- self.start = 0
125
- self.stop = train_size
126
- elif split_type == 'val':
127
- self.start = train_size
128
- self.stop = train_size + val_size
129
- else:
130
- raise ValueError("split_type must be 'train' or 'val'")
131
- self.total_sequences = max(0, self.stop - self.start)
132
- print(f"Split {split_type}: {self.total_sequences} sequences (out of {total_batches})")
133
-
134
- def __iter__(self):
135
- for i in range(self.start * self.seq_len, self.stop * self.seq_len, self.seq_len):
136
- if i + self.seq_len + 1 > len(self.tokens):
137
- break
138
- input_seq = torch.tensor(self.tokens[i: i + self.seq_len], dtype=torch.long)
139
- label_seq = torch.tensor(self.tokens[i + 1: i + self.seq_len + 1], dtype=torch.long)
140
- yield input_seq, label_seq
141
-
142
- def __len__(self):
143
- return self.total_sequences
144
-
145
- # ========================= Try to load Python nn.Module model =========================
146
- def try_load_python_model():
147
- """
148
- Attempt to find and import a local Python model implementation (nn.Module).
149
- Looks for several possible module and class names.
150
- Returns (model_instance, source_description) or (None, None).
151
- """
152
- candidates_modules = [
153
- "gpt_modern_1b_class",
154
- "gpt_modern_1b",
155
- "gpt_modern_1b_class_fixed",
156
- "model", "ji_rack_model"
157
- ]
158
- candidates_class_names = [
159
- "JiRackPyTorch",
160
- "JiRackPyTorch1B",
161
- "GPTModel",
162
- "JiRackModel"
163
- ]
164
-
165
- for modname in candidates_modules:
166
- try:
167
- spec = importlib.util.find_spec(modname)
168
- if spec is None:
169
- continue
170
- mod = importlib.import_module(modname)
171
- for cls_name in candidates_class_names:
172
- if hasattr(mod, cls_name):
173
- cls = getattr(mod, cls_name)
174
- try:
175
- inst = cls()
176
- print(f"Loaded Python model class {cls_name} from module {modname}")
177
- return inst, f"python:{modname}.{cls_name}"
178
- except Exception as e:
179
- print(f"Found class {cls_name} in {modname} but instantiation failed: {e}")
180
- continue
181
- except Exception:
182
- continue
183
- return None, None
184
-
185
- # ========================= Utility: load weights from JIT script into nn.Module =========================
186
- def load_weights_from_script_to_module(script_path: Path, module_model: nn.Module):
187
- """
188
- Best-effort: extract state_dict from a ScriptModule and load it into a regular nn.Module.
189
- Returns True on success.
190
- """
191
- try:
192
- script_mod = torch.jit.load(script_path, map_location="cpu")
193
- except Exception as e:
194
- print(f"Cannot load script {script_path}: {e}")
195
- return False
196
- try:
197
- sd = script_mod.state_dict()
198
- except Exception as e:
199
- print(f"ScriptModule.state_dict() failed: {e}")
200
- return False
201
- try:
202
- module_model.load_state_dict(sd, strict=False)
203
- print("Weights successfully loaded from ScriptModule into Python nn.Module (strict=False).")
204
- return True
205
- except Exception as e:
206
- print(f"load_state_dict failed: {e}")
207
- return False
208
-
209
- # ========================= Helper to get logits from any model type =========================
210
- def get_logits_from_model(model, inputs: torch.Tensor):
211
- inputs = inputs.to(device)
212
- out = model(inputs)
213
- if isinstance(out, (tuple, list)):
214
- return out[0]
215
- return out
216
-
217
- # ========================= Evaluation =========================
218
- def evaluate(model, dataloader, criterion):
219
- model.eval()
220
- total_loss = 0.0
221
- count = 0
222
- with torch.no_grad():
223
- for inputs, targets in dataloader:
224
- inputs, targets = inputs.to(device), targets.to(device)
225
- logits = get_logits_from_model(model, inputs)
226
- logits = logits.contiguous().view(-1, logits.size(-1))
227
- targets = targets.contiguous().view(-1)[:logits.shape[0]]
228
- loss = criterion(logits, targets)
229
- total_loss += float(loss.item())
230
- count += 1
231
- model.train()
232
- return total_loss / max(1, count)
233
-
234
- # ========================= Training loop =========================
235
- def train():
236
- OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
237
- print("Loading model...")
238
-
239
- python_model, python_source = try_load_python_model()
240
- model = None
241
- model_source = None
242
-
243
- # Prefer Python nn.Module if available
244
- if python_model is not None:
245
- model = python_model
246
- model_source = python_source
247
- loaded = False
248
- # Try to load latest weights (state_dict first, then JIT β†’ state_dict)
249
- if PT_STATE_DICT_PATH.exists():
250
- try:
251
- sd = torch.load(PT_STATE_DICT_PATH, map_location="cpu")
252
- model.load_state_dict(sd, strict=False)
253
- print(f"Loaded state_dict from {PT_STATE_DICT_PATH}")
254
- loaded = True
255
- except Exception as e:
256
- print(f"Failed to load state_dict from {PT_STATE_DICT_PATH}: {e}")
257
- if not loaded and LAST_TRAINED_PATH.exists():
258
- if load_weights_from_script_to_module(LAST_TRAINED_PATH, model):
259
- loaded = True
260
- if not loaded and BASE_MODEL_PATH.exists():
261
- if load_weights_from_script_to_module(BASE_MODEL_PATH, model):
262
- loaded = True
263
- else:
264
- # Fallback to JIT ScriptModule
265
- if LAST_TRAINED_PATH.exists():
266
- model = torch.jit.load(LAST_TRAINED_PATH, map_location=device)
267
- model_source = f"jit:{LAST_TRAINED_PATH}"
268
- elif BASE_MODEL_PATH.exists():
269
- model = torch.jit.load(BASE_MODEL_PATH, map_location=device)
270
- model_source = f"jit:{BASE_MODEL_PATH}"
271
- else:
272
- print("ERROR: No model found (neither Python module nor JIT). Place a model file or Python implementation.")
273
- return
274
-
275
- print(f"Model loaded from: {model_source}")
276
-
277
- # If we are using a real nn.Module β†’ move to device + enable gradient checkpointing if possible
278
- is_python_module = isinstance(model, nn.Module)
279
- if is_python_module:
280
- model.to(device)
281
- model.train()
282
- try:
283
- model.gradient_checkpointing_enable()
284
- print("Gradient checkpointing ENABLED on Python nn.Module.")
285
- except Exception:
286
- try:
287
- model.gradient_checkpointing = True
288
- print("Set attribute gradient_checkpointing = True (best-effort).")
289
- except Exception:
290
- print("Gradient checkpointing not available on this Python model.")
291
- else:
292
- # ScriptModule path
293
- try:
294
- model.to(device)
295
- except Exception:
296
- print("Warning: model.to(device) failed for ScriptModule; trying best-effort buffer move.")
297
- model.train()
298
- print("Training on ScriptModule (gradient checkpointing not available).")
299
-
300
- # ========================= Dataset preparation =========================
301
- if not CLEAN_PATH.exists():
302
- if not RAW_PATH.exists():
303
- raise FileNotFoundError(f"Missing dataset {RAW_PATH}")
304
- print("Cleaning raw dataset β†’ cleaned version...")
305
- text = RAW_PATH.read_text(encoding="utf-8")
306
- text = re.sub(r" {2,}", " ", text)
307
- text = text.replace(" \n", "\n").replace("\n ", "\n")
308
- CLEAN_PATH.write_text(text, encoding="utf-8")
309
- print(f"Cleaned dataset saved β†’ {CLEAN_PATH}")
310
-
311
- train_dataset = LazyTextDataset(CLEAN_PATH, seq_len=TRAIN_SEQ_LEN, split_type='train', val_ratio=VAL_SPLIT_RATIO)
312
- val_dataset = LazyTextDataset(CLEAN_PATH, seq_len=TRAIN_SEQ_LEN, split_type='val', val_ratio=VAL_SPLIT_RATIO)
313
-
314
- train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=0)
315
- val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=0)
316
-
317
- # ========================= Optimizer (try 8-bit first) =========================
318
- try:
319
- import bitsandbytes as bnb # type: ignore
320
- try:
321
- optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
322
- except Exception:
323
- optimizer = bnb.optim.Adam8bit(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
324
- print("Using bitsandbytes 8-bit optimizer.")
325
- except Exception:
326
- optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
327
- print("Using standard torch.optim.AdamW (bitsandbytes not available).")
328
-
329
- criterion = nn.CrossEntropyLoss()
330
- scaler = GradScaler(enabled=(device.type == 'cuda'))
331
-
332
- if device.type == 'cuda':
333
- torch.cuda.empty_cache()
334
-
335
- total_steps = (len(train_dataset) // BATCH_SIZE) * EPOCHS if len(train_dataset) > 0 else 0
336
- print(f"\nSTARTING training: epochs={EPOCHS}, approx. steps={total_steps}, examples={len(train_dataset)}")
337
- print(f"Batch size={BATCH_SIZE}, seq_len={TRAIN_SEQ_LEN}, device={device}, AMP={'on' if device.type=='cuda' else 'off'}")
338
-
339
- global_step = 0
340
- for epoch in range(1, EPOCHS + 1):
341
- print(f"\n=== Epoch {epoch}/{EPOCHS} ===")
342
- epoch_loss = 0.0
343
-
344
- pbar = tqdm(train_loader, desc=f"Epoch {epoch} [TRAIN]", leave=False)
345
- for inputs, targets in pbar:
346
- inputs, targets = inputs.to(device), targets.to(device)
347
- optimizer.zero_grad(set_to_none=True)
348
-
349
- with autocast(enabled=(device.type == 'cuda')):
350
- logits = get_logits_from_model(model, inputs)
351
- logits = logits.contiguous().view(-1, logits.size(-1))
352
- targets_view = targets.contiguous().view(-1)[:logits.shape[0]]
353
- loss = criterion(logits, targets_view)
354
-
355
- # Backward pass (AMP-safe)
356
- if device.type == 'cuda':
357
- try:
358
- scaler.scale(loss).backward()
359
- scaler.unscale_(optimizer)
360
- except Exception as e:
361
- print("Scaled backward failed:", e)
362
- loss.backward()
363
- try:
364
- torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
365
- except Exception:
366
- pass
367
- try:
368
- scaler.step(optimizer)
369
- scaler.update()
370
- except RuntimeError as e:
371
- print("RuntimeError in scaler.step():", e)
372
- print(torch.cuda.memory_summary())
373
- # Fallback without scaler
374
- try:
375
- scaler.unscale_(optimizer)
376
- optimizer.step()
377
- except Exception as e2:
378
- print("Fallback optimizer.step() failed:", e2)
379
- raise e
380
- else:
381
- loss.backward()
382
- torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
383
- optimizer.step()
384
-
385
- if device.type == 'cuda':
386
- torch.cuda.empty_cache()
387
-
388
- loss_val = float(loss.item())
389
- epoch_loss += loss_val
390
- global_step += 1
391
- pbar.set_postfix({"loss": f"{loss_val:.4f}", "ppl": f"{math.exp(min(loss_val, 10)):.2f}", "step": global_step})
392
-
393
- avg_train_loss = epoch_loss / max(1, len(train_dataset) // BATCH_SIZE)
394
- print(f"[TRAIN] Avg loss: {avg_train_loss:.4f} | Perplexity: {math.exp(avg_train_loss):.2f}")
395
-
396
- print("Running validation...")
397
- val_loss = evaluate(model, val_loader, criterion)
398
- print(f"[VAL] Avg loss: {val_loss:.4f} | Perplexity: {math.exp(val_loss):.2f}")
399
-
400
- # Save checkpoint for this epoch
401
- epoch_dir = OUTPUT_DIR / f"epoch{epoch}"
402
- epoch_dir.mkdir(parents=True, exist_ok=True)
403
- try:
404
- if is_python_module:
405
- model.eval()
406
- dummy = torch.randint(0, 50257, (1, min(32, TRAIN_SEQ_LEN)), device=device)
407
- try:
408
- traced = torch.jit.trace(model, dummy, strict=False)
409
- torch.jit.save(traced, epoch_dir / MODEL_SAVE_NAME)
410
- print(f"Exported traced JIT β†’ {epoch_dir / MODEL_SAVE_NAME}")
411
- except Exception as e:
412
- torch.save(model.state_dict(), epoch_dir / "state_dict.pt")
413
- print(f"Saved state_dict (trace failed): {e}")
414
- model.train()
415
- else:
416
- torch.jit.save(model, epoch_dir / MODEL_SAVE_NAME)
417
- print(f"Saved ScriptModule β†’ {epoch_dir / MODEL_SAVE_NAME}")
418
- except Exception as e:
419
- print("Error while saving epoch model:", e)
420
-
421
- cleanup_old_epochs()
422
-
423
- # ========================= Final model save =========================
424
- final_dir = OUTPUT_DIR / "final"
425
- final_dir.mkdir(parents=True, exist_ok=True)
426
- try:
427
- if is_python_module:
428
- model.eval()
429
- dummy = torch.randint(0, 50257, (1, min(32, TRAIN_SEQ_LEN)), device=device)
430
- traced = torch.jit.trace(model, dummy, strict=False)
431
- torch.jit.save(traced, final_dir / MODEL_SAVE_NAME)
432
- print(f"Final traced JIT saved β†’ {final_dir / MODEL_SAVE_NAME}")
433
- else:
434
- torch.jit.save(model, final_dir / MODEL_SAVE_NAME)
435
- print(f"Final ScriptModule saved β†’ {final_dir / MODEL_SAVE_NAME}")
436
- except Exception:
437
- torch.save(model.state_dict(), final_dir / "state_dict.pt")
438
- print("Final model saved as state_dict (trace failed).")
439
-
440
- # Save tokenizer with the final model
441
- try:
442
- train_dataset.tokenizer.save_pretrained(final_dir)
443
- except Exception:
444
- pass
445
-
446
- # Backup previous last-trained model and update the "current" symlink/file
447
- if LAST_TRAINED_PATH.exists():
448
- backup_path = BACKUP_DIR / f"gpt_last_trained_backup_{int(LAST_TRAINED_PATH.stat().st_mtime)}.script.pt"
449
- shutil.copy(LAST_TRAINED_PATH, backup_path)
450
- print(f"Backed up previous last_trained β†’ {backup_path}")
451
-
452
- if (final_dir / MODEL_SAVE_NAME).exists():
453
- shutil.copy(final_dir / MODEL_SAVE_NAME, LAST_TRAINED_PATH)
454
- print(f"Copied final model β†’ {LAST_TRAINED_PATH}")
455
- elif (final_dir / "state_dict.pt").exists():
456
- shutil.copy(final_dir / "state_dict.pt", LAST_TRAINED_PATH.with_suffix(".state_dict.pt"))
457
-
458
- print("TRAINING COMPLETED.")
459
-
460
- # ========================= Entrypoint =========================
461
- if __name__ == "__main__":
462
- if not RAW_PATH.exists():
463
- print(f"ERROR: dataset {RAW_PATH} not found. Place your training text there.")
464
- sys.exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
  train()
 
1
+ # Copyright (c) 2025 CMS Manhattan
2
+ # All rights reserved.
3
+ # Author: Konstantin Vladimirovich Grabko
4
+ # Email: grabko@cmsmanhattan.com
5
+ # Phone: +1(516)777-0945
6
+ #
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU General Public License as published by
9
+ # the Free Software Foundation, version 3 of the License.
10
+ #
11
+ # This program is distributed in the hope that it will be useful,
12
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
+ # GNU General Public License for more details.
15
+ #
16
+ # You should have received a copy of the GNU General Public License
17
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
18
+ #
19
+ # Additional terms:
20
+ # Any commercial use or distribution of this software or derivative works
21
+ # requires explicit written permission from the copyright holder.
22
+
23
+ #!/usr/bin/env python3
24
+ """
25
+ # install tokenizer before run
26
+ mkdir -p tokenizer
27
+ wget -O tokenizer/tokenizer.json https://huggingface.co/gpt2/resolve/main/tokenizer.json
28
+ wget -O tokenizer/vocab.json https://huggingface.co/gpt2/resolve/main/vocab.json
29
+ wget -O tokenizer/merges.txt https://huggingface.co/gpt2/resolve/main/merges.txt
30
+ wget -O tokenizer/tokenizer_config.json https://huggingface.co/gpt2/resolve/main/tokenizer_config.json
31
+
32
+ Updated fine-tuning script, version "prefer Python nn.Module with gradient checkpointing".
33
+
34
+ What it does:
35
+ - Tries to load a local Python implementation of the model (as torch.nn.Module). If found β€” uses it and
36
+ enables gradient_checkpointing (if the model supports it).
37
+ - If no Python model class is found β€” falls back to JIT ScriptModule (as before).
38
+ - If the original weights are only available as JIT, attempts to extract state_dict() from the ScriptModule
39
+ and load it into the nn.Module (best-effort).
40
+ - Saves the final trained model as a JIT torch.jit.save at the end, or as state_dict if something fails.
41
+ - Saves the tokenizer locally (./tokenizer) and uses it. Gives a helpful message if the tokenizer is missing.
42
+ - Supports AMP (autocast + GradScaler) on GPU.
43
+ - Optional support for bitsandbytes 8-bit optimizer (if installed).
44
+ - Comments and console messages are in Russian.
45
+
46
+ Before running: if you have a Python file with the model implementation
47
+ (for example gpt_modern_1b.py or gpt_modern_1b_class.py), place it in the same folder
48
+ and make sure it contains a class named JiRackPyTorch (or one of the other names the script looks for).
49
+ If no such file exists β€” the script will just use the JIT model as before.
50
+ """
51
+
52
+ import os
53
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:128,garbage_collection_threshold:0.6")
54
+
55
+ import sys
56
+ import importlib
57
+ import math
58
+ import shutil
59
+ import re
60
+ from pathlib import Path
61
+ from typing import Optional
62
+
63
+ import torch
64
+ import torch.nn as nn
65
+ import torch.optim as optim
66
+ from torch.utils.data import IterableDataset, DataLoader
67
+ from transformers import GPT2TokenizerFast
68
+ from tqdm import tqdm
69
+ from torch.cuda.amp import GradScaler, autocast
70
+
71
+ # ========================= SETTINGS =========================
72
+ TRAIN_SEQ_LEN = int(os.environ.get("TRAIN_SEQ_LEN", 64))
73
+ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 1))
74
+ EPOCHS = int(os.environ.get("EPOCHS", 999))
75
+ LEARNING_RATE = float(os.environ.get("LEARNING_RATE", 6e-6))
76
+ WEIGHT_DECAY = float(os.environ.get("WEIGHT_DECAY", 0.01))
77
+ GRAD_CLIP = float(os.environ.get("GRAD_CLIP", 1.0))
78
+ KEEP_LAST_EPOCHS = int(os.environ.get("KEEP_LAST_EPOCHS", 3))
79
+ VAL_SPLIT_RATIO = float(os.environ.get("VAL_SPLIT_RATIO", 0.05))
80
+
81
+ BASE_MODEL_PATH = Path("models/gpt_modern_1b_class.script.pt")
82
+ LAST_TRAINED_PATH = Path("models/gpt_1b_last_trained.script.pt")
83
+ PT_STATE_DICT_PATH = Path("models/gpt_modern_1b_class.state_dict.pt")
84
+ BACKUP_DIR = Path("models/backups")
85
+ BACKUP_DIR.mkdir(parents=True, exist_ok=True)
86
+
87
+ RAW_PATH = Path("datasets/dialogues_text.txt")
88
+ CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
89
+ TOKENIZER_LOCAL_DIR = Path("./tokenizer")
90
+
91
+ OUTPUT_DIR = Path("build/fine_tuning_output")
92
+ MODEL_SAVE_NAME = "gpt_finetuned.script.pt"
93
+
94
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
95
+ print(f"Using device: {device}")
96
+
97
+ # ========================= Tokenizer helper =========================
98
+ def _load_tokenizer_local(tokenizer_name: str = "gpt2"):
99
+ """
100
+ Try to load tokenizer from local files. If not found β€” give the user instructions.
101
+ """
102
+ candidates = []
103
+ env_path = os.environ.get("TOKENIZER_PATH")
104
+ if env_path:
105
+ candidates.append(env_path)
106
+ candidates.append(str(TOKENIZER_LOCAL_DIR))
107
+ candidates.append(tokenizer_name)
108
+ candidates.append("./")
109
+ for cand in candidates:
110
+ try:
111
+ tok = GPT2TokenizerFast.from_pretrained(cand, local_files_only=True)
112
+ if getattr(tok, "pad_token", None) is None:
113
+ tok.pad_token = tok.eos_token
114
+ print(f"Tokenizer loaded from: {cand}")
115
+ return tok
116
+ except Exception:
117
+ continue
118
+
119
+ raise RuntimeError(
120
+ "Local tokenizer not found. Place tokenizer.json or (vocab.json + merges.txt) into ./tokenizer\n"
121
+ "OR set the path via TOKENIZER_PATH environment variable.\n"
122
+ "Example: export TOKENIZER_PATH=/path/to/tokenizer\n"
123
+ "If you have internet access, you can temporarily use transformers.GPT2TokenizerFast.from_pretrained('gpt2')"
124
+ )
125
+
126
+ # ========================= Dataset =========================
127
+ class LazyTextDataset(IterableDataset):
128
+ def __init__(self, text_file: Path, seq_len: int = TRAIN_SEQ_LEN, tokenizer_name: str = "gpt2",
129
+ split_type: str = 'train', val_ratio: float = VAL_SPLIT_RATIO):
130
+ self.seq_len = seq_len
131
+ self.tokenizer = _load_tokenizer_local(tokenizer_name)
132
+ self.text_file = Path(text_file)
133
+ self.split_type = split_type
134
+ self.val_ratio = val_ratio
135
+
136
+ print(f"Loading and tokenizing {self.text_file} (one-time tokenization into ids)...")
137
+ with open(self.text_file, "r", encoding="utf-8") as f:
138
+ data = f.read()
139
+ self.tokens = self.tokenizer.encode(data)
140
+
141
+ total_tokens = max(0, len(self.tokens) - 1)
142
+ total_batches = total_tokens // seq_len if seq_len > 0 else 0
143
+ val_size = int(total_batches * val_ratio)
144
+ train_size = total_batches - val_size
145
+ if split_type == 'train':
146
+ self.start = 0
147
+ self.stop = train_size
148
+ elif split_type == 'val':
149
+ self.start = train_size
150
+ self.stop = train_size + val_size
151
+ else:
152
+ raise ValueError("split_type must be 'train' or 'val'")
153
+ self.total_sequences = max(0, self.stop - self.start)
154
+ print(f"Split {split_type}: {self.total_sequences} sequences (out of {total_batches})")
155
+
156
+ def __iter__(self):
157
+ for i in range(self.start * self.seq_len, self.stop * self.seq_len, self.seq_len):
158
+ if i + self.seq_len + 1 > len(self.tokens):
159
+ break
160
+ input_seq = torch.tensor(self.tokens[i: i + self.seq_len], dtype=torch.long)
161
+ label_seq = torch.tensor(self.tokens[i + 1: i + self.seq_len + 1], dtype=torch.long)
162
+ yield input_seq, label_seq
163
+
164
+ def __len__(self):
165
+ return self.total_sequences
166
+
167
+ # ========================= Try to load Python nn.Module model =========================
168
+ def try_load_python_model():
169
+ """
170
+ Attempt to find and import a local Python model implementation (nn.Module).
171
+ Looks for several possible module and class names.
172
+ Returns (model_instance, source_description) or (None, None).
173
+ """
174
+ candidates_modules = [
175
+ "gpt_modern_1b_class",
176
+ "gpt_modern_1b",
177
+ "gpt_modern_1b_class_fixed",
178
+ "model", "ji_rack_model"
179
+ ]
180
+ candidates_class_names = [
181
+ "JiRackPyTorch",
182
+ "JiRackPyTorch1B",
183
+ "GPTModel",
184
+ "JiRackModel"
185
+ ]
186
+
187
+ for modname in candidates_modules:
188
+ try:
189
+ spec = importlib.util.find_spec(modname)
190
+ if spec is None:
191
+ continue
192
+ mod = importlib.import_module(modname)
193
+ for cls_name in candidates_class_names:
194
+ if hasattr(mod, cls_name):
195
+ cls = getattr(mod, cls_name)
196
+ try:
197
+ inst = cls()
198
+ print(f"Loaded Python model class {cls_name} from module {modname}")
199
+ return inst, f"python:{modname}.{cls_name}"
200
+ except Exception as e:
201
+ print(f"Found class {cls_name} in {modname} but instantiation failed: {e}")
202
+ continue
203
+ except Exception:
204
+ continue
205
+ return None, None
206
+
207
+ # ========================= Utility: load weights from JIT script into nn.Module =========================
208
+ def load_weights_from_script_to_module(script_path: Path, module_model: nn.Module):
209
+ """
210
+ Best-effort: extract state_dict from a ScriptModule and load it into a regular nn.Module.
211
+ Returns True on success.
212
+ """
213
+ try:
214
+ script_mod = torch.jit.load(script_path, map_location="cpu")
215
+ except Exception as e:
216
+ print(f"Cannot load script {script_path}: {e}")
217
+ return False
218
+ try:
219
+ sd = script_mod.state_dict()
220
+ except Exception as e:
221
+ print(f"ScriptModule.state_dict() failed: {e}")
222
+ return False
223
+ try:
224
+ module_model.load_state_dict(sd, strict=False)
225
+ print("Weights successfully loaded from ScriptModule into Python nn.Module (strict=False).")
226
+ return True
227
+ except Exception as e:
228
+ print(f"load_state_dict failed: {e}")
229
+ return False
230
+
231
+ # ========================= Helper to get logits from any model type =========================
232
+ def get_logits_from_model(model, inputs: torch.Tensor):
233
+ inputs = inputs.to(device)
234
+ out = model(inputs)
235
+ if isinstance(out, (tuple, list)):
236
+ return out[0]
237
+ return out
238
+
239
+ # ========================= Evaluation =========================
240
+ def evaluate(model, dataloader, criterion):
241
+ model.eval()
242
+ total_loss = 0.0
243
+ count = 0
244
+ with torch.no_grad():
245
+ for inputs, targets in dataloader:
246
+ inputs, targets = inputs.to(device), targets.to(device)
247
+ logits = get_logits_from_model(model, inputs)
248
+ logits = logits.contiguous().view(-1, logits.size(-1))
249
+ targets = targets.contiguous().view(-1)[:logits.shape[0]]
250
+ loss = criterion(logits, targets)
251
+ total_loss += float(loss.item())
252
+ count += 1
253
+ model.train()
254
+ return total_loss / max(1, count)
255
+
256
+ # ========================= Training loop =========================
257
+ def train():
258
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
259
+ print("Loading model...")
260
+
261
+ python_model, python_source = try_load_python_model()
262
+ model = None
263
+ model_source = None
264
+
265
+ # Prefer Python nn.Module if available
266
+ if python_model is not None:
267
+ model = python_model
268
+ model_source = python_source
269
+ loaded = False
270
+ # Try to load latest weights (state_dict first, then JIT β†’ state_dict)
271
+ if PT_STATE_DICT_PATH.exists():
272
+ try:
273
+ sd = torch.load(PT_STATE_DICT_PATH, map_location="cpu")
274
+ model.load_state_dict(sd, strict=False)
275
+ print(f"Loaded state_dict from {PT_STATE_DICT_PATH}")
276
+ loaded = True
277
+ except Exception as e:
278
+ print(f"Failed to load state_dict from {PT_STATE_DICT_PATH}: {e}")
279
+ if not loaded and LAST_TRAINED_PATH.exists():
280
+ if load_weights_from_script_to_module(LAST_TRAINED_PATH, model):
281
+ loaded = True
282
+ if not loaded and BASE_MODEL_PATH.exists():
283
+ if load_weights_from_script_to_module(BASE_MODEL_PATH, model):
284
+ loaded = True
285
+ else:
286
+ # Fallback to JIT ScriptModule
287
+ if LAST_TRAINED_PATH.exists():
288
+ model = torch.jit.load(LAST_TRAINED_PATH, map_location=device)
289
+ model_source = f"jit:{LAST_TRAINED_PATH}"
290
+ elif BASE_MODEL_PATH.exists():
291
+ model = torch.jit.load(BASE_MODEL_PATH, map_location=device)
292
+ model_source = f"jit:{BASE_MODEL_PATH}"
293
+ else:
294
+ print("ERROR: No model found (neither Python module nor JIT). Place a model file or Python implementation.")
295
+ return
296
+
297
+ print(f"Model loaded from: {model_source}")
298
+
299
+ # If we are using a real nn.Module β†’ move to device + enable gradient checkpointing if possible
300
+ is_python_module = isinstance(model, nn.Module)
301
+ if is_python_module:
302
+ model.to(device)
303
+ model.train()
304
+ try:
305
+ model.gradient_checkpointing_enable()
306
+ print("Gradient checkpointing ENABLED on Python nn.Module.")
307
+ except Exception:
308
+ try:
309
+ model.gradient_checkpointing = True
310
+ print("Set attribute gradient_checkpointing = True (best-effort).")
311
+ except Exception:
312
+ print("Gradient checkpointing not available on this Python model.")
313
+ else:
314
+ # ScriptModule path
315
+ try:
316
+ model.to(device)
317
+ except Exception:
318
+ print("Warning: model.to(device) failed for ScriptModule; trying best-effort buffer move.")
319
+ model.train()
320
+ print("Training on ScriptModule (gradient checkpointing not available).")
321
+
322
+ # ========================= Dataset preparation =========================
323
+ if not CLEAN_PATH.exists():
324
+ if not RAW_PATH.exists():
325
+ raise FileNotFoundError(f"Missing dataset {RAW_PATH}")
326
+ print("Cleaning raw dataset β†’ cleaned version...")
327
+ text = RAW_PATH.read_text(encoding="utf-8")
328
+ text = re.sub(r" {2,}", " ", text)
329
+ text = text.replace(" \n", "\n").replace("\n ", "\n")
330
+ CLEAN_PATH.write_text(text, encoding="utf-8")
331
+ print(f"Cleaned dataset saved β†’ {CLEAN_PATH}")
332
+
333
+ train_dataset = LazyTextDataset(CLEAN_PATH, seq_len=TRAIN_SEQ_LEN, split_type='train', val_ratio=VAL_SPLIT_RATIO)
334
+ val_dataset = LazyTextDataset(CLEAN_PATH, seq_len=TRAIN_SEQ_LEN, split_type='val', val_ratio=VAL_SPLIT_RATIO)
335
+
336
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=0)
337
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=0)
338
+
339
+ # ========================= Optimizer (try 8-bit first) =========================
340
+ try:
341
+ import bitsandbytes as bnb # type: ignore
342
+ try:
343
+ optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
344
+ except Exception:
345
+ optimizer = bnb.optim.Adam8bit(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
346
+ print("Using bitsandbytes 8-bit optimizer.")
347
+ except Exception:
348
+ optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
349
+ print("Using standard torch.optim.AdamW (bitsandbytes not available).")
350
+
351
+ criterion = nn.CrossEntropyLoss()
352
+ scaler = GradScaler(enabled=(device.type == 'cuda'))
353
+
354
+ if device.type == 'cuda':
355
+ torch.cuda.empty_cache()
356
+
357
+ total_steps = (len(train_dataset) // BATCH_SIZE) * EPOCHS if len(train_dataset) > 0 else 0
358
+ print(f"\nSTARTING training: epochs={EPOCHS}, approx. steps={total_steps}, examples={len(train_dataset)}")
359
+ print(f"Batch size={BATCH_SIZE}, seq_len={TRAIN_SEQ_LEN}, device={device}, AMP={'on' if device.type=='cuda' else 'off'}")
360
+
361
+ global_step = 0
362
+ for epoch in range(1, EPOCHS + 1):
363
+ print(f"\n=== Epoch {epoch}/{EPOCHS} ===")
364
+ epoch_loss = 0.0
365
+
366
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch} [TRAIN]", leave=False)
367
+ for inputs, targets in pbar:
368
+ inputs, targets = inputs.to(device), targets.to(device)
369
+ optimizer.zero_grad(set_to_none=True)
370
+
371
+ with autocast(enabled=(device.type == 'cuda')):
372
+ logits = get_logits_from_model(model, inputs)
373
+ logits = logits.contiguous().view(-1, logits.size(-1))
374
+ targets_view = targets.contiguous().view(-1)[:logits.shape[0]]
375
+ loss = criterion(logits, targets_view)
376
+
377
+ # Backward pass (AMP-safe)
378
+ if device.type == 'cuda':
379
+ try:
380
+ scaler.scale(loss).backward()
381
+ scaler.unscale_(optimizer)
382
+ except Exception as e:
383
+ print("Scaled backward failed:", e)
384
+ loss.backward()
385
+ try:
386
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
387
+ except Exception:
388
+ pass
389
+ try:
390
+ scaler.step(optimizer)
391
+ scaler.update()
392
+ except RuntimeError as e:
393
+ print("RuntimeError in scaler.step():", e)
394
+ print(torch.cuda.memory_summary())
395
+ # Fallback without scaler
396
+ try:
397
+ scaler.unscale_(optimizer)
398
+ optimizer.step()
399
+ except Exception as e2:
400
+ print("Fallback optimizer.step() failed:", e2)
401
+ raise e
402
+ else:
403
+ loss.backward()
404
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
405
+ optimizer.step()
406
+
407
+ if device.type == 'cuda':
408
+ torch.cuda.empty_cache()
409
+
410
+ loss_val = float(loss.item())
411
+ epoch_loss += loss_val
412
+ global_step += 1
413
+ pbar.set_postfix({"loss": f"{loss_val:.4f}", "ppl": f"{math.exp(min(loss_val, 10)):.2f}", "step": global_step})
414
+
415
+ avg_train_loss = epoch_loss / max(1, len(train_dataset) // BATCH_SIZE)
416
+ print(f"[TRAIN] Avg loss: {avg_train_loss:.4f} | Perplexity: {math.exp(avg_train_loss):.2f}")
417
+
418
+ print("Running validation...")
419
+ val_loss = evaluate(model, val_loader, criterion)
420
+ print(f"[VAL] Avg loss: {val_loss:.4f} | Perplexity: {math.exp(val_loss):.2f}")
421
+
422
+ # Save checkpoint for this epoch
423
+ epoch_dir = OUTPUT_DIR / f"epoch{epoch}"
424
+ epoch_dir.mkdir(parents=True, exist_ok=True)
425
+ try:
426
+ if is_python_module:
427
+ model.eval()
428
+ dummy = torch.randint(0, 50257, (1, min(32, TRAIN_SEQ_LEN)), device=device)
429
+ try:
430
+ traced = torch.jit.trace(model, dummy, strict=False)
431
+ torch.jit.save(traced, epoch_dir / MODEL_SAVE_NAME)
432
+ print(f"Exported traced JIT β†’ {epoch_dir / MODEL_SAVE_NAME}")
433
+ except Exception as e:
434
+ torch.save(model.state_dict(), epoch_dir / "state_dict.pt")
435
+ print(f"Saved state_dict (trace failed): {e}")
436
+ model.train()
437
+ else:
438
+ torch.jit.save(model, epoch_dir / MODEL_SAVE_NAME)
439
+ print(f"Saved ScriptModule β†’ {epoch_dir / MODEL_SAVE_NAME}")
440
+ except Exception as e:
441
+ print("Error while saving epoch model:", e)
442
+
443
+ cleanup_old_epochs()
444
+
445
+ # ========================= Final model save =========================
446
+ final_dir = OUTPUT_DIR / "final"
447
+ final_dir.mkdir(parents=True, exist_ok=True)
448
+ try:
449
+ if is_python_module:
450
+ model.eval()
451
+ dummy = torch.randint(0, 50257, (1, min(32, TRAIN_SEQ_LEN)), device=device)
452
+ traced = torch.jit.trace(model, dummy, strict=False)
453
+ torch.jit.save(traced, final_dir / MODEL_SAVE_NAME)
454
+ print(f"Final traced JIT saved β†’ {final_dir / MODEL_SAVE_NAME}")
455
+ else:
456
+ torch.jit.save(model, final_dir / MODEL_SAVE_NAME)
457
+ print(f"Final ScriptModule saved β†’ {final_dir / MODEL_SAVE_NAME}")
458
+ except Exception:
459
+ torch.save(model.state_dict(), final_dir / "state_dict.pt")
460
+ print("Final model saved as state_dict (trace failed).")
461
+
462
+ # Save tokenizer with the final model
463
+ try:
464
+ train_dataset.tokenizer.save_pretrained(final_dir)
465
+ except Exception:
466
+ pass
467
+
468
+ # Backup previous last-trained model and update the "current" symlink/file
469
+ if LAST_TRAINED_PATH.exists():
470
+ backup_path = BACKUP_DIR / f"gpt_last_trained_backup_{int(LAST_TRAINED_PATH.stat().st_mtime)}.script.pt"
471
+ shutil.copy(LAST_TRAINED_PATH, backup_path)
472
+ print(f"Backed up previous last_trained β†’ {backup_path}")
473
+
474
+ if (final_dir / MODEL_SAVE_NAME).exists():
475
+ shutil.copy(final_dir / MODEL_SAVE_NAME, LAST_TRAINED_PATH)
476
+ print(f"Copied final model β†’ {LAST_TRAINED_PATH}")
477
+ elif (final_dir / "state_dict.pt").exists():
478
+ shutil.copy(final_dir / "state_dict.pt", LAST_TRAINED_PATH.with_suffix(".state_dict.pt"))
479
+
480
+ print("TRAINING COMPLETED.")
481
+
482
+ # ========================= Entrypoint =========================
483
+ if __name__ == "__main__":
484
+ if not RAW_PATH.exists():
485
+ print(f"ERROR: dataset {RAW_PATH} not found. Place your training text there.")
486
+ sys.exit(1)
487
  train()