| """
|
| Train H4 geometric ranker with contrastive learning on SQuAD.
|
|
|
| For each batch:
|
| - Each question is paired with its correct passage (positive)
|
| - All other passages in the batch are negatives (in-batch negatives)
|
| - Loss: InfoNCE — correct passage should score highest
|
|
|
| Metric: Recall@1 (does the top-ranked passage contain the answer?)
|
|
|
| This is a much simpler task than extractive QA:
|
| - Ranking maps two texts to a scalar (not text to text)
|
| - 370K ternary params can learn this
|
| - 5,928 SQuAD pairs provide enough signal
|
| """
|
|
|
| import os
|
| import math
|
| import time
|
| import random
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import numpy as np
|
| import sys
|
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
|
|
| from rag.ranking_model import H4Ranker
|
| from rag.prepare_qa import download_squad_dev
|
| from rag.tokenizer import BPETokenizer
|
|
|
|
|
|
|
|
|
|
|
| TIME_BUDGET = int(os.environ.get('RANKER_TIME', 600))
|
| D_MODEL = int(os.environ.get('RANKER_DMODEL', 128))
|
| N_HEADS = 8
|
| N_LAYERS = int(os.environ.get('RANKER_LAYERS', 2))
|
| D_VALUE = D_MODEL // N_HEADS
|
| D_FFN = D_MODEL * 4
|
| USE_BITLINEAR = True
|
| LR = float(os.environ.get('RANKER_LR', 3e-3))
|
| WEIGHT_DECAY = 0.01
|
| GRAD_CLIP = 1.0
|
| BATCH_SIZE = int(os.environ.get('RANKER_BATCH', 32))
|
| MAX_Q_LEN = 64
|
| MAX_P_LEN = 192
|
| TEMPERATURE = 0.15
|
| EVAL_INTERVAL = 200
|
|
|
|
|
| def pad_tokens(ids, max_len):
|
| """Pad or truncate token list to fixed length."""
|
| ids = ids[:max_len]
|
| return ids + [0] * (max_len - len(ids))
|
|
|
|
|
| def contrastive_loss(q_h4, p_h4, temperature):
|
| """InfoNCE loss with in-batch negatives."""
|
| B = q_h4.shape[0]
|
|
|
| sim = torch.mm(q_h4, p_h4.t()) / temperature
|
|
|
|
|
| labels = torch.arange(B, device=sim.device)
|
| loss = F.cross_entropy(sim, labels)
|
|
|
|
|
| with torch.no_grad():
|
| preds = sim.argmax(dim=1)
|
| recall_at_1 = (preds == labels).float().mean().item()
|
|
|
| top5 = sim.topk(min(5, B), dim=1).indices
|
| recall_at_5 = sum(
|
| labels[i].item() in top5[i].tolist() for i in range(B)
|
| ) / B
|
|
|
| ranks = (sim.argsort(dim=1, descending=True) == labels.unsqueeze(1)).nonzero()[:, 1].float() + 1
|
| mrr = (1.0 / ranks).mean().item()
|
|
|
| return loss, recall_at_1, recall_at_5, mrr
|
|
|
|
|
| def main():
|
| t_start = time.time()
|
| torch.manual_seed(42)
|
| random.seed(42)
|
| np.random.seed(42)
|
|
|
|
|
| squad = download_squad_dev()
|
| if len(squad) < 100:
|
| print("SQuAD not available. Run: python rag/prepare_qa.py")
|
| return
|
| print(f"SQuAD: {len(squad)} QA pairs")
|
|
|
|
|
| tokenizer = BPETokenizer(max_vocab=4096)
|
| all_texts = [qa['context'] + ' ' + qa['question'] for qa in squad[:2000]]
|
| tokenizer.build_vocab(all_texts)
|
| vocab_size = tokenizer.vocab_size
|
|
|
|
|
| indices = list(range(len(squad)))
|
| random.shuffle(indices)
|
| n_val = 200
|
| val_indices = set(indices[:n_val])
|
| train_data = [squad[i] for i in indices if i not in val_indices]
|
| val_data = [squad[i] for i in indices if i in val_indices]
|
| print(f"Train: {len(train_data)}, Val: {len(val_data)}")
|
|
|
|
|
| model = H4Ranker(
|
| vocab_size=vocab_size,
|
| d_model=D_MODEL,
|
| n_heads=N_HEADS,
|
| n_layers=N_LAYERS,
|
| d_value=D_VALUE,
|
| d_ffn=D_FFN,
|
| use_bitlinear=USE_BITLINEAR,
|
| max_seq_len=max(MAX_Q_LEN, MAX_P_LEN),
|
| )
|
| n_params = model.count_params()
|
| print(f"Model: {n_params:,} params ({'ternary' if USE_BITLINEAR else 'float'})")
|
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.95))
|
|
|
| def lr_schedule(step):
|
| if step < 50:
|
| return step / 50
|
| progress = (step - 50) / max(1, 5000 - 50)
|
| return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * min(progress, 1.0)))
|
|
|
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
|
|
|
|
|
| model.train()
|
| step = 0
|
| total_training_time = 0.0
|
| best_r1 = 0.0
|
|
|
| print(f"\nTraining for {TIME_BUDGET}s, metric=Recall@1")
|
| print(f"{'step':>6} {'loss':>8} {'R@1':>8} {'R@5':>8} {'MRR':>8} {'lr':>10}")
|
| print("-" * 56)
|
|
|
| while True:
|
| t0 = time.time()
|
|
|
|
|
| batch = random.sample(train_data, min(BATCH_SIZE, len(train_data)))
|
|
|
| q_ids = torch.tensor(
|
| [pad_tokens(tokenizer.encode(qa['question']), MAX_Q_LEN) for qa in batch],
|
| dtype=torch.long,
|
| )
|
| p_ids = torch.tensor(
|
| [pad_tokens(tokenizer.encode(qa['context']), MAX_P_LEN) for qa in batch],
|
| dtype=torch.long,
|
| )
|
|
|
|
|
| q_h4 = model.encode(q_ids)
|
| p_h4 = model.encode(p_ids)
|
|
|
|
|
| loss, r1, r5, mrr = contrastive_loss(q_h4, p_h4, TEMPERATURE)
|
|
|
| optimizer.zero_grad()
|
| loss.backward()
|
| if GRAD_CLIP > 0:
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
|
| optimizer.step()
|
| scheduler.step()
|
|
|
| dt = time.time() - t0
|
| if step > 2:
|
| total_training_time += dt
|
|
|
| if step % EVAL_INTERVAL == 0:
|
|
|
| model.eval()
|
| val_r1s = []
|
| with torch.no_grad():
|
| for vi in range(0, min(len(val_data), 100), BATCH_SIZE):
|
| vbatch = val_data[vi:vi + BATCH_SIZE]
|
| if len(vbatch) < 2:
|
| continue
|
| vq = torch.tensor(
|
| [pad_tokens(tokenizer.encode(qa['question']), MAX_Q_LEN) for qa in vbatch],
|
| dtype=torch.long,
|
| )
|
| vp = torch.tensor(
|
| [pad_tokens(tokenizer.encode(qa['context']), MAX_P_LEN) for qa in vbatch],
|
| dtype=torch.long,
|
| )
|
| vq_h4 = model.encode(vq)
|
| vp_h4 = model.encode(vp)
|
| _, vr1, _, _ = contrastive_loss(vq_h4, vp_h4, TEMPERATURE)
|
| val_r1s.append(vr1)
|
|
|
| val_r1 = sum(val_r1s) / len(val_r1s) if val_r1s else 0
|
| if val_r1 > best_r1:
|
| best_r1 = val_r1
|
|
|
| current_lr = scheduler.get_last_lr()[0]
|
| print(f"{step:6d} {loss.item():8.4f} {val_r1:8.3f} {r5:8.3f} {mrr:8.3f} {current_lr:10.6f}")
|
| model.train()
|
|
|
| step += 1
|
| if step > 2 and total_training_time >= TIME_BUDGET:
|
| break
|
|
|
|
|
| model.eval()
|
| print("\n" + "=" * 60)
|
| print("FINAL RANKING EVALUATION:")
|
|
|
| all_r1 = []
|
| all_r5 = []
|
| all_mrr = []
|
| with torch.no_grad():
|
| for vi in range(0, min(len(val_data), 200), BATCH_SIZE):
|
| vbatch = val_data[vi:vi + BATCH_SIZE]
|
| if len(vbatch) < 2:
|
| continue
|
| vq = torch.tensor(
|
| [pad_tokens(tokenizer.encode(qa['question']), MAX_Q_LEN) for qa in vbatch],
|
| dtype=torch.long,
|
| )
|
| vp = torch.tensor(
|
| [pad_tokens(tokenizer.encode(qa['context']), MAX_P_LEN) for qa in vbatch],
|
| dtype=torch.long,
|
| )
|
| vq_h4 = model.encode(vq)
|
| vp_h4 = model.encode(vp)
|
| _, vr1, vr5, vmrr = contrastive_loss(vq_h4, vp_h4, TEMPERATURE)
|
| all_r1.append(vr1)
|
| all_r5.append(vr5)
|
| all_mrr.append(vmrr)
|
|
|
| final_r1 = sum(all_r1) / len(all_r1) if all_r1 else 0
|
| final_r5 = sum(all_r5) / len(all_r5) if all_r5 else 0
|
| final_mrr = sum(all_mrr) / len(all_mrr) if all_mrr else 0
|
|
|
|
|
| print(f"\nSample rankings (batch of {min(BATCH_SIZE, 8)}):")
|
| sample_batch = val_data[:min(BATCH_SIZE, 8)]
|
| sq = torch.tensor([pad_tokens(tokenizer.encode(qa['question']), MAX_Q_LEN) for qa in sample_batch], dtype=torch.long)
|
| sp = torch.tensor([pad_tokens(tokenizer.encode(qa['context']), MAX_P_LEN) for qa in sample_batch], dtype=torch.long)
|
| with torch.no_grad():
|
| sq_h4 = model.encode(sq)
|
| sp_h4 = model.encode(sp)
|
| sim = torch.mm(sq_h4, sp_h4.t())
|
|
|
| for i in range(min(3, len(sample_batch))):
|
| scores = sim[i].tolist()
|
| ranked = sorted(range(len(scores)), key=lambda j: -scores[j])
|
| correct = i
|
| rank_of_correct = ranked.index(correct) + 1
|
| print(f" Q: {sample_batch[i]['question'][:60]}")
|
| print(f" Correct passage rank: {rank_of_correct}/{len(scores)}")
|
| print(f" Top passage: {sample_batch[ranked[0]]['context'][:60]}...")
|
| print()
|
|
|
| print("=" * 60)
|
| print("\n---")
|
| print(f"val_recall_at_1: {final_r1:.4f}")
|
| print(f"val_recall_at_5: {final_r5:.4f}")
|
| print(f"val_mrr: {final_mrr:.4f}")
|
| print(f"best_recall_at_1: {best_r1:.4f}")
|
| print(f"training_seconds: {total_training_time:.1f}")
|
| print(f"total_seconds: {time.time() - t_start:.1f}")
|
| print(f"num_steps: {step}")
|
| print(f"num_params: {n_params}")
|
| print(f"ternary: {'yes' if USE_BITLINEAR else 'no'}")
|
| print(f"batch_size: {BATCH_SIZE}")
|
| print(f"temperature: {TEMPERATURE}")
|
|
|
|
|
| if __name__ == '__main__':
|
| main()
|
|
|