| """ |
| NeuroLex v4 โ Training Script |
| ============================== |
| Optimized for free Google Colab (T4 16GB GPU). |
| |
| Training time: ~20-30 minutes for 'base' model (12M params) |
| Memory usage: <8GB GPU RAM with batch_size=256 |
| |
| Usage: |
| python train.py --size base --epochs 30 --batch_size 256 |
| |
| Or in Colab: |
| %run train.py --size base --epochs 30 |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR |
| import time |
| import argparse |
| import os |
| import json |
| import math |
| from pathlib import Path |
|
|
| from neurolex_v4_model import NeuroLexV4, NeuroLexConfig, CharTokenizer, create_model |
| from neurolex_v4_model import DOMAINS, STYLES, LANGUAGES, DOMAIN_TO_ID, STYLE_TO_ID, LANG_TO_ID |
| from neurolex_v4_dataset import create_dataloaders, NeuroLexDataset |
|
|
|
|
| |
| |
| |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Train NeuroLex v4') |
| parser.add_argument('--size', type=str, default='base', choices=['tiny', 'small', 'base', 'large']) |
| parser.add_argument('--epochs', type=int, default=30) |
| parser.add_argument('--batch_size', type=int, default=256) |
| parser.add_argument('--lr', type=float, default=3e-4) |
| parser.add_argument('--warmup_steps', type=int, default=500) |
| parser.add_argument('--n_samples', type=int, default=100000) |
| parser.add_argument('--streaming', action='store_true', help='Use streaming dataset (infinite)') |
| parser.add_argument('--save_dir', type=str, default='./checkpoints') |
| parser.add_argument('--log_every', type=int, default=100) |
| parser.add_argument('--sample_every', type=int, default=3, help='Sample every N epochs') |
| parser.add_argument('--device', type=str, default='auto') |
| parser.add_argument('--seed', type=int, default=42) |
| parser.add_argument('--gradient_clip', type=float, default=1.0) |
| parser.add_argument('--weight_decay', type=float, default=0.01) |
| parser.add_argument('--num_workers', type=int, default=2) |
| return parser.parse_args() |
|
|
|
|
| |
| |
| |
|
|
| class Trainer: |
| def __init__(self, model, config, args, device): |
| self.model = model.to(device) |
| self.config = config |
| self.args = args |
| self.device = device |
| self.tokenizer = CharTokenizer() |
| |
| |
| self.optimizer = AdamW( |
| model.parameters(), |
| lr=args.lr, |
| weight_decay=args.weight_decay, |
| betas=(0.9, 0.98), |
| eps=1e-8 |
| ) |
| |
| |
| total_steps = args.epochs * (args.n_samples // args.batch_size) |
| warmup_scheduler = LinearLR( |
| self.optimizer, start_factor=0.01, end_factor=1.0, |
| total_iters=args.warmup_steps |
| ) |
| cosine_scheduler = CosineAnnealingLR( |
| self.optimizer, T_max=total_steps - args.warmup_steps, |
| eta_min=args.lr * 0.01 |
| ) |
| self.scheduler = SequentialLR( |
| self.optimizer, schedulers=[warmup_scheduler, cosine_scheduler], |
| milestones=[args.warmup_steps] |
| ) |
| |
| |
| self.global_step = 0 |
| self.best_loss = float('inf') |
| self.train_losses = [] |
| self.val_losses = [] |
| |
| |
| os.makedirs(args.save_dir, exist_ok=True) |
| |
| def train_step(self, batch): |
| """Single training step.""" |
| self.model.train() |
| |
| input_ids = batch['input_ids'].to(self.device) |
| domain = batch['domain'].to(self.device) |
| style = batch['style'].to(self.device) |
| language = batch['language'].to(self.device) |
| length = batch['length'].to(self.device) |
| |
| |
| loss = self.model.compute_loss(input_ids, domain, style, language, length) |
| |
| |
| self.optimizer.zero_grad() |
| loss.backward() |
| |
| |
| if self.args.gradient_clip > 0: |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip) |
| |
| self.optimizer.step() |
| self.scheduler.step() |
| |
| self.global_step += 1 |
| return loss.item() |
| |
| @torch.no_grad() |
| def validate(self, val_loader): |
| """Run validation.""" |
| self.model.eval() |
| total_loss = 0 |
| n_batches = 0 |
| |
| for batch in val_loader: |
| input_ids = batch['input_ids'].to(self.device) |
| domain = batch['domain'].to(self.device) |
| style = batch['style'].to(self.device) |
| language = batch['language'].to(self.device) |
| length = batch['length'].to(self.device) |
| |
| loss = self.model.compute_loss(input_ids, domain, style, language, length) |
| total_loss += loss.item() |
| n_batches += 1 |
| |
| return total_loss / max(n_batches, 1) |
| |
| @torch.no_grad() |
| def generate_samples(self, n_per_category: int = 8): |
| """Generate sample names for quality monitoring.""" |
| self.model.eval() |
| |
| |
| test_configs = [ |
| ('tech', 'sharp', 'english', 8), |
| ('tech', 'futuristic', 'japanese', 7), |
| ('food', 'warm', 'french', 7), |
| ('gaming', 'bold', 'english', 9), |
| ('luxury', 'elegant', 'italian', 8), |
| ('health', 'organic', 'hawaiian', 7), |
| ('crypto', 'futuristic', 'greek', 8), |
| ('music', 'playful', 'spanish', 7), |
| ('ai', 'sharp', 'latin', 6), |
| ('eco', 'warm', 'swedish', 7), |
| ('fitness', 'bold', 'german', 8), |
| ('social', 'playful', 'korean', 6), |
| ] |
| |
| all_names = [] |
| results = {} |
| |
| for domain, style, lang, length in test_configs: |
| names = self.model.generate( |
| domain_id=DOMAIN_TO_ID[domain], |
| style_id=STYLE_TO_ID[style], |
| lang_id=LANG_TO_ID[lang], |
| target_length=length, |
| batch_size=n_per_category, |
| cfg_scale=self.config.cfg_scale, |
| temperature=self.config.temperature, |
| n_steps=60, |
| odd_alpha=self.config.odd_alpha, |
| device=str(self.device) |
| ) |
| |
| key = f"{domain}/{style}/{lang}" |
| results[key] = names |
| all_names.extend(names) |
| |
| |
| unique_names = set(n.lower() for n in all_names) |
| uniqueness = len(unique_names) / max(len(all_names), 1) * 100 |
| |
| avg_len = sum(len(n) for n in all_names) / max(len(all_names), 1) |
| |
| return results, uniqueness, avg_len, all_names |
| |
| def save_checkpoint(self, path: str, is_best: bool = False): |
| """Save model checkpoint.""" |
| checkpoint = { |
| 'config': vars(self.config), |
| 'state_dict': self.model.state_dict(), |
| 'optimizer': self.optimizer.state_dict(), |
| 'scheduler': self.scheduler.state_dict(), |
| 'global_step': self.global_step, |
| 'best_loss': self.best_loss, |
| 'train_losses': self.train_losses, |
| 'val_losses': self.val_losses, |
| 'vocab_size': self.tokenizer.vocab_size, |
| 'vocab': self.tokenizer.vocab, |
| } |
| torch.save(checkpoint, path) |
| |
| if is_best: |
| best_path = os.path.join(self.args.save_dir, 'neurolex_v4_best.pt') |
| torch.save(checkpoint, best_path) |
| |
| def train(self, train_loader, val_loader): |
| """Full training loop.""" |
| total_steps = self.args.epochs * len(train_loader) |
| |
| print("=" * 70) |
| print(" NEUROLEX v4 โ TRAINING") |
| print("=" * 70) |
| print(f" Model size: {self.model.count_parameters():,} parameters") |
| print(f" Epochs: {self.args.epochs}") |
| print(f" Batch size: {self.args.batch_size}") |
| print(f" Steps/epoch: {len(train_loader)}") |
| print(f" Total steps: {total_steps}") |
| print(f" Learning rate: {self.args.lr} (cosine decay)") |
| print(f" Warmup steps: {self.args.warmup_steps}") |
| print(f" Device: {self.device}") |
| print(f" CFG dropout: {self.config.cfg_dropout}") |
| print(f" Noise schedule: {self.config.noise_schedule}") |
| print("=" * 70) |
| |
| start_time = time.time() |
| |
| for epoch in range(1, self.args.epochs + 1): |
| epoch_start = time.time() |
| epoch_loss = 0 |
| n_steps_epoch = 0 |
| |
| for batch_idx, batch in enumerate(train_loader): |
| loss = self.train_step(batch) |
| epoch_loss += loss |
| n_steps_epoch += 1 |
| |
| |
| if self.global_step % self.args.log_every == 0: |
| avg_loss = epoch_loss / n_steps_epoch |
| lr = self.scheduler.get_last_lr()[0] |
| elapsed = time.time() - start_time |
| eta = elapsed / self.global_step * (total_steps - self.global_step) |
| |
| print(f" Step {self.global_step:>6}/{total_steps} | " |
| f"Loss={loss:.4f} | Avg={avg_loss:.4f} | " |
| f"LR={lr:.2e} | ETA={eta/60:.0f}min") |
| |
| |
| epoch_time = time.time() - epoch_start |
| avg_epoch_loss = epoch_loss / n_steps_epoch |
| self.train_losses.append(avg_epoch_loss) |
| |
| |
| val_loss = self.validate(val_loader) |
| self.val_losses.append(val_loss) |
| |
| print(f"\n Epoch {epoch}/{self.args.epochs} | " |
| f"Train={avg_epoch_loss:.4f} | Val={val_loss:.4f} | " |
| f"Time={epoch_time:.0f}s") |
| |
| |
| is_best = val_loss < self.best_loss |
| if is_best: |
| self.best_loss = val_loss |
| print(f" โ
New best validation loss: {val_loss:.4f}") |
| |
| self.save_checkpoint( |
| os.path.join(self.args.save_dir, f'neurolex_v4_epoch{epoch}.pt'), |
| is_best=is_best |
| ) |
| |
| |
| if epoch % self.args.sample_every == 0 or epoch == self.args.epochs: |
| print(f"\n --- Sample Generations (Epoch {epoch}) ---") |
| results, uniqueness, avg_len, all_names = self.generate_samples(n_per_category=8) |
| |
| for key, names in results.items(): |
| if names: |
| print(f" {key:30s}: {', '.join(names[:5])}") |
| |
| print(f"\n Diversity: {uniqueness:.1f}% unique | Avg length: {avg_len:.1f}") |
| print(f" Total generated: {len(all_names)} | Unique: {len(set(n.lower() for n in all_names))}") |
| |
| print() |
| |
| |
| total_time = time.time() - start_time |
| print("=" * 70) |
| print(" TRAINING COMPLETE") |
| print("=" * 70) |
| print(f" Total time: {total_time/60:.1f} minutes") |
| print(f" Best val loss: {self.best_loss:.4f}") |
| print(f" Final train loss: {self.train_losses[-1]:.4f}") |
| print(f" Loss curve: {' -> '.join(f'{l:.3f}' for l in self.train_losses[:5])}") |
| if len(self.train_losses) > 5: |
| print(f" ... -> {' -> '.join(f'{l:.3f}' for l in self.train_losses[-3:])}") |
| |
| |
| print("\n" + "=" * 70) |
| print(" FINAL GENERATION SHOWCASE") |
| print("=" * 70) |
| results, uniqueness, avg_len, all_names = self.generate_samples(n_per_category=12) |
| |
| for key, names in results.items(): |
| domain, style, lang = key.split('/') |
| emoji_map = { |
| 'tech': '๐ป', 'food': '๐', 'gaming': '๐ฎ', 'luxury': '๐', |
| 'health': '๐ฟ', 'crypto': '๐ช', 'music': '๐ต', 'ai': '๐ค', |
| 'eco': 'โป๏ธ', 'fitness': '๐ช', 'social': '๐', 'general': 'โญ' |
| } |
| emoji = emoji_map.get(domain, 'โข') |
| if names: |
| print(f"\n {emoji} {domain.upper()} ({style}, {lang}):") |
| for name in names[:8]: |
| print(f" โ {name}") |
| |
| print(f"\n {'โ' * 50}") |
| print(f" ๐ DIVERSITY SCORE: {uniqueness:.1f}%") |
| print(f" ๐ AVG LENGTH: {avg_len:.1f} chars") |
| print(f" ๐ฏ TOTAL UNIQUE: {len(set(n.lower() for n in all_names))}/{len(all_names)}") |
| print("=" * 70) |
| |
| return self.best_loss |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| args = parse_args() |
| |
| |
| torch.manual_seed(args.seed) |
| |
| |
| if args.device == 'auto': |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| else: |
| device = torch.device(args.device) |
| |
| print(f"Using device: {device}") |
| if device.type == 'cuda': |
| print(f" GPU: {torch.cuda.get_device_name()}") |
| print(f" Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB") |
| |
| |
| model, config = create_model(args.size) |
| |
| |
| config.learning_rate = args.lr |
| config.batch_size = args.batch_size |
| config.epochs = args.epochs |
| config.warmup_steps = args.warmup_steps |
| |
| |
| train_loader, val_loader = create_dataloaders( |
| batch_size=args.batch_size, |
| n_samples=args.n_samples, |
| num_workers=args.num_workers, |
| streaming=args.streaming |
| ) |
| |
| |
| trainer = Trainer(model, config, args, device) |
| best_loss = trainer.train(train_loader, val_loader) |
| |
| |
| final_path = os.path.join(args.save_dir, 'neurolex_v4_final.pt') |
| trainer.save_checkpoint(final_path) |
| print(f"\nFinal model saved to: {final_path}") |
| print(f"Best model saved to: {os.path.join(args.save_dir, 'neurolex_v4_best.pt')}") |
| |
| |
| print(f""" |
| {'=' * 70} |
| HOW TO USE THE TRAINED MODEL |
| {'=' * 70} |
| |
| # Load the model: |
| from neurolex_v4_model import NeuroLexV4, NeuroLexConfig, CharTokenizer |
| from neurolex_v4_model import DOMAIN_TO_ID, STYLE_TO_ID, LANG_TO_ID |
| |
| checkpoint = torch.load('{final_path}') |
| config = NeuroLexConfig(**checkpoint['config']) |
| model = NeuroLexV4(config) |
| model.load_state_dict(checkpoint['state_dict']) |
| model.eval() |
| |
| # Generate names: |
| names = model.generate( |
| domain_id=DOMAIN_TO_ID['tech'], |
| style_id=STYLE_TO_ID['sharp'], |
| lang_id=LANG_TO_ID['english'], |
| target_length=8, |
| batch_size=20, |
| cfg_scale=2.5, # higher = more condition-faithful |
| temperature=0.9, # higher = more creative |
| n_steps=80, # more steps = better quality |
| odd_alpha=8.0, # higher = more diversity between samples |
| device='cuda' |
| ) |
| print(names) |
| """) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|