File size: 4,245 Bytes
d223135 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """
H4 Cross-Encoder Reranker — score (question, passage) with full cross-attention.
The bi-encoder (ranking_model.py) encodes question and passage separately.
Fast but can't compare them directly — gets R@5=100% but R@1 plateaus at ~40%.
The cross-encoder feeds question + passage as ONE sequence through H4 attention.
The attention heads directly attend from question tokens to passage tokens.
Slower (one forward pass per candidate) but much more precise.
Production pipeline:
1. Bi-encoder retrieves top-k candidates (fast, 20ms for all docs)
2. Cross-encoder reranks k candidates (precise, ~10ms per candidate)
3. Return the top-ranked candidate
Uses the PPL 10.0 TinyStories checkpoint as backbone — the model already
knows English, it just needs to learn "does this passage answer this question?"
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from h4_language_model import H4LanguageModel
from bitlinear import BitLinear
class H4CrossEncoder(nn.Module):
"""
Cross-encoder reranker using H4 attention.
Input: [question tokens] [SEP] [passage tokens]
Output: scalar relevance score
The H4 attention heads attend across the full concatenated sequence,
so question tokens can directly attend to passage tokens via
ChamberTree geometric routing.
"""
def __init__(
self,
vocab_size: int,
d_model: int = 512,
n_heads: int = 8,
n_layers: int = 8,
use_bitlinear: bool = True,
max_seq_len: int = 256,
):
super().__init__()
self.d_model = d_model
# Use the same architecture as the language model
self.lm = H4LanguageModel(
vocab_size=vocab_size,
d_model=d_model,
n_heads=n_heads,
n_layers=n_layers,
d_value=d_model // n_heads,
d_ffn=d_model * 4,
max_seq_len=max_seq_len,
dropout=0.0,
use_bitlinear=use_bitlinear,
)
# Classification head: pool the sequence, project to scalar
Linear = BitLinear if use_bitlinear else nn.Linear
self.score_head = nn.Sequential(
Linear(d_model, d_model // 4, bias=False),
nn.GELU(),
nn.Linear(d_model // 4, 1),
)
def load_lm_backbone(self, checkpoint_path: str):
"""
Load pre-trained language model weights as backbone.
The LM head is discarded; we keep the transformer blocks.
"""
ckpt = torch.load(checkpoint_path, map_location='cpu')
lm_state = ckpt['model_state']
model_state = self.lm.state_dict()
loaded = 0
skipped = 0
for key in lm_state:
if key in model_state and lm_state[key].shape == model_state[key].shape:
model_state[key] = lm_state[key]
loaded += 1
else:
skipped += 1
self.lm.load_state_dict(model_state)
print(f"Loaded LM backbone: {loaded} tensors, {skipped} skipped")
return ckpt.get('config', {})
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Score a batch of (question + passage) sequences.
Args:
input_ids: (B, T) — tokenized [question SEP passage]
Returns:
(B,) relevance scores
"""
# Get transformer hidden states (bypass LM head)
B, T = input_ids.shape
tok_emb = self.lm.token_emb(input_ids) * self.lm.emb_scale
pos_emb = self.lm.pos_enc(T).unsqueeze(0).to(tok_emb.device)
x = self.lm.emb_dropout(tok_emb + pos_emb)
for block in self.lm.blocks:
x = block(x, use_tree=False)
x = self.lm.ln_f(x)
# Mean pool over non-padding tokens
pad_mask = (input_ids != 0).float().unsqueeze(-1) # (B, T, 1)
pooled = (x * pad_mask).sum(dim=1) / pad_mask.sum(dim=1).clamp(min=1) # (B, d_model)
# Score
score = self.score_head(pooled).squeeze(-1) # (B,)
return score
def count_params(self):
return sum(p.numel() for p in self.parameters())
|