| """ |
| NeuroName Training Script |
| |
| Implements a multi-stage training procedure: |
| |
| Stage 1: VAE Pretraining (reconstruct names from latent space) |
| - Train char_encoder + char_decoder with reconstruction loss |
| - KL annealing from 0 to target weight (prevents posterior collapse) |
| - Free bits strategy (minimum KL per dimension) |
| |
| Stage 2: Phonotactic Discriminator Training |
| - Train on real names (positive) vs random sequences (negative) |
| - Binary classification with balanced sampling |
| |
| Stage 3: Attribute Classifier Training |
| - Train style and language classifiers on latent representations |
| - Uses frozen encoder to get z, trains classifiers only |
| |
| Stage 4: Joint Fine-tuning |
| - All components trained together |
| - Full loss: reconstruction + KL + phonotactic + attribute classification |
| |
| Usage: |
| python train.py --config configs/default.yaml |
| python train.py --epochs 100 --batch_size 128 --lr 3e-4 |
| """ |
|
|
| import os |
| import sys |
| import math |
| import time |
| import argparse |
| import json |
| from pathlib import Path |
| from typing import Dict, Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR |
| from tqdm import tqdm |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| from neuroname.model import NeuroNameModel, CharVocab |
| from neuroname.config import NeuroNameConfig |
| from neuroname.data import ( |
| SemanticVocab, |
| NameDataset, |
| get_curated_brand_names, |
| get_synthetic_training_data, |
| create_dataloader, |
| ) |
| from neuroname.phonotactics import PhonotacticDataGenerator, PhonotacticScorer |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Train NeuroName model") |
| parser.add_argument("--config", type=str, default=None, help="Path to config YAML") |
| parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs") |
| parser.add_argument("--batch_size", type=int, default=64, help="Batch size") |
| parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") |
| parser.add_argument("--device", type=str, default="auto", help="Device (cpu/cuda/auto)") |
| parser.add_argument("--save_dir", type=str, default="checkpoints", help="Save directory") |
| parser.add_argument("--num_train_samples", type=int, default=5000, help="Number of training samples") |
| parser.add_argument("--log_every", type=int, default=50, help="Log every N steps") |
| parser.add_argument("--save_every", type=int, default=10, help="Save every N epochs") |
| parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| return parser.parse_args() |
|
|
|
|
| def set_seed(seed: int): |
| """Set random seeds for reproducibility.""" |
| import random |
| import numpy as np |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def get_kl_weight(step: int, config: NeuroNameConfig) -> float: |
| """Cyclical KL annealing schedule.""" |
| if step >= config.kl_anneal_steps: |
| return config.kl_weight_end |
| |
| |
| ratio = step / config.kl_anneal_steps |
| return config.kl_weight_start + ratio * (config.kl_weight_end - config.kl_weight_start) |
|
|
|
|
| def free_bits_kl(kl_per_dim: torch.Tensor, free_bits: float) -> torch.Tensor: |
| """Apply free bits strategy to prevent KL collapse. |
| |
| Free bits: allow each latent dimension at least `free_bits` nats of KL |
| before it contributes to the loss. This prevents the model from ignoring |
| the latent space entirely (a common VAE failure mode). |
| """ |
| return torch.clamp(kl_per_dim, min=free_bits).sum(dim=-1).mean() |
|
|
|
|
| class Trainer: |
| """Complete training loop for NeuroName.""" |
| |
| def __init__(self, config: NeuroNameConfig, device: str = "cpu"): |
| self.config = config |
| self.device = device |
| |
| |
| self.model = NeuroNameModel(config.to_dict()).to(device) |
| print(f"Model initialized with {sum(p.numel() for p in self.model.parameters()):,} parameters") |
| print(f"Parameter breakdown:") |
| for name, count in self.model.count_parameters().items(): |
| print(f" {name}: {count:,}") |
| |
| |
| self.char_vocab = self.model.char_vocab |
| self.semantic_vocab = SemanticVocab() |
| |
| |
| self.vae_optimizer = AdamW( |
| list(self.model.semantic_encoder.parameters()) |
| + list(self.model.control_encoder.parameters()) |
| + list(self.model.char_encoder.parameters()) |
| + list(self.model.char_decoder.parameters()) |
| + list(self.model.prior_net.parameters()), |
| lr=config.learning_rate, |
| weight_decay=config.weight_decay, |
| betas=(0.9, 0.999), |
| ) |
| |
| self.disc_optimizer = AdamW( |
| self.model.phonotactic_disc.parameters(), |
| lr=config.phon_lr, |
| weight_decay=0.01, |
| ) |
| |
| self.cls_optimizer = AdamW( |
| list(self.model.style_classifier.parameters()) |
| + list(self.model.lang_classifier.parameters()), |
| lr=config.learning_rate, |
| weight_decay=0.01, |
| ) |
| |
| |
| self.phon_generator = PhonotacticDataGenerator() |
| self.phon_scorer = PhonotacticScorer() |
| |
| |
| self.global_step = 0 |
| self.best_loss = float("inf") |
| self.history = {"train_loss": [], "val_loss": [], "recon_loss": [], "kl_loss": []} |
| |
| def train_epoch_vae(self, dataloader, epoch: int) -> Dict[str, float]: |
| """Train one epoch of the VAE (Stage 1 or Stage 4).""" |
| self.model.train() |
| total_loss = 0 |
| total_recon = 0 |
| total_kl = 0 |
| total_style = 0 |
| total_lang = 0 |
| num_batches = 0 |
| |
| for batch in dataloader: |
| |
| char_ids = batch["char_ids"].to(self.device) |
| hint_ids = batch["hint_ids"].to(self.device) |
| target_length = batch["target_length"].to(self.device) |
| style = batch["style"].to(self.device) |
| language_feel = batch["language_feel"].to(self.device) |
| energy = batch["energy"].to(self.device) |
| char_padding_mask = batch["char_padding_mask"].to(self.device) |
| hint_padding_mask = batch["hint_padding_mask"].to(self.device) |
| |
| |
| outputs = self.model( |
| char_ids=char_ids, |
| hint_ids=hint_ids, |
| target_length=target_length, |
| style=style, |
| language_feel=language_feel, |
| energy=energy, |
| char_padding_mask=char_padding_mask, |
| hint_padding_mask=hint_padding_mask, |
| ) |
| |
| |
| kl_weight = get_kl_weight(self.global_step, self.config) |
| |
| |
| kl_per_dim = 0.5 * ( |
| outputs["p_mu"].detach() - outputs["q_logvar"] |
| + (torch.exp(outputs["q_logvar"]) + (outputs["q_mu"] - outputs["p_mu"].detach()).pow(2)) |
| / torch.exp(outputs["p_mu"].detach()).clamp(min=1e-8) - 1.0 |
| ) |
| |
| kl_loss = outputs["kl_loss"] |
| |
| |
| loss = ( |
| outputs["recon_loss"] |
| + kl_weight * kl_loss |
| + self.config.style_loss_weight * outputs["style_loss"] |
| + self.config.lang_loss_weight * outputs["lang_loss"] |
| ) |
| |
| |
| self.vae_optimizer.zero_grad() |
| self.cls_optimizer.zero_grad() |
| loss.backward() |
| |
| |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) |
| |
| self.vae_optimizer.step() |
| self.cls_optimizer.step() |
| |
| |
| total_loss += loss.item() |
| total_recon += outputs["recon_loss"].item() |
| total_kl += kl_loss.item() |
| total_style += outputs["style_loss"].item() |
| total_lang += outputs["lang_loss"].item() |
| num_batches += 1 |
| self.global_step += 1 |
| |
| |
| if self.global_step % self.config.batch_size == 0: |
| pass |
| |
| return { |
| "loss": total_loss / max(num_batches, 1), |
| "recon_loss": total_recon / max(num_batches, 1), |
| "kl_loss": total_kl / max(num_batches, 1), |
| "style_loss": total_style / max(num_batches, 1), |
| "lang_loss": total_lang / max(num_batches, 1), |
| "kl_weight": kl_weight, |
| } |
| |
| def train_phonotactic_discriminator(self, real_names: list, num_steps: int = 500): |
| """Train the phonotactic discriminator (Stage 2).""" |
| self.model.phonotactic_disc.train() |
| |
| total_loss = 0 |
| total_acc = 0 |
| |
| for step in range(num_steps): |
| |
| names, labels = self.phon_generator.generate_batch(real_names, batch_size=64) |
| |
| |
| char_ids = self.char_vocab.batch_encode(names, max_len=self.config.max_len) |
| char_ids = char_ids.to(self.device) |
| labels_tensor = torch.tensor(labels, dtype=torch.float, device=self.device) |
| |
| padding_mask = (char_ids == self.char_vocab.pad_idx) |
| |
| |
| scores = self.model.phonotactic_disc(char_ids, padding_mask).squeeze(-1) |
| loss = F.binary_cross_entropy_with_logits(scores, labels_tensor) |
| |
| |
| self.disc_optimizer.zero_grad() |
| loss.backward() |
| self.disc_optimizer.step() |
| |
| |
| preds = (scores > 0).float() |
| acc = (preds == labels_tensor).float().mean().item() |
| |
| total_loss += loss.item() |
| total_acc += acc |
| |
| avg_loss = total_loss / num_steps |
| avg_acc = total_acc / num_steps |
| print(f" Phonotactic Discriminator - Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.3f}") |
| return {"phon_loss": avg_loss, "phon_acc": avg_acc} |
| |
| @torch.no_grad() |
| def validate(self, dataloader) -> Dict[str, float]: |
| """Validate the model.""" |
| self.model.eval() |
| total_loss = 0 |
| total_recon = 0 |
| num_batches = 0 |
| |
| for batch in dataloader: |
| char_ids = batch["char_ids"].to(self.device) |
| hint_ids = batch["hint_ids"].to(self.device) |
| target_length = batch["target_length"].to(self.device) |
| style = batch["style"].to(self.device) |
| language_feel = batch["language_feel"].to(self.device) |
| energy = batch["energy"].to(self.device) |
| char_padding_mask = batch["char_padding_mask"].to(self.device) |
| hint_padding_mask = batch["hint_padding_mask"].to(self.device) |
| |
| outputs = self.model( |
| char_ids=char_ids, |
| hint_ids=hint_ids, |
| target_length=target_length, |
| style=style, |
| language_feel=language_feel, |
| energy=energy, |
| char_padding_mask=char_padding_mask, |
| hint_padding_mask=hint_padding_mask, |
| ) |
| |
| loss = outputs["recon_loss"] + 0.1 * outputs["kl_loss"] |
| total_loss += loss.item() |
| total_recon += outputs["recon_loss"].item() |
| num_batches += 1 |
| |
| return { |
| "val_loss": total_loss / max(num_batches, 1), |
| "val_recon": total_recon / max(num_batches, 1), |
| } |
| |
| @torch.no_grad() |
| def generate_samples(self, num_samples: int = 5) -> list: |
| """Generate sample names for monitoring.""" |
| self.model.eval() |
| |
| hints_list = [ |
| ["speed", "technology", "future"], |
| ["nature", "calm", "harmony"], |
| ["gaming", "epic", "adventure"], |
| ["luxury", "elegance", "premium"], |
| ["creative", "art", "design"], |
| ] |
| |
| styles = ["techy", "organic", "playful", "elegant", "modern"] |
| |
| samples = [] |
| for hints, style in zip(hints_list[:num_samples], styles[:num_samples]): |
| hint_ids = self.semantic_vocab.encode(hints) |
| hint_ids = torch.tensor([hint_ids], dtype=torch.long, device=self.device) |
| hint_mask = (hint_ids == self.semantic_vocab.pad_idx) |
| |
| style_idx = torch.tensor([NameDataset.STYLE_MAP.get(style, 0)], |
| dtype=torch.long, device=self.device) |
| lang_idx = torch.tensor([0], dtype=torch.long, device=self.device) |
| energy_idx = torch.tensor([1], dtype=torch.long, device=self.device) |
| target_len = torch.tensor([[0.25]], dtype=torch.float, device=self.device) |
| |
| generated = self.model.generate_from_prior( |
| hint_ids=hint_ids, |
| target_length=target_len, |
| style=style_idx, |
| language_feel=lang_idx, |
| energy=energy_idx, |
| hint_padding_mask=hint_mask, |
| temperature=0.8, |
| num_samples=3, |
| ) |
| |
| names = self.char_vocab.batch_decode(generated) |
| samples.append({ |
| "hints": hints, |
| "style": style, |
| "generated": names, |
| }) |
| |
| return samples |
| |
| def save_checkpoint(self, path: str, epoch: int): |
| """Save model checkpoint.""" |
| os.makedirs(os.path.dirname(path), exist_ok=True) |
| torch.save({ |
| "epoch": epoch, |
| "global_step": self.global_step, |
| "model_state_dict": self.model.state_dict(), |
| "vae_optimizer": self.vae_optimizer.state_dict(), |
| "config": self.config.to_dict(), |
| "history": self.history, |
| "best_loss": self.best_loss, |
| }, path) |
| print(f" Checkpoint saved to {path}") |
| |
| def train( |
| self, |
| train_data: list, |
| val_data: Optional[list] = None, |
| num_epochs: int = 100, |
| save_dir: str = "checkpoints", |
| save_every: int = 10, |
| ): |
| """ |
| Full training procedure. |
| |
| Stage 1 (epochs 1-40%): VAE pretraining with KL annealing |
| Stage 2 (after stage 1): Phonotactic discriminator training |
| Stage 3 (epochs 40%-100%): Joint training with all losses |
| """ |
| print("=" * 60) |
| print("NeuroName Training") |
| print("=" * 60) |
| print(f"Training samples: {len(train_data)}") |
| print(f"Epochs: {num_epochs}") |
| print(f"Device: {self.device}") |
| print() |
| |
| |
| train_loader = create_dataloader( |
| train_data, self.char_vocab, self.semantic_vocab, |
| batch_size=self.config.batch_size, shuffle=True, |
| ) |
| |
| val_loader = None |
| if val_data: |
| val_loader = create_dataloader( |
| val_data, self.char_vocab, self.semantic_vocab, |
| batch_size=self.config.batch_size, shuffle=False, |
| ) |
| |
| |
| real_names = [item["name"] for item in train_data] |
| |
| |
| stage2_done = False |
| stage2_epoch = int(num_epochs * 0.3) |
| |
| for epoch in range(1, num_epochs + 1): |
| start_time = time.time() |
| |
| |
| if epoch == stage2_epoch and not stage2_done: |
| print("\n" + "=" * 40) |
| print("Stage 2: Training Phonotactic Discriminator") |
| print("=" * 40) |
| self.train_phonotactic_discriminator(real_names, num_steps=500) |
| stage2_done = True |
| print() |
| |
| |
| metrics = self.train_epoch_vae(train_loader, epoch) |
| |
| |
| val_metrics = {} |
| if val_loader and epoch % 5 == 0: |
| val_metrics = self.validate(val_loader) |
| |
| |
| elapsed = time.time() - start_time |
| print( |
| f"Epoch {epoch:3d}/{num_epochs} | " |
| f"Loss: {metrics['loss']:.4f} | " |
| f"Recon: {metrics['recon_loss']:.4f} | " |
| f"KL: {metrics['kl_loss']:.4f} | " |
| f"KL_w: {metrics['kl_weight']:.4f} | " |
| f"Style: {metrics['style_loss']:.4f} | " |
| f"Time: {elapsed:.1f}s" |
| + (f" | Val: {val_metrics.get('val_loss', 0):.4f}" if val_metrics else "") |
| ) |
| |
| |
| self.history["train_loss"].append(metrics["loss"]) |
| self.history["recon_loss"].append(metrics["recon_loss"]) |
| self.history["kl_loss"].append(metrics["kl_loss"]) |
| if val_metrics: |
| self.history["val_loss"].append(val_metrics["val_loss"]) |
| |
| |
| if epoch % 10 == 0 or epoch == 1: |
| print("\n Sample generations:") |
| samples = self.generate_samples() |
| for s in samples: |
| names_str = ", ".join(s["generated"][:3]) |
| print(f" [{s['style']}] {s['hints']} → {names_str}") |
| print() |
| |
| |
| if epoch % save_every == 0 or epoch == num_epochs: |
| path = os.path.join(save_dir, f"checkpoint_epoch_{epoch}.pt") |
| self.save_checkpoint(path, epoch) |
| |
| |
| if metrics["loss"] < self.best_loss: |
| self.best_loss = metrics["loss"] |
| path = os.path.join(save_dir, "best_model.pt") |
| self.save_checkpoint(path, epoch) |
| |
| print("\n" + "=" * 60) |
| print("Training complete!") |
| print(f"Best loss: {self.best_loss:.4f}") |
| print("=" * 60) |
|
|
|
|
| def main(): |
| args = parse_args() |
| set_seed(args.seed) |
| |
| |
| if args.config: |
| config = NeuroNameConfig.load(args.config) |
| else: |
| config = NeuroNameConfig() |
| |
| |
| config.num_epochs = args.epochs |
| config.batch_size = args.batch_size |
| config.learning_rate = args.lr |
| |
| |
| if args.device == "auto": |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| else: |
| device = args.device |
| print(f"Using device: {device}") |
| |
| |
| print("Generating training data...") |
| train_data = get_synthetic_training_data(num_samples=args.num_train_samples, seed=args.seed) |
| |
| |
| split_idx = int(len(train_data) * 0.9) |
| val_data = train_data[split_idx:] |
| train_data = train_data[:split_idx] |
| print(f"Train: {len(train_data)} samples, Val: {len(val_data)} samples") |
| |
| |
| trainer = Trainer(config, device=device) |
| trainer.train( |
| train_data=train_data, |
| val_data=val_data, |
| num_epochs=config.num_epochs, |
| save_dir=args.save_dir, |
| save_every=args.save_every, |
| ) |
| |
| |
| config.save_json(os.path.join(args.save_dir, "config.json")) |
| print(f"\nConfig saved to {args.save_dir}/config.json") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|