kgrabko commited on
Commit
ea7894f
·
verified ·
1 Parent(s): ae3a2ac

Update fine_tune_jit_with_validation_gpt2.py

Browse files
Files changed (1) hide show
  1. fine_tune_jit_with_validation_gpt2.py +292 -276
fine_tune_jit_with_validation_gpt2.py CHANGED
@@ -1,277 +1,293 @@
1
- # Copyright (c) 2025 CMS Manhattan
2
- # All rights reserved.
3
- #
4
- # This file is part of a project authored by CMS Manhattan. You may use, distribute, and modify
5
- # this code under the terms of the APACHE 2.0 license .
6
-
7
- import os
8
- import torch
9
- import torch.nn as nn
10
- import torch.optim as optim
11
- from torch.utils.data import Dataset, DataLoader
12
- from transformers import GPT2TokenizerFast
13
- from tqdm import tqdm
14
- import shutil
15
- import math
16
- from pathlib import Path
17
- import re
18
- from typing import Optional, List, Tuple
19
-
20
- # from gpt_pytorch_33b import GPTPyTorch # REMOVED: Now loaded via JIT/TorchScript!
21
-
22
- # ============================= SETTINGS =============================
23
- TRAIN_SEQ_LEN = 256 # Context length
24
- BATCH_SIZE = 12
25
- EPOCHS = 50
26
- LEARNING_RATE = 6e-6
27
- WEIGHT_DECAY = 0.01
28
- GRAD_CLIP = 1.0
29
- KEEP_LAST_EPOCHS = 3
30
- VAL_SPLIT_RATIO = 0.05
31
-
32
- # === MODEL PATHS (ADAPTED FOR JIT) ===
33
- # NOTE: Model must be saved by torch.jit.trace() or torch.jit.script()
34
- BASE_MODEL_PATH = Path("models/JiRack_H12_L6_V50257_D768_MSL8192_FF768x4.script.pt")
35
- LAST_TRAINED_PATH = Path("models/JiRack_last_H12_L6_V50257_D768_MSL8192_FF768x4.script.pt")
36
- BACKUP_DIR = Path("models/backups")
37
- BACKUP_DIR.mkdir(exist_ok=True)
38
-
39
- # === AUTOCLEAN DATASET (Improved reliability) ===
40
- RAW_PATH = Path("datasets/dialogues_text.txt")
41
- CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
42
-
43
- # Flag to control cleaning process
44
- force_clean = False
45
- if not CLEAN_PATH.exists():
46
- print("Cleaned dataset not found. Performing initial cleaning...")
47
- force_clean = True
48
- else:
49
- try:
50
- if RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime:
51
- print("Detected changes in the raw dataset. Re-cleaning...")
52
- force_clean = True
53
- else:
54
- print(f"Using existing cleaned dataset → {CLEAN_PATH}")
55
- except FileNotFoundError:
56
- print("File system synchronization error. Performing re-cleaning for safety...")
57
- force_clean = True
58
-
59
- if force_clean:
60
- if not RAW_PATH.exists():
61
- raise FileNotFoundError(f"ERROR: Source file {RAW_PATH} not found. Check the path.")
62
-
63
- print("Cleaning up the dataset from garbage (wrong separators, extra spaces)...")
64
- text = RAW_PATH.read_text(encoding="utf-8")
65
-
66
- # 3. General cleanup: remove multiple spaces and spaces around newlines
67
- text = re.sub(r' {2,}', ' ', text)
68
- text = text.replace(" \n", "\n").replace("\n ", "\n")
69
-
70
- CLEAN_PATH.write_text(text, encoding="utf-8")
71
- print(f"Dataset successfully cleaned and saved → {CLEAN_PATH}")
72
-
73
- DATASET_PATH = CLEAN_PATH
74
-
75
- OUTPUT_DIR = Path("build/fine_tuning_output")
76
- MODEL_SAVE_NAME = "gpt_finetuned.script.pt" # Changed to JIT format
77
-
78
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
- print(f"Using device: {device}")
80
-
81
- # ============================= DATASET =============================
82
- class TextDataset(Dataset):
83
- def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, tokenizer_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO):
84
- self.seq_len = seq_len
85
- self.tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_name)
86
- self.tokenizer.pad_token = self.tokenizer.eos_token
87
- self.split_type = split_type
88
-
89
- print(f"Loading text from {text_file} for {split_type} split...")
90
- text = Path(text_file).read_text(encoding="utf-8")
91
- tokens = self.tokenizer.encode(text)
92
-
93
- if len(tokens) < seq_len * 2:
94
- raise ValueError("Text too short!")
95
-
96
- all_inputs = []
97
- all_labels = []
98
-
99
- for i in range(0, len(tokens) - seq_len, seq_len):
100
- all_inputs.append(tokens[i:i + seq_len])
101
- all_labels.append(tokens[i + 1:i + seq_len + 1])
102
-
103
- total_sequences = len(all_inputs)
104
- val_size = int(total_sequences * val_ratio)
105
- train_size = total_sequences - val_size
106
-
107
- if self.split_type == 'train':
108
- self.inputs = all_inputs[:train_size]
109
- self.labels = all_labels[:train_size]
110
- elif self.split_type == 'val':
111
- self.inputs = all_inputs[train_size:]
112
- self.labels = all_labels[train_size:]
113
- else:
114
- raise ValueError("Invalid split_type. Must be 'train' or 'val'.")
115
-
116
- print(f"Created {len(self.inputs):,} sequences for {self.split_type} split.")
117
-
118
- def __len__(self):
119
- return len(self.inputs)
120
-
121
- def __getitem__(self, idx):
122
- return (torch.tensor(self.inputs[idx], dtype=torch.long),
123
- torch.tensor(self.labels[idx], dtype=torch.long))
124
-
125
- # ----------------------------------------------------------------------------------------------------------------------
126
-
127
- # ============================= EVALUATION (VALIDATION) =============================
128
- def get_logits_from_model(model, inputs):
129
- """
130
- Adapted model invocation handling a possible output of (logits, new_kv)
131
- or just logits for JIT models.
132
- """
133
- try:
134
- # Try to call as the original model (Logits, KV-cache)
135
- logits, _ = model(inputs)
136
- except Exception:
137
- # If JIT model returns only Logits (most likely)
138
- logits = model(inputs)
139
- return logits
140
-
141
-
142
- def evaluate(model, dataloader, criterion, device):
143
- """Evaluates the model on the validation dataset."""
144
- model.eval()
145
- total_loss = 0.0
146
-
147
- with torch.no_grad():
148
- for inputs, targets in dataloader:
149
- inputs, targets = inputs.to(device), targets.to(device)
150
-
151
- logits = get_logits_from_model(model, inputs)
152
- loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
153
- total_loss += loss.item()
154
-
155
- avg_loss = total_loss / len(dataloader)
156
- model.train()
157
- return avg_loss
158
-
159
- # ----------------------------------------------------------------------------------------------------------------------
160
-
161
- # ============================= CLEANUP OLD EPOCHS =============================
162
- def cleanup_old_epochs(keep_last=KEEP_LAST_EPOCHS):
163
- epochs = sorted([p for p in OUTPUT_DIR.glob("epoch*") if p.is_dir()],
164
- key=lambda x: int(x.name.replace("epoch", "")))
165
- for old in epochs[:-keep_last]:
166
- if old.exists():
167
- shutil.rmtree(old)
168
- print(f"Old epoch deleted: {old.name}")
169
-
170
- # ----------------------------------------------------------------------------------------------------------------------
171
-
172
- # ============================= TRAINING =============================
173
- def train():
174
- OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
175
-
176
- print("Loading model...")
177
- model = None
178
-
179
- # === SMART JIT MODEL LOADING ===
180
- if LAST_TRAINED_PATH.exists():
181
- print(f"Continuing training from last JIT model: {LAST_TRAINED_PATH}")
182
- model = torch.jit.load(LAST_TRAINED_PATH, map_location=device)
183
- elif BASE_MODEL_PATH.exists():
184
- print(f"Starting from base JIT model: {BASE_MODEL_PATH}")
185
- model = torch.jit.load(BASE_MODEL_PATH, map_location=device)
186
- else:
187
- print("ERROR: JIT model not found. Cannot start training without source code or JIT file.")
188
- return
189
-
190
- model.train()
191
-
192
- # Create datasets and dataloaders (Train and Validation)
193
- train_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, split_type='train', val_ratio=VAL_SPLIT_RATIO)
194
- val_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, split_type='val', val_ratio=VAL_SPLIT_RATIO)
195
-
196
- train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
197
- val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
198
-
199
- # NOTE: .parameters() should work for JIT models if saved with parameters.
200
- optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
201
- criterion = nn.CrossEntropyLoss()
202
-
203
- total_steps = len(train_dataloader) * EPOCHS
204
- print(f"\n=== BEGINNING LONG-TERM TRAINING ===")
205
- print(f"Epochs: {EPOCHS} | Steps (Train): {total_steps} | Examples (Train): {len(train_dataset)}")
206
-
207
- global_step = 0
208
- for epoch in range(1, EPOCHS + 1):
209
- print(f"\n--- Epoch {epoch}/{EPOCHS} ---")
210
- epoch_loss = 0.0
211
-
212
- # ====================== TRAINING LOOP ======================
213
- with tqdm(train_dataloader, desc=f"Epoch {epoch} [TRAIN]", leave=False) as pbar:
214
- for inputs, targets in pbar:
215
- inputs, targets = inputs.to(device), targets.to(device)
216
-
217
- optimizer.zero_grad()
218
-
219
- # ADAPTED MODEL CALL
220
- logits = get_logits_from_model(model, inputs)
221
-
222
- loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
223
- loss.backward()
224
- torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
225
- optimizer.step()
226
-
227
- loss_val = loss.item()
228
- epoch_loss += loss_val
229
- global_step += 1
230
-
231
- pbar.set_postfix({
232
- "loss": f"{loss_val:.3f}",
233
- "ppl": f"{math.exp(min(loss_val, 10)):.1f}",
234
- "step": f"{global_step}/{total_steps}"
235
- })
236
-
237
- avg_train_loss = epoch_loss / len(train_dataloader)
238
- print(f" [TRAIN] Average loss: {avg_train_loss:.3f} | PPL: {math.exp(avg_train_loss):.1f}")
239
-
240
- # ====================== VALIDATION LOOP ======================
241
- print(" [VALIDATION] Starting evaluation...")
242
- val_loss = evaluate(model, val_dataloader, criterion, device)
243
- print(f" [VALIDATION] Average loss: {val_loss:.3f} | PPL: {math.exp(val_loss):.1f}")
244
-
245
- # ====================== SAVING MODEL (ADAPTED) ======================
246
- epoch_dir = OUTPUT_DIR / f"epoch{epoch}"
247
- epoch_dir.mkdir(exist_ok=True)
248
- # Save JIT model: use .save() instead of torch.save(state_dict)
249
- model.save(epoch_dir / MODEL_SAVE_NAME)
250
- print(f"Model saved: {epoch_dir / MODEL_SAVE_NAME}")
251
- cleanup_old_epochs()
252
-
253
- # === FINAL SAVE ===
254
- final_dir = OUTPUT_DIR / "final"
255
- final_dir.mkdir(exist_ok=True)
256
- model.save(final_dir / MODEL_SAVE_NAME) # Save final JIT model
257
- train_dataset.tokenizer.save_pretrained(final_dir)
258
-
259
- # === AUTO-SAVE LAST MODEL + BACKUP (ADAPTED) ===
260
- if LAST_TRAINED_PATH.exists():
261
- backup_path = BACKUP_DIR / f"gpt_last_trained_backup_{int(os.path.getmtime(LAST_TRAINED_PATH))}.script.pt"
262
- shutil.copy(LAST_TRAINED_PATH, backup_path)
263
- print(f"Backup of previous model created → {backup_path.name}")
264
-
265
- shutil.copy(final_dir / MODEL_SAVE_NAME, LAST_TRAINED_PATH)
266
- print(f"Last trained model saved → {LAST_TRAINED_PATH}")
267
-
268
- print(f"\nTRAINING COMPLETED! Model ready:")
269
- print(f" • For chat: {final_dir / MODEL_SAVE_NAME}")
270
- print(f" • For further fine-tuning: {LAST_TRAINED_PATH}")
271
-
272
- if __name__ == "__main__":
273
- if not RAW_PATH.exists():
274
- print(f"ERROR: No file {RAW_PATH}")
275
- print("Put your text into datasets/dialogues_text.txt")
276
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  train()
 
1
+ # Copyright (c) 2025 CMS Manhattan
2
+ # All rights reserved.
3
+ # Author: Konstantin Vladimirovich Grabko
4
+ # Email: grabko@cmsmanhattan.com
5
+ # Phone: +1(516)777-0945
6
+ #
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU General Public License as published by
9
+ # the Free Software Foundation, version 3 of the License.
10
+ #
11
+ # This program is distributed in the hope that it will be useful,
12
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
+ # GNU General Public License for more details.
15
+ #
16
+ # You should have received a copy of the GNU General Public License
17
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
18
+ #
19
+ # Additional terms:
20
+ # Any commercial use or distribution of this software or derivative works
21
+ # requires explicit written permission from the copyright holder.
22
+
23
+ import os
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.optim as optim
27
+ from torch.utils.data import Dataset, DataLoader
28
+ from transformers import GPT2TokenizerFast
29
+ from tqdm import tqdm
30
+ import shutil
31
+ import math
32
+ from pathlib import Path
33
+ import re
34
+ from typing import Optional, List, Tuple
35
+
36
+ # from gpt_pytorch_33b import GPTPyTorch # REMOVED: Now loaded via JIT/TorchScript!
37
+
38
+ # ============================= SETTINGS =============================
39
+ TRAIN_SEQ_LEN = 256 # Context length
40
+ BATCH_SIZE = 12
41
+ EPOCHS = 50
42
+ LEARNING_RATE = 6e-6
43
+ WEIGHT_DECAY = 0.01
44
+ GRAD_CLIP = 1.0
45
+ KEEP_LAST_EPOCHS = 3
46
+ VAL_SPLIT_RATIO = 0.05
47
+
48
+ # === MODEL PATHS (ADAPTED FOR JIT) ===
49
+ # NOTE: Model must be saved by torch.jit.trace() or torch.jit.script()
50
+ BASE_MODEL_PATH = Path("models/JiRack_H12_L6_V50257_D768_MSL8192_FF768x4.script.pt")
51
+ LAST_TRAINED_PATH = Path("models/JiRack_last_H12_L6_V50257_D768_MSL8192_FF768x4.script.pt")
52
+ BACKUP_DIR = Path("models/backups")
53
+ BACKUP_DIR.mkdir(exist_ok=True)
54
+
55
+ # === AUTOCLEAN DATASET (Improved reliability) ===
56
+ RAW_PATH = Path("datasets/dialogues_text.txt")
57
+ CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
58
+
59
+ # Flag to control cleaning process
60
+ force_clean = False
61
+ if not CLEAN_PATH.exists():
62
+ print("Cleaned dataset not found. Performing initial cleaning...")
63
+ force_clean = True
64
+ else:
65
+ try:
66
+ if RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime:
67
+ print("Detected changes in the raw dataset. Re-cleaning...")
68
+ force_clean = True
69
+ else:
70
+ print(f"Using existing cleaned dataset → {CLEAN_PATH}")
71
+ except FileNotFoundError:
72
+ print("File system synchronization error. Performing re-cleaning for safety...")
73
+ force_clean = True
74
+
75
+ if force_clean:
76
+ if not RAW_PATH.exists():
77
+ raise FileNotFoundError(f"ERROR: Source file {RAW_PATH} not found. Check the path.")
78
+
79
+ print("Cleaning up the dataset from garbage (wrong separators, extra spaces)...")
80
+ text = RAW_PATH.read_text(encoding="utf-8")
81
+
82
+ # 3. General cleanup: remove multiple spaces and spaces around newlines
83
+ text = re.sub(r' {2,}', ' ', text)
84
+ text = text.replace(" \n", "\n").replace("\n ", "\n")
85
+
86
+ CLEAN_PATH.write_text(text, encoding="utf-8")
87
+ print(f"Dataset successfully cleaned and saved → {CLEAN_PATH}")
88
+
89
+ DATASET_PATH = CLEAN_PATH
90
+
91
+ OUTPUT_DIR = Path("build/fine_tuning_output")
92
+ MODEL_SAVE_NAME = "gpt_finetuned.script.pt" # Changed to JIT format
93
+
94
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95
+ print(f"Using device: {device}")
96
+
97
+ # ============================= DATASET =============================
98
+ class TextDataset(Dataset):
99
+ def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, tokenizer_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO):
100
+ self.seq_len = seq_len
101
+ self.tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_name)
102
+ self.tokenizer.pad_token = self.tokenizer.eos_token
103
+ self.split_type = split_type
104
+
105
+ print(f"Loading text from {text_file} for {split_type} split...")
106
+ text = Path(text_file).read_text(encoding="utf-8")
107
+ tokens = self.tokenizer.encode(text)
108
+
109
+ if len(tokens) < seq_len * 2:
110
+ raise ValueError("Text too short!")
111
+
112
+ all_inputs = []
113
+ all_labels = []
114
+
115
+ for i in range(0, len(tokens) - seq_len, seq_len):
116
+ all_inputs.append(tokens[i:i + seq_len])
117
+ all_labels.append(tokens[i + 1:i + seq_len + 1])
118
+
119
+ total_sequences = len(all_inputs)
120
+ val_size = int(total_sequences * val_ratio)
121
+ train_size = total_sequences - val_size
122
+
123
+ if self.split_type == 'train':
124
+ self.inputs = all_inputs[:train_size]
125
+ self.labels = all_labels[:train_size]
126
+ elif self.split_type == 'val':
127
+ self.inputs = all_inputs[train_size:]
128
+ self.labels = all_labels[train_size:]
129
+ else:
130
+ raise ValueError("Invalid split_type. Must be 'train' or 'val'.")
131
+
132
+ print(f"Created {len(self.inputs):,} sequences for {self.split_type} split.")
133
+
134
+ def __len__(self):
135
+ return len(self.inputs)
136
+
137
+ def __getitem__(self, idx):
138
+ return (torch.tensor(self.inputs[idx], dtype=torch.long),
139
+ torch.tensor(self.labels[idx], dtype=torch.long))
140
+
141
+ # ----------------------------------------------------------------------------------------------------------------------
142
+
143
+ # ============================= EVALUATION (VALIDATION) =============================
144
+ def get_logits_from_model(model, inputs):
145
+ """
146
+ Adapted model invocation handling a possible output of (logits, new_kv)
147
+ or just logits for JIT models.
148
+ """
149
+ try:
150
+ # Try to call as the original model (Logits, KV-cache)
151
+ logits, _ = model(inputs)
152
+ except Exception:
153
+ # If JIT model returns only Logits (most likely)
154
+ logits = model(inputs)
155
+ return logits
156
+
157
+
158
+ def evaluate(model, dataloader, criterion, device):
159
+ """Evaluates the model on the validation dataset."""
160
+ model.eval()
161
+ total_loss = 0.0
162
+
163
+ with torch.no_grad():
164
+ for inputs, targets in dataloader:
165
+ inputs, targets = inputs.to(device), targets.to(device)
166
+
167
+ logits = get_logits_from_model(model, inputs)
168
+ loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
169
+ total_loss += loss.item()
170
+
171
+ avg_loss = total_loss / len(dataloader)
172
+ model.train()
173
+ return avg_loss
174
+
175
+ # ----------------------------------------------------------------------------------------------------------------------
176
+
177
+ # ============================= CLEANUP OLD EPOCHS =============================
178
+ def cleanup_old_epochs(keep_last=KEEP_LAST_EPOCHS):
179
+ epochs = sorted([p for p in OUTPUT_DIR.glob("epoch*") if p.is_dir()],
180
+ key=lambda x: int(x.name.replace("epoch", "")))
181
+ for old in epochs[:-keep_last]:
182
+ if old.exists():
183
+ shutil.rmtree(old)
184
+ print(f"Old epoch deleted: {old.name}")
185
+
186
+ # ----------------------------------------------------------------------------------------------------------------------
187
+
188
+ # ============================= TRAINING =============================
189
+ def train():
190
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
191
+
192
+ print("Loading model...")
193
+ model = None
194
+
195
+ # === SMART JIT MODEL LOADING ===
196
+ if LAST_TRAINED_PATH.exists():
197
+ print(f"Continuing training from last JIT model: {LAST_TRAINED_PATH}")
198
+ model = torch.jit.load(LAST_TRAINED_PATH, map_location=device)
199
+ elif BASE_MODEL_PATH.exists():
200
+ print(f"Starting from base JIT model: {BASE_MODEL_PATH}")
201
+ model = torch.jit.load(BASE_MODEL_PATH, map_location=device)
202
+ else:
203
+ print("ERROR: JIT model not found. Cannot start training without source code or JIT file.")
204
+ return
205
+
206
+ model.train()
207
+
208
+ # Create datasets and dataloaders (Train and Validation)
209
+ train_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, split_type='train', val_ratio=VAL_SPLIT_RATIO)
210
+ val_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, split_type='val', val_ratio=VAL_SPLIT_RATIO)
211
+
212
+ train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
213
+ val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
214
+
215
+ # NOTE: .parameters() should work for JIT models if saved with parameters.
216
+ optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
217
+ criterion = nn.CrossEntropyLoss()
218
+
219
+ total_steps = len(train_dataloader) * EPOCHS
220
+ print(f"\n=== BEGINNING LONG-TERM TRAINING ===")
221
+ print(f"Epochs: {EPOCHS} | Steps (Train): {total_steps} | Examples (Train): {len(train_dataset)}")
222
+
223
+ global_step = 0
224
+ for epoch in range(1, EPOCHS + 1):
225
+ print(f"\n--- Epoch {epoch}/{EPOCHS} ---")
226
+ epoch_loss = 0.0
227
+
228
+ # ====================== TRAINING LOOP ======================
229
+ with tqdm(train_dataloader, desc=f"Epoch {epoch} [TRAIN]", leave=False) as pbar:
230
+ for inputs, targets in pbar:
231
+ inputs, targets = inputs.to(device), targets.to(device)
232
+
233
+ optimizer.zero_grad()
234
+
235
+ # ADAPTED MODEL CALL
236
+ logits = get_logits_from_model(model, inputs)
237
+
238
+ loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
239
+ loss.backward()
240
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
241
+ optimizer.step()
242
+
243
+ loss_val = loss.item()
244
+ epoch_loss += loss_val
245
+ global_step += 1
246
+
247
+ pbar.set_postfix({
248
+ "loss": f"{loss_val:.3f}",
249
+ "ppl": f"{math.exp(min(loss_val, 10)):.1f}",
250
+ "step": f"{global_step}/{total_steps}"
251
+ })
252
+
253
+ avg_train_loss = epoch_loss / len(train_dataloader)
254
+ print(f" [TRAIN] Average loss: {avg_train_loss:.3f} | PPL: {math.exp(avg_train_loss):.1f}")
255
+
256
+ # ====================== VALIDATION LOOP ======================
257
+ print(" [VALIDATION] Starting evaluation...")
258
+ val_loss = evaluate(model, val_dataloader, criterion, device)
259
+ print(f" [VALIDATION] Average loss: {val_loss:.3f} | PPL: {math.exp(val_loss):.1f}")
260
+
261
+ # ====================== SAVING MODEL (ADAPTED) ======================
262
+ epoch_dir = OUTPUT_DIR / f"epoch{epoch}"
263
+ epoch_dir.mkdir(exist_ok=True)
264
+ # Save JIT model: use .save() instead of torch.save(state_dict)
265
+ model.save(epoch_dir / MODEL_SAVE_NAME)
266
+ print(f"Model saved: {epoch_dir / MODEL_SAVE_NAME}")
267
+ cleanup_old_epochs()
268
+
269
+ # === FINAL SAVE ===
270
+ final_dir = OUTPUT_DIR / "final"
271
+ final_dir.mkdir(exist_ok=True)
272
+ model.save(final_dir / MODEL_SAVE_NAME) # Save final JIT model
273
+ train_dataset.tokenizer.save_pretrained(final_dir)
274
+
275
+ # === AUTO-SAVE LAST MODEL + BACKUP (ADAPTED) ===
276
+ if LAST_TRAINED_PATH.exists():
277
+ backup_path = BACKUP_DIR / f"gpt_last_trained_backup_{int(os.path.getmtime(LAST_TRAINED_PATH))}.script.pt"
278
+ shutil.copy(LAST_TRAINED_PATH, backup_path)
279
+ print(f"Backup of previous model created → {backup_path.name}")
280
+
281
+ shutil.copy(final_dir / MODEL_SAVE_NAME, LAST_TRAINED_PATH)
282
+ print(f"Last trained model saved → {LAST_TRAINED_PATH}")
283
+
284
+ print(f"\nTRAINING COMPLETED! Model ready:")
285
+ print(f" • For chat: {final_dir / MODEL_SAVE_NAME}")
286
+ print(f" • For further fine-tuning: {LAST_TRAINED_PATH}")
287
+
288
+ if __name__ == "__main__":
289
+ if not RAW_PATH.exists():
290
+ print(f"ERROR: No file {RAW_PATH}")
291
+ print("Put your text into datasets/dialogues_text.txt")
292
+ else:
293
  train()