File size: 3,954 Bytes
112ea08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
H4 Geometric Ranker: score (question, passage) relevance in H4 space.

Architecture:
    1. Encode question with H4 attention (4D geometric heads)
    2. Encode passage with H4 attention (shared weights)
    3. Relevance = dot product on S³ (same metric as ChamberTree attention)

The scoring uses the SAME geometry as attention routing.
No separate scoring function needed — the architecture is the scorer.
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

from h4_hybrid_attention import H4TransformerBlock
from bitlinear import BitLinear


class H4Ranker(nn.Module):
    """
    Score (question, passage) relevance via H4 geometric similarity.

    Both question and passage are encoded to 4D vectors on S³.
    Relevance = dot product in H4 space. Higher = more relevant.
    Trained with contrastive loss (InfoNCE) using in-batch negatives.
    """

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 128,
        n_heads: int = 8,
        n_layers: int = 2,
        d_value: int = 16,
        d_ffn: int = None,
        use_bitlinear: bool = True,
        max_seq_len: int = 256,
    ):
        super().__init__()
        self.d_model = d_model
        self.use_bitlinear = use_bitlinear

        if d_ffn is None:
            d_ffn = d_model * 4

        Linear = BitLinear if use_bitlinear else nn.Linear

        # Token embedding (shared between question and passage)
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.emb_scale = math.sqrt(d_model)

        # H4 attention blocks (shared encoder)
        self.blocks = nn.ModuleList([
            H4TransformerBlock(
                d_model=d_model,
                n_heads=n_heads,
                d_value=d_value,
                d_ffn=d_ffn,
                dropout=0.0,
                use_bitlinear=use_bitlinear,
            )
            for _ in range(n_layers)
        ])

        self.ln_f = nn.LayerNorm(d_model)

        # Project from d_model to 4D (H4 space) for geometric scoring
        self.to_h4 = Linear(d_model, 4, bias=False)

        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, (nn.Linear, BitLinear)):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def encode(self, token_ids: torch.Tensor) -> torch.Tensor:
        """
        Encode a batch of sequences to 4D vectors on S³.

        Args:
            token_ids: (B, T) tokenized text (0-padded)
        Returns:
            (B, 4) unit vectors in H4 space
        """
        # Create padding mask
        pad_mask = (token_ids != 0).float()  # (B, T)

        x = self.embedding(token_ids) * self.emb_scale  # (B, T, d_model)

        for block in self.blocks:
            x = block(x, use_tree=False)

        x = self.ln_f(x)

        # Masked mean pool (ignore padding)
        mask = pad_mask.unsqueeze(-1)  # (B, T, 1)
        x = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)  # (B, d_model)

        # Project to H4 and normalize to S³
        h4 = self.to_h4(x)  # (B, 4)
        h4 = F.normalize(h4, dim=-1)

        return h4

    def score(self, q_ids: torch.Tensor, p_ids: torch.Tensor) -> torch.Tensor:
        """
        Score relevance of (question, passage) pairs.

        Args:
            q_ids: (B, T_q) question tokens
            p_ids: (B, T_p) passage tokens
        Returns:
            (B,) scores in [-1, 1]
        """
        q_h4 = self.encode(q_ids)
        p_h4 = self.encode(p_ids)
        return (q_h4 * p_h4).sum(dim=-1)

    def count_params(self):
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        return trainable