Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| """Quick ablation: baseline → scaled data → scaled epochs. | |
| Forces CPU to avoid MPS abort on macOS.""" | |
| import json, sys, time, random, warnings, os | |
| warnings.filterwarnings('ignore') | |
| os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' | |
| os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0' | |
| import torch | |
| def _detect_device() -> str: | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| return "cpu" | |
| DEVICE = _detect_device() | |
| import numpy as np | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | |
| from training.leaderboard_data import LeaderboardPair | |
| from training.leaderboard_generator import KANJEPAStrategy | |
| from training.core.generative_flywheel import score_generation | |
| def load_pairs(n): | |
| raw = json.loads(Path('training/kan_bench_results/sota_training_data.json').read_text()) | |
| random.seed(42) | |
| random.shuffle(raw) | |
| pairs = [] | |
| for i, r in enumerate(raw[:n]): | |
| q = r.get('question', r.get('source', '')) | |
| g = r.get('gold', r.get('target', '')) | |
| if q and g: | |
| pairs.append(LeaderboardPair( | |
| question=q, gold=g, instance_id=f's{i}', | |
| benchmark='text2cypher', difficulty='medium')) | |
| return pairs | |
| def run_config(name, n_pairs, epochs, d_model=128, preset='default', n_eval=15): | |
| pairs = load_pairs(n_pairs) | |
| n_train = int(len(pairs) * 0.8) | |
| train_pairs, eval_pairs = pairs[:n_train], pairs[n_train:] | |
| print(f'\n{"="*60}') | |
| print(f'{name}: {n_train} train, {len(eval_pairs)} eval, {epochs} ep, d={d_model}, preset={preset}') | |
| print(f'{"="*60}') | |
| strategy = KANJEPAStrategy(d_model=d_model, preset=preset) | |
| vocab = strategy.build_vocab(pairs) | |
| model = strategy.build_model(vocab, DEVICE) | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| print(f'Vocab: {len(vocab.idx2tok)}, Params: {n_params:,}', flush=True) | |
| t0 = time.time() | |
| info = strategy.train(model, train_pairs, vocab, epochs=epochs, device=DEVICE) | |
| train_time = time.time() - t0 | |
| print(f'Training: {train_time:.1f}s', flush=True) | |
| # Fast greedy eval | |
| print(f'Evaluating {min(n_eval, len(eval_pairs))} pairs...', flush=True) | |
| model.eval() | |
| metrics = {'bleu4': [], 'rouge_l': [], 'token_accuracy': [], | |
| 'exact_match': [], 'chrf': [], 'composite': []} | |
| for i, p in enumerate(eval_pairs[:n_eval]): | |
| encoded = vocab.encode(p.question)[:64] | |
| if not encoded: | |
| encoded = [vocab.BOS, vocab.EOS] | |
| src = torch.tensor([encoded], dtype=torch.long) | |
| with torch.no_grad(): | |
| text, conf, _ = model.generate_autoregressive( | |
| src, kan_features=None, vocab=vocab, max_len=64, temperature=0.0) | |
| ms = score_generation(p.gold, text) | |
| metrics['bleu4'].append(ms.bleu4) | |
| metrics['rouge_l'].append(ms.rouge_l) | |
| metrics['token_accuracy'].append(ms.token_accuracy) | |
| metrics['exact_match'].append(ms.exact_match) | |
| metrics['chrf'].append(ms.chrf) | |
| metrics['composite'].append(ms.composite) | |
| if i < 3: | |
| print(f' [{i}] BLEU={ms.bleu4:.3f} comp={ms.composite:.3f}', flush=True) | |
| print(f' pred={text[:80]}', flush=True) | |
| print(f' gold={p.gold[:80]}', flush=True) | |
| result = {k: float(np.mean(v)) for k, v in metrics.items()} | |
| result['non_zero_bleu'] = sum(1 for b in metrics['bleu4'] if b > 0) / max(len(metrics['bleu4']), 1) | |
| result['name'] = name | |
| result['n_train'] = n_train | |
| result['n_params'] = n_params | |
| result['train_time_s'] = train_time | |
| print(f'\n--- {name} ---') | |
| for k in ['bleu4', 'rouge_l', 'token_accuracy', 'exact_match', 'chrf', 'composite']: | |
| print(f' {k:<16}: {result[k]*100:.2f}%') | |
| print(f' non_zero_bleu : {result["non_zero_bleu"]*100:.1f}%') | |
| return result | |
| if __name__ == '__main__': | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', default='baseline', | |
| choices=['baseline', 'scaled', 'all']) | |
| args = parser.parse_args() | |
| out_path = Path('training/kan_bench_results/quick_ablation.json') | |
| results = [] | |
| configs = { | |
| # Keep within 8-min CPU budget (488s for 100 pairs × 49ep) | |
| 'baseline': [('small_100_49ep', 100, 49, 64, 'small')], | |
| 'scaled': [('small_200_20ep', 200, 20, 64, 'small')], | |
| 'all': [ | |
| ('small_100_20ep', 100, 20, 64, 'small'), | |
| ('small_200_20ep', 200, 20, 64, 'small'), | |
| ('small_100_49ep', 100, 49, 64, 'small'), | |
| ], | |
| } | |
| for name, n, ep, d, preset in configs[args.config]: | |
| result = run_config(name, n_pairs=n, epochs=ep, d_model=d, preset=preset) | |
| results.append(result) | |
| # Save after each config so we don't lose results | |
| out_path.write_text(json.dumps(results, indent=2)) | |
| print(f'[Saved intermediate to {out_path}]', flush=True) | |
| # Summary | |
| print(f'\n{"="*80}') | |
| print(f'{"Config":<22} {"Train":>6} {"BLEU":>8} {"ROUGE":>8} {"TokAcc":>8} {"EM":>8} {"ChrF":>8} {"Comp":>8}') | |
| print(f'{"-"*22} {"-"*6} {"-"*8} {"-"*8} {"-"*8} {"-"*8} {"-"*8} {"-"*8}') | |
| for r in results: | |
| print(f'{r["name"]:<22} {r["n_train"]:>6} ' | |
| f'{r["bleu4"]*100:>7.2f}% {r["rouge_l"]*100:>7.2f}% ' | |
| f'{r["token_accuracy"]*100:>7.2f}% {r["exact_match"]*100:>7.2f}% ' | |
| f'{r["chrf"]*100:>7.2f}% {r["composite"]*100:>7.2f}%') | |
| print(f'{"="*80}') | |