nemotron-kan-350m / train.py
icarus112's picture
feat(arch): sync with Nemotron-4 architecture alignment (d823c2c)
eca97aa verified
"""NemotronKAN pretraining script — streams FineWeb-Edu, supports Accelerate bf16/DDP.
Usage:
python3 train.py [--resume PATH]
RTX 3060 Laptop (6GB VRAM). Nemotron recipe hyperparams.
"""
from __future__ import annotations
import argparse
import math
import os
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import Accelerator
try:
import bitsandbytes as bnb
except ImportError:
bnb = None
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
BATCH_SIZE = 8
SEQ_LEN = 512
GRAD_ACCUM_STEPS = 4 # effective batch = 8*4*512 = 16384 tokens
MAX_STEPS = 50_000
WARMUP_STEPS = 100
ADAM_LR = 1e-3
WEIGHT_DECAY = 0.1
DECAY_START = 0.7
MAX_GRAD_NORM = 1.0
EVAL_INTERVAL = 500
EVAL_STEPS = 20
CHECKPOINT_DIR = "checkpoints"
CHECKPOINT_INTERVAL = 2000
LOG_INTERVAL = 1
DATASET_NAME = "HuggingFaceFW/fineweb-edu"
DATASET_CONFIG = "sample-10BT"
MONITOR_INTERVAL_S = 20
# ---------------------------------------------------------------------------
# Streaming token buffer
# ---------------------------------------------------------------------------
class StreamingTokenBuffer:
"""Streams text from HuggingFace datasets, tokenises on-the-fly,
and yields (input, target) chunks of block_size."""
def __init__(
self,
split: str,
block_size: int,
batch_size: int,
dataset_name: str,
dataset_config: str,
accelerator: Accelerator | None = None,
):
import tiktoken
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
self.enc = tiktoken.get_encoding("gpt2")
self.block_size = block_size
self.batch_size = batch_size
self.ds = load_dataset(
dataset_name,
dataset_config,
split=split,
streaming=True,
)
rank = accelerator.process_index if accelerator else 0
world_size = accelerator.num_processes if accelerator else 1
if world_size > 1:
self.ds = split_dataset_by_node(self.ds, rank=rank, world_size=world_size)
self._buf: list[int] = []
self._iter = None
def _ensure_iter(self):
if self._iter is None:
self._iter = iter(self.ds)
def _fill(self, need: int):
self._ensure_iter()
assert self._iter is not None
while len(self._buf) < need:
try:
row = next(self._iter)
except StopIteration:
self._iter = iter(self.ds)
row = next(self._iter)
text = row.get("text", "")
if text.strip():
self._buf.extend(self.enc.encode_ordinary(text))
def prefill(self, num_batches: int = 64):
need = num_batches * self.batch_size * (self.block_size + 1)
self._fill(need)
def get_batch(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
need = self.batch_size * (self.block_size + 1)
self._fill(need)
tokens = self._buf[:need]
self._buf = self._buf[need:]
t = torch.tensor(tokens, dtype=torch.long).view(
self.batch_size, self.block_size + 1
)
x = t[:, :-1].to(device, non_blocking=True)
y = t[:, 1:].to(device, non_blocking=True)
return x, y
# ---------------------------------------------------------------------------
# Learning rate schedule (WSD: warmup-stable-decay)
# ---------------------------------------------------------------------------
def get_lr_factor(
step: int, warmup_steps: int, max_steps: int, decay_start: float
) -> float:
if step < warmup_steps:
return step / warmup_steps
decay_start_step = int(max_steps * decay_start)
if step < decay_start_step:
return 1.0
if step >= max_steps:
return 0.1
decay_ratio = (step - decay_start_step) / (max_steps - decay_start_step)
return 1.0 - 0.9 * decay_ratio
@torch.no_grad()
def sample(
model: nn.Module,
device: torch.device,
seq_len: int,
prompt: str,
max_new_tokens: int,
temperature: float = 0.8,
top_k: int = 40,
) -> str:
import tiktoken
enc = tiktoken.get_encoding("gpt2")
idx = torch.tensor(enc.encode_ordinary(prompt), dtype=torch.long, device=device)[
None, :
]
was_training = model.training
model.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -seq_len:]
logits, _, _ = model(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-6)
k = min(top_k, logits.size(-1))
v, _ = torch.topk(logits, k)
logits[logits < v[:, [-1]]] = -float("inf")
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
if was_training:
model.train()
tokens = [t for t in idx[0].tolist() if t < 50257]
return enc.decode(tokens)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
# Ensure nvcc is on PATH for torch.compile inductor backend
_path = os.environ.get("PATH", "")
if "/usr/local/cuda/bin" not in _path:
os.environ["PATH"] = "/usr/local/cuda/bin:" + _path
parser = argparse.ArgumentParser()
parser.add_argument(
"--resume", type=str, default=None, help="Checkpoint to resume from"
)
parser.add_argument("--model-size", choices=["124m", "350m"], default="350m")
parser.add_argument("--dataset", default="HuggingFaceFW/fineweb-edu")
parser.add_argument("--dataset-config", default="sample-10BT")
parser.add_argument("--val-dataset", default="wikitext")
parser.add_argument("--val-dataset-config", default="wikitext-103-raw-v1")
parser.add_argument("--push-to-hub", action="store_true")
parser.add_argument("--hub-model-id", type=str, default=None)
args = parser.parse_args()
if args.model_size == "350m":
batch_size = 64
seq_len = 1024
grad_accum_steps = 8
adam_lr = 3e-4
max_steps = 19_073
warmup_steps = 2000
eval_interval = 250
checkpoint_interval = 1000
else:
batch_size = BATCH_SIZE
seq_len = SEQ_LEN
grad_accum_steps = GRAD_ACCUM_STEPS
adam_lr = ADAM_LR
max_steps = MAX_STEPS
warmup_steps = WARMUP_STEPS
eval_interval = EVAL_INTERVAL
checkpoint_interval = CHECKPOINT_INTERVAL
accelerator = Accelerator(
mixed_precision="bf16" if torch.cuda.is_available() else "no"
)
device = accelerator.device
if accelerator.is_main_process:
print(f"Device: {device}")
if device.type == "cuda":
if accelerator.is_main_process:
print(f"GPU: {torch.cuda.get_device_name(device)}")
print(
f"VRAM: {torch.cuda.get_device_properties(device).total_memory / 1024**3:.2f} GB"
)
os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")
# --- Model ---------------------------------------------------------------
from nemotron_kan import NemotronKANConfig, NemotronKAN
if args.model_size == "350m":
config = NemotronKANConfig(
n_layer=24,
n_embd=1024,
n_head=16,
n_kv_head=4,
initializer_range=0.014,
vocab_size=50304,
block_size=1024,
dropout=0.0,
kan_type="grkan",
kan_num_grids=4,
kan_hidden_mult=4,
kan_grkan_num_groups=8,
hc_num_streams=2,
mhc=True,
gradient_checkpointing=True,
use_engram=False,
use_sdr_compression=False,
use_axiomatic_attention=False,
use_holographic_memory=False,
use_jit_cache=False,
use_synaptic_offload=False,
use_fused_ce=True,
)
else:
config = NemotronKANConfig(
n_layer=12,
n_embd=768,
n_head=12,
n_kv_head=4,
vocab_size=50304,
block_size=seq_len,
dropout=0.0,
kan_type="grkan",
kan_num_grids=4,
kan_hidden_mult=4,
kan_grkan_num_groups=8,
hc_num_streams=2,
mhc=True,
gradient_checkpointing=True,
use_engram=False,
use_sdr_compression=False,
use_axiomatic_attention=False,
use_holographic_memory=False,
use_jit_cache=False,
use_synaptic_offload=False,
use_fused_ce=True,
)
model = NemotronKAN(config)
num_params = sum(p.numel() for p in model.parameters())
if accelerator.is_main_process:
print(f"Model params: {num_params:,} ({num_params / 1e6:.1f}M)")
model = model.to(device)
# --- 8-bit Adam optimizer ------------------------------------------------
linear_weight_names = set()
for module_name, module in model.named_modules():
if isinstance(module, nn.Linear) and module.weight is not None:
weight_name = f"{module_name}.weight" if module_name else "weight"
linear_weight_names.add(weight_name)
decay_names: list[str] = []
no_decay_names: list[str] = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if param.ndim == 2 and name in linear_weight_names:
decay_names.append(name)
else:
no_decay_names.append(name)
all_names = set(decay_names) | set(no_decay_names)
param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
assert len(set(decay_names) & set(no_decay_names)) == 0, (
"Overlap in decay/no_decay groups"
)
assert all_names == set(param_dict.keys()), "Missing params in optimizer groups"
param_groups = [
dict(
params=[param_dict[pn] for pn in sorted(decay_names)],
lr=adam_lr,
weight_decay=WEIGHT_DECAY,
),
dict(
params=[param_dict[pn] for pn in sorted(no_decay_names)],
lr=adam_lr,
weight_decay=0.0,
),
]
if args.model_size == "350m":
if device.type == "cuda":
try:
optimizer = torch.optim.AdamW(
param_groups,
lr=adam_lr,
betas=(0.9, 0.95),
eps=1e-8,
fused=True,
)
except TypeError:
optimizer = torch.optim.AdamW(
param_groups,
lr=adam_lr,
betas=(0.9, 0.95),
eps=1e-8,
)
else:
optimizer = torch.optim.AdamW(
param_groups,
lr=adam_lr,
betas=(0.9, 0.95),
eps=1e-8,
)
if accelerator.is_main_process:
print("Optimizer: AdamW")
else:
if bnb is None:
raise ImportError(
"bitsandbytes required for Adam8bit. Install: pip install bitsandbytes"
)
optimizer = bnb.optim.Adam8bit(
param_groups, lr=adam_lr, betas=(0.9, 0.95), eps=1e-10
)
if accelerator.is_main_process:
print("Optimizer: 8-bit Adam (single optimizer)")
if device.type == "cuda":
model = torch.compile(model, backend="inductor")
model, optimizer = accelerator.prepare(model, optimizer)
if accelerator.is_main_process:
print(f"AMP: {accelerator.mixed_precision}")
# --- Data ----------------------------------------------------------------
if accelerator.is_main_process:
print(f"Initializing streaming data ({args.dataset}/{args.dataset_config})...")
train_data = StreamingTokenBuffer(
"train",
seq_len,
batch_size,
args.dataset,
args.dataset_config,
accelerator,
)
val_data = StreamingTokenBuffer(
"validation",
seq_len,
batch_size,
args.val_dataset,
args.val_dataset_config,
accelerator,
)
if accelerator.is_main_process:
print("Pre-filling token buffer...")
train_data.prefill(128)
if accelerator.is_main_process:
print(f"Data stream ready ({len(train_data._buf):,} tokens buffered).")
# --- Resume from checkpoint -----------------------------------------------
global_step = 0
best_val_loss = float("inf")
unwrapped_model = accelerator.unwrap_model(model)
if args.resume and os.path.isfile(args.resume):
if accelerator.is_main_process:
print(f"Resuming from {args.resume}...")
ckpt = torch.load(args.resume, map_location="cpu", weights_only=False)
unwrapped_model.load_state_dict(ckpt["model_state_dict"])
try:
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
except Exception as exc:
if accelerator.is_main_process:
print(f"Warning: could not load optimizer state ({exc})")
global_step = ckpt.get("global_step", 0)
best_val_loss = ckpt.get("best_val_loss", float("inf"))
del ckpt # Free ~1.7GB CPU memory immediately
torch.cuda.empty_cache()
if accelerator.is_main_process:
print(f"Resumed at step {global_step}, best_val_loss={best_val_loss:.4f}")
# --- Training loop -------------------------------------------------------
if accelerator.is_main_process:
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
model.train()
optimizer.zero_grad(set_to_none=True)
accum_count = 0
running_loss = 0.0
running_tokens = 0
t_start = time.time()
t_last_monitor = t_start
t_last_step = t_start
total_tokens = 0
loss_history: list[float] = []
if accelerator.is_main_process:
print(f"\n{'=' * 70}")
print(f"Training NemotronKAN — {num_params / 1e6:.1f}M params")
print(
f" model_size={args.model_size} | world_size={accelerator.num_processes}"
)
print(f" batch={batch_size}, seq={seq_len}, grad_accum={grad_accum_steps}")
print(
f" effective_batch_tokens={batch_size * grad_accum_steps * seq_len * accelerator.num_processes:,}"
)
print(f" adam_lr={adam_lr}, warmup={warmup_steps}, max_steps={max_steps}")
print(" AMP bf16 (cuda), gradient checkpointing, WSD, fused_ce")
print(f"{'=' * 70}\n")
try:
while global_step < max_steps:
x, y = train_data.get_batch(device)
tokens_in_batch = x.numel()
with accelerator.autocast():
logits, loss, _sdr = model(x, y)
if logits is not None:
logits = 30.0 * torch.tanh(logits / 30.0)
scaled_loss = loss / grad_accum_steps
accelerator.backward(scaled_loss)
running_loss += loss.item()
running_tokens += tokens_in_batch
total_tokens += tokens_in_batch
accum_count += 1
if accum_count >= grad_accum_steps:
lr_factor = get_lr_factor(
global_step, warmup_steps, max_steps, DECAY_START
)
lr = adam_lr * lr_factor
for pg in optimizer.param_groups:
pg["lr"] = lr
grad_norm = accelerator.clip_grad_norm_(
model.parameters(), MAX_GRAD_NORM
).item()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
global_step += 1
accum_count = 0
if global_step % LOG_INTERVAL == 0 and accelerator.is_main_process:
avg_loss = running_loss / grad_accum_steps
t_now = time.time()
dt = t_now - t_last_step
tok_s = running_tokens / max(dt, 1e-6)
elapsed = t_now - t_start
loss_history.append(avg_loss)
trend = ""
if len(loss_history) >= 10:
recent = loss_history[-10:]
if recent[-1] < recent[0]:
trend = " [decreasing]"
else:
trend = " [WARNING: not decreasing]"
gpu_mem = ""
if device.type == "cuda":
mem_used = torch.cuda.max_memory_allocated() / 1024**3
gpu_mem = f" | GPU {mem_used:.2f}GB"
print(
f"step {global_step:>6d} | loss {avg_loss:.4f}{trend} | "
f"lr {lr:.2e} | gnorm {grad_norm:.2f} | "
f"tok/s {tok_s:,.0f} | elapsed {elapsed:.0f}s{gpu_mem}"
)
sys.stdout.flush()
running_loss = 0.0
running_tokens = 0
t_last_step = t_now
t_now = time.time()
if (
t_now - t_last_monitor >= MONITOR_INTERVAL_S
and accelerator.is_main_process
):
elapsed = t_now - t_start
avg_tok_s = total_tokens / max(elapsed, 1e-6)
if device.type == "cuda":
mem_alloc = torch.cuda.memory_allocated() / 1024**3
mem_max = torch.cuda.max_memory_allocated() / 1024**3
print(
f" [MONITOR] step={global_step} | "
f"total_tok={total_tokens:,} | avg_tok/s={avg_tok_s:,.0f} | "
f"GPU_alloc={mem_alloc:.2f}GB | GPU_peak={mem_max:.2f}GB"
)
else:
print(
f" [MONITOR] step={global_step} | "
f"total_tok={total_tokens:,} | avg_tok/s={avg_tok_s:,.0f}"
)
sys.stdout.flush()
t_last_monitor = t_now
if global_step % eval_interval == 0:
model.eval()
val_loss_total = 0.0
with torch.no_grad():
for _ in range(EVAL_STEPS):
vx, vy = val_data.get_batch(device)
with accelerator.autocast():
_, vloss, _ = model(vx, vy)
gathered_vloss = accelerator.gather(vloss.detach())
val_loss_total += gathered_vloss.mean().item()
val_loss = val_loss_total / EVAL_STEPS
ppl = math.exp(min(val_loss, 20.0))
improved = val_loss < best_val_loss
if improved:
best_val_loss = val_loss
if accelerator.is_main_process:
print(
f" >>> EVAL step {global_step} | val_loss {val_loss:.4f} | "
f"ppl {ppl:.1f} | best {best_val_loss:.4f}"
f"{' [NEW BEST - saving]' if improved else ''}"
)
sys.stdout.flush()
if improved and accelerator.is_main_process:
_save_checkpoint(
accelerator,
model,
optimizer,
global_step,
best_val_loss,
os.path.join(CHECKPOINT_DIR, "best.pt"),
)
if accelerator.is_main_process:
eval_sample = sample(
accelerator.unwrap_model(model),
device,
seq_len,
prompt="The meaning of life is",
max_new_tokens=60,
temperature=0.8,
top_k=40,
)
print(f" >>> SAMPLE step {global_step}: {eval_sample}")
model.train()
if (
global_step % checkpoint_interval == 0
and accelerator.is_main_process
):
_save_checkpoint(
accelerator,
model,
optimizer,
global_step,
best_val_loss,
os.path.join(CHECKPOINT_DIR, f"step_{global_step}.pt"),
)
except KeyboardInterrupt:
if accelerator.is_main_process:
print("\nInterrupted — saving checkpoint...")
_save_checkpoint(
accelerator,
model,
optimizer,
global_step,
best_val_loss,
os.path.join(CHECKPOINT_DIR, "interrupted.pt"),
)
finally:
elapsed = time.time() - t_start
avg_tok_s = total_tokens / max(elapsed, 1e-6)
if accelerator.is_main_process:
print(f"\nTraining ended at step {global_step}")
print(
f"Total tokens: {total_tokens:,} in {elapsed:.0f}s ({avg_tok_s:,.0f} tok/s)"
)
if loss_history and accelerator.is_main_process:
print(
f"Final loss: {loss_history[-1]:.4f} (started at {loss_history[0]:.4f})"
)
if accelerator.is_main_process:
final_sample = sample(
accelerator.unwrap_model(model),
device,
seq_len,
prompt="The meaning of life is",
max_new_tokens=200,
temperature=0.8,
top_k=40,
)
print(f"Final sample: {final_sample}")
if args.push_to_hub and args.hub_model_id and os.environ.get("HF_TOKEN"):
try:
accelerator.unwrap_model(model).push_to_hub(args.hub_model_id)
print(f"Pushed model to hub: {args.hub_model_id}")
except Exception as exc:
print(f"Warning: push_to_hub failed ({exc})")
def _save_checkpoint(accelerator, model, optimizer, global_step, best_val_loss, path):
if not accelerator.is_main_process:
return
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
state = {
"model_state_dict": accelerator.unwrap_model(model).state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"global_step": global_step,
"best_val_loss": best_val_loss,
}
torch.save(state, path)
print(f" Saved checkpoint: {path}")
# TODO: push checkpoint to HF Hub using safetensors when push_to_hub is enabled and HF_TOKEN is present.
if __name__ == "__main__":
main()