kgrabko commited on
Commit
355f97d
·
verified ·
1 Parent(s): 9a02192

Upload fine_tune_jit_with_validation_torch_script_cuda_33b.py

Browse files
fine_tune_jit_with_validation_torch_script_cuda_33b.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) 2025 CMS Manhattan
3
+ # All rights reserved.
4
+ #
5
+ # This file is part of a project authored by CMS Manhattan. You may use, distribute, and modify
6
+ # this code under the terms of the GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
7
+ # please read <http://www.gnu.org/licenses/>.
8
+ """
9
+ mkdir -p tokenizer
10
+ wget -O tokenizer/tokenizer.json https://huggingface.co/gpt2/resolve/main/tokenizer.json
11
+ wget -O tokenizer/vocab.json https://huggingface.co/gpt2/resolve/main/vocab.json
12
+ wget -O tokenizer/merges.txt https://huggingface.co/gpt2/resolve/main/merges.txt
13
+ wget -O tokenizer/tokenizer_config.json https://huggingface.co/gpt2/resolve/main/tokenizer_config.json
14
+
15
+ Updated fine-tuning script, version "prefer Python nn.Module with gradient checkpointing".
16
+
17
+ What it does:
18
+ - Tries to load a local Python model implementation (nn.Module). If found — uses it and
19
+ enables gradient_checkpointing (if implemented).
20
+ - If the Python model class is not found, falls back to JIT ScriptModule (as before).
21
+ - If the original weights are only available as JIT, attempts to extract state_dict() from the ScriptModule
22
+ and load it into the nn.Module (best-effort).
23
+ - Saves the final trained model as a JIT (torch.jit.save) at the end, or saves state_dict if an error occurs.
24
+ - Saves the tokenizer locally (./tokenizer) and uses it. If the tokenizer is missing, gives a helpful hint.
25
+ - Supports AMP (autocast + GradScaler) for GPU.
26
+ - Optional support for bitsandbytes 8-bit optimizer (if available).
27
+ - Comments and messages are in Russian.
28
+
29
+ Before running: if you have a file with the model implementation (e.g., gpt_modern_1b.py or gpt_modern_1b_class.py),
30
+ place it in the same directory and make sure it contains a class named JiRackPyTorch or another name we're looking for.
31
+ If not — the script will fall back to using the JIT model as before.
32
+ """
33
+
34
+ import os
35
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:128,garbage_collection_threshold:0.6")
36
+
37
+ import sys
38
+ import importlib
39
+ import math
40
+ import shutil
41
+ import re
42
+ from pathlib import Path
43
+ from typing import Optional
44
+
45
+ import torch
46
+ import torch.nn as nn
47
+ import torch.optim as optim
48
+ from torch.utils.data import IterableDataset, DataLoader
49
+ from transformers import GPT2TokenizerFast
50
+ from tqdm import tqdm
51
+ from torch.cuda.amp import GradScaler, autocast
52
+
53
+ # ========================= SETTINGS =========================
54
+ TRAIN_SEQ_LEN = int(os.environ.get("TRAIN_SEQ_LEN", 64))
55
+ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 1))
56
+ EPOCHS = int(os.environ.get("EPOCHS", 999))
57
+ LEARNING_RATE = float(os.environ.get("LEARNING_RATE", 6e-6))
58
+ WEIGHT_DECAY = float(os.environ.get("WEIGHT_DECAY", 0.01))
59
+ GRAD_CLIP = float(os.environ.get("GRAD_CLIP", 1.0))
60
+ KEEP_LAST_EPOCHS = int(os.environ.get("KEEP_LAST_EPOCHS", 3))
61
+ VAL_SPLIT_RATIO = float(os.environ.get("VAL_SPLIT_RATIO", 0.05))
62
+
63
+ BASE_MODEL_PATH = Path("models/gpt_modern_33b_class.script.pt")
64
+ LAST_TRAINED_PATH = Path("models/gpt_33b_last_trained.script.pt")
65
+ PT_STATE_DICT_PATH = Path("models/gpt_modern_33b_class.state_dict.pt")
66
+ BACKUP_DIR = Path("models/backups")
67
+ BACKUP_DIR.mkdir(parents=True, exist_ok=True)
68
+
69
+ RAW_PATH = Path("datasets/dialogues_text.txt")
70
+ CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
71
+ TOKENIZER_LOCAL_DIR = Path("./tokenizer")
72
+
73
+ OUTPUT_DIR = Path("build/fine_tuning_output")
74
+ MODEL_SAVE_NAME = "gpt_finetuned.script.pt"
75
+
76
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
77
+ print(f"Using device: {device}")
78
+
79
+ # ========================= Tokenizer helper =========================
80
+ def _load_tokenizer_local(tokenizer_name: str = "gpt2"):
81
+ """
82
+ Try to load the tokenizer locally. If not found — give user instructions.
83
+ """
84
+ candidates = []
85
+ env_path = os.environ.get("TOKENIZER_PATH")
86
+ if env_path:
87
+ candidates.append(env_path)
88
+ candidates.append(str(TOKENIZER_LOCAL_DIR))
89
+ candidates.append(tokenizer_name)
90
+ candidates.append("./")
91
+ for cand in candidates:
92
+ try:
93
+ tok = GPT2TokenizerFast.from_pretrained(cand, local_files_only=True)
94
+ if getattr(tok, "pad_token", None) is None:
95
+ tok.pad_token = tok.eos_token
96
+ print(f"Tokenizer loaded from: {cand}")
97
+ return tok
98
+ except Exception:
99
+ continue
100
+
101
+ raise RuntimeError(
102
+ "Local tokenizer not found. Place tokenizer.json or (vocab.json + merges.txt) in ./tokenizer\n"
103
+ "OR specify the path via TOKENIZER_PATH environment variable.\n"
104
+ "Example: export TOKENIZER_PATH=/path/to/tokenizer\n"
105
+ "If you have internet access, you can temporarily use transformers.GPT2TokenizerFast.from_pretrained('gpt2')"
106
+ )
107
+
108
+ # ========================= Dataset =========================
109
+ class LazyTextDataset(IterableDataset):
110
+ def __init__(self, text_file: Path, seq_len: int = TRAIN_SEQ_LEN, tokenizer_name: str = "gpt2",
111
+ split_type: str = 'train', val_ratio: float = VAL_SPLIT_RATIO):
112
+ self.seq_len = seq_len
113
+ self.tokenizer = _load_tokenizer_local(tokenizer_name)
114
+ self.text_file = Path(text_file)
115
+ self.split_type = split_type
116
+ self.val_ratio = val_ratio
117
+
118
+ print(f"Loading and tokenizing {self.text_file} (one-time tokenization into ids)...")
119
+ with open(self.text_file, "r", encoding="utf-8") as f:
120
+ data = f.read()
121
+ self.tokens = self.tokenizer.encode(data)
122
+
123
+ total_tokens = max(0, len(self.tokens) - 1)
124
+ total_batches = total_tokens // seq_len if seq_len > 0 else 0
125
+ val_size = int(total_batches * val_ratio)
126
+ train_size = total_batches - val_size
127
+ if split_type == 'train':
128
+ self.start = 0
129
+ self.stop = train_size
130
+ elif split_type == 'val':
131
+ self.start = train_size
132
+ self.stop = train_size + val_size
133
+ else:
134
+ raise ValueError("split_type must be 'train' or 'val'")
135
+ self.total_sequences = max(0, self.stop - self.start)
136
+ print(f"Split {split_type}: {self.total_sequences} sequences (out of {total_batches})")
137
+
138
+ def __iter__(self):
139
+ for i in range(self.start * self.seq_len, self.stop * self.seq_len, self.seq_len):
140
+ if i + self.seq_len + 1 > len(self.tokens):
141
+ break
142
+ input_seq = torch.tensor(self.tokens[i: i + self.seq_len], dtype=torch.long)
143
+ label_seq = torch.tensor(self.tokens[i + 1: i + self.seq_len + 1], dtype=torch.long)
144
+ yield input_seq, label_seq
145
+
146
+ def __len__(self):
147
+ return self.total_sequences
148
+
149
+ # ========================= Attempt to load Python nn.Module model =========================
150
+ def try_load_python_model():
151
+ """
152
+ Try to find a local module/class implementing the model (nn.Module).
153
+ Search through possible module and class names.
154
+ Returns (model_instance, source_str) or (None, None).
155
+ """
156
+ candidates_modules = [
157
+ "gpt_modern_1b_class",
158
+ "gpt_modern_1b",
159
+ "gpt_modern_1b_class_fixed",
160
+ "model", "ji_rack_model"
161
+ ]
162
+ candidates_class_names = [
163
+ "JiRackPyTorch",
164
+ "JiRackPyTorch1B",
165
+ "GPTModel",
166
+ "JiRackModel"
167
+ ]
168
+
169
+ for modname in candidates_modules:
170
+ try:
171
+ spec = importlib.util.find_spec(modname)
172
+ if spec is None:
173
+ continue
174
+ mod = importlib.import_module(modname)
175
+ for cls_name in candidates_class_names:
176
+ if hasattr(mod, cls_name):
177
+ cls = getattr(mod, cls_name)
178
+ try:
179
+ inst = cls()
180
+ print(f"Loaded Python model class {cls_name} from module {modname}")
181
+ return inst, f"python:{modname}.{cls_name}"
182
+ except Exception as e:
183
+ print(f"Found class {cls_name} in {modname} but failed to instantiate: {e}")
184
+ continue
185
+ except Exception:
186
+ continue
187
+ return None, None
188
+
189
+ # ========================= Utility: load weights from script -> module =========================
190
+ def load_weights_from_script_to_module(script_path: Path, module_model: nn.Module):
191
+ """
192
+ Best-effort: load state_dict from ScriptModule and apply it to nn.Module.
193
+ Returns True on success.
194
+ """
195
+ try:
196
+ script_mod = torch.jit.load(script_path, map_location="cpu")
197
+ except Exception as e:
198
+ print(f"Cannot load script at {script_path}: {e}")
199
+ return False
200
+ try:
201
+ sd = script_mod.state_dict()
202
+ except Exception as e:
203
+ print(f"ScriptModule.state_dict() failed: {e}")
204
+ return False
205
+ # Try to load into module_model
206
+ try:
207
+ module_model.load_state_dict(sd, strict=False)
208
+ print("Weights successfully loaded from ScriptModule.state_dict() into Python nn.Module (strict=False).")
209
+ return True
210
+ except Exception as e:
211
+ print(f"load_state_dict failed: {e}")
212
+ return False
213
+
214
+ # ========================= get_logits helper =========================
215
+ def get_logits_from_model(model, inputs: torch.Tensor):
216
+ inputs = inputs.to(device)
217
+ out = model(inputs)
218
+ if isinstance(out, (tuple, list)):
219
+ return out[0]
220
+ return out
221
+
222
+ # ========================= EVALUATION =========================
223
+ def evaluate(model, dataloader, criterion):
224
+ model.eval()
225
+ total_loss = 0.0
226
+ count = 0
227
+ with torch.no_grad():
228
+ for inputs, targets in dataloader:
229
+ inputs, targets = inputs.to(device), targets.to(device)
230
+ logits = get_logits_from_model(model, inputs)
231
+ logits = logits.contiguous().view(-1, logits.size(-1))
232
+ targets = targets.contiguous().view(-1)[:logits.shape[0]]
233
+ loss = criterion(logits, targets)
234
+ total_loss += float(loss.item())
235
+ count += 1
236
+ model.train()
237
+ return total_loss / max(1, count)
238
+
239
+ # ========================= TRAINING LOOP =========================
240
+ def train():
241
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
242
+ print("Loading model...")
243
+
244
+ python_model, python_source = try_load_python_model()
245
+ model = None
246
+ model_source = None
247
+
248
+ # If we have a Python model class, instantiate and try to load weights
249
+ if python_model is not None:
250
+ model = python_model
251
+ model_source = python_source
252
+ # Try to load weights from last trained state_dict or from JIT script
253
+ loaded = False
254
+ if PT_STATE_DICT_PATH.exists():
255
+ try:
256
+ sd = torch.load(PT_STATE_DICT_PATH, map_location="cpu")
257
+ model.load_state_dict(sd, strict=False)
258
+ print(f"Loaded state_dict from {PT_STATE_DICT_PATH}")
259
+ loaded = True
260
+ except Exception as e:
261
+ print(f"Failed to load state_dict from {PT_STATE_DICT_PATH}: {e}")
262
+ if (not loaded) and LAST_TRAINED_PATH.exists():
263
+ if load_weights_from_script_to_module(LAST_TRAINED_PATH, model):
264
+ loaded = True
265
+ if (not loaded) and BASE_MODEL_PATH.exists():
266
+ if load_weights_from_script_to_module(BASE_MODEL_PATH, model):
267
+ loaded = True
268
+ else:
269
+ # Fallback to ScriptModule (JIT)
270
+ if LAST_TRAINED_PATH.exists():
271
+ model = torch.jit.load(LAST_TRAINED_PATH, map_location=device)
272
+ model_source = f"jit:{LAST_TRAINED_PATH}"
273
+ elif BASE_MODEL_PATH.exists():
274
+ model = torch.jit.load(BASE_MODEL_PATH, map_location=device)
275
+ model_source = f"jit:{BASE_MODEL_PATH}"
276
+ else:
277
+ print("ERROR: No model found (neither Python module nor JIT). Place a model file or Python module.")
278
+ return
279
+
280
+ print(f"Model loaded from: {model_source}")
281
+
282
+ # If we have an nn.Module (Python), move to device and enable gradient checkpointing if available
283
+ is_python_module = isinstance(model, nn.Module)
284
+ if is_python_module:
285
+ model.to(device)
286
+ model.train()
287
+ try:
288
+ model.gradient_checkpointing_enable()
289
+ print("Gradient checkpointing ENABLED on Python nn.Module.")
290
+ except Exception:
291
+ # try alternative attribute
292
+ try:
293
+ model.gradient_checkpointing = True
294
+ print("Set attribute gradient_checkpointing = True (best-effort).")
295
+ except Exception:
296
+ print("Gradient checkpointing not available on this Python model.")
297
+ else:
298
+ # ScriptModule
299
+ try:
300
+ model.to(device)
301
+ except Exception:
302
+ print("Warning: model.to(device) failed for ScriptModule; attempting best-effort buffer moves.")
303
+ model.train()
304
+ print("Training on ScriptModule (gradient checkpointing not available).")
305
+
306
+ # Dataset preparation
307
+ if not CLEAN_PATH.exists():
308
+ # Try to clean RAW -> CLEAN
309
+ if not RAW_PATH.exists():
310
+ raise FileNotFoundError(f"Missing dataset {RAW_PATH}")
311
+ print("Cleaning raw dataset to create cleaned version...")
312
+ text = RAW_PATH.read_text(encoding="utf-8")
313
+ text = re.sub(r" {2,}", " ", text)
314
+ text = text.replace(" \n", "\n").replace("\n ", "\n")
315
+ CLEAN_PATH.write_text(text, encoding="utf-8")
316
+ print(f"Saved cleaned dataset → {CLEAN_PATH}")
317
+
318
+ train_dataset = LazyTextDataset(CLEAN_PATH, seq_len=TRAIN_SEQ_LEN, split_type='train', val_ratio=VAL_SPLIT_RATIO)
319
+ val_dataset = LazyTextDataset(CLEAN_PATH, seq_len=TRAIN_SEQ_LEN, split_type='val', val_ratio=VAL_SPLIT_RATIO)
320
+
321
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=0)
322
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=0)
323
+
324
+ # Optimizer: try bitsandbytes 8-bit first (if available)
325
+ try:
326
+ import bitsandbytes as bnb # type: ignore
327
+ try:
328
+ optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
329
+ except Exception:
330
+ optimizer = bnb.optim.Adam8bit(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
331
+ print("Using bitsandbytes 8-bit optimizer.")
332
+ except Exception:
333
+ optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
334
+ print("Using torch.optim.AdamW (bitsandbytes not available).")
335
+
336
+ criterion = nn.CrossEntropyLoss()
337
+ scaler = GradScaler(enabled=(device.type == 'cuda'))
338
+
339
+ # Pre-clean GPU memory
340
+ if device.type == 'cuda':
341
+ torch.cuda.empty_cache()
342
+
343
+ total_steps = (len(train_dataset) // BATCH_SIZE) * EPOCHS if len(train_dataset) > 0 else 0
344
+ print(f"\nSTARTING training: epochs={EPOCHS}, approx steps={total_steps}, examples={len(train_dataset)}")
345
+ print(f"Batch size={BATCH_SIZE}, seq_len={TRAIN_SEQ_LEN}, device={device}, AMP={'on' if device.type=='cuda' else 'off'}")
346
+
347
+ global_step = 0
348
+ for epoch in range(1, EPOCHS + 1):
349
+ print(f"\n=== Epoch {epoch}/{EPOCHS} ===")
350
+ epoch_loss = 0.0
351
+
352
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch} [TRAIN]", leave=False)
353
+ for inputs, targets in pbar:
354
+ inputs, targets = inputs.to(device), targets.to(device)
355
+ optimizer.zero_grad(set_to_none=True)
356
+
357
+ with autocast(enabled=(device.type == 'cuda')):
358
+ logits = get_logits_from_model(model, inputs)
359
+ logits = logits.contiguous().view(-1, logits.size(-1))
360
+ targets_view = targets.contiguous().view(-1)[:logits.shape[0]]
361
+ loss = criterion(logits, targets_view)
362
+
363
+ # Backward pass + optimizer step (AMP-safe)
364
+ if device.type == 'cuda':
365
+ try:
366
+ scaler.scale(loss).backward()
367
+ scaler.unscale_(optimizer)
368
+ except Exception as e:
369
+ print("Scaled backward failed:", e)
370
+ loss.backward()
371
+ try:
372
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
373
+ except Exception:
374
+ pass
375
+ try:
376
+ scaler.step(optimizer)
377
+ scaler.update()
378
+ except RuntimeError as e:
379
+ print("RuntimeError during scaler.step():", e)
380
+ print(torch.cuda.memory_summary())
381
+ # Fallback: regular step
382
+ try:
383
+ scaler.unscale_(optimizer)
384
+ optimizer.step()
385
+ except Exception as e2:
386
+ print("Fallback optimizer.step() failed:", e2)
387
+ raise e
388
+ else:
389
+ loss.backward()
390
+ try:
391
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
392
+ except Exception:
393
+ pass
394
+ optimizer.step()
395
+
396
+ if device.type == 'cuda':
397
+ torch.cuda.empty_cache()
398
+
399
+ loss_val = float(loss.item())
400
+ epoch_loss += loss_val
401
+ global_step += 1
402
+ pbar.set_postfix({"loss": f"{loss_val:.4f}", "ppl": f"{math.exp(min(loss_val, 10)):.2f}", "step": f"{global_step}"})
403
+
404
+ avg_train_loss = epoch_loss / max(1, len(train_dataset) // BATCH_SIZE)
405
+ print(f"[TRAIN] Avg loss: {avg_train_loss:.4f} | PPL: {math.exp(avg_train_loss):.2f}")
406
+
407
+ print("Running validation...")
408
+ val_loss = evaluate(model, val_loader, criterion)
409
+ print(f"[VAL] Avg loss: {val_loss:.4f} | PPL: {math.exp(val_loss):.2f}")
410
+
411
+ # Save current epoch
412
+ epoch_dir = OUTPUT_DIR / f"epoch{epoch}"
413
+ epoch_dir.mkdir(parents=True, exist_ok=True)
414
+ try:
415
+ if is_python_module:
416
+ model.eval()
417
+ dummy = torch.randint(0, 50257, (1, min(32, TRAIN_SEQ_LEN)), device=device)
418
+ try:
419
+ traced = torch.jit.trace(model, dummy, strict=False)
420
+ torch.jit.save(traced, epoch_dir / MODEL_SAVE_NAME)
421
+ print(f"Exported traced JIT to {epoch_dir / MODEL_SAVE_NAME}")
422
+ except Exception as e:
423
+ torch.save(model.state_dict(), epoch_dir / "state_dict.pt")
424
+ print(f"Saved state_dict due to trace failure: {e}")
425
+ model.train()
426
+ else:
427
+ torch.jit.save(model, epoch_dir / MODEL_SAVE_NAME)
428
+ print(f"Saved ScriptModule to {epoch_dir / MODEL_SAVE_NAME}")
429
+ except Exception as e:
430
+ print("Exception during model save:", e)
431
+
432
+ cleanup_old_epochs()
433
+
434
+ # Final model save
435
+ final_dir = OUTPUT_DIR / "final"
436
+ final_dir.mkdir(parents=True, exist_ok=True)
437
+ try:
438
+ if is_python_module:
439
+ model.eval()
440
+ dummy = torch.randint(0, 50257, (1, min(32, TRAIN_SEQ_LEN)), device=device)
441
+ try:
442
+ traced = torch.jit.trace(model, dummy, strict=False)
443
+ torch.jit.save(traced, final_dir / MODEL_SAVE_NAME)
444
+ print(f"Final JIT saved to {final_dir / MODEL_SAVE_NAME}")
445
+ except Exception:
446
+ torch.save(model.state_dict(), final_dir / "state_dict.pt")
447
+ print("Final model saved as state_dict (trace failed).")
448
+ else:
449
+ torch.jit.save(model, final_dir / MODEL_SAVE_NAME)
450
+ print(f"Final ScriptModule saved to {final_dir / MODEL_SAVE_NAME}")
451
+ except Exception:
452
+ try:
453
+ torch.save(model.state_dict(), final_dir / "state_dict.pt")
454
+ except Exception:
455
+ pass
456
+
457
+ # Save tokenizer with the final model
458
+ try:
459
+ train_dataset.tokenizer.save_pretrained(final_dir)
460
+ except Exception:
461
+ pass
462
+
463
+ # Backup previous last_trained and update with new one
464
+ if LAST_TRAINED_PATH.exists():
465
+ backup_path = BACKUP_DIR / f"gpt_last_trained_backup_{int(LAST_TRAINED_PATH.stat().st_mtime)}.script.pt"
466
+ try:
467
+ shutil.copy(LAST_TRAINED_PATH, backup_path)
468
+ print(f"Backed up previous last_trained -> {backup_path}")
469
+ except Exception:
470
+ pass
471
+ try:
472
+ if (final_dir / MODEL_SAVE_NAME).exists():
473
+ shutil.copy(final_dir / MODEL_SAVE_NAME, LAST_TRAINED_PATH)
474
+ print(f"Copied final model to {LAST_TRAINED_PATH}")
475
+ elif (final_dir / "state_dict.pt").exists():
476
+ shutil.copy(final_dir / "state_dict.pt", LAST_TRAINED_PATH.with_suffix(".state_dict.pt"))
477
+ print("Copied final state_dict to LAST_TRAINED_PATH with .state_dict.pt suffix")
478
+ except Exception:
479
+ pass
480
+
481
+ print("TRAINING COMPLETED.")
482
+
483
+ # ========================= Entrypoint =========================
484
+ if __name__ == "__main__":
485
+ if not RAW_PATH.exists():
486
+ print(f"ERROR: dataset {RAW_PATH} not found. Put your training text there.")
487
+ sys.exit(1)
488
+ train()