#!/usr/bin/env python3 """ Training script for Circuit Transformer. Usage: python circuits/train.py --data hf:roneneldan/TinyStories --preset tiny --epochs 1 --gpu 0 python circuits/train.py --data path/to/corpus.txt --dims 256 --layers 6 --fp16 """ import gc import os import time import math import random from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F from torch.cuda.amp import GradScaler from torch.amp import autocast from .config import CircuitConfig, parse_args from .model import CircuitTransformer, count_parameters from .mirrored import MirroredConfig, MirroredTransformer, count_mirrored_parameters from .graft_g2lu import G2LU_GraftedModel, save_g2lu_checkpoint from .layers import build_word_start_table, compute_word_positions from .data import get_tokenizer, load_data, create_dataloader def corrupt_tokens(input_ids, ratio, vocab_size): """Replace random tokens with random vocab tokens for denoising autoencoder. Returns (corrupted_ids, mask) where mask is True at corrupted positions. """ mask = torch.rand(input_ids.shape, device=input_ids.device) < ratio mask[:, 0] = False # never corrupt first token (BOS/start) random_tokens = torch.randint(0, vocab_size, input_ids.shape, device=input_ids.device) corrupted = input_ids.clone() corrupted[mask] = random_tokens[mask] return corrupted, mask @torch.no_grad() def evaluate(config, model, dataloader, device, use_amp=False, amp_dtype=torch.float16, mid_run_eval=False, word_start_table=None): """Run validation and return avg loss + perplexity.""" model.eval() total_loss = 0.0 n_batches = 0 for batch in dataloader: input_ids = batch["input_ids"].to(device) labels = batch["labels"].to(device) word_positions = None if word_start_table is not None: word_positions = compute_word_positions(input_ids, word_start_table) if use_amp: with autocast('cuda', dtype=amp_dtype): output = model(input_ids, labels=labels, word_positions=word_positions) else: output = model(input_ids, labels=labels, word_positions=word_positions) total_loss += output["loss"].item() n_batches += 1 if n_batches % (config.log_every * 10) == 0: avg_loss = total_loss / max(n_batches, 1) ppl = math.exp(min(avg_loss, 20)) print( f"batch {n_batches:6d}/{len(dataloader):6d} | " f"Loss {total_loss / n_batches:.4f} | " f"PPL {ppl:8.2f}" ) if mid_run_eval and n_batches >= 1500 : break if not mid_run_eval: model.train() avg_loss = total_loss / max(n_batches, 1) ppl = math.exp(min(avg_loss, 20)) # cap to avoid overflow return avg_loss, ppl def get_lr(step: int, warmup_steps: int, max_steps: int, max_lr: float, min_lr: float = 0.0, delay: int = 0) -> float: """Cosine learning rate schedule with warmup and optional delay. With delay > 0, the schedule is shifted: Steps 0..delay: LR = 0 (frozen) Steps delay..delay+warmup: linear ramp 0 → max_lr Steps delay+warmup..max_steps: cosine decay max_lr → min_lr """ if step < delay: return 0.0 effective_step = step - delay effective_max = max(1, max_steps - delay) if effective_step < warmup_steps: return max_lr * effective_step / warmup_steps if effective_step >= effective_max: return min_lr progress = (effective_step - warmup_steps) / (effective_max - warmup_steps) return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress)) def save_checkpoint( model: nn.Module, optimizer: torch.optim.Optimizer, step: int, epoch: int, loss: float, config, path: str, model_type: str = "standard", epoch_step: int = 0, best_val_loss: float | None = None, scaler=None, tokenizer_name: str = "gpt2", ): """Save training checkpoint. Args: epoch: Next epoch to start on resume (completed epoch count). epoch_step: Batches already processed in `epoch` (0 if epoch is complete). optimizer_mid: Middle optimizer for dual-path training (optional). """ checkpoint = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": step, "epoch": epoch, "epoch_step": epoch_step, "loss": loss, "config": config.to_dict(), "model_type": model_type, "tokenizer_name": tokenizer_name, } if best_val_loss is not None: checkpoint["best_val_loss"] = best_val_loss if scaler is not None: checkpoint["scaler"] = scaler.state_dict() torch.save(checkpoint, path) def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict: """Migrate checkpoint state_dict to match current model architecture. Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle). """ model_keys = set(model.state_dict().keys()) ckpt_keys = set(state_dict.keys()) missing = model_keys - ckpt_keys unexpected = ckpt_keys - model_keys if not missing and not unexpected: return state_dict # perfect match, no migration needed migrated = dict(state_dict) migrations = [] # SwiGLU → MirroredSwiGLU: w3 → gate_expand (dual_gate_middle upgrade) for key in list(unexpected): if ".ffn.gate_expand.weight" in key: new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight") if new_key in missing: migrated[new_key] = migrated.pop(key) missing.discard(new_key) unexpected.discard(key) migrations.append(f" {key} → {new_key}") if ".ffn.gate_compress.weight" in key: new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight") if new_key in missing: migrated[new_key] = migrated.pop(key) missing.discard(new_key) unexpected.discard(key) migrations.append(f" {key} → {new_key}") if migrations: print(f"State dict migration ({len(migrations)} keys renamed):") for m in migrations: print(m) # Report remaining missing keys (freshly initialized) still_missing = model_keys - set(migrated.keys()) if still_missing: print(f" New parameters (freshly initialized): {len(still_missing)}") for k in sorted(still_missing): print(f" {k}") return migrated def load_checkpoint(path: str, model: nn.Module, optimizer: torch.optim.Optimizer = None, scaler=None, reset:bool = False): """Load training checkpoint. Returns dict with resume info.""" checkpoint = torch.load(path, map_location="cpu", weights_only=False) state_dict = _migrate_state_dict(checkpoint["model"], model) model.load_state_dict(state_dict, strict=False) if not reset: if optimizer is not None and "optimizer" in checkpoint: optimizer.load_state_dict(checkpoint["optimizer"]) if scaler is not None and "scaler" in checkpoint: scaler.load_state_dict(checkpoint["scaler"]) return { "step": checkpoint.get("step", 0), "epoch": checkpoint.get("epoch", 0), "epoch_step": checkpoint.get("epoch_step", 0), "best_val_loss": checkpoint.get("best_val_loss", float("inf")), } def train(): config, args = parse_args() # Setup device device = torch.device(f"cuda:{config.gpu}" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") # Load tokenizer and data print(f"Loading data from: {args.data}") model_type = args.arch tokenizer_name = getattr(args, 'tokenizer', 'gpt2') if model_type == "graft_g2lu": tokenizer_name = args.pretrained tokenizer = get_tokenizer(tokenizer_name) config.vocab_size = len(tokenizer) print(f"Tokenizer: {tokenizer_name} (vocab_size={config.vocab_size})") cache_dir = None if args.no_cache else args.cache_dir dataset = load_data( args.data, tokenizer, config.max_seq_len, text_column=args.text_column, num_samples=args.num_samples, cache_dir=cache_dir, data_format=args.data_format, ) print(f"Loaded {len(dataset):,} chunks") # Train/val split val_split = args.val_split if val_split > 0 and len(dataset) > 20: train_dataset, val_dataset = dataset.split(val_split) print(f"Split: {len(train_dataset):,} train / {len(val_dataset):,} val ({val_split:.0%})") else: train_dataset = dataset val_dataset = None # Create dataloaders dataloader = create_dataloader( train_dataset, config.batch_size, shuffle=True, ) val_dataloader = None if val_dataset is not None: val_dataloader = create_dataloader( val_dataset, config.batch_size, shuffle=False, ) # Create model if model_type == "mirrored": model_config = MirroredConfig( vocab_size=config.vocab_size, hidden_size=config.hidden_size, num_heads=config.num_heads, num_kv_heads=config.num_kv_heads, num_layers=config.num_layers, n_middle=args.n_middle, max_seq_len=config.max_seq_len, dropout=config.dropout, use_g2lu=not getattr(args, 'no_g2lu', False), aux_skip_k=getattr(args, 'aux_skip', 0), aux_skip_weight=getattr(args, 'aux_weight', 0.1), word_rope_dims=getattr(config, 'word_rope_dims', 0), word_rope_base=getattr(config, 'word_rope_base', 10.0), embed_dim=getattr(config, 'embed_dim', 0), head_dim=getattr(config, 'head_dim', 0), ) model = MirroredTransformer(model_config).to(device) param_info = count_mirrored_parameters(model) num_params = param_info["unique"] print(f"Model: MirroredTransformer") print(f" Virtual layers: {model.total_virtual_layers} ({model_config.n_mirror} mirror pairs + {model_config.n_middle} middle)") print(f" Parameters: {num_params:,} ({num_params/1e6:.1f}M unique)") print(f" Shared FFN base: {param_info['shared_ffn_base']:,}") print(f" Direction gates: {param_info['direction_gates']:,}") print(f" FFN gating: {'G²LU (nested dual gate)' if model_config.use_g2lu else 'SwiGLU (vanilla)'}") if model_config.num_kv_heads is not None: print(f" GQA: {model_config.num_heads}Q / {model_config.num_kv_heads}KV ({model_config.num_heads // model_config.num_kv_heads}:1 ratio)") if model_config.aux_skip_k > 0: print(f" Aux skip prediction: t+{model_config.aux_skip_k} (weight={model_config.aux_skip_weight})") if getattr(model_config, 'embed_dim', 0) > 0: std_embed = config.vocab_size * config.hidden_size fact_embed = config.vocab_size * model_config.embed_dim + model_config.embed_dim * config.hidden_size print(f" Factorized embedding: {model_config.embed_dim} → {config.hidden_size} (saves {(std_embed - fact_embed):,} params)") if getattr(model_config, 'head_dim', 0) > 0: std_head = config.hidden_size * config.vocab_size mlp_head = config.hidden_size * model_config.head_dim + model_config.head_dim * config.vocab_size print(f" MLP head: {config.hidden_size} → {model_config.head_dim} → vocab (saves {(std_head - mlp_head):,} params)") elif model_type == "graft_g2lu": assert args.pretrained, "--pretrained is required for graft_g2lu architecture" amp_dtype = torch.bfloat16 if config.bf16 else (torch.float16 if config.fp16 else torch.float32) model = G2LU_GraftedModel( pretrained_name=args.pretrained, align_weight=args.align_weight, warmup_steps=args.graft_warmup, device=device, dtype=amp_dtype, ) model_config = None # No CircuitConfig for HF models num_params = sum(p.numel() for p in model.model.parameters() if p.requires_grad) else: model_config = config model = CircuitTransformer(config).to(device) num_params = count_parameters(model) print(f"Model: CircuitTransformer") print(f" Parameters: {num_params:,} ({num_params/1e6:.1f}M)") if getattr(config, 'aux_skip_k', 0) > 0: print(f" Aux skip prediction: t+{config.aux_skip_k} (weight={config.aux_skip_weight})") if getattr(config, 'embed_dim', 0) > 0: std_embed = config.vocab_size * config.hidden_size fact_embed = config.vocab_size * config.embed_dim + config.embed_dim * config.hidden_size print(f" Factorized embedding: {config.embed_dim} → {config.hidden_size} (saves {(std_embed - fact_embed):,} params)") if getattr(config, 'head_dim', 0) > 0: std_head = config.hidden_size * config.vocab_size mlp_head = config.hidden_size * config.head_dim + config.head_dim * config.vocab_size print(f" MLP head: {config.hidden_size} → {config.head_dim} → vocab (saves {(std_head - mlp_head):,} params)") # Build word-position table if enabled word_rope_dims = getattr(config, 'word_rope_dims', 0) if word_rope_dims > 0: word_start_table = build_word_start_table(tokenizer, len(tokenizer)).to(device) print(f" Word-position RoPE: {word_rope_dims} dims, base={getattr(config, 'word_rope_base', 10.0)}") print(f" Word starters in vocab: {word_start_table.sum().item():,} / {len(tokenizer):,}") else: word_start_table = None # Keep raw reference for set_gate_step (torch.compile wraps the model) raw_model = model # Optionally compile if config.compile and hasattr(torch, "compile"): print("Compiling model with torch.compile...") model = torch.compile(raw_model) # Optimizer — with optional staggered warmup and dual-path training grad_accum = getattr(args, 'grad_accum', 1) opt_params = list(raw_model.trainable_parameters()) if model_type == "graft_g2lu" else model.parameters() optimizer = torch.optim.AdamW( opt_params, lr=config.learning_rate, weight_decay=config.weight_decay, betas=(0.9, 0.95), ) # Mixed precision use_amp = (config.fp16 or config.bf16) and device.type == "cuda" amp_dtype = torch.bfloat16 if config.bf16 else torch.float16 scaler = GradScaler() if (config.fp16 and use_amp) else None if use_amp: print(f" Mixed precision: {'BF16' if config.bf16 else 'FP16'}" + (" (no scaler)" if scaler is None else " (with GradScaler)")) # Resume from checkpoint start_step = 0 start_epoch = 0 skip_batches = 0 best_val_loss = float("inf") if args.resume: print(f"Resuming from: {args.resume}") resume_info = load_checkpoint(args.resume, model, optimizer, scaler, args.reset) if not args.reset: start_step = resume_info["step"] start_epoch = resume_info["epoch"] skip_batches = resume_info["epoch_step"] best_val_loss = resume_info["best_val_loss"] print(f"Resumed at step {start_step}, epoch {start_epoch}" + (f", skipping {skip_batches} batches" if skip_batches > 0 else "")) if best_val_loss < float("inf"): print(f" Best val loss so far: {best_val_loss:.4f} (PPL {math.exp(min(best_val_loss, 20)):.2f})") # Setup checkpoint directory checkpoint_dir = Path(config.checkpoint_dir) checkpoint_dir.mkdir(parents=True, exist_ok=True) # Training loop steps_per_epoch = math.ceil(len(dataloader) / grad_accum) max_steps = config.epochs * steps_per_epoch tokens_per_step = config.batch_size * grad_accum * config.max_seq_len total_train_tokens = config.epochs * len(dataloader) * config.batch_size * config.max_seq_len step = start_step model.train() print(f"\nStarting training:") print(f" Epochs: {config.epochs}") print(f" Batch size: {config.batch_size}" + (f" x {grad_accum} accum = {config.batch_size * grad_accum} effective" if grad_accum > 1 else "")) print(f" Steps per epoch: {steps_per_epoch}" + (f" ({len(dataloader)} micro-batches)" if grad_accum > 1 else "")) print(f" Total steps: {max_steps}") print(f" Total tokens: {total_train_tokens:,} ({total_train_tokens/1e6:.1f}M)") if num_params > 0: print(f" Tokens/param ratio: {total_train_tokens/num_params:.1f}x (Chinchilla=20x)") print(f" Learning rate: {config.learning_rate}" + (f" → {config.min_lr}" if config.min_lr > 0 else "")) print(f" Mixed precision: {use_amp}") print(f" Validation: {'enabled' if val_dataloader else 'disabled'}") print() total_loss = 0.0 log_steps = 0 total_tokens_seen = step * tokens_per_step # best_val_loss already set in resume section above h_mid_buffer = None last_align_val = float("inf") start_time = time.time() for epoch in range(start_epoch, config.epochs): epoch_start = time.time() epoch_loss = 0.0 epoch_steps = 0 micro_batches = [] epoch_micro_batches = skip_batches if epoch == start_epoch else 0 for batch_idx, batch in enumerate(dataloader): # Skip already-processed batches on resume if epoch == start_epoch and batch_idx < skip_batches: continue micro_batches.append(batch) epoch_micro_batches += 1 # Accumulate micro-batches (flush at accum boundary or epoch end) if len(micro_batches) < grad_accum and batch_idx < len(dataloader) - 1: continue n_micro = len(micro_batches) actual_tokens = n_micro * config.batch_size * config.max_seq_len # Update learning rate (per-group delays for staggered warmup) for param_group in optimizer.param_groups: delay = param_group.get("delay", 0) param_group["lr"] = get_lr(step, config.warmup_steps, max_steps, config.learning_rate, min_lr=config.min_lr, delay=delay) lr = optimizer.param_groups[0]["lr"] # for logging loss_ed_val = None loss_align_val = None grad_norm_mid = None absorb_loss_val = None # Update blend alpha for G²LU grafting if model_type == "graft_g2lu": raw_model.set_step(step) # === Standard single-path training with accumulation === optimizer.zero_grad() accum_loss = 0.0 accum_aux = 0.0 accum_align = 0.0 for mb in micro_batches: mb_ids = mb["input_ids"].to(device) mb_labels = mb["labels"].to(device) word_positions = None if word_start_table is not None: word_positions = compute_word_positions(mb_ids, word_start_table) if use_amp: with autocast('cuda', dtype=amp_dtype): output = model(mb_ids, labels=mb_labels, word_positions=word_positions) else: output = model(mb_ids, labels=mb_labels, word_positions=word_positions) if scaler: scaler.scale(output["loss"] / n_micro).backward() else: (output["loss"] / n_micro).backward() accum_loss += output["loss"].item() if "aux_loss" in output: accum_aux += output["aux_loss"].item() if "align_loss" in output: accum_align += output["align_loss"].item() if scaler: scaler.unscale_(optimizer) clip_params = list(raw_model.trainable_parameters()) if model_type == "graft_g2lu" else model.parameters() grad_norm = nn.utils.clip_grad_norm_(clip_params, config.grad_clip).item() if scaler: scaler.step(optimizer) scaler.update() else: optimizer.step() optimizer.zero_grad() loss_val = accum_loss / n_micro aux_loss_val = accum_aux / n_micro if accum_aux > 0 else None align_loss_val = accum_align / n_micro if accum_align > 0 else None total_loss += loss_val epoch_loss += loss_val epoch_steps += 1 log_steps += 1 total_tokens_seen += actual_tokens step += 1 # Logging if step % config.log_every == 0: avg_loss = total_loss / max(log_steps, 1) ppl = math.exp(min(avg_loss, 20)) elapsed = time.time() - start_time tok_s = (log_steps * tokens_per_step) / max(elapsed, 1e-6) extra = "" if aux_loss_val is not None: extra += f" | Aux {aux_loss_val:.3f}" if align_loss_val is not None: extra += f" | Align {align_loss_val:.4f}" print( f"Step {step:6d} | " f"Epoch {epoch+1}/{config.epochs} | " f"Loss {avg_loss:.4f} | " f"PPL {ppl:8.2f} | " f"GradN {grad_norm:.3f} | " f"LR {lr:.2e} | " f"Tok/s {tok_s:.0f}" f"{extra}" ) total_loss = 0.0 log_steps = 0 start_time = time.time() # Checkpointing if step % config.save_every == 0: ckpt_path = checkpoint_dir / f"step_{step:06d}.pt" if model_type == "graft_g2lu": save_g2lu_checkpoint(raw_model, optimizer, step, epoch, loss_val, str(ckpt_path), epoch_step=epoch_micro_batches, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name) else: save_checkpoint(model, optimizer, step, epoch, loss_val, model_config, str(ckpt_path), model_type, epoch_step=epoch_micro_batches, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name) print(f" Saved checkpoint: {ckpt_path}") gc.collect() torch.cuda.empty_cache() # Mid-training validation val_every = getattr(args, 'val_every', 0) if val_every > 0 and step % val_every == 0 and val_dataloader: val_loss, val_ppl = evaluate(config, model, val_dataloader, device, use_amp, amp_dtype, mid_run_eval=True, word_start_table=word_start_table) avg_train = epoch_loss / max(epoch_steps, 1) gap = val_loss - avg_train print(f" [Val @ step {step}] Loss: {val_loss:.4f} | PPL: {val_ppl:.2f} | Gap: {gap:+.4f}") if val_loss < best_val_loss: best_val_loss = val_loss best_path = checkpoint_dir / "best.pt" if model_type == "graft_g2lu": save_g2lu_checkpoint(raw_model, optimizer, step, epoch, val_loss, str(best_path), epoch_step=epoch_micro_batches, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name) else: save_checkpoint(model, optimizer, step, epoch, val_loss, model_config, str(best_path), model_type, epoch_step=epoch_micro_batches, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name) print(f" New best! Saved: {best_path}") gc.collect() torch.cuda.empty_cache() micro_batches = [] # --- Epoch summary --- epoch_elapsed = time.time() - epoch_start avg_epoch_loss = epoch_loss / max(epoch_steps, 1) epoch_ppl = math.exp(min(avg_epoch_loss, 20)) print(f"\n{'='*70}") print(f"Epoch {epoch+1}/{config.epochs} complete in {epoch_elapsed:.0f}s") print(f" Train loss: {avg_epoch_loss:.4f} | Train PPL: {epoch_ppl:.2f}") print(f" Tokens seen: {total_tokens_seen:,} ({total_tokens_seen/1e6:.1f}M)") # Validation if val_dataloader: val_loss, val_ppl = evaluate(config, model, val_dataloader, device, use_amp, amp_dtype, word_start_table=word_start_table) gap = val_loss - avg_epoch_loss print(f" Val loss: {val_loss:.4f} | Val PPL: {val_ppl:.2f} | Gap: {gap:+.4f}") if val_loss < best_val_loss: best_val_loss = val_loss best_path = checkpoint_dir / "best.pt" if model_type == "graft_g2lu": save_g2lu_checkpoint(raw_model, optimizer, step, epoch + 1, val_loss, str(best_path), epoch_step=0, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name) else: save_checkpoint(model, optimizer, step, epoch + 1, val_loss, model_config, str(best_path), model_type, epoch_step=0, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name) print(f" New best! Saved: {best_path}") # Free validation tensors gc.collect() torch.cuda.empty_cache() print(f"{'='*70}\n") # Save epoch checkpoint ckpt_path = checkpoint_dir / f"epoch_{epoch+1:02d}.pt" if model_type == "graft_g2lu": save_g2lu_checkpoint(raw_model, optimizer, step, epoch + 1, avg_epoch_loss, str(ckpt_path), epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name) else: save_checkpoint(model, optimizer, step, epoch + 1, avg_epoch_loss, model_config, str(ckpt_path), model_type, epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name) gc.collect() torch.cuda.empty_cache() # Save final checkpoint if step == start_step: print(f"\nNo training performed (already at step {step}/{max_steps}).") print(f" To train more epochs, increase --epochs beyond {config.epochs}.") else: final_path = checkpoint_dir / "latest.pt" if model_type == "graft_g2lu": save_g2lu_checkpoint(raw_model, optimizer, step, config.epochs, avg_epoch_loss, str(final_path), epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name) else: save_checkpoint(model, optimizer, step, config.epochs, avg_epoch_loss, model_config, str(final_path), model_type, epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name) print(f"\nTraining complete.") print(f" Final train loss: {avg_epoch_loss:.4f} | PPL: {epoch_ppl:.2f}") if val_dataloader: print(f" Best val loss: {best_val_loss:.4f} | PPL: {math.exp(min(best_val_loss, 20)):.2f}") print(f" Total tokens: {total_tokens_seen:,}") print(f" Checkpoints: {final_path}") if __name__ == "__main__": train()