h4-polytopic-attention / python /rag /train_cross_encoder.py
grapheneaffiliates's picture
Upload python/rag/train_cross_encoder.py with huggingface_hub
60ddd88 verified
"""
Train H4 cross-encoder reranker on SQuAD.
Uses the PPL 10.0 TinyStories checkpoint as backbone.
Fine-tunes on binary classification: does this passage answer this question?
For each SQuAD example:
- Positive: [question SEP correct_passage] -> label 1
- Negative: [question SEP wrong_passage] -> label 0
The H4 attention heads directly attend from question tokens to passage tokens
within the same sequence — this is why cross-encoders beat bi-encoders.
Pipeline integration:
1. Bi-encoder retrieves top-5 (R@5=100%, 20ms)
2. Cross-encoder reranks 5 candidates (5 forward passes, ~50ms)
3. Return highest-scoring → R@1 should reach 80-90%+
"""
import os
import math
import time
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from rag.cross_encoder import H4CrossEncoder
from rag.prepare_qa import download_squad_dev
from rag.tokenizer import BPETokenizer
# ---------------------------------------------------------------------------
# Hyperparameters
# ---------------------------------------------------------------------------
TIME_BUDGET = int(os.environ.get('CE_TIME', 3600)) # 1 hour default
D_MODEL = 512
N_HEADS = 8
N_LAYERS = 8
USE_BITLINEAR = True
LR = 5e-4 # lower LR for fine-tuning (backbone is pre-trained)
WEIGHT_DECAY = 0.01
GRAD_CLIP = 1.0
BATCH_SIZE = 8 # pairs per batch (each has 1 positive + 1 negative)
MAX_SEQ_LEN = 192 # question + passage combined
EVAL_INTERVAL = 100
CHECKPOINT_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'checkpoints', 'h4_fullscale_final.pt')
def pad_tokens(ids, max_len):
ids = ids[:max_len]
return ids + [0] * (max_len - len(ids))
def make_pair(tokenizer, question, passage, max_len):
"""Encode [question SEP passage] as a single sequence."""
q_ids = tokenizer.encode(question)
p_ids = tokenizer.encode(passage)
# Budget: half for question, half for passage (with SEP)
max_q = max_len // 3
max_p = max_len - max_q - 1
q_ids = q_ids[:max_q]
p_ids = p_ids[:max_p]
combined = q_ids + [2] + p_ids # 2 = SEP
return pad_tokens(combined, max_len)
def main():
t_start = time.time()
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
# Load SQuAD
squad = download_squad_dev()
if len(squad) < 100:
print("SQuAD not available. Run: python rag/prepare_qa.py")
return
print(f"SQuAD: {len(squad)} QA pairs")
# Build BPE tokenizer
tokenizer = BPETokenizer(max_vocab=8192)
all_texts = [qa['context'] + ' ' + qa['question'] for qa in squad[:2000]]
tokenizer.build_vocab(all_texts)
# Split
indices = list(range(len(squad)))
random.shuffle(indices)
n_val = 200
train_data = [squad[i] for i in indices[n_val:]]
val_data = [squad[i] for i in indices[:n_val]]
print(f"Train: {len(train_data)}, Val: {len(val_data)}")
# Create cross-encoder
model = H4CrossEncoder(
vocab_size=tokenizer.vocab_size,
d_model=D_MODEL,
n_heads=N_HEADS,
n_layers=N_LAYERS,
use_bitlinear=USE_BITLINEAR,
max_seq_len=MAX_SEQ_LEN,
)
print(f"Model: {model.count_params():,} params")
# Load pre-trained backbone
if os.path.exists(CHECKPOINT_PATH):
config = model.load_lm_backbone(CHECKPOINT_PATH)
print(f"Loaded backbone from {CHECKPOINT_PATH}")
else:
print(f"No checkpoint at {CHECKPOINT_PATH}, training from scratch")
# Freeze backbone initially, only train score head
# Then unfreeze after warmup for fine-tuning
for name, param in model.lm.named_parameters():
param.requires_grad = False
trainable_head = sum(p.numel() for p in model.score_head.parameters())
print(f"Phase 1: training score head only ({trainable_head:,} params)")
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=LR * 10, # higher LR for head-only phase
weight_decay=WEIGHT_DECAY,
)
# Training loop
model.train()
step = 0
total_training_time = 0.0
best_acc = 0.0
unfrozen = False
UNFREEZE_STEP = 200
print(f"\nTraining for {TIME_BUDGET}s")
print(f"{'step':>6} {'loss':>8} {'acc':>8} {'val_acc':>8} {'phase':>10}")
print("-" * 48)
while True:
t0 = time.time()
# Unfreeze backbone after warmup
if step == UNFREEZE_STEP and not unfrozen:
for param in model.lm.parameters():
param.requires_grad = True
unfrozen = True
total_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n Phase 2: unfreezing backbone ({total_trainable:,} trainable params)")
optimizer = torch.optim.AdamW(
model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.95))
# Sample batch: positive and negative pairs
batch_qa = random.sample(train_data, min(BATCH_SIZE, len(train_data)))
input_ids = []
labels = []
for qa in batch_qa:
# Positive: question + correct passage
pos = make_pair(tokenizer, qa['question'], qa['context'], MAX_SEQ_LEN)
input_ids.append(pos)
labels.append(1.0)
# Negative: question + random wrong passage
neg_qa = random.choice(train_data)
while neg_qa['context'] == qa['context']:
neg_qa = random.choice(train_data)
neg = make_pair(tokenizer, qa['question'], neg_qa['context'], MAX_SEQ_LEN)
input_ids.append(neg)
labels.append(0.0)
input_ids = torch.tensor(input_ids, dtype=torch.long)
labels = torch.tensor(labels, dtype=torch.float32)
# Forward
scores = model(input_ids)
loss = F.binary_cross_entropy_with_logits(scores, labels)
# Accuracy
with torch.no_grad():
preds = (scores > 0).float()
acc = (preds == labels).float().mean().item()
optimizer.zero_grad()
loss.backward()
if GRAD_CLIP > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
optimizer.step()
dt = time.time() - t0
if step > 2:
total_training_time += dt
# Eval
if step % EVAL_INTERVAL == 0:
model.eval()
val_correct = 0
val_total = 0
val_r1 = 0
val_r1_total = 0
with torch.no_grad():
# Binary accuracy
for vi in range(0, min(len(val_data), 100), BATCH_SIZE):
vbatch = val_data[vi:vi + BATCH_SIZE]
v_ids = []
v_labels = []
for qa in vbatch:
pos = make_pair(tokenizer, qa['question'], qa['context'], MAX_SEQ_LEN)
v_ids.append(pos)
v_labels.append(1.0)
neg_qa = random.choice(val_data)
neg = make_pair(tokenizer, qa['question'], neg_qa['context'], MAX_SEQ_LEN)
v_ids.append(neg)
v_labels.append(0.0)
v_ids = torch.tensor(v_ids, dtype=torch.long)
v_labels = torch.tensor(v_labels)
v_scores = model(v_ids)
v_preds = (v_scores > 0).float()
val_correct += (v_preds == v_labels).sum().item()
val_total += len(v_labels)
# Reranking accuracy (R@1 on top-5 candidates)
for qa in val_data[:50]:
# Simulate: 1 correct + 4 wrong passages
candidates = [qa['context']]
neg_pool = [q for q in val_data if q['context'] != qa['context']]
for neg in random.sample(neg_pool, min(4, len(neg_pool))):
candidates.append(neg['context'])
c_ids = []
for passage in candidates:
c_ids.append(make_pair(tokenizer, qa['question'], passage, MAX_SEQ_LEN))
c_ids = torch.tensor(c_ids, dtype=torch.long)
c_scores = model(c_ids)
top_idx = c_scores.argmax().item()
if top_idx == 0: # correct passage was ranked first
val_r1 += 1
val_r1_total += 1
val_acc = val_correct / max(val_total, 1)
rerank_r1 = val_r1 / max(val_r1_total, 1)
if val_acc > best_acc:
best_acc = val_acc
phase = "head-only" if not unfrozen else "full"
print(f"{step:6d} {loss.item():8.4f} {acc:8.3f} {val_acc:8.3f} {phase:>10}"
f" rerank_R@1={rerank_r1:.3f}")
model.train()
step += 1
elapsed = time.time() - t_start
if step > 2 and total_training_time >= TIME_BUDGET:
break
# Final evaluation
model.eval()
print("\n" + "=" * 60)
print("FINAL CROSS-ENCODER EVALUATION:")
final_r1 = 0
final_total = 0
with torch.no_grad():
for qa in val_data[:100]:
candidates = [qa['context']]
neg_pool = [q for q in val_data if q['context'] != qa['context']]
for neg in random.sample(neg_pool, min(4, len(neg_pool))):
candidates.append(neg['context'])
c_ids = []
for passage in candidates:
c_ids.append(make_pair(tokenizer, qa['question'], passage, MAX_SEQ_LEN))
c_ids = torch.tensor(c_ids, dtype=torch.long)
c_scores = model(c_ids)
if c_scores.argmax().item() == 0:
final_r1 += 1
final_total += 1
rerank_r1 = final_r1 / max(final_total, 1)
print(f" Rerank R@1 (top-5): {rerank_r1:.1%} ({final_r1}/{final_total})")
print(f" Best binary acc: {best_acc:.1%}")
print("=" * 60)
print("\n---")
print(f"rerank_r1: {rerank_r1:.4f}")
print(f"best_binary_acc: {best_acc:.4f}")
print(f"training_seconds: {total_training_time:.1f}")
print(f"total_seconds: {time.time() - t_start:.1f}")
print(f"num_steps: {step}")
print(f"num_params: {model.count_params()}")
print(f"ternary: {'yes' if USE_BITLINEAR else 'no'}")
# Save checkpoint
os.makedirs('checkpoints', exist_ok=True)
ckpt_path = os.path.join('checkpoints', 'h4_cross_encoder.pt')
torch.save({
'model_state': model.state_dict(),
'rerank_r1': rerank_r1,
'step': step,
'config': {
'd_model': D_MODEL, 'n_layers': N_LAYERS, 'n_heads': N_HEADS,
'vocab_size': tokenizer.vocab_size, 'use_bitlinear': USE_BITLINEAR,
},
}, ckpt_path)
print(f"Saved: {ckpt_path}")
if __name__ == '__main__':
main()