#!/usr/bin/env python3 """Quick ablation: baseline → scaled data → scaled epochs. Forces CPU to avoid MPS abort on macOS.""" import json, sys, time, random, warnings, os warnings.filterwarnings('ignore') os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0' import torch def _detect_device() -> str: if torch.cuda.is_available(): return "cuda" return "cpu" DEVICE = _detect_device() import numpy as np from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from training.leaderboard_data import LeaderboardPair from training.leaderboard_generator import KANJEPAStrategy from training.core.generative_flywheel import score_generation def load_pairs(n): raw = json.loads(Path('training/kan_bench_results/sota_training_data.json').read_text()) random.seed(42) random.shuffle(raw) pairs = [] for i, r in enumerate(raw[:n]): q = r.get('question', r.get('source', '')) g = r.get('gold', r.get('target', '')) if q and g: pairs.append(LeaderboardPair( question=q, gold=g, instance_id=f's{i}', benchmark='text2cypher', difficulty='medium')) return pairs def run_config(name, n_pairs, epochs, d_model=128, preset='default', n_eval=15): pairs = load_pairs(n_pairs) n_train = int(len(pairs) * 0.8) train_pairs, eval_pairs = pairs[:n_train], pairs[n_train:] print(f'\n{"="*60}') print(f'{name}: {n_train} train, {len(eval_pairs)} eval, {epochs} ep, d={d_model}, preset={preset}') print(f'{"="*60}') strategy = KANJEPAStrategy(d_model=d_model, preset=preset) vocab = strategy.build_vocab(pairs) model = strategy.build_model(vocab, DEVICE) n_params = sum(p.numel() for p in model.parameters()) print(f'Vocab: {len(vocab.idx2tok)}, Params: {n_params:,}', flush=True) t0 = time.time() info = strategy.train(model, train_pairs, vocab, epochs=epochs, device=DEVICE) train_time = time.time() - t0 print(f'Training: {train_time:.1f}s', flush=True) # Fast greedy eval print(f'Evaluating {min(n_eval, len(eval_pairs))} pairs...', flush=True) model.eval() metrics = {'bleu4': [], 'rouge_l': [], 'token_accuracy': [], 'exact_match': [], 'chrf': [], 'composite': []} for i, p in enumerate(eval_pairs[:n_eval]): encoded = vocab.encode(p.question)[:64] if not encoded: encoded = [vocab.BOS, vocab.EOS] src = torch.tensor([encoded], dtype=torch.long) with torch.no_grad(): text, conf, _ = model.generate_autoregressive( src, kan_features=None, vocab=vocab, max_len=64, temperature=0.0) ms = score_generation(p.gold, text) metrics['bleu4'].append(ms.bleu4) metrics['rouge_l'].append(ms.rouge_l) metrics['token_accuracy'].append(ms.token_accuracy) metrics['exact_match'].append(ms.exact_match) metrics['chrf'].append(ms.chrf) metrics['composite'].append(ms.composite) if i < 3: print(f' [{i}] BLEU={ms.bleu4:.3f} comp={ms.composite:.3f}', flush=True) print(f' pred={text[:80]}', flush=True) print(f' gold={p.gold[:80]}', flush=True) result = {k: float(np.mean(v)) for k, v in metrics.items()} result['non_zero_bleu'] = sum(1 for b in metrics['bleu4'] if b > 0) / max(len(metrics['bleu4']), 1) result['name'] = name result['n_train'] = n_train result['n_params'] = n_params result['train_time_s'] = train_time print(f'\n--- {name} ---') for k in ['bleu4', 'rouge_l', 'token_accuracy', 'exact_match', 'chrf', 'composite']: print(f' {k:<16}: {result[k]*100:.2f}%') print(f' non_zero_bleu : {result["non_zero_bleu"]*100:.1f}%') return result if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--config', default='baseline', choices=['baseline', 'scaled', 'all']) args = parser.parse_args() out_path = Path('training/kan_bench_results/quick_ablation.json') results = [] configs = { # Keep within 8-min CPU budget (488s for 100 pairs × 49ep) 'baseline': [('small_100_49ep', 100, 49, 64, 'small')], 'scaled': [('small_200_20ep', 200, 20, 64, 'small')], 'all': [ ('small_100_20ep', 100, 20, 64, 'small'), ('small_200_20ep', 200, 20, 64, 'small'), ('small_100_49ep', 100, 49, 64, 'small'), ], } for name, n, ep, d, preset in configs[args.config]: result = run_config(name, n_pairs=n, epochs=ep, d_model=d, preset=preset) results.append(result) # Save after each config so we don't lose results out_path.write_text(json.dumps(results, indent=2)) print(f'[Saved intermediate to {out_path}]', flush=True) # Summary print(f'\n{"="*80}') print(f'{"Config":<22} {"Train":>6} {"BLEU":>8} {"ROUGE":>8} {"TokAcc":>8} {"EM":>8} {"ChrF":>8} {"Comp":>8}') print(f'{"-"*22} {"-"*6} {"-"*8} {"-"*8} {"-"*8} {"-"*8} {"-"*8} {"-"*8}') for r in results: print(f'{r["name"]:<22} {r["n_train"]:>6} ' f'{r["bleu4"]*100:>7.2f}% {r["rouge_l"]*100:>7.2f}% ' f'{r["token_accuracy"]*100:>7.2f}% {r["exact_match"]*100:>7.2f}% ' f'{r["chrf"]*100:>7.2f}% {r["composite"]*100:>7.2f}%') print(f'{"="*80}')