File size: 3,954 Bytes
112ea08 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | """
H4 Geometric Ranker: score (question, passage) relevance in H4 space.
Architecture:
1. Encode question with H4 attention (4D geometric heads)
2. Encode passage with H4 attention (shared weights)
3. Relevance = dot product on S³ (same metric as ChamberTree attention)
The scoring uses the SAME geometry as attention routing.
No separate scoring function needed — the architecture is the scorer.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from h4_hybrid_attention import H4TransformerBlock
from bitlinear import BitLinear
class H4Ranker(nn.Module):
"""
Score (question, passage) relevance via H4 geometric similarity.
Both question and passage are encoded to 4D vectors on S³.
Relevance = dot product in H4 space. Higher = more relevant.
Trained with contrastive loss (InfoNCE) using in-batch negatives.
"""
def __init__(
self,
vocab_size: int,
d_model: int = 128,
n_heads: int = 8,
n_layers: int = 2,
d_value: int = 16,
d_ffn: int = None,
use_bitlinear: bool = True,
max_seq_len: int = 256,
):
super().__init__()
self.d_model = d_model
self.use_bitlinear = use_bitlinear
if d_ffn is None:
d_ffn = d_model * 4
Linear = BitLinear if use_bitlinear else nn.Linear
# Token embedding (shared between question and passage)
self.embedding = nn.Embedding(vocab_size, d_model)
self.emb_scale = math.sqrt(d_model)
# H4 attention blocks (shared encoder)
self.blocks = nn.ModuleList([
H4TransformerBlock(
d_model=d_model,
n_heads=n_heads,
d_value=d_value,
d_ffn=d_ffn,
dropout=0.0,
use_bitlinear=use_bitlinear,
)
for _ in range(n_layers)
])
self.ln_f = nn.LayerNorm(d_model)
# Project from d_model to 4D (H4 space) for geometric scoring
self.to_h4 = Linear(d_model, 4, bias=False)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, (nn.Linear, BitLinear)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def encode(self, token_ids: torch.Tensor) -> torch.Tensor:
"""
Encode a batch of sequences to 4D vectors on S³.
Args:
token_ids: (B, T) tokenized text (0-padded)
Returns:
(B, 4) unit vectors in H4 space
"""
# Create padding mask
pad_mask = (token_ids != 0).float() # (B, T)
x = self.embedding(token_ids) * self.emb_scale # (B, T, d_model)
for block in self.blocks:
x = block(x, use_tree=False)
x = self.ln_f(x)
# Masked mean pool (ignore padding)
mask = pad_mask.unsqueeze(-1) # (B, T, 1)
x = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) # (B, d_model)
# Project to H4 and normalize to S³
h4 = self.to_h4(x) # (B, 4)
h4 = F.normalize(h4, dim=-1)
return h4
def score(self, q_ids: torch.Tensor, p_ids: torch.Tensor) -> torch.Tensor:
"""
Score relevance of (question, passage) pairs.
Args:
q_ids: (B, T_q) question tokens
p_ids: (B, T_p) passage tokens
Returns:
(B,) scores in [-1, 1]
"""
q_h4 = self.encode(q_ids)
p_h4 = self.encode(p_ids)
return (q_h4 * p_h4).sum(dim=-1)
def count_params(self):
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
return trainable
|