| |
|
| | """
|
| | Training script for Circuit Transformer.
|
| |
|
| | Usage:
|
| | python circuits/train.py --data hf:roneneldan/TinyStories --preset tiny --epochs 1 --gpu 0
|
| | python circuits/train.py --data path/to/corpus.txt --dims 256 --layers 6 --fp16
|
| | """
|
| |
|
| | import gc
|
| | import os
|
| | import time
|
| | import math
|
| | import random
|
| | from pathlib import Path
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from torch.cuda.amp import GradScaler
|
| | from torch.amp import autocast
|
| |
|
| | from .config import CircuitConfig, parse_args
|
| | from .model import CircuitTransformer, count_parameters
|
| | from .mirrored import MirroredConfig, MirroredTransformer, count_mirrored_parameters
|
| | from .graft_g2lu import G2LU_GraftedModel, save_g2lu_checkpoint
|
| | from .layers import build_word_start_table, compute_word_positions
|
| | from .data import get_tokenizer, load_data, create_dataloader
|
| |
|
| | def corrupt_tokens(input_ids, ratio, vocab_size):
|
| | """Replace random tokens with random vocab tokens for denoising autoencoder.
|
| |
|
| | Returns (corrupted_ids, mask) where mask is True at corrupted positions.
|
| | """
|
| | mask = torch.rand(input_ids.shape, device=input_ids.device) < ratio
|
| | mask[:, 0] = False
|
| | random_tokens = torch.randint(0, vocab_size, input_ids.shape, device=input_ids.device)
|
| | corrupted = input_ids.clone()
|
| | corrupted[mask] = random_tokens[mask]
|
| | return corrupted, mask
|
| |
|
| |
|
| | @torch.no_grad()
|
| | def evaluate(config, model, dataloader, device, use_amp=False, amp_dtype=torch.float16, mid_run_eval=False,
|
| | word_start_table=None):
|
| | """Run validation and return avg loss + perplexity."""
|
| | model.eval()
|
| | total_loss = 0.0
|
| | n_batches = 0
|
| |
|
| | for batch in dataloader:
|
| | input_ids = batch["input_ids"].to(device)
|
| | labels = batch["labels"].to(device)
|
| | word_positions = None
|
| | if word_start_table is not None:
|
| | word_positions = compute_word_positions(input_ids, word_start_table)
|
| |
|
| | if use_amp:
|
| | with autocast('cuda', dtype=amp_dtype):
|
| | output = model(input_ids, labels=labels, word_positions=word_positions)
|
| | else:
|
| | output = model(input_ids, labels=labels, word_positions=word_positions)
|
| |
|
| | total_loss += output["loss"].item()
|
| | n_batches += 1
|
| |
|
| | if n_batches % (config.log_every * 10) == 0:
|
| | avg_loss = total_loss / max(n_batches, 1)
|
| | ppl = math.exp(min(avg_loss, 20))
|
| | print(
|
| | f"batch {n_batches:6d}/{len(dataloader):6d} | "
|
| | f"Loss {total_loss / n_batches:.4f} | "
|
| | f"PPL {ppl:8.2f}"
|
| | )
|
| |
|
| | if mid_run_eval and n_batches >= 1500 :
|
| | break
|
| |
|
| | if not mid_run_eval:
|
| | model.train()
|
| |
|
| | avg_loss = total_loss / max(n_batches, 1)
|
| | ppl = math.exp(min(avg_loss, 20))
|
| |
|
| | return avg_loss, ppl
|
| |
|
| |
|
| | def get_lr(step: int, warmup_steps: int, max_steps: int, max_lr: float, min_lr: float = 0.0, delay: int = 0) -> float:
|
| | """Cosine learning rate schedule with warmup and optional delay.
|
| |
|
| | With delay > 0, the schedule is shifted:
|
| | Steps 0..delay: LR = 0 (frozen)
|
| | Steps delay..delay+warmup: linear ramp 0 → max_lr
|
| | Steps delay+warmup..max_steps: cosine decay max_lr → min_lr
|
| | """
|
| | if step < delay:
|
| | return 0.0
|
| | effective_step = step - delay
|
| | effective_max = max(1, max_steps - delay)
|
| | if effective_step < warmup_steps:
|
| | return max_lr * effective_step / warmup_steps
|
| | if effective_step >= effective_max:
|
| | return min_lr
|
| | progress = (effective_step - warmup_steps) / (effective_max - warmup_steps)
|
| | return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
|
| |
|
| |
|
| | def save_checkpoint(
|
| | model: nn.Module,
|
| | optimizer: torch.optim.Optimizer,
|
| | step: int,
|
| | epoch: int,
|
| | loss: float,
|
| | config,
|
| | path: str,
|
| | model_type: str = "standard",
|
| | epoch_step: int = 0,
|
| | best_val_loss: float | None = None,
|
| | scaler=None,
|
| | tokenizer_name: str = "gpt2",
|
| | ):
|
| | """Save training checkpoint.
|
| |
|
| | Args:
|
| | epoch: Next epoch to start on resume (completed epoch count).
|
| | epoch_step: Batches already processed in `epoch` (0 if epoch is complete).
|
| | optimizer_mid: Middle optimizer for dual-path training (optional).
|
| | """
|
| | checkpoint = {
|
| | "model": model.state_dict(),
|
| | "optimizer": optimizer.state_dict(),
|
| | "step": step,
|
| | "epoch": epoch,
|
| | "epoch_step": epoch_step,
|
| | "loss": loss,
|
| | "config": config.to_dict(),
|
| | "model_type": model_type,
|
| | "tokenizer_name": tokenizer_name,
|
| | }
|
| | if best_val_loss is not None:
|
| | checkpoint["best_val_loss"] = best_val_loss
|
| | if scaler is not None:
|
| | checkpoint["scaler"] = scaler.state_dict()
|
| | torch.save(checkpoint, path)
|
| |
|
| |
|
| | def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict:
|
| | """Migrate checkpoint state_dict to match current model architecture.
|
| |
|
| | Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle).
|
| | """
|
| | model_keys = set(model.state_dict().keys())
|
| | ckpt_keys = set(state_dict.keys())
|
| |
|
| | missing = model_keys - ckpt_keys
|
| | unexpected = ckpt_keys - model_keys
|
| |
|
| | if not missing and not unexpected:
|
| | return state_dict
|
| |
|
| | migrated = dict(state_dict)
|
| | migrations = []
|
| |
|
| |
|
| | for key in list(unexpected):
|
| | if ".ffn.gate_expand.weight" in key:
|
| | new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight")
|
| | if new_key in missing:
|
| | migrated[new_key] = migrated.pop(key)
|
| | missing.discard(new_key)
|
| | unexpected.discard(key)
|
| | migrations.append(f" {key} → {new_key}")
|
| | if ".ffn.gate_compress.weight" in key:
|
| | new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight")
|
| | if new_key in missing:
|
| | migrated[new_key] = migrated.pop(key)
|
| | missing.discard(new_key)
|
| | unexpected.discard(key)
|
| | migrations.append(f" {key} → {new_key}")
|
| |
|
| | if migrations:
|
| | print(f"State dict migration ({len(migrations)} keys renamed):")
|
| | for m in migrations:
|
| | print(m)
|
| |
|
| | still_missing = model_keys - set(migrated.keys())
|
| | if still_missing:
|
| | print(f" New parameters (freshly initialized): {len(still_missing)}")
|
| | for k in sorted(still_missing):
|
| | print(f" {k}")
|
| |
|
| | return migrated
|
| |
|
| |
|
| | def load_checkpoint(path: str, model: nn.Module, optimizer: torch.optim.Optimizer = None,
|
| | scaler=None, reset:bool = False):
|
| | """Load training checkpoint. Returns dict with resume info."""
|
| | checkpoint = torch.load(path, map_location="cpu", weights_only=False)
|
| | state_dict = _migrate_state_dict(checkpoint["model"], model)
|
| | model.load_state_dict(state_dict, strict=False)
|
| | if not reset:
|
| | if optimizer is not None and "optimizer" in checkpoint:
|
| | optimizer.load_state_dict(checkpoint["optimizer"])
|
| | if scaler is not None and "scaler" in checkpoint:
|
| | scaler.load_state_dict(checkpoint["scaler"])
|
| | return {
|
| | "step": checkpoint.get("step", 0),
|
| | "epoch": checkpoint.get("epoch", 0),
|
| | "epoch_step": checkpoint.get("epoch_step", 0),
|
| | "best_val_loss": checkpoint.get("best_val_loss", float("inf")),
|
| | }
|
| |
|
| |
|
| | def train():
|
| | config, args = parse_args()
|
| |
|
| |
|
| | device = torch.device(f"cuda:{config.gpu}" if torch.cuda.is_available() else "cpu")
|
| | print(f"Device: {device}")
|
| |
|
| |
|
| | print(f"Loading data from: {args.data}")
|
| | model_type = args.arch
|
| | tokenizer_name = getattr(args, 'tokenizer', 'gpt2')
|
| | if model_type == "graft_g2lu":
|
| | tokenizer_name = args.pretrained
|
| | tokenizer = get_tokenizer(tokenizer_name)
|
| | config.vocab_size = len(tokenizer)
|
| | print(f"Tokenizer: {tokenizer_name} (vocab_size={config.vocab_size})")
|
| | cache_dir = None if args.no_cache else args.cache_dir
|
| | dataset = load_data(
|
| | args.data,
|
| | tokenizer,
|
| | config.max_seq_len,
|
| | text_column=args.text_column,
|
| | num_samples=args.num_samples,
|
| | cache_dir=cache_dir,
|
| | data_format=args.data_format,
|
| | )
|
| | print(f"Loaded {len(dataset):,} chunks")
|
| |
|
| |
|
| | val_split = args.val_split
|
| | if val_split > 0 and len(dataset) > 20:
|
| | train_dataset, val_dataset = dataset.split(val_split)
|
| | print(f"Split: {len(train_dataset):,} train / {len(val_dataset):,} val ({val_split:.0%})")
|
| | else:
|
| | train_dataset = dataset
|
| | val_dataset = None
|
| |
|
| |
|
| | dataloader = create_dataloader(
|
| | train_dataset,
|
| | config.batch_size,
|
| | shuffle=True,
|
| | )
|
| | val_dataloader = None
|
| | if val_dataset is not None:
|
| | val_dataloader = create_dataloader(
|
| | val_dataset,
|
| | config.batch_size,
|
| | shuffle=False,
|
| | )
|
| |
|
| |
|
| | if model_type == "mirrored":
|
| | model_config = MirroredConfig(
|
| | vocab_size=config.vocab_size,
|
| | hidden_size=config.hidden_size,
|
| | num_heads=config.num_heads,
|
| | num_kv_heads=config.num_kv_heads,
|
| | num_layers=config.num_layers,
|
| | n_middle=args.n_middle,
|
| | max_seq_len=config.max_seq_len,
|
| | dropout=config.dropout,
|
| | use_g2lu=not getattr(args, 'no_g2lu', False),
|
| | aux_skip_k=getattr(args, 'aux_skip', 0),
|
| | aux_skip_weight=getattr(args, 'aux_weight', 0.1),
|
| | word_rope_dims=getattr(config, 'word_rope_dims', 0),
|
| | word_rope_base=getattr(config, 'word_rope_base', 10.0),
|
| | embed_dim=getattr(config, 'embed_dim', 0),
|
| | head_dim=getattr(config, 'head_dim', 0),
|
| | )
|
| | model = MirroredTransformer(model_config).to(device)
|
| | param_info = count_mirrored_parameters(model)
|
| | num_params = param_info["unique"]
|
| | print(f"Model: MirroredTransformer")
|
| | print(f" Virtual layers: {model.total_virtual_layers} ({model_config.n_mirror} mirror pairs + {model_config.n_middle} middle)")
|
| | print(f" Parameters: {num_params:,} ({num_params/1e6:.1f}M unique)")
|
| | print(f" Shared FFN base: {param_info['shared_ffn_base']:,}")
|
| | print(f" Direction gates: {param_info['direction_gates']:,}")
|
| | print(f" FFN gating: {'G²LU (nested dual gate)' if model_config.use_g2lu else 'SwiGLU (vanilla)'}")
|
| | if model_config.num_kv_heads is not None:
|
| | print(f" GQA: {model_config.num_heads}Q / {model_config.num_kv_heads}KV ({model_config.num_heads // model_config.num_kv_heads}:1 ratio)")
|
| | if model_config.aux_skip_k > 0:
|
| | print(f" Aux skip prediction: t+{model_config.aux_skip_k} (weight={model_config.aux_skip_weight})")
|
| | if getattr(model_config, 'embed_dim', 0) > 0:
|
| | std_embed = config.vocab_size * config.hidden_size
|
| | fact_embed = config.vocab_size * model_config.embed_dim + model_config.embed_dim * config.hidden_size
|
| | print(f" Factorized embedding: {model_config.embed_dim} → {config.hidden_size} (saves {(std_embed - fact_embed):,} params)")
|
| | if getattr(model_config, 'head_dim', 0) > 0:
|
| | std_head = config.hidden_size * config.vocab_size
|
| | mlp_head = config.hidden_size * model_config.head_dim + model_config.head_dim * config.vocab_size
|
| | print(f" MLP head: {config.hidden_size} → {model_config.head_dim} → vocab (saves {(std_head - mlp_head):,} params)")
|
| | elif model_type == "graft_g2lu":
|
| | assert args.pretrained, "--pretrained is required for graft_g2lu architecture"
|
| | amp_dtype = torch.bfloat16 if config.bf16 else (torch.float16 if config.fp16 else torch.float32)
|
| | model = G2LU_GraftedModel(
|
| | pretrained_name=args.pretrained,
|
| | align_weight=args.align_weight,
|
| | warmup_steps=args.graft_warmup,
|
| | device=device,
|
| | dtype=amp_dtype,
|
| | )
|
| | model_config = None
|
| | num_params = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
|
| | else:
|
| | model_config = config
|
| | model = CircuitTransformer(config).to(device)
|
| | num_params = count_parameters(model)
|
| | print(f"Model: CircuitTransformer")
|
| | print(f" Parameters: {num_params:,} ({num_params/1e6:.1f}M)")
|
| | if getattr(config, 'aux_skip_k', 0) > 0:
|
| | print(f" Aux skip prediction: t+{config.aux_skip_k} (weight={config.aux_skip_weight})")
|
| | if getattr(config, 'embed_dim', 0) > 0:
|
| | std_embed = config.vocab_size * config.hidden_size
|
| | fact_embed = config.vocab_size * config.embed_dim + config.embed_dim * config.hidden_size
|
| | print(f" Factorized embedding: {config.embed_dim} → {config.hidden_size} (saves {(std_embed - fact_embed):,} params)")
|
| | if getattr(config, 'head_dim', 0) > 0:
|
| | std_head = config.hidden_size * config.vocab_size
|
| | mlp_head = config.hidden_size * config.head_dim + config.head_dim * config.vocab_size
|
| | print(f" MLP head: {config.hidden_size} → {config.head_dim} → vocab (saves {(std_head - mlp_head):,} params)")
|
| |
|
| |
|
| | word_rope_dims = getattr(config, 'word_rope_dims', 0)
|
| | if word_rope_dims > 0:
|
| | word_start_table = build_word_start_table(tokenizer, len(tokenizer)).to(device)
|
| | print(f" Word-position RoPE: {word_rope_dims} dims, base={getattr(config, 'word_rope_base', 10.0)}")
|
| | print(f" Word starters in vocab: {word_start_table.sum().item():,} / {len(tokenizer):,}")
|
| | else:
|
| | word_start_table = None
|
| |
|
| |
|
| | raw_model = model
|
| |
|
| |
|
| | if config.compile and hasattr(torch, "compile"):
|
| | print("Compiling model with torch.compile...")
|
| | model = torch.compile(raw_model)
|
| |
|
| |
|
| | grad_accum = getattr(args, 'grad_accum', 1)
|
| |
|
| | opt_params = list(raw_model.trainable_parameters()) if model_type == "graft_g2lu" else model.parameters()
|
| | optimizer = torch.optim.AdamW(
|
| | opt_params,
|
| | lr=config.learning_rate,
|
| | weight_decay=config.weight_decay,
|
| | betas=(0.9, 0.95),
|
| | )
|
| |
|
| |
|
| | use_amp = (config.fp16 or config.bf16) and device.type == "cuda"
|
| | amp_dtype = torch.bfloat16 if config.bf16 else torch.float16
|
| | scaler = GradScaler() if (config.fp16 and use_amp) else None
|
| | if use_amp:
|
| | print(f" Mixed precision: {'BF16' if config.bf16 else 'FP16'}" +
|
| | (" (no scaler)" if scaler is None else " (with GradScaler)"))
|
| |
|
| |
|
| | start_step = 0
|
| | start_epoch = 0
|
| | skip_batches = 0
|
| | best_val_loss = float("inf")
|
| | if args.resume:
|
| | print(f"Resuming from: {args.resume}")
|
| | resume_info = load_checkpoint(args.resume, model, optimizer, scaler, args.reset)
|
| | if not args.reset:
|
| | start_step = resume_info["step"]
|
| | start_epoch = resume_info["epoch"]
|
| | skip_batches = resume_info["epoch_step"]
|
| | best_val_loss = resume_info["best_val_loss"]
|
| | print(f"Resumed at step {start_step}, epoch {start_epoch}" +
|
| | (f", skipping {skip_batches} batches" if skip_batches > 0 else ""))
|
| | if best_val_loss < float("inf"):
|
| | print(f" Best val loss so far: {best_val_loss:.4f} (PPL {math.exp(min(best_val_loss, 20)):.2f})")
|
| |
|
| |
|
| | checkpoint_dir = Path(config.checkpoint_dir)
|
| | checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| |
|
| |
|
| | steps_per_epoch = math.ceil(len(dataloader) / grad_accum)
|
| | max_steps = config.epochs * steps_per_epoch
|
| | tokens_per_step = config.batch_size * grad_accum * config.max_seq_len
|
| | total_train_tokens = config.epochs * len(dataloader) * config.batch_size * config.max_seq_len
|
| | step = start_step
|
| | model.train()
|
| |
|
| | print(f"\nStarting training:")
|
| | print(f" Epochs: {config.epochs}")
|
| | print(f" Batch size: {config.batch_size}" +
|
| | (f" x {grad_accum} accum = {config.batch_size * grad_accum} effective" if grad_accum > 1 else ""))
|
| | print(f" Steps per epoch: {steps_per_epoch}" +
|
| | (f" ({len(dataloader)} micro-batches)" if grad_accum > 1 else ""))
|
| | print(f" Total steps: {max_steps}")
|
| | print(f" Total tokens: {total_train_tokens:,} ({total_train_tokens/1e6:.1f}M)")
|
| | if num_params > 0:
|
| | print(f" Tokens/param ratio: {total_train_tokens/num_params:.1f}x (Chinchilla=20x)")
|
| | print(f" Learning rate: {config.learning_rate}" +
|
| | (f" → {config.min_lr}" if config.min_lr > 0 else ""))
|
| | print(f" Mixed precision: {use_amp}")
|
| | print(f" Validation: {'enabled' if val_dataloader else 'disabled'}")
|
| | print()
|
| |
|
| | total_loss = 0.0
|
| | log_steps = 0
|
| | total_tokens_seen = step * tokens_per_step
|
| |
|
| | h_mid_buffer = None
|
| | last_align_val = float("inf")
|
| | start_time = time.time()
|
| |
|
| | for epoch in range(start_epoch, config.epochs):
|
| | epoch_start = time.time()
|
| | epoch_loss = 0.0
|
| | epoch_steps = 0
|
| |
|
| | micro_batches = []
|
| | epoch_micro_batches = skip_batches if epoch == start_epoch else 0
|
| |
|
| | for batch_idx, batch in enumerate(dataloader):
|
| |
|
| | if epoch == start_epoch and batch_idx < skip_batches:
|
| | continue
|
| |
|
| | micro_batches.append(batch)
|
| | epoch_micro_batches += 1
|
| |
|
| |
|
| | if len(micro_batches) < grad_accum and batch_idx < len(dataloader) - 1:
|
| | continue
|
| |
|
| | n_micro = len(micro_batches)
|
| | actual_tokens = n_micro * config.batch_size * config.max_seq_len
|
| |
|
| |
|
| | for param_group in optimizer.param_groups:
|
| | delay = param_group.get("delay", 0)
|
| | param_group["lr"] = get_lr(step, config.warmup_steps, max_steps, config.learning_rate, min_lr=config.min_lr, delay=delay)
|
| | lr = optimizer.param_groups[0]["lr"]
|
| |
|
| | loss_ed_val = None
|
| | loss_align_val = None
|
| | grad_norm_mid = None
|
| | absorb_loss_val = None
|
| |
|
| |
|
| | if model_type == "graft_g2lu":
|
| | raw_model.set_step(step)
|
| |
|
| |
|
| | optimizer.zero_grad()
|
| | accum_loss = 0.0
|
| | accum_aux = 0.0
|
| | accum_align = 0.0
|
| |
|
| | for mb in micro_batches:
|
| | mb_ids = mb["input_ids"].to(device)
|
| | mb_labels = mb["labels"].to(device)
|
| | word_positions = None
|
| | if word_start_table is not None:
|
| | word_positions = compute_word_positions(mb_ids, word_start_table)
|
| | if use_amp:
|
| | with autocast('cuda', dtype=amp_dtype):
|
| | output = model(mb_ids, labels=mb_labels, word_positions=word_positions)
|
| | else:
|
| | output = model(mb_ids, labels=mb_labels, word_positions=word_positions)
|
| | if scaler:
|
| | scaler.scale(output["loss"] / n_micro).backward()
|
| | else:
|
| | (output["loss"] / n_micro).backward()
|
| | accum_loss += output["loss"].item()
|
| | if "aux_loss" in output:
|
| | accum_aux += output["aux_loss"].item()
|
| | if "align_loss" in output:
|
| | accum_align += output["align_loss"].item()
|
| |
|
| | if scaler:
|
| | scaler.unscale_(optimizer)
|
| | clip_params = list(raw_model.trainable_parameters()) if model_type == "graft_g2lu" else model.parameters()
|
| | grad_norm = nn.utils.clip_grad_norm_(clip_params, config.grad_clip).item()
|
| | if scaler:
|
| | scaler.step(optimizer)
|
| | scaler.update()
|
| | else:
|
| | optimizer.step()
|
| | optimizer.zero_grad()
|
| |
|
| | loss_val = accum_loss / n_micro
|
| | aux_loss_val = accum_aux / n_micro if accum_aux > 0 else None
|
| | align_loss_val = accum_align / n_micro if accum_align > 0 else None
|
| |
|
| | total_loss += loss_val
|
| | epoch_loss += loss_val
|
| | epoch_steps += 1
|
| | log_steps += 1
|
| | total_tokens_seen += actual_tokens
|
| | step += 1
|
| |
|
| |
|
| | if step % config.log_every == 0:
|
| | avg_loss = total_loss / max(log_steps, 1)
|
| | ppl = math.exp(min(avg_loss, 20))
|
| | elapsed = time.time() - start_time
|
| | tok_s = (log_steps * tokens_per_step) / max(elapsed, 1e-6)
|
| |
|
| | extra = ""
|
| | if aux_loss_val is not None:
|
| | extra += f" | Aux {aux_loss_val:.3f}"
|
| | if align_loss_val is not None:
|
| | extra += f" | Align {align_loss_val:.4f}"
|
| |
|
| | print(
|
| | f"Step {step:6d} | "
|
| | f"Epoch {epoch+1}/{config.epochs} | "
|
| | f"Loss {avg_loss:.4f} | "
|
| | f"PPL {ppl:8.2f} | "
|
| | f"GradN {grad_norm:.3f} | "
|
| | f"LR {lr:.2e} | "
|
| | f"Tok/s {tok_s:.0f}"
|
| | f"{extra}"
|
| | )
|
| |
|
| | total_loss = 0.0
|
| | log_steps = 0
|
| | start_time = time.time()
|
| |
|
| |
|
| | if step % config.save_every == 0:
|
| | ckpt_path = checkpoint_dir / f"step_{step:06d}.pt"
|
| | if model_type == "graft_g2lu":
|
| | save_g2lu_checkpoint(raw_model, optimizer, step, epoch, loss_val, str(ckpt_path),
|
| | epoch_step=epoch_micro_batches, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| | else:
|
| | save_checkpoint(model, optimizer, step, epoch, loss_val, model_config, str(ckpt_path), model_type,
|
| | epoch_step=epoch_micro_batches, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| | print(f" Saved checkpoint: {ckpt_path}")
|
| | gc.collect()
|
| | torch.cuda.empty_cache()
|
| |
|
| |
|
| | val_every = getattr(args, 'val_every', 0)
|
| | if val_every > 0 and step % val_every == 0 and val_dataloader:
|
| | val_loss, val_ppl = evaluate(config, model, val_dataloader, device, use_amp, amp_dtype, mid_run_eval=True, word_start_table=word_start_table)
|
| | avg_train = epoch_loss / max(epoch_steps, 1)
|
| | gap = val_loss - avg_train
|
| | print(f" [Val @ step {step}] Loss: {val_loss:.4f} | PPL: {val_ppl:.2f} | Gap: {gap:+.4f}")
|
| | if val_loss < best_val_loss:
|
| | best_val_loss = val_loss
|
| | best_path = checkpoint_dir / "best.pt"
|
| | if model_type == "graft_g2lu":
|
| | save_g2lu_checkpoint(raw_model, optimizer, step, epoch, val_loss, str(best_path),
|
| | epoch_step=epoch_micro_batches, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| | else:
|
| | save_checkpoint(model, optimizer, step, epoch, val_loss, model_config, str(best_path), model_type,
|
| | epoch_step=epoch_micro_batches, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| | print(f" New best! Saved: {best_path}")
|
| | gc.collect()
|
| | torch.cuda.empty_cache()
|
| |
|
| | micro_batches = []
|
| |
|
| |
|
| | epoch_elapsed = time.time() - epoch_start
|
| | avg_epoch_loss = epoch_loss / max(epoch_steps, 1)
|
| | epoch_ppl = math.exp(min(avg_epoch_loss, 20))
|
| |
|
| | print(f"\n{'='*70}")
|
| | print(f"Epoch {epoch+1}/{config.epochs} complete in {epoch_elapsed:.0f}s")
|
| | print(f" Train loss: {avg_epoch_loss:.4f} | Train PPL: {epoch_ppl:.2f}")
|
| | print(f" Tokens seen: {total_tokens_seen:,} ({total_tokens_seen/1e6:.1f}M)")
|
| |
|
| |
|
| | if val_dataloader:
|
| | val_loss, val_ppl = evaluate(config, model, val_dataloader, device, use_amp, amp_dtype, word_start_table=word_start_table)
|
| | gap = val_loss - avg_epoch_loss
|
| | print(f" Val loss: {val_loss:.4f} | Val PPL: {val_ppl:.2f} | Gap: {gap:+.4f}")
|
| |
|
| | if val_loss < best_val_loss:
|
| | best_val_loss = val_loss
|
| | best_path = checkpoint_dir / "best.pt"
|
| | if model_type == "graft_g2lu":
|
| | save_g2lu_checkpoint(raw_model, optimizer, step, epoch + 1, val_loss, str(best_path),
|
| | epoch_step=0, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| | else:
|
| | save_checkpoint(model, optimizer, step, epoch + 1, val_loss, model_config, str(best_path), model_type,
|
| | epoch_step=0, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| | print(f" New best! Saved: {best_path}")
|
| |
|
| | gc.collect()
|
| | torch.cuda.empty_cache()
|
| | print(f"{'='*70}\n")
|
| |
|
| |
|
| | ckpt_path = checkpoint_dir / f"epoch_{epoch+1:02d}.pt"
|
| | if model_type == "graft_g2lu":
|
| | save_g2lu_checkpoint(raw_model, optimizer, step, epoch + 1, avg_epoch_loss, str(ckpt_path),
|
| | epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| | else:
|
| | save_checkpoint(model, optimizer, step, epoch + 1, avg_epoch_loss, model_config, str(ckpt_path), model_type,
|
| | epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| | gc.collect()
|
| | torch.cuda.empty_cache()
|
| |
|
| |
|
| | if step == start_step:
|
| | print(f"\nNo training performed (already at step {step}/{max_steps}).")
|
| | print(f" To train more epochs, increase --epochs beyond {config.epochs}.")
|
| | else:
|
| | final_path = checkpoint_dir / "latest.pt"
|
| | if model_type == "graft_g2lu":
|
| | save_g2lu_checkpoint(raw_model, optimizer, step, config.epochs, avg_epoch_loss, str(final_path),
|
| | epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| | else:
|
| | save_checkpoint(model, optimizer, step, config.epochs, avg_epoch_loss, model_config, str(final_path), model_type,
|
| | epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| | print(f"\nTraining complete.")
|
| | print(f" Final train loss: {avg_epoch_loss:.4f} | PPL: {epoch_ppl:.2f}")
|
| | if val_dataloader:
|
| | print(f" Best val loss: {best_val_loss:.4f} | PPL: {math.exp(min(best_val_loss, 20)):.2f}")
|
| | print(f" Total tokens: {total_tokens_seen:,}")
|
| | print(f" Checkpoints: {final_path}")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | train()
|
| |
|