kgrabko commited on
Commit
a22c288
·
verified ·
1 Parent(s): 6552c3a

Update fine_tune_with_validation.py

Browse files
Files changed (1) hide show
  1. fine_tune_with_validation.py +258 -236
fine_tune_with_validation.py CHANGED
@@ -1,237 +1,259 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torch.optim as optim
5
- from torch.utils.data import Dataset, DataLoader
6
- import tiktoken # New: OpenAI's fast BPE tokenizer (open-source, auto-downloads once)
7
- from tqdm import tqdm
8
- import shutil
9
- import math
10
- from pathlib import Path
11
- import re
12
-
13
- from gpt_pytorch import GPTPyTorch # Your model import
14
-
15
- # ============================= SETTINGS =============================
16
- TRAIN_SEQ_LEN = 256 # Context length (increased for better coherence)
17
- BATCH_SIZE = 12
18
- EPOCHS = 50
19
- LEARNING_RATE = 6e-6
20
- WEIGHT_DECAY = 0.01
21
- GRAD_CLIP = 1.0
22
- KEEP_LAST_EPOCHS = 3
23
- VAL_SPLIT_RATIO = 0.05 # 5% of data used for validation
24
-
25
- # === MODEL PATHS ===
26
- BASE_MODEL_PATH = Path("models/JiRack_H12_L6_V50257_D768_MSL8192_FF768x4.pt")
27
- LAST_TRAINED_PATH = Path("models/JiRack_last_H12_L6_V50257_D768_MSL8192_FF768x4.pt")
28
- BACKUP_DIR = Path("models/backups")
29
- BACKUP_DIR.mkdir(exist_ok=True)
30
-
31
- # === DATASET AUTO-CLEANING ===
32
- RAW_PATH = Path("datasets/dialogues_text_clean.txt")
33
- CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
34
-
35
- force_clean = False
36
- if not CLEAN_PATH.exists():
37
- print("Clean dataset not found. Performing initial cleaning...")
38
- force_clean = True
39
- else:
40
- try:
41
- if RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime:
42
- print("Changes detected in the source dataset. Performing re-cleaning...")
43
- force_clean = True
44
- else:
45
- print(f"Using existing clean dataset {CLEAN_PATH}")
46
- except FileNotFoundError:
47
- print("File system synchronization error. Performing re-cleaning for safety...")
48
- force_clean = True
49
-
50
- if force_clean:
51
- if not RAW_PATH.exists():
52
- raise FileNotFoundError(f"ERROR: Source file {RAW_PATH} not found. Check the path.")
53
-
54
- print("Cleaning dataset from garbage (extra spaces, incorrect separators)...")
55
- text = RAW_PATH.read_text(encoding="utf-8")
56
-
57
- text = re.sub(r' {2,}', ' ', text)
58
- text = text.replace(" \n", "\n").replace("\n ", "\n")
59
-
60
- CLEAN_PATH.write_text(text, encoding="utf-8")
61
- print(f"Dataset successfully cleaned and saved → {CLEAN_PATH}")
62
-
63
- DATASET_PATH = CLEAN_PATH
64
-
65
- OUTPUT_DIR = Path("build/fine_tuning_output")
66
- MODEL_SAVE_NAME = "gpt_finetuned.pt"
67
-
68
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
- print(f"Using device: {device}")
70
-
71
- # ============================= DATASET =============================
72
- class TextDataset(Dataset):
73
- def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO):
74
- self.seq_len = seq_len
75
-
76
- # New: tiktoken exact GPT-2 encoding, fast, auto-downloads small .tiktoken file once
77
- print(f"Loading tiktoken encoding '{encoding_name}' (small file auto-downloads on first run if needed)...")
78
- self.enc = tiktoken.get_encoding(encoding_name) # "gpt2" is built-in and matches GPT-2 vocab perfectly
79
-
80
- self.split_type = split_type
81
-
82
- print(f"Loading text from {text_file} for {split_type} split...")
83
- text = Path(text_file).read_text(encoding="utf-8")
84
- tokens = self.enc.encode(text) # List of ints (exact same as GPT2Tokenizer)
85
-
86
- if len(tokens) < seq_len * 2:
87
- raise ValueError("Text too short!")
88
-
89
- all_inputs = []
90
- all_labels = []
91
-
92
- for i in range(0, len(tokens) - seq_len, seq_len):
93
- all_inputs.append(tokens[i:i + seq_len])
94
- all_labels.append(tokens[i + 1:i + seq_len + 1])
95
-
96
- total_sequences = len(all_inputs)
97
- val_size = int(total_sequences * val_ratio)
98
- train_size = total_sequences - val_size
99
-
100
- if self.split_type == 'train':
101
- self.inputs = all_inputs[:train_size]
102
- self.labels = all_labels[:train_size]
103
- elif self.split_type == 'val':
104
- self.inputs = all_inputs[train_size:]
105
- self.labels = all_labels[train_size:]
106
- else:
107
- raise ValueError("Invalid split_type. Must be 'train' or 'val'.")
108
-
109
- print(f"Created {len(self.inputs):,} sequences for {self.split_type} split.")
110
-
111
- def __len__(self):
112
- return len(self.inputs)
113
-
114
- def __getitem__(self, idx):
115
- return (torch.tensor(self.inputs[idx], dtype=torch.long),
116
- torch.tensor(self.labels[idx], dtype=torch.long))
117
-
118
- # ============================= EVALUATION =============================
119
- def evaluate(model, dataloader, criterion, device):
120
- model.eval()
121
- total_loss = 0.0
122
-
123
- with torch.no_grad():
124
- for inputs, targets in dataloader:
125
- inputs, targets = inputs.to(device), targets.to(device)
126
- logits, _ = model(inputs)
127
- loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
128
- total_loss += loss.item()
129
-
130
- avg_loss = total_loss / len(dataloader)
131
- model.train()
132
- return avg_loss
133
-
134
- # ============================= CLEANUP OLD EPOCHS =============================
135
- def cleanup_old_epochs(keep_last=KEEP_LAST_EPOCHS):
136
- epochs = sorted([p for p in OUTPUT_DIR.glob("epoch*") if p.is_dir()],
137
- key=lambda x: int(x.name.replace("epoch", "")))
138
- for old in epochs[:-keep_last]:
139
- if old.exists():
140
- shutil.rmtree(old)
141
- print(f"Deleted old epoch: {old.name}")
142
-
143
- # ============================= TRAINING =============================
144
- def train():
145
- OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
146
-
147
- print("Loading model...")
148
- model = GPTPyTorch().to(device)
149
-
150
- # Safer loading (silences FutureWarning)
151
- load_kwargs = {"map_location": device, "weights_only": True}
152
- if LAST_TRAINED_PATH.exists():
153
- print(f"Resuming training from last model: {LAST_TRAINED_PATH}")
154
- model.load_state_dict(torch.load(LAST_TRAINED_PATH, **load_kwargs))
155
- elif BASE_MODEL_PATH.exists():
156
- print(f"Starting from base model: {BASE_MODEL_PATH}")
157
- model.load_state_dict(torch.load(BASE_MODEL_PATH, **load_kwargs))
158
- else:
159
- print("No models found — initializing from scratch")
160
-
161
- model.train()
162
-
163
- train_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO)
164
- val_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='val', val_ratio=VAL_SPLIT_RATIO)
165
-
166
- train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
167
- val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
168
-
169
- optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
170
- criterion = nn.CrossEntropyLoss()
171
-
172
- total_steps = len(train_dataloader) * EPOCHS
173
- print(f"\n=== STARTING LONG-TERM TRAINING ===")
174
- print(f"Epochs: {EPOCHS} | Steps (Train): {total_steps} | Examples (Train): {len(train_dataset)}")
175
-
176
- global_step = 0
177
- for epoch in range(1, EPOCHS + 1):
178
- print(f"\n--- Epoch {epoch}/{EPOCHS} ---")
179
- epoch_loss = 0.0
180
-
181
- with tqdm(train_dataloader, desc=f"Epoch {epoch} [TRAIN]", leave=False) as pbar:
182
- for inputs, targets in pbar:
183
- inputs, targets = inputs.to(device), targets.to(device)
184
-
185
- optimizer.zero_grad()
186
- logits, _ = model(inputs)
187
- loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
188
- loss.backward()
189
- torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
190
- optimizer.step()
191
-
192
- loss_val = loss.item()
193
- epoch_loss += loss_val
194
- global_step += 1
195
-
196
- pbar.set_postfix({
197
- "loss": f"{loss_val:.3f}",
198
- "ppl": f"{math.exp(min(loss_val, 10)):.1f}",
199
- "step": f"{global_step}/{total_steps}"
200
- })
201
-
202
- avg_train_loss = epoch_loss / len(train_dataloader)
203
- print(f" [TRAIN] Average loss: {avg_train_loss:.3f} | PPL: {math.exp(avg_train_loss):.1f}")
204
-
205
- print(" [VALIDATION] Running evaluation...")
206
- val_loss = evaluate(model, val_dataloader, criterion, device)
207
- print(f" [VALIDATION] Average loss: {val_loss:.3f} | PPL: {math.exp(val_loss):.1f}")
208
-
209
- epoch_dir = OUTPUT_DIR / f"epoch{epoch}"
210
- epoch_dir.mkdir(exist_ok=True)
211
- torch.save(model.state_dict(), epoch_dir / MODEL_SAVE_NAME)
212
- print(f"Model saved: {epoch_dir / MODEL_SAVE_NAME}")
213
- cleanup_old_epochs()
214
-
215
- # Final saving – note: no tokenizer.save_pretrained anymore (tiktoken doesn't need it)
216
- final_dir = OUTPUT_DIR / "final"
217
- final_dir.mkdir(exist_ok=True)
218
- torch.save(model.state_dict(), final_dir / MODEL_SAVE_NAME)
219
-
220
- if LAST_TRAINED_PATH.exists():
221
- backup_path = BACKUP_DIR / f"gpt_last_trained_backup_{int(os.path.getmtime(LAST_TRAINED_PATH))}.pt"
222
- shutil.copy(LAST_TRAINED_PATH, backup_path)
223
- print(f"Backup of previous model created → {backup_path.name}")
224
-
225
- shutil.copy(final_dir / MODEL_SAVE_NAME, LAST_TRAINED_PATH)
226
- print(f"Last trained model saved → {LAST_TRAINED_PATH}")
227
-
228
- print(f"\nTRAINING COMPLETED! Model is ready:")
229
- print(f" For chat/inference: {final_dir / MODEL_SAVE_NAME}")
230
- print(f" • For continued fine-tuning: {LAST_TRAINED_PATH}")
231
-
232
- if __name__ == "__main__":
233
- if not RAW_PATH.exists():
234
- print(f"ERROR: File {RAW_PATH} not found")
235
- print("Place your text in datasets/dialogues_text.txt")
236
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  train()
 
1
+ # Copyright (c) 2025 CMS Manhattan
2
+ # All rights reserved.
3
+ # Author: Konstantin Vladimirovich Grabko
4
+ # Email: grabko@cmsmanhattan.com
5
+ # Phone: +1(516)777-0945
6
+ #
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU General Public License as published by
9
+ # the Free Software Foundation, version 3 of the License.
10
+ #
11
+ # This program is distributed in the hope that it will be useful,
12
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
+ # GNU General Public License for more details.
15
+ #
16
+ # You should have received a copy of the GNU General Public License
17
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
18
+ #
19
+ # Additional terms:
20
+ # Any commercial use or distribution of this software or derivative works
21
+ # requires explicit written permission from the copyright holder.
22
+
23
+ import os
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.optim as optim
27
+ from torch.utils.data import Dataset, DataLoader
28
+ import tiktoken # New: OpenAI's fast BPE tokenizer (open-source, auto-downloads once)
29
+ from tqdm import tqdm
30
+ import shutil
31
+ import math
32
+ from pathlib import Path
33
+ import re
34
+
35
+ from gpt_pytorch import GPTPyTorch # Your model import
36
+
37
+ # ============================= SETTINGS =============================
38
+ TRAIN_SEQ_LEN = 256 # Context length (increased for better coherence)
39
+ BATCH_SIZE = 12
40
+ EPOCHS = 50
41
+ LEARNING_RATE = 6e-6
42
+ WEIGHT_DECAY = 0.01
43
+ GRAD_CLIP = 1.0
44
+ KEEP_LAST_EPOCHS = 3
45
+ VAL_SPLIT_RATIO = 0.05 # 5% of data used for validation
46
+
47
+ # === MODEL PATHS ===
48
+ BASE_MODEL_PATH = Path("models/JiRack_H12_L6_V50257_D768_MSL8192_FF768x4.pt")
49
+ LAST_TRAINED_PATH = Path("models/JiRack_last_H12_L6_V50257_D768_MSL8192_FF768x4.pt")
50
+ BACKUP_DIR = Path("models/backups")
51
+ BACKUP_DIR.mkdir(exist_ok=True)
52
+
53
+ # === DATASET AUTO-CLEANING ===
54
+ RAW_PATH = Path("datasets/dialogues_text_clean.txt")
55
+ CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
56
+
57
+ force_clean = False
58
+ if not CLEAN_PATH.exists():
59
+ print("Clean dataset not found. Performing initial cleaning...")
60
+ force_clean = True
61
+ else:
62
+ try:
63
+ if RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime:
64
+ print("Changes detected in the source dataset. Performing re-cleaning...")
65
+ force_clean = True
66
+ else:
67
+ print(f"Using existing clean dataset → {CLEAN_PATH}")
68
+ except FileNotFoundError:
69
+ print("File system synchronization error. Performing re-cleaning for safety...")
70
+ force_clean = True
71
+
72
+ if force_clean:
73
+ if not RAW_PATH.exists():
74
+ raise FileNotFoundError(f"ERROR: Source file {RAW_PATH} not found. Check the path.")
75
+
76
+ print("Cleaning dataset from garbage (extra spaces, incorrect separators)...")
77
+ text = RAW_PATH.read_text(encoding="utf-8")
78
+
79
+ text = re.sub(r' {2,}', ' ', text)
80
+ text = text.replace(" \n", "\n").replace("\n ", "\n")
81
+
82
+ CLEAN_PATH.write_text(text, encoding="utf-8")
83
+ print(f"Dataset successfully cleaned and saved → {CLEAN_PATH}")
84
+
85
+ DATASET_PATH = CLEAN_PATH
86
+
87
+ OUTPUT_DIR = Path("build/fine_tuning_output")
88
+ MODEL_SAVE_NAME = "gpt_finetuned.pt"
89
+
90
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
91
+ print(f"Using device: {device}")
92
+
93
+ # ============================= DATASET =============================
94
+ class TextDataset(Dataset):
95
+ def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO):
96
+ self.seq_len = seq_len
97
+
98
+ # New: tiktoken – exact GPT-2 encoding, fast, auto-downloads small .tiktoken file once
99
+ print(f"Loading tiktoken encoding '{encoding_name}' (small file auto-downloads on first run if needed)...")
100
+ self.enc = tiktoken.get_encoding(encoding_name) # "gpt2" is built-in and matches GPT-2 vocab perfectly
101
+
102
+ self.split_type = split_type
103
+
104
+ print(f"Loading text from {text_file} for {split_type} split...")
105
+ text = Path(text_file).read_text(encoding="utf-8")
106
+ tokens = self.enc.encode(text) # List of ints (exact same as GPT2Tokenizer)
107
+
108
+ if len(tokens) < seq_len * 2:
109
+ raise ValueError("Text too short!")
110
+
111
+ all_inputs = []
112
+ all_labels = []
113
+
114
+ for i in range(0, len(tokens) - seq_len, seq_len):
115
+ all_inputs.append(tokens[i:i + seq_len])
116
+ all_labels.append(tokens[i + 1:i + seq_len + 1])
117
+
118
+ total_sequences = len(all_inputs)
119
+ val_size = int(total_sequences * val_ratio)
120
+ train_size = total_sequences - val_size
121
+
122
+ if self.split_type == 'train':
123
+ self.inputs = all_inputs[:train_size]
124
+ self.labels = all_labels[:train_size]
125
+ elif self.split_type == 'val':
126
+ self.inputs = all_inputs[train_size:]
127
+ self.labels = all_labels[train_size:]
128
+ else:
129
+ raise ValueError("Invalid split_type. Must be 'train' or 'val'.")
130
+
131
+ print(f"Created {len(self.inputs):,} sequences for {self.split_type} split.")
132
+
133
+ def __len__(self):
134
+ return len(self.inputs)
135
+
136
+ def __getitem__(self, idx):
137
+ return (torch.tensor(self.inputs[idx], dtype=torch.long),
138
+ torch.tensor(self.labels[idx], dtype=torch.long))
139
+
140
+ # ============================= EVALUATION =============================
141
+ def evaluate(model, dataloader, criterion, device):
142
+ model.eval()
143
+ total_loss = 0.0
144
+
145
+ with torch.no_grad():
146
+ for inputs, targets in dataloader:
147
+ inputs, targets = inputs.to(device), targets.to(device)
148
+ logits, _ = model(inputs)
149
+ loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
150
+ total_loss += loss.item()
151
+
152
+ avg_loss = total_loss / len(dataloader)
153
+ model.train()
154
+ return avg_loss
155
+
156
+ # ============================= CLEANUP OLD EPOCHS =============================
157
+ def cleanup_old_epochs(keep_last=KEEP_LAST_EPOCHS):
158
+ epochs = sorted([p for p in OUTPUT_DIR.glob("epoch*") if p.is_dir()],
159
+ key=lambda x: int(x.name.replace("epoch", "")))
160
+ for old in epochs[:-keep_last]:
161
+ if old.exists():
162
+ shutil.rmtree(old)
163
+ print(f"Deleted old epoch: {old.name}")
164
+
165
+ # ============================= TRAINING =============================
166
+ def train():
167
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
168
+
169
+ print("Loading model...")
170
+ model = GPTPyTorch().to(device)
171
+
172
+ # Safer loading (silences FutureWarning)
173
+ load_kwargs = {"map_location": device, "weights_only": True}
174
+ if LAST_TRAINED_PATH.exists():
175
+ print(f"Resuming training from last model: {LAST_TRAINED_PATH}")
176
+ model.load_state_dict(torch.load(LAST_TRAINED_PATH, **load_kwargs))
177
+ elif BASE_MODEL_PATH.exists():
178
+ print(f"Starting from base model: {BASE_MODEL_PATH}")
179
+ model.load_state_dict(torch.load(BASE_MODEL_PATH, **load_kwargs))
180
+ else:
181
+ print("No models found initializing from scratch")
182
+
183
+ model.train()
184
+
185
+ train_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO)
186
+ val_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='val', val_ratio=VAL_SPLIT_RATIO)
187
+
188
+ train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
189
+ val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
190
+
191
+ optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
192
+ criterion = nn.CrossEntropyLoss()
193
+
194
+ total_steps = len(train_dataloader) * EPOCHS
195
+ print(f"\n=== STARTING LONG-TERM TRAINING ===")
196
+ print(f"Epochs: {EPOCHS} | Steps (Train): {total_steps} | Examples (Train): {len(train_dataset)}")
197
+
198
+ global_step = 0
199
+ for epoch in range(1, EPOCHS + 1):
200
+ print(f"\n--- Epoch {epoch}/{EPOCHS} ---")
201
+ epoch_loss = 0.0
202
+
203
+ with tqdm(train_dataloader, desc=f"Epoch {epoch} [TRAIN]", leave=False) as pbar:
204
+ for inputs, targets in pbar:
205
+ inputs, targets = inputs.to(device), targets.to(device)
206
+
207
+ optimizer.zero_grad()
208
+ logits, _ = model(inputs)
209
+ loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
210
+ loss.backward()
211
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
212
+ optimizer.step()
213
+
214
+ loss_val = loss.item()
215
+ epoch_loss += loss_val
216
+ global_step += 1
217
+
218
+ pbar.set_postfix({
219
+ "loss": f"{loss_val:.3f}",
220
+ "ppl": f"{math.exp(min(loss_val, 10)):.1f}",
221
+ "step": f"{global_step}/{total_steps}"
222
+ })
223
+
224
+ avg_train_loss = epoch_loss / len(train_dataloader)
225
+ print(f" [TRAIN] Average loss: {avg_train_loss:.3f} | PPL: {math.exp(avg_train_loss):.1f}")
226
+
227
+ print(" [VALIDATION] Running evaluation...")
228
+ val_loss = evaluate(model, val_dataloader, criterion, device)
229
+ print(f" [VALIDATION] Average loss: {val_loss:.3f} | PPL: {math.exp(val_loss):.1f}")
230
+
231
+ epoch_dir = OUTPUT_DIR / f"epoch{epoch}"
232
+ epoch_dir.mkdir(exist_ok=True)
233
+ torch.save(model.state_dict(), epoch_dir / MODEL_SAVE_NAME)
234
+ print(f"Model saved: {epoch_dir / MODEL_SAVE_NAME}")
235
+ cleanup_old_epochs()
236
+
237
+ # Final saving – note: no tokenizer.save_pretrained anymore (tiktoken doesn't need it)
238
+ final_dir = OUTPUT_DIR / "final"
239
+ final_dir.mkdir(exist_ok=True)
240
+ torch.save(model.state_dict(), final_dir / MODEL_SAVE_NAME)
241
+
242
+ if LAST_TRAINED_PATH.exists():
243
+ backup_path = BACKUP_DIR / f"gpt_last_trained_backup_{int(os.path.getmtime(LAST_TRAINED_PATH))}.pt"
244
+ shutil.copy(LAST_TRAINED_PATH, backup_path)
245
+ print(f"Backup of previous model created → {backup_path.name}")
246
+
247
+ shutil.copy(final_dir / MODEL_SAVE_NAME, LAST_TRAINED_PATH)
248
+ print(f"Last trained model saved → {LAST_TRAINED_PATH}")
249
+
250
+ print(f"\nTRAINING COMPLETED! Model is ready:")
251
+ print(f" • For chat/inference: {final_dir / MODEL_SAVE_NAME}")
252
+ print(f" • For continued fine-tuning: {LAST_TRAINED_PATH}")
253
+
254
+ if __name__ == "__main__":
255
+ if not RAW_PATH.exists():
256
+ print(f"ERROR: File {RAW_PATH} not found")
257
+ print("Place your text in datasets/dialogues_text.txt")
258
+ else:
259
  train()