kgrabko commited on
Commit
6552c3a
·
verified ·
1 Parent(s): f95d60f

Upload fine_tune_with_validation.py

Browse files
Files changed (1) hide show
  1. fine_tune_with_validation.py +237 -0
fine_tune_with_validation.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import tiktoken # New: OpenAI's fast BPE tokenizer (open-source, auto-downloads once)
7
+ from tqdm import tqdm
8
+ import shutil
9
+ import math
10
+ from pathlib import Path
11
+ import re
12
+
13
+ from gpt_pytorch import GPTPyTorch # Your model import
14
+
15
+ # ============================= SETTINGS =============================
16
+ TRAIN_SEQ_LEN = 256 # Context length (increased for better coherence)
17
+ BATCH_SIZE = 12
18
+ EPOCHS = 50
19
+ LEARNING_RATE = 6e-6
20
+ WEIGHT_DECAY = 0.01
21
+ GRAD_CLIP = 1.0
22
+ KEEP_LAST_EPOCHS = 3
23
+ VAL_SPLIT_RATIO = 0.05 # 5% of data used for validation
24
+
25
+ # === MODEL PATHS ===
26
+ BASE_MODEL_PATH = Path("models/JiRack_H12_L6_V50257_D768_MSL8192_FF768x4.pt")
27
+ LAST_TRAINED_PATH = Path("models/JiRack_last_H12_L6_V50257_D768_MSL8192_FF768x4.pt")
28
+ BACKUP_DIR = Path("models/backups")
29
+ BACKUP_DIR.mkdir(exist_ok=True)
30
+
31
+ # === DATASET AUTO-CLEANING ===
32
+ RAW_PATH = Path("datasets/dialogues_text_clean.txt")
33
+ CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
34
+
35
+ force_clean = False
36
+ if not CLEAN_PATH.exists():
37
+ print("Clean dataset not found. Performing initial cleaning...")
38
+ force_clean = True
39
+ else:
40
+ try:
41
+ if RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime:
42
+ print("Changes detected in the source dataset. Performing re-cleaning...")
43
+ force_clean = True
44
+ else:
45
+ print(f"Using existing clean dataset → {CLEAN_PATH}")
46
+ except FileNotFoundError:
47
+ print("File system synchronization error. Performing re-cleaning for safety...")
48
+ force_clean = True
49
+
50
+ if force_clean:
51
+ if not RAW_PATH.exists():
52
+ raise FileNotFoundError(f"ERROR: Source file {RAW_PATH} not found. Check the path.")
53
+
54
+ print("Cleaning dataset from garbage (extra spaces, incorrect separators)...")
55
+ text = RAW_PATH.read_text(encoding="utf-8")
56
+
57
+ text = re.sub(r' {2,}', ' ', text)
58
+ text = text.replace(" \n", "\n").replace("\n ", "\n")
59
+
60
+ CLEAN_PATH.write_text(text, encoding="utf-8")
61
+ print(f"Dataset successfully cleaned and saved → {CLEAN_PATH}")
62
+
63
+ DATASET_PATH = CLEAN_PATH
64
+
65
+ OUTPUT_DIR = Path("build/fine_tuning_output")
66
+ MODEL_SAVE_NAME = "gpt_finetuned.pt"
67
+
68
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
+ print(f"Using device: {device}")
70
+
71
+ # ============================= DATASET =============================
72
+ class TextDataset(Dataset):
73
+ def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO):
74
+ self.seq_len = seq_len
75
+
76
+ # New: tiktoken – exact GPT-2 encoding, fast, auto-downloads small .tiktoken file once
77
+ print(f"Loading tiktoken encoding '{encoding_name}' (small file auto-downloads on first run if needed)...")
78
+ self.enc = tiktoken.get_encoding(encoding_name) # "gpt2" is built-in and matches GPT-2 vocab perfectly
79
+
80
+ self.split_type = split_type
81
+
82
+ print(f"Loading text from {text_file} for {split_type} split...")
83
+ text = Path(text_file).read_text(encoding="utf-8")
84
+ tokens = self.enc.encode(text) # List of ints (exact same as GPT2Tokenizer)
85
+
86
+ if len(tokens) < seq_len * 2:
87
+ raise ValueError("Text too short!")
88
+
89
+ all_inputs = []
90
+ all_labels = []
91
+
92
+ for i in range(0, len(tokens) - seq_len, seq_len):
93
+ all_inputs.append(tokens[i:i + seq_len])
94
+ all_labels.append(tokens[i + 1:i + seq_len + 1])
95
+
96
+ total_sequences = len(all_inputs)
97
+ val_size = int(total_sequences * val_ratio)
98
+ train_size = total_sequences - val_size
99
+
100
+ if self.split_type == 'train':
101
+ self.inputs = all_inputs[:train_size]
102
+ self.labels = all_labels[:train_size]
103
+ elif self.split_type == 'val':
104
+ self.inputs = all_inputs[train_size:]
105
+ self.labels = all_labels[train_size:]
106
+ else:
107
+ raise ValueError("Invalid split_type. Must be 'train' or 'val'.")
108
+
109
+ print(f"Created {len(self.inputs):,} sequences for {self.split_type} split.")
110
+
111
+ def __len__(self):
112
+ return len(self.inputs)
113
+
114
+ def __getitem__(self, idx):
115
+ return (torch.tensor(self.inputs[idx], dtype=torch.long),
116
+ torch.tensor(self.labels[idx], dtype=torch.long))
117
+
118
+ # ============================= EVALUATION =============================
119
+ def evaluate(model, dataloader, criterion, device):
120
+ model.eval()
121
+ total_loss = 0.0
122
+
123
+ with torch.no_grad():
124
+ for inputs, targets in dataloader:
125
+ inputs, targets = inputs.to(device), targets.to(device)
126
+ logits, _ = model(inputs)
127
+ loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
128
+ total_loss += loss.item()
129
+
130
+ avg_loss = total_loss / len(dataloader)
131
+ model.train()
132
+ return avg_loss
133
+
134
+ # ============================= CLEANUP OLD EPOCHS =============================
135
+ def cleanup_old_epochs(keep_last=KEEP_LAST_EPOCHS):
136
+ epochs = sorted([p for p in OUTPUT_DIR.glob("epoch*") if p.is_dir()],
137
+ key=lambda x: int(x.name.replace("epoch", "")))
138
+ for old in epochs[:-keep_last]:
139
+ if old.exists():
140
+ shutil.rmtree(old)
141
+ print(f"Deleted old epoch: {old.name}")
142
+
143
+ # ============================= TRAINING =============================
144
+ def train():
145
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
146
+
147
+ print("Loading model...")
148
+ model = GPTPyTorch().to(device)
149
+
150
+ # Safer loading (silences FutureWarning)
151
+ load_kwargs = {"map_location": device, "weights_only": True}
152
+ if LAST_TRAINED_PATH.exists():
153
+ print(f"Resuming training from last model: {LAST_TRAINED_PATH}")
154
+ model.load_state_dict(torch.load(LAST_TRAINED_PATH, **load_kwargs))
155
+ elif BASE_MODEL_PATH.exists():
156
+ print(f"Starting from base model: {BASE_MODEL_PATH}")
157
+ model.load_state_dict(torch.load(BASE_MODEL_PATH, **load_kwargs))
158
+ else:
159
+ print("No models found — initializing from scratch")
160
+
161
+ model.train()
162
+
163
+ train_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO)
164
+ val_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='val', val_ratio=VAL_SPLIT_RATIO)
165
+
166
+ train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
167
+ val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
168
+
169
+ optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
170
+ criterion = nn.CrossEntropyLoss()
171
+
172
+ total_steps = len(train_dataloader) * EPOCHS
173
+ print(f"\n=== STARTING LONG-TERM TRAINING ===")
174
+ print(f"Epochs: {EPOCHS} | Steps (Train): {total_steps} | Examples (Train): {len(train_dataset)}")
175
+
176
+ global_step = 0
177
+ for epoch in range(1, EPOCHS + 1):
178
+ print(f"\n--- Epoch {epoch}/{EPOCHS} ---")
179
+ epoch_loss = 0.0
180
+
181
+ with tqdm(train_dataloader, desc=f"Epoch {epoch} [TRAIN]", leave=False) as pbar:
182
+ for inputs, targets in pbar:
183
+ inputs, targets = inputs.to(device), targets.to(device)
184
+
185
+ optimizer.zero_grad()
186
+ logits, _ = model(inputs)
187
+ loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
188
+ loss.backward()
189
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
190
+ optimizer.step()
191
+
192
+ loss_val = loss.item()
193
+ epoch_loss += loss_val
194
+ global_step += 1
195
+
196
+ pbar.set_postfix({
197
+ "loss": f"{loss_val:.3f}",
198
+ "ppl": f"{math.exp(min(loss_val, 10)):.1f}",
199
+ "step": f"{global_step}/{total_steps}"
200
+ })
201
+
202
+ avg_train_loss = epoch_loss / len(train_dataloader)
203
+ print(f" [TRAIN] Average loss: {avg_train_loss:.3f} | PPL: {math.exp(avg_train_loss):.1f}")
204
+
205
+ print(" [VALIDATION] Running evaluation...")
206
+ val_loss = evaluate(model, val_dataloader, criterion, device)
207
+ print(f" [VALIDATION] Average loss: {val_loss:.3f} | PPL: {math.exp(val_loss):.1f}")
208
+
209
+ epoch_dir = OUTPUT_DIR / f"epoch{epoch}"
210
+ epoch_dir.mkdir(exist_ok=True)
211
+ torch.save(model.state_dict(), epoch_dir / MODEL_SAVE_NAME)
212
+ print(f"Model saved: {epoch_dir / MODEL_SAVE_NAME}")
213
+ cleanup_old_epochs()
214
+
215
+ # Final saving – note: no tokenizer.save_pretrained anymore (tiktoken doesn't need it)
216
+ final_dir = OUTPUT_DIR / "final"
217
+ final_dir.mkdir(exist_ok=True)
218
+ torch.save(model.state_dict(), final_dir / MODEL_SAVE_NAME)
219
+
220
+ if LAST_TRAINED_PATH.exists():
221
+ backup_path = BACKUP_DIR / f"gpt_last_trained_backup_{int(os.path.getmtime(LAST_TRAINED_PATH))}.pt"
222
+ shutil.copy(LAST_TRAINED_PATH, backup_path)
223
+ print(f"Backup of previous model created → {backup_path.name}")
224
+
225
+ shutil.copy(final_dir / MODEL_SAVE_NAME, LAST_TRAINED_PATH)
226
+ print(f"Last trained model saved → {LAST_TRAINED_PATH}")
227
+
228
+ print(f"\nTRAINING COMPLETED! Model is ready:")
229
+ print(f" • For chat/inference: {final_dir / MODEL_SAVE_NAME}")
230
+ print(f" • For continued fine-tuning: {LAST_TRAINED_PATH}")
231
+
232
+ if __name__ == "__main__":
233
+ if not RAW_PATH.exists():
234
+ print(f"ERROR: File {RAW_PATH} not found")
235
+ print("Place your text in datasets/dialogues_text.txt")
236
+ else:
237
+ train()