codsworth-3.8m / codsworth /scripts /train_full.py
Jaqshanahan's picture
Initial upload of Codsworth model
b84d85a verified
"""
╔══════════════════════════════════════════════════════════════════════════════╗
β•‘ CODSWORTH TRAINING SCRIPT β•‘
β•‘ Transformer Language Model from Scratch β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
"""
import sys
sys.path.insert(0, '.')
import json
import glob
import torch
from datetime import datetime
# Color codes
GREEN = '\033[92m'
YELLOW = '\033[93m'
RED = '\033[91m'
BLUE = '\033[94m'
CYAN = '\033[96m'
BOLD = '\033[1m'
RESET = '\033[0m'
def color_print(text, color=GREEN):
print(f"{color}{text}{RESET}")
def header_print(text):
print(f"\n{BOLD}{CYAN}{'='*60}{RESET}")
print(f"{BOLD}{CYAN}{text:^60}{RESET}")
print(f"{BOLD}{CYAN}{'='*60}{RESET}\n")
from codsworth.config import CodsworthConfig
from codsworth.model import CodsworthTransformer
from codsworth.utils import setup_logging, set_seed, get_device, AverageMeter
setup_logging()
set_seed(42)
header_print("CODSWORTH TRAINER")
# Device info
device = get_device()
color_print(f"πŸ“± Using device: {device}", BLUE)
# Load vocabulary
header_print("LOADING VOCABULARY")
with open("tokenizer.json") as f:
vocab = json.load(f)
vocab_size = len(vocab)
color_print(f"πŸ“š Loaded vocabulary: {vocab_size:,} words", GREEN)
# Model config
header_print("MODEL CONFIGURATION")
config = CodsworthConfig(
vocab_size=vocab_size,
context_length=128,
embedding_dim=256,
num_heads=4,
ffn_hidden_dim=512,
num_layers=2,
use_flash_attention=False,
use_gradient_checkpointing=False,
dropout=0.1,
)
color_print(f"πŸ—οΈ Model: {config.num_layers} layers, {config.embedding_dim}d embed, {config.num_heads} heads", BLUE)
color_print(f"πŸ“Š Parameters: {config.estimate_parameters():,}", YELLOW)
# Create model
model = CodsworthTransformer(config).to(device)
model.train()
color_print("βœ… Model created successfully!", GREEN)
# Load data
header_print("LOADING TRAINING DATA")
def encode_text(text):
words = text.lower().split()
return [vocab.get(w, vocab.get("<unk>", 1)) for w in words]
all_tokens = []
for f in glob.glob("data/train/*.txt")[:1]:
with open(f, 'r', encoding='utf-8', errors='ignore') as file:
text = file.read(100000)
tokens = encode_text(text)
all_tokens.extend(tokens)
color_print(f"πŸ“„ Loaded {len(all_tokens):,} tokens", GREEN)
# Training setup
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
meter = AverageMeter("loss")
# Training loop
header_print("TRAINING STARTED")
start_time = datetime.now()
naN_count = 0
for step in range(2000):
idx = (step * 64) % (len(all_tokens) - config.context_length - 1)
input_ids = all_tokens[idx:idx + config.context_length]
labels = all_tokens[idx + 1:idx + config.context_length + 1]
input_t = torch.tensor([input_ids], dtype=torch.long).to(device)
labels_t = torch.tensor([labels], dtype=torch.long).to(device)
outputs = model(input_ids=input_t, labels=labels_t)
loss = outputs["loss"]
if torch.isnan(loss):
naN_count += 1
color_print(f"⚠️ Step {step + 1}: NaN detected ({naN_count}x)", RED)
if naN_count >= 3:
color_print("❌ Too many NaNs, aborting!", RED)
break
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
continue
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
meter.update(loss.item())
if (step + 1) % 10 == 0:
elapsed = (datetime.now() - start_time).total_seconds()
speed = (step + 1) / elapsed
color_print(f"Step {step + 1:3d} | Loss: {loss.item():.4f} | Avg: {meter.avg:.4f} | Speed: {speed:.1f} step/s", GREEN)
elapsed_time = (datetime.now() - start_time).total_seconds()
# Save model
header_print("SAVING MODEL")
torch.save(model.state_dict(), "codsworth_model.pt")
color_print("πŸ’Ύ Model saved to: codsworth_model.pt", GREEN)
# Test generation
header_print("GENERATION TEST")
model.eval()
prompt_ids = [vocab.get("the", vocab.get("<unk>", 1))]
for _ in range(30):
inp = prompt_ids[-config.context_length:] + [0] * max(0, config.context_length - len(prompt_ids))
with torch.no_grad():
logits = model(torch.tensor([inp]).to(device))["logits"]
probs = torch.softmax(logits[0, -1], dim=-1)
next_tok = torch.multinomial(probs, 1).item()
prompt_ids.append(next_tok)
if next_tok == vocab.get("<eos>", 3):
break
id_to_word = {v: k for k, v in vocab.items()}
words = [id_to_word.get(t, "<unk>") for t in prompt_ids]
color_print(f"πŸ“ Generated: {' '.join(words)}", CYAN)
# Summary
header_print("πŸ“‹ TRAINING SUMMARY")
print(f"""
{BOLD}β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ TRAINING COMPLETE β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Steps Trained: {step + 1:>5} β”‚
β”‚ Final Loss: {meter.avg:>5.4f} β”‚
β”‚ Time Elapsed: {elapsed_time:>5.1f}s β”‚
β”‚ NaN Count: {naN_count:>5} β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Model Saved: codsworth_model.pt β”‚
β”‚ Parameters: {model.get_num_params():>10,} β”‚
β”‚ Vocabulary: {vocab_size:>5} words β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜{RESET}
""")
color_print("✨ All done! Model ready for inference.", GREEN)
color_print(" Run: python codsworth/scripts/inference.py --model codsworth_model.pt", BLUE)