|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import tiktoken |
|
|
from tqdm import tqdm |
|
|
import shutil |
|
|
import math |
|
|
from pathlib import Path |
|
|
import re |
|
|
|
|
|
from gpt_pytorch import GPTPyTorch |
|
|
|
|
|
|
|
|
TRAIN_SEQ_LEN = 256 |
|
|
BATCH_SIZE = 12 |
|
|
EPOCHS = 50 |
|
|
LEARNING_RATE = 6e-6 |
|
|
WEIGHT_DECAY = 0.01 |
|
|
GRAD_CLIP = 1.0 |
|
|
KEEP_LAST_EPOCHS = 3 |
|
|
VAL_SPLIT_RATIO = 0.05 |
|
|
|
|
|
|
|
|
BASE_MODEL_PATH = Path("models/JiRack_H12_L6_V50257_D768_MSL8192_FF768x4.pt") |
|
|
LAST_TRAINED_PATH = Path("models/JiRack_last_H12_L6_V50257_D768_MSL8192_FF768x4.pt") |
|
|
BACKUP_DIR = Path("models/backups") |
|
|
BACKUP_DIR.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
RAW_PATH = Path("datasets/dialogues_text_clean.txt") |
|
|
CLEAN_PATH = Path("datasets/dialogues_text_clean.txt") |
|
|
|
|
|
force_clean = False |
|
|
if not CLEAN_PATH.exists(): |
|
|
print("Clean dataset not found. Performing initial cleaning...") |
|
|
force_clean = True |
|
|
else: |
|
|
try: |
|
|
if RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime: |
|
|
print("Changes detected in the source dataset. Performing re-cleaning...") |
|
|
force_clean = True |
|
|
else: |
|
|
print(f"Using existing clean dataset → {CLEAN_PATH}") |
|
|
except FileNotFoundError: |
|
|
print("File system synchronization error. Performing re-cleaning for safety...") |
|
|
force_clean = True |
|
|
|
|
|
if force_clean: |
|
|
if not RAW_PATH.exists(): |
|
|
raise FileNotFoundError(f"ERROR: Source file {RAW_PATH} not found. Check the path.") |
|
|
|
|
|
print("Cleaning dataset from garbage (extra spaces, incorrect separators)...") |
|
|
text = RAW_PATH.read_text(encoding="utf-8") |
|
|
|
|
|
text = re.sub(r' {2,}', ' ', text) |
|
|
text = text.replace(" \n", "\n").replace("\n ", "\n") |
|
|
|
|
|
CLEAN_PATH.write_text(text, encoding="utf-8") |
|
|
print(f"Dataset successfully cleaned and saved → {CLEAN_PATH}") |
|
|
|
|
|
DATASET_PATH = CLEAN_PATH |
|
|
|
|
|
OUTPUT_DIR = Path("build/fine_tuning_output") |
|
|
MODEL_SAVE_NAME = "gpt_finetuned.pt" |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
class TextDataset(Dataset): |
|
|
def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO): |
|
|
self.seq_len = seq_len |
|
|
|
|
|
|
|
|
print(f"Loading tiktoken encoding '{encoding_name}' (small file auto-downloads on first run if needed)...") |
|
|
self.enc = tiktoken.get_encoding(encoding_name) |
|
|
|
|
|
self.split_type = split_type |
|
|
|
|
|
print(f"Loading text from {text_file} for {split_type} split...") |
|
|
text = Path(text_file).read_text(encoding="utf-8") |
|
|
tokens = self.enc.encode(text) |
|
|
|
|
|
if len(tokens) < seq_len * 2: |
|
|
raise ValueError("Text too short!") |
|
|
|
|
|
all_inputs = [] |
|
|
all_labels = [] |
|
|
|
|
|
for i in range(0, len(tokens) - seq_len, seq_len): |
|
|
all_inputs.append(tokens[i:i + seq_len]) |
|
|
all_labels.append(tokens[i + 1:i + seq_len + 1]) |
|
|
|
|
|
total_sequences = len(all_inputs) |
|
|
val_size = int(total_sequences * val_ratio) |
|
|
train_size = total_sequences - val_size |
|
|
|
|
|
if self.split_type == 'train': |
|
|
self.inputs = all_inputs[:train_size] |
|
|
self.labels = all_labels[:train_size] |
|
|
elif self.split_type == 'val': |
|
|
self.inputs = all_inputs[train_size:] |
|
|
self.labels = all_labels[train_size:] |
|
|
else: |
|
|
raise ValueError("Invalid split_type. Must be 'train' or 'val'.") |
|
|
|
|
|
print(f"Created {len(self.inputs):,} sequences for {self.split_type} split.") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.inputs) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return (torch.tensor(self.inputs[idx], dtype=torch.long), |
|
|
torch.tensor(self.labels[idx], dtype=torch.long)) |
|
|
|
|
|
|
|
|
def evaluate(model, dataloader, criterion, device): |
|
|
model.eval() |
|
|
total_loss = 0.0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for inputs, targets in dataloader: |
|
|
inputs, targets = inputs.to(device), targets.to(device) |
|
|
logits, _ = model(inputs) |
|
|
loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) |
|
|
total_loss += loss.item() |
|
|
|
|
|
avg_loss = total_loss / len(dataloader) |
|
|
model.train() |
|
|
return avg_loss |
|
|
|
|
|
|
|
|
def cleanup_old_epochs(keep_last=KEEP_LAST_EPOCHS): |
|
|
epochs = sorted([p for p in OUTPUT_DIR.glob("epoch*") if p.is_dir()], |
|
|
key=lambda x: int(x.name.replace("epoch", ""))) |
|
|
for old in epochs[:-keep_last]: |
|
|
if old.exists(): |
|
|
shutil.rmtree(old) |
|
|
print(f"Deleted old epoch: {old.name}") |
|
|
|
|
|
|
|
|
def train(): |
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
print("Loading model...") |
|
|
model = GPTPyTorch().to(device) |
|
|
|
|
|
|
|
|
load_kwargs = {"map_location": device, "weights_only": True} |
|
|
if LAST_TRAINED_PATH.exists(): |
|
|
print(f"Resuming training from last model: {LAST_TRAINED_PATH}") |
|
|
model.load_state_dict(torch.load(LAST_TRAINED_PATH, **load_kwargs)) |
|
|
elif BASE_MODEL_PATH.exists(): |
|
|
print(f"Starting from base model: {BASE_MODEL_PATH}") |
|
|
model.load_state_dict(torch.load(BASE_MODEL_PATH, **load_kwargs)) |
|
|
else: |
|
|
print("No models found — initializing from scratch") |
|
|
|
|
|
model.train() |
|
|
|
|
|
train_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO) |
|
|
val_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='val', val_ratio=VAL_SPLIT_RATIO) |
|
|
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) |
|
|
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True) |
|
|
|
|
|
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) |
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
total_steps = len(train_dataloader) * EPOCHS |
|
|
print(f"\n=== STARTING LONG-TERM TRAINING ===") |
|
|
print(f"Epochs: {EPOCHS} | Steps (Train): {total_steps} | Examples (Train): {len(train_dataset)}") |
|
|
|
|
|
global_step = 0 |
|
|
for epoch in range(1, EPOCHS + 1): |
|
|
print(f"\n--- Epoch {epoch}/{EPOCHS} ---") |
|
|
epoch_loss = 0.0 |
|
|
|
|
|
with tqdm(train_dataloader, desc=f"Epoch {epoch} [TRAIN]", leave=False) as pbar: |
|
|
for inputs, targets in pbar: |
|
|
inputs, targets = inputs.to(device), targets.to(device) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
logits, _ = model(inputs) |
|
|
loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) |
|
|
optimizer.step() |
|
|
|
|
|
loss_val = loss.item() |
|
|
epoch_loss += loss_val |
|
|
global_step += 1 |
|
|
|
|
|
pbar.set_postfix({ |
|
|
"loss": f"{loss_val:.3f}", |
|
|
"ppl": f"{math.exp(min(loss_val, 10)):.1f}", |
|
|
"step": f"{global_step}/{total_steps}" |
|
|
}) |
|
|
|
|
|
avg_train_loss = epoch_loss / len(train_dataloader) |
|
|
print(f" [TRAIN] Average loss: {avg_train_loss:.3f} | PPL: {math.exp(avg_train_loss):.1f}") |
|
|
|
|
|
print(" [VALIDATION] Running evaluation...") |
|
|
val_loss = evaluate(model, val_dataloader, criterion, device) |
|
|
print(f" [VALIDATION] Average loss: {val_loss:.3f} | PPL: {math.exp(val_loss):.1f}") |
|
|
|
|
|
epoch_dir = OUTPUT_DIR / f"epoch{epoch}" |
|
|
epoch_dir.mkdir(exist_ok=True) |
|
|
torch.save(model.state_dict(), epoch_dir / MODEL_SAVE_NAME) |
|
|
print(f"Model saved: {epoch_dir / MODEL_SAVE_NAME}") |
|
|
cleanup_old_epochs() |
|
|
|
|
|
|
|
|
final_dir = OUTPUT_DIR / "final" |
|
|
final_dir.mkdir(exist_ok=True) |
|
|
torch.save(model.state_dict(), final_dir / MODEL_SAVE_NAME) |
|
|
|
|
|
if LAST_TRAINED_PATH.exists(): |
|
|
backup_path = BACKUP_DIR / f"gpt_last_trained_backup_{int(os.path.getmtime(LAST_TRAINED_PATH))}.pt" |
|
|
shutil.copy(LAST_TRAINED_PATH, backup_path) |
|
|
print(f"Backup of previous model created → {backup_path.name}") |
|
|
|
|
|
shutil.copy(final_dir / MODEL_SAVE_NAME, LAST_TRAINED_PATH) |
|
|
print(f"Last trained model saved → {LAST_TRAINED_PATH}") |
|
|
|
|
|
print(f"\nTRAINING COMPLETED! Model is ready:") |
|
|
print(f" • For chat/inference: {final_dir / MODEL_SAVE_NAME}") |
|
|
print(f" • For continued fine-tuning: {LAST_TRAINED_PATH}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
if not RAW_PATH.exists(): |
|
|
print(f"ERROR: File {RAW_PATH} not found") |
|
|
print("Place your text in datasets/dialogues_text.txt") |
|
|
else: |
|
|
train() |