krystv's picture
Add training script optimized for free Colab
322ae07 verified
"""
NeuroLex v4 โ€” Training Script
==============================
Optimized for free Google Colab (T4 16GB GPU).
Training time: ~20-30 minutes for 'base' model (12M params)
Memory usage: <8GB GPU RAM with batch_size=256
Usage:
python train.py --size base --epochs 30 --batch_size 256
Or in Colab:
%run train.py --size base --epochs 30
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
import time
import argparse
import os
import json
import math
from pathlib import Path
from neurolex_v4_model import NeuroLexV4, NeuroLexConfig, CharTokenizer, create_model
from neurolex_v4_model import DOMAINS, STYLES, LANGUAGES, DOMAIN_TO_ID, STYLE_TO_ID, LANG_TO_ID
from neurolex_v4_dataset import create_dataloaders, NeuroLexDataset
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# TRAINING CONFIGURATION
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
def parse_args():
parser = argparse.ArgumentParser(description='Train NeuroLex v4')
parser.add_argument('--size', type=str, default='base', choices=['tiny', 'small', 'base', 'large'])
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--warmup_steps', type=int, default=500)
parser.add_argument('--n_samples', type=int, default=100000)
parser.add_argument('--streaming', action='store_true', help='Use streaming dataset (infinite)')
parser.add_argument('--save_dir', type=str, default='./checkpoints')
parser.add_argument('--log_every', type=int, default=100)
parser.add_argument('--sample_every', type=int, default=3, help='Sample every N epochs')
parser.add_argument('--device', type=str, default='auto')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--gradient_clip', type=float, default=1.0)
parser.add_argument('--weight_decay', type=float, default=0.01)
parser.add_argument('--num_workers', type=int, default=2)
return parser.parse_args()
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# TRAINING LOOP
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
class Trainer:
def __init__(self, model, config, args, device):
self.model = model.to(device)
self.config = config
self.args = args
self.device = device
self.tokenizer = CharTokenizer()
# Optimizer: AdamW with weight decay
self.optimizer = AdamW(
model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay,
betas=(0.9, 0.98), # Standard for transformers
eps=1e-8
)
# Learning rate schedule: linear warmup โ†’ cosine decay
total_steps = args.epochs * (args.n_samples // args.batch_size)
warmup_scheduler = LinearLR(
self.optimizer, start_factor=0.01, end_factor=1.0,
total_iters=args.warmup_steps
)
cosine_scheduler = CosineAnnealingLR(
self.optimizer, T_max=total_steps - args.warmup_steps,
eta_min=args.lr * 0.01
)
self.scheduler = SequentialLR(
self.optimizer, schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[args.warmup_steps]
)
# Tracking
self.global_step = 0
self.best_loss = float('inf')
self.train_losses = []
self.val_losses = []
# Save directory
os.makedirs(args.save_dir, exist_ok=True)
def train_step(self, batch):
"""Single training step."""
self.model.train()
input_ids = batch['input_ids'].to(self.device)
domain = batch['domain'].to(self.device)
style = batch['style'].to(self.device)
language = batch['language'].to(self.device)
length = batch['length'].to(self.device)
# Compute UDLM loss
loss = self.model.compute_loss(input_ids, domain, style, language, length)
# Backward pass
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping
if self.args.gradient_clip > 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip)
self.optimizer.step()
self.scheduler.step()
self.global_step += 1
return loss.item()
@torch.no_grad()
def validate(self, val_loader):
"""Run validation."""
self.model.eval()
total_loss = 0
n_batches = 0
for batch in val_loader:
input_ids = batch['input_ids'].to(self.device)
domain = batch['domain'].to(self.device)
style = batch['style'].to(self.device)
language = batch['language'].to(self.device)
length = batch['length'].to(self.device)
loss = self.model.compute_loss(input_ids, domain, style, language, length)
total_loss += loss.item()
n_batches += 1
return total_loss / max(n_batches, 1)
@torch.no_grad()
def generate_samples(self, n_per_category: int = 8):
"""Generate sample names for quality monitoring."""
self.model.eval()
# Test different domain/style/language combinations
test_configs = [
('tech', 'sharp', 'english', 8),
('tech', 'futuristic', 'japanese', 7),
('food', 'warm', 'french', 7),
('gaming', 'bold', 'english', 9),
('luxury', 'elegant', 'italian', 8),
('health', 'organic', 'hawaiian', 7),
('crypto', 'futuristic', 'greek', 8),
('music', 'playful', 'spanish', 7),
('ai', 'sharp', 'latin', 6),
('eco', 'warm', 'swedish', 7),
('fitness', 'bold', 'german', 8),
('social', 'playful', 'korean', 6),
]
all_names = []
results = {}
for domain, style, lang, length in test_configs:
names = self.model.generate(
domain_id=DOMAIN_TO_ID[domain],
style_id=STYLE_TO_ID[style],
lang_id=LANG_TO_ID[lang],
target_length=length,
batch_size=n_per_category,
cfg_scale=self.config.cfg_scale,
temperature=self.config.temperature,
n_steps=60, # Fewer steps for quick preview
odd_alpha=self.config.odd_alpha,
device=str(self.device)
)
key = f"{domain}/{style}/{lang}"
results[key] = names
all_names.extend(names)
# Compute diversity metrics
unique_names = set(n.lower() for n in all_names)
uniqueness = len(unique_names) / max(len(all_names), 1) * 100
avg_len = sum(len(n) for n in all_names) / max(len(all_names), 1)
return results, uniqueness, avg_len, all_names
def save_checkpoint(self, path: str, is_best: bool = False):
"""Save model checkpoint."""
checkpoint = {
'config': vars(self.config),
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict(),
'global_step': self.global_step,
'best_loss': self.best_loss,
'train_losses': self.train_losses,
'val_losses': self.val_losses,
'vocab_size': self.tokenizer.vocab_size,
'vocab': self.tokenizer.vocab,
}
torch.save(checkpoint, path)
if is_best:
best_path = os.path.join(self.args.save_dir, 'neurolex_v4_best.pt')
torch.save(checkpoint, best_path)
def train(self, train_loader, val_loader):
"""Full training loop."""
total_steps = self.args.epochs * len(train_loader)
print("=" * 70)
print(" NEUROLEX v4 โ€” TRAINING")
print("=" * 70)
print(f" Model size: {self.model.count_parameters():,} parameters")
print(f" Epochs: {self.args.epochs}")
print(f" Batch size: {self.args.batch_size}")
print(f" Steps/epoch: {len(train_loader)}")
print(f" Total steps: {total_steps}")
print(f" Learning rate: {self.args.lr} (cosine decay)")
print(f" Warmup steps: {self.args.warmup_steps}")
print(f" Device: {self.device}")
print(f" CFG dropout: {self.config.cfg_dropout}")
print(f" Noise schedule: {self.config.noise_schedule}")
print("=" * 70)
start_time = time.time()
for epoch in range(1, self.args.epochs + 1):
epoch_start = time.time()
epoch_loss = 0
n_steps_epoch = 0
for batch_idx, batch in enumerate(train_loader):
loss = self.train_step(batch)
epoch_loss += loss
n_steps_epoch += 1
# Logging
if self.global_step % self.args.log_every == 0:
avg_loss = epoch_loss / n_steps_epoch
lr = self.scheduler.get_last_lr()[0]
elapsed = time.time() - start_time
eta = elapsed / self.global_step * (total_steps - self.global_step)
print(f" Step {self.global_step:>6}/{total_steps} | "
f"Loss={loss:.4f} | Avg={avg_loss:.4f} | "
f"LR={lr:.2e} | ETA={eta/60:.0f}min")
# End of epoch
epoch_time = time.time() - epoch_start
avg_epoch_loss = epoch_loss / n_steps_epoch
self.train_losses.append(avg_epoch_loss)
# Validation
val_loss = self.validate(val_loader)
self.val_losses.append(val_loss)
print(f"\n Epoch {epoch}/{self.args.epochs} | "
f"Train={avg_epoch_loss:.4f} | Val={val_loss:.4f} | "
f"Time={epoch_time:.0f}s")
# Save best model
is_best = val_loss < self.best_loss
if is_best:
self.best_loss = val_loss
print(f" โ˜… New best validation loss: {val_loss:.4f}")
self.save_checkpoint(
os.path.join(self.args.save_dir, f'neurolex_v4_epoch{epoch}.pt'),
is_best=is_best
)
# Generate samples periodically
if epoch % self.args.sample_every == 0 or epoch == self.args.epochs:
print(f"\n --- Sample Generations (Epoch {epoch}) ---")
results, uniqueness, avg_len, all_names = self.generate_samples(n_per_category=8)
for key, names in results.items():
if names:
print(f" {key:30s}: {', '.join(names[:5])}")
print(f"\n Diversity: {uniqueness:.1f}% unique | Avg length: {avg_len:.1f}")
print(f" Total generated: {len(all_names)} | Unique: {len(set(n.lower() for n in all_names))}")
print()
# Final summary
total_time = time.time() - start_time
print("=" * 70)
print(" TRAINING COMPLETE")
print("=" * 70)
print(f" Total time: {total_time/60:.1f} minutes")
print(f" Best val loss: {self.best_loss:.4f}")
print(f" Final train loss: {self.train_losses[-1]:.4f}")
print(f" Loss curve: {' -> '.join(f'{l:.3f}' for l in self.train_losses[:5])}")
if len(self.train_losses) > 5:
print(f" ... -> {' -> '.join(f'{l:.3f}' for l in self.train_losses[-3:])}")
# Final generation showcase
print("\n" + "=" * 70)
print(" FINAL GENERATION SHOWCASE")
print("=" * 70)
results, uniqueness, avg_len, all_names = self.generate_samples(n_per_category=12)
for key, names in results.items():
domain, style, lang = key.split('/')
emoji_map = {
'tech': '๐Ÿ’ป', 'food': '๐Ÿœ', 'gaming': '๐ŸŽฎ', 'luxury': '๐Ÿ’Ž',
'health': '๐ŸŒฟ', 'crypto': '๐Ÿช™', 'music': '๐ŸŽต', 'ai': '๐Ÿค–',
'eco': 'โ™ป๏ธ', 'fitness': '๐Ÿ’ช', 'social': '๐ŸŒ', 'general': 'โญ'
}
emoji = emoji_map.get(domain, 'โ€ข')
if names:
print(f"\n {emoji} {domain.upper()} ({style}, {lang}):")
for name in names[:8]:
print(f" โ†’ {name}")
print(f"\n {'โ”€' * 50}")
print(f" ๐Ÿ“Š DIVERSITY SCORE: {uniqueness:.1f}%")
print(f" ๐Ÿ“ AVG LENGTH: {avg_len:.1f} chars")
print(f" ๐ŸŽฏ TOTAL UNIQUE: {len(set(n.lower() for n in all_names))}/{len(all_names)}")
print("=" * 70)
return self.best_loss
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# MAIN
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
def main():
args = parse_args()
# Set seed
torch.manual_seed(args.seed)
# Device
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(args.device)
print(f"Using device: {device}")
if device.type == 'cuda':
print(f" GPU: {torch.cuda.get_device_name()}")
print(f" Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
# Create model
model, config = create_model(args.size)
# Override config with args
config.learning_rate = args.lr
config.batch_size = args.batch_size
config.epochs = args.epochs
config.warmup_steps = args.warmup_steps
# Create dataloaders
train_loader, val_loader = create_dataloaders(
batch_size=args.batch_size,
n_samples=args.n_samples,
num_workers=args.num_workers,
streaming=args.streaming
)
# Create trainer and train
trainer = Trainer(model, config, args, device)
best_loss = trainer.train(train_loader, val_loader)
# Save final model
final_path = os.path.join(args.save_dir, 'neurolex_v4_final.pt')
trainer.save_checkpoint(final_path)
print(f"\nFinal model saved to: {final_path}")
print(f"Best model saved to: {os.path.join(args.save_dir, 'neurolex_v4_best.pt')}")
# Print model loading instructions
print(f"""
{'=' * 70}
HOW TO USE THE TRAINED MODEL
{'=' * 70}
# Load the model:
from neurolex_v4_model import NeuroLexV4, NeuroLexConfig, CharTokenizer
from neurolex_v4_model import DOMAIN_TO_ID, STYLE_TO_ID, LANG_TO_ID
checkpoint = torch.load('{final_path}')
config = NeuroLexConfig(**checkpoint['config'])
model = NeuroLexV4(config)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
# Generate names:
names = model.generate(
domain_id=DOMAIN_TO_ID['tech'],
style_id=STYLE_TO_ID['sharp'],
lang_id=LANG_TO_ID['english'],
target_length=8,
batch_size=20,
cfg_scale=2.5, # higher = more condition-faithful
temperature=0.9, # higher = more creative
n_steps=80, # more steps = better quality
odd_alpha=8.0, # higher = more diversity between samples
device='cuda'
)
print(names)
""")
if __name__ == '__main__':
main()