File size: 4,245 Bytes
d223135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
H4 Cross-Encoder Reranker — score (question, passage) with full cross-attention.

The bi-encoder (ranking_model.py) encodes question and passage separately.
Fast but can't compare them directly — gets R@5=100% but R@1 plateaus at ~40%.

The cross-encoder feeds question + passage as ONE sequence through H4 attention.
The attention heads directly attend from question tokens to passage tokens.
Slower (one forward pass per candidate) but much more precise.

Production pipeline:
  1. Bi-encoder retrieves top-k candidates (fast, 20ms for all docs)
  2. Cross-encoder reranks k candidates (precise, ~10ms per candidate)
  3. Return the top-ranked candidate

Uses the PPL 10.0 TinyStories checkpoint as backbone — the model already
knows English, it just needs to learn "does this passage answer this question?"
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

from h4_language_model import H4LanguageModel
from bitlinear import BitLinear


class H4CrossEncoder(nn.Module):
    """
    Cross-encoder reranker using H4 attention.

    Input: [question tokens] [SEP] [passage tokens]
    Output: scalar relevance score

    The H4 attention heads attend across the full concatenated sequence,
    so question tokens can directly attend to passage tokens via
    ChamberTree geometric routing.
    """

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        n_heads: int = 8,
        n_layers: int = 8,
        use_bitlinear: bool = True,
        max_seq_len: int = 256,
    ):
        super().__init__()
        self.d_model = d_model

        # Use the same architecture as the language model
        self.lm = H4LanguageModel(
            vocab_size=vocab_size,
            d_model=d_model,
            n_heads=n_heads,
            n_layers=n_layers,
            d_value=d_model // n_heads,
            d_ffn=d_model * 4,
            max_seq_len=max_seq_len,
            dropout=0.0,
            use_bitlinear=use_bitlinear,
        )

        # Classification head: pool the sequence, project to scalar
        Linear = BitLinear if use_bitlinear else nn.Linear
        self.score_head = nn.Sequential(
            Linear(d_model, d_model // 4, bias=False),
            nn.GELU(),
            nn.Linear(d_model // 4, 1),
        )

    def load_lm_backbone(self, checkpoint_path: str):
        """
        Load pre-trained language model weights as backbone.
        The LM head is discarded; we keep the transformer blocks.
        """
        ckpt = torch.load(checkpoint_path, map_location='cpu')
        lm_state = ckpt['model_state']
        model_state = self.lm.state_dict()

        loaded = 0
        skipped = 0
        for key in lm_state:
            if key in model_state and lm_state[key].shape == model_state[key].shape:
                model_state[key] = lm_state[key]
                loaded += 1
            else:
                skipped += 1

        self.lm.load_state_dict(model_state)
        print(f"Loaded LM backbone: {loaded} tensors, {skipped} skipped")
        return ckpt.get('config', {})

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        Score a batch of (question + passage) sequences.

        Args:
            input_ids: (B, T) — tokenized [question SEP passage]
        Returns:
            (B,) relevance scores
        """
        # Get transformer hidden states (bypass LM head)
        B, T = input_ids.shape
        tok_emb = self.lm.token_emb(input_ids) * self.lm.emb_scale
        pos_emb = self.lm.pos_enc(T).unsqueeze(0).to(tok_emb.device)
        x = self.lm.emb_dropout(tok_emb + pos_emb)

        for block in self.lm.blocks:
            x = block(x, use_tree=False)

        x = self.lm.ln_f(x)

        # Mean pool over non-padding tokens
        pad_mask = (input_ids != 0).float().unsqueeze(-1)  # (B, T, 1)
        pooled = (x * pad_mask).sum(dim=1) / pad_mask.sum(dim=1).clamp(min=1)  # (B, d_model)

        # Score
        score = self.score_head(pooled).squeeze(-1)  # (B,)
        return score

    def count_params(self):
        return sum(p.numel() for p in self.parameters())