| """ |
| Train H4 cross-encoder reranker on SQuAD. |
| |
| Uses the PPL 10.0 TinyStories checkpoint as backbone. |
| Fine-tunes on binary classification: does this passage answer this question? |
| |
| For each SQuAD example: |
| - Positive: [question SEP correct_passage] -> label 1 |
| - Negative: [question SEP wrong_passage] -> label 0 |
| |
| The H4 attention heads directly attend from question tokens to passage tokens |
| within the same sequence — this is why cross-encoders beat bi-encoders. |
| |
| Pipeline integration: |
| 1. Bi-encoder retrieves top-5 (R@5=100%, 20ms) |
| 2. Cross-encoder reranks 5 candidates (5 forward passes, ~50ms) |
| 3. Return highest-scoring → R@1 should reach 80-90%+ |
| """ |
|
|
| import os |
| import math |
| import time |
| import random |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import sys |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) |
|
|
| from rag.cross_encoder import H4CrossEncoder |
| from rag.prepare_qa import download_squad_dev |
| from rag.tokenizer import BPETokenizer |
|
|
| |
| |
| |
|
|
| TIME_BUDGET = int(os.environ.get('CE_TIME', 3600)) |
| D_MODEL = 512 |
| N_HEADS = 8 |
| N_LAYERS = 8 |
| USE_BITLINEAR = True |
| LR = 5e-4 |
| WEIGHT_DECAY = 0.01 |
| GRAD_CLIP = 1.0 |
| BATCH_SIZE = 8 |
| MAX_SEQ_LEN = 192 |
| EVAL_INTERVAL = 100 |
| CHECKPOINT_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'checkpoints', 'h4_fullscale_final.pt') |
|
|
|
|
| def pad_tokens(ids, max_len): |
| ids = ids[:max_len] |
| return ids + [0] * (max_len - len(ids)) |
|
|
|
|
| def make_pair(tokenizer, question, passage, max_len): |
| """Encode [question SEP passage] as a single sequence.""" |
| q_ids = tokenizer.encode(question) |
| p_ids = tokenizer.encode(passage) |
| |
| max_q = max_len // 3 |
| max_p = max_len - max_q - 1 |
| q_ids = q_ids[:max_q] |
| p_ids = p_ids[:max_p] |
| combined = q_ids + [2] + p_ids |
| return pad_tokens(combined, max_len) |
|
|
|
|
| def main(): |
| t_start = time.time() |
| torch.manual_seed(42) |
| random.seed(42) |
| np.random.seed(42) |
|
|
| |
| squad = download_squad_dev() |
| if len(squad) < 100: |
| print("SQuAD not available. Run: python rag/prepare_qa.py") |
| return |
| print(f"SQuAD: {len(squad)} QA pairs") |
|
|
| |
| tokenizer = BPETokenizer(max_vocab=8192) |
| all_texts = [qa['context'] + ' ' + qa['question'] for qa in squad[:2000]] |
| tokenizer.build_vocab(all_texts) |
|
|
| |
| indices = list(range(len(squad))) |
| random.shuffle(indices) |
| n_val = 200 |
| train_data = [squad[i] for i in indices[n_val:]] |
| val_data = [squad[i] for i in indices[:n_val]] |
| print(f"Train: {len(train_data)}, Val: {len(val_data)}") |
|
|
| |
| model = H4CrossEncoder( |
| vocab_size=tokenizer.vocab_size, |
| d_model=D_MODEL, |
| n_heads=N_HEADS, |
| n_layers=N_LAYERS, |
| use_bitlinear=USE_BITLINEAR, |
| max_seq_len=MAX_SEQ_LEN, |
| ) |
| print(f"Model: {model.count_params():,} params") |
|
|
| |
| if os.path.exists(CHECKPOINT_PATH): |
| config = model.load_lm_backbone(CHECKPOINT_PATH) |
| print(f"Loaded backbone from {CHECKPOINT_PATH}") |
| else: |
| print(f"No checkpoint at {CHECKPOINT_PATH}, training from scratch") |
|
|
| |
| |
| for name, param in model.lm.named_parameters(): |
| param.requires_grad = False |
| trainable_head = sum(p.numel() for p in model.score_head.parameters()) |
| print(f"Phase 1: training score head only ({trainable_head:,} params)") |
|
|
| optimizer = torch.optim.AdamW( |
| filter(lambda p: p.requires_grad, model.parameters()), |
| lr=LR * 10, |
| weight_decay=WEIGHT_DECAY, |
| ) |
|
|
| |
| model.train() |
| step = 0 |
| total_training_time = 0.0 |
| best_acc = 0.0 |
| unfrozen = False |
| UNFREEZE_STEP = 200 |
|
|
| print(f"\nTraining for {TIME_BUDGET}s") |
| print(f"{'step':>6} {'loss':>8} {'acc':>8} {'val_acc':>8} {'phase':>10}") |
| print("-" * 48) |
|
|
| while True: |
| t0 = time.time() |
|
|
| |
| if step == UNFREEZE_STEP and not unfrozen: |
| for param in model.lm.parameters(): |
| param.requires_grad = True |
| unfrozen = True |
| total_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"\n Phase 2: unfreezing backbone ({total_trainable:,} trainable params)") |
| optimizer = torch.optim.AdamW( |
| model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.95)) |
|
|
| |
| batch_qa = random.sample(train_data, min(BATCH_SIZE, len(train_data))) |
| input_ids = [] |
| labels = [] |
|
|
| for qa in batch_qa: |
| |
| pos = make_pair(tokenizer, qa['question'], qa['context'], MAX_SEQ_LEN) |
| input_ids.append(pos) |
| labels.append(1.0) |
|
|
| |
| neg_qa = random.choice(train_data) |
| while neg_qa['context'] == qa['context']: |
| neg_qa = random.choice(train_data) |
| neg = make_pair(tokenizer, qa['question'], neg_qa['context'], MAX_SEQ_LEN) |
| input_ids.append(neg) |
| labels.append(0.0) |
|
|
| input_ids = torch.tensor(input_ids, dtype=torch.long) |
| labels = torch.tensor(labels, dtype=torch.float32) |
|
|
| |
| scores = model(input_ids) |
| loss = F.binary_cross_entropy_with_logits(scores, labels) |
|
|
| |
| with torch.no_grad(): |
| preds = (scores > 0).float() |
| acc = (preds == labels).float().mean().item() |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| if GRAD_CLIP > 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) |
| optimizer.step() |
|
|
| dt = time.time() - t0 |
| if step > 2: |
| total_training_time += dt |
|
|
| |
| if step % EVAL_INTERVAL == 0: |
| model.eval() |
| val_correct = 0 |
| val_total = 0 |
| val_r1 = 0 |
| val_r1_total = 0 |
|
|
| with torch.no_grad(): |
| |
| for vi in range(0, min(len(val_data), 100), BATCH_SIZE): |
| vbatch = val_data[vi:vi + BATCH_SIZE] |
| v_ids = [] |
| v_labels = [] |
| for qa in vbatch: |
| pos = make_pair(tokenizer, qa['question'], qa['context'], MAX_SEQ_LEN) |
| v_ids.append(pos) |
| v_labels.append(1.0) |
| neg_qa = random.choice(val_data) |
| neg = make_pair(tokenizer, qa['question'], neg_qa['context'], MAX_SEQ_LEN) |
| v_ids.append(neg) |
| v_labels.append(0.0) |
| v_ids = torch.tensor(v_ids, dtype=torch.long) |
| v_labels = torch.tensor(v_labels) |
| v_scores = model(v_ids) |
| v_preds = (v_scores > 0).float() |
| val_correct += (v_preds == v_labels).sum().item() |
| val_total += len(v_labels) |
|
|
| |
| for qa in val_data[:50]: |
| |
| candidates = [qa['context']] |
| neg_pool = [q for q in val_data if q['context'] != qa['context']] |
| for neg in random.sample(neg_pool, min(4, len(neg_pool))): |
| candidates.append(neg['context']) |
|
|
| c_ids = [] |
| for passage in candidates: |
| c_ids.append(make_pair(tokenizer, qa['question'], passage, MAX_SEQ_LEN)) |
| c_ids = torch.tensor(c_ids, dtype=torch.long) |
| c_scores = model(c_ids) |
| top_idx = c_scores.argmax().item() |
| if top_idx == 0: |
| val_r1 += 1 |
| val_r1_total += 1 |
|
|
| val_acc = val_correct / max(val_total, 1) |
| rerank_r1 = val_r1 / max(val_r1_total, 1) |
|
|
| if val_acc > best_acc: |
| best_acc = val_acc |
|
|
| phase = "head-only" if not unfrozen else "full" |
| print(f"{step:6d} {loss.item():8.4f} {acc:8.3f} {val_acc:8.3f} {phase:>10}" |
| f" rerank_R@1={rerank_r1:.3f}") |
| model.train() |
|
|
| step += 1 |
| elapsed = time.time() - t_start |
| if step > 2 and total_training_time >= TIME_BUDGET: |
| break |
|
|
| |
| model.eval() |
| print("\n" + "=" * 60) |
| print("FINAL CROSS-ENCODER EVALUATION:") |
|
|
| final_r1 = 0 |
| final_total = 0 |
| with torch.no_grad(): |
| for qa in val_data[:100]: |
| candidates = [qa['context']] |
| neg_pool = [q for q in val_data if q['context'] != qa['context']] |
| for neg in random.sample(neg_pool, min(4, len(neg_pool))): |
| candidates.append(neg['context']) |
|
|
| c_ids = [] |
| for passage in candidates: |
| c_ids.append(make_pair(tokenizer, qa['question'], passage, MAX_SEQ_LEN)) |
| c_ids = torch.tensor(c_ids, dtype=torch.long) |
| c_scores = model(c_ids) |
| if c_scores.argmax().item() == 0: |
| final_r1 += 1 |
| final_total += 1 |
|
|
| rerank_r1 = final_r1 / max(final_total, 1) |
|
|
| print(f" Rerank R@1 (top-5): {rerank_r1:.1%} ({final_r1}/{final_total})") |
| print(f" Best binary acc: {best_acc:.1%}") |
| print("=" * 60) |
|
|
| print("\n---") |
| print(f"rerank_r1: {rerank_r1:.4f}") |
| print(f"best_binary_acc: {best_acc:.4f}") |
| print(f"training_seconds: {total_training_time:.1f}") |
| print(f"total_seconds: {time.time() - t_start:.1f}") |
| print(f"num_steps: {step}") |
| print(f"num_params: {model.count_params()}") |
| print(f"ternary: {'yes' if USE_BITLINEAR else 'no'}") |
|
|
| |
| os.makedirs('checkpoints', exist_ok=True) |
| ckpt_path = os.path.join('checkpoints', 'h4_cross_encoder.pt') |
| torch.save({ |
| 'model_state': model.state_dict(), |
| 'rerank_r1': rerank_r1, |
| 'step': step, |
| 'config': { |
| 'd_model': D_MODEL, 'n_layers': N_LAYERS, 'n_heads': N_HEADS, |
| 'vocab_size': tokenizer.vocab_size, 'use_bitlinear': USE_BITLINEAR, |
| }, |
| }, ckpt_path) |
| print(f"Saved: {ckpt_path}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|