| """NemotronKAN pretraining script — streams FineWeb-Edu, supports Accelerate bf16/DDP. |
| |
| Usage: |
| python3 train.py [--resume PATH] |
| |
| RTX 3060 Laptop (6GB VRAM). Nemotron recipe hyperparams. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import math |
| import os |
| import sys |
| import time |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from accelerate import Accelerator |
|
|
| try: |
| import bitsandbytes as bnb |
| except ImportError: |
| bnb = None |
|
|
| |
| |
| |
| BATCH_SIZE = 8 |
| SEQ_LEN = 512 |
| GRAD_ACCUM_STEPS = 4 |
| MAX_STEPS = 50_000 |
| WARMUP_STEPS = 100 |
| ADAM_LR = 1e-3 |
| WEIGHT_DECAY = 0.1 |
| DECAY_START = 0.7 |
| MAX_GRAD_NORM = 1.0 |
| EVAL_INTERVAL = 500 |
| EVAL_STEPS = 20 |
| CHECKPOINT_DIR = "checkpoints" |
| CHECKPOINT_INTERVAL = 2000 |
| LOG_INTERVAL = 1 |
| DATASET_NAME = "HuggingFaceFW/fineweb-edu" |
| DATASET_CONFIG = "sample-10BT" |
| MONITOR_INTERVAL_S = 20 |
|
|
|
|
| |
| |
| |
| class StreamingTokenBuffer: |
| """Streams text from HuggingFace datasets, tokenises on-the-fly, |
| and yields (input, target) chunks of block_size.""" |
|
|
| def __init__( |
| self, |
| split: str, |
| block_size: int, |
| batch_size: int, |
| dataset_name: str, |
| dataset_config: str, |
| accelerator: Accelerator | None = None, |
| ): |
| import tiktoken |
| from datasets import load_dataset |
| from datasets.distributed import split_dataset_by_node |
|
|
| self.enc = tiktoken.get_encoding("gpt2") |
| self.block_size = block_size |
| self.batch_size = batch_size |
| self.ds = load_dataset( |
| dataset_name, |
| dataset_config, |
| split=split, |
| streaming=True, |
| ) |
| rank = accelerator.process_index if accelerator else 0 |
| world_size = accelerator.num_processes if accelerator else 1 |
| if world_size > 1: |
| self.ds = split_dataset_by_node(self.ds, rank=rank, world_size=world_size) |
| self._buf: list[int] = [] |
| self._iter = None |
|
|
| def _ensure_iter(self): |
| if self._iter is None: |
| self._iter = iter(self.ds) |
|
|
| def _fill(self, need: int): |
| self._ensure_iter() |
| assert self._iter is not None |
| while len(self._buf) < need: |
| try: |
| row = next(self._iter) |
| except StopIteration: |
| self._iter = iter(self.ds) |
| row = next(self._iter) |
| text = row.get("text", "") |
| if text.strip(): |
| self._buf.extend(self.enc.encode_ordinary(text)) |
|
|
| def prefill(self, num_batches: int = 64): |
| need = num_batches * self.batch_size * (self.block_size + 1) |
| self._fill(need) |
|
|
| def get_batch(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: |
| need = self.batch_size * (self.block_size + 1) |
| self._fill(need) |
| tokens = self._buf[:need] |
| self._buf = self._buf[need:] |
| t = torch.tensor(tokens, dtype=torch.long).view( |
| self.batch_size, self.block_size + 1 |
| ) |
| x = t[:, :-1].to(device, non_blocking=True) |
| y = t[:, 1:].to(device, non_blocking=True) |
| return x, y |
|
|
|
|
| |
| |
| |
| def get_lr_factor( |
| step: int, warmup_steps: int, max_steps: int, decay_start: float |
| ) -> float: |
| if step < warmup_steps: |
| return step / warmup_steps |
|
|
| decay_start_step = int(max_steps * decay_start) |
| if step < decay_start_step: |
| return 1.0 |
|
|
| if step >= max_steps: |
| return 0.1 |
|
|
| decay_ratio = (step - decay_start_step) / (max_steps - decay_start_step) |
| return 1.0 - 0.9 * decay_ratio |
|
|
|
|
| @torch.no_grad() |
| def sample( |
| model: nn.Module, |
| device: torch.device, |
| seq_len: int, |
| prompt: str, |
| max_new_tokens: int, |
| temperature: float = 0.8, |
| top_k: int = 40, |
| ) -> str: |
| import tiktoken |
|
|
| enc = tiktoken.get_encoding("gpt2") |
| idx = torch.tensor(enc.encode_ordinary(prompt), dtype=torch.long, device=device)[ |
| None, : |
| ] |
| was_training = model.training |
| model.eval() |
|
|
| for _ in range(max_new_tokens): |
| idx_cond = idx[:, -seq_len:] |
| logits, _, _ = model(idx_cond) |
| logits = logits[:, -1, :] / max(temperature, 1e-6) |
|
|
| k = min(top_k, logits.size(-1)) |
| v, _ = torch.topk(logits, k) |
| logits[logits < v[:, [-1]]] = -float("inf") |
|
|
| probs = F.softmax(logits, dim=-1) |
| idx_next = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat((idx, idx_next), dim=1) |
|
|
| if was_training: |
| model.train() |
| tokens = [t for t in idx[0].tolist() if t < 50257] |
| return enc.decode(tokens) |
|
|
|
|
| |
| |
| |
| def main(): |
| |
| _path = os.environ.get("PATH", "") |
| if "/usr/local/cuda/bin" not in _path: |
| os.environ["PATH"] = "/usr/local/cuda/bin:" + _path |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--resume", type=str, default=None, help="Checkpoint to resume from" |
| ) |
| parser.add_argument("--model-size", choices=["124m", "350m"], default="350m") |
| parser.add_argument("--dataset", default="HuggingFaceFW/fineweb-edu") |
| parser.add_argument("--dataset-config", default="sample-10BT") |
| parser.add_argument("--val-dataset", default="wikitext") |
| parser.add_argument("--val-dataset-config", default="wikitext-103-raw-v1") |
| parser.add_argument("--push-to-hub", action="store_true") |
| parser.add_argument("--hub-model-id", type=str, default=None) |
| args = parser.parse_args() |
|
|
| if args.model_size == "350m": |
| batch_size = 64 |
| seq_len = 1024 |
| grad_accum_steps = 8 |
| adam_lr = 3e-4 |
| max_steps = 19_073 |
| warmup_steps = 2000 |
| eval_interval = 250 |
| checkpoint_interval = 1000 |
| else: |
| batch_size = BATCH_SIZE |
| seq_len = SEQ_LEN |
| grad_accum_steps = GRAD_ACCUM_STEPS |
| adam_lr = ADAM_LR |
| max_steps = MAX_STEPS |
| warmup_steps = WARMUP_STEPS |
| eval_interval = EVAL_INTERVAL |
| checkpoint_interval = CHECKPOINT_INTERVAL |
|
|
| accelerator = Accelerator( |
| mixed_precision="bf16" if torch.cuda.is_available() else "no" |
| ) |
| device = accelerator.device |
| if accelerator.is_main_process: |
| print(f"Device: {device}") |
| if device.type == "cuda": |
| if accelerator.is_main_process: |
| print(f"GPU: {torch.cuda.get_device_name(device)}") |
| print( |
| f"VRAM: {torch.cuda.get_device_properties(device).total_memory / 1024**3:.2f} GB" |
| ) |
| os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.set_float32_matmul_precision("high") |
|
|
| |
| from nemotron_kan import NemotronKANConfig, NemotronKAN |
|
|
| if args.model_size == "350m": |
| config = NemotronKANConfig( |
| n_layer=24, |
| n_embd=1024, |
| n_head=16, |
| n_kv_head=4, |
| initializer_range=0.014, |
| vocab_size=50304, |
| block_size=1024, |
| dropout=0.0, |
| kan_type="grkan", |
| kan_num_grids=4, |
| kan_hidden_mult=4, |
| kan_grkan_num_groups=8, |
| hc_num_streams=2, |
| mhc=True, |
| gradient_checkpointing=True, |
| use_engram=False, |
| use_sdr_compression=False, |
| use_axiomatic_attention=False, |
| use_holographic_memory=False, |
| use_jit_cache=False, |
| use_synaptic_offload=False, |
| use_fused_ce=True, |
| ) |
| else: |
| config = NemotronKANConfig( |
| n_layer=12, |
| n_embd=768, |
| n_head=12, |
| n_kv_head=4, |
| vocab_size=50304, |
| block_size=seq_len, |
| dropout=0.0, |
| kan_type="grkan", |
| kan_num_grids=4, |
| kan_hidden_mult=4, |
| kan_grkan_num_groups=8, |
| hc_num_streams=2, |
| mhc=True, |
| gradient_checkpointing=True, |
| use_engram=False, |
| use_sdr_compression=False, |
| use_axiomatic_attention=False, |
| use_holographic_memory=False, |
| use_jit_cache=False, |
| use_synaptic_offload=False, |
| use_fused_ce=True, |
| ) |
|
|
| model = NemotronKAN(config) |
| num_params = sum(p.numel() for p in model.parameters()) |
| if accelerator.is_main_process: |
| print(f"Model params: {num_params:,} ({num_params / 1e6:.1f}M)") |
| model = model.to(device) |
|
|
| |
| linear_weight_names = set() |
| for module_name, module in model.named_modules(): |
| if isinstance(module, nn.Linear) and module.weight is not None: |
| weight_name = f"{module_name}.weight" if module_name else "weight" |
| linear_weight_names.add(weight_name) |
|
|
| decay_names: list[str] = [] |
| no_decay_names: list[str] = [] |
|
|
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
|
|
| if param.ndim == 2 and name in linear_weight_names: |
| decay_names.append(name) |
| else: |
| no_decay_names.append(name) |
|
|
| all_names = set(decay_names) | set(no_decay_names) |
| param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} |
| assert len(set(decay_names) & set(no_decay_names)) == 0, ( |
| "Overlap in decay/no_decay groups" |
| ) |
| assert all_names == set(param_dict.keys()), "Missing params in optimizer groups" |
|
|
| param_groups = [ |
| dict( |
| params=[param_dict[pn] for pn in sorted(decay_names)], |
| lr=adam_lr, |
| weight_decay=WEIGHT_DECAY, |
| ), |
| dict( |
| params=[param_dict[pn] for pn in sorted(no_decay_names)], |
| lr=adam_lr, |
| weight_decay=0.0, |
| ), |
| ] |
| if args.model_size == "350m": |
| if device.type == "cuda": |
| try: |
| optimizer = torch.optim.AdamW( |
| param_groups, |
| lr=adam_lr, |
| betas=(0.9, 0.95), |
| eps=1e-8, |
| fused=True, |
| ) |
| except TypeError: |
| optimizer = torch.optim.AdamW( |
| param_groups, |
| lr=adam_lr, |
| betas=(0.9, 0.95), |
| eps=1e-8, |
| ) |
| else: |
| optimizer = torch.optim.AdamW( |
| param_groups, |
| lr=adam_lr, |
| betas=(0.9, 0.95), |
| eps=1e-8, |
| ) |
| if accelerator.is_main_process: |
| print("Optimizer: AdamW") |
| else: |
| if bnb is None: |
| raise ImportError( |
| "bitsandbytes required for Adam8bit. Install: pip install bitsandbytes" |
| ) |
| optimizer = bnb.optim.Adam8bit( |
| param_groups, lr=adam_lr, betas=(0.9, 0.95), eps=1e-10 |
| ) |
| if accelerator.is_main_process: |
| print("Optimizer: 8-bit Adam (single optimizer)") |
|
|
| if device.type == "cuda": |
| model = torch.compile(model, backend="inductor") |
|
|
| model, optimizer = accelerator.prepare(model, optimizer) |
| if accelerator.is_main_process: |
| print(f"AMP: {accelerator.mixed_precision}") |
|
|
| |
| if accelerator.is_main_process: |
| print(f"Initializing streaming data ({args.dataset}/{args.dataset_config})...") |
| train_data = StreamingTokenBuffer( |
| "train", |
| seq_len, |
| batch_size, |
| args.dataset, |
| args.dataset_config, |
| accelerator, |
| ) |
| val_data = StreamingTokenBuffer( |
| "validation", |
| seq_len, |
| batch_size, |
| args.val_dataset, |
| args.val_dataset_config, |
| accelerator, |
| ) |
| if accelerator.is_main_process: |
| print("Pre-filling token buffer...") |
| train_data.prefill(128) |
| if accelerator.is_main_process: |
| print(f"Data stream ready ({len(train_data._buf):,} tokens buffered).") |
|
|
| |
| global_step = 0 |
| best_val_loss = float("inf") |
| unwrapped_model = accelerator.unwrap_model(model) |
| if args.resume and os.path.isfile(args.resume): |
| if accelerator.is_main_process: |
| print(f"Resuming from {args.resume}...") |
| ckpt = torch.load(args.resume, map_location="cpu", weights_only=False) |
| unwrapped_model.load_state_dict(ckpt["model_state_dict"]) |
| try: |
| optimizer.load_state_dict(ckpt["optimizer_state_dict"]) |
| except Exception as exc: |
| if accelerator.is_main_process: |
| print(f"Warning: could not load optimizer state ({exc})") |
| global_step = ckpt.get("global_step", 0) |
| best_val_loss = ckpt.get("best_val_loss", float("inf")) |
| del ckpt |
| torch.cuda.empty_cache() |
| if accelerator.is_main_process: |
| print(f"Resumed at step {global_step}, best_val_loss={best_val_loss:.4f}") |
|
|
| |
| if accelerator.is_main_process: |
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) |
| model.train() |
| optimizer.zero_grad(set_to_none=True) |
|
|
| accum_count = 0 |
| running_loss = 0.0 |
| running_tokens = 0 |
| t_start = time.time() |
| t_last_monitor = t_start |
| t_last_step = t_start |
| total_tokens = 0 |
| loss_history: list[float] = [] |
|
|
| if accelerator.is_main_process: |
| print(f"\n{'=' * 70}") |
| print(f"Training NemotronKAN — {num_params / 1e6:.1f}M params") |
| print( |
| f" model_size={args.model_size} | world_size={accelerator.num_processes}" |
| ) |
| print(f" batch={batch_size}, seq={seq_len}, grad_accum={grad_accum_steps}") |
| print( |
| f" effective_batch_tokens={batch_size * grad_accum_steps * seq_len * accelerator.num_processes:,}" |
| ) |
| print(f" adam_lr={adam_lr}, warmup={warmup_steps}, max_steps={max_steps}") |
| print(" AMP bf16 (cuda), gradient checkpointing, WSD, fused_ce") |
| print(f"{'=' * 70}\n") |
|
|
| try: |
| while global_step < max_steps: |
| x, y = train_data.get_batch(device) |
| tokens_in_batch = x.numel() |
|
|
| with accelerator.autocast(): |
| logits, loss, _sdr = model(x, y) |
|
|
| if logits is not None: |
| logits = 30.0 * torch.tanh(logits / 30.0) |
|
|
| scaled_loss = loss / grad_accum_steps |
|
|
| accelerator.backward(scaled_loss) |
|
|
| running_loss += loss.item() |
| running_tokens += tokens_in_batch |
| total_tokens += tokens_in_batch |
| accum_count += 1 |
|
|
| if accum_count >= grad_accum_steps: |
| lr_factor = get_lr_factor( |
| global_step, warmup_steps, max_steps, DECAY_START |
| ) |
| lr = adam_lr * lr_factor |
|
|
| for pg in optimizer.param_groups: |
| pg["lr"] = lr |
|
|
| grad_norm = accelerator.clip_grad_norm_( |
| model.parameters(), MAX_GRAD_NORM |
| ).item() |
| optimizer.step() |
|
|
| optimizer.zero_grad(set_to_none=True) |
| global_step += 1 |
| accum_count = 0 |
|
|
| if global_step % LOG_INTERVAL == 0 and accelerator.is_main_process: |
| avg_loss = running_loss / grad_accum_steps |
| t_now = time.time() |
| dt = t_now - t_last_step |
| tok_s = running_tokens / max(dt, 1e-6) |
| elapsed = t_now - t_start |
| loss_history.append(avg_loss) |
|
|
| trend = "" |
| if len(loss_history) >= 10: |
| recent = loss_history[-10:] |
| if recent[-1] < recent[0]: |
| trend = " [decreasing]" |
| else: |
| trend = " [WARNING: not decreasing]" |
|
|
| gpu_mem = "" |
| if device.type == "cuda": |
| mem_used = torch.cuda.max_memory_allocated() / 1024**3 |
| gpu_mem = f" | GPU {mem_used:.2f}GB" |
|
|
| print( |
| f"step {global_step:>6d} | loss {avg_loss:.4f}{trend} | " |
| f"lr {lr:.2e} | gnorm {grad_norm:.2f} | " |
| f"tok/s {tok_s:,.0f} | elapsed {elapsed:.0f}s{gpu_mem}" |
| ) |
| sys.stdout.flush() |
|
|
| running_loss = 0.0 |
| running_tokens = 0 |
| t_last_step = t_now |
|
|
| t_now = time.time() |
| if ( |
| t_now - t_last_monitor >= MONITOR_INTERVAL_S |
| and accelerator.is_main_process |
| ): |
| elapsed = t_now - t_start |
| avg_tok_s = total_tokens / max(elapsed, 1e-6) |
| if device.type == "cuda": |
| mem_alloc = torch.cuda.memory_allocated() / 1024**3 |
| mem_max = torch.cuda.max_memory_allocated() / 1024**3 |
| print( |
| f" [MONITOR] step={global_step} | " |
| f"total_tok={total_tokens:,} | avg_tok/s={avg_tok_s:,.0f} | " |
| f"GPU_alloc={mem_alloc:.2f}GB | GPU_peak={mem_max:.2f}GB" |
| ) |
| else: |
| print( |
| f" [MONITOR] step={global_step} | " |
| f"total_tok={total_tokens:,} | avg_tok/s={avg_tok_s:,.0f}" |
| ) |
| sys.stdout.flush() |
| t_last_monitor = t_now |
|
|
| if global_step % eval_interval == 0: |
| model.eval() |
| val_loss_total = 0.0 |
| with torch.no_grad(): |
| for _ in range(EVAL_STEPS): |
| vx, vy = val_data.get_batch(device) |
| with accelerator.autocast(): |
| _, vloss, _ = model(vx, vy) |
| gathered_vloss = accelerator.gather(vloss.detach()) |
| val_loss_total += gathered_vloss.mean().item() |
| val_loss = val_loss_total / EVAL_STEPS |
| ppl = math.exp(min(val_loss, 20.0)) |
| improved = val_loss < best_val_loss |
| if improved: |
| best_val_loss = val_loss |
| if accelerator.is_main_process: |
| print( |
| f" >>> EVAL step {global_step} | val_loss {val_loss:.4f} | " |
| f"ppl {ppl:.1f} | best {best_val_loss:.4f}" |
| f"{' [NEW BEST - saving]' if improved else ''}" |
| ) |
| sys.stdout.flush() |
| if improved and accelerator.is_main_process: |
| _save_checkpoint( |
| accelerator, |
| model, |
| optimizer, |
| global_step, |
| best_val_loss, |
| os.path.join(CHECKPOINT_DIR, "best.pt"), |
| ) |
|
|
| if accelerator.is_main_process: |
| eval_sample = sample( |
| accelerator.unwrap_model(model), |
| device, |
| seq_len, |
| prompt="The meaning of life is", |
| max_new_tokens=60, |
| temperature=0.8, |
| top_k=40, |
| ) |
| print(f" >>> SAMPLE step {global_step}: {eval_sample}") |
| model.train() |
|
|
| if ( |
| global_step % checkpoint_interval == 0 |
| and accelerator.is_main_process |
| ): |
| _save_checkpoint( |
| accelerator, |
| model, |
| optimizer, |
| global_step, |
| best_val_loss, |
| os.path.join(CHECKPOINT_DIR, f"step_{global_step}.pt"), |
| ) |
|
|
| except KeyboardInterrupt: |
| if accelerator.is_main_process: |
| print("\nInterrupted — saving checkpoint...") |
| _save_checkpoint( |
| accelerator, |
| model, |
| optimizer, |
| global_step, |
| best_val_loss, |
| os.path.join(CHECKPOINT_DIR, "interrupted.pt"), |
| ) |
| finally: |
| elapsed = time.time() - t_start |
| avg_tok_s = total_tokens / max(elapsed, 1e-6) |
| if accelerator.is_main_process: |
| print(f"\nTraining ended at step {global_step}") |
| print( |
| f"Total tokens: {total_tokens:,} in {elapsed:.0f}s ({avg_tok_s:,.0f} tok/s)" |
| ) |
| if loss_history and accelerator.is_main_process: |
| print( |
| f"Final loss: {loss_history[-1]:.4f} (started at {loss_history[0]:.4f})" |
| ) |
| if accelerator.is_main_process: |
| final_sample = sample( |
| accelerator.unwrap_model(model), |
| device, |
| seq_len, |
| prompt="The meaning of life is", |
| max_new_tokens=200, |
| temperature=0.8, |
| top_k=40, |
| ) |
| print(f"Final sample: {final_sample}") |
| if args.push_to_hub and args.hub_model_id and os.environ.get("HF_TOKEN"): |
| try: |
| accelerator.unwrap_model(model).push_to_hub(args.hub_model_id) |
| print(f"Pushed model to hub: {args.hub_model_id}") |
| except Exception as exc: |
| print(f"Warning: push_to_hub failed ({exc})") |
|
|
|
|
| def _save_checkpoint(accelerator, model, optimizer, global_step, best_val_loss, path): |
| if not accelerator.is_main_process: |
| return |
| os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) |
| state = { |
| "model_state_dict": accelerator.unwrap_model(model).state_dict(), |
| "optimizer_state_dict": optimizer.state_dict(), |
| "global_step": global_step, |
| "best_val_loss": best_val_loss, |
| } |
| torch.save(state, path) |
| print(f" Saved checkpoint: {path}") |
| |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|