|
|
import math |
|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from tqdm import tqdm |
|
|
import string |
|
|
import contextlib |
|
|
from model import ChatGCLM, MAX_SEQ_LEN |
|
|
|
|
|
if os.name != "nt": |
|
|
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.set_float32_matmul_precision("high") |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
FINETUNE = True |
|
|
DATA_DIR = "finetune" if FINETUNE else "data" |
|
|
DATA_PCT = 0.002 |
|
|
MPS_SEQ_LEN = 512 |
|
|
MPS_STEPS_PER_EPOCH = 18 |
|
|
CPU_SEQ_LEN = 512 |
|
|
CPU_STEPS_PER_EPOCH = 48 |
|
|
VOCAB_SAVE_PATH = "vocab_map.pt" |
|
|
|
|
|
EPOCHS = 100 |
|
|
MICRO_BATCH_SIZE = 8 |
|
|
GRAD_ACCUM_STEPS = 4 |
|
|
STEPS_PER_EPOCH = 500 |
|
|
LEARNING_RATE = 5e-4 |
|
|
MIN_LR = 1e-5 |
|
|
|
|
|
SAVE_N_EPOCHS = 1 |
|
|
|
|
|
PAD_ID = 0 |
|
|
SEP_ID = 1 |
|
|
EOS_ID = 2 |
|
|
OFFSET = 3 |
|
|
CHARS = string.printable |
|
|
VOCAB_SIZE = len(CHARS) + OFFSET |
|
|
|
|
|
def encode(text): |
|
|
return [CHARS.index(c) + OFFSET for c in text if c in CHARS] |
|
|
|
|
|
def decode(ids): |
|
|
return "".join([CHARS[i - OFFSET] for i in ids if i >= OFFSET]) |
|
|
|
|
|
def build_dataset_vocab(save_path): |
|
|
torch.save({ |
|
|
"vocab_size": VOCAB_SIZE, |
|
|
"PAD_ID": PAD_ID, |
|
|
"SEP_ID": SEP_ID, |
|
|
"EOS_ID": EOS_ID, |
|
|
"CHARS": CHARS |
|
|
}, save_path) |
|
|
return VOCAB_SIZE |
|
|
|
|
|
class RemappedTextDataset(Dataset): |
|
|
def __init__(self, ids, max_len): |
|
|
self.ids = ids |
|
|
self.max_len = max_len |
|
|
|
|
|
def __len__(self): |
|
|
return max(0, (len(self.ids) - 1) // self.max_len) |
|
|
|
|
|
def __getitem__(self, i): |
|
|
start = i * self.max_len |
|
|
x = self.ids[start : start + self.max_len] |
|
|
y = self.ids[start + 1 : start + self.max_len + 1] |
|
|
|
|
|
if len(x) < self.max_len: |
|
|
x = x + [PAD_ID] * (self.max_len - len(x)) |
|
|
if len(y) < self.max_len: |
|
|
y = y + [PAD_ID] * (self.max_len - len(y)) |
|
|
|
|
|
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long) |
|
|
|
|
|
def format_params(num): |
|
|
if num >= 1_000_000_000: |
|
|
return f"{num/1_000_000_000:.1f}B" |
|
|
elif num >= 1_000_000: |
|
|
return f"{num/1_000_000:.1f}M" |
|
|
else: |
|
|
return f"{num/1_000:.1f}K" |
|
|
|
|
|
@torch.no_grad() |
|
|
def estimate_loss(model, dl, device, ctx): |
|
|
model.eval() |
|
|
losses = [] |
|
|
limit = 50 |
|
|
for i, (x, y) in enumerate(dl): |
|
|
if i >= limit: break |
|
|
x, y = x.to(device), y.to(device) |
|
|
with ctx: |
|
|
logits = model(x) |
|
|
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=PAD_ID) |
|
|
losses.append(loss.item()) |
|
|
model.train() |
|
|
return sum(losses) / len(losses) if losses else 0.0 |
|
|
|
|
|
def train(): |
|
|
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
|
|
|
effective_batch_target = MICRO_BATCH_SIZE * GRAD_ACCUM_STEPS |
|
|
micro_batch_size = MICRO_BATCH_SIZE |
|
|
grad_accum_steps = GRAD_ACCUM_STEPS |
|
|
train_seq_len = MAX_SEQ_LEN |
|
|
steps_per_epoch = STEPS_PER_EPOCH |
|
|
|
|
|
if device == "mps": |
|
|
if hasattr(torch, "mps"): |
|
|
torch.mps.empty_cache() |
|
|
micro_batch_size = 1 |
|
|
grad_accum_steps = max(1, math.ceil(effective_batch_target / micro_batch_size)) |
|
|
train_seq_len = min(MAX_SEQ_LEN, MPS_SEQ_LEN) |
|
|
steps_per_epoch = min(STEPS_PER_EPOCH, MPS_STEPS_PER_EPOCH) |
|
|
elif device == "cpu": |
|
|
micro_batch_size = min(4, MICRO_BATCH_SIZE) |
|
|
grad_accum_steps = max(1, math.ceil(effective_batch_target / micro_batch_size)) |
|
|
train_seq_len = min(MAX_SEQ_LEN, CPU_SEQ_LEN) |
|
|
steps_per_epoch = min(STEPS_PER_EPOCH, CPU_STEPS_PER_EPOCH) |
|
|
|
|
|
steps_per_epoch = max(1, steps_per_epoch) |
|
|
effective_batch_size = micro_batch_size * grad_accum_steps |
|
|
vocab = build_dataset_vocab(VOCAB_SAVE_PATH) |
|
|
|
|
|
full_text = "" |
|
|
target_files = [f for f in os.listdir(DATA_DIR) if f.endswith(".txt")] |
|
|
target_files.sort() |
|
|
print(f"Loading {len(target_files)} text file(s) from {DATA_DIR}...") |
|
|
for f in target_files: |
|
|
fpath = os.path.join(DATA_DIR, f) |
|
|
print(f" - Reading {f}...") |
|
|
try: |
|
|
with open(fpath, "r", encoding="utf-8") as file: |
|
|
content = file.read() |
|
|
full_text += content + "\n" |
|
|
except Exception as e: |
|
|
print(f"Error reading {f}: {e}") |
|
|
|
|
|
print(f"Total dataset size: {len(full_text):,} characters") |
|
|
ids = encode(full_text) + [EOS_ID] |
|
|
if 0 < DATA_PCT < 1.0: |
|
|
target_tokens = max(MAX_SEQ_LEN + 1, int(len(ids) * DATA_PCT)) |
|
|
ids = ids[:target_tokens] |
|
|
print(f"Using {DATA_PCT*100:.2f}% of tokens -> {len(ids):,} tokens") |
|
|
else: |
|
|
print(f"Tokenized dataset -> {len(ids):,} tokens") |
|
|
|
|
|
n = len(ids) |
|
|
split_idx = int(n * 0.95) |
|
|
train_ids = ids[:split_idx] |
|
|
val_ids = ids[split_idx:] |
|
|
|
|
|
train_ds = RemappedTextDataset(train_ids, train_seq_len) |
|
|
val_ds = RemappedTextDataset(val_ids, train_seq_len) |
|
|
|
|
|
kwargs = {'num_workers': 4, 'pin_memory': True} if device == "cuda" else {} |
|
|
train_dl = DataLoader(train_ds, batch_size=micro_batch_size, shuffle=True, **kwargs) |
|
|
val_dl = DataLoader(val_ds, batch_size=micro_batch_size, shuffle=False, **kwargs) |
|
|
|
|
|
model = ChatGCLM(vocab).to(device) |
|
|
|
|
|
|
|
|
if torch.cuda.device_count() > 1: |
|
|
print(f"Using {torch.cuda.device_count()} GPUs!") |
|
|
model = nn.DataParallel(model) |
|
|
|
|
|
num_params = sum(p.numel() for p in model.parameters()) |
|
|
param_str = format_params(num_params) |
|
|
save_path = f"Turing_{param_str}.pt" |
|
|
|
|
|
print("-" * 30) |
|
|
print(f"Turing TRAINING START") |
|
|
print(f"Model ID: {save_path}") |
|
|
print(f"Parameters: {num_params:,}") |
|
|
print(f"Device: {device}") |
|
|
print(f"Vocab Size: {vocab}") |
|
|
print(f"Learning Rate: {LEARNING_RATE}") |
|
|
print(f"Micro Batch: {micro_batch_size}") |
|
|
print(f"Grad Accum: {grad_accum_steps}") |
|
|
print(f"Effective Batch: {effective_batch_size}") |
|
|
print(f"Train Seq: {train_seq_len}") |
|
|
print(f"Epoch Steps: {steps_per_epoch}") |
|
|
print(f"Epochs: {EPOCHS}") |
|
|
print("-" * 30) |
|
|
|
|
|
if os.path.exists(save_path) and os.path.getsize(save_path) > 0: |
|
|
print(f" Found checkpoint at {save_path}, loading...") |
|
|
state_dict = torch.load(save_path, map_location=device) |
|
|
if isinstance(model, nn.DataParallel): |
|
|
if "module." not in list(state_dict.keys())[0]: |
|
|
new_state_dict = {f"module.{k}": v for k, v in state_dict.items()} |
|
|
state_dict = new_state_dict |
|
|
elif "module." in list(state_dict.keys())[0]: |
|
|
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
state_dict = new_state_dict |
|
|
|
|
|
model.load_state_dict(state_dict) |
|
|
print(" Model weights loaded successfully! Resuming training.") |
|
|
else: |
|
|
print(" No checkpoint found. Starting training from scratch.") |
|
|
|
|
|
opt_kwargs = {"lr": LEARNING_RATE} |
|
|
if device == "cuda": |
|
|
opt_kwargs["fused"] = True |
|
|
opt = torch.optim.AdamW(model.parameters(), **opt_kwargs) |
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS, eta_min=MIN_LR) |
|
|
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID) |
|
|
if device == "cuda": |
|
|
ctx = torch.amp.autocast(device_type="cuda") |
|
|
scaler = torch.amp.GradScaler() |
|
|
else: |
|
|
ctx = contextlib.nullcontext() |
|
|
scaler = None |
|
|
|
|
|
for ep in range(EPOCHS): |
|
|
model.train() |
|
|
opt.zero_grad(set_to_none=True) |
|
|
total_steps = min(len(train_dl), steps_per_epoch) |
|
|
pbar = tqdm(train_dl, desc=f"Epoch {ep+1}/{EPOCHS}", total=total_steps) |
|
|
running_loss = 0.0 |
|
|
steps_since_update = 0 |
|
|
for step_idx, (x, y) in enumerate(pbar): |
|
|
if step_idx >= total_steps: |
|
|
break |
|
|
x, y = x.to(device), y.to(device) |
|
|
steps_since_update += 1 |
|
|
is_last_batch = (step_idx + 1) == total_steps |
|
|
accum_divisor = grad_accum_steps if not is_last_batch else steps_since_update |
|
|
with ctx: |
|
|
logits = model(x) |
|
|
loss = loss_fn(logits.reshape(-1, logits.size(-1)), y.reshape(-1)) |
|
|
loss_val = loss.item() |
|
|
loss = loss / accum_divisor |
|
|
if scaler: |
|
|
scaler.scale(loss).backward() |
|
|
else: |
|
|
loss.backward() |
|
|
should_step = steps_since_update == grad_accum_steps or is_last_batch |
|
|
if should_step: |
|
|
if scaler: |
|
|
scaler.step(opt) |
|
|
scaler.update() |
|
|
else: |
|
|
opt.step() |
|
|
opt.zero_grad(set_to_none=True) |
|
|
if device == "mps" and hasattr(torch, "mps"): |
|
|
torch.mps.empty_cache() |
|
|
steps_since_update = 0 |
|
|
running_loss = 0.9 * running_loss + 0.1 * loss_val if running_loss > 0 else loss_val |
|
|
pbar.set_postfix(loss=f"{running_loss:.4f}") |
|
|
val_loss = estimate_loss(model, val_dl, device, ctx) |
|
|
current_lr = scheduler.get_last_lr()[0] |
|
|
print(f"Epoch {ep+1} | Train Loss: {running_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {current_lr:.6f}") |
|
|
torch.save(model.state_dict(), save_path) |
|
|
print(f" Model saved successfully after epoch {ep+1} to {save_path}") |
|
|
scheduler.step() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |
|
|
|