| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.utils.data import Dataset, DataLoader |
| | from transformers import GPT2TokenizerFast |
| | from tqdm import tqdm |
| | import shutil |
| | import math |
| | from pathlib import Path |
| | import re |
| | from typing import Optional, List, Tuple |
| |
|
| | |
| |
|
| | |
| | TRAIN_SEQ_LEN = 256 |
| | BATCH_SIZE = 12 |
| | EPOCHS = 50 |
| | LEARNING_RATE = 6e-6 |
| | WEIGHT_DECAY = 0.01 |
| | GRAD_CLIP = 1.0 |
| | KEEP_LAST_EPOCHS = 3 |
| | VAL_SPLIT_RATIO = 0.05 |
| |
|
| | |
| | |
| | BASE_MODEL_PATH = Path("models/JiRack_H12_L6_V50257_D768_MSL8192_FF768x4.script.pt") |
| | LAST_TRAINED_PATH = Path("models/JiRack_last_H12_L6_V50257_D768_MSL8192_FF768x4.script.pt") |
| | BACKUP_DIR = Path("models/backups") |
| | BACKUP_DIR.mkdir(exist_ok=True) |
| |
|
| | |
| | RAW_PATH = Path("datasets/dialogues_text.txt") |
| | CLEAN_PATH = Path("datasets/dialogues_text_clean.txt") |
| |
|
| | |
| | force_clean = False |
| | if not CLEAN_PATH.exists(): |
| | print("Cleaned dataset not found. Performing initial cleaning...") |
| | force_clean = True |
| | else: |
| | try: |
| | if RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime: |
| | print("Detected changes in the raw dataset. Re-cleaning...") |
| | force_clean = True |
| | else: |
| | print(f"Using existing cleaned dataset → {CLEAN_PATH}") |
| | except FileNotFoundError: |
| | print("File system synchronization error. Performing re-cleaning for safety...") |
| | force_clean = True |
| |
|
| | if force_clean: |
| | if not RAW_PATH.exists(): |
| | raise FileNotFoundError(f"ERROR: Source file {RAW_PATH} not found. Check the path.") |
| |
|
| | print("Cleaning up the dataset from garbage (wrong separators, extra spaces)...") |
| | text = RAW_PATH.read_text(encoding="utf-8") |
| |
|
| | |
| | text = re.sub(r' {2,}', ' ', text) |
| | text = text.replace(" \n", "\n").replace("\n ", "\n") |
| |
|
| | CLEAN_PATH.write_text(text, encoding="utf-8") |
| | print(f"Dataset successfully cleaned and saved → {CLEAN_PATH}") |
| |
|
| | DATASET_PATH = CLEAN_PATH |
| |
|
| | OUTPUT_DIR = Path("build/fine_tuning_output") |
| | MODEL_SAVE_NAME = "gpt_finetuned.script.pt" |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"Using device: {device}") |
| |
|
| | |
| | class TextDataset(Dataset): |
| | def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, tokenizer_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO): |
| | self.seq_len = seq_len |
| | self.tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_name) |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| | self.split_type = split_type |
| |
|
| | print(f"Loading text from {text_file} for {split_type} split...") |
| | text = Path(text_file).read_text(encoding="utf-8") |
| | tokens = self.tokenizer.encode(text) |
| |
|
| | if len(tokens) < seq_len * 2: |
| | raise ValueError("Text too short!") |
| |
|
| | all_inputs = [] |
| | all_labels = [] |
| |
|
| | for i in range(0, len(tokens) - seq_len, seq_len): |
| | all_inputs.append(tokens[i:i + seq_len]) |
| | all_labels.append(tokens[i + 1:i + seq_len + 1]) |
| |
|
| | total_sequences = len(all_inputs) |
| | val_size = int(total_sequences * val_ratio) |
| | train_size = total_sequences - val_size |
| |
|
| | if self.split_type == 'train': |
| | self.inputs = all_inputs[:train_size] |
| | self.labels = all_labels[:train_size] |
| | elif self.split_type == 'val': |
| | self.inputs = all_inputs[train_size:] |
| | self.labels = all_labels[train_size:] |
| | else: |
| | raise ValueError("Invalid split_type. Must be 'train' or 'val'.") |
| |
|
| | print(f"Created {len(self.inputs):,} sequences for {self.split_type} split.") |
| |
|
| | def __len__(self): |
| | return len(self.inputs) |
| |
|
| | def __getitem__(self, idx): |
| | return (torch.tensor(self.inputs[idx], dtype=torch.long), |
| | torch.tensor(self.labels[idx], dtype=torch.long)) |
| |
|
| | |
| |
|
| | |
| | def get_logits_from_model(model, inputs): |
| | """ |
| | Adapted model invocation handling a possible output of (logits, new_kv) |
| | or just logits for JIT models. |
| | """ |
| | try: |
| | |
| | logits, _ = model(inputs) |
| | except Exception: |
| | |
| | logits = model(inputs) |
| | return logits |
| |
|
| |
|
| | def evaluate(model, dataloader, criterion, device): |
| | """Evaluates the model on the validation dataset.""" |
| | model.eval() |
| | total_loss = 0.0 |
| |
|
| | with torch.no_grad(): |
| | for inputs, targets in dataloader: |
| | inputs, targets = inputs.to(device), targets.to(device) |
| |
|
| | logits = get_logits_from_model(model, inputs) |
| | loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) |
| | total_loss += loss.item() |
| |
|
| | avg_loss = total_loss / len(dataloader) |
| | model.train() |
| | return avg_loss |
| |
|
| | |
| |
|
| | |
| | def cleanup_old_epochs(keep_last=KEEP_LAST_EPOCHS): |
| | epochs = sorted([p for p in OUTPUT_DIR.glob("epoch*") if p.is_dir()], |
| | key=lambda x: int(x.name.replace("epoch", ""))) |
| | for old in epochs[:-keep_last]: |
| | if old.exists(): |
| | shutil.rmtree(old) |
| | print(f"Old epoch deleted: {old.name}") |
| |
|
| | |
| |
|
| | |
| | def train(): |
| | OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
| |
|
| | print("Loading model...") |
| | model = None |
| |
|
| | |
| | if LAST_TRAINED_PATH.exists(): |
| | print(f"Continuing training from last JIT model: {LAST_TRAINED_PATH}") |
| | model = torch.jit.load(LAST_TRAINED_PATH, map_location=device) |
| | elif BASE_MODEL_PATH.exists(): |
| | print(f"Starting from base JIT model: {BASE_MODEL_PATH}") |
| | model = torch.jit.load(BASE_MODEL_PATH, map_location=device) |
| | else: |
| | print("ERROR: JIT model not found. Cannot start training without source code or JIT file.") |
| | return |
| |
|
| | model.train() |
| |
|
| | |
| | train_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, split_type='train', val_ratio=VAL_SPLIT_RATIO) |
| | val_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, split_type='val', val_ratio=VAL_SPLIT_RATIO) |
| |
|
| | train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) |
| | val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True) |
| |
|
| | |
| | optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) |
| | criterion = nn.CrossEntropyLoss() |
| |
|
| | total_steps = len(train_dataloader) * EPOCHS |
| | print(f"\n=== BEGINNING LONG-TERM TRAINING ===") |
| | print(f"Epochs: {EPOCHS} | Steps (Train): {total_steps} | Examples (Train): {len(train_dataset)}") |
| |
|
| | global_step = 0 |
| | for epoch in range(1, EPOCHS + 1): |
| | print(f"\n--- Epoch {epoch}/{EPOCHS} ---") |
| | epoch_loss = 0.0 |
| |
|
| | |
| | with tqdm(train_dataloader, desc=f"Epoch {epoch} [TRAIN]", leave=False) as pbar: |
| | for inputs, targets in pbar: |
| | inputs, targets = inputs.to(device), targets.to(device) |
| |
|
| | optimizer.zero_grad() |
| |
|
| | |
| | logits = get_logits_from_model(model, inputs) |
| |
|
| | loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) |
| | optimizer.step() |
| |
|
| | loss_val = loss.item() |
| | epoch_loss += loss_val |
| | global_step += 1 |
| |
|
| | pbar.set_postfix({ |
| | "loss": f"{loss_val:.3f}", |
| | "ppl": f"{math.exp(min(loss_val, 10)):.1f}", |
| | "step": f"{global_step}/{total_steps}" |
| | }) |
| |
|
| | avg_train_loss = epoch_loss / len(train_dataloader) |
| | print(f" [TRAIN] Average loss: {avg_train_loss:.3f} | PPL: {math.exp(avg_train_loss):.1f}") |
| |
|
| | |
| | print(" [VALIDATION] Starting evaluation...") |
| | val_loss = evaluate(model, val_dataloader, criterion, device) |
| | print(f" [VALIDATION] Average loss: {val_loss:.3f} | PPL: {math.exp(val_loss):.1f}") |
| |
|
| | |
| | epoch_dir = OUTPUT_DIR / f"epoch{epoch}" |
| | epoch_dir.mkdir(exist_ok=True) |
| | |
| | model.save(epoch_dir / MODEL_SAVE_NAME) |
| | print(f"Model saved: {epoch_dir / MODEL_SAVE_NAME}") |
| | cleanup_old_epochs() |
| |
|
| | |
| | final_dir = OUTPUT_DIR / "final" |
| | final_dir.mkdir(exist_ok=True) |
| | model.save(final_dir / MODEL_SAVE_NAME) |
| | train_dataset.tokenizer.save_pretrained(final_dir) |
| |
|
| | |
| | if LAST_TRAINED_PATH.exists(): |
| | backup_path = BACKUP_DIR / f"gpt_last_trained_backup_{int(os.path.getmtime(LAST_TRAINED_PATH))}.script.pt" |
| | shutil.copy(LAST_TRAINED_PATH, backup_path) |
| | print(f"Backup of previous model created → {backup_path.name}") |
| |
|
| | shutil.copy(final_dir / MODEL_SAVE_NAME, LAST_TRAINED_PATH) |
| | print(f"Last trained model saved → {LAST_TRAINED_PATH}") |
| |
|
| | print(f"\nTRAINING COMPLETED! Model ready:") |
| | print(f" • For chat: {final_dir / MODEL_SAVE_NAME}") |
| | print(f" • For further fine-tuning: {LAST_TRAINED_PATH}") |
| |
|
| | if __name__ == "__main__": |
| | if not RAW_PATH.exists(): |
| | print(f"ERROR: No file {RAW_PATH}") |
| | print("Put your text into datasets/dialogues_text.txt") |
| | else: |
| | train() |