File size: 5,485 Bytes
201cf4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#!/usr/bin/env python3
"""Quick ablation: baseline → scaled data → scaled epochs.
Forces CPU to avoid MPS abort on macOS."""
import json, sys, time, random, warnings, os
warnings.filterwarnings('ignore')
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'

import torch

def _detect_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    return "cpu"

DEVICE = _detect_device()

import numpy as np
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from training.leaderboard_data import LeaderboardPair
from training.leaderboard_generator import KANJEPAStrategy
from training.core.generative_flywheel import score_generation


def load_pairs(n):
    raw = json.loads(Path('training/kan_bench_results/sota_training_data.json').read_text())
    random.seed(42)
    random.shuffle(raw)
    pairs = []
    for i, r in enumerate(raw[:n]):
        q = r.get('question', r.get('source', ''))
        g = r.get('gold', r.get('target', ''))
        if q and g:
            pairs.append(LeaderboardPair(
                question=q, gold=g, instance_id=f's{i}',
                benchmark='text2cypher', difficulty='medium'))
    return pairs


def run_config(name, n_pairs, epochs, d_model=128, preset='default', n_eval=15):
    pairs = load_pairs(n_pairs)
    n_train = int(len(pairs) * 0.8)
    train_pairs, eval_pairs = pairs[:n_train], pairs[n_train:]

    print(f'\n{"="*60}')
    print(f'{name}: {n_train} train, {len(eval_pairs)} eval, {epochs} ep, d={d_model}, preset={preset}')
    print(f'{"="*60}')

    strategy = KANJEPAStrategy(d_model=d_model, preset=preset)
    vocab = strategy.build_vocab(pairs)
    model = strategy.build_model(vocab, DEVICE)
    n_params = sum(p.numel() for p in model.parameters())
    print(f'Vocab: {len(vocab.idx2tok)}, Params: {n_params:,}', flush=True)

    t0 = time.time()
    info = strategy.train(model, train_pairs, vocab, epochs=epochs, device=DEVICE)
    train_time = time.time() - t0
    print(f'Training: {train_time:.1f}s', flush=True)

    # Fast greedy eval
    print(f'Evaluating {min(n_eval, len(eval_pairs))} pairs...', flush=True)
    model.eval()
    metrics = {'bleu4': [], 'rouge_l': [], 'token_accuracy': [],
               'exact_match': [], 'chrf': [], 'composite': []}

    for i, p in enumerate(eval_pairs[:n_eval]):
        encoded = vocab.encode(p.question)[:64]
        if not encoded:
            encoded = [vocab.BOS, vocab.EOS]
        src = torch.tensor([encoded], dtype=torch.long)
        with torch.no_grad():
            text, conf, _ = model.generate_autoregressive(
                src, kan_features=None, vocab=vocab, max_len=64, temperature=0.0)
        ms = score_generation(p.gold, text)
        metrics['bleu4'].append(ms.bleu4)
        metrics['rouge_l'].append(ms.rouge_l)
        metrics['token_accuracy'].append(ms.token_accuracy)
        metrics['exact_match'].append(ms.exact_match)
        metrics['chrf'].append(ms.chrf)
        metrics['composite'].append(ms.composite)
        if i < 3:
            print(f'  [{i}] BLEU={ms.bleu4:.3f} comp={ms.composite:.3f}', flush=True)
            print(f'       pred={text[:80]}', flush=True)
            print(f'       gold={p.gold[:80]}', flush=True)

    result = {k: float(np.mean(v)) for k, v in metrics.items()}
    result['non_zero_bleu'] = sum(1 for b in metrics['bleu4'] if b > 0) / max(len(metrics['bleu4']), 1)
    result['name'] = name
    result['n_train'] = n_train
    result['n_params'] = n_params
    result['train_time_s'] = train_time

    print(f'\n--- {name} ---')
    for k in ['bleu4', 'rouge_l', 'token_accuracy', 'exact_match', 'chrf', 'composite']:
        print(f'  {k:<16}: {result[k]*100:.2f}%')
    print(f'  non_zero_bleu   : {result["non_zero_bleu"]*100:.1f}%')
    return result


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='baseline',
                        choices=['baseline', 'scaled', 'all'])
    args = parser.parse_args()

    out_path = Path('training/kan_bench_results/quick_ablation.json')
    results = []

    configs = {
        # Keep within 8-min CPU budget (488s for 100 pairs × 49ep)
        'baseline': [('small_100_49ep', 100, 49, 64, 'small')],
        'scaled': [('small_200_20ep', 200, 20, 64, 'small')],
        'all': [
            ('small_100_20ep', 100, 20, 64, 'small'),
            ('small_200_20ep', 200, 20, 64, 'small'),
            ('small_100_49ep', 100, 49, 64, 'small'),
        ],
    }

    for name, n, ep, d, preset in configs[args.config]:
        result = run_config(name, n_pairs=n, epochs=ep, d_model=d, preset=preset)
        results.append(result)
        # Save after each config so we don't lose results
        out_path.write_text(json.dumps(results, indent=2))
        print(f'[Saved intermediate to {out_path}]', flush=True)

    # Summary
    print(f'\n{"="*80}')
    print(f'{"Config":<22} {"Train":>6} {"BLEU":>8} {"ROUGE":>8} {"TokAcc":>8} {"EM":>8} {"ChrF":>8} {"Comp":>8}')
    print(f'{"-"*22} {"-"*6} {"-"*8} {"-"*8} {"-"*8} {"-"*8} {"-"*8} {"-"*8}')
    for r in results:
        print(f'{r["name"]:<22} {r["n_train"]:>6} '
              f'{r["bleu4"]*100:>7.2f}% {r["rouge_l"]*100:>7.2f}% '
              f'{r["token_accuracy"]*100:>7.2f}% {r["exact_match"]*100:>7.2f}% '
              f'{r["chrf"]*100:>7.2f}% {r["composite"]*100:>7.2f}%')
    print(f'{"="*80}')