import os import torch from torch import nn from torch.optim import AdamW import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from tokenizers import Tokenizer, models, trainers, pre_tokenizers import math # ========================= # Juicy variables # ========================= DATA_PATH = "dataset_clean.txt" # one text per line VOCAB_LIMIT = None # None = all tokens, or int = cap vocab MODEL_DIM = 256 NUM_LAYERS = 6 NUM_HEADS = 4 FF_DIM = 1024 SEQ_LEN = 128 BATCH_SIZE = 64 LEARNING_RATE = 3e-4 WEIGHT_DECAY = 0.01 WARMUP_STEPS = 50 MAX_STEPS = 100 TEMPERATURE = 0.05 OPTIMIZER = "adamw" # "adamw" or "muon" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def estimate_params(vocab_size, model_dim, ff_dim, num_layers, seq_len): # Embedding + positional emb_params = vocab_size * model_dim pos_params = seq_len * model_dim # Per-layer Transformer block # Attention projections (Q, K, V, O): 4 * d^2 attn_params = 4 * (model_dim ** 2) # Feed-forward (two linear layers): 2 * d * ff_dim ff_params = 2 * model_dim * ff_dim # LayerNorms ~2 * d, negligible compared to above per_layer = attn_params + ff_params # Multiply by number of layers encoder_params = num_layers * per_layer total = emb_params + pos_params + encoder_params return { "embeddings": emb_params, "positional": pos_params, "encoder_layers": encoder_params, "total": total } # ========================= # ------------------------- # Build tokenizer from dataset # ------------------------- def build_tokenizer(data_path, vocab_limit=None): tokenizer = Tokenizer(models.WordLevel(unk_token="[UNK]")) if vocab_limit is not None: trainer = trainers.WordLevelTrainer( vocab_size=vocab_limit, min_frequency=1, special_tokens=["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"] ) else: trainer = trainers.WordLevelTrainer( min_frequency=1, special_tokens=["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"] ) tokenizer.pre_tokenizer = pre_tokenizers.Whitespace() with open(data_path, "r", encoding="utf-8") as f: lines = [line.strip() for line in f if line.strip()] tokenizer.train_from_iterator(lines, trainer=trainer) os.makedirs("tokenizer", exist_ok=True) tokenizer.save("tokenizer/tokenizer.json") return tokenizer tokenizer = build_tokenizer(DATA_PATH, VOCAB_LIMIT) VOCAB_SIZE = tokenizer.get_vocab_size() print(f"[INFO] Custom vocab size: {VOCAB_SIZE}") est = estimate_params(VOCAB_SIZE, MODEL_DIM, FF_DIM, NUM_LAYERS, SEQ_LEN) print("Parameter estimate:") for k, v in est.items(): print(f"{k:15}: {v:,}") # ------------------------- # Dataset wrapper # ------------------------- class TextDataset(Dataset): def __init__(self, path, tokenizer, seq_len): with open(path, "r", encoding="utf-8") as f: self.lines = [line.strip() for line in f if line.strip()] self.tokenizer = tokenizer self.seq_len = seq_len self.pad_id = self.tokenizer.token_to_id("[PAD]") def __len__(self): return len(self.lines) def __getitem__(self, idx): tokens = self.tokenizer.encode(self.lines[idx]).ids # pad / truncate tokens = tokens[:self.seq_len] tokens += [self.pad_id] * (self.seq_len - len(tokens)) return torch.tensor(tokens, dtype=torch.long) dataset = TextDataset(DATA_PATH, tokenizer, SEQ_LEN) loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) # ------------------------- # Transformer Encoder # ------------------------- class TransformerEncoder(nn.Module): def __init__(self): super().__init__() self.token_emb = nn.Embedding(VOCAB_SIZE, MODEL_DIM) self.pos_emb = nn.Embedding(SEQ_LEN, MODEL_DIM) encoder_layer = nn.TransformerEncoderLayer( d_model=MODEL_DIM, nhead=NUM_HEADS, dim_feedforward=FF_DIM, activation="gelu", batch_first=True ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=NUM_LAYERS) self.norm = nn.LayerNorm(MODEL_DIM) def forward(self, x): positions = torch.arange(0, x.size(1), device=x.device).unsqueeze(0) h = self.token_emb(x) + self.pos_emb(positions) h = self.encoder(h) h = self.norm(h) return h.mean(dim=1) # pooled embedding # ------------------------- # Contrastive loss # ------------------------- def contrastive_loss(z1, z2, temperature=TEMPERATURE): z1 = F.normalize(z1, dim=1) z2 = F.normalize(z2, dim=1) logits = z1 @ z2.t() / temperature labels = torch.arange(z1.size(0), device=z1.device) return F.cross_entropy(logits, labels) # ------------------------- # Setup # ------------------------- model = TransformerEncoder().to(DEVICE) if OPTIMIZER == "adamw": optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) elif OPTIMIZER == "muon": from muon import Muon optimizer = Muon(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) else: raise ValueError("Invalid optimizer") def lr_lambda(step): if step < WARMUP_STEPS: return float(step) / float(max(1, WARMUP_STEPS)) progress = float(step - WARMUP_STEPS) / float(max(1, MAX_STEPS - WARMUP_STEPS)) return 0.5 * (1.0 + math.cos(math.pi * progress)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # ------------------------- # Training loop # ------------------------- step = 0 while step < MAX_STEPS: for batch in loader: if step >= MAX_STEPS: break x = batch.to(DEVICE) # "Augment" — here just duplicate batch (replace with dropout/noise if you want) z1 = model(x) z2 = model(x) loss = contrastive_loss(z1, z2) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() if step % 100 == 0: print(f"Step {step}: loss={loss.item():.4f}, lr={scheduler.get_last_lr()[0]:.6f}") step += 1 print("[DONE] Training complete") print("[INFO] Saving model...") torch.save(model.state_dict(), "ckpt.pt") print("[DONE] Model saved to ckpt.pt")