|
|
print("Starting...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TXT_PATH = "data.txt" |
|
|
DATA_PCT = 0.001 |
|
|
TOKENIZER_NAME = "gpt2" |
|
|
REDUCE_VOCAB = True |
|
|
VOCAB_SAVE_PATH = "vocab_map.pt" |
|
|
|
|
|
|
|
|
EPOCHS = 25 |
|
|
MICRO_BATCH_SIZE = 1 |
|
|
GRAD_ACCUM_STEPS = 8 |
|
|
LEARNING_RATE = 3e-4 |
|
|
|
|
|
|
|
|
D_MODEL = 256 |
|
|
N_LAYERS = 4 |
|
|
MAX_SEQ_LEN = 1024 |
|
|
|
|
|
LOCAL_KERNEL_SIZE = 5 |
|
|
GLOBAL_KERNEL_SIZE = 256 |
|
|
USE_GLOBAL_EVERY_N_LAYERS = 2 |
|
|
|
|
|
|
|
|
FFT_SIZE = 1024 |
|
|
|
|
|
|
|
|
SAVE_PATH = "model.pt" |
|
|
SAVE_N_EPOCHS = 1 |
|
|
|
|
|
|
|
|
USE_DEVICE = "cuda" |
|
|
USE_AMP = True |
|
|
USE_ACTIVATION_CHECKPOINTING = False |
|
|
|
|
|
|
|
|
COMPILE = False |
|
|
COMPILE_MODE = "reduce-overhead" |
|
|
COMPILE_BACKEND = "eager" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
|
|
|
if os.name != "nt": |
|
|
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from tqdm import tqdm |
|
|
import tiktoken |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.set_float32_matmul_precision("high") |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PAD_ID = 0 |
|
|
SEP_ID = 1 |
|
|
EOS_ID = 2 |
|
|
OFFSET = 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_dataset_vocab(txt_path, tokenizer, save_path): |
|
|
text = open(txt_path, "r", encoding="utf-8").read() |
|
|
if DATA_PCT < 1.0: |
|
|
text = text[:int(len(text) * DATA_PCT)] |
|
|
token_ids = tokenizer.encode(text) |
|
|
used = sorted(set(token_ids)) |
|
|
|
|
|
id2new = {tok: i + OFFSET for i, tok in enumerate(used)} |
|
|
|
|
|
torch.save({ |
|
|
"used_tokens": used, |
|
|
"id2new": id2new, |
|
|
"PAD_ID": PAD_ID, |
|
|
"SEP_ID": SEP_ID, |
|
|
"EOS_ID": EOS_ID, |
|
|
}, save_path) |
|
|
|
|
|
print(f"[OK] Vocab size: {len(used) + OFFSET}") |
|
|
return used, id2new |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RemappedTextDataset(Dataset): |
|
|
def __init__(self, ids, max_len): |
|
|
self.ids = ids |
|
|
self.max_len = max_len |
|
|
|
|
|
def __len__(self): |
|
|
return max(0, len(self.ids) - self.max_len - 1) |
|
|
|
|
|
def __getitem__(self, i): |
|
|
x = self.ids[i:i+self.max_len] |
|
|
y = self.ids[i+1:i+self.max_len+1] |
|
|
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GlobalConv1D(nn.Module): |
|
|
def __init__(self, d_model, kernel_size, fft_size): |
|
|
super().__init__() |
|
|
self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01) |
|
|
self.kernel_size = kernel_size |
|
|
self.fft_size = fft_size |
|
|
|
|
|
def forward(self, x): |
|
|
B, C, T = x.shape |
|
|
K = min(self.kernel_size, T) |
|
|
|
|
|
overlap = K - 1 |
|
|
block = self.fft_size - overlap |
|
|
|
|
|
x = F.pad(x, (overlap, 0)) |
|
|
k = self.kernel[:, :K] |
|
|
k = F.pad(k, (0, self.fft_size - K)) |
|
|
k_f = torch.fft.rfft(k, n=self.fft_size) |
|
|
|
|
|
outs = [] |
|
|
pos = 0 |
|
|
while pos < T: |
|
|
seg = x[..., pos:pos+self.fft_size] |
|
|
if seg.shape[-1] < self.fft_size: |
|
|
seg = F.pad(seg, (0, self.fft_size - seg.shape[-1])) |
|
|
|
|
|
y = torch.fft.irfft( |
|
|
torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), |
|
|
n=self.fft_size |
|
|
) |
|
|
outs.append(y[..., overlap:overlap+block]) |
|
|
pos += block |
|
|
|
|
|
return torch.cat(outs, dim=-1)[..., :T] |
|
|
|
|
|
|
|
|
class LocalConv1D(nn.Module): |
|
|
def __init__(self, d_model, k): |
|
|
super().__init__() |
|
|
self.k = k |
|
|
self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model) |
|
|
self.pw = nn.Conv1d(d_model, d_model, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
x = F.pad(x, (self.k - 1, 0)) |
|
|
return self.pw(F.relu(self.dw(x))) |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, d_model, use_global): |
|
|
super().__init__() |
|
|
self.use_global = use_global |
|
|
|
|
|
self.ln1 = nn.LayerNorm(d_model) |
|
|
self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE) |
|
|
|
|
|
if use_global: |
|
|
self.ln2 = nn.LayerNorm(d_model) |
|
|
self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE) |
|
|
|
|
|
self.ln3 = nn.LayerNorm(d_model) |
|
|
self.ff = nn.Sequential( |
|
|
nn.Linear(d_model, d_model*4), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_model*4, d_model) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2) |
|
|
if self.use_global: |
|
|
x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2) |
|
|
return x + self.ff(self.ln3(x)) |
|
|
|
|
|
|
|
|
class GCLM(nn.Module): |
|
|
def __init__(self, vocab): |
|
|
super().__init__() |
|
|
self.emb = nn.Embedding(vocab, D_MODEL) |
|
|
self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL) |
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0) |
|
|
for i in range(N_LAYERS) |
|
|
]) |
|
|
|
|
|
self.ln = nn.LayerNorm(D_MODEL) |
|
|
self.head = nn.Linear(D_MODEL, vocab) |
|
|
|
|
|
|
|
|
self.head.weight = self.emb.weight |
|
|
|
|
|
def forward(self, x): |
|
|
T = x.size(1) |
|
|
h = self.emb(x) + self.pos(torch.arange(T, device=x.device)) |
|
|
for layer in self.layers: |
|
|
h = layer(h) |
|
|
return self.head(self.ln(h)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_params(num): |
|
|
if num >= 1_000_000_000: |
|
|
return f"{num/1_000_000_000:.1f}B" |
|
|
elif num >= 1_000_000: |
|
|
return f"{num/1_000_000:.1f}M" |
|
|
else: |
|
|
return f"{num/1_000:.1f}K" |
|
|
|
|
|
@torch.no_grad() |
|
|
def estimate_loss(model, dl, device, ctx): |
|
|
model.eval() |
|
|
losses = [] |
|
|
|
|
|
limit = 50 |
|
|
for i, (x, y) in enumerate(dl): |
|
|
if i >= limit: break |
|
|
x, y = x.to(device), y.to(device) |
|
|
with ctx: |
|
|
logits = model(x) |
|
|
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=PAD_ID) |
|
|
losses.append(loss.item()) |
|
|
model.train() |
|
|
return sum(losses) / len(losses) if losses else 0.0 |
|
|
|
|
|
def train(): |
|
|
if torch.cuda.is_available(): |
|
|
device = "cuda" |
|
|
elif torch.backends.mps.is_available(): |
|
|
device = "mps" |
|
|
else: |
|
|
device = "cpu" |
|
|
print("[INFO] Device:", device) |
|
|
|
|
|
|
|
|
tok = tiktoken.get_encoding(TOKENIZER_NAME) |
|
|
|
|
|
|
|
|
used, id2new = build_dataset_vocab(TXT_PATH, tok, VOCAB_SAVE_PATH) |
|
|
vocab = len(used) + OFFSET |
|
|
|
|
|
|
|
|
print("[INFO] Loading and tokenizing text...") |
|
|
text = open(TXT_PATH, "r", encoding="utf-8").read() |
|
|
if DATA_PCT < 1.0: |
|
|
text = text[:int(len(text) * DATA_PCT)] |
|
|
|
|
|
raw_ids = tok.encode(text) |
|
|
|
|
|
ids = [id2new.get(i, PAD_ID) for i in raw_ids] + [EOS_ID] |
|
|
|
|
|
|
|
|
n = len(ids) |
|
|
split_idx = int(n * 0.9) |
|
|
train_ids = ids[:split_idx] |
|
|
val_ids = ids[split_idx:] |
|
|
|
|
|
print(f"[INFO] Tokens: {n} | Train: {len(train_ids)} | Val: {len(val_ids)}") |
|
|
|
|
|
train_ds = RemappedTextDataset(train_ids, MAX_SEQ_LEN) |
|
|
val_ds = RemappedTextDataset(val_ids, MAX_SEQ_LEN) |
|
|
|
|
|
train_dl = DataLoader(train_ds, batch_size=MICRO_BATCH_SIZE, shuffle=True) |
|
|
val_dl = DataLoader(val_ds, batch_size=MICRO_BATCH_SIZE, shuffle=False) |
|
|
|
|
|
model = GCLM(vocab).to(device) |
|
|
|
|
|
|
|
|
num_params = sum(p.numel() for p in model.parameters()) |
|
|
param_str = format_params(num_params) |
|
|
save_path = f"chatgclm_base_{param_str}.pt" |
|
|
print(f"[INFO] Model parameters: {num_params:,} ({param_str})") |
|
|
print(f"[INFO] Save path: {save_path}") |
|
|
|
|
|
|
|
|
if os.path.exists(save_path): |
|
|
model.load_state_dict(torch.load(save_path, map_location=device)) |
|
|
print(f"[RESUME] Loaded existing checkpoint from {save_path}") |
|
|
|
|
|
if device == "cuda" and COMPILE: |
|
|
print("[INFO] Compiling model with torch.compile...") |
|
|
model = torch.compile( |
|
|
model, |
|
|
mode=COMPILE_MODE, |
|
|
fullgraph=False, |
|
|
backend=COMPILE_BACKEND |
|
|
) |
|
|
|
|
|
opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) |
|
|
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID) |
|
|
|
|
|
|
|
|
if device == "cuda" and USE_AMP: |
|
|
ctx = torch.amp.autocast(device) |
|
|
scaler = torch.amp.GradScaler(device) |
|
|
else: |
|
|
|
|
|
import contextlib |
|
|
ctx = contextlib.nullcontext() |
|
|
scaler = None |
|
|
|
|
|
for ep in range(EPOCHS): |
|
|
print(f"\nEpoch {ep+1}/{EPOCHS}") |
|
|
opt.zero_grad(set_to_none=True) |
|
|
|
|
|
pbar = tqdm(train_dl, desc="Training") |
|
|
running_loss = 0.0 |
|
|
|
|
|
for i, (x, y) in enumerate(pbar): |
|
|
x, y = x.to(device), y.to(device) |
|
|
|
|
|
with ctx: |
|
|
logits = model(x) |
|
|
loss = loss_fn(logits.reshape(-1, vocab), y.reshape(-1)) |
|
|
loss_val = loss.item() |
|
|
loss = loss / GRAD_ACCUM_STEPS |
|
|
|
|
|
if scaler: |
|
|
scaler.scale(loss).backward() |
|
|
else: |
|
|
loss.backward() |
|
|
|
|
|
if (i+1) % GRAD_ACCUM_STEPS == 0: |
|
|
if scaler: |
|
|
scaler.step(opt) |
|
|
scaler.update() |
|
|
else: |
|
|
opt.step() |
|
|
opt.zero_grad(set_to_none=True) |
|
|
|
|
|
|
|
|
running_loss = 0.9 * running_loss + 0.1 * loss_val if running_loss > 0 else loss_val |
|
|
pbar.set_postfix(loss=f"{running_loss:.4f}") |
|
|
|
|
|
|
|
|
val_loss = estimate_loss(model, val_dl, device, ctx) |
|
|
print(f"Epoch {ep+1} finished. Train Loss: {running_loss:.4f} | Val Loss: {val_loss:.4f}") |
|
|
|
|
|
if SAVE_N_EPOCHS and (ep+1) % SAVE_N_EPOCHS == 0: |
|
|
torch.save(model.state_dict(), save_path) |
|
|
print(f"[OK] Saved checkpoint to {save_path}") |
|
|
|
|
|
torch.save(model.state_dict(), save_path) |
|
|
print("[DONE] Training complete.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |
|
|
|