| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import sys |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import Dataset, DataLoader |
| from transformers import GPT2TokenizerFast |
| from tqdm import tqdm |
| import shutil |
| import math |
| from pathlib import Path |
| import re |
| import logging |
| from torch.amp import GradScaler, autocast |
|
|
| |
| logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR) |
| |
|
|
| |
| from gpt_jit_modern_1b import JiRackPyTorch |
|
|
| |
| |
| TRAIN_SEQ_LEN = 64 |
| BATCH_SIZE = 1 |
| ACCUM_STEPS = 32 |
| EPOCHS = 500 |
| LEARNING_RATE = 1e-6 |
| WEIGHT_DECAY = 0.01 |
| GRAD_CLIP = 1.0 |
| VAL_SPLIT_RATIO = 0.05 |
| KEEP_LAST_EPOCHS = 3 |
| |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| if device.type == 'cuda': |
| USE_AMP = True |
| AUTOCAST_DTYPE = torch.float16 |
| print(f"Using device: {device} (GPU). AMP (FP16) enabled for efficiency.") |
| elif device.type == 'cpu': |
| USE_AMP = False |
| AUTOCAST_DTYPE = torch.float32 |
| print(f"Using device: {device} (CPU). WARNING: Training 1.2B model on CPU will be extremely slow.") |
| else: |
| USE_AMP = False |
| AUTOCAST_DTYPE = torch.float32 |
| print(f"Using device: {device}. AMP disabled.") |
|
|
| |
| BASE_MODEL_PATH = Path("models/gpt_modern_1b_class.state_dict.pt") |
| LAST_TRAINED_PATH = Path("models/gpt_last_modern_1b_class.state_dict.pt") |
| BACKUP_DIR = Path("models/backups") |
| BACKUP_DIR.mkdir(exist_ok=True, parents=True) |
|
|
| RAW_PATH = Path("datasets/dialogues_text.txt") |
| CLEAN_PATH = Path("datasets/dialogues_text_clean.txt") |
|
|
| |
| if not CLEAN_PATH.exists() or RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime: |
| print("Cleaning dataset...") |
| try: |
| text = RAW_PATH.read_text(encoding="utf-8") |
| text = re.sub(r' {2,}', ' ', text) |
| text = text.replace(" \n", "\n").replace("\n ", "\n") |
| CLEAN_PATH.write_text(text, encoding="utf-8") |
| print(f"Done → {CLEAN_PATH}") |
| except FileNotFoundError: |
| print(f"ERROR: Raw dataset not found at {RAW_PATH}") |
| sys.exit(1) |
|
|
| DATASET_PATH = CLEAN_PATH |
| OUTPUT_DIR = Path("build/fine_tuning_output") |
| MODEL_SAVE_NAME = "pytorch_model.bin" |
|
|
| |
| class TextDataset(Dataset): |
| def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, split='train'): |
| self.seq_len = seq_len |
| try: |
| tokenizer = GPT2TokenizerFast.from_pretrained("./tokenizer", local_files_only=True) |
| except Exception: |
| tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") |
| |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| text = Path(text_file).read_text(encoding="utf-8") |
| tokens = tokenizer.encode(text) |
|
|
| sequences = [] |
| for i in range(0, len(tokens) - seq_len, seq_len): |
| sequences.append(tokens[i:i + seq_len + 1]) |
|
|
| split_idx = int(len(sequences) * (1 - VAL_SPLIT_RATIO)) |
| if split == 'train': |
| self.data = sequences[:split_idx] |
| else: |
| self.data = sequences[split_idx:] |
|
|
| print(f"{split.upper()} sequences: {len(self.data):,}") |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| seq = self.data[idx] |
| return torch.tensor(seq[:-1], dtype=torch.long), torch.tensor(seq[1:], dtype=torch.long) |
|
|
|
|
| def evaluate(model, loader): |
| model.eval() |
| total_loss = 0 |
| criterion = nn.CrossEntropyLoss() |
| |
| with torch.no_grad(), autocast(device_type=device.type, enabled=USE_AMP, dtype=AUTOCAST_DTYPE): |
| for x, y in loader: |
| x, y = x.to(device), y.to(device) |
| |
| logits = model(x) |
| if isinstance(logits, tuple): |
| logits = logits[0] |
| |
| input_logits = logits.contiguous().view(-1, logits.size(-1)) |
| target_labels = y.contiguous().view(-1)[:input_logits.size(0)] |
|
|
| |
| loss = criterion(input_logits.float(), target_labels) |
| |
| total_loss += loss.item() |
| |
| model.train() |
| return total_loss / len(loader) |
|
|
|
|
| def train(): |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| print("Loading model...") |
| model = JiRackPyTorch().to(device) |
| |
| |
| scaler = GradScaler(enabled=USE_AMP, device=device.type) |
|
|
| |
| |
| |
| print("Starting from scratch — random weights (Skipping state_dict load for stability test!)") |
| |
|
|
| model.train() |
|
|
| train_dataset = TextDataset(DATASET_PATH, split='train') |
| val_dataset = TextDataset(DATASET_PATH, split='val') |
|
|
| train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) |
| val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=0) |
|
|
| optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) |
| criterion = nn.CrossEntropyLoss() |
|
|
| print("\nFULL TRAINING STARTED! No LoRA, no compromises — we're training the whole thing!\n") |
| print(f"Batch size: {BATCH_SIZE * ACCUM_STEPS} (effective) | LR: {LEARNING_RATE} | AMP: {USE_AMP} ({AUTOCAST_DTYPE})") |
|
|
| for epoch in range(1, EPOCHS + 1): |
| total_loss = 0 |
| pbar = tqdm(train_loader, desc=f"Epoch {epoch} [TRAIN]") |
| |
| for step, (x, y) in enumerate(pbar): |
| x, y = x.to(device), y.to(device) |
|
|
| |
| with autocast(device_type=device.type, enabled=USE_AMP, dtype=AUTOCAST_DTYPE): |
| |
| logits = model(x) |
| if isinstance(logits, tuple): |
| logits = logits[0] |
| |
| input_logits = logits.contiguous().view(-1, logits.size(-1)) |
| target_labels = y.contiguous().view(-1)[:input_logits.size(0)] |
|
|
| loss = criterion(input_logits.float(), target_labels) |
| loss = loss / ACCUM_STEPS |
| |
| |
| if torch.isnan(loss).any(): |
| print(f"\n[FATAL ERROR] Loss became NaN at step {step}. Stopping training.") |
| raise RuntimeError("Loss became NaN during training, stopping.") |
|
|
| |
| scaler.scale(loss).backward() |
|
|
| total_loss += loss.item() * ACCUM_STEPS |
|
|
| if (step + 1) % ACCUM_STEPS == 0 or (step + 1) == len(train_loader): |
| |
| if USE_AMP: |
| scaler.unscale_(optimizer) |
| |
| torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad() |
| |
| |
| current_avg_loss = total_loss / (step + 1) |
| ppl_val = math.exp(min(current_avg_loss, 10)) |
| pbar.set_postfix({"loss (avg)": f"{current_avg_loss:.4f}", "ppl": f"{ppl_val:.2f}"}) |
|
|
|
|
| avg_train_loss = total_loss / len(train_loader) |
| val_loss = evaluate(model, val_loader) |
|
|
| print(f"\nEpoch {epoch}") |
| print(f" Train loss: {avg_train_loss:.4f} | PPL: {math.exp(avg_train_loss):.2f}") |
| print(f" Val loss: {val_loss:.4f} | PPL: {math.exp(val_loss):.2f}") |
|
|
| |
| save_dir = OUTPUT_DIR / f"epoch_{epoch}" |
| save_dir.mkdir(exist_ok=True, parents=True) |
| |
| torch.save(model.state_dict(), save_dir / MODEL_SAVE_NAME) |
| torch.save(model.state_dict(), LAST_TRAINED_PATH) |
|
|
| |
| epochs_dirs = sorted([p for p in OUTPUT_DIR.iterdir() if p.is_dir() and p.name.startswith("epoch_")]) |
| for old in epochs_dirs[:-KEEP_LAST_EPOCHS]: |
| shutil.rmtree(old) |
|
|
| print("\nDONE! Full model trained. You are now the emperor of fine-tuning.") |
|
|
|
|
| if __name__ == "__main__": |
| try: |
| train() |
| except RuntimeError as e: |
| if "Loss became NaN" in str(e): |
| print("\n[CRITICAL FAILURE] Training stopped due to NaN loss.") |
| print("Action: Revisit JiRackPyTorch weight initialization (reduce STD further) or reduce LEARNING_RATE to 1e-6.") |
| sys.exit(1) |
| elif "CUDA out of memory" in str(e): |
| print("\n[CRITICAL FAILURE] CUDA Out of Memory.") |
| print("Action: Current configuration BATCH_SIZE=1, AMP=FP16 is the minimum memory usage possible. Try reducing TRAIN_SEQ_LEN from 256 to 128.") |
| sys.exit(1) |
| raise |