|
|
import os
|
|
|
import torch
|
|
|
from torch import nn
|
|
|
from torch.optim import AdamW
|
|
|
import torch.nn.functional as F
|
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
from tokenizers import Tokenizer, models, trainers, pre_tokenizers
|
|
|
import math
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DATA_PATH = "dataset_clean.txt"
|
|
|
VOCAB_LIMIT = None
|
|
|
MODEL_DIM = 256
|
|
|
NUM_LAYERS = 6
|
|
|
NUM_HEADS = 4
|
|
|
FF_DIM = 1024
|
|
|
SEQ_LEN = 128
|
|
|
|
|
|
BATCH_SIZE = 64
|
|
|
LEARNING_RATE = 3e-4
|
|
|
WEIGHT_DECAY = 0.01
|
|
|
WARMUP_STEPS = 50
|
|
|
MAX_STEPS = 100
|
|
|
TEMPERATURE = 0.05
|
|
|
|
|
|
OPTIMIZER = "adamw"
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
def estimate_params(vocab_size, model_dim, ff_dim, num_layers, seq_len):
|
|
|
|
|
|
emb_params = vocab_size * model_dim
|
|
|
pos_params = seq_len * model_dim
|
|
|
|
|
|
|
|
|
|
|
|
attn_params = 4 * (model_dim ** 2)
|
|
|
|
|
|
ff_params = 2 * model_dim * ff_dim
|
|
|
|
|
|
per_layer = attn_params + ff_params
|
|
|
|
|
|
|
|
|
encoder_params = num_layers * per_layer
|
|
|
|
|
|
total = emb_params + pos_params + encoder_params
|
|
|
return {
|
|
|
"embeddings": emb_params,
|
|
|
"positional": pos_params,
|
|
|
"encoder_layers": encoder_params,
|
|
|
"total": total
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_tokenizer(data_path, vocab_limit=None):
|
|
|
tokenizer = Tokenizer(models.WordLevel(unk_token="[UNK]"))
|
|
|
if vocab_limit is not None:
|
|
|
trainer = trainers.WordLevelTrainer(
|
|
|
vocab_size=vocab_limit,
|
|
|
min_frequency=1,
|
|
|
special_tokens=["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"]
|
|
|
)
|
|
|
else:
|
|
|
trainer = trainers.WordLevelTrainer(
|
|
|
min_frequency=1,
|
|
|
special_tokens=["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"]
|
|
|
)
|
|
|
tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
|
|
|
|
|
|
with open(data_path, "r", encoding="utf-8") as f:
|
|
|
lines = [line.strip() for line in f if line.strip()]
|
|
|
|
|
|
tokenizer.train_from_iterator(lines, trainer=trainer)
|
|
|
os.makedirs("tokenizer", exist_ok=True)
|
|
|
tokenizer.save("tokenizer/tokenizer.json")
|
|
|
return tokenizer
|
|
|
|
|
|
tokenizer = build_tokenizer(DATA_PATH, VOCAB_LIMIT)
|
|
|
VOCAB_SIZE = tokenizer.get_vocab_size()
|
|
|
print(f"[INFO] Custom vocab size: {VOCAB_SIZE}")
|
|
|
|
|
|
est = estimate_params(VOCAB_SIZE, MODEL_DIM, FF_DIM, NUM_LAYERS, SEQ_LEN)
|
|
|
print("Parameter estimate:")
|
|
|
for k, v in est.items():
|
|
|
print(f"{k:15}: {v:,}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TextDataset(Dataset):
|
|
|
def __init__(self, path, tokenizer, seq_len):
|
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
|
self.lines = [line.strip() for line in f if line.strip()]
|
|
|
self.tokenizer = tokenizer
|
|
|
self.seq_len = seq_len
|
|
|
self.pad_id = self.tokenizer.token_to_id("[PAD]")
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.lines)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
tokens = self.tokenizer.encode(self.lines[idx]).ids
|
|
|
|
|
|
tokens = tokens[:self.seq_len]
|
|
|
tokens += [self.pad_id] * (self.seq_len - len(tokens))
|
|
|
return torch.tensor(tokens, dtype=torch.long)
|
|
|
|
|
|
dataset = TextDataset(DATA_PATH, tokenizer, SEQ_LEN)
|
|
|
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerEncoder(nn.Module):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.token_emb = nn.Embedding(VOCAB_SIZE, MODEL_DIM)
|
|
|
self.pos_emb = nn.Embedding(SEQ_LEN, MODEL_DIM)
|
|
|
|
|
|
encoder_layer = nn.TransformerEncoderLayer(
|
|
|
d_model=MODEL_DIM,
|
|
|
nhead=NUM_HEADS,
|
|
|
dim_feedforward=FF_DIM,
|
|
|
activation="gelu",
|
|
|
batch_first=True
|
|
|
)
|
|
|
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=NUM_LAYERS)
|
|
|
self.norm = nn.LayerNorm(MODEL_DIM)
|
|
|
|
|
|
def forward(self, x):
|
|
|
positions = torch.arange(0, x.size(1), device=x.device).unsqueeze(0)
|
|
|
h = self.token_emb(x) + self.pos_emb(positions)
|
|
|
h = self.encoder(h)
|
|
|
h = self.norm(h)
|
|
|
return h.mean(dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def contrastive_loss(z1, z2, temperature=TEMPERATURE):
|
|
|
z1 = F.normalize(z1, dim=1)
|
|
|
z2 = F.normalize(z2, dim=1)
|
|
|
logits = z1 @ z2.t() / temperature
|
|
|
labels = torch.arange(z1.size(0), device=z1.device)
|
|
|
return F.cross_entropy(logits, labels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = TransformerEncoder().to(DEVICE)
|
|
|
|
|
|
if OPTIMIZER == "adamw":
|
|
|
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
|
|
|
elif OPTIMIZER == "muon":
|
|
|
from muon import Muon
|
|
|
optimizer = Muon(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
|
|
|
else:
|
|
|
raise ValueError("Invalid optimizer")
|
|
|
|
|
|
def lr_lambda(step):
|
|
|
if step < WARMUP_STEPS:
|
|
|
return float(step) / float(max(1, WARMUP_STEPS))
|
|
|
progress = float(step - WARMUP_STEPS) / float(max(1, MAX_STEPS - WARMUP_STEPS))
|
|
|
return 0.5 * (1.0 + math.cos(math.pi * progress))
|
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
step = 0
|
|
|
while step < MAX_STEPS:
|
|
|
for batch in loader:
|
|
|
if step >= MAX_STEPS:
|
|
|
break
|
|
|
|
|
|
x = batch.to(DEVICE)
|
|
|
|
|
|
z1 = model(x)
|
|
|
z2 = model(x)
|
|
|
|
|
|
loss = contrastive_loss(z1, z2)
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
scheduler.step()
|
|
|
|
|
|
if step % 100 == 0:
|
|
|
print(f"Step {step}: loss={loss.item():.4f}, lr={scheduler.get_last_lr()[0]:.6f}")
|
|
|
|
|
|
step += 1
|
|
|
|
|
|
print("[DONE] Training complete")
|
|
|
print("[INFO] Saving model...")
|
|
|
torch.save(model.state_dict(), "ckpt.pt")
|
|
|
print("[DONE] Model saved to ckpt.pt") |