JiRack_GPT3_3b / fine_tune3b_with_validation_no_torchscript.py
kgrabko's picture
Upload fine_tune3b_with_validation_no_torchscript.py
4bedc60 verified
# Copyright (c) 2025 CMS Manhattan
# All rights reserved.
#
# This file is part of a project authored by CMS Manhattan. You may use, distribute, and modify
# this code under the terms of the GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
# please read <http://www.gnu.org/licenses/>.
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2TokenizerFast
from tqdm import tqdm
import shutil
import math
from pathlib import Path
import re
from gpt_jit_modern_3b import JiRackPyTorch
# ============================= SETTINGS =============================
TRAIN_SEQ_LEN = 256
BATCH_SIZE = 2
ACCUM_STEPS = 16
EPOCHS = 500
LEARNING_RATE = 3e-5
WEIGHT_DECAY = 0.01
GRAD_CLIP = 1.0
VAL_SPLIT_RATIO = 0.05
KEEP_LAST_EPOCHS = 3
# === PATHS ===
BASE_MODEL_PATH = Path("models/gpt_modern_3b_class.state_dict.pt")
LAST_TRAINED_PATH = Path("models/gpt_last_modern_3b_class.state_dict.pt")
BACKUP_DIR = Path("models/backups")
BACKUP_DIR.mkdir(exist_ok=True, parents=True)
RAW_PATH = Path("datasets/dialogues_text.txt")
CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print(f"Using device: {device}")
# === DATASET CLEANING ===
if not CLEAN_PATH.exists() or RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime:
print("Cleaning dataset...")
text = RAW_PATH.read_text(encoding="utf-8")
text = re.sub(r' {2,}', ' ', text) # remove extra spaces
text = text.replace(" \n", "\n").replace("\n ", "\n")
CLEAN_PATH.write_text(text, encoding="utf-8")
print(f"Done → {CLEAN_PATH}")
DATASET_PATH = CLEAN_PATH
OUTPUT_DIR = Path("build/fine_tuning_output")
MODEL_SAVE_NAME = "pytorch_model.bin"
# ============================= DATASET =============================
class TextDataset(Dataset):
def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, split='train'):
self.seq_len = seq_len
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
text = Path(text_file).read_text(encoding="utf-8")
tokens = tokenizer.encode(text)
sequences = []
for i in range(0, len(tokens) - seq_len, seq_len):
sequences.append(tokens[i:i + seq_len + 1]) # +1 for labels
split_idx = int(len(sequences) * (1 - VAL_SPLIT_RATIO))
if split == 'train':
self.data = sequences[:split_idx]
else:
self.data = sequences[split_idx:]
print(f"{split.upper()} sequences: {len(self.data):,}")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
seq = self.data[idx]
return torch.tensor(seq[:-1], dtype=torch.long), torch.tensor(seq[1:], dtype=torch.long)
def evaluate(model, loader):
model.eval()
total_loss = 0
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
for x, y in loader:
x, y = x.to(device), y.to(device)
logits, _ = model(x)
loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
total_loss += loss.item()
model.train()
return total_loss / len(loader)
def train():
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print("Loading model...")
model = JiRackPyTorch().to(device)
if LAST_TRAINED_PATH.exists():
print(f"Resuming from {LAST_TRAINED_PATH}")
model.load_state_dict(torch.load(LAST_TRAINED_PATH, map_location=device))
elif BASE_MODEL_PATH.exists():
print(f"Starting from base model {BASE_MODEL_PATH}")
model.load_state_dict(torch.load(BASE_MODEL_PATH, map_location=device))
else:
print("Starting from scratch — random weights")
model.train()
train_dataset = TextDataset(DATASET_PATH, split='train')
val_dataset = TextDataset(DATASET_PATH, split='val')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
criterion = nn.CrossEntropyLoss()
print("\nFULL TRAINING STARTED! No LoRA, no compromises — we're training the whole thing!\n")
for epoch in range(1, EPOCHS + 1):
total_loss = 0
for step, (x, y) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch}")):
x, y = x.to(device), y.to(device)
logits, _ = model(x)
loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
loss = loss / ACCUM_STEPS
loss.backward()
total_loss += loss.item() * ACCUM_STEPS
if (step + 1) % ACCUM_STEPS == 0 or (step + 1) == len(train_loader):
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
optimizer.step()
optimizer.zero_grad()
avg_train_loss = total_loss / len(train_loader)
val_loss = evaluate(model, val_loader)
print(f"\nEpoch {epoch}")
print(f" Train loss: {avg_train_loss:.4f} | PPL: {math.exp(avg_train_loss):.2f}")
print(f" Val loss: {val_loss:.4f} | PPL: {math.exp(val_loss):.2f}")
# Save checkpoint
save_dir = OUTPUT_DIR / f"epoch_{epoch}"
save_dir.mkdir(exist_ok=True, parents=True)
torch.save(model.state_dict(), save_dir / MODEL_SAVE_NAME)
torch.save(model.state_dict(), LAST_TRAINED_PATH)
# Keep only the last N epochs to save disk space
epochs = sorted([p for p in OUTPUT_DIR.iterdir() if p.is_dir() and p.name.startswith("epoch_")])
for old in epochs[:-KEEP_LAST_EPOCHS]:
shutil.rmtree(old)
print("\nDONE! Full model trained. You are now the emperor of fine-tuning.")
if __name__ == "__main__":
train()