kgrabko commited on
Commit
f47992d
·
verified ·
1 Parent(s): 5b024f4

Upload fine_tune_jit_with_validation_gpt2.py

Browse files
Files changed (1) hide show
  1. fine_tune_jit_with_validation_gpt2.py +277 -0
fine_tune_jit_with_validation_gpt2.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 CMS Manhattan
2
+ # All rights reserved.
3
+ #
4
+ # This file is part of a project authored by CMS Manhattan. You may use, distribute, and modify
5
+ # this code under the terms of the APACHE 2.0 license .
6
+
7
+ import os
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from transformers import GPT2TokenizerFast
13
+ from tqdm import tqdm
14
+ import shutil
15
+ import math
16
+ from pathlib import Path
17
+ import re
18
+ from typing import Optional, List, Tuple
19
+
20
+ # from gpt_pytorch_33b import GPTPyTorch # REMOVED: Now loaded via JIT/TorchScript!
21
+
22
+ # ============================= SETTINGS =============================
23
+ TRAIN_SEQ_LEN = 256 # Context length
24
+ BATCH_SIZE = 12
25
+ EPOCHS = 50
26
+ LEARNING_RATE = 6e-6
27
+ WEIGHT_DECAY = 0.01
28
+ GRAD_CLIP = 1.0
29
+ KEEP_LAST_EPOCHS = 3
30
+ VAL_SPLIT_RATIO = 0.05
31
+
32
+ # === MODEL PATHS (ADAPTED FOR JIT) ===
33
+ # NOTE: Model must be saved by torch.jit.trace() or torch.jit.script()
34
+ BASE_MODEL_PATH = Path("models/JiRack_H12_L6_V50257_D768_MSL8192_FF768x4.script.pt")
35
+ LAST_TRAINED_PATH = Path("models/JiRack_last_H12_L6_V50257_D768_MSL8192_FF768x4.script.pt")
36
+ BACKUP_DIR = Path("models/backups")
37
+ BACKUP_DIR.mkdir(exist_ok=True)
38
+
39
+ # === AUTOCLEAN DATASET (Improved reliability) ===
40
+ RAW_PATH = Path("datasets/dialogues_text.txt")
41
+ CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
42
+
43
+ # Flag to control cleaning process
44
+ force_clean = False
45
+ if not CLEAN_PATH.exists():
46
+ print("Cleaned dataset not found. Performing initial cleaning...")
47
+ force_clean = True
48
+ else:
49
+ try:
50
+ if RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime:
51
+ print("Detected changes in the raw dataset. Re-cleaning...")
52
+ force_clean = True
53
+ else:
54
+ print(f"Using existing cleaned dataset → {CLEAN_PATH}")
55
+ except FileNotFoundError:
56
+ print("File system synchronization error. Performing re-cleaning for safety...")
57
+ force_clean = True
58
+
59
+ if force_clean:
60
+ if not RAW_PATH.exists():
61
+ raise FileNotFoundError(f"ERROR: Source file {RAW_PATH} not found. Check the path.")
62
+
63
+ print("Cleaning up the dataset from garbage (wrong separators, extra spaces)...")
64
+ text = RAW_PATH.read_text(encoding="utf-8")
65
+
66
+ # 3. General cleanup: remove multiple spaces and spaces around newlines
67
+ text = re.sub(r' {2,}', ' ', text)
68
+ text = text.replace(" \n", "\n").replace("\n ", "\n")
69
+
70
+ CLEAN_PATH.write_text(text, encoding="utf-8")
71
+ print(f"Dataset successfully cleaned and saved → {CLEAN_PATH}")
72
+
73
+ DATASET_PATH = CLEAN_PATH
74
+
75
+ OUTPUT_DIR = Path("build/fine_tuning_output")
76
+ MODEL_SAVE_NAME = "gpt_finetuned.script.pt" # Changed to JIT format
77
+
78
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+ print(f"Using device: {device}")
80
+
81
+ # ============================= DATASET =============================
82
+ class TextDataset(Dataset):
83
+ def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, tokenizer_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO):
84
+ self.seq_len = seq_len
85
+ self.tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_name)
86
+ self.tokenizer.pad_token = self.tokenizer.eos_token
87
+ self.split_type = split_type
88
+
89
+ print(f"Loading text from {text_file} for {split_type} split...")
90
+ text = Path(text_file).read_text(encoding="utf-8")
91
+ tokens = self.tokenizer.encode(text)
92
+
93
+ if len(tokens) < seq_len * 2:
94
+ raise ValueError("Text too short!")
95
+
96
+ all_inputs = []
97
+ all_labels = []
98
+
99
+ for i in range(0, len(tokens) - seq_len, seq_len):
100
+ all_inputs.append(tokens[i:i + seq_len])
101
+ all_labels.append(tokens[i + 1:i + seq_len + 1])
102
+
103
+ total_sequences = len(all_inputs)
104
+ val_size = int(total_sequences * val_ratio)
105
+ train_size = total_sequences - val_size
106
+
107
+ if self.split_type == 'train':
108
+ self.inputs = all_inputs[:train_size]
109
+ self.labels = all_labels[:train_size]
110
+ elif self.split_type == 'val':
111
+ self.inputs = all_inputs[train_size:]
112
+ self.labels = all_labels[train_size:]
113
+ else:
114
+ raise ValueError("Invalid split_type. Must be 'train' or 'val'.")
115
+
116
+ print(f"Created {len(self.inputs):,} sequences for {self.split_type} split.")
117
+
118
+ def __len__(self):
119
+ return len(self.inputs)
120
+
121
+ def __getitem__(self, idx):
122
+ return (torch.tensor(self.inputs[idx], dtype=torch.long),
123
+ torch.tensor(self.labels[idx], dtype=torch.long))
124
+
125
+ # ----------------------------------------------------------------------------------------------------------------------
126
+
127
+ # ============================= EVALUATION (VALIDATION) =============================
128
+ def get_logits_from_model(model, inputs):
129
+ """
130
+ Adapted model invocation handling a possible output of (logits, new_kv)
131
+ or just logits for JIT models.
132
+ """
133
+ try:
134
+ # Try to call as the original model (Logits, KV-cache)
135
+ logits, _ = model(inputs)
136
+ except Exception:
137
+ # If JIT model returns only Logits (most likely)
138
+ logits = model(inputs)
139
+ return logits
140
+
141
+
142
+ def evaluate(model, dataloader, criterion, device):
143
+ """Evaluates the model on the validation dataset."""
144
+ model.eval()
145
+ total_loss = 0.0
146
+
147
+ with torch.no_grad():
148
+ for inputs, targets in dataloader:
149
+ inputs, targets = inputs.to(device), targets.to(device)
150
+
151
+ logits = get_logits_from_model(model, inputs)
152
+ loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
153
+ total_loss += loss.item()
154
+
155
+ avg_loss = total_loss / len(dataloader)
156
+ model.train()
157
+ return avg_loss
158
+
159
+ # ----------------------------------------------------------------------------------------------------------------------
160
+
161
+ # ============================= CLEANUP OLD EPOCHS =============================
162
+ def cleanup_old_epochs(keep_last=KEEP_LAST_EPOCHS):
163
+ epochs = sorted([p for p in OUTPUT_DIR.glob("epoch*") if p.is_dir()],
164
+ key=lambda x: int(x.name.replace("epoch", "")))
165
+ for old in epochs[:-keep_last]:
166
+ if old.exists():
167
+ shutil.rmtree(old)
168
+ print(f"Old epoch deleted: {old.name}")
169
+
170
+ # ----------------------------------------------------------------------------------------------------------------------
171
+
172
+ # ============================= TRAINING =============================
173
+ def train():
174
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
175
+
176
+ print("Loading model...")
177
+ model = None
178
+
179
+ # === SMART JIT MODEL LOADING ===
180
+ if LAST_TRAINED_PATH.exists():
181
+ print(f"Continuing training from last JIT model: {LAST_TRAINED_PATH}")
182
+ model = torch.jit.load(LAST_TRAINED_PATH, map_location=device)
183
+ elif BASE_MODEL_PATH.exists():
184
+ print(f"Starting from base JIT model: {BASE_MODEL_PATH}")
185
+ model = torch.jit.load(BASE_MODEL_PATH, map_location=device)
186
+ else:
187
+ print("ERROR: JIT model not found. Cannot start training without source code or JIT file.")
188
+ return
189
+
190
+ model.train()
191
+
192
+ # Create datasets and dataloaders (Train and Validation)
193
+ train_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, split_type='train', val_ratio=VAL_SPLIT_RATIO)
194
+ val_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, split_type='val', val_ratio=VAL_SPLIT_RATIO)
195
+
196
+ train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
197
+ val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
198
+
199
+ # NOTE: .parameters() should work for JIT models if saved with parameters.
200
+ optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
201
+ criterion = nn.CrossEntropyLoss()
202
+
203
+ total_steps = len(train_dataloader) * EPOCHS
204
+ print(f"\n=== BEGINNING LONG-TERM TRAINING ===")
205
+ print(f"Epochs: {EPOCHS} | Steps (Train): {total_steps} | Examples (Train): {len(train_dataset)}")
206
+
207
+ global_step = 0
208
+ for epoch in range(1, EPOCHS + 1):
209
+ print(f"\n--- Epoch {epoch}/{EPOCHS} ---")
210
+ epoch_loss = 0.0
211
+
212
+ # ====================== TRAINING LOOP ======================
213
+ with tqdm(train_dataloader, desc=f"Epoch {epoch} [TRAIN]", leave=False) as pbar:
214
+ for inputs, targets in pbar:
215
+ inputs, targets = inputs.to(device), targets.to(device)
216
+
217
+ optimizer.zero_grad()
218
+
219
+ # ADAPTED MODEL CALL
220
+ logits = get_logits_from_model(model, inputs)
221
+
222
+ loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
223
+ loss.backward()
224
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
225
+ optimizer.step()
226
+
227
+ loss_val = loss.item()
228
+ epoch_loss += loss_val
229
+ global_step += 1
230
+
231
+ pbar.set_postfix({
232
+ "loss": f"{loss_val:.3f}",
233
+ "ppl": f"{math.exp(min(loss_val, 10)):.1f}",
234
+ "step": f"{global_step}/{total_steps}"
235
+ })
236
+
237
+ avg_train_loss = epoch_loss / len(train_dataloader)
238
+ print(f" [TRAIN] Average loss: {avg_train_loss:.3f} | PPL: {math.exp(avg_train_loss):.1f}")
239
+
240
+ # ====================== VALIDATION LOOP ======================
241
+ print(" [VALIDATION] Starting evaluation...")
242
+ val_loss = evaluate(model, val_dataloader, criterion, device)
243
+ print(f" [VALIDATION] Average loss: {val_loss:.3f} | PPL: {math.exp(val_loss):.1f}")
244
+
245
+ # ====================== SAVING MODEL (ADAPTED) ======================
246
+ epoch_dir = OUTPUT_DIR / f"epoch{epoch}"
247
+ epoch_dir.mkdir(exist_ok=True)
248
+ # Save JIT model: use .save() instead of torch.save(state_dict)
249
+ model.save(epoch_dir / MODEL_SAVE_NAME)
250
+ print(f"Model saved: {epoch_dir / MODEL_SAVE_NAME}")
251
+ cleanup_old_epochs()
252
+
253
+ # === FINAL SAVE ===
254
+ final_dir = OUTPUT_DIR / "final"
255
+ final_dir.mkdir(exist_ok=True)
256
+ model.save(final_dir / MODEL_SAVE_NAME) # Save final JIT model
257
+ train_dataset.tokenizer.save_pretrained(final_dir)
258
+
259
+ # === AUTO-SAVE LAST MODEL + BACKUP (ADAPTED) ===
260
+ if LAST_TRAINED_PATH.exists():
261
+ backup_path = BACKUP_DIR / f"gpt_last_trained_backup_{int(os.path.getmtime(LAST_TRAINED_PATH))}.script.pt"
262
+ shutil.copy(LAST_TRAINED_PATH, backup_path)
263
+ print(f"Backup of previous model created → {backup_path.name}")
264
+
265
+ shutil.copy(final_dir / MODEL_SAVE_NAME, LAST_TRAINED_PATH)
266
+ print(f"Last trained model saved → {LAST_TRAINED_PATH}")
267
+
268
+ print(f"\nTRAINING COMPLETED! Model ready:")
269
+ print(f" • For chat: {final_dir / MODEL_SAVE_NAME}")
270
+ print(f" • For further fine-tuning: {LAST_TRAINED_PATH}")
271
+
272
+ if __name__ == "__main__":
273
+ if not RAW_PATH.exists():
274
+ print(f"ERROR: No file {RAW_PATH}")
275
+ print("Put your text into datasets/dialogues_text.txt")
276
+ else:
277
+ train()