""" NeuroLex v4 — Training Script ============================== Optimized for free Google Colab (T4 16GB GPU). Training time: ~20-30 minutes for 'base' model (12M params) Memory usage: <8GB GPU RAM with batch_size=256 Usage: python train.py --size base --epochs 30 --batch_size 256 Or in Colab: %run train.py --size base --epochs 30 """ import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR import time import argparse import os import json import math from pathlib import Path from neurolex_v4_model import NeuroLexV4, NeuroLexConfig, CharTokenizer, create_model from neurolex_v4_model import DOMAINS, STYLES, LANGUAGES, DOMAIN_TO_ID, STYLE_TO_ID, LANG_TO_ID from neurolex_v4_dataset import create_dataloaders, NeuroLexDataset # ═══════════════════════════════════════════════════════════════ # TRAINING CONFIGURATION # ═══════════════════════════════════════════════════════════════ def parse_args(): parser = argparse.ArgumentParser(description='Train NeuroLex v4') parser.add_argument('--size', type=str, default='base', choices=['tiny', 'small', 'base', 'large']) parser.add_argument('--epochs', type=int, default=30) parser.add_argument('--batch_size', type=int, default=256) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--warmup_steps', type=int, default=500) parser.add_argument('--n_samples', type=int, default=100000) parser.add_argument('--streaming', action='store_true', help='Use streaming dataset (infinite)') parser.add_argument('--save_dir', type=str, default='./checkpoints') parser.add_argument('--log_every', type=int, default=100) parser.add_argument('--sample_every', type=int, default=3, help='Sample every N epochs') parser.add_argument('--device', type=str, default='auto') parser.add_argument('--seed', type=int, default=42) parser.add_argument('--gradient_clip', type=float, default=1.0) parser.add_argument('--weight_decay', type=float, default=0.01) parser.add_argument('--num_workers', type=int, default=2) return parser.parse_args() # ═══════════════════════════════════════════════════════════════ # TRAINING LOOP # ═══════════════════════════════════════════════════════════════ class Trainer: def __init__(self, model, config, args, device): self.model = model.to(device) self.config = config self.args = args self.device = device self.tokenizer = CharTokenizer() # Optimizer: AdamW with weight decay self.optimizer = AdamW( model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.98), # Standard for transformers eps=1e-8 ) # Learning rate schedule: linear warmup → cosine decay total_steps = args.epochs * (args.n_samples // args.batch_size) warmup_scheduler = LinearLR( self.optimizer, start_factor=0.01, end_factor=1.0, total_iters=args.warmup_steps ) cosine_scheduler = CosineAnnealingLR( self.optimizer, T_max=total_steps - args.warmup_steps, eta_min=args.lr * 0.01 ) self.scheduler = SequentialLR( self.optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[args.warmup_steps] ) # Tracking self.global_step = 0 self.best_loss = float('inf') self.train_losses = [] self.val_losses = [] # Save directory os.makedirs(args.save_dir, exist_ok=True) def train_step(self, batch): """Single training step.""" self.model.train() input_ids = batch['input_ids'].to(self.device) domain = batch['domain'].to(self.device) style = batch['style'].to(self.device) language = batch['language'].to(self.device) length = batch['length'].to(self.device) # Compute UDLM loss loss = self.model.compute_loss(input_ids, domain, style, language, length) # Backward pass self.optimizer.zero_grad() loss.backward() # Gradient clipping if self.args.gradient_clip > 0: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip) self.optimizer.step() self.scheduler.step() self.global_step += 1 return loss.item() @torch.no_grad() def validate(self, val_loader): """Run validation.""" self.model.eval() total_loss = 0 n_batches = 0 for batch in val_loader: input_ids = batch['input_ids'].to(self.device) domain = batch['domain'].to(self.device) style = batch['style'].to(self.device) language = batch['language'].to(self.device) length = batch['length'].to(self.device) loss = self.model.compute_loss(input_ids, domain, style, language, length) total_loss += loss.item() n_batches += 1 return total_loss / max(n_batches, 1) @torch.no_grad() def generate_samples(self, n_per_category: int = 8): """Generate sample names for quality monitoring.""" self.model.eval() # Test different domain/style/language combinations test_configs = [ ('tech', 'sharp', 'english', 8), ('tech', 'futuristic', 'japanese', 7), ('food', 'warm', 'french', 7), ('gaming', 'bold', 'english', 9), ('luxury', 'elegant', 'italian', 8), ('health', 'organic', 'hawaiian', 7), ('crypto', 'futuristic', 'greek', 8), ('music', 'playful', 'spanish', 7), ('ai', 'sharp', 'latin', 6), ('eco', 'warm', 'swedish', 7), ('fitness', 'bold', 'german', 8), ('social', 'playful', 'korean', 6), ] all_names = [] results = {} for domain, style, lang, length in test_configs: names = self.model.generate( domain_id=DOMAIN_TO_ID[domain], style_id=STYLE_TO_ID[style], lang_id=LANG_TO_ID[lang], target_length=length, batch_size=n_per_category, cfg_scale=self.config.cfg_scale, temperature=self.config.temperature, n_steps=60, # Fewer steps for quick preview odd_alpha=self.config.odd_alpha, device=str(self.device) ) key = f"{domain}/{style}/{lang}" results[key] = names all_names.extend(names) # Compute diversity metrics unique_names = set(n.lower() for n in all_names) uniqueness = len(unique_names) / max(len(all_names), 1) * 100 avg_len = sum(len(n) for n in all_names) / max(len(all_names), 1) return results, uniqueness, avg_len, all_names def save_checkpoint(self, path: str, is_best: bool = False): """Save model checkpoint.""" checkpoint = { 'config': vars(self.config), 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict(), 'global_step': self.global_step, 'best_loss': self.best_loss, 'train_losses': self.train_losses, 'val_losses': self.val_losses, 'vocab_size': self.tokenizer.vocab_size, 'vocab': self.tokenizer.vocab, } torch.save(checkpoint, path) if is_best: best_path = os.path.join(self.args.save_dir, 'neurolex_v4_best.pt') torch.save(checkpoint, best_path) def train(self, train_loader, val_loader): """Full training loop.""" total_steps = self.args.epochs * len(train_loader) print("=" * 70) print(" NEUROLEX v4 — TRAINING") print("=" * 70) print(f" Model size: {self.model.count_parameters():,} parameters") print(f" Epochs: {self.args.epochs}") print(f" Batch size: {self.args.batch_size}") print(f" Steps/epoch: {len(train_loader)}") print(f" Total steps: {total_steps}") print(f" Learning rate: {self.args.lr} (cosine decay)") print(f" Warmup steps: {self.args.warmup_steps}") print(f" Device: {self.device}") print(f" CFG dropout: {self.config.cfg_dropout}") print(f" Noise schedule: {self.config.noise_schedule}") print("=" * 70) start_time = time.time() for epoch in range(1, self.args.epochs + 1): epoch_start = time.time() epoch_loss = 0 n_steps_epoch = 0 for batch_idx, batch in enumerate(train_loader): loss = self.train_step(batch) epoch_loss += loss n_steps_epoch += 1 # Logging if self.global_step % self.args.log_every == 0: avg_loss = epoch_loss / n_steps_epoch lr = self.scheduler.get_last_lr()[0] elapsed = time.time() - start_time eta = elapsed / self.global_step * (total_steps - self.global_step) print(f" Step {self.global_step:>6}/{total_steps} | " f"Loss={loss:.4f} | Avg={avg_loss:.4f} | " f"LR={lr:.2e} | ETA={eta/60:.0f}min") # End of epoch epoch_time = time.time() - epoch_start avg_epoch_loss = epoch_loss / n_steps_epoch self.train_losses.append(avg_epoch_loss) # Validation val_loss = self.validate(val_loader) self.val_losses.append(val_loss) print(f"\n Epoch {epoch}/{self.args.epochs} | " f"Train={avg_epoch_loss:.4f} | Val={val_loss:.4f} | " f"Time={epoch_time:.0f}s") # Save best model is_best = val_loss < self.best_loss if is_best: self.best_loss = val_loss print(f" ★ New best validation loss: {val_loss:.4f}") self.save_checkpoint( os.path.join(self.args.save_dir, f'neurolex_v4_epoch{epoch}.pt'), is_best=is_best ) # Generate samples periodically if epoch % self.args.sample_every == 0 or epoch == self.args.epochs: print(f"\n --- Sample Generations (Epoch {epoch}) ---") results, uniqueness, avg_len, all_names = self.generate_samples(n_per_category=8) for key, names in results.items(): if names: print(f" {key:30s}: {', '.join(names[:5])}") print(f"\n Diversity: {uniqueness:.1f}% unique | Avg length: {avg_len:.1f}") print(f" Total generated: {len(all_names)} | Unique: {len(set(n.lower() for n in all_names))}") print() # Final summary total_time = time.time() - start_time print("=" * 70) print(" TRAINING COMPLETE") print("=" * 70) print(f" Total time: {total_time/60:.1f} minutes") print(f" Best val loss: {self.best_loss:.4f}") print(f" Final train loss: {self.train_losses[-1]:.4f}") print(f" Loss curve: {' -> '.join(f'{l:.3f}' for l in self.train_losses[:5])}") if len(self.train_losses) > 5: print(f" ... -> {' -> '.join(f'{l:.3f}' for l in self.train_losses[-3:])}") # Final generation showcase print("\n" + "=" * 70) print(" FINAL GENERATION SHOWCASE") print("=" * 70) results, uniqueness, avg_len, all_names = self.generate_samples(n_per_category=12) for key, names in results.items(): domain, style, lang = key.split('/') emoji_map = { 'tech': '💻', 'food': '🍜', 'gaming': '🎮', 'luxury': '💎', 'health': '🌿', 'crypto': '🪙', 'music': '🎵', 'ai': '🤖', 'eco': '♻️', 'fitness': '💪', 'social': '🌐', 'general': '⭐' } emoji = emoji_map.get(domain, '•') if names: print(f"\n {emoji} {domain.upper()} ({style}, {lang}):") for name in names[:8]: print(f" → {name}") print(f"\n {'─' * 50}") print(f" 📊 DIVERSITY SCORE: {uniqueness:.1f}%") print(f" 📏 AVG LENGTH: {avg_len:.1f} chars") print(f" 🎯 TOTAL UNIQUE: {len(set(n.lower() for n in all_names))}/{len(all_names)}") print("=" * 70) return self.best_loss # ═══════════════════════════════════════════════════════════════ # MAIN # ═══════════════════════════════════════════════════════════════ def main(): args = parse_args() # Set seed torch.manual_seed(args.seed) # Device if args.device == 'auto': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: device = torch.device(args.device) print(f"Using device: {device}") if device.type == 'cuda': print(f" GPU: {torch.cuda.get_device_name()}") print(f" Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB") # Create model model, config = create_model(args.size) # Override config with args config.learning_rate = args.lr config.batch_size = args.batch_size config.epochs = args.epochs config.warmup_steps = args.warmup_steps # Create dataloaders train_loader, val_loader = create_dataloaders( batch_size=args.batch_size, n_samples=args.n_samples, num_workers=args.num_workers, streaming=args.streaming ) # Create trainer and train trainer = Trainer(model, config, args, device) best_loss = trainer.train(train_loader, val_loader) # Save final model final_path = os.path.join(args.save_dir, 'neurolex_v4_final.pt') trainer.save_checkpoint(final_path) print(f"\nFinal model saved to: {final_path}") print(f"Best model saved to: {os.path.join(args.save_dir, 'neurolex_v4_best.pt')}") # Print model loading instructions print(f""" {'=' * 70} HOW TO USE THE TRAINED MODEL {'=' * 70} # Load the model: from neurolex_v4_model import NeuroLexV4, NeuroLexConfig, CharTokenizer from neurolex_v4_model import DOMAIN_TO_ID, STYLE_TO_ID, LANG_TO_ID checkpoint = torch.load('{final_path}') config = NeuroLexConfig(**checkpoint['config']) model = NeuroLexV4(config) model.load_state_dict(checkpoint['state_dict']) model.eval() # Generate names: names = model.generate( domain_id=DOMAIN_TO_ID['tech'], style_id=STYLE_TO_ID['sharp'], lang_id=LANG_TO_ID['english'], target_length=8, batch_size=20, cfg_scale=2.5, # higher = more condition-faithful temperature=0.9, # higher = more creative n_steps=80, # more steps = better quality odd_alpha=8.0, # higher = more diversity between samples device='cuda' ) print(names) """) if __name__ == '__main__': main()