Upload fine_tune_jit_with_validation_1b.py
Browse files
fine_tune_jit_with_validation_1b.py
CHANGED
|
@@ -1,14 +1,8 @@
|
|
| 1 |
-
# Copyright (c) 2025 CMS Manhattan
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This file is part of a project authored by CMS Manhattan. You may use, distribute, and modify
|
| 5 |
-
# this code under the terms of the APACHE 2.0 license.
|
| 6 |
-
|
| 7 |
import os
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
import torch.optim as optim
|
| 11 |
-
from torch.utils.data import
|
| 12 |
from transformers import GPT2TokenizerFast
|
| 13 |
from tqdm import tqdm
|
| 14 |
import shutil
|
|
@@ -66,55 +60,53 @@ DATASET_PATH = CLEAN_PATH
|
|
| 66 |
OUTPUT_DIR = Path("build/fine_tuning_output")
|
| 67 |
MODEL_SAVE_NAME = "gpt_finetuned.script.pt"
|
| 68 |
|
| 69 |
-
device = torch.device("
|
| 70 |
print(f"Using device: {device}")
|
| 71 |
|
| 72 |
-
# ============================= DATASET =============================
|
| 73 |
|
| 74 |
-
class
|
|
|
|
| 75 |
def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, tokenizer_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO):
|
| 76 |
self.seq_len = seq_len
|
| 77 |
self.tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_name)
|
| 78 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
| 79 |
self.split_type = split_type
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
text
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
if self.split_type == 'train':
|
| 100 |
-
self.inputs = all_inputs[:train_size]
|
| 101 |
-
self.labels = all_labels[:train_size]
|
| 102 |
-
elif self.split_type == 'val':
|
| 103 |
-
self.inputs = all_inputs[train_size:]
|
| 104 |
-
self.labels = all_labels[train_size:]
|
| 105 |
else:
|
| 106 |
-
raise ValueError("
|
| 107 |
-
|
| 108 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
def __len__(self):
|
| 111 |
-
return
|
| 112 |
-
|
| 113 |
-
def __getitem__(self, idx):
|
| 114 |
-
return (
|
| 115 |
-
torch.tensor(self.inputs[idx], dtype=torch.long),
|
| 116 |
-
torch.tensor(self.labels[idx], dtype=torch.long)
|
| 117 |
-
)
|
| 118 |
|
| 119 |
# ============================= GET LOGITS UTIL =============================
|
| 120 |
|
|
@@ -130,6 +122,7 @@ def get_logits_from_model(model, inputs):
|
|
| 130 |
def evaluate(model, dataloader, criterion, device):
|
| 131 |
model.eval()
|
| 132 |
total_loss = 0.0
|
|
|
|
| 133 |
with torch.no_grad():
|
| 134 |
for inputs, targets in dataloader:
|
| 135 |
inputs, targets = inputs.to(device), targets.to(device)
|
|
@@ -138,7 +131,8 @@ def evaluate(model, dataloader, criterion, device):
|
|
| 138 |
targets = targets.contiguous().view(-1)[:logits.shape[0]]
|
| 139 |
loss = criterion(logits, targets)
|
| 140 |
total_loss += loss.item()
|
| 141 |
-
|
|
|
|
| 142 |
model.train()
|
| 143 |
return avg_loss
|
| 144 |
|
|
@@ -185,16 +179,17 @@ def train():
|
|
| 185 |
except AttributeError:
|
| 186 |
print("⚠️ Warning: model.gradient_checkpointing_enable() not found on JIT model. Training will proceed without GC.")
|
| 187 |
|
| 188 |
-
train_dataset =
|
| 189 |
-
val_dataset =
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
|
|
|
| 193 |
|
| 194 |
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
|
| 195 |
criterion = nn.CrossEntropyLoss()
|
| 196 |
|
| 197 |
-
total_steps = len(
|
| 198 |
print(f"\n=== BEGINNING LONG-TERM TRAINING ===")
|
| 199 |
print(f"Epochs: {EPOCHS} | Steps (Train): {total_steps} | Examples (Train): {len(train_dataset)}")
|
| 200 |
print(f"Batch Size (Effective): {BATCH_SIZE} | Precision: FP32")
|
|
@@ -223,10 +218,10 @@ def train():
|
|
| 223 |
pbar.set_postfix({
|
| 224 |
"loss": f"{loss_val:.3f}",
|
| 225 |
"ppl": f"{math.exp(min(loss_val, 10)):.1f}",
|
| 226 |
-
"step": f"{global_step}
|
| 227 |
})
|
| 228 |
|
| 229 |
-
avg_train_loss = epoch_loss / len(
|
| 230 |
print(f" [TRAIN] Average loss: {avg_train_loss:.3f} | PPL: {math.exp(avg_train_loss):.1f}")
|
| 231 |
|
| 232 |
print(" [VALIDATION] Starting evaluation...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.optim as optim
|
| 5 |
+
from torch.utils.data import IterableDataset, DataLoader
|
| 6 |
from transformers import GPT2TokenizerFast
|
| 7 |
from tqdm import tqdm
|
| 8 |
import shutil
|
|
|
|
| 60 |
OUTPUT_DIR = Path("build/fine_tuning_output")
|
| 61 |
MODEL_SAVE_NAME = "gpt_finetuned.script.pt"
|
| 62 |
|
| 63 |
+
device = torch.device("cpu")
|
| 64 |
print(f"Using device: {device}")
|
| 65 |
|
| 66 |
+
# ============================= DATASET (LAZY) =============================
|
| 67 |
|
| 68 |
+
class LazyTextDataset(IterableDataset):
|
| 69 |
+
"""Lazy memory-efficient dataset, splits on-the-fly into train and val."""
|
| 70 |
def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, tokenizer_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO):
|
| 71 |
self.seq_len = seq_len
|
| 72 |
self.tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_name)
|
| 73 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 74 |
+
self.text_file = text_file
|
| 75 |
self.split_type = split_type
|
| 76 |
+
self.val_ratio = val_ratio
|
| 77 |
+
|
| 78 |
+
print(f"Loading and tokenizing text from {text_file}")
|
| 79 |
+
with open(text_file, "r", encoding="utf-8") as f:
|
| 80 |
+
self.data = f.read()
|
| 81 |
+
self.tokens = self.tokenizer.encode(self.data)
|
| 82 |
+
|
| 83 |
+
# Work out split indices
|
| 84 |
+
total_tokens = len(self.tokens) - 1 # because label sequence shifted
|
| 85 |
+
total_batches = total_tokens // seq_len
|
| 86 |
+
val_size = int(total_batches * self.val_ratio)
|
| 87 |
+
train_size = total_batches - val_size
|
| 88 |
+
if split_type == 'train':
|
| 89 |
+
self.start = 0
|
| 90 |
+
self.stop = train_size
|
| 91 |
+
elif split_type == 'val':
|
| 92 |
+
self.start = train_size
|
| 93 |
+
self.stop = train_size + val_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
else:
|
| 95 |
+
raise ValueError(f"split_type should be 'train' or 'val', got {split_type}")
|
| 96 |
+
self.total_sequences = self.stop - self.start
|
| 97 |
+
print(f"Lazy dataset: {self.total_sequences:,} sequences for {split_type} split (from {total_batches:,} total)")
|
| 98 |
+
|
| 99 |
+
def __iter__(self):
|
| 100 |
+
for i in range(self.start * self.seq_len, self.stop * self.seq_len, self.seq_len):
|
| 101 |
+
# Make sure last batch fits
|
| 102 |
+
if i + self.seq_len + 1 > len(self.tokens):
|
| 103 |
+
break
|
| 104 |
+
input_seq = torch.tensor(self.tokens[i : i + self.seq_len], dtype=torch.long)
|
| 105 |
+
label_seq = torch.tensor(self.tokens[i + 1 : i + self.seq_len + 1], dtype=torch.long)
|
| 106 |
+
yield input_seq, label_seq
|
| 107 |
|
| 108 |
def __len__(self):
|
| 109 |
+
return self.total_sequences
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
# ============================= GET LOGITS UTIL =============================
|
| 112 |
|
|
|
|
| 122 |
def evaluate(model, dataloader, criterion, device):
|
| 123 |
model.eval()
|
| 124 |
total_loss = 0.0
|
| 125 |
+
count = 0
|
| 126 |
with torch.no_grad():
|
| 127 |
for inputs, targets in dataloader:
|
| 128 |
inputs, targets = inputs.to(device), targets.to(device)
|
|
|
|
| 131 |
targets = targets.contiguous().view(-1)[:logits.shape[0]]
|
| 132 |
loss = criterion(logits, targets)
|
| 133 |
total_loss += loss.item()
|
| 134 |
+
count += 1
|
| 135 |
+
avg_loss = total_loss / max(count, 1)
|
| 136 |
model.train()
|
| 137 |
return avg_loss
|
| 138 |
|
|
|
|
| 179 |
except AttributeError:
|
| 180 |
print("⚠️ Warning: model.gradient_checkpointing_enable() not found on JIT model. Training will proceed without GC.")
|
| 181 |
|
| 182 |
+
train_dataset = LazyTextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, split_type='train', val_ratio=VAL_SPLIT_RATIO)
|
| 183 |
+
val_dataset = LazyTextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, split_type='val', val_ratio=VAL_SPLIT_RATIO)
|
| 184 |
|
| 185 |
+
# IterableDataset: must use drop_last=True and shuffle=False, num_workers=0 on CPU
|
| 186 |
+
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=0)
|
| 187 |
+
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=0)
|
| 188 |
|
| 189 |
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
|
| 190 |
criterion = nn.CrossEntropyLoss()
|
| 191 |
|
| 192 |
+
total_steps = (len(train_dataset) // BATCH_SIZE) * EPOCHS
|
| 193 |
print(f"\n=== BEGINNING LONG-TERM TRAINING ===")
|
| 194 |
print(f"Epochs: {EPOCHS} | Steps (Train): {total_steps} | Examples (Train): {len(train_dataset)}")
|
| 195 |
print(f"Batch Size (Effective): {BATCH_SIZE} | Precision: FP32")
|
|
|
|
| 218 |
pbar.set_postfix({
|
| 219 |
"loss": f"{loss_val:.3f}",
|
| 220 |
"ppl": f"{math.exp(min(loss_val, 10)):.1f}",
|
| 221 |
+
"step": f"{global_step}"
|
| 222 |
})
|
| 223 |
|
| 224 |
+
avg_train_loss = epoch_loss / max(1, len(train_dataset) // BATCH_SIZE)
|
| 225 |
print(f" [TRAIN] Average loss: {avg_train_loss:.3f} | PPL: {math.exp(avg_train_loss):.1f}")
|
| 226 |
|
| 227 |
print(" [VALIDATION] Starting evaluation...")
|