File size: 4,477 Bytes
a608d21 11bc1ef a608d21 11bc1ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
class HyperNet(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.1):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.hypernet = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, output_dim * 2),
)
for layer in self.hypernet:
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, std=0.01)
nn.init.zeros_(layer.bias)
def forward(self, embedding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
hyper_output = self.hypernet(embedding)
scale = torch.sigmoid(hyper_output[:, : self.output_dim])
bias = torch.tanh(hyper_output[:, self.output_dim :])
return scale, bias
@dataclass(frozen=True)
class ModelConfig:
input_dim: int
output_dim: int = 256
hidden_dim: int = 128
dropout: float = 0.1
class IMRNN(nn.Module):
"""
Adapter-only IMRNN implementation over cached dense embeddings.
The model keeps the legacy module names (`query_hypernet`, `doc_hypernet`,
`query_norm`, `doc_norm`) so existing `bihypernet_*.pt` checkpoints can be
loaded with key remapping and `strict=False`.
"""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.projector = nn.Linear(config.input_dim, config.output_dim)
self.query_hypernet = HyperNet(config.output_dim, config.hidden_dim, config.output_dim, config.dropout)
self.doc_hypernet = HyperNet(config.output_dim, config.hidden_dim, config.output_dim, config.dropout)
self.query_norm = nn.LayerNorm(config.output_dim)
self.doc_norm = nn.LayerNorm(config.output_dim)
def project(self, embeddings: torch.Tensor) -> torch.Tensor:
return F.normalize(self.projector(embeddings), p=2, dim=-1)
def modulate_documents(
self,
query_embeddings: torch.Tensor,
document_embeddings: torch.Tensor,
) -> torch.Tensor:
q_scale, q_bias = self.query_hypernet(query_embeddings)
return self.doc_norm(
document_embeddings * q_scale.unsqueeze(1) + q_bias.unsqueeze(1)
)
def modulate_query(
self,
query_embeddings: torch.Tensor,
modulated_documents: torch.Tensor,
) -> torch.Tensor:
document_summary = modulated_documents.mean(dim=1)
d_scale, d_bias = self.doc_hypernet(document_summary)
return self.query_norm(query_embeddings * d_scale + d_bias)
def forward(
self,
query_embeddings: torch.Tensor,
document_embeddings: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
query_embeddings: [batch, input_dim]
document_embeddings: [batch, docs_per_query, input_dim]
"""
projected_queries = self.project(query_embeddings)
projected_documents = self.project(document_embeddings)
modulated_documents = self.modulate_documents(projected_queries, projected_documents)
modulated_queries = self.modulate_query(projected_queries, modulated_documents)
scores = torch.einsum("bd,bkd->bk", F.normalize(modulated_queries, p=2, dim=-1), F.normalize(modulated_documents, p=2, dim=-1))
return modulated_queries, modulated_documents, scores
def score_candidates(
self,
query_embedding: torch.Tensor,
candidate_document_embeddings: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if query_embedding.dim() == 1:
query_embedding = query_embedding.unsqueeze(0)
if candidate_document_embeddings.dim() == 2:
candidate_document_embeddings = candidate_document_embeddings.unsqueeze(0)
modulated_query, modulated_docs, scores = self.forward(query_embedding, candidate_document_embeddings)
return modulated_query.squeeze(0), modulated_docs.squeeze(0), scores.squeeze(0)
BiHyperNetIR = IMRNN
"""Backward-compatible alias retained for legacy checkpoints and code paths."""
|