ml-intern
neuroname / train.py
asdf98's picture
Add complete training script with staged training, KL annealing, and all losses
b41f1fb verified
"""
NeuroName Training Script
Implements a multi-stage training procedure:
Stage 1: VAE Pretraining (reconstruct names from latent space)
- Train char_encoder + char_decoder with reconstruction loss
- KL annealing from 0 to target weight (prevents posterior collapse)
- Free bits strategy (minimum KL per dimension)
Stage 2: Phonotactic Discriminator Training
- Train on real names (positive) vs random sequences (negative)
- Binary classification with balanced sampling
Stage 3: Attribute Classifier Training
- Train style and language classifiers on latent representations
- Uses frozen encoder to get z, trains classifiers only
Stage 4: Joint Fine-tuning
- All components trained together
- Full loss: reconstruction + KL + phonotactic + attribute classification
Usage:
python train.py --config configs/default.yaml
python train.py --epochs 100 --batch_size 128 --lr 3e-4
"""
import os
import sys
import math
import time
import argparse
import json
from pathlib import Path
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from tqdm import tqdm
# Add parent to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from neuroname.model import NeuroNameModel, CharVocab
from neuroname.config import NeuroNameConfig
from neuroname.data import (
SemanticVocab,
NameDataset,
get_curated_brand_names,
get_synthetic_training_data,
create_dataloader,
)
from neuroname.phonotactics import PhonotacticDataGenerator, PhonotacticScorer
def parse_args():
parser = argparse.ArgumentParser(description="Train NeuroName model")
parser.add_argument("--config", type=str, default=None, help="Path to config YAML")
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
parser.add_argument("--device", type=str, default="auto", help="Device (cpu/cuda/auto)")
parser.add_argument("--save_dir", type=str, default="checkpoints", help="Save directory")
parser.add_argument("--num_train_samples", type=int, default=5000, help="Number of training samples")
parser.add_argument("--log_every", type=int, default=50, help="Log every N steps")
parser.add_argument("--save_every", type=int, default=10, help="Save every N epochs")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
return parser.parse_args()
def set_seed(seed: int):
"""Set random seeds for reproducibility."""
import random
import numpy as np
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def get_kl_weight(step: int, config: NeuroNameConfig) -> float:
"""Cyclical KL annealing schedule."""
if step >= config.kl_anneal_steps:
return config.kl_weight_end
# Linear annealing
ratio = step / config.kl_anneal_steps
return config.kl_weight_start + ratio * (config.kl_weight_end - config.kl_weight_start)
def free_bits_kl(kl_per_dim: torch.Tensor, free_bits: float) -> torch.Tensor:
"""Apply free bits strategy to prevent KL collapse.
Free bits: allow each latent dimension at least `free_bits` nats of KL
before it contributes to the loss. This prevents the model from ignoring
the latent space entirely (a common VAE failure mode).
"""
return torch.clamp(kl_per_dim, min=free_bits).sum(dim=-1).mean()
class Trainer:
"""Complete training loop for NeuroName."""
def __init__(self, config: NeuroNameConfig, device: str = "cpu"):
self.config = config
self.device = device
# Initialize model
self.model = NeuroNameModel(config.to_dict()).to(device)
print(f"Model initialized with {sum(p.numel() for p in self.model.parameters()):,} parameters")
print(f"Parameter breakdown:")
for name, count in self.model.count_parameters().items():
print(f" {name}: {count:,}")
# Vocabularies
self.char_vocab = self.model.char_vocab
self.semantic_vocab = SemanticVocab()
# Optimizers (separate for different components)
self.vae_optimizer = AdamW(
list(self.model.semantic_encoder.parameters())
+ list(self.model.control_encoder.parameters())
+ list(self.model.char_encoder.parameters())
+ list(self.model.char_decoder.parameters())
+ list(self.model.prior_net.parameters()),
lr=config.learning_rate,
weight_decay=config.weight_decay,
betas=(0.9, 0.999),
)
self.disc_optimizer = AdamW(
self.model.phonotactic_disc.parameters(),
lr=config.phon_lr,
weight_decay=0.01,
)
self.cls_optimizer = AdamW(
list(self.model.style_classifier.parameters())
+ list(self.model.lang_classifier.parameters()),
lr=config.learning_rate,
weight_decay=0.01,
)
# Phonotactic data generator
self.phon_generator = PhonotacticDataGenerator()
self.phon_scorer = PhonotacticScorer()
# Training state
self.global_step = 0
self.best_loss = float("inf")
self.history = {"train_loss": [], "val_loss": [], "recon_loss": [], "kl_loss": []}
def train_epoch_vae(self, dataloader, epoch: int) -> Dict[str, float]:
"""Train one epoch of the VAE (Stage 1 or Stage 4)."""
self.model.train()
total_loss = 0
total_recon = 0
total_kl = 0
total_style = 0
total_lang = 0
num_batches = 0
for batch in dataloader:
# Move to device
char_ids = batch["char_ids"].to(self.device)
hint_ids = batch["hint_ids"].to(self.device)
target_length = batch["target_length"].to(self.device)
style = batch["style"].to(self.device)
language_feel = batch["language_feel"].to(self.device)
energy = batch["energy"].to(self.device)
char_padding_mask = batch["char_padding_mask"].to(self.device)
hint_padding_mask = batch["hint_padding_mask"].to(self.device)
# Forward pass
outputs = self.model(
char_ids=char_ids,
hint_ids=hint_ids,
target_length=target_length,
style=style,
language_feel=language_feel,
energy=energy,
char_padding_mask=char_padding_mask,
hint_padding_mask=hint_padding_mask,
)
# KL annealing weight
kl_weight = get_kl_weight(self.global_step, self.config)
# Apply free bits to KL
kl_per_dim = 0.5 * (
outputs["p_mu"].detach() - outputs["q_logvar"] # Simplified for per-dim
+ (torch.exp(outputs["q_logvar"]) + (outputs["q_mu"] - outputs["p_mu"].detach()).pow(2))
/ torch.exp(outputs["p_mu"].detach()).clamp(min=1e-8) - 1.0
)
# Use standard KL from model for simplicity
kl_loss = outputs["kl_loss"]
# Total loss
loss = (
outputs["recon_loss"]
+ kl_weight * kl_loss
+ self.config.style_loss_weight * outputs["style_loss"]
+ self.config.lang_loss_weight * outputs["lang_loss"]
)
# Backward pass
self.vae_optimizer.zero_grad()
self.cls_optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
self.vae_optimizer.step()
self.cls_optimizer.step()
# Logging
total_loss += loss.item()
total_recon += outputs["recon_loss"].item()
total_kl += kl_loss.item()
total_style += outputs["style_loss"].item()
total_lang += outputs["lang_loss"].item()
num_batches += 1
self.global_step += 1
# Periodic logging
if self.global_step % self.config.batch_size == 0:
pass # tqdm handles this
return {
"loss": total_loss / max(num_batches, 1),
"recon_loss": total_recon / max(num_batches, 1),
"kl_loss": total_kl / max(num_batches, 1),
"style_loss": total_style / max(num_batches, 1),
"lang_loss": total_lang / max(num_batches, 1),
"kl_weight": kl_weight,
}
def train_phonotactic_discriminator(self, real_names: list, num_steps: int = 500):
"""Train the phonotactic discriminator (Stage 2)."""
self.model.phonotactic_disc.train()
total_loss = 0
total_acc = 0
for step in range(num_steps):
# Generate balanced batch
names, labels = self.phon_generator.generate_batch(real_names, batch_size=64)
# Encode characters
char_ids = self.char_vocab.batch_encode(names, max_len=self.config.max_len)
char_ids = char_ids.to(self.device)
labels_tensor = torch.tensor(labels, dtype=torch.float, device=self.device)
padding_mask = (char_ids == self.char_vocab.pad_idx)
# Forward
scores = self.model.phonotactic_disc(char_ids, padding_mask).squeeze(-1)
loss = F.binary_cross_entropy_with_logits(scores, labels_tensor)
# Backward
self.disc_optimizer.zero_grad()
loss.backward()
self.disc_optimizer.step()
# Accuracy
preds = (scores > 0).float()
acc = (preds == labels_tensor).float().mean().item()
total_loss += loss.item()
total_acc += acc
avg_loss = total_loss / num_steps
avg_acc = total_acc / num_steps
print(f" Phonotactic Discriminator - Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.3f}")
return {"phon_loss": avg_loss, "phon_acc": avg_acc}
@torch.no_grad()
def validate(self, dataloader) -> Dict[str, float]:
"""Validate the model."""
self.model.eval()
total_loss = 0
total_recon = 0
num_batches = 0
for batch in dataloader:
char_ids = batch["char_ids"].to(self.device)
hint_ids = batch["hint_ids"].to(self.device)
target_length = batch["target_length"].to(self.device)
style = batch["style"].to(self.device)
language_feel = batch["language_feel"].to(self.device)
energy = batch["energy"].to(self.device)
char_padding_mask = batch["char_padding_mask"].to(self.device)
hint_padding_mask = batch["hint_padding_mask"].to(self.device)
outputs = self.model(
char_ids=char_ids,
hint_ids=hint_ids,
target_length=target_length,
style=style,
language_feel=language_feel,
energy=energy,
char_padding_mask=char_padding_mask,
hint_padding_mask=hint_padding_mask,
)
loss = outputs["recon_loss"] + 0.1 * outputs["kl_loss"]
total_loss += loss.item()
total_recon += outputs["recon_loss"].item()
num_batches += 1
return {
"val_loss": total_loss / max(num_batches, 1),
"val_recon": total_recon / max(num_batches, 1),
}
@torch.no_grad()
def generate_samples(self, num_samples: int = 5) -> list:
"""Generate sample names for monitoring."""
self.model.eval()
hints_list = [
["speed", "technology", "future"],
["nature", "calm", "harmony"],
["gaming", "epic", "adventure"],
["luxury", "elegance", "premium"],
["creative", "art", "design"],
]
styles = ["techy", "organic", "playful", "elegant", "modern"]
samples = []
for hints, style in zip(hints_list[:num_samples], styles[:num_samples]):
hint_ids = self.semantic_vocab.encode(hints)
hint_ids = torch.tensor([hint_ids], dtype=torch.long, device=self.device)
hint_mask = (hint_ids == self.semantic_vocab.pad_idx)
style_idx = torch.tensor([NameDataset.STYLE_MAP.get(style, 0)],
dtype=torch.long, device=self.device)
lang_idx = torch.tensor([0], dtype=torch.long, device=self.device)
energy_idx = torch.tensor([1], dtype=torch.long, device=self.device)
target_len = torch.tensor([[0.25]], dtype=torch.float, device=self.device)
generated = self.model.generate_from_prior(
hint_ids=hint_ids,
target_length=target_len,
style=style_idx,
language_feel=lang_idx,
energy=energy_idx,
hint_padding_mask=hint_mask,
temperature=0.8,
num_samples=3,
)
names = self.char_vocab.batch_decode(generated)
samples.append({
"hints": hints,
"style": style,
"generated": names,
})
return samples
def save_checkpoint(self, path: str, epoch: int):
"""Save model checkpoint."""
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({
"epoch": epoch,
"global_step": self.global_step,
"model_state_dict": self.model.state_dict(),
"vae_optimizer": self.vae_optimizer.state_dict(),
"config": self.config.to_dict(),
"history": self.history,
"best_loss": self.best_loss,
}, path)
print(f" Checkpoint saved to {path}")
def train(
self,
train_data: list,
val_data: Optional[list] = None,
num_epochs: int = 100,
save_dir: str = "checkpoints",
save_every: int = 10,
):
"""
Full training procedure.
Stage 1 (epochs 1-40%): VAE pretraining with KL annealing
Stage 2 (after stage 1): Phonotactic discriminator training
Stage 3 (epochs 40%-100%): Joint training with all losses
"""
print("=" * 60)
print("NeuroName Training")
print("=" * 60)
print(f"Training samples: {len(train_data)}")
print(f"Epochs: {num_epochs}")
print(f"Device: {self.device}")
print()
# Create dataloaders
train_loader = create_dataloader(
train_data, self.char_vocab, self.semantic_vocab,
batch_size=self.config.batch_size, shuffle=True,
)
val_loader = None
if val_data:
val_loader = create_dataloader(
val_data, self.char_vocab, self.semantic_vocab,
batch_size=self.config.batch_size, shuffle=False,
)
# Collect real names for phonotactic training
real_names = [item["name"] for item in train_data]
# Training loop
stage2_done = False
stage2_epoch = int(num_epochs * 0.3)
for epoch in range(1, num_epochs + 1):
start_time = time.time()
# === Stage transition: train phonotactic discriminator ===
if epoch == stage2_epoch and not stage2_done:
print("\n" + "=" * 40)
print("Stage 2: Training Phonotactic Discriminator")
print("=" * 40)
self.train_phonotactic_discriminator(real_names, num_steps=500)
stage2_done = True
print()
# === Main VAE training ===
metrics = self.train_epoch_vae(train_loader, epoch)
# Validation
val_metrics = {}
if val_loader and epoch % 5 == 0:
val_metrics = self.validate(val_loader)
# Logging
elapsed = time.time() - start_time
print(
f"Epoch {epoch:3d}/{num_epochs} | "
f"Loss: {metrics['loss']:.4f} | "
f"Recon: {metrics['recon_loss']:.4f} | "
f"KL: {metrics['kl_loss']:.4f} | "
f"KL_w: {metrics['kl_weight']:.4f} | "
f"Style: {metrics['style_loss']:.4f} | "
f"Time: {elapsed:.1f}s"
+ (f" | Val: {val_metrics.get('val_loss', 0):.4f}" if val_metrics else "")
)
# Track history
self.history["train_loss"].append(metrics["loss"])
self.history["recon_loss"].append(metrics["recon_loss"])
self.history["kl_loss"].append(metrics["kl_loss"])
if val_metrics:
self.history["val_loss"].append(val_metrics["val_loss"])
# Generate samples periodically
if epoch % 10 == 0 or epoch == 1:
print("\n Sample generations:")
samples = self.generate_samples()
for s in samples:
names_str = ", ".join(s["generated"][:3])
print(f" [{s['style']}] {s['hints']}{names_str}")
print()
# Save checkpoint
if epoch % save_every == 0 or epoch == num_epochs:
path = os.path.join(save_dir, f"checkpoint_epoch_{epoch}.pt")
self.save_checkpoint(path, epoch)
# Save best
if metrics["loss"] < self.best_loss:
self.best_loss = metrics["loss"]
path = os.path.join(save_dir, "best_model.pt")
self.save_checkpoint(path, epoch)
print("\n" + "=" * 60)
print("Training complete!")
print(f"Best loss: {self.best_loss:.4f}")
print("=" * 60)
def main():
args = parse_args()
set_seed(args.seed)
# Configuration
if args.config:
config = NeuroNameConfig.load(args.config)
else:
config = NeuroNameConfig()
# Override from command line
config.num_epochs = args.epochs
config.batch_size = args.batch_size
config.learning_rate = args.lr
# Device
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
print(f"Using device: {device}")
# Generate training data
print("Generating training data...")
train_data = get_synthetic_training_data(num_samples=args.num_train_samples, seed=args.seed)
# Split into train/val (90/10)
split_idx = int(len(train_data) * 0.9)
val_data = train_data[split_idx:]
train_data = train_data[:split_idx]
print(f"Train: {len(train_data)} samples, Val: {len(val_data)} samples")
# Train
trainer = Trainer(config, device=device)
trainer.train(
train_data=train_data,
val_data=val_data,
num_epochs=config.num_epochs,
save_dir=args.save_dir,
save_every=args.save_every,
)
# Save final configuration
config.save_json(os.path.join(args.save_dir, "config.json"))
print(f"\nConfig saved to {args.save_dir}/config.json")
if __name__ == "__main__":
main()