Spaces:
Build error
Build error
File size: 5,485 Bytes
201cf4d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | #!/usr/bin/env python3
"""Quick ablation: baseline → scaled data → scaled epochs.
Forces CPU to avoid MPS abort on macOS."""
import json, sys, time, random, warnings, os
warnings.filterwarnings('ignore')
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
import torch
def _detect_device() -> str:
if torch.cuda.is_available():
return "cuda"
return "cpu"
DEVICE = _detect_device()
import numpy as np
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from training.leaderboard_data import LeaderboardPair
from training.leaderboard_generator import KANJEPAStrategy
from training.core.generative_flywheel import score_generation
def load_pairs(n):
raw = json.loads(Path('training/kan_bench_results/sota_training_data.json').read_text())
random.seed(42)
random.shuffle(raw)
pairs = []
for i, r in enumerate(raw[:n]):
q = r.get('question', r.get('source', ''))
g = r.get('gold', r.get('target', ''))
if q and g:
pairs.append(LeaderboardPair(
question=q, gold=g, instance_id=f's{i}',
benchmark='text2cypher', difficulty='medium'))
return pairs
def run_config(name, n_pairs, epochs, d_model=128, preset='default', n_eval=15):
pairs = load_pairs(n_pairs)
n_train = int(len(pairs) * 0.8)
train_pairs, eval_pairs = pairs[:n_train], pairs[n_train:]
print(f'\n{"="*60}')
print(f'{name}: {n_train} train, {len(eval_pairs)} eval, {epochs} ep, d={d_model}, preset={preset}')
print(f'{"="*60}')
strategy = KANJEPAStrategy(d_model=d_model, preset=preset)
vocab = strategy.build_vocab(pairs)
model = strategy.build_model(vocab, DEVICE)
n_params = sum(p.numel() for p in model.parameters())
print(f'Vocab: {len(vocab.idx2tok)}, Params: {n_params:,}', flush=True)
t0 = time.time()
info = strategy.train(model, train_pairs, vocab, epochs=epochs, device=DEVICE)
train_time = time.time() - t0
print(f'Training: {train_time:.1f}s', flush=True)
# Fast greedy eval
print(f'Evaluating {min(n_eval, len(eval_pairs))} pairs...', flush=True)
model.eval()
metrics = {'bleu4': [], 'rouge_l': [], 'token_accuracy': [],
'exact_match': [], 'chrf': [], 'composite': []}
for i, p in enumerate(eval_pairs[:n_eval]):
encoded = vocab.encode(p.question)[:64]
if not encoded:
encoded = [vocab.BOS, vocab.EOS]
src = torch.tensor([encoded], dtype=torch.long)
with torch.no_grad():
text, conf, _ = model.generate_autoregressive(
src, kan_features=None, vocab=vocab, max_len=64, temperature=0.0)
ms = score_generation(p.gold, text)
metrics['bleu4'].append(ms.bleu4)
metrics['rouge_l'].append(ms.rouge_l)
metrics['token_accuracy'].append(ms.token_accuracy)
metrics['exact_match'].append(ms.exact_match)
metrics['chrf'].append(ms.chrf)
metrics['composite'].append(ms.composite)
if i < 3:
print(f' [{i}] BLEU={ms.bleu4:.3f} comp={ms.composite:.3f}', flush=True)
print(f' pred={text[:80]}', flush=True)
print(f' gold={p.gold[:80]}', flush=True)
result = {k: float(np.mean(v)) for k, v in metrics.items()}
result['non_zero_bleu'] = sum(1 for b in metrics['bleu4'] if b > 0) / max(len(metrics['bleu4']), 1)
result['name'] = name
result['n_train'] = n_train
result['n_params'] = n_params
result['train_time_s'] = train_time
print(f'\n--- {name} ---')
for k in ['bleu4', 'rouge_l', 'token_accuracy', 'exact_match', 'chrf', 'composite']:
print(f' {k:<16}: {result[k]*100:.2f}%')
print(f' non_zero_bleu : {result["non_zero_bleu"]*100:.1f}%')
return result
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='baseline',
choices=['baseline', 'scaled', 'all'])
args = parser.parse_args()
out_path = Path('training/kan_bench_results/quick_ablation.json')
results = []
configs = {
# Keep within 8-min CPU budget (488s for 100 pairs × 49ep)
'baseline': [('small_100_49ep', 100, 49, 64, 'small')],
'scaled': [('small_200_20ep', 200, 20, 64, 'small')],
'all': [
('small_100_20ep', 100, 20, 64, 'small'),
('small_200_20ep', 200, 20, 64, 'small'),
('small_100_49ep', 100, 49, 64, 'small'),
],
}
for name, n, ep, d, preset in configs[args.config]:
result = run_config(name, n_pairs=n, epochs=ep, d_model=d, preset=preset)
results.append(result)
# Save after each config so we don't lose results
out_path.write_text(json.dumps(results, indent=2))
print(f'[Saved intermediate to {out_path}]', flush=True)
# Summary
print(f'\n{"="*80}')
print(f'{"Config":<22} {"Train":>6} {"BLEU":>8} {"ROUGE":>8} {"TokAcc":>8} {"EM":>8} {"ChrF":>8} {"Comp":>8}')
print(f'{"-"*22} {"-"*6} {"-"*8} {"-"*8} {"-"*8} {"-"*8} {"-"*8} {"-"*8}')
for r in results:
print(f'{r["name"]:<22} {r["n_train"]:>6} '
f'{r["bleu4"]*100:>7.2f}% {r["rouge_l"]*100:>7.2f}% '
f'{r["token_accuracy"]*100:>7.2f}% {r["exact_match"]*100:>7.2f}% '
f'{r["chrf"]*100:>7.2f}% {r["composite"]*100:>7.2f}%')
print(f'{"="*80}')
|