| import os |
| import csv |
| import json |
| import time |
| import argparse |
| import subprocess |
| import math |
| import sys |
| import torch |
| import torch.nn.functional as F |
|
|
| from GPT_model import ( |
| GPT, |
| device, |
| DEFAULT_CONFIG, |
| GPTConfig, |
| config_from_dict, |
| SimpleBPETokenizer as BPETokenizer, |
| ) |
|
|
| PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="CPU GPT trainer") |
| p.add_argument("--train-data", default=os.path.join("data", "jarvis_train.txt")) |
| p.add_argument("--val-data", default=os.path.join("data", "jarvis_val.txt")) |
| p.add_argument("--prepare-data", action="store_true") |
| p.add_argument("--n-embd", type=int, default=0, help="Model embedding size. 0 uses default/ckpt.") |
| p.add_argument("--n-head", type=int, default=0, help="Attention heads. 0 uses default/ckpt.") |
| p.add_argument("--n-layer", type=int, default=0, help="Transformer layers. 0 uses default/ckpt.") |
| p.add_argument("--block-size", type=int, default=0, help="Context length. 0 uses default/ckpt.") |
| p.add_argument("--dropout", type=float, default=-1.0, help="Dropout in [0,0.5]. <0 uses default/ckpt.") |
| p.add_argument("--run-steps", type=int, default=None, help="Train this many steps from current checkpoint.") |
| p.add_argument("--max-steps", type=int, default=230_000, help="Absolute max step index fallback.") |
| p.add_argument("--batch-size", type=int, default=4) |
| p.add_argument("--accum-steps", type=int, default=4) |
| p.add_argument("--lr", type=float, default=3e-5) |
| p.add_argument("--warmup-steps", type=int, default=200) |
| p.add_argument("--eval-every", type=int, default=100) |
| p.add_argument("--eval-batches", type=int, default=8) |
| p.add_argument("--save-every", type=int, default=200) |
| p.add_argument("--sample-every", type=int, default=200) |
| p.add_argument("--log-every", type=int, default=20) |
| p.add_argument("--grad-clip", type=float, default=1.0) |
| p.add_argument("--label-smoothing", type=float, default=0.0) |
| p.add_argument("--early-stop-patience", type=int, default=0, help="Stop after this many evals without val improvement. 0 disables.") |
| p.add_argument("--threads", type=int, default=max(1, min(6, (os.cpu_count() or 4) - 2))) |
| p.add_argument("--interop-threads", type=int, default=1) |
| p.add_argument("--ckpt-path", default="cpu_gpt_jarvis_rebuild_l6_v2048.pth") |
| p.add_argument("--best-path", default="cpu_gpt_jarvis_rebuild_l6_v2048_best.pth") |
| p.add_argument("--metrics-csv", default="cpu_gpt_jarvis_rebuild_l6_v2048_metrics.csv") |
| p.add_argument("--sample-temperature", type=float, default=0.75) |
| p.add_argument("--sample-top-k", type=int, default=40) |
| p.add_argument("--sample-top-p", type=float, default=0.9) |
| p.add_argument("--seed", type=int, default=1337) |
| p.add_argument("--reset-best-val", action=argparse.BooleanOptionalAction, default=False) |
| p.add_argument("--reset-optimizer", action=argparse.BooleanOptionalAction, default=False) |
| return p.parse_args() |
|
|
|
|
| def ensure_data_ready(args): |
| need_prepare = args.prepare_data or (not os.path.exists(args.train_data)) or (not os.path.exists(args.val_data)) |
| if not need_prepare: |
| return |
|
|
| train_name = os.path.basename(args.train_data).lower() |
| val_name = os.path.basename(args.val_data).lower() |
| target = f"{train_name} {val_name}" |
|
|
| scripts = [] |
| if "jarvis_mix" in target: |
| scripts = ["prepare_refine_data.py", "build_mixed_refine_data.py"] |
| elif "jarvis_refine" in target: |
| scripts = ["prepare_refine_data.py"] |
| else: |
| scripts = ["prepare_data.py"] |
|
|
| for script in scripts: |
| print(f"Preparing data with {script} ...") |
| script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), script) |
| cmd = [sys.executable, script_path] |
| res = subprocess.run(cmd, check=False, capture_output=True, text=True) |
| if res.stdout: |
| print(res.stdout.strip()) |
| if res.returncode != 0: |
| if res.stderr: |
| print(res.stderr.strip()) |
| raise RuntimeError(f"{script} failed") |
|
|
|
|
| def load_tokenizer(): |
| tokenizer = BPETokenizer() |
| vocab_path = os.path.join(PROJECT_ROOT, "data", "bpe_vocab.json") |
| if not os.path.exists(vocab_path): |
| vocab_path = "bpe_vocab.json" |
| with open(vocab_path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| tokenizer.merges = { |
| tuple(map(int, k.split(","))): v |
| for k, v in data["merges"].items() |
| } |
| tokenizer.vocab = { |
| int(k): bytes(v, "latin1") |
| for k, v in data["vocab"].items() |
| } |
| tokenizer._encode_cached.cache_clear() |
| print("Vocab size:", len(tokenizer.vocab)) |
| return tokenizer |
|
|
|
|
| class TokenWindowDataset: |
| def __init__(self, path, tokenizer, block_size: int): |
| self.path = path |
| self.block_size = int(block_size) |
| tokens = [] |
| newline_tokens = tokenizer.encode("\n") |
| if not newline_tokens: |
| newline_tokens = [10] |
| with open(path, "r", encoding="utf-8", errors="ignore") as f: |
| for line in f: |
| stripped = line.strip() |
| if not stripped: |
| continue |
| tokens.extend(tokenizer.encode(stripped)) |
| tokens.extend(newline_tokens) |
|
|
| token_tensor = torch.tensor(tokens, dtype=torch.long) |
| if token_tensor.numel() <= self.block_size + 1: |
| raise RuntimeError(f"Not enough tokens in {path} for block_size={self.block_size}") |
|
|
| |
| self.x_windows = token_tensor[:-1].unfold(0, self.block_size, 1) |
| self.y_windows = token_tensor[1:].unfold(0, self.block_size, 1) |
| self.num_windows = int(self.x_windows.size(0)) |
| print(f"Loaded {os.path.basename(path)}: tokens={token_tensor.numel()} windows={self.num_windows}") |
|
|
| def get_batch(self, batch_size): |
| starts = torch.randint(0, self.num_windows, (batch_size,), dtype=torch.long) |
| xb = self.x_windows.index_select(0, starts).to(device) |
| yb = self.y_windows.index_select(0, starts).to(device) |
| return xb, yb |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, dataset, batch_size, num_batches): |
| model.eval() |
| total = 0.0 |
| for _ in range(num_batches): |
| xb, yb = dataset.get_batch(batch_size) |
| _, loss = model(xb, yb) |
| total += loss.item() |
| model.train() |
| return total / max(1, num_batches) |
|
|
|
|
| def apply_top_p(logits, top_p): |
| if top_p is None or top_p >= 1.0: |
| return logits |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| probs = torch.softmax(sorted_logits, dim=-1) |
| cumprobs = torch.cumsum(probs, dim=-1) |
| mask = cumprobs > top_p |
| mask[..., 1:] = mask[..., :-1].clone() |
| mask[..., 0] = False |
| sorted_logits[mask] = -1e9 |
| out = torch.full_like(logits, -1e9) |
| out.scatter_(dim=-1, index=sorted_indices, src=sorted_logits) |
| return out |
|
|
|
|
| @torch.no_grad() |
| def sample( |
| model, |
| tokenizer, |
| prompt="User: Hello\nAssistant:", |
| max_new_tokens=48, |
| temperature=0.8, |
| top_k=50, |
| top_p=0.9, |
| ): |
| model.eval() |
| idx = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device)[None, :] |
| start_len = idx.size(1) |
|
|
| for _ in range(max_new_tokens): |
| idx_cond = idx[:, -model.cfg.block_size :] |
| logits, _ = model(idx_cond) |
| logits = logits[:, -1, :] / max(temperature, 1e-6) |
|
|
| if top_k is not None: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -1e9 |
|
|
| logits = apply_top_p(logits, top_p) |
| probs = torch.softmax(logits, dim=-1) |
| idx_next = torch.multinomial(probs, 1) |
| idx = torch.cat([idx, idx_next], dim=1) |
|
|
| model.train() |
| new_tokens = idx[0, start_len:].tolist() |
| return tokenizer.decode(new_tokens) |
|
|
|
|
| def save_checkpoint(path, model, optimizer, step, ema, best_val): |
| torch.save( |
| { |
| "format_version": 2, |
| "vocab_size": model.head.out_features, |
| "model_config": { |
| **model.cfg.to_dict(), |
| }, |
| "model": model.state_dict(), |
| "opt": optimizer.state_dict(), |
| "step": step, |
| "ema": ema, |
| "best_val": best_val, |
| }, |
| path, |
| ) |
|
|
|
|
| def write_sample_snapshot( |
| model, |
| tokenizer, |
| step: int, |
| checkpoint_path: str, |
| reason: str, |
| args, |
| ): |
| out = sample( |
| model, |
| tokenizer, |
| temperature=args.sample_temperature, |
| top_k=args.sample_top_k, |
| top_p=args.sample_top_p, |
| ) |
| with open("samples.txt", "a", encoding="utf-8") as f: |
| f.write( |
| f"\n--- step {step} | {reason} | checkpoint: {checkpoint_path} ---\n" |
| f"{out}\n" |
| ) |
| return out |
|
|
|
|
| def config_from_args(args) -> GPTConfig: |
| return GPTConfig( |
| n_embd=int(args.n_embd) if int(args.n_embd) > 0 else DEFAULT_CONFIG.n_embd, |
| n_head=int(args.n_head) if int(args.n_head) > 0 else DEFAULT_CONFIG.n_head, |
| n_layer=int(args.n_layer) if int(args.n_layer) > 0 else DEFAULT_CONFIG.n_layer, |
| block_size=int(args.block_size) if int(args.block_size) > 0 else DEFAULT_CONFIG.block_size, |
| dropout=float(args.dropout) if float(args.dropout) >= 0.0 else float(DEFAULT_CONFIG.dropout), |
| ) |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| torch.manual_seed(args.seed) |
| torch.set_float32_matmul_precision("high") |
| torch.set_num_threads(args.threads) |
| torch.set_num_interop_threads(args.interop_threads) |
| print("PyTorch threads:", torch.get_num_threads()) |
| print("Interop threads:", torch.get_num_interop_threads()) |
|
|
| ensure_data_ready(args) |
| tokenizer = load_tokenizer() |
| vocab_size = len(tokenizer.vocab) |
|
|
| ckpt = None |
| cfg: GPTConfig |
| if os.path.exists(args.ckpt_path): |
| ckpt = torch.load(args.ckpt_path, map_location=device) |
| ckpt_vocab = ckpt.get("vocab_size") |
| if ckpt_vocab is not None and int(ckpt_vocab) != vocab_size: |
| raise RuntimeError( |
| f"Checkpoint/tokenizer mismatch: ckpt vocab_size={ckpt_vocab}, tokenizer vocab_size={vocab_size}. " |
| "Start a fresh checkpoint path for the new tokenizer." |
| ) |
| cfg = config_from_dict(ckpt.get("model_config")) |
|
|
| |
| requested = config_from_args(args) |
| overrides = [] |
| if int(args.n_embd) > 0 and requested.n_embd != cfg.n_embd: |
| overrides.append(f"n_embd={requested.n_embd} (ckpt {cfg.n_embd})") |
| if int(args.n_head) > 0 and requested.n_head != cfg.n_head: |
| overrides.append(f"n_head={requested.n_head} (ckpt {cfg.n_head})") |
| if int(args.n_layer) > 0 and requested.n_layer != cfg.n_layer: |
| overrides.append(f"n_layer={requested.n_layer} (ckpt {cfg.n_layer})") |
| if int(args.block_size) > 0 and requested.block_size != cfg.block_size: |
| overrides.append(f"block_size={requested.block_size} (ckpt {cfg.block_size})") |
| if float(args.dropout) >= 0.0 and abs(requested.dropout - cfg.dropout) > 1e-9: |
| overrides.append(f"dropout={requested.dropout} (ckpt {cfg.dropout})") |
| if overrides: |
| raise RuntimeError( |
| "You are resuming from an existing checkpoint, but you also requested a different model size. " |
| "Use a new --ckpt-path/--best-path to start fresh, or remove the size overrides. " |
| "Mismatches: " + ", ".join(overrides) |
| ) |
| print("Resuming checkpoint model_config:", cfg.to_dict()) |
| else: |
| cfg = config_from_args(args) |
| cfg.validate() |
| print("Fresh model_config:", cfg.to_dict()) |
|
|
| train_ds = TokenWindowDataset(args.train_data, tokenizer, block_size=cfg.block_size) |
| val_ds = TokenWindowDataset(args.val_data, tokenizer, block_size=cfg.block_size) |
|
|
| model = GPT(vocab_size, cfg=cfg).to(device) |
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.1) |
|
|
| start_step = 0 |
| ema_loss = None |
| best_val = float("inf") |
|
|
| if ckpt is not None: |
| try: |
| model.load_state_dict(ckpt["model"], strict=True) |
| except Exception as exc: |
| raise RuntimeError( |
| f"Checkpoint is incompatible with current model: {exc}. " |
| "Use a new --ckpt-path for fresh training." |
| ) from exc |
|
|
| if args.reset_optimizer: |
| print("Optimizer reset requested; starting with fresh optimizer state.") |
| elif "opt" in ckpt: |
| try: |
| optimizer.load_state_dict(ckpt["opt"]) |
| print("Optimizer state restored") |
| except Exception as exc: |
| print(f"Optimizer state incompatible, starting fresh optimizer: {exc}") |
| else: |
| print("Optimizer state missing, starting fresh optimizer") |
|
|
| raw_step = int(ckpt.get("step", 0)) |
| fmt = int(ckpt.get("format_version", 1)) |
| |
| start_step = raw_step if fmt >= 2 else raw_step + (1 if raw_step > 0 else 0) |
| ema_loss = ckpt.get("ema", None) |
| best_val = float(ckpt.get("best_val", best_val)) |
| if args.reset_best_val: |
| best_val = float("inf") |
| print("Best validation reset requested; best_val=inf") |
| print(f"Resumed from step {start_step}") |
| else: |
| print("Fresh start") |
|
|
| if args.run_steps is not None: |
| end_step = start_step + args.run_steps |
| else: |
| end_step = args.max_steps |
|
|
| if start_step >= end_step: |
| print(f"Nothing to do: start_step={start_step} >= end_step={end_step}") |
| return |
|
|
| run_span = end_step - start_step |
| effective_warmup = min(args.warmup_steps, max(1, run_span // 10)) |
|
|
| print(f"TRAINING STARTED | from {start_step} to {end_step - 1} | warmup={effective_warmup}") |
| tokens_per_step = args.batch_size * args.accum_steps * cfg.block_size |
| wall_t0 = time.time() |
| log_t0 = wall_t0 |
|
|
| metrics_header_needed = not os.path.exists(args.metrics_csv) |
| no_improve_evals = 0 |
| should_stop_early = False |
| last_step = start_step |
| with open(args.metrics_csv, "a", encoding="utf-8", newline="") as csv_file: |
| writer = csv.writer(csv_file) |
| if metrics_header_needed: |
| writer.writerow(["step", "loss", "ema_loss", "val_loss", "lr", "tokens_per_sec"]) |
|
|
| for step in range(start_step, end_step): |
| model.train() |
| optimizer.zero_grad(set_to_none=True) |
|
|
| |
| if args.run_steps is not None: |
| progress = (step - start_step + 1) / max(1, args.run_steps) |
| else: |
| progress = (step + 1) / max(1, args.max_steps) |
|
|
| if step - start_step < effective_warmup: |
| lr_scale = (step - start_step + 1) / max(1, effective_warmup) |
| else: |
| lr_scale = 0.5 * (1.0 + math.cos(progress * math.pi)) |
| lr_scale = max(0.1, lr_scale) |
| for pg in optimizer.param_groups: |
| pg["lr"] = args.lr * lr_scale |
|
|
| micro_losses = [] |
| for _ in range(args.accum_steps): |
| xb, yb = train_ds.get_batch(args.batch_size) |
| logits, _ = model(xb, None) |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| yb.view(-1), |
| label_smoothing=max(0.0, min(0.2, args.label_smoothing)), |
| ) |
| micro_losses.append(float(loss.item())) |
| (loss / args.accum_steps).backward() |
|
|
| if args.grad_clip > 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) |
| optimizer.step() |
| last_step = step + 1 |
|
|
| loss_val = sum(micro_losses) / max(1, len(micro_losses)) |
| if not math.isfinite(loss_val): |
| raise RuntimeError(f"Non-finite loss encountered at step {step}: {loss_val}") |
| ema_loss = loss_val if ema_loss is None else (0.95 * ema_loss + 0.05 * loss_val) |
|
|
| val_loss = None |
| if step % args.eval_every == 0: |
| val_loss = evaluate(model, val_ds, args.batch_size, args.eval_batches) |
| print( |
| f"Step {step:7d} | train {loss_val:.4f} | ema {ema_loss:.4f} " |
| f"| val {val_loss:.4f} | lr {optimizer.param_groups[0]['lr']:.6f}" |
| ) |
| if val_loss < best_val: |
| best_val = val_loss |
| no_improve_evals = 0 |
| save_checkpoint(args.best_path, model, optimizer, step + 1, ema_loss, best_val) |
| write_sample_snapshot( |
| model, |
| tokenizer, |
| step + 1, |
| args.best_path, |
| "new best", |
| args, |
| ) |
| print(f"New best checkpoint saved to {args.best_path}") |
| else: |
| no_improve_evals += 1 |
| if args.early_stop_patience > 0 and no_improve_evals >= args.early_stop_patience: |
| should_stop_early = True |
| print( |
| f"Early stop triggered at step {step}: " |
| f"no val improvement for {no_improve_evals} evals." |
| ) |
|
|
| if step % args.log_every == 0 and step > start_step: |
| now = time.time() |
| elapsed = now - log_t0 |
| tps = (tokens_per_step * args.log_every) / max(1e-6, elapsed) |
| log_t0 = now |
| writer.writerow([step, f"{loss_val:.6f}", f"{ema_loss:.6f}", "" if val_loss is None else f"{val_loss:.6f}", f"{optimizer.param_groups[0]['lr']:.8f}", f"{tps:.2f}"]) |
| csv_file.flush() |
|
|
| if step % args.sample_every == 0 and step > start_step: |
| write_sample_snapshot( |
| model, |
| tokenizer, |
| step, |
| "(scheduled sample)", |
| "sample interval", |
| args, |
| ) |
|
|
| if step % args.save_every == 0 and step > start_step: |
| save_checkpoint(args.ckpt_path, model, optimizer, step + 1, ema_loss, best_val) |
| write_sample_snapshot( |
| model, |
| tokenizer, |
| step + 1, |
| args.ckpt_path, |
| "checkpoint save", |
| args, |
| ) |
|
|
| if should_stop_early: |
| break |
|
|
| final_step = last_step if should_stop_early else end_step |
| save_checkpoint(args.ckpt_path, model, optimizer, final_step, ema_loss, best_val) |
| write_sample_snapshot( |
| model, |
| tokenizer, |
| final_step, |
| args.ckpt_path, |
| "final save", |
| args, |
| ) |
| elapsed_total = time.time() - wall_t0 |
| print( |
| f"TRAINING COMPLETE | elapsed={elapsed_total/60.0:.2f} min " |
| f"| final_step={final_step} | best_val={best_val:.4f}" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|