"""NemotronKAN pretraining script — streams FineWeb-Edu, supports Accelerate bf16/DDP. Usage: python3 train.py [--resume PATH] RTX 3060 Laptop (6GB VRAM). Nemotron recipe hyperparams. """ from __future__ import annotations import argparse import math import os import sys import time import torch import torch.nn as nn import torch.nn.functional as F from accelerate import Accelerator try: import bitsandbytes as bnb except ImportError: bnb = None # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- BATCH_SIZE = 8 SEQ_LEN = 512 GRAD_ACCUM_STEPS = 4 # effective batch = 8*4*512 = 16384 tokens MAX_STEPS = 50_000 WARMUP_STEPS = 100 ADAM_LR = 1e-3 WEIGHT_DECAY = 0.1 DECAY_START = 0.7 MAX_GRAD_NORM = 1.0 EVAL_INTERVAL = 500 EVAL_STEPS = 20 CHECKPOINT_DIR = "checkpoints" CHECKPOINT_INTERVAL = 2000 LOG_INTERVAL = 1 DATASET_NAME = "HuggingFaceFW/fineweb-edu" DATASET_CONFIG = "sample-10BT" MONITOR_INTERVAL_S = 20 # --------------------------------------------------------------------------- # Streaming token buffer # --------------------------------------------------------------------------- class StreamingTokenBuffer: """Streams text from HuggingFace datasets, tokenises on-the-fly, and yields (input, target) chunks of block_size.""" def __init__( self, split: str, block_size: int, batch_size: int, dataset_name: str, dataset_config: str, accelerator: Accelerator | None = None, ): import tiktoken from datasets import load_dataset from datasets.distributed import split_dataset_by_node self.enc = tiktoken.get_encoding("gpt2") self.block_size = block_size self.batch_size = batch_size self.ds = load_dataset( dataset_name, dataset_config, split=split, streaming=True, ) rank = accelerator.process_index if accelerator else 0 world_size = accelerator.num_processes if accelerator else 1 if world_size > 1: self.ds = split_dataset_by_node(self.ds, rank=rank, world_size=world_size) self._buf: list[int] = [] self._iter = None def _ensure_iter(self): if self._iter is None: self._iter = iter(self.ds) def _fill(self, need: int): self._ensure_iter() assert self._iter is not None while len(self._buf) < need: try: row = next(self._iter) except StopIteration: self._iter = iter(self.ds) row = next(self._iter) text = row.get("text", "") if text.strip(): self._buf.extend(self.enc.encode_ordinary(text)) def prefill(self, num_batches: int = 64): need = num_batches * self.batch_size * (self.block_size + 1) self._fill(need) def get_batch(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: need = self.batch_size * (self.block_size + 1) self._fill(need) tokens = self._buf[:need] self._buf = self._buf[need:] t = torch.tensor(tokens, dtype=torch.long).view( self.batch_size, self.block_size + 1 ) x = t[:, :-1].to(device, non_blocking=True) y = t[:, 1:].to(device, non_blocking=True) return x, y # --------------------------------------------------------------------------- # Learning rate schedule (WSD: warmup-stable-decay) # --------------------------------------------------------------------------- def get_lr_factor( step: int, warmup_steps: int, max_steps: int, decay_start: float ) -> float: if step < warmup_steps: return step / warmup_steps decay_start_step = int(max_steps * decay_start) if step < decay_start_step: return 1.0 if step >= max_steps: return 0.1 decay_ratio = (step - decay_start_step) / (max_steps - decay_start_step) return 1.0 - 0.9 * decay_ratio @torch.no_grad() def sample( model: nn.Module, device: torch.device, seq_len: int, prompt: str, max_new_tokens: int, temperature: float = 0.8, top_k: int = 40, ) -> str: import tiktoken enc = tiktoken.get_encoding("gpt2") idx = torch.tensor(enc.encode_ordinary(prompt), dtype=torch.long, device=device)[ None, : ] was_training = model.training model.eval() for _ in range(max_new_tokens): idx_cond = idx[:, -seq_len:] logits, _, _ = model(idx_cond) logits = logits[:, -1, :] / max(temperature, 1e-6) k = min(top_k, logits.size(-1)) v, _ = torch.topk(logits, k) logits[logits < v[:, [-1]]] = -float("inf") probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, idx_next), dim=1) if was_training: model.train() tokens = [t for t in idx[0].tolist() if t < 50257] return enc.decode(tokens) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): # Ensure nvcc is on PATH for torch.compile inductor backend _path = os.environ.get("PATH", "") if "/usr/local/cuda/bin" not in _path: os.environ["PATH"] = "/usr/local/cuda/bin:" + _path parser = argparse.ArgumentParser() parser.add_argument( "--resume", type=str, default=None, help="Checkpoint to resume from" ) parser.add_argument("--model-size", choices=["124m", "350m"], default="350m") parser.add_argument("--dataset", default="HuggingFaceFW/fineweb-edu") parser.add_argument("--dataset-config", default="sample-10BT") parser.add_argument("--val-dataset", default="wikitext") parser.add_argument("--val-dataset-config", default="wikitext-103-raw-v1") parser.add_argument("--push-to-hub", action="store_true") parser.add_argument("--hub-model-id", type=str, default=None) args = parser.parse_args() if args.model_size == "350m": batch_size = 64 seq_len = 1024 grad_accum_steps = 8 adam_lr = 3e-4 max_steps = 19_073 warmup_steps = 2000 eval_interval = 250 checkpoint_interval = 1000 else: batch_size = BATCH_SIZE seq_len = SEQ_LEN grad_accum_steps = GRAD_ACCUM_STEPS adam_lr = ADAM_LR max_steps = MAX_STEPS warmup_steps = WARMUP_STEPS eval_interval = EVAL_INTERVAL checkpoint_interval = CHECKPOINT_INTERVAL accelerator = Accelerator( mixed_precision="bf16" if torch.cuda.is_available() else "no" ) device = accelerator.device if accelerator.is_main_process: print(f"Device: {device}") if device.type == "cuda": if accelerator.is_main_process: print(f"GPU: {torch.cuda.get_device_name(device)}") print( f"VRAM: {torch.cuda.get_device_properties(device).total_memory / 1024**3:.2f} GB" ) os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.set_float32_matmul_precision("high") # --- Model --------------------------------------------------------------- from nemotron_kan import NemotronKANConfig, NemotronKAN if args.model_size == "350m": config = NemotronKANConfig( n_layer=24, n_embd=1024, n_head=16, n_kv_head=4, initializer_range=0.014, vocab_size=50304, block_size=1024, dropout=0.0, kan_type="grkan", kan_num_grids=4, kan_hidden_mult=4, kan_grkan_num_groups=8, hc_num_streams=2, mhc=True, gradient_checkpointing=True, use_engram=False, use_sdr_compression=False, use_axiomatic_attention=False, use_holographic_memory=False, use_jit_cache=False, use_synaptic_offload=False, use_fused_ce=True, ) else: config = NemotronKANConfig( n_layer=12, n_embd=768, n_head=12, n_kv_head=4, vocab_size=50304, block_size=seq_len, dropout=0.0, kan_type="grkan", kan_num_grids=4, kan_hidden_mult=4, kan_grkan_num_groups=8, hc_num_streams=2, mhc=True, gradient_checkpointing=True, use_engram=False, use_sdr_compression=False, use_axiomatic_attention=False, use_holographic_memory=False, use_jit_cache=False, use_synaptic_offload=False, use_fused_ce=True, ) model = NemotronKAN(config) num_params = sum(p.numel() for p in model.parameters()) if accelerator.is_main_process: print(f"Model params: {num_params:,} ({num_params / 1e6:.1f}M)") model = model.to(device) # --- 8-bit Adam optimizer ------------------------------------------------ linear_weight_names = set() for module_name, module in model.named_modules(): if isinstance(module, nn.Linear) and module.weight is not None: weight_name = f"{module_name}.weight" if module_name else "weight" linear_weight_names.add(weight_name) decay_names: list[str] = [] no_decay_names: list[str] = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if param.ndim == 2 and name in linear_weight_names: decay_names.append(name) else: no_decay_names.append(name) all_names = set(decay_names) | set(no_decay_names) param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} assert len(set(decay_names) & set(no_decay_names)) == 0, ( "Overlap in decay/no_decay groups" ) assert all_names == set(param_dict.keys()), "Missing params in optimizer groups" param_groups = [ dict( params=[param_dict[pn] for pn in sorted(decay_names)], lr=adam_lr, weight_decay=WEIGHT_DECAY, ), dict( params=[param_dict[pn] for pn in sorted(no_decay_names)], lr=adam_lr, weight_decay=0.0, ), ] if args.model_size == "350m": if device.type == "cuda": try: optimizer = torch.optim.AdamW( param_groups, lr=adam_lr, betas=(0.9, 0.95), eps=1e-8, fused=True, ) except TypeError: optimizer = torch.optim.AdamW( param_groups, lr=adam_lr, betas=(0.9, 0.95), eps=1e-8, ) else: optimizer = torch.optim.AdamW( param_groups, lr=adam_lr, betas=(0.9, 0.95), eps=1e-8, ) if accelerator.is_main_process: print("Optimizer: AdamW") else: if bnb is None: raise ImportError( "bitsandbytes required for Adam8bit. Install: pip install bitsandbytes" ) optimizer = bnb.optim.Adam8bit( param_groups, lr=adam_lr, betas=(0.9, 0.95), eps=1e-10 ) if accelerator.is_main_process: print("Optimizer: 8-bit Adam (single optimizer)") if device.type == "cuda": model = torch.compile(model, backend="inductor") model, optimizer = accelerator.prepare(model, optimizer) if accelerator.is_main_process: print(f"AMP: {accelerator.mixed_precision}") # --- Data ---------------------------------------------------------------- if accelerator.is_main_process: print(f"Initializing streaming data ({args.dataset}/{args.dataset_config})...") train_data = StreamingTokenBuffer( "train", seq_len, batch_size, args.dataset, args.dataset_config, accelerator, ) val_data = StreamingTokenBuffer( "validation", seq_len, batch_size, args.val_dataset, args.val_dataset_config, accelerator, ) if accelerator.is_main_process: print("Pre-filling token buffer...") train_data.prefill(128) if accelerator.is_main_process: print(f"Data stream ready ({len(train_data._buf):,} tokens buffered).") # --- Resume from checkpoint ----------------------------------------------- global_step = 0 best_val_loss = float("inf") unwrapped_model = accelerator.unwrap_model(model) if args.resume and os.path.isfile(args.resume): if accelerator.is_main_process: print(f"Resuming from {args.resume}...") ckpt = torch.load(args.resume, map_location="cpu", weights_only=False) unwrapped_model.load_state_dict(ckpt["model_state_dict"]) try: optimizer.load_state_dict(ckpt["optimizer_state_dict"]) except Exception as exc: if accelerator.is_main_process: print(f"Warning: could not load optimizer state ({exc})") global_step = ckpt.get("global_step", 0) best_val_loss = ckpt.get("best_val_loss", float("inf")) del ckpt # Free ~1.7GB CPU memory immediately torch.cuda.empty_cache() if accelerator.is_main_process: print(f"Resumed at step {global_step}, best_val_loss={best_val_loss:.4f}") # --- Training loop ------------------------------------------------------- if accelerator.is_main_process: os.makedirs(CHECKPOINT_DIR, exist_ok=True) model.train() optimizer.zero_grad(set_to_none=True) accum_count = 0 running_loss = 0.0 running_tokens = 0 t_start = time.time() t_last_monitor = t_start t_last_step = t_start total_tokens = 0 loss_history: list[float] = [] if accelerator.is_main_process: print(f"\n{'=' * 70}") print(f"Training NemotronKAN — {num_params / 1e6:.1f}M params") print( f" model_size={args.model_size} | world_size={accelerator.num_processes}" ) print(f" batch={batch_size}, seq={seq_len}, grad_accum={grad_accum_steps}") print( f" effective_batch_tokens={batch_size * grad_accum_steps * seq_len * accelerator.num_processes:,}" ) print(f" adam_lr={adam_lr}, warmup={warmup_steps}, max_steps={max_steps}") print(" AMP bf16 (cuda), gradient checkpointing, WSD, fused_ce") print(f"{'=' * 70}\n") try: while global_step < max_steps: x, y = train_data.get_batch(device) tokens_in_batch = x.numel() with accelerator.autocast(): logits, loss, _sdr = model(x, y) if logits is not None: logits = 30.0 * torch.tanh(logits / 30.0) scaled_loss = loss / grad_accum_steps accelerator.backward(scaled_loss) running_loss += loss.item() running_tokens += tokens_in_batch total_tokens += tokens_in_batch accum_count += 1 if accum_count >= grad_accum_steps: lr_factor = get_lr_factor( global_step, warmup_steps, max_steps, DECAY_START ) lr = adam_lr * lr_factor for pg in optimizer.param_groups: pg["lr"] = lr grad_norm = accelerator.clip_grad_norm_( model.parameters(), MAX_GRAD_NORM ).item() optimizer.step() optimizer.zero_grad(set_to_none=True) global_step += 1 accum_count = 0 if global_step % LOG_INTERVAL == 0 and accelerator.is_main_process: avg_loss = running_loss / grad_accum_steps t_now = time.time() dt = t_now - t_last_step tok_s = running_tokens / max(dt, 1e-6) elapsed = t_now - t_start loss_history.append(avg_loss) trend = "" if len(loss_history) >= 10: recent = loss_history[-10:] if recent[-1] < recent[0]: trend = " [decreasing]" else: trend = " [WARNING: not decreasing]" gpu_mem = "" if device.type == "cuda": mem_used = torch.cuda.max_memory_allocated() / 1024**3 gpu_mem = f" | GPU {mem_used:.2f}GB" print( f"step {global_step:>6d} | loss {avg_loss:.4f}{trend} | " f"lr {lr:.2e} | gnorm {grad_norm:.2f} | " f"tok/s {tok_s:,.0f} | elapsed {elapsed:.0f}s{gpu_mem}" ) sys.stdout.flush() running_loss = 0.0 running_tokens = 0 t_last_step = t_now t_now = time.time() if ( t_now - t_last_monitor >= MONITOR_INTERVAL_S and accelerator.is_main_process ): elapsed = t_now - t_start avg_tok_s = total_tokens / max(elapsed, 1e-6) if device.type == "cuda": mem_alloc = torch.cuda.memory_allocated() / 1024**3 mem_max = torch.cuda.max_memory_allocated() / 1024**3 print( f" [MONITOR] step={global_step} | " f"total_tok={total_tokens:,} | avg_tok/s={avg_tok_s:,.0f} | " f"GPU_alloc={mem_alloc:.2f}GB | GPU_peak={mem_max:.2f}GB" ) else: print( f" [MONITOR] step={global_step} | " f"total_tok={total_tokens:,} | avg_tok/s={avg_tok_s:,.0f}" ) sys.stdout.flush() t_last_monitor = t_now if global_step % eval_interval == 0: model.eval() val_loss_total = 0.0 with torch.no_grad(): for _ in range(EVAL_STEPS): vx, vy = val_data.get_batch(device) with accelerator.autocast(): _, vloss, _ = model(vx, vy) gathered_vloss = accelerator.gather(vloss.detach()) val_loss_total += gathered_vloss.mean().item() val_loss = val_loss_total / EVAL_STEPS ppl = math.exp(min(val_loss, 20.0)) improved = val_loss < best_val_loss if improved: best_val_loss = val_loss if accelerator.is_main_process: print( f" >>> EVAL step {global_step} | val_loss {val_loss:.4f} | " f"ppl {ppl:.1f} | best {best_val_loss:.4f}" f"{' [NEW BEST - saving]' if improved else ''}" ) sys.stdout.flush() if improved and accelerator.is_main_process: _save_checkpoint( accelerator, model, optimizer, global_step, best_val_loss, os.path.join(CHECKPOINT_DIR, "best.pt"), ) if accelerator.is_main_process: eval_sample = sample( accelerator.unwrap_model(model), device, seq_len, prompt="The meaning of life is", max_new_tokens=60, temperature=0.8, top_k=40, ) print(f" >>> SAMPLE step {global_step}: {eval_sample}") model.train() if ( global_step % checkpoint_interval == 0 and accelerator.is_main_process ): _save_checkpoint( accelerator, model, optimizer, global_step, best_val_loss, os.path.join(CHECKPOINT_DIR, f"step_{global_step}.pt"), ) except KeyboardInterrupt: if accelerator.is_main_process: print("\nInterrupted — saving checkpoint...") _save_checkpoint( accelerator, model, optimizer, global_step, best_val_loss, os.path.join(CHECKPOINT_DIR, "interrupted.pt"), ) finally: elapsed = time.time() - t_start avg_tok_s = total_tokens / max(elapsed, 1e-6) if accelerator.is_main_process: print(f"\nTraining ended at step {global_step}") print( f"Total tokens: {total_tokens:,} in {elapsed:.0f}s ({avg_tok_s:,.0f} tok/s)" ) if loss_history and accelerator.is_main_process: print( f"Final loss: {loss_history[-1]:.4f} (started at {loss_history[0]:.4f})" ) if accelerator.is_main_process: final_sample = sample( accelerator.unwrap_model(model), device, seq_len, prompt="The meaning of life is", max_new_tokens=200, temperature=0.8, top_k=40, ) print(f"Final sample: {final_sample}") if args.push_to_hub and args.hub_model_id and os.environ.get("HF_TOKEN"): try: accelerator.unwrap_model(model).push_to_hub(args.hub_model_id) print(f"Pushed model to hub: {args.hub_model_id}") except Exception as exc: print(f"Warning: push_to_hub failed ({exc})") def _save_checkpoint(accelerator, model, optimizer, global_step, best_val_loss, path): if not accelerator.is_main_process: return os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) state = { "model_state_dict": accelerator.unwrap_model(model).state_dict(), "optimizer_state_dict": optimizer.state_dict(), "global_step": global_step, "best_val_loss": best_val_loss, } torch.save(state, path) print(f" Saved checkpoint: {path}") # TODO: push checkpoint to HF Hub using safetensors when push_to_hub is enabled and HF_TOKEN is present. if __name__ == "__main__": main()