JiRack_empty / fine_tune_with_validation.py
kgrabko's picture
Update fine_tune_with_validation.py
a22c288 verified
# Copyright (c) 2025 CMS Manhattan
# All rights reserved.
# Author: Konstantin Vladimirovich Grabko
# Email: grabko@cmsmanhattan.com
# Phone: +1(516)777-0945
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
# Additional terms:
# Any commercial use or distribution of this software or derivative works
# requires explicit written permission from the copyright holder.
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import tiktoken # New: OpenAI's fast BPE tokenizer (open-source, auto-downloads once)
from tqdm import tqdm
import shutil
import math
from pathlib import Path
import re
from gpt_pytorch import GPTPyTorch # Your model import
# ============================= SETTINGS =============================
TRAIN_SEQ_LEN = 256 # Context length (increased for better coherence)
BATCH_SIZE = 12
EPOCHS = 50
LEARNING_RATE = 6e-6
WEIGHT_DECAY = 0.01
GRAD_CLIP = 1.0
KEEP_LAST_EPOCHS = 3
VAL_SPLIT_RATIO = 0.05 # 5% of data used for validation
# === MODEL PATHS ===
BASE_MODEL_PATH = Path("models/JiRack_H12_L6_V50257_D768_MSL8192_FF768x4.pt")
LAST_TRAINED_PATH = Path("models/JiRack_last_H12_L6_V50257_D768_MSL8192_FF768x4.pt")
BACKUP_DIR = Path("models/backups")
BACKUP_DIR.mkdir(exist_ok=True)
# === DATASET AUTO-CLEANING ===
RAW_PATH = Path("datasets/dialogues_text_clean.txt")
CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
force_clean = False
if not CLEAN_PATH.exists():
print("Clean dataset not found. Performing initial cleaning...")
force_clean = True
else:
try:
if RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime:
print("Changes detected in the source dataset. Performing re-cleaning...")
force_clean = True
else:
print(f"Using existing clean dataset → {CLEAN_PATH}")
except FileNotFoundError:
print("File system synchronization error. Performing re-cleaning for safety...")
force_clean = True
if force_clean:
if not RAW_PATH.exists():
raise FileNotFoundError(f"ERROR: Source file {RAW_PATH} not found. Check the path.")
print("Cleaning dataset from garbage (extra spaces, incorrect separators)...")
text = RAW_PATH.read_text(encoding="utf-8")
text = re.sub(r' {2,}', ' ', text)
text = text.replace(" \n", "\n").replace("\n ", "\n")
CLEAN_PATH.write_text(text, encoding="utf-8")
print(f"Dataset successfully cleaned and saved → {CLEAN_PATH}")
DATASET_PATH = CLEAN_PATH
OUTPUT_DIR = Path("build/fine_tuning_output")
MODEL_SAVE_NAME = "gpt_finetuned.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# ============================= DATASET =============================
class TextDataset(Dataset):
def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO):
self.seq_len = seq_len
# New: tiktoken – exact GPT-2 encoding, fast, auto-downloads small .tiktoken file once
print(f"Loading tiktoken encoding '{encoding_name}' (small file auto-downloads on first run if needed)...")
self.enc = tiktoken.get_encoding(encoding_name) # "gpt2" is built-in and matches GPT-2 vocab perfectly
self.split_type = split_type
print(f"Loading text from {text_file} for {split_type} split...")
text = Path(text_file).read_text(encoding="utf-8")
tokens = self.enc.encode(text) # List of ints (exact same as GPT2Tokenizer)
if len(tokens) < seq_len * 2:
raise ValueError("Text too short!")
all_inputs = []
all_labels = []
for i in range(0, len(tokens) - seq_len, seq_len):
all_inputs.append(tokens[i:i + seq_len])
all_labels.append(tokens[i + 1:i + seq_len + 1])
total_sequences = len(all_inputs)
val_size = int(total_sequences * val_ratio)
train_size = total_sequences - val_size
if self.split_type == 'train':
self.inputs = all_inputs[:train_size]
self.labels = all_labels[:train_size]
elif self.split_type == 'val':
self.inputs = all_inputs[train_size:]
self.labels = all_labels[train_size:]
else:
raise ValueError("Invalid split_type. Must be 'train' or 'val'.")
print(f"Created {len(self.inputs):,} sequences for {self.split_type} split.")
def __len__(self):
return len(self.inputs)
def __getitem__(self, idx):
return (torch.tensor(self.inputs[idx], dtype=torch.long),
torch.tensor(self.labels[idx], dtype=torch.long))
# ============================= EVALUATION =============================
def evaluate(model, dataloader, criterion, device):
model.eval()
total_loss = 0.0
with torch.no_grad():
for inputs, targets in dataloader:
inputs, targets = inputs.to(device), targets.to(device)
logits, _ = model(inputs)
loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
model.train()
return avg_loss
# ============================= CLEANUP OLD EPOCHS =============================
def cleanup_old_epochs(keep_last=KEEP_LAST_EPOCHS):
epochs = sorted([p for p in OUTPUT_DIR.glob("epoch*") if p.is_dir()],
key=lambda x: int(x.name.replace("epoch", "")))
for old in epochs[:-keep_last]:
if old.exists():
shutil.rmtree(old)
print(f"Deleted old epoch: {old.name}")
# ============================= TRAINING =============================
def train():
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print("Loading model...")
model = GPTPyTorch().to(device)
# Safer loading (silences FutureWarning)
load_kwargs = {"map_location": device, "weights_only": True}
if LAST_TRAINED_PATH.exists():
print(f"Resuming training from last model: {LAST_TRAINED_PATH}")
model.load_state_dict(torch.load(LAST_TRAINED_PATH, **load_kwargs))
elif BASE_MODEL_PATH.exists():
print(f"Starting from base model: {BASE_MODEL_PATH}")
model.load_state_dict(torch.load(BASE_MODEL_PATH, **load_kwargs))
else:
print("No models found — initializing from scratch")
model.train()
train_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO)
val_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='val', val_ratio=VAL_SPLIT_RATIO)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
criterion = nn.CrossEntropyLoss()
total_steps = len(train_dataloader) * EPOCHS
print(f"\n=== STARTING LONG-TERM TRAINING ===")
print(f"Epochs: {EPOCHS} | Steps (Train): {total_steps} | Examples (Train): {len(train_dataset)}")
global_step = 0
for epoch in range(1, EPOCHS + 1):
print(f"\n--- Epoch {epoch}/{EPOCHS} ---")
epoch_loss = 0.0
with tqdm(train_dataloader, desc=f"Epoch {epoch} [TRAIN]", leave=False) as pbar:
for inputs, targets in pbar:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
logits, _ = model(inputs)
loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
optimizer.step()
loss_val = loss.item()
epoch_loss += loss_val
global_step += 1
pbar.set_postfix({
"loss": f"{loss_val:.3f}",
"ppl": f"{math.exp(min(loss_val, 10)):.1f}",
"step": f"{global_step}/{total_steps}"
})
avg_train_loss = epoch_loss / len(train_dataloader)
print(f" [TRAIN] Average loss: {avg_train_loss:.3f} | PPL: {math.exp(avg_train_loss):.1f}")
print(" [VALIDATION] Running evaluation...")
val_loss = evaluate(model, val_dataloader, criterion, device)
print(f" [VALIDATION] Average loss: {val_loss:.3f} | PPL: {math.exp(val_loss):.1f}")
epoch_dir = OUTPUT_DIR / f"epoch{epoch}"
epoch_dir.mkdir(exist_ok=True)
torch.save(model.state_dict(), epoch_dir / MODEL_SAVE_NAME)
print(f"Model saved: {epoch_dir / MODEL_SAVE_NAME}")
cleanup_old_epochs()
# Final saving – note: no tokenizer.save_pretrained anymore (tiktoken doesn't need it)
final_dir = OUTPUT_DIR / "final"
final_dir.mkdir(exist_ok=True)
torch.save(model.state_dict(), final_dir / MODEL_SAVE_NAME)
if LAST_TRAINED_PATH.exists():
backup_path = BACKUP_DIR / f"gpt_last_trained_backup_{int(os.path.getmtime(LAST_TRAINED_PATH))}.pt"
shutil.copy(LAST_TRAINED_PATH, backup_path)
print(f"Backup of previous model created → {backup_path.name}")
shutil.copy(final_dir / MODEL_SAVE_NAME, LAST_TRAINED_PATH)
print(f"Last trained model saved → {LAST_TRAINED_PATH}")
print(f"\nTRAINING COMPLETED! Model is ready:")
print(f" • For chat/inference: {final_dir / MODEL_SAVE_NAME}")
print(f" • For continued fine-tuning: {LAST_TRAINED_PATH}")
if __name__ == "__main__":
if not RAW_PATH.exists():
print(f"ERROR: File {RAW_PATH} not found")
print("Place your text in datasets/dialogues_text.txt")
else:
train()