File size: 11,004 Bytes
60ddd88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
"""
Train H4 cross-encoder reranker on SQuAD.

Uses the PPL 10.0 TinyStories checkpoint as backbone.
Fine-tunes on binary classification: does this passage answer this question?

For each SQuAD example:
  - Positive: [question SEP correct_passage] -> label 1
  - Negative: [question SEP wrong_passage] -> label 0

The H4 attention heads directly attend from question tokens to passage tokens
within the same sequence — this is why cross-encoders beat bi-encoders.

Pipeline integration:
  1. Bi-encoder retrieves top-5 (R@5=100%, 20ms)
  2. Cross-encoder reranks 5 candidates (5 forward passes, ~50ms)
  3. Return highest-scoring → R@1 should reach 80-90%+
"""

import os
import math
import time
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys

sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

from rag.cross_encoder import H4CrossEncoder
from rag.prepare_qa import download_squad_dev
from rag.tokenizer import BPETokenizer

# ---------------------------------------------------------------------------
# Hyperparameters
# ---------------------------------------------------------------------------

TIME_BUDGET = int(os.environ.get('CE_TIME', 3600))  # 1 hour default
D_MODEL = 512
N_HEADS = 8
N_LAYERS = 8
USE_BITLINEAR = True
LR = 5e-4  # lower LR for fine-tuning (backbone is pre-trained)
WEIGHT_DECAY = 0.01
GRAD_CLIP = 1.0
BATCH_SIZE = 8  # pairs per batch (each has 1 positive + 1 negative)
MAX_SEQ_LEN = 192  # question + passage combined
EVAL_INTERVAL = 100
CHECKPOINT_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'checkpoints', 'h4_fullscale_final.pt')


def pad_tokens(ids, max_len):
    ids = ids[:max_len]
    return ids + [0] * (max_len - len(ids))


def make_pair(tokenizer, question, passage, max_len):
    """Encode [question SEP passage] as a single sequence."""
    q_ids = tokenizer.encode(question)
    p_ids = tokenizer.encode(passage)
    # Budget: half for question, half for passage (with SEP)
    max_q = max_len // 3
    max_p = max_len - max_q - 1
    q_ids = q_ids[:max_q]
    p_ids = p_ids[:max_p]
    combined = q_ids + [2] + p_ids  # 2 = SEP
    return pad_tokens(combined, max_len)


def main():
    t_start = time.time()
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)

    # Load SQuAD
    squad = download_squad_dev()
    if len(squad) < 100:
        print("SQuAD not available. Run: python rag/prepare_qa.py")
        return
    print(f"SQuAD: {len(squad)} QA pairs")

    # Build BPE tokenizer
    tokenizer = BPETokenizer(max_vocab=8192)
    all_texts = [qa['context'] + ' ' + qa['question'] for qa in squad[:2000]]
    tokenizer.build_vocab(all_texts)

    # Split
    indices = list(range(len(squad)))
    random.shuffle(indices)
    n_val = 200
    train_data = [squad[i] for i in indices[n_val:]]
    val_data = [squad[i] for i in indices[:n_val]]
    print(f"Train: {len(train_data)}, Val: {len(val_data)}")

    # Create cross-encoder
    model = H4CrossEncoder(
        vocab_size=tokenizer.vocab_size,
        d_model=D_MODEL,
        n_heads=N_HEADS,
        n_layers=N_LAYERS,
        use_bitlinear=USE_BITLINEAR,
        max_seq_len=MAX_SEQ_LEN,
    )
    print(f"Model: {model.count_params():,} params")

    # Load pre-trained backbone
    if os.path.exists(CHECKPOINT_PATH):
        config = model.load_lm_backbone(CHECKPOINT_PATH)
        print(f"Loaded backbone from {CHECKPOINT_PATH}")
    else:
        print(f"No checkpoint at {CHECKPOINT_PATH}, training from scratch")

    # Freeze backbone initially, only train score head
    # Then unfreeze after warmup for fine-tuning
    for name, param in model.lm.named_parameters():
        param.requires_grad = False
    trainable_head = sum(p.numel() for p in model.score_head.parameters())
    print(f"Phase 1: training score head only ({trainable_head:,} params)")

    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=LR * 10,  # higher LR for head-only phase
        weight_decay=WEIGHT_DECAY,
    )

    # Training loop
    model.train()
    step = 0
    total_training_time = 0.0
    best_acc = 0.0
    unfrozen = False
    UNFREEZE_STEP = 200

    print(f"\nTraining for {TIME_BUDGET}s")
    print(f"{'step':>6} {'loss':>8} {'acc':>8} {'val_acc':>8} {'phase':>10}")
    print("-" * 48)

    while True:
        t0 = time.time()

        # Unfreeze backbone after warmup
        if step == UNFREEZE_STEP and not unfrozen:
            for param in model.lm.parameters():
                param.requires_grad = True
            unfrozen = True
            total_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
            print(f"\n  Phase 2: unfreezing backbone ({total_trainable:,} trainable params)")
            optimizer = torch.optim.AdamW(
                model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.95))

        # Sample batch: positive and negative pairs
        batch_qa = random.sample(train_data, min(BATCH_SIZE, len(train_data)))
        input_ids = []
        labels = []

        for qa in batch_qa:
            # Positive: question + correct passage
            pos = make_pair(tokenizer, qa['question'], qa['context'], MAX_SEQ_LEN)
            input_ids.append(pos)
            labels.append(1.0)

            # Negative: question + random wrong passage
            neg_qa = random.choice(train_data)
            while neg_qa['context'] == qa['context']:
                neg_qa = random.choice(train_data)
            neg = make_pair(tokenizer, qa['question'], neg_qa['context'], MAX_SEQ_LEN)
            input_ids.append(neg)
            labels.append(0.0)

        input_ids = torch.tensor(input_ids, dtype=torch.long)
        labels = torch.tensor(labels, dtype=torch.float32)

        # Forward
        scores = model(input_ids)
        loss = F.binary_cross_entropy_with_logits(scores, labels)

        # Accuracy
        with torch.no_grad():
            preds = (scores > 0).float()
            acc = (preds == labels).float().mean().item()

        optimizer.zero_grad()
        loss.backward()
        if GRAD_CLIP > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()

        dt = time.time() - t0
        if step > 2:
            total_training_time += dt

        # Eval
        if step % EVAL_INTERVAL == 0:
            model.eval()
            val_correct = 0
            val_total = 0
            val_r1 = 0
            val_r1_total = 0

            with torch.no_grad():
                # Binary accuracy
                for vi in range(0, min(len(val_data), 100), BATCH_SIZE):
                    vbatch = val_data[vi:vi + BATCH_SIZE]
                    v_ids = []
                    v_labels = []
                    for qa in vbatch:
                        pos = make_pair(tokenizer, qa['question'], qa['context'], MAX_SEQ_LEN)
                        v_ids.append(pos)
                        v_labels.append(1.0)
                        neg_qa = random.choice(val_data)
                        neg = make_pair(tokenizer, qa['question'], neg_qa['context'], MAX_SEQ_LEN)
                        v_ids.append(neg)
                        v_labels.append(0.0)
                    v_ids = torch.tensor(v_ids, dtype=torch.long)
                    v_labels = torch.tensor(v_labels)
                    v_scores = model(v_ids)
                    v_preds = (v_scores > 0).float()
                    val_correct += (v_preds == v_labels).sum().item()
                    val_total += len(v_labels)

                # Reranking accuracy (R@1 on top-5 candidates)
                for qa in val_data[:50]:
                    # Simulate: 1 correct + 4 wrong passages
                    candidates = [qa['context']]
                    neg_pool = [q for q in val_data if q['context'] != qa['context']]
                    for neg in random.sample(neg_pool, min(4, len(neg_pool))):
                        candidates.append(neg['context'])

                    c_ids = []
                    for passage in candidates:
                        c_ids.append(make_pair(tokenizer, qa['question'], passage, MAX_SEQ_LEN))
                    c_ids = torch.tensor(c_ids, dtype=torch.long)
                    c_scores = model(c_ids)
                    top_idx = c_scores.argmax().item()
                    if top_idx == 0:  # correct passage was ranked first
                        val_r1 += 1
                    val_r1_total += 1

            val_acc = val_correct / max(val_total, 1)
            rerank_r1 = val_r1 / max(val_r1_total, 1)

            if val_acc > best_acc:
                best_acc = val_acc

            phase = "head-only" if not unfrozen else "full"
            print(f"{step:6d} {loss.item():8.4f} {acc:8.3f} {val_acc:8.3f} {phase:>10}"
                  f"  rerank_R@1={rerank_r1:.3f}")
            model.train()

        step += 1
        elapsed = time.time() - t_start
        if step > 2 and total_training_time >= TIME_BUDGET:
            break

    # Final evaluation
    model.eval()
    print("\n" + "=" * 60)
    print("FINAL CROSS-ENCODER EVALUATION:")

    final_r1 = 0
    final_total = 0
    with torch.no_grad():
        for qa in val_data[:100]:
            candidates = [qa['context']]
            neg_pool = [q for q in val_data if q['context'] != qa['context']]
            for neg in random.sample(neg_pool, min(4, len(neg_pool))):
                candidates.append(neg['context'])

            c_ids = []
            for passage in candidates:
                c_ids.append(make_pair(tokenizer, qa['question'], passage, MAX_SEQ_LEN))
            c_ids = torch.tensor(c_ids, dtype=torch.long)
            c_scores = model(c_ids)
            if c_scores.argmax().item() == 0:
                final_r1 += 1
            final_total += 1

    rerank_r1 = final_r1 / max(final_total, 1)

    print(f"  Rerank R@1 (top-5): {rerank_r1:.1%} ({final_r1}/{final_total})")
    print(f"  Best binary acc: {best_acc:.1%}")
    print("=" * 60)

    print("\n---")
    print(f"rerank_r1:        {rerank_r1:.4f}")
    print(f"best_binary_acc:  {best_acc:.4f}")
    print(f"training_seconds: {total_training_time:.1f}")
    print(f"total_seconds:    {time.time() - t_start:.1f}")
    print(f"num_steps:        {step}")
    print(f"num_params:       {model.count_params()}")
    print(f"ternary:          {'yes' if USE_BITLINEAR else 'no'}")

    # Save checkpoint
    os.makedirs('checkpoints', exist_ok=True)
    ckpt_path = os.path.join('checkpoints', 'h4_cross_encoder.pt')
    torch.save({
        'model_state': model.state_dict(),
        'rerank_r1': rerank_r1,
        'step': step,
        'config': {
            'd_model': D_MODEL, 'n_layers': N_LAYERS, 'n_heads': N_HEADS,
            'vocab_size': tokenizer.vocab_size, 'use_bitlinear': USE_BITLINEAR,
        },
    }, ckpt_path)
    print(f"Saved: {ckpt_path}")


if __name__ == '__main__':
    main()