File size: 4,477 Bytes
a608d21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11bc1ef
a608d21
 
 
 
 
 
 
 
 
 
 
 
11bc1ef
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from __future__ import annotations

from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F


class HyperNet(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hypernet = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, output_dim * 2),
        )

        for layer in self.hypernet:
            if isinstance(layer, nn.Linear):
                nn.init.normal_(layer.weight, std=0.01)
                nn.init.zeros_(layer.bias)

    def forward(self, embedding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        hyper_output = self.hypernet(embedding)
        scale = torch.sigmoid(hyper_output[:, : self.output_dim])
        bias = torch.tanh(hyper_output[:, self.output_dim :])
        return scale, bias


@dataclass(frozen=True)
class ModelConfig:
    input_dim: int
    output_dim: int = 256
    hidden_dim: int = 128
    dropout: float = 0.1


class IMRNN(nn.Module):
    """
    Adapter-only IMRNN implementation over cached dense embeddings.

    The model keeps the legacy module names (`query_hypernet`, `doc_hypernet`,
    `query_norm`, `doc_norm`) so existing `bihypernet_*.pt` checkpoints can be
    loaded with key remapping and `strict=False`.
    """

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.projector = nn.Linear(config.input_dim, config.output_dim)
        self.query_hypernet = HyperNet(config.output_dim, config.hidden_dim, config.output_dim, config.dropout)
        self.doc_hypernet = HyperNet(config.output_dim, config.hidden_dim, config.output_dim, config.dropout)
        self.query_norm = nn.LayerNorm(config.output_dim)
        self.doc_norm = nn.LayerNorm(config.output_dim)

    def project(self, embeddings: torch.Tensor) -> torch.Tensor:
        return F.normalize(self.projector(embeddings), p=2, dim=-1)

    def modulate_documents(
        self,
        query_embeddings: torch.Tensor,
        document_embeddings: torch.Tensor,
    ) -> torch.Tensor:
        q_scale, q_bias = self.query_hypernet(query_embeddings)
        return self.doc_norm(
            document_embeddings * q_scale.unsqueeze(1) + q_bias.unsqueeze(1)
        )

    def modulate_query(
        self,
        query_embeddings: torch.Tensor,
        modulated_documents: torch.Tensor,
    ) -> torch.Tensor:
        document_summary = modulated_documents.mean(dim=1)
        d_scale, d_bias = self.doc_hypernet(document_summary)
        return self.query_norm(query_embeddings * d_scale + d_bias)

    def forward(
        self,
        query_embeddings: torch.Tensor,
        document_embeddings: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            query_embeddings: [batch, input_dim]
            document_embeddings: [batch, docs_per_query, input_dim]
        """
        projected_queries = self.project(query_embeddings)
        projected_documents = self.project(document_embeddings)
        modulated_documents = self.modulate_documents(projected_queries, projected_documents)
        modulated_queries = self.modulate_query(projected_queries, modulated_documents)
        scores = torch.einsum("bd,bkd->bk", F.normalize(modulated_queries, p=2, dim=-1), F.normalize(modulated_documents, p=2, dim=-1))
        return modulated_queries, modulated_documents, scores

    def score_candidates(
        self,
        query_embedding: torch.Tensor,
        candidate_document_embeddings: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        if query_embedding.dim() == 1:
            query_embedding = query_embedding.unsqueeze(0)
        if candidate_document_embeddings.dim() == 2:
            candidate_document_embeddings = candidate_document_embeddings.unsqueeze(0)
        modulated_query, modulated_docs, scores = self.forward(query_embedding, candidate_document_embeddings)
        return modulated_query.squeeze(0), modulated_docs.squeeze(0), scores.squeeze(0)


BiHyperNetIR = IMRNN
"""Backward-compatible alias retained for legacy checkpoints and code paths."""