# /// script # requires-python = ">=3.10" # dependencies = [ # "torch>=2.0", # "numpy", # "accelerate", # "datasets", # "huggingface-hub>=0.20", # "psutil", # ] # /// """One-command GPU training for KAN-JEPA SOTA pipeline. Works on: Google Colab, HuggingFace Jobs, local GPU (CUDA/MPS), CPU fallback. Usage: # Colab (in notebook cell): !python scripts/gpu_train.py --preset text2cypher --epochs 100 # HF Jobs: hf jobs run scripts/gpu_train.py --hardware t4-small --secret HF_TOKEN # Local: python3 scripts/gpu_train.py --preset text2cypher --epochs 50 # Multi-domain SOTA: python3 scripts/gpu_train.py --preset all --epochs 200 --n-pairs 20000 All checkpoints are pushed to HF Hub as PRIVATE repos. """ from __future__ import annotations import argparse import json import os import random import sys import time from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Tuple import numpy as np import torch # --------------------------------------------------------------------------- # Environment detection # --------------------------------------------------------------------------- def detect_env() -> Dict[str, Any]: """Auto-detect runtime environment and GPU.""" info: Dict[str, Any] = {"device": "cpu", "gpu_name": None, "vram_gb": 0, "mixed_precision": "no", "env": "local"} if "COLAB_GPU" in os.environ or os.path.exists("/content"): info["env"] = "colab" elif os.environ.get("HF_JOBS"): info["env"] = "hf_jobs" if torch.cuda.is_available(): info["device"] = "cuda" info["gpu_name"] = torch.cuda.get_device_name(0) props = torch.cuda.get_device_properties(0) info["vram_gb"] = props.total_mem / 1e9 cap = torch.cuda.get_device_capability() info["mixed_precision"] = "bf16" if cap[0] >= 8 else "fp16" elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): info["device"] = "mps" return info def ensure_imports(): """Ensure training package is importable (handles HF Jobs cloning).""" if (Path(__file__).resolve().parent.parent / "training" / "core").exists(): sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) return # HF Jobs: clone from Hub repo_dir = Path("/tmp/ane-repo") if not (repo_dir / "training" / "core").exists(): import subprocess token = os.environ.get("HF_TOKEN", "") url = f"https://{'user:' + token + '@' if token else ''}huggingface.co/JohnGenetica/ane-sota-kan-v2" branch = os.environ.get("REPO_BRANCH", "x/ane-local-backport") subprocess.run(["git", "clone", "--depth", "1", "--branch", branch, url, str(repo_dir)], check=True, capture_output=True) sys.path.insert(0, str(repo_dir)) # --------------------------------------------------------------------------- # Data loading # --------------------------------------------------------------------------- def load_data(preset: str, n_pairs: int, include_hf: bool) -> Tuple[List[Tuple[str, str]], Dict[str, int]]: """Load training data from all sources: curated + benchmarks + SOTA file.""" pairs: List[Tuple[str, str]] = [] stats: Dict[str, int] = {} # Source 1: DatasetBuilder (curated edge cases, rewards, all 10 domains) try: from training.core.dataset_builder import DatasetBuilder db = DatasetBuilder() base = db.build_all() expanded = db.expand_parametric(base, factor=5) curated = db.to_pairs(expanded) pairs.extend(curated) stats["curated"] = len(curated) print(f" DatasetBuilder: {len(curated)} pairs") except Exception as e: print(f" DatasetBuilder: {e}") # Source 2: Benchmark adapters benchmarks = _get_benchmarks(preset) try: from training.leaderboard_data import AdapterRegistry for bench in benchmarks: t0 = time.time() try: kwargs = {"use_hf": include_hf} if bench == "text2cypher" else {} adapter = AdapterRegistry.get(bench, **kwargs) data = adapter.load(split="train", max_n=n_pairs) bp = [(d.question, d.gold) for d in data if d.question and d.gold] pairs.extend(bp) stats[bench] = len(bp) print(f" {bench}: {len(bp)} pairs ({time.time()-t0:.1f}s)") except Exception as e: print(f" {bench}: FAILED - {e}") except ImportError: print(" AdapterRegistry not available") # Source 3: Pre-generated SOTA data for sota_path in ["training/kan_bench_results/sota_training_data.json", "training/kan_bench_results/sota_curated_training_data.json"]: if Path(sota_path).exists(): try: raw = json.loads(Path(sota_path).read_text()) sp = [(d["question"], d["gold"]) for d in raw if d.get("question") and d.get("gold")] pairs.extend(sp) stats[Path(sota_path).stem] = len(sp) print(f" {Path(sota_path).name}: {len(sp)} pairs") except Exception as e: print(f" {Path(sota_path).name}: {e}") # Dedup seen = set() deduped = [] for q, g in pairs: key = (q.strip(), g.strip()) if key not in seen: seen.add(key) deduped.append((q, g)) random.shuffle(deduped) final = deduped[:n_pairs] if len(deduped) > n_pairs else deduped print(f" Total: {len(final)} unique pairs (from {len(pairs)} raw)") return final, stats def _get_benchmarks(preset: str) -> List[str]: """Map preset to list of benchmarks.""" mapping = { "text2cypher": ["text2cypher"], "spider2": ["spider2", "bird_sql"], "swebench": ["swebench"], "code": ["humaneval", "mbpp", "livecodebench"], "all": ["text2cypher", "spider2", "swebench", "humaneval", "mbpp", "gaia", "bird_sql", "livecodebench", "text2gql", "mmlu_pro", "gpqa_diamond"], } return mapping.get(preset, ["text2cypher"]) # --------------------------------------------------------------------------- # Training # --------------------------------------------------------------------------- def train(args, env_info: Dict[str, Any]): """Full training pipeline.""" device = env_info["device"] mp = env_info["mixed_precision"] # Load data print("\n--- Data Loading ---") all_pairs, data_stats = load_data(args.preset, args.n_pairs, args.include_hf) split_idx = max(10, int(len(all_pairs) * 0.9)) train_pairs = all_pairs[:split_idx] val_pairs = all_pairs[split_idx:] print(f" Split: {len(train_pairs)} train / {len(val_pairs)} val") # Build vocab from training.core.bidirectional_generator import SimpleVocab texts = [q for q, _ in all_pairs] + [g for _, g in all_pairs] vocab_cap = 4096 if len(all_pairs) < 2000 else (8192 if len(all_pairs) < 10000 else 16384) vocab = SimpleVocab.build_from_corpus(texts, max_size=min(vocab_cap, args.vocab_size)) print(f" Vocab: {len(vocab)} tokens") # Create model print("\n--- Model Creation ---") from training.core.super_model import UnifiedSuperModel benchmark_map = { "text2cypher": "text2cypher", "spider2": "spider2", "swebench": "swebench", "code": "code", "all": "code", } benchmark = benchmark_map.get(args.preset, "text2cypher") if args.preset == "all": model = UnifiedSuperModel.for_sota(vocab_size=len(vocab), benchmark=benchmark) else: factories = { "text2cypher": UnifiedSuperModel.for_text2cypher, "spider2": UnifiedSuperModel.for_spider2, "swebench": UnifiedSuperModel.for_swebench, "code": UnifiedSuperModel.for_code, } factory = factories.get(args.preset, UnifiedSuperModel.for_text2cypher) model = factory(vocab_size=len(vocab)) if device == "cuda": model = model.cuda() n_params = sum(p.numel() for p in model.parameters()) n_gen = sum(p.numel() for p in model.generator.parameters()) print(f" {n_gen:,} generator / {n_params:,} total params") # Phase 1: AccelerateTrainer print("\n" + "=" * 60) print(" PHASE 1: Base Training") print("=" * 60) from training.core.accelerate_trainer import AccelerateTrainer trainer_kwargs = dict( epochs=args.epochs, lr=args.lr, batch_size=args.batch_size, ) if args.hub_repo: trainer_kwargs.update(push_to_hub=True, hub_repo=args.hub_repo) if device == "cuda": trainer = AccelerateTrainer.for_colab( model.generator, vocab, train_pairs, **trainer_kwargs) elif device == "mps": trainer = AccelerateTrainer.for_local( model.generator, vocab, train_pairs, **trainer_kwargs) else: trainer_kwargs["epochs"] = min(args.epochs, 30) trainer = AccelerateTrainer.for_local( model.generator, vocab, train_pairs, **trainer_kwargs) base_result = trainer.train(verbose=True) print(f"\n Phase 1: loss={base_result['final_loss']:.4f} " f"BLEU={base_result.get('final_bleu', 0):.1f} " f"({base_result['training_time_s']:.0f}s on {base_result['device']})") # Phase 2: Flywheel (if GPU available for speed) if not args.no_flywheel: print("\n" + "=" * 60) print(" PHASE 2: Flywheel Self-Learning") print("=" * 60) try: model.train_generative_with_flywheel( train_pairs[:5000], vocab, base_epochs=0, flywheel_preset="sota" if device == "cuda" else "default", verbose=True, ) except Exception as e: print(f" Flywheel: {e}") # Phase 3: Evolution (optional) if args.evolve and device == "cuda": print("\n" + "=" * 60) print(" PHASE 3: Evolutionary Search") print("=" * 60) try: evo = model.evolve_generator( train_pairs[:2000], vocab, val_pairs, population_size=4, generations=3, verbose=True) print(f" Best: {evo.get('best_score', 0):.4f}") except Exception as e: print(f" Evolution: {e}") # Evaluate print("\n" + "=" * 60) print(" EVALUATION") print("=" * 60) from training.core.generative_flywheel import score_generation model.generator.eval() dev = next(model.generator.parameters()).device eval_pairs = val_pairs[:min(200, len(val_pairs))] scores = {"bleu4": [], "rouge_l": [], "exact_match": [], "chrf": [], "composite": []} for q, gold in eval_pairs: enc = vocab.encode(q)[:128] if len(enc) < 2: continue try: src = torch.tensor([enc], dtype=torch.long, device=dev) text, _, _ = model.generate(src, vocab, max_len=128, temperature=0.0) sc = score_generation(gold, text) for k in scores: scores[k].append(getattr(sc, k, 0.0) * 100) except Exception: pass final_metrics = {k: float(np.mean(v)) if v else 0.0 for k, v in scores.items()} print(f"\n{'Metric':<15} {'Score':>8}") print("-" * 25) for k, v in final_metrics.items(): print(f" {k:<13} {v:>7.1f}%") # Save + push print("\n--- Saving ---") save_dir = Path(args.save_dir) save_dir.mkdir(parents=True, exist_ok=True) tag = datetime.now().strftime("%Y%m%d_%H%M") ckpt_path = save_dir / f"gpu_{args.preset}_{tag}.pt" torch.save({ "model_state_dict": model.state_dict(), "generator_state_dict": model.generator.state_dict(), "vocab_size": len(vocab), "preset": args.preset, "metrics": final_metrics, "data_stats": data_stats, "config": vars(args), "env": env_info, "timestamp": datetime.now().isoformat(), }, ckpt_path) print(f" Checkpoint: {ckpt_path}") # Results JSON results_path = save_dir / f"gpu_{args.preset}_{tag}_results.json" results_path.write_text(json.dumps({ "metrics": final_metrics, "data_stats": data_stats, "n_train": len(train_pairs), "n_val": len(val_pairs), "n_params": n_params, "env": env_info, "novelty_claims": model.novelty_summary(), }, indent=2, default=str)) print(f" Results: {results_path}") # Push to Hub (PRIVATE) if args.hub_repo: try: url = model.push_to_hub( args.hub_repo, private=True, metrics=final_metrics, vocab=vocab) print(f" Hub: {url}") except Exception as e: print(f" Hub push failed: {e}") print(f"\nDone! ({env_info['env']}, {env_info['device']}, {env_info.get('gpu_name', 'N/A')})") return final_metrics # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser( description="KAN-JEPA GPU Training", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__) parser.add_argument("--preset", default="text2cypher", choices=["text2cypher", "spider2", "swebench", "code", "all"]) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--n-pairs", type=int, default=5000) parser.add_argument("--lr", type=float, default=2e-3) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--vocab-size", type=int, default=16384) parser.add_argument("--include-hf", action="store_true", default=True) parser.add_argument("--no-hf", dest="include_hf", action="store_false") parser.add_argument("--no-flywheel", action="store_true") parser.add_argument("--evolve", action="store_true", default=True) parser.add_argument("--no-evolve", dest="evolve", action="store_false") parser.add_argument("--hub-repo", default="JohnGenetica/ane-sota-kan-v2") parser.add_argument("--save-dir", default="checkpoints") parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() # Also accept env vars (for HF Jobs) if os.environ.get("PRESET"): args.preset = os.environ["PRESET"] if os.environ.get("EPOCHS"): args.epochs = int(os.environ["EPOCHS"]) if os.environ.get("N_PAIRS"): args.n_pairs = int(os.environ["N_PAIRS"]) if os.environ.get("LR"): args.lr = float(os.environ["LR"]) if os.environ.get("BATCH_SIZE"): args.batch_size = int(os.environ["BATCH_SIZE"]) if os.environ.get("HF_REPO"): args.hub_repo = os.environ["HF_REPO"] random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) print("=" * 60) print(" KAN-JEPA SOTA GPU Training") print("=" * 60) env_info = detect_env() print(f" Env: {env_info['env']}") print(f" Device: {env_info['device']}") if env_info["gpu_name"]: print(f" GPU: {env_info['gpu_name']} ({env_info['vram_gb']:.1f} GB)") print(f" Mixed: {env_info['mixed_precision']}") print(f" Preset: {args.preset}") print(f" Epochs: {args.epochs}") print(f" Pairs: {args.n_pairs}") ensure_imports() train(args, env_info) if __name__ == "__main__": main()