Prisma / train.py
y3i12's picture
Initial commit
56e82ec
#!/usr/bin/env python3
"""
Training script for Circuit Transformer.
Usage:
python circuits/train.py --data hf:roneneldan/TinyStories --preset tiny --epochs 1 --gpu 0
python circuits/train.py --data path/to/corpus.txt --dims 256 --layers 6 --fp16
"""
import gc
import os
import time
import math
import random
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler
from torch.amp import autocast
from .config import CircuitConfig, parse_args
from .model import CircuitTransformer, count_parameters
from .mirrored import MirroredConfig, MirroredTransformer, count_mirrored_parameters
from .graft_g2lu import G2LU_GraftedModel, save_g2lu_checkpoint
from .layers import build_word_start_table, compute_word_positions
from .data import get_tokenizer, load_data, create_dataloader
def corrupt_tokens(input_ids, ratio, vocab_size):
"""Replace random tokens with random vocab tokens for denoising autoencoder.
Returns (corrupted_ids, mask) where mask is True at corrupted positions.
"""
mask = torch.rand(input_ids.shape, device=input_ids.device) < ratio
mask[:, 0] = False # never corrupt first token (BOS/start)
random_tokens = torch.randint(0, vocab_size, input_ids.shape, device=input_ids.device)
corrupted = input_ids.clone()
corrupted[mask] = random_tokens[mask]
return corrupted, mask
@torch.no_grad()
def evaluate(config, model, dataloader, device, use_amp=False, amp_dtype=torch.float16, mid_run_eval=False,
word_start_table=None):
"""Run validation and return avg loss + perplexity."""
model.eval()
total_loss = 0.0
n_batches = 0
for batch in dataloader:
input_ids = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
word_positions = None
if word_start_table is not None:
word_positions = compute_word_positions(input_ids, word_start_table)
if use_amp:
with autocast('cuda', dtype=amp_dtype):
output = model(input_ids, labels=labels, word_positions=word_positions)
else:
output = model(input_ids, labels=labels, word_positions=word_positions)
total_loss += output["loss"].item()
n_batches += 1
if n_batches % (config.log_every * 10) == 0:
avg_loss = total_loss / max(n_batches, 1)
ppl = math.exp(min(avg_loss, 20))
print(
f"batch {n_batches:6d}/{len(dataloader):6d} | "
f"Loss {total_loss / n_batches:.4f} | "
f"PPL {ppl:8.2f}"
)
if mid_run_eval and n_batches >= 1500 :
break
if not mid_run_eval:
model.train()
avg_loss = total_loss / max(n_batches, 1)
ppl = math.exp(min(avg_loss, 20)) # cap to avoid overflow
return avg_loss, ppl
def get_lr(step: int, warmup_steps: int, max_steps: int, max_lr: float, min_lr: float = 0.0, delay: int = 0) -> float:
"""Cosine learning rate schedule with warmup and optional delay.
With delay > 0, the schedule is shifted:
Steps 0..delay: LR = 0 (frozen)
Steps delay..delay+warmup: linear ramp 0 → max_lr
Steps delay+warmup..max_steps: cosine decay max_lr → min_lr
"""
if step < delay:
return 0.0
effective_step = step - delay
effective_max = max(1, max_steps - delay)
if effective_step < warmup_steps:
return max_lr * effective_step / warmup_steps
if effective_step >= effective_max:
return min_lr
progress = (effective_step - warmup_steps) / (effective_max - warmup_steps)
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
def save_checkpoint(
model: nn.Module,
optimizer: torch.optim.Optimizer,
step: int,
epoch: int,
loss: float,
config,
path: str,
model_type: str = "standard",
epoch_step: int = 0,
best_val_loss: float | None = None,
scaler=None,
tokenizer_name: str = "gpt2",
):
"""Save training checkpoint.
Args:
epoch: Next epoch to start on resume (completed epoch count).
epoch_step: Batches already processed in `epoch` (0 if epoch is complete).
optimizer_mid: Middle optimizer for dual-path training (optional).
"""
checkpoint = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"step": step,
"epoch": epoch,
"epoch_step": epoch_step,
"loss": loss,
"config": config.to_dict(),
"model_type": model_type,
"tokenizer_name": tokenizer_name,
}
if best_val_loss is not None:
checkpoint["best_val_loss"] = best_val_loss
if scaler is not None:
checkpoint["scaler"] = scaler.state_dict()
torch.save(checkpoint, path)
def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict:
"""Migrate checkpoint state_dict to match current model architecture.
Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle).
"""
model_keys = set(model.state_dict().keys())
ckpt_keys = set(state_dict.keys())
missing = model_keys - ckpt_keys
unexpected = ckpt_keys - model_keys
if not missing and not unexpected:
return state_dict # perfect match, no migration needed
migrated = dict(state_dict)
migrations = []
# SwiGLU → MirroredSwiGLU: w3 → gate_expand (dual_gate_middle upgrade)
for key in list(unexpected):
if ".ffn.gate_expand.weight" in key:
new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight")
if new_key in missing:
migrated[new_key] = migrated.pop(key)
missing.discard(new_key)
unexpected.discard(key)
migrations.append(f" {key}{new_key}")
if ".ffn.gate_compress.weight" in key:
new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight")
if new_key in missing:
migrated[new_key] = migrated.pop(key)
missing.discard(new_key)
unexpected.discard(key)
migrations.append(f" {key}{new_key}")
if migrations:
print(f"State dict migration ({len(migrations)} keys renamed):")
for m in migrations:
print(m)
# Report remaining missing keys (freshly initialized)
still_missing = model_keys - set(migrated.keys())
if still_missing:
print(f" New parameters (freshly initialized): {len(still_missing)}")
for k in sorted(still_missing):
print(f" {k}")
return migrated
def load_checkpoint(path: str, model: nn.Module, optimizer: torch.optim.Optimizer = None,
scaler=None, reset:bool = False):
"""Load training checkpoint. Returns dict with resume info."""
checkpoint = torch.load(path, map_location="cpu", weights_only=False)
state_dict = _migrate_state_dict(checkpoint["model"], model)
model.load_state_dict(state_dict, strict=False)
if not reset:
if optimizer is not None and "optimizer" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer"])
if scaler is not None and "scaler" in checkpoint:
scaler.load_state_dict(checkpoint["scaler"])
return {
"step": checkpoint.get("step", 0),
"epoch": checkpoint.get("epoch", 0),
"epoch_step": checkpoint.get("epoch_step", 0),
"best_val_loss": checkpoint.get("best_val_loss", float("inf")),
}
def train():
config, args = parse_args()
# Setup device
device = torch.device(f"cuda:{config.gpu}" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Load tokenizer and data
print(f"Loading data from: {args.data}")
model_type = args.arch
tokenizer_name = getattr(args, 'tokenizer', 'gpt2')
if model_type == "graft_g2lu":
tokenizer_name = args.pretrained
tokenizer = get_tokenizer(tokenizer_name)
config.vocab_size = len(tokenizer)
print(f"Tokenizer: {tokenizer_name} (vocab_size={config.vocab_size})")
cache_dir = None if args.no_cache else args.cache_dir
dataset = load_data(
args.data,
tokenizer,
config.max_seq_len,
text_column=args.text_column,
num_samples=args.num_samples,
cache_dir=cache_dir,
data_format=args.data_format,
)
print(f"Loaded {len(dataset):,} chunks")
# Train/val split
val_split = args.val_split
if val_split > 0 and len(dataset) > 20:
train_dataset, val_dataset = dataset.split(val_split)
print(f"Split: {len(train_dataset):,} train / {len(val_dataset):,} val ({val_split:.0%})")
else:
train_dataset = dataset
val_dataset = None
# Create dataloaders
dataloader = create_dataloader(
train_dataset,
config.batch_size,
shuffle=True,
)
val_dataloader = None
if val_dataset is not None:
val_dataloader = create_dataloader(
val_dataset,
config.batch_size,
shuffle=False,
)
# Create model
if model_type == "mirrored":
model_config = MirroredConfig(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
num_heads=config.num_heads,
num_kv_heads=config.num_kv_heads,
num_layers=config.num_layers,
n_middle=args.n_middle,
max_seq_len=config.max_seq_len,
dropout=config.dropout,
use_g2lu=not getattr(args, 'no_g2lu', False),
aux_skip_k=getattr(args, 'aux_skip', 0),
aux_skip_weight=getattr(args, 'aux_weight', 0.1),
word_rope_dims=getattr(config, 'word_rope_dims', 0),
word_rope_base=getattr(config, 'word_rope_base', 10.0),
embed_dim=getattr(config, 'embed_dim', 0),
head_dim=getattr(config, 'head_dim', 0),
)
model = MirroredTransformer(model_config).to(device)
param_info = count_mirrored_parameters(model)
num_params = param_info["unique"]
print(f"Model: MirroredTransformer")
print(f" Virtual layers: {model.total_virtual_layers} ({model_config.n_mirror} mirror pairs + {model_config.n_middle} middle)")
print(f" Parameters: {num_params:,} ({num_params/1e6:.1f}M unique)")
print(f" Shared FFN base: {param_info['shared_ffn_base']:,}")
print(f" Direction gates: {param_info['direction_gates']:,}")
print(f" FFN gating: {'G²LU (nested dual gate)' if model_config.use_g2lu else 'SwiGLU (vanilla)'}")
if model_config.num_kv_heads is not None:
print(f" GQA: {model_config.num_heads}Q / {model_config.num_kv_heads}KV ({model_config.num_heads // model_config.num_kv_heads}:1 ratio)")
if model_config.aux_skip_k > 0:
print(f" Aux skip prediction: t+{model_config.aux_skip_k} (weight={model_config.aux_skip_weight})")
if getattr(model_config, 'embed_dim', 0) > 0:
std_embed = config.vocab_size * config.hidden_size
fact_embed = config.vocab_size * model_config.embed_dim + model_config.embed_dim * config.hidden_size
print(f" Factorized embedding: {model_config.embed_dim}{config.hidden_size} (saves {(std_embed - fact_embed):,} params)")
if getattr(model_config, 'head_dim', 0) > 0:
std_head = config.hidden_size * config.vocab_size
mlp_head = config.hidden_size * model_config.head_dim + model_config.head_dim * config.vocab_size
print(f" MLP head: {config.hidden_size}{model_config.head_dim} → vocab (saves {(std_head - mlp_head):,} params)")
elif model_type == "graft_g2lu":
assert args.pretrained, "--pretrained is required for graft_g2lu architecture"
amp_dtype = torch.bfloat16 if config.bf16 else (torch.float16 if config.fp16 else torch.float32)
model = G2LU_GraftedModel(
pretrained_name=args.pretrained,
align_weight=args.align_weight,
warmup_steps=args.graft_warmup,
device=device,
dtype=amp_dtype,
)
model_config = None # No CircuitConfig for HF models
num_params = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
else:
model_config = config
model = CircuitTransformer(config).to(device)
num_params = count_parameters(model)
print(f"Model: CircuitTransformer")
print(f" Parameters: {num_params:,} ({num_params/1e6:.1f}M)")
if getattr(config, 'aux_skip_k', 0) > 0:
print(f" Aux skip prediction: t+{config.aux_skip_k} (weight={config.aux_skip_weight})")
if getattr(config, 'embed_dim', 0) > 0:
std_embed = config.vocab_size * config.hidden_size
fact_embed = config.vocab_size * config.embed_dim + config.embed_dim * config.hidden_size
print(f" Factorized embedding: {config.embed_dim}{config.hidden_size} (saves {(std_embed - fact_embed):,} params)")
if getattr(config, 'head_dim', 0) > 0:
std_head = config.hidden_size * config.vocab_size
mlp_head = config.hidden_size * config.head_dim + config.head_dim * config.vocab_size
print(f" MLP head: {config.hidden_size}{config.head_dim} → vocab (saves {(std_head - mlp_head):,} params)")
# Build word-position table if enabled
word_rope_dims = getattr(config, 'word_rope_dims', 0)
if word_rope_dims > 0:
word_start_table = build_word_start_table(tokenizer, len(tokenizer)).to(device)
print(f" Word-position RoPE: {word_rope_dims} dims, base={getattr(config, 'word_rope_base', 10.0)}")
print(f" Word starters in vocab: {word_start_table.sum().item():,} / {len(tokenizer):,}")
else:
word_start_table = None
# Keep raw reference for set_gate_step (torch.compile wraps the model)
raw_model = model
# Optionally compile
if config.compile and hasattr(torch, "compile"):
print("Compiling model with torch.compile...")
model = torch.compile(raw_model)
# Optimizer — with optional staggered warmup and dual-path training
grad_accum = getattr(args, 'grad_accum', 1)
opt_params = list(raw_model.trainable_parameters()) if model_type == "graft_g2lu" else model.parameters()
optimizer = torch.optim.AdamW(
opt_params,
lr=config.learning_rate,
weight_decay=config.weight_decay,
betas=(0.9, 0.95),
)
# Mixed precision
use_amp = (config.fp16 or config.bf16) and device.type == "cuda"
amp_dtype = torch.bfloat16 if config.bf16 else torch.float16
scaler = GradScaler() if (config.fp16 and use_amp) else None
if use_amp:
print(f" Mixed precision: {'BF16' if config.bf16 else 'FP16'}" +
(" (no scaler)" if scaler is None else " (with GradScaler)"))
# Resume from checkpoint
start_step = 0
start_epoch = 0
skip_batches = 0
best_val_loss = float("inf")
if args.resume:
print(f"Resuming from: {args.resume}")
resume_info = load_checkpoint(args.resume, model, optimizer, scaler, args.reset)
if not args.reset:
start_step = resume_info["step"]
start_epoch = resume_info["epoch"]
skip_batches = resume_info["epoch_step"]
best_val_loss = resume_info["best_val_loss"]
print(f"Resumed at step {start_step}, epoch {start_epoch}" +
(f", skipping {skip_batches} batches" if skip_batches > 0 else ""))
if best_val_loss < float("inf"):
print(f" Best val loss so far: {best_val_loss:.4f} (PPL {math.exp(min(best_val_loss, 20)):.2f})")
# Setup checkpoint directory
checkpoint_dir = Path(config.checkpoint_dir)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
# Training loop
steps_per_epoch = math.ceil(len(dataloader) / grad_accum)
max_steps = config.epochs * steps_per_epoch
tokens_per_step = config.batch_size * grad_accum * config.max_seq_len
total_train_tokens = config.epochs * len(dataloader) * config.batch_size * config.max_seq_len
step = start_step
model.train()
print(f"\nStarting training:")
print(f" Epochs: {config.epochs}")
print(f" Batch size: {config.batch_size}" +
(f" x {grad_accum} accum = {config.batch_size * grad_accum} effective" if grad_accum > 1 else ""))
print(f" Steps per epoch: {steps_per_epoch}" +
(f" ({len(dataloader)} micro-batches)" if grad_accum > 1 else ""))
print(f" Total steps: {max_steps}")
print(f" Total tokens: {total_train_tokens:,} ({total_train_tokens/1e6:.1f}M)")
if num_params > 0:
print(f" Tokens/param ratio: {total_train_tokens/num_params:.1f}x (Chinchilla=20x)")
print(f" Learning rate: {config.learning_rate}" +
(f" → {config.min_lr}" if config.min_lr > 0 else ""))
print(f" Mixed precision: {use_amp}")
print(f" Validation: {'enabled' if val_dataloader else 'disabled'}")
print()
total_loss = 0.0
log_steps = 0
total_tokens_seen = step * tokens_per_step
# best_val_loss already set in resume section above
h_mid_buffer = None
last_align_val = float("inf")
start_time = time.time()
for epoch in range(start_epoch, config.epochs):
epoch_start = time.time()
epoch_loss = 0.0
epoch_steps = 0
micro_batches = []
epoch_micro_batches = skip_batches if epoch == start_epoch else 0
for batch_idx, batch in enumerate(dataloader):
# Skip already-processed batches on resume
if epoch == start_epoch and batch_idx < skip_batches:
continue
micro_batches.append(batch)
epoch_micro_batches += 1
# Accumulate micro-batches (flush at accum boundary or epoch end)
if len(micro_batches) < grad_accum and batch_idx < len(dataloader) - 1:
continue
n_micro = len(micro_batches)
actual_tokens = n_micro * config.batch_size * config.max_seq_len
# Update learning rate (per-group delays for staggered warmup)
for param_group in optimizer.param_groups:
delay = param_group.get("delay", 0)
param_group["lr"] = get_lr(step, config.warmup_steps, max_steps, config.learning_rate, min_lr=config.min_lr, delay=delay)
lr = optimizer.param_groups[0]["lr"] # for logging
loss_ed_val = None
loss_align_val = None
grad_norm_mid = None
absorb_loss_val = None
# Update blend alpha for G²LU grafting
if model_type == "graft_g2lu":
raw_model.set_step(step)
# === Standard single-path training with accumulation ===
optimizer.zero_grad()
accum_loss = 0.0
accum_aux = 0.0
accum_align = 0.0
for mb in micro_batches:
mb_ids = mb["input_ids"].to(device)
mb_labels = mb["labels"].to(device)
word_positions = None
if word_start_table is not None:
word_positions = compute_word_positions(mb_ids, word_start_table)
if use_amp:
with autocast('cuda', dtype=amp_dtype):
output = model(mb_ids, labels=mb_labels, word_positions=word_positions)
else:
output = model(mb_ids, labels=mb_labels, word_positions=word_positions)
if scaler:
scaler.scale(output["loss"] / n_micro).backward()
else:
(output["loss"] / n_micro).backward()
accum_loss += output["loss"].item()
if "aux_loss" in output:
accum_aux += output["aux_loss"].item()
if "align_loss" in output:
accum_align += output["align_loss"].item()
if scaler:
scaler.unscale_(optimizer)
clip_params = list(raw_model.trainable_parameters()) if model_type == "graft_g2lu" else model.parameters()
grad_norm = nn.utils.clip_grad_norm_(clip_params, config.grad_clip).item()
if scaler:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad()
loss_val = accum_loss / n_micro
aux_loss_val = accum_aux / n_micro if accum_aux > 0 else None
align_loss_val = accum_align / n_micro if accum_align > 0 else None
total_loss += loss_val
epoch_loss += loss_val
epoch_steps += 1
log_steps += 1
total_tokens_seen += actual_tokens
step += 1
# Logging
if step % config.log_every == 0:
avg_loss = total_loss / max(log_steps, 1)
ppl = math.exp(min(avg_loss, 20))
elapsed = time.time() - start_time
tok_s = (log_steps * tokens_per_step) / max(elapsed, 1e-6)
extra = ""
if aux_loss_val is not None:
extra += f" | Aux {aux_loss_val:.3f}"
if align_loss_val is not None:
extra += f" | Align {align_loss_val:.4f}"
print(
f"Step {step:6d} | "
f"Epoch {epoch+1}/{config.epochs} | "
f"Loss {avg_loss:.4f} | "
f"PPL {ppl:8.2f} | "
f"GradN {grad_norm:.3f} | "
f"LR {lr:.2e} | "
f"Tok/s {tok_s:.0f}"
f"{extra}"
)
total_loss = 0.0
log_steps = 0
start_time = time.time()
# Checkpointing
if step % config.save_every == 0:
ckpt_path = checkpoint_dir / f"step_{step:06d}.pt"
if model_type == "graft_g2lu":
save_g2lu_checkpoint(raw_model, optimizer, step, epoch, loss_val, str(ckpt_path),
epoch_step=epoch_micro_batches, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
else:
save_checkpoint(model, optimizer, step, epoch, loss_val, model_config, str(ckpt_path), model_type,
epoch_step=epoch_micro_batches, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
print(f" Saved checkpoint: {ckpt_path}")
gc.collect()
torch.cuda.empty_cache()
# Mid-training validation
val_every = getattr(args, 'val_every', 0)
if val_every > 0 and step % val_every == 0 and val_dataloader:
val_loss, val_ppl = evaluate(config, model, val_dataloader, device, use_amp, amp_dtype, mid_run_eval=True, word_start_table=word_start_table)
avg_train = epoch_loss / max(epoch_steps, 1)
gap = val_loss - avg_train
print(f" [Val @ step {step}] Loss: {val_loss:.4f} | PPL: {val_ppl:.2f} | Gap: {gap:+.4f}")
if val_loss < best_val_loss:
best_val_loss = val_loss
best_path = checkpoint_dir / "best.pt"
if model_type == "graft_g2lu":
save_g2lu_checkpoint(raw_model, optimizer, step, epoch, val_loss, str(best_path),
epoch_step=epoch_micro_batches, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
else:
save_checkpoint(model, optimizer, step, epoch, val_loss, model_config, str(best_path), model_type,
epoch_step=epoch_micro_batches, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
print(f" New best! Saved: {best_path}")
gc.collect()
torch.cuda.empty_cache()
micro_batches = []
# --- Epoch summary ---
epoch_elapsed = time.time() - epoch_start
avg_epoch_loss = epoch_loss / max(epoch_steps, 1)
epoch_ppl = math.exp(min(avg_epoch_loss, 20))
print(f"\n{'='*70}")
print(f"Epoch {epoch+1}/{config.epochs} complete in {epoch_elapsed:.0f}s")
print(f" Train loss: {avg_epoch_loss:.4f} | Train PPL: {epoch_ppl:.2f}")
print(f" Tokens seen: {total_tokens_seen:,} ({total_tokens_seen/1e6:.1f}M)")
# Validation
if val_dataloader:
val_loss, val_ppl = evaluate(config, model, val_dataloader, device, use_amp, amp_dtype, word_start_table=word_start_table)
gap = val_loss - avg_epoch_loss
print(f" Val loss: {val_loss:.4f} | Val PPL: {val_ppl:.2f} | Gap: {gap:+.4f}")
if val_loss < best_val_loss:
best_val_loss = val_loss
best_path = checkpoint_dir / "best.pt"
if model_type == "graft_g2lu":
save_g2lu_checkpoint(raw_model, optimizer, step, epoch + 1, val_loss, str(best_path),
epoch_step=0, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
else:
save_checkpoint(model, optimizer, step, epoch + 1, val_loss, model_config, str(best_path), model_type,
epoch_step=0, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
print(f" New best! Saved: {best_path}")
# Free validation tensors
gc.collect()
torch.cuda.empty_cache()
print(f"{'='*70}\n")
# Save epoch checkpoint
ckpt_path = checkpoint_dir / f"epoch_{epoch+1:02d}.pt"
if model_type == "graft_g2lu":
save_g2lu_checkpoint(raw_model, optimizer, step, epoch + 1, avg_epoch_loss, str(ckpt_path),
epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
else:
save_checkpoint(model, optimizer, step, epoch + 1, avg_epoch_loss, model_config, str(ckpt_path), model_type,
epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
gc.collect()
torch.cuda.empty_cache()
# Save final checkpoint
if step == start_step:
print(f"\nNo training performed (already at step {step}/{max_steps}).")
print(f" To train more epochs, increase --epochs beyond {config.epochs}.")
else:
final_path = checkpoint_dir / "latest.pt"
if model_type == "graft_g2lu":
save_g2lu_checkpoint(raw_model, optimizer, step, config.epochs, avg_epoch_loss, str(final_path),
epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
else:
save_checkpoint(model, optimizer, step, config.epochs, avg_epoch_loss, model_config, str(final_path), model_type,
epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
print(f"\nTraining complete.")
print(f" Final train loss: {avg_epoch_loss:.4f} | PPL: {epoch_ppl:.2f}")
if val_dataloader:
print(f" Best val loss: {best_val_loss:.4f} | PPL: {math.exp(min(best_val_loss, 20)):.2f}")
print(f" Total tokens: {total_tokens_seen:,}")
print(f" Checkpoints: {final_path}")
if __name__ == "__main__":
train()