""" NeuroName Training Script Implements a multi-stage training procedure: Stage 1: VAE Pretraining (reconstruct names from latent space) - Train char_encoder + char_decoder with reconstruction loss - KL annealing from 0 to target weight (prevents posterior collapse) - Free bits strategy (minimum KL per dimension) Stage 2: Phonotactic Discriminator Training - Train on real names (positive) vs random sequences (negative) - Binary classification with balanced sampling Stage 3: Attribute Classifier Training - Train style and language classifiers on latent representations - Uses frozen encoder to get z, trains classifiers only Stage 4: Joint Fine-tuning - All components trained together - Full loss: reconstruction + KL + phonotactic + attribute classification Usage: python train.py --config configs/default.yaml python train.py --epochs 100 --batch_size 128 --lr 3e-4 """ import os import sys import math import time import argparse import json from pathlib import Path from typing import Dict, Optional import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from tqdm import tqdm # Add parent to path for imports sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from neuroname.model import NeuroNameModel, CharVocab from neuroname.config import NeuroNameConfig from neuroname.data import ( SemanticVocab, NameDataset, get_curated_brand_names, get_synthetic_training_data, create_dataloader, ) from neuroname.phonotactics import PhonotacticDataGenerator, PhonotacticScorer def parse_args(): parser = argparse.ArgumentParser(description="Train NeuroName model") parser.add_argument("--config", type=str, default=None, help="Path to config YAML") parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs") parser.add_argument("--batch_size", type=int, default=64, help="Batch size") parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") parser.add_argument("--device", type=str, default="auto", help="Device (cpu/cuda/auto)") parser.add_argument("--save_dir", type=str, default="checkpoints", help="Save directory") parser.add_argument("--num_train_samples", type=int, default=5000, help="Number of training samples") parser.add_argument("--log_every", type=int, default=50, help="Log every N steps") parser.add_argument("--save_every", type=int, default=10, help="Save every N epochs") parser.add_argument("--seed", type=int, default=42, help="Random seed") return parser.parse_args() def set_seed(seed: int): """Set random seeds for reproducibility.""" import random import numpy as np random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def get_kl_weight(step: int, config: NeuroNameConfig) -> float: """Cyclical KL annealing schedule.""" if step >= config.kl_anneal_steps: return config.kl_weight_end # Linear annealing ratio = step / config.kl_anneal_steps return config.kl_weight_start + ratio * (config.kl_weight_end - config.kl_weight_start) def free_bits_kl(kl_per_dim: torch.Tensor, free_bits: float) -> torch.Tensor: """Apply free bits strategy to prevent KL collapse. Free bits: allow each latent dimension at least `free_bits` nats of KL before it contributes to the loss. This prevents the model from ignoring the latent space entirely (a common VAE failure mode). """ return torch.clamp(kl_per_dim, min=free_bits).sum(dim=-1).mean() class Trainer: """Complete training loop for NeuroName.""" def __init__(self, config: NeuroNameConfig, device: str = "cpu"): self.config = config self.device = device # Initialize model self.model = NeuroNameModel(config.to_dict()).to(device) print(f"Model initialized with {sum(p.numel() for p in self.model.parameters()):,} parameters") print(f"Parameter breakdown:") for name, count in self.model.count_parameters().items(): print(f" {name}: {count:,}") # Vocabularies self.char_vocab = self.model.char_vocab self.semantic_vocab = SemanticVocab() # Optimizers (separate for different components) self.vae_optimizer = AdamW( list(self.model.semantic_encoder.parameters()) + list(self.model.control_encoder.parameters()) + list(self.model.char_encoder.parameters()) + list(self.model.char_decoder.parameters()) + list(self.model.prior_net.parameters()), lr=config.learning_rate, weight_decay=config.weight_decay, betas=(0.9, 0.999), ) self.disc_optimizer = AdamW( self.model.phonotactic_disc.parameters(), lr=config.phon_lr, weight_decay=0.01, ) self.cls_optimizer = AdamW( list(self.model.style_classifier.parameters()) + list(self.model.lang_classifier.parameters()), lr=config.learning_rate, weight_decay=0.01, ) # Phonotactic data generator self.phon_generator = PhonotacticDataGenerator() self.phon_scorer = PhonotacticScorer() # Training state self.global_step = 0 self.best_loss = float("inf") self.history = {"train_loss": [], "val_loss": [], "recon_loss": [], "kl_loss": []} def train_epoch_vae(self, dataloader, epoch: int) -> Dict[str, float]: """Train one epoch of the VAE (Stage 1 or Stage 4).""" self.model.train() total_loss = 0 total_recon = 0 total_kl = 0 total_style = 0 total_lang = 0 num_batches = 0 for batch in dataloader: # Move to device char_ids = batch["char_ids"].to(self.device) hint_ids = batch["hint_ids"].to(self.device) target_length = batch["target_length"].to(self.device) style = batch["style"].to(self.device) language_feel = batch["language_feel"].to(self.device) energy = batch["energy"].to(self.device) char_padding_mask = batch["char_padding_mask"].to(self.device) hint_padding_mask = batch["hint_padding_mask"].to(self.device) # Forward pass outputs = self.model( char_ids=char_ids, hint_ids=hint_ids, target_length=target_length, style=style, language_feel=language_feel, energy=energy, char_padding_mask=char_padding_mask, hint_padding_mask=hint_padding_mask, ) # KL annealing weight kl_weight = get_kl_weight(self.global_step, self.config) # Apply free bits to KL kl_per_dim = 0.5 * ( outputs["p_mu"].detach() - outputs["q_logvar"] # Simplified for per-dim + (torch.exp(outputs["q_logvar"]) + (outputs["q_mu"] - outputs["p_mu"].detach()).pow(2)) / torch.exp(outputs["p_mu"].detach()).clamp(min=1e-8) - 1.0 ) # Use standard KL from model for simplicity kl_loss = outputs["kl_loss"] # Total loss loss = ( outputs["recon_loss"] + kl_weight * kl_loss + self.config.style_loss_weight * outputs["style_loss"] + self.config.lang_loss_weight * outputs["lang_loss"] ) # Backward pass self.vae_optimizer.zero_grad() self.cls_optimizer.zero_grad() loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) self.vae_optimizer.step() self.cls_optimizer.step() # Logging total_loss += loss.item() total_recon += outputs["recon_loss"].item() total_kl += kl_loss.item() total_style += outputs["style_loss"].item() total_lang += outputs["lang_loss"].item() num_batches += 1 self.global_step += 1 # Periodic logging if self.global_step % self.config.batch_size == 0: pass # tqdm handles this return { "loss": total_loss / max(num_batches, 1), "recon_loss": total_recon / max(num_batches, 1), "kl_loss": total_kl / max(num_batches, 1), "style_loss": total_style / max(num_batches, 1), "lang_loss": total_lang / max(num_batches, 1), "kl_weight": kl_weight, } def train_phonotactic_discriminator(self, real_names: list, num_steps: int = 500): """Train the phonotactic discriminator (Stage 2).""" self.model.phonotactic_disc.train() total_loss = 0 total_acc = 0 for step in range(num_steps): # Generate balanced batch names, labels = self.phon_generator.generate_batch(real_names, batch_size=64) # Encode characters char_ids = self.char_vocab.batch_encode(names, max_len=self.config.max_len) char_ids = char_ids.to(self.device) labels_tensor = torch.tensor(labels, dtype=torch.float, device=self.device) padding_mask = (char_ids == self.char_vocab.pad_idx) # Forward scores = self.model.phonotactic_disc(char_ids, padding_mask).squeeze(-1) loss = F.binary_cross_entropy_with_logits(scores, labels_tensor) # Backward self.disc_optimizer.zero_grad() loss.backward() self.disc_optimizer.step() # Accuracy preds = (scores > 0).float() acc = (preds == labels_tensor).float().mean().item() total_loss += loss.item() total_acc += acc avg_loss = total_loss / num_steps avg_acc = total_acc / num_steps print(f" Phonotactic Discriminator - Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.3f}") return {"phon_loss": avg_loss, "phon_acc": avg_acc} @torch.no_grad() def validate(self, dataloader) -> Dict[str, float]: """Validate the model.""" self.model.eval() total_loss = 0 total_recon = 0 num_batches = 0 for batch in dataloader: char_ids = batch["char_ids"].to(self.device) hint_ids = batch["hint_ids"].to(self.device) target_length = batch["target_length"].to(self.device) style = batch["style"].to(self.device) language_feel = batch["language_feel"].to(self.device) energy = batch["energy"].to(self.device) char_padding_mask = batch["char_padding_mask"].to(self.device) hint_padding_mask = batch["hint_padding_mask"].to(self.device) outputs = self.model( char_ids=char_ids, hint_ids=hint_ids, target_length=target_length, style=style, language_feel=language_feel, energy=energy, char_padding_mask=char_padding_mask, hint_padding_mask=hint_padding_mask, ) loss = outputs["recon_loss"] + 0.1 * outputs["kl_loss"] total_loss += loss.item() total_recon += outputs["recon_loss"].item() num_batches += 1 return { "val_loss": total_loss / max(num_batches, 1), "val_recon": total_recon / max(num_batches, 1), } @torch.no_grad() def generate_samples(self, num_samples: int = 5) -> list: """Generate sample names for monitoring.""" self.model.eval() hints_list = [ ["speed", "technology", "future"], ["nature", "calm", "harmony"], ["gaming", "epic", "adventure"], ["luxury", "elegance", "premium"], ["creative", "art", "design"], ] styles = ["techy", "organic", "playful", "elegant", "modern"] samples = [] for hints, style in zip(hints_list[:num_samples], styles[:num_samples]): hint_ids = self.semantic_vocab.encode(hints) hint_ids = torch.tensor([hint_ids], dtype=torch.long, device=self.device) hint_mask = (hint_ids == self.semantic_vocab.pad_idx) style_idx = torch.tensor([NameDataset.STYLE_MAP.get(style, 0)], dtype=torch.long, device=self.device) lang_idx = torch.tensor([0], dtype=torch.long, device=self.device) energy_idx = torch.tensor([1], dtype=torch.long, device=self.device) target_len = torch.tensor([[0.25]], dtype=torch.float, device=self.device) generated = self.model.generate_from_prior( hint_ids=hint_ids, target_length=target_len, style=style_idx, language_feel=lang_idx, energy=energy_idx, hint_padding_mask=hint_mask, temperature=0.8, num_samples=3, ) names = self.char_vocab.batch_decode(generated) samples.append({ "hints": hints, "style": style, "generated": names, }) return samples def save_checkpoint(self, path: str, epoch: int): """Save model checkpoint.""" os.makedirs(os.path.dirname(path), exist_ok=True) torch.save({ "epoch": epoch, "global_step": self.global_step, "model_state_dict": self.model.state_dict(), "vae_optimizer": self.vae_optimizer.state_dict(), "config": self.config.to_dict(), "history": self.history, "best_loss": self.best_loss, }, path) print(f" Checkpoint saved to {path}") def train( self, train_data: list, val_data: Optional[list] = None, num_epochs: int = 100, save_dir: str = "checkpoints", save_every: int = 10, ): """ Full training procedure. Stage 1 (epochs 1-40%): VAE pretraining with KL annealing Stage 2 (after stage 1): Phonotactic discriminator training Stage 3 (epochs 40%-100%): Joint training with all losses """ print("=" * 60) print("NeuroName Training") print("=" * 60) print(f"Training samples: {len(train_data)}") print(f"Epochs: {num_epochs}") print(f"Device: {self.device}") print() # Create dataloaders train_loader = create_dataloader( train_data, self.char_vocab, self.semantic_vocab, batch_size=self.config.batch_size, shuffle=True, ) val_loader = None if val_data: val_loader = create_dataloader( val_data, self.char_vocab, self.semantic_vocab, batch_size=self.config.batch_size, shuffle=False, ) # Collect real names for phonotactic training real_names = [item["name"] for item in train_data] # Training loop stage2_done = False stage2_epoch = int(num_epochs * 0.3) for epoch in range(1, num_epochs + 1): start_time = time.time() # === Stage transition: train phonotactic discriminator === if epoch == stage2_epoch and not stage2_done: print("\n" + "=" * 40) print("Stage 2: Training Phonotactic Discriminator") print("=" * 40) self.train_phonotactic_discriminator(real_names, num_steps=500) stage2_done = True print() # === Main VAE training === metrics = self.train_epoch_vae(train_loader, epoch) # Validation val_metrics = {} if val_loader and epoch % 5 == 0: val_metrics = self.validate(val_loader) # Logging elapsed = time.time() - start_time print( f"Epoch {epoch:3d}/{num_epochs} | " f"Loss: {metrics['loss']:.4f} | " f"Recon: {metrics['recon_loss']:.4f} | " f"KL: {metrics['kl_loss']:.4f} | " f"KL_w: {metrics['kl_weight']:.4f} | " f"Style: {metrics['style_loss']:.4f} | " f"Time: {elapsed:.1f}s" + (f" | Val: {val_metrics.get('val_loss', 0):.4f}" if val_metrics else "") ) # Track history self.history["train_loss"].append(metrics["loss"]) self.history["recon_loss"].append(metrics["recon_loss"]) self.history["kl_loss"].append(metrics["kl_loss"]) if val_metrics: self.history["val_loss"].append(val_metrics["val_loss"]) # Generate samples periodically if epoch % 10 == 0 or epoch == 1: print("\n Sample generations:") samples = self.generate_samples() for s in samples: names_str = ", ".join(s["generated"][:3]) print(f" [{s['style']}] {s['hints']} → {names_str}") print() # Save checkpoint if epoch % save_every == 0 or epoch == num_epochs: path = os.path.join(save_dir, f"checkpoint_epoch_{epoch}.pt") self.save_checkpoint(path, epoch) # Save best if metrics["loss"] < self.best_loss: self.best_loss = metrics["loss"] path = os.path.join(save_dir, "best_model.pt") self.save_checkpoint(path, epoch) print("\n" + "=" * 60) print("Training complete!") print(f"Best loss: {self.best_loss:.4f}") print("=" * 60) def main(): args = parse_args() set_seed(args.seed) # Configuration if args.config: config = NeuroNameConfig.load(args.config) else: config = NeuroNameConfig() # Override from command line config.num_epochs = args.epochs config.batch_size = args.batch_size config.learning_rate = args.lr # Device if args.device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" else: device = args.device print(f"Using device: {device}") # Generate training data print("Generating training data...") train_data = get_synthetic_training_data(num_samples=args.num_train_samples, seed=args.seed) # Split into train/val (90/10) split_idx = int(len(train_data) * 0.9) val_data = train_data[split_idx:] train_data = train_data[:split_idx] print(f"Train: {len(train_data)} samples, Val: {len(val_data)} samples") # Train trainer = Trainer(config, device=device) trainer.train( train_data=train_data, val_data=val_data, num_epochs=config.num_epochs, save_dir=args.save_dir, save_every=args.save_every, ) # Save final configuration config.save_json(os.path.join(args.save_dir, "config.json")) print(f"\nConfig saved to {args.save_dir}/config.json") if __name__ == "__main__": main()