openagi-agi's picture
Upload 4 files
c96148d verified
import os
import torch
from torch import nn
from torch.optim import AdamW
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tokenizers import Tokenizer, models, trainers, pre_tokenizers
import math
# =========================
# Juicy variables
# =========================
DATA_PATH = "dataset_clean.txt" # one text per line
VOCAB_LIMIT = None # None = all tokens, or int = cap vocab
MODEL_DIM = 256
NUM_LAYERS = 6
NUM_HEADS = 4
FF_DIM = 1024
SEQ_LEN = 128
BATCH_SIZE = 64
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 0.01
WARMUP_STEPS = 50
MAX_STEPS = 100
TEMPERATURE = 0.05
OPTIMIZER = "adamw" # "adamw" or "muon"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def estimate_params(vocab_size, model_dim, ff_dim, num_layers, seq_len):
# Embedding + positional
emb_params = vocab_size * model_dim
pos_params = seq_len * model_dim
# Per-layer Transformer block
# Attention projections (Q, K, V, O): 4 * d^2
attn_params = 4 * (model_dim ** 2)
# Feed-forward (two linear layers): 2 * d * ff_dim
ff_params = 2 * model_dim * ff_dim
# LayerNorms ~2 * d, negligible compared to above
per_layer = attn_params + ff_params
# Multiply by number of layers
encoder_params = num_layers * per_layer
total = emb_params + pos_params + encoder_params
return {
"embeddings": emb_params,
"positional": pos_params,
"encoder_layers": encoder_params,
"total": total
}
# =========================
# -------------------------
# Build tokenizer from dataset
# -------------------------
def build_tokenizer(data_path, vocab_limit=None):
tokenizer = Tokenizer(models.WordLevel(unk_token="[UNK]"))
if vocab_limit is not None:
trainer = trainers.WordLevelTrainer(
vocab_size=vocab_limit,
min_frequency=1,
special_tokens=["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"]
)
else:
trainer = trainers.WordLevelTrainer(
min_frequency=1,
special_tokens=["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"]
)
tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
with open(data_path, "r", encoding="utf-8") as f:
lines = [line.strip() for line in f if line.strip()]
tokenizer.train_from_iterator(lines, trainer=trainer)
os.makedirs("tokenizer", exist_ok=True)
tokenizer.save("tokenizer/tokenizer.json")
return tokenizer
tokenizer = build_tokenizer(DATA_PATH, VOCAB_LIMIT)
VOCAB_SIZE = tokenizer.get_vocab_size()
print(f"[INFO] Custom vocab size: {VOCAB_SIZE}")
est = estimate_params(VOCAB_SIZE, MODEL_DIM, FF_DIM, NUM_LAYERS, SEQ_LEN)
print("Parameter estimate:")
for k, v in est.items():
print(f"{k:15}: {v:,}")
# -------------------------
# Dataset wrapper
# -------------------------
class TextDataset(Dataset):
def __init__(self, path, tokenizer, seq_len):
with open(path, "r", encoding="utf-8") as f:
self.lines = [line.strip() for line in f if line.strip()]
self.tokenizer = tokenizer
self.seq_len = seq_len
self.pad_id = self.tokenizer.token_to_id("[PAD]")
def __len__(self):
return len(self.lines)
def __getitem__(self, idx):
tokens = self.tokenizer.encode(self.lines[idx]).ids
# pad / truncate
tokens = tokens[:self.seq_len]
tokens += [self.pad_id] * (self.seq_len - len(tokens))
return torch.tensor(tokens, dtype=torch.long)
dataset = TextDataset(DATA_PATH, tokenizer, SEQ_LEN)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# -------------------------
# Transformer Encoder
# -------------------------
class TransformerEncoder(nn.Module):
def __init__(self):
super().__init__()
self.token_emb = nn.Embedding(VOCAB_SIZE, MODEL_DIM)
self.pos_emb = nn.Embedding(SEQ_LEN, MODEL_DIM)
encoder_layer = nn.TransformerEncoderLayer(
d_model=MODEL_DIM,
nhead=NUM_HEADS,
dim_feedforward=FF_DIM,
activation="gelu",
batch_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=NUM_LAYERS)
self.norm = nn.LayerNorm(MODEL_DIM)
def forward(self, x):
positions = torch.arange(0, x.size(1), device=x.device).unsqueeze(0)
h = self.token_emb(x) + self.pos_emb(positions)
h = self.encoder(h)
h = self.norm(h)
return h.mean(dim=1) # pooled embedding
# -------------------------
# Contrastive loss
# -------------------------
def contrastive_loss(z1, z2, temperature=TEMPERATURE):
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
logits = z1 @ z2.t() / temperature
labels = torch.arange(z1.size(0), device=z1.device)
return F.cross_entropy(logits, labels)
# -------------------------
# Setup
# -------------------------
model = TransformerEncoder().to(DEVICE)
if OPTIMIZER == "adamw":
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
elif OPTIMIZER == "muon":
from muon import Muon
optimizer = Muon(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
else:
raise ValueError("Invalid optimizer")
def lr_lambda(step):
if step < WARMUP_STEPS:
return float(step) / float(max(1, WARMUP_STEPS))
progress = float(step - WARMUP_STEPS) / float(max(1, MAX_STEPS - WARMUP_STEPS))
return 0.5 * (1.0 + math.cos(math.pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# -------------------------
# Training loop
# -------------------------
step = 0
while step < MAX_STEPS:
for batch in loader:
if step >= MAX_STEPS:
break
x = batch.to(DEVICE)
# "Augment" — here just duplicate batch (replace with dropout/noise if you want)
z1 = model(x)
z2 = model(x)
loss = contrastive_loss(z1, z2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
if step % 100 == 0:
print(f"Step {step}: loss={loss.item():.4f}, lr={scheduler.get_last_lr()[0]:.6f}")
step += 1
print("[DONE] Training complete")
print("[INFO] Saving model...")
torch.save(model.state_dict(), "ckpt.pt")
print("[DONE] Model saved to ckpt.pt")