ane-kan-runtime / scripts /_quick_baseline.py
JohnGenetica's picture
Deploy ANE KAN runtime Space
201cf4d verified
#!/usr/bin/env python3
"""Quick ablation: baseline → scaled data → scaled epochs.
Forces CPU to avoid MPS abort on macOS."""
import json, sys, time, random, warnings, os
warnings.filterwarnings('ignore')
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
import torch
def _detect_device() -> str:
if torch.cuda.is_available():
return "cuda"
return "cpu"
DEVICE = _detect_device()
import numpy as np
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from training.leaderboard_data import LeaderboardPair
from training.leaderboard_generator import KANJEPAStrategy
from training.core.generative_flywheel import score_generation
def load_pairs(n):
raw = json.loads(Path('training/kan_bench_results/sota_training_data.json').read_text())
random.seed(42)
random.shuffle(raw)
pairs = []
for i, r in enumerate(raw[:n]):
q = r.get('question', r.get('source', ''))
g = r.get('gold', r.get('target', ''))
if q and g:
pairs.append(LeaderboardPair(
question=q, gold=g, instance_id=f's{i}',
benchmark='text2cypher', difficulty='medium'))
return pairs
def run_config(name, n_pairs, epochs, d_model=128, preset='default', n_eval=15):
pairs = load_pairs(n_pairs)
n_train = int(len(pairs) * 0.8)
train_pairs, eval_pairs = pairs[:n_train], pairs[n_train:]
print(f'\n{"="*60}')
print(f'{name}: {n_train} train, {len(eval_pairs)} eval, {epochs} ep, d={d_model}, preset={preset}')
print(f'{"="*60}')
strategy = KANJEPAStrategy(d_model=d_model, preset=preset)
vocab = strategy.build_vocab(pairs)
model = strategy.build_model(vocab, DEVICE)
n_params = sum(p.numel() for p in model.parameters())
print(f'Vocab: {len(vocab.idx2tok)}, Params: {n_params:,}', flush=True)
t0 = time.time()
info = strategy.train(model, train_pairs, vocab, epochs=epochs, device=DEVICE)
train_time = time.time() - t0
print(f'Training: {train_time:.1f}s', flush=True)
# Fast greedy eval
print(f'Evaluating {min(n_eval, len(eval_pairs))} pairs...', flush=True)
model.eval()
metrics = {'bleu4': [], 'rouge_l': [], 'token_accuracy': [],
'exact_match': [], 'chrf': [], 'composite': []}
for i, p in enumerate(eval_pairs[:n_eval]):
encoded = vocab.encode(p.question)[:64]
if not encoded:
encoded = [vocab.BOS, vocab.EOS]
src = torch.tensor([encoded], dtype=torch.long)
with torch.no_grad():
text, conf, _ = model.generate_autoregressive(
src, kan_features=None, vocab=vocab, max_len=64, temperature=0.0)
ms = score_generation(p.gold, text)
metrics['bleu4'].append(ms.bleu4)
metrics['rouge_l'].append(ms.rouge_l)
metrics['token_accuracy'].append(ms.token_accuracy)
metrics['exact_match'].append(ms.exact_match)
metrics['chrf'].append(ms.chrf)
metrics['composite'].append(ms.composite)
if i < 3:
print(f' [{i}] BLEU={ms.bleu4:.3f} comp={ms.composite:.3f}', flush=True)
print(f' pred={text[:80]}', flush=True)
print(f' gold={p.gold[:80]}', flush=True)
result = {k: float(np.mean(v)) for k, v in metrics.items()}
result['non_zero_bleu'] = sum(1 for b in metrics['bleu4'] if b > 0) / max(len(metrics['bleu4']), 1)
result['name'] = name
result['n_train'] = n_train
result['n_params'] = n_params
result['train_time_s'] = train_time
print(f'\n--- {name} ---')
for k in ['bleu4', 'rouge_l', 'token_accuracy', 'exact_match', 'chrf', 'composite']:
print(f' {k:<16}: {result[k]*100:.2f}%')
print(f' non_zero_bleu : {result["non_zero_bleu"]*100:.1f}%')
return result
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='baseline',
choices=['baseline', 'scaled', 'all'])
args = parser.parse_args()
out_path = Path('training/kan_bench_results/quick_ablation.json')
results = []
configs = {
# Keep within 8-min CPU budget (488s for 100 pairs × 49ep)
'baseline': [('small_100_49ep', 100, 49, 64, 'small')],
'scaled': [('small_200_20ep', 200, 20, 64, 'small')],
'all': [
('small_100_20ep', 100, 20, 64, 'small'),
('small_200_20ep', 200, 20, 64, 'small'),
('small_100_49ep', 100, 49, 64, 'small'),
],
}
for name, n, ep, d, preset in configs[args.config]:
result = run_config(name, n_pairs=n, epochs=ep, d_model=d, preset=preset)
results.append(result)
# Save after each config so we don't lose results
out_path.write_text(json.dumps(results, indent=2))
print(f'[Saved intermediate to {out_path}]', flush=True)
# Summary
print(f'\n{"="*80}')
print(f'{"Config":<22} {"Train":>6} {"BLEU":>8} {"ROUGE":>8} {"TokAcc":>8} {"EM":>8} {"ChrF":>8} {"Comp":>8}')
print(f'{"-"*22} {"-"*6} {"-"*8} {"-"*8} {"-"*8} {"-"*8} {"-"*8} {"-"*8}')
for r in results:
print(f'{r["name"]:<22} {r["n_train"]:>6} '
f'{r["bleu4"]*100:>7.2f}% {r["rouge_l"]*100:>7.2f}% '
f'{r["token_accuracy"]*100:>7.2f}% {r["exact_match"]*100:>7.2f}% '
f'{r["chrf"]*100:>7.2f}% {r["composite"]*100:>7.2f}%')
print(f'{"="*80}')