Spaces:
Build error
Build error
| # /// script | |
| # requires-python = ">=3.10" | |
| # dependencies = [ | |
| # "torch>=2.0", | |
| # "numpy", | |
| # "accelerate", | |
| # "datasets", | |
| # "huggingface-hub>=0.20", | |
| # "psutil", | |
| # ] | |
| # /// | |
| """One-command GPU training for KAN-JEPA SOTA pipeline. | |
| Works on: Google Colab, HuggingFace Jobs, local GPU (CUDA/MPS), CPU fallback. | |
| Usage: | |
| # Colab (in notebook cell): | |
| !python scripts/gpu_train.py --preset text2cypher --epochs 100 | |
| # HF Jobs: | |
| hf jobs run scripts/gpu_train.py --hardware t4-small --secret HF_TOKEN | |
| # Local: | |
| python3 scripts/gpu_train.py --preset text2cypher --epochs 50 | |
| # Multi-domain SOTA: | |
| python3 scripts/gpu_train.py --preset all --epochs 200 --n-pairs 20000 | |
| All checkpoints are pushed to HF Hub as PRIVATE repos. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import random | |
| import sys | |
| import time | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Tuple | |
| import numpy as np | |
| import torch | |
| # --------------------------------------------------------------------------- | |
| # Environment detection | |
| # --------------------------------------------------------------------------- | |
| def detect_env() -> Dict[str, Any]: | |
| """Auto-detect runtime environment and GPU.""" | |
| info: Dict[str, Any] = {"device": "cpu", "gpu_name": None, "vram_gb": 0, | |
| "mixed_precision": "no", "env": "local"} | |
| if "COLAB_GPU" in os.environ or os.path.exists("/content"): | |
| info["env"] = "colab" | |
| elif os.environ.get("HF_JOBS"): | |
| info["env"] = "hf_jobs" | |
| if torch.cuda.is_available(): | |
| info["device"] = "cuda" | |
| info["gpu_name"] = torch.cuda.get_device_name(0) | |
| props = torch.cuda.get_device_properties(0) | |
| info["vram_gb"] = props.total_mem / 1e9 | |
| cap = torch.cuda.get_device_capability() | |
| info["mixed_precision"] = "bf16" if cap[0] >= 8 else "fp16" | |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| info["device"] = "mps" | |
| return info | |
| def ensure_imports(): | |
| """Ensure training package is importable (handles HF Jobs cloning).""" | |
| if (Path(__file__).resolve().parent.parent / "training" / "core").exists(): | |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | |
| return | |
| # HF Jobs: clone from Hub | |
| repo_dir = Path("/tmp/ane-repo") | |
| if not (repo_dir / "training" / "core").exists(): | |
| import subprocess | |
| token = os.environ.get("HF_TOKEN", "") | |
| url = f"https://{'user:' + token + '@' if token else ''}huggingface.co/JohnGenetica/ane-sota-kan-v2" | |
| branch = os.environ.get("REPO_BRANCH", "x/ane-local-backport") | |
| subprocess.run(["git", "clone", "--depth", "1", "--branch", branch, | |
| url, str(repo_dir)], check=True, capture_output=True) | |
| sys.path.insert(0, str(repo_dir)) | |
| # --------------------------------------------------------------------------- | |
| # Data loading | |
| # --------------------------------------------------------------------------- | |
| def load_data(preset: str, n_pairs: int, include_hf: bool) -> Tuple[List[Tuple[str, str]], Dict[str, int]]: | |
| """Load training data from all sources: curated + benchmarks + SOTA file.""" | |
| pairs: List[Tuple[str, str]] = [] | |
| stats: Dict[str, int] = {} | |
| # Source 1: DatasetBuilder (curated edge cases, rewards, all 10 domains) | |
| try: | |
| from training.core.dataset_builder import DatasetBuilder | |
| db = DatasetBuilder() | |
| base = db.build_all() | |
| expanded = db.expand_parametric(base, factor=5) | |
| curated = db.to_pairs(expanded) | |
| pairs.extend(curated) | |
| stats["curated"] = len(curated) | |
| print(f" DatasetBuilder: {len(curated)} pairs") | |
| except Exception as e: | |
| print(f" DatasetBuilder: {e}") | |
| # Source 2: Benchmark adapters | |
| benchmarks = _get_benchmarks(preset) | |
| try: | |
| from training.leaderboard_data import AdapterRegistry | |
| for bench in benchmarks: | |
| t0 = time.time() | |
| try: | |
| kwargs = {"use_hf": include_hf} if bench == "text2cypher" else {} | |
| adapter = AdapterRegistry.get(bench, **kwargs) | |
| data = adapter.load(split="train", max_n=n_pairs) | |
| bp = [(d.question, d.gold) for d in data if d.question and d.gold] | |
| pairs.extend(bp) | |
| stats[bench] = len(bp) | |
| print(f" {bench}: {len(bp)} pairs ({time.time()-t0:.1f}s)") | |
| except Exception as e: | |
| print(f" {bench}: FAILED - {e}") | |
| except ImportError: | |
| print(" AdapterRegistry not available") | |
| # Source 3: Pre-generated SOTA data | |
| for sota_path in ["training/kan_bench_results/sota_training_data.json", | |
| "training/kan_bench_results/sota_curated_training_data.json"]: | |
| if Path(sota_path).exists(): | |
| try: | |
| raw = json.loads(Path(sota_path).read_text()) | |
| sp = [(d["question"], d["gold"]) for d in raw | |
| if d.get("question") and d.get("gold")] | |
| pairs.extend(sp) | |
| stats[Path(sota_path).stem] = len(sp) | |
| print(f" {Path(sota_path).name}: {len(sp)} pairs") | |
| except Exception as e: | |
| print(f" {Path(sota_path).name}: {e}") | |
| # Dedup | |
| seen = set() | |
| deduped = [] | |
| for q, g in pairs: | |
| key = (q.strip(), g.strip()) | |
| if key not in seen: | |
| seen.add(key) | |
| deduped.append((q, g)) | |
| random.shuffle(deduped) | |
| final = deduped[:n_pairs] if len(deduped) > n_pairs else deduped | |
| print(f" Total: {len(final)} unique pairs (from {len(pairs)} raw)") | |
| return final, stats | |
| def _get_benchmarks(preset: str) -> List[str]: | |
| """Map preset to list of benchmarks.""" | |
| mapping = { | |
| "text2cypher": ["text2cypher"], | |
| "spider2": ["spider2", "bird_sql"], | |
| "swebench": ["swebench"], | |
| "code": ["humaneval", "mbpp", "livecodebench"], | |
| "all": ["text2cypher", "spider2", "swebench", "humaneval", "mbpp", | |
| "gaia", "bird_sql", "livecodebench", "text2gql", | |
| "mmlu_pro", "gpqa_diamond"], | |
| } | |
| return mapping.get(preset, ["text2cypher"]) | |
| # --------------------------------------------------------------------------- | |
| # Training | |
| # --------------------------------------------------------------------------- | |
| def train(args, env_info: Dict[str, Any]): | |
| """Full training pipeline.""" | |
| device = env_info["device"] | |
| mp = env_info["mixed_precision"] | |
| # Load data | |
| print("\n--- Data Loading ---") | |
| all_pairs, data_stats = load_data(args.preset, args.n_pairs, args.include_hf) | |
| split_idx = max(10, int(len(all_pairs) * 0.9)) | |
| train_pairs = all_pairs[:split_idx] | |
| val_pairs = all_pairs[split_idx:] | |
| print(f" Split: {len(train_pairs)} train / {len(val_pairs)} val") | |
| # Build vocab | |
| from training.core.bidirectional_generator import SimpleVocab | |
| texts = [q for q, _ in all_pairs] + [g for _, g in all_pairs] | |
| vocab_cap = 4096 if len(all_pairs) < 2000 else (8192 if len(all_pairs) < 10000 else 16384) | |
| vocab = SimpleVocab.build_from_corpus(texts, max_size=min(vocab_cap, args.vocab_size)) | |
| print(f" Vocab: {len(vocab)} tokens") | |
| # Create model | |
| print("\n--- Model Creation ---") | |
| from training.core.super_model import UnifiedSuperModel | |
| benchmark_map = { | |
| "text2cypher": "text2cypher", "spider2": "spider2", | |
| "swebench": "swebench", "code": "code", "all": "code", | |
| } | |
| benchmark = benchmark_map.get(args.preset, "text2cypher") | |
| if args.preset == "all": | |
| model = UnifiedSuperModel.for_sota(vocab_size=len(vocab), benchmark=benchmark) | |
| else: | |
| factories = { | |
| "text2cypher": UnifiedSuperModel.for_text2cypher, | |
| "spider2": UnifiedSuperModel.for_spider2, | |
| "swebench": UnifiedSuperModel.for_swebench, | |
| "code": UnifiedSuperModel.for_code, | |
| } | |
| factory = factories.get(args.preset, UnifiedSuperModel.for_text2cypher) | |
| model = factory(vocab_size=len(vocab)) | |
| if device == "cuda": | |
| model = model.cuda() | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| n_gen = sum(p.numel() for p in model.generator.parameters()) | |
| print(f" {n_gen:,} generator / {n_params:,} total params") | |
| # Phase 1: AccelerateTrainer | |
| print("\n" + "=" * 60) | |
| print(" PHASE 1: Base Training") | |
| print("=" * 60) | |
| from training.core.accelerate_trainer import AccelerateTrainer | |
| trainer_kwargs = dict( | |
| epochs=args.epochs, lr=args.lr, batch_size=args.batch_size, | |
| ) | |
| if args.hub_repo: | |
| trainer_kwargs.update(push_to_hub=True, hub_repo=args.hub_repo) | |
| if device == "cuda": | |
| trainer = AccelerateTrainer.for_colab( | |
| model.generator, vocab, train_pairs, **trainer_kwargs) | |
| elif device == "mps": | |
| trainer = AccelerateTrainer.for_local( | |
| model.generator, vocab, train_pairs, **trainer_kwargs) | |
| else: | |
| trainer_kwargs["epochs"] = min(args.epochs, 30) | |
| trainer = AccelerateTrainer.for_local( | |
| model.generator, vocab, train_pairs, **trainer_kwargs) | |
| base_result = trainer.train(verbose=True) | |
| print(f"\n Phase 1: loss={base_result['final_loss']:.4f} " | |
| f"BLEU={base_result.get('final_bleu', 0):.1f} " | |
| f"({base_result['training_time_s']:.0f}s on {base_result['device']})") | |
| # Phase 2: Flywheel (if GPU available for speed) | |
| if not args.no_flywheel: | |
| print("\n" + "=" * 60) | |
| print(" PHASE 2: Flywheel Self-Learning") | |
| print("=" * 60) | |
| try: | |
| model.train_generative_with_flywheel( | |
| train_pairs[:5000], vocab, | |
| base_epochs=0, | |
| flywheel_preset="sota" if device == "cuda" else "default", | |
| verbose=True, | |
| ) | |
| except Exception as e: | |
| print(f" Flywheel: {e}") | |
| # Phase 3: Evolution (optional) | |
| if args.evolve and device == "cuda": | |
| print("\n" + "=" * 60) | |
| print(" PHASE 3: Evolutionary Search") | |
| print("=" * 60) | |
| try: | |
| evo = model.evolve_generator( | |
| train_pairs[:2000], vocab, val_pairs, | |
| population_size=4, generations=3, verbose=True) | |
| print(f" Best: {evo.get('best_score', 0):.4f}") | |
| except Exception as e: | |
| print(f" Evolution: {e}") | |
| # Evaluate | |
| print("\n" + "=" * 60) | |
| print(" EVALUATION") | |
| print("=" * 60) | |
| from training.core.generative_flywheel import score_generation | |
| model.generator.eval() | |
| dev = next(model.generator.parameters()).device | |
| eval_pairs = val_pairs[:min(200, len(val_pairs))] | |
| scores = {"bleu4": [], "rouge_l": [], "exact_match": [], "chrf": [], "composite": []} | |
| for q, gold in eval_pairs: | |
| enc = vocab.encode(q)[:128] | |
| if len(enc) < 2: | |
| continue | |
| try: | |
| src = torch.tensor([enc], dtype=torch.long, device=dev) | |
| text, _, _ = model.generate(src, vocab, max_len=128, temperature=0.0) | |
| sc = score_generation(gold, text) | |
| for k in scores: | |
| scores[k].append(getattr(sc, k, 0.0) * 100) | |
| except Exception: | |
| pass | |
| final_metrics = {k: float(np.mean(v)) if v else 0.0 for k, v in scores.items()} | |
| print(f"\n{'Metric':<15} {'Score':>8}") | |
| print("-" * 25) | |
| for k, v in final_metrics.items(): | |
| print(f" {k:<13} {v:>7.1f}%") | |
| # Save + push | |
| print("\n--- Saving ---") | |
| save_dir = Path(args.save_dir) | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| tag = datetime.now().strftime("%Y%m%d_%H%M") | |
| ckpt_path = save_dir / f"gpu_{args.preset}_{tag}.pt" | |
| torch.save({ | |
| "model_state_dict": model.state_dict(), | |
| "generator_state_dict": model.generator.state_dict(), | |
| "vocab_size": len(vocab), | |
| "preset": args.preset, | |
| "metrics": final_metrics, | |
| "data_stats": data_stats, | |
| "config": vars(args), | |
| "env": env_info, | |
| "timestamp": datetime.now().isoformat(), | |
| }, ckpt_path) | |
| print(f" Checkpoint: {ckpt_path}") | |
| # Results JSON | |
| results_path = save_dir / f"gpu_{args.preset}_{tag}_results.json" | |
| results_path.write_text(json.dumps({ | |
| "metrics": final_metrics, "data_stats": data_stats, | |
| "n_train": len(train_pairs), "n_val": len(val_pairs), | |
| "n_params": n_params, "env": env_info, | |
| "novelty_claims": model.novelty_summary(), | |
| }, indent=2, default=str)) | |
| print(f" Results: {results_path}") | |
| # Push to Hub (PRIVATE) | |
| if args.hub_repo: | |
| try: | |
| url = model.push_to_hub( | |
| args.hub_repo, private=True, | |
| metrics=final_metrics, vocab=vocab) | |
| print(f" Hub: {url}") | |
| except Exception as e: | |
| print(f" Hub push failed: {e}") | |
| print(f"\nDone! ({env_info['env']}, {env_info['device']}, {env_info.get('gpu_name', 'N/A')})") | |
| return final_metrics | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="KAN-JEPA GPU Training", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=__doc__) | |
| parser.add_argument("--preset", default="text2cypher", | |
| choices=["text2cypher", "spider2", "swebench", "code", "all"]) | |
| parser.add_argument("--epochs", type=int, default=100) | |
| parser.add_argument("--n-pairs", type=int, default=5000) | |
| parser.add_argument("--lr", type=float, default=2e-3) | |
| parser.add_argument("--batch-size", type=int, default=64) | |
| parser.add_argument("--vocab-size", type=int, default=16384) | |
| parser.add_argument("--include-hf", action="store_true", default=True) | |
| parser.add_argument("--no-hf", dest="include_hf", action="store_false") | |
| parser.add_argument("--no-flywheel", action="store_true") | |
| parser.add_argument("--evolve", action="store_true", default=True) | |
| parser.add_argument("--no-evolve", dest="evolve", action="store_false") | |
| parser.add_argument("--hub-repo", default="JohnGenetica/ane-sota-kan-v2") | |
| parser.add_argument("--save-dir", default="checkpoints") | |
| parser.add_argument("--seed", type=int, default=42) | |
| args = parser.parse_args() | |
| # Also accept env vars (for HF Jobs) | |
| if os.environ.get("PRESET"): | |
| args.preset = os.environ["PRESET"] | |
| if os.environ.get("EPOCHS"): | |
| args.epochs = int(os.environ["EPOCHS"]) | |
| if os.environ.get("N_PAIRS"): | |
| args.n_pairs = int(os.environ["N_PAIRS"]) | |
| if os.environ.get("LR"): | |
| args.lr = float(os.environ["LR"]) | |
| if os.environ.get("BATCH_SIZE"): | |
| args.batch_size = int(os.environ["BATCH_SIZE"]) | |
| if os.environ.get("HF_REPO"): | |
| args.hub_repo = os.environ["HF_REPO"] | |
| random.seed(args.seed) | |
| np.random.seed(args.seed) | |
| torch.manual_seed(args.seed) | |
| print("=" * 60) | |
| print(" KAN-JEPA SOTA GPU Training") | |
| print("=" * 60) | |
| env_info = detect_env() | |
| print(f" Env: {env_info['env']}") | |
| print(f" Device: {env_info['device']}") | |
| if env_info["gpu_name"]: | |
| print(f" GPU: {env_info['gpu_name']} ({env_info['vram_gb']:.1f} GB)") | |
| print(f" Mixed: {env_info['mixed_precision']}") | |
| print(f" Preset: {args.preset}") | |
| print(f" Epochs: {args.epochs}") | |
| print(f" Pairs: {args.n_pairs}") | |
| ensure_imports() | |
| train(args, env_info) | |
| if __name__ == "__main__": | |
| main() | |