h4-polytopic-attention / python /rag /train_ranker.py
grapheneaffiliates's picture
Upload python/rag/train_ranker.py with huggingface_hub
de3c862 verified
"""
Train H4 geometric ranker with contrastive learning on SQuAD.
For each batch:
- Each question is paired with its correct passage (positive)
- All other passages in the batch are negatives (in-batch negatives)
- Loss: InfoNCE — correct passage should score highest
Metric: Recall@1 (does the top-ranked passage contain the answer?)
This is a much simpler task than extractive QA:
- Ranking maps two texts to a scalar (not text to text)
- 370K ternary params can learn this
- 5,928 SQuAD pairs provide enough signal
"""
import os
import math
import time
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from rag.ranking_model import H4Ranker
from rag.prepare_qa import download_squad_dev
from rag.tokenizer import BPETokenizer
# ---------------------------------------------------------------------------
# Hyperparameters
# ---------------------------------------------------------------------------
TIME_BUDGET = int(os.environ.get('RANKER_TIME', 600)) # default 10 min, override with env
D_MODEL = int(os.environ.get('RANKER_DMODEL', 128))
N_HEADS = 8
N_LAYERS = int(os.environ.get('RANKER_LAYERS', 2))
D_VALUE = D_MODEL // N_HEADS
D_FFN = D_MODEL * 4
USE_BITLINEAR = True
LR = float(os.environ.get('RANKER_LR', 3e-3))
WEIGHT_DECAY = 0.01
GRAD_CLIP = 1.0
BATCH_SIZE = int(os.environ.get('RANKER_BATCH', 32))
MAX_Q_LEN = 64
MAX_P_LEN = 192
TEMPERATURE = 0.15
EVAL_INTERVAL = 200
def pad_tokens(ids, max_len):
"""Pad or truncate token list to fixed length."""
ids = ids[:max_len]
return ids + [0] * (max_len - len(ids))
def contrastive_loss(q_h4, p_h4, temperature):
"""InfoNCE loss with in-batch negatives."""
B = q_h4.shape[0]
# Similarity matrix: (B, B)
sim = torch.mm(q_h4, p_h4.t()) / temperature
# Labels: positive is on diagonal
labels = torch.arange(B, device=sim.device)
loss = F.cross_entropy(sim, labels)
# Metrics
with torch.no_grad():
preds = sim.argmax(dim=1)
recall_at_1 = (preds == labels).float().mean().item()
top5 = sim.topk(min(5, B), dim=1).indices
recall_at_5 = sum(
labels[i].item() in top5[i].tolist() for i in range(B)
) / B
ranks = (sim.argsort(dim=1, descending=True) == labels.unsqueeze(1)).nonzero()[:, 1].float() + 1
mrr = (1.0 / ranks).mean().item()
return loss, recall_at_1, recall_at_5, mrr
def main():
t_start = time.time()
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
# Load SQuAD
squad = download_squad_dev()
if len(squad) < 100:
print("SQuAD not available. Run: python rag/prepare_qa.py")
return
print(f"SQuAD: {len(squad)} QA pairs")
# Build BPE tokenizer
tokenizer = BPETokenizer(max_vocab=4096)
all_texts = [qa['context'] + ' ' + qa['question'] for qa in squad[:2000]]
tokenizer.build_vocab(all_texts)
vocab_size = tokenizer.vocab_size
# Split train/val
indices = list(range(len(squad)))
random.shuffle(indices)
n_val = 200
val_indices = set(indices[:n_val])
train_data = [squad[i] for i in indices if i not in val_indices]
val_data = [squad[i] for i in indices if i in val_indices]
print(f"Train: {len(train_data)}, Val: {len(val_data)}")
# Create model
model = H4Ranker(
vocab_size=vocab_size,
d_model=D_MODEL,
n_heads=N_HEADS,
n_layers=N_LAYERS,
d_value=D_VALUE,
d_ffn=D_FFN,
use_bitlinear=USE_BITLINEAR,
max_seq_len=max(MAX_Q_LEN, MAX_P_LEN),
)
n_params = model.count_params()
print(f"Model: {n_params:,} params ({'ternary' if USE_BITLINEAR else 'float'})")
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.95))
def lr_schedule(step):
if step < 50:
return step / 50
progress = (step - 50) / max(1, 5000 - 50)
return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * min(progress, 1.0)))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
# Training loop
model.train()
step = 0
total_training_time = 0.0
best_r1 = 0.0
print(f"\nTraining for {TIME_BUDGET}s, metric=Recall@1")
print(f"{'step':>6} {'loss':>8} {'R@1':>8} {'R@5':>8} {'MRR':>8} {'lr':>10}")
print("-" * 56)
while True:
t0 = time.time()
# Sample batch
batch = random.sample(train_data, min(BATCH_SIZE, len(train_data)))
q_ids = torch.tensor(
[pad_tokens(tokenizer.encode(qa['question']), MAX_Q_LEN) for qa in batch],
dtype=torch.long,
)
p_ids = torch.tensor(
[pad_tokens(tokenizer.encode(qa['context']), MAX_P_LEN) for qa in batch],
dtype=torch.long,
)
# Encode
q_h4 = model.encode(q_ids)
p_h4 = model.encode(p_ids)
# Loss
loss, r1, r5, mrr = contrastive_loss(q_h4, p_h4, TEMPERATURE)
optimizer.zero_grad()
loss.backward()
if GRAD_CLIP > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
optimizer.step()
scheduler.step()
dt = time.time() - t0
if step > 2:
total_training_time += dt
if step % EVAL_INTERVAL == 0:
# Quick val eval
model.eval()
val_r1s = []
with torch.no_grad():
for vi in range(0, min(len(val_data), 100), BATCH_SIZE):
vbatch = val_data[vi:vi + BATCH_SIZE]
if len(vbatch) < 2:
continue
vq = torch.tensor(
[pad_tokens(tokenizer.encode(qa['question']), MAX_Q_LEN) for qa in vbatch],
dtype=torch.long,
)
vp = torch.tensor(
[pad_tokens(tokenizer.encode(qa['context']), MAX_P_LEN) for qa in vbatch],
dtype=torch.long,
)
vq_h4 = model.encode(vq)
vp_h4 = model.encode(vp)
_, vr1, _, _ = contrastive_loss(vq_h4, vp_h4, TEMPERATURE)
val_r1s.append(vr1)
val_r1 = sum(val_r1s) / len(val_r1s) if val_r1s else 0
if val_r1 > best_r1:
best_r1 = val_r1
current_lr = scheduler.get_last_lr()[0]
print(f"{step:6d} {loss.item():8.4f} {val_r1:8.3f} {r5:8.3f} {mrr:8.3f} {current_lr:10.6f}")
model.train()
step += 1
if step > 2 and total_training_time >= TIME_BUDGET:
break
# Final evaluation
model.eval()
print("\n" + "=" * 60)
print("FINAL RANKING EVALUATION:")
all_r1 = []
all_r5 = []
all_mrr = []
with torch.no_grad():
for vi in range(0, min(len(val_data), 200), BATCH_SIZE):
vbatch = val_data[vi:vi + BATCH_SIZE]
if len(vbatch) < 2:
continue
vq = torch.tensor(
[pad_tokens(tokenizer.encode(qa['question']), MAX_Q_LEN) for qa in vbatch],
dtype=torch.long,
)
vp = torch.tensor(
[pad_tokens(tokenizer.encode(qa['context']), MAX_P_LEN) for qa in vbatch],
dtype=torch.long,
)
vq_h4 = model.encode(vq)
vp_h4 = model.encode(vp)
_, vr1, vr5, vmrr = contrastive_loss(vq_h4, vp_h4, TEMPERATURE)
all_r1.append(vr1)
all_r5.append(vr5)
all_mrr.append(vmrr)
final_r1 = sum(all_r1) / len(all_r1) if all_r1 else 0
final_r5 = sum(all_r5) / len(all_r5) if all_r5 else 0
final_mrr = sum(all_mrr) / len(all_mrr) if all_mrr else 0
# Show some examples
print(f"\nSample rankings (batch of {min(BATCH_SIZE, 8)}):")
sample_batch = val_data[:min(BATCH_SIZE, 8)]
sq = torch.tensor([pad_tokens(tokenizer.encode(qa['question']), MAX_Q_LEN) for qa in sample_batch], dtype=torch.long)
sp = torch.tensor([pad_tokens(tokenizer.encode(qa['context']), MAX_P_LEN) for qa in sample_batch], dtype=torch.long)
with torch.no_grad():
sq_h4 = model.encode(sq)
sp_h4 = model.encode(sp)
sim = torch.mm(sq_h4, sp_h4.t())
for i in range(min(3, len(sample_batch))):
scores = sim[i].tolist()
ranked = sorted(range(len(scores)), key=lambda j: -scores[j])
correct = i
rank_of_correct = ranked.index(correct) + 1
print(f" Q: {sample_batch[i]['question'][:60]}")
print(f" Correct passage rank: {rank_of_correct}/{len(scores)}")
print(f" Top passage: {sample_batch[ranked[0]]['context'][:60]}...")
print()
print("=" * 60)
print("\n---")
print(f"val_recall_at_1: {final_r1:.4f}")
print(f"val_recall_at_5: {final_r5:.4f}")
print(f"val_mrr: {final_mrr:.4f}")
print(f"best_recall_at_1: {best_r1:.4f}")
print(f"training_seconds: {total_training_time:.1f}")
print(f"total_seconds: {time.time() - t_start:.1f}")
print(f"num_steps: {step}")
print(f"num_params: {n_params}")
print(f"ternary: {'yes' if USE_BITLINEAR else 'no'}")
print(f"batch_size: {BATCH_SIZE}")
print(f"temperature: {TEMPERATURE}")
if __name__ == '__main__':
main()